1//===-- BrainF.cpp - BrainF compiler example ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This class compiles the BrainF language into LLVM assembly.
10//
11// The BrainF language has 8 commands:
12// Command   Equivalent C    Action
13// -------   ------------    ------
14// ,         *h=getchar();   Read a character from stdin, 255 on EOF
15// .         putchar(*h);    Write a character to stdout
16// -         --*h;           Decrement tape
17// +         ++*h;           Increment tape
18// <         --h;            Move head left
19// >         ++h;            Move head right
20// [         while(*h) {     Start loop
21// ]         }               End loop
22//
23//===----------------------------------------------------------------------===//
24
25#include "BrainF.h"
26#include "llvm/ADT/APInt.h"
27#include "llvm/IR/BasicBlock.h"
28#include "llvm/IR/Constant.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DerivedTypes.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/GlobalValue.h"
33#include "llvm/IR/GlobalVariable.h"
34#include "llvm/IR/InstrTypes.h"
35#include "llvm/IR/Instruction.h"
36#include "llvm/IR/Instructions.h"
37#include "llvm/IR/Intrinsics.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/Type.h"
40#include "llvm/Support/Casting.h"
41#include <cstdlib>
42#include <iostream>
43
44using namespace llvm;
45
46//Set the constants for naming
47const char *BrainF::tapereg = "tape";
48const char *BrainF::headreg = "head";
49const char *BrainF::label   = "brainf";
50const char *BrainF::testreg = "test";
51
52Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf,
53                      LLVMContext& Context) {
54  in       = in1;
55  memtotal = mem;
56  comflag  = cf;
57
58  header(Context);
59  readloop(nullptr, nullptr, nullptr, Context);
60  delete builder;
61  return module;
62}
63
64void BrainF::header(LLVMContext& C) {
65  module = new Module("BrainF", C);
66
67  //Function prototypes
68
69  //declare void @llvm.memset.p0i8.i32(i8 *, i8, i32, i1)
70  Type *Tys[] = { Type::getInt8PtrTy(C), Type::getInt32Ty(C) };
71  Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset,
72                                                    Tys);
73
74  //declare i32 @getchar()
75  getchar_func =
76      module->getOrInsertFunction("getchar", IntegerType::getInt32Ty(C));
77
78  //declare i32 @putchar(i32)
79  putchar_func = module->getOrInsertFunction(
80      "putchar", IntegerType::getInt32Ty(C), IntegerType::getInt32Ty(C));
81
82  //Function header
83
84  //define void @brainf()
85  brainf_func = Function::Create(FunctionType::get(Type::getVoidTy(C), false),
86                                 Function::ExternalLinkage, "brainf", module);
87
88  builder = new IRBuilder<>(BasicBlock::Create(C, label, brainf_func));
89
90  //%arr = malloc i8, i32 %d
91  ConstantInt *val_mem = ConstantInt::get(C, APInt(32, memtotal));
92  BasicBlock* BB = builder->GetInsertBlock();
93  Type* IntPtrTy = IntegerType::getInt32Ty(C);
94  Type* Int8Ty = IntegerType::getInt8Ty(C);
95  Constant* allocsize = ConstantExpr::getSizeOf(Int8Ty);
96  allocsize = ConstantExpr::getTruncOrBitCast(allocsize, IntPtrTy);
97  ptr_arr = CallInst::CreateMalloc(BB, IntPtrTy, Int8Ty, allocsize, val_mem,
98                                   nullptr, "arr");
99  BB->getInstList().push_back(cast<Instruction>(ptr_arr));
100
101  //call void @llvm.memset.p0i8.i32(i8 *%arr, i8 0, i32 %d, i1 0)
102  {
103    Value *memset_params[] = {
104      ptr_arr,
105      ConstantInt::get(C, APInt(8, 0)),
106      val_mem,
107      ConstantInt::get(C, APInt(1, 0))
108    };
109
110    CallInst *memset_call = builder->
111      CreateCall(memset_func, memset_params);
112    memset_call->setTailCall(false);
113  }
114
115  //%arrmax = getelementptr i8 *%arr, i32 %d
116  if (comflag & flag_arraybounds) {
117    ptr_arrmax = builder->
118      CreateGEP(ptr_arr, ConstantInt::get(C, APInt(32, memtotal)), "arrmax");
119  }
120
121  //%head.%d = getelementptr i8 *%arr, i32 %d
122  curhead = builder->CreateGEP(ptr_arr,
123                               ConstantInt::get(C, APInt(32, memtotal/2)),
124                               headreg);
125
126  //Function footer
127
128  //brainf.end:
129  endbb = BasicBlock::Create(C, label, brainf_func);
130
131  //call free(i8 *%arr)
132  endbb->getInstList().push_back(CallInst::CreateFree(ptr_arr, endbb));
133
134  //ret void
135  ReturnInst::Create(C, endbb);
136
137  //Error block for array out of bounds
138  if (comflag & flag_arraybounds)
139  {
140    //@aberrormsg = internal constant [%d x i8] c"\00"
141    Constant *msg_0 =
142      ConstantDataArray::getString(C, "Error: The head has left the tape.",
143                                   true);
144
145    GlobalVariable *aberrormsg = new GlobalVariable(
146      *module,
147      msg_0->getType(),
148      true,
149      GlobalValue::InternalLinkage,
150      msg_0,
151      "aberrormsg");
152
153    //declare i32 @puts(i8 *)
154    FunctionCallee puts_func = module->getOrInsertFunction(
155        "puts", IntegerType::getInt32Ty(C),
156        PointerType::getUnqual(IntegerType::getInt8Ty(C)));
157
158    //brainf.aberror:
159    aberrorbb = BasicBlock::Create(C, label, brainf_func);
160
161    //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
162    {
163      Constant *zero_32 = Constant::getNullValue(IntegerType::getInt32Ty(C));
164
165      Constant *gep_params[] = {
166        zero_32,
167        zero_32
168      };
169
170      Constant *msgptr = ConstantExpr::
171        getGetElementPtr(aberrormsg->getValueType(), aberrormsg, gep_params);
172
173      Value *puts_params[] = {
174        msgptr
175      };
176
177      CallInst *puts_call =
178        CallInst::Create(puts_func,
179                         puts_params,
180                         "", aberrorbb);
181      puts_call->setTailCall(false);
182    }
183
184    //br label %brainf.end
185    BranchInst::Create(endbb, aberrorbb);
186  }
187}
188
189void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb,
190                      LLVMContext &C) {
191  Symbol cursym = SYM_NONE;
192  int curvalue = 0;
193  Symbol nextsym = SYM_NONE;
194  int nextvalue = 0;
195  char c;
196  int loop;
197  int direction;
198
199  while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
200    // Write out commands
201    switch(cursym) {
202      case SYM_NONE:
203        // Do nothing
204        break;
205
206      case SYM_READ:
207        {
208          //%tape.%d = call i32 @getchar()
209          CallInst *getchar_call =
210              builder->CreateCall(getchar_func, {}, tapereg);
211          getchar_call->setTailCall(false);
212          Value *tape_0 = getchar_call;
213
214          //%tape.%d = trunc i32 %tape.%d to i8
215          Value *tape_1 = builder->
216            CreateTrunc(tape_0, IntegerType::getInt8Ty(C), tapereg);
217
218          //store i8 %tape.%d, i8 *%head.%d
219          builder->CreateStore(tape_1, curhead);
220        }
221        break;
222
223      case SYM_WRITE:
224        {
225          //%tape.%d = load i8 *%head.%d
226          LoadInst *tape_0 =
227              builder->CreateLoad(IntegerType::getInt8Ty(C), curhead, tapereg);
228
229          //%tape.%d = sext i8 %tape.%d to i32
230          Value *tape_1 = builder->
231            CreateSExt(tape_0, IntegerType::getInt32Ty(C), tapereg);
232
233          //call i32 @putchar(i32 %tape.%d)
234          Value *putchar_params[] = {
235            tape_1
236          };
237          CallInst *putchar_call = builder->
238            CreateCall(putchar_func,
239                       putchar_params);
240          putchar_call->setTailCall(false);
241        }
242        break;
243
244      case SYM_MOVE:
245        {
246          //%head.%d = getelementptr i8 *%head.%d, i32 %d
247          curhead = builder->
248            CreateGEP(curhead, ConstantInt::get(C, APInt(32, curvalue)),
249                      headreg);
250
251          //Error block for array out of bounds
252          if (comflag & flag_arraybounds)
253          {
254            //%test.%d = icmp uge i8 *%head.%d, %arrmax
255            Value *test_0 = builder->
256              CreateICmpUGE(curhead, ptr_arrmax, testreg);
257
258            //%test.%d = icmp ult i8 *%head.%d, %arr
259            Value *test_1 = builder->
260              CreateICmpULT(curhead, ptr_arr, testreg);
261
262            //%test.%d = or i1 %test.%d, %test.%d
263            Value *test_2 = builder->
264              CreateOr(test_0, test_1, testreg);
265
266            //br i1 %test.%d, label %main.%d, label %main.%d
267            BasicBlock *nextbb = BasicBlock::Create(C, label, brainf_func);
268            builder->CreateCondBr(test_2, aberrorbb, nextbb);
269
270            //main.%d:
271            builder->SetInsertPoint(nextbb);
272          }
273        }
274        break;
275
276      case SYM_CHANGE:
277        {
278          //%tape.%d = load i8 *%head.%d
279          LoadInst *tape_0 =
280              builder->CreateLoad(IntegerType::getInt8Ty(C), curhead, tapereg);
281
282          //%tape.%d = add i8 %tape.%d, %d
283          Value *tape_1 = builder->
284            CreateAdd(tape_0, ConstantInt::get(C, APInt(8, curvalue)), tapereg);
285
286          //store i8 %tape.%d, i8 *%head.%d\n"
287          builder->CreateStore(tape_1, curhead);
288        }
289        break;
290
291      case SYM_LOOP:
292        {
293          //br label %main.%d
294          BasicBlock *testbb = BasicBlock::Create(C, label, brainf_func);
295          builder->CreateBr(testbb);
296
297          //main.%d:
298          BasicBlock *bb_0 = builder->GetInsertBlock();
299          BasicBlock *bb_1 = BasicBlock::Create(C, label, brainf_func);
300          builder->SetInsertPoint(bb_1);
301
302          // Make part of PHI instruction now, wait until end of loop to finish
303          PHINode *phi_0 =
304            PHINode::Create(PointerType::getUnqual(IntegerType::getInt8Ty(C)),
305                            2, headreg, testbb);
306          phi_0->addIncoming(curhead, bb_0);
307          curhead = phi_0;
308
309          readloop(phi_0, bb_1, testbb, C);
310        }
311        break;
312
313      default:
314        std::cerr << "Error: Unknown symbol.\n";
315        abort();
316        break;
317    }
318
319    cursym = nextsym;
320    curvalue = nextvalue;
321    nextsym = SYM_NONE;
322
323    // Reading stdin loop
324    loop = (cursym == SYM_NONE)
325        || (cursym == SYM_MOVE)
326        || (cursym == SYM_CHANGE);
327    while(loop) {
328      *in>>c;
329      if (in->eof()) {
330        if (cursym == SYM_NONE) {
331          cursym = SYM_EOF;
332        } else {
333          nextsym = SYM_EOF;
334        }
335        loop = 0;
336      } else {
337        direction = 1;
338        switch(c) {
339          case '-':
340            direction = -1;
341            LLVM_FALLTHROUGH;
342
343          case '+':
344            if (cursym == SYM_CHANGE) {
345              curvalue += direction;
346              // loop = 1
347            } else {
348              if (cursym == SYM_NONE) {
349                cursym = SYM_CHANGE;
350                curvalue = direction;
351                // loop = 1
352              } else {
353                nextsym = SYM_CHANGE;
354                nextvalue = direction;
355                loop = 0;
356              }
357            }
358            break;
359
360          case '<':
361            direction = -1;
362            LLVM_FALLTHROUGH;
363
364          case '>':
365            if (cursym == SYM_MOVE) {
366              curvalue += direction;
367              // loop = 1
368            } else {
369              if (cursym == SYM_NONE) {
370                cursym = SYM_MOVE;
371                curvalue = direction;
372                // loop = 1
373              } else {
374                nextsym = SYM_MOVE;
375                nextvalue = direction;
376                loop = 0;
377              }
378            }
379            break;
380
381          case ',':
382            if (cursym == SYM_NONE) {
383              cursym = SYM_READ;
384            } else {
385              nextsym = SYM_READ;
386            }
387            loop = 0;
388            break;
389
390          case '.':
391            if (cursym == SYM_NONE) {
392              cursym = SYM_WRITE;
393            } else {
394              nextsym = SYM_WRITE;
395            }
396            loop = 0;
397            break;
398
399          case '[':
400            if (cursym == SYM_NONE) {
401              cursym = SYM_LOOP;
402            } else {
403              nextsym = SYM_LOOP;
404            }
405            loop = 0;
406            break;
407
408          case ']':
409            if (cursym == SYM_NONE) {
410              cursym = SYM_ENDLOOP;
411            } else {
412              nextsym = SYM_ENDLOOP;
413            }
414            loop = 0;
415            break;
416
417          // Ignore other characters
418          default:
419            break;
420        }
421      }
422    }
423  }
424
425  if (cursym == SYM_ENDLOOP) {
426    if (!phi) {
427      std::cerr << "Error: Extra ']'\n";
428      abort();
429    }
430
431    // Write loop test
432    {
433      //br label %main.%d
434      builder->CreateBr(testbb);
435
436      //main.%d:
437
438      //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
439      //Finish phi made at beginning of loop
440      phi->addIncoming(curhead, builder->GetInsertBlock());
441      Value *head_0 = phi;
442
443      //%tape.%d = load i8 *%head.%d
444      LoadInst *tape_0 = new LoadInst(IntegerType::getInt8Ty(C), head_0,
445                                      tapereg, testbb);
446
447      //%test.%d = icmp eq i8 %tape.%d, 0
448      ICmpInst *test_0 = new ICmpInst(*testbb, ICmpInst::ICMP_EQ, tape_0,
449                                    ConstantInt::get(C, APInt(8, 0)), testreg);
450
451      //br i1 %test.%d, label %main.%d, label %main.%d
452      BasicBlock *bb_0 = BasicBlock::Create(C, label, brainf_func);
453      BranchInst::Create(bb_0, oldbb, test_0, testbb);
454
455      //main.%d:
456      builder->SetInsertPoint(bb_0);
457
458      //%head.%d = phi i8 *[%head.%d, %main.%d]
459      PHINode *phi_1 = builder->
460        CreatePHI(PointerType::getUnqual(IntegerType::getInt8Ty(C)), 1,
461                  headreg);
462      phi_1->addIncoming(head_0, testbb);
463      curhead = phi_1;
464    }
465
466    return;
467  }
468
469  //End of the program, so go to return block
470  builder->CreateBr(endbb);
471
472  if (phi) {
473    std::cerr << "Error: Missing ']'\n";
474    abort();
475  }
476}
477