1//===-------- BasicOrcV2CBindings.c - Basic OrcV2 C Bindings Demo ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm-c/Core.h"
10#include "llvm-c/Error.h"
11#include "llvm-c/Initialization.h"
12#include "llvm-c/LLJIT.h"
13#include "llvm-c/Support.h"
14#include "llvm-c/Target.h"
15
16#include <assert.h>
17#include <stdio.h>
18
19int handleError(LLVMErrorRef Err) {
20  char *ErrMsg = LLVMGetErrorMessage(Err);
21  fprintf(stderr, "Error: %s\n", ErrMsg);
22  LLVMDisposeErrorMessage(ErrMsg);
23  return 1;
24}
25
26int32_t add(int32_t X, int32_t Y) { return X + Y; }
27
28int32_t mul(int32_t X, int32_t Y) { return X * Y; }
29
30int allowedSymbols(void *Ctx, LLVMOrcSymbolStringPoolEntryRef Sym) {
31  assert(Ctx && "Cannot call allowedSymbols with a null context");
32
33  LLVMOrcSymbolStringPoolEntryRef *AllowList =
34      (LLVMOrcSymbolStringPoolEntryRef *)Ctx;
35
36  // If Sym appears in the allowed list then return true.
37  LLVMOrcSymbolStringPoolEntryRef *P = AllowList;
38  while (*P) {
39    if (Sym == *P)
40      return 1;
41    ++P;
42  }
43
44  // otherwise return false.
45  return 0;
46}
47
48LLVMOrcThreadSafeModuleRef createDemoModule() {
49  // Create a new ThreadSafeContext and underlying LLVMContext.
50  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
51
52  // Get a reference to the underlying LLVMContext.
53  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
54
55  // Create a new LLVM module.
56  LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
57
58  // Add a "sum" function":
59  //  - Create the function type and function instance.
60  LLVMTypeRef I32BinOpParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()};
61  LLVMTypeRef I32BinOpFunctionType =
62      LLVMFunctionType(LLVMInt32Type(), I32BinOpParamTypes, 2, 0);
63  LLVMValueRef AddI32Function = LLVMAddFunction(M, "add", I32BinOpFunctionType);
64  LLVMValueRef MulI32Function = LLVMAddFunction(M, "mul", I32BinOpFunctionType);
65
66  LLVMTypeRef MulAddParamTypes[] = {LLVMInt32Type(), LLVMInt32Type(),
67                                    LLVMInt32Type()};
68  LLVMTypeRef MulAddFunctionType =
69      LLVMFunctionType(LLVMInt32Type(), MulAddParamTypes, 3, 0);
70  LLVMValueRef MulAddFunction =
71      LLVMAddFunction(M, "mul_add", MulAddFunctionType);
72
73  //  - Add a basic block to the function.
74  LLVMBasicBlockRef EntryBB = LLVMAppendBasicBlock(MulAddFunction, "entry");
75
76  //  - Add an IR builder and point it at the end of the basic block.
77  LLVMBuilderRef Builder = LLVMCreateBuilder();
78  LLVMPositionBuilderAtEnd(Builder, EntryBB);
79
80  //  - Get the three function arguments and use them co construct calls to
81  //    'mul' and 'add':
82  //
83  //    i32 mul_add(i32 %0, i32 %1, i32 %2) {
84  //      %t = call i32 @mul(i32 %0, i32 %1)
85  //      %r = call i32 @add(i32 %t, i32 %2)
86  //      ret i32 %r
87  //    }
88  LLVMValueRef SumArg0 = LLVMGetParam(MulAddFunction, 0);
89  LLVMValueRef SumArg1 = LLVMGetParam(MulAddFunction, 1);
90  LLVMValueRef SumArg2 = LLVMGetParam(MulAddFunction, 2);
91
92  LLVMValueRef MulArgs[] = {SumArg0, SumArg1};
93  LLVMValueRef MulResult = LLVMBuildCall2(Builder, I32BinOpFunctionType,
94                                          MulI32Function, MulArgs, 2, "t");
95
96  LLVMValueRef AddArgs[] = {MulResult, SumArg2};
97  LLVMValueRef AddResult = LLVMBuildCall2(Builder, I32BinOpFunctionType,
98                                          AddI32Function, AddArgs, 2, "r");
99
100  //  - Build the return instruction.
101  LLVMBuildRet(Builder, AddResult);
102
103  // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
104  // ThreadSafeModule.
105  LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
106
107  // Dispose of our local ThreadSafeContext value. The underlying LLVMContext
108  // will be kept alive by our ThreadSafeModule, TSM.
109  LLVMOrcDisposeThreadSafeContext(TSCtx);
110
111  // Return the result.
112  return TSM;
113}
114
115int main(int argc, char *argv[]) {
116
117  int MainResult = 0;
118
119  // Parse command line arguments and initialize LLVM Core.
120  LLVMParseCommandLineOptions(argc, (const char **)argv, "");
121  LLVMInitializeCore(LLVMGetGlobalPassRegistry());
122
123  // Initialize native target codegen and asm printer.
124  LLVMInitializeNativeTarget();
125  LLVMInitializeNativeAsmPrinter();
126
127  // Create the JIT instance.
128  LLVMOrcLLJITRef J;
129  {
130    LLVMErrorRef Err;
131    if ((Err = LLVMOrcCreateLLJIT(&J, 0))) {
132      MainResult = handleError(Err);
133      goto llvm_shutdown;
134    }
135  }
136
137  // Build a filter to allow JIT'd code to only access allowed symbols.
138  // This filter is optional: If a null value is suppled for the Filter
139  // argument to LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess then
140  // all process symbols will be reflected.
141  LLVMOrcSymbolStringPoolEntryRef AllowList[] = {
142      LLVMOrcLLJITMangleAndIntern(J, "mul"),
143      LLVMOrcLLJITMangleAndIntern(J, "add"), 0};
144
145  {
146    LLVMOrcDefinitionGeneratorRef ProcessSymbolsGenerator = 0;
147    LLVMErrorRef Err;
148    if ((Err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
149             &ProcessSymbolsGenerator, LLVMOrcLLJITGetGlobalPrefix(J),
150             allowedSymbols, AllowList))) {
151      MainResult = handleError(Err);
152      goto jit_cleanup;
153    }
154
155    LLVMOrcJITDylibAddGenerator(LLVMOrcLLJITGetMainJITDylib(J),
156                                ProcessSymbolsGenerator);
157  }
158
159  // Create our demo module.
160  LLVMOrcThreadSafeModuleRef TSM = createDemoModule();
161
162  // Add our demo module to the JIT.
163  {
164    LLVMOrcJITDylibRef MainJD = LLVMOrcLLJITGetMainJITDylib(J);
165    LLVMErrorRef Err;
166    if ((Err = LLVMOrcLLJITAddLLVMIRModule(J, MainJD, TSM))) {
167      // If adding the ThreadSafeModule fails then we need to clean it up
168      // ourselves. If adding it succeeds the JIT will manage the memory.
169      LLVMOrcDisposeThreadSafeModule(TSM);
170      MainResult = handleError(Err);
171      goto jit_cleanup;
172    }
173  }
174
175  // Look up the address of our demo entry point.
176  LLVMOrcJITTargetAddress MulAddAddr;
177  {
178    LLVMErrorRef Err;
179    if ((Err = LLVMOrcLLJITLookup(J, &MulAddAddr, "mul_add"))) {
180      MainResult = handleError(Err);
181      goto jit_cleanup;
182    }
183  }
184
185  // If we made it here then everything succeeded. Execute our JIT'd code.
186  int32_t (*MulAdd)(int32_t, int32_t, int32_t) =
187      (int32_t(*)(int32_t, int32_t, int32_t))MulAddAddr;
188  int32_t Result = MulAdd(3, 4, 5);
189
190  // Print the result.
191  printf("3 * 4 + 5 = %i\n", Result);
192
193jit_cleanup:
194  // Release all symbol string pool entries that we have allocated. In this
195  // example that's just our allowed entries.
196  {
197    LLVMOrcSymbolStringPoolEntryRef *P = AllowList;
198    while (*P)
199      LLVMOrcReleaseSymbolStringPoolEntry(*P++);
200  }
201
202  // Destroy our JIT instance. This will clean up any memory that the JIT has
203  // taken ownership of. This operation is non-trivial (e.g. it may need to
204  // JIT static destructors) and may also fail. In that case we want to render
205  // the error to stderr, but not overwrite any existing return value.
206  {
207    LLVMErrorRef Err;
208    if ((Err = LLVMOrcDisposeLLJIT(J))) {
209      int NewFailureResult = handleError(Err);
210      if (MainResult == 0)
211        MainResult = NewFailureResult;
212    }
213  }
214
215llvm_shutdown:
216  // Shut down LLVM.
217  LLVMShutdown();
218
219  return MainResult;
220}
221