1#include "llvm/DerivedTypes.h"
2#include "llvm/ExecutionEngine/ExecutionEngine.h"
3#include "llvm/ExecutionEngine/JIT.h"
4#include "llvm/IRBuilder.h"
5#include "llvm/LLVMContext.h"
6#include "llvm/Module.h"
7#include "llvm/PassManager.h"
8#include "llvm/Analysis/Verifier.h"
9#include "llvm/Analysis/Passes.h"
10#include "llvm/Target/TargetData.h"
11#include "llvm/Transforms/Scalar.h"
12#include "llvm/Support/TargetSelect.h"
13#include <cstdio>
14#include <string>
15#include <map>
16#include <vector>
17using namespace llvm;
18
19//===----------------------------------------------------------------------===//
20// Lexer
21//===----------------------------------------------------------------------===//
22
23// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
24// of these for known things.
25enum Token {
26  tok_eof = -1,
27
28  // commands
29  tok_def = -2, tok_extern = -3,
30
31  // primary
32  tok_identifier = -4, tok_number = -5
33};
34
35static std::string IdentifierStr;  // Filled in if tok_identifier
36static double NumVal;              // Filled in if tok_number
37
38/// gettok - Return the next token from standard input.
39static int gettok() {
40  static int LastChar = ' ';
41
42  // Skip any whitespace.
43  while (isspace(LastChar))
44    LastChar = getchar();
45
46  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
47    IdentifierStr = LastChar;
48    while (isalnum((LastChar = getchar())))
49      IdentifierStr += LastChar;
50
51    if (IdentifierStr == "def") return tok_def;
52    if (IdentifierStr == "extern") return tok_extern;
53    return tok_identifier;
54  }
55
56  if (isdigit(LastChar) || LastChar == '.') {   // Number: [0-9.]+
57    std::string NumStr;
58    do {
59      NumStr += LastChar;
60      LastChar = getchar();
61    } while (isdigit(LastChar) || LastChar == '.');
62
63    NumVal = strtod(NumStr.c_str(), 0);
64    return tok_number;
65  }
66
67  if (LastChar == '#') {
68    // Comment until end of line.
69    do LastChar = getchar();
70    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
71
72    if (LastChar != EOF)
73      return gettok();
74  }
75
76  // Check for end of file.  Don't eat the EOF.
77  if (LastChar == EOF)
78    return tok_eof;
79
80  // Otherwise, just return the character as its ascii value.
81  int ThisChar = LastChar;
82  LastChar = getchar();
83  return ThisChar;
84}
85
86//===----------------------------------------------------------------------===//
87// Abstract Syntax Tree (aka Parse Tree)
88//===----------------------------------------------------------------------===//
89
90/// ExprAST - Base class for all expression nodes.
91class ExprAST {
92public:
93  virtual ~ExprAST() {}
94  virtual Value *Codegen() = 0;
95};
96
97/// NumberExprAST - Expression class for numeric literals like "1.0".
98class NumberExprAST : public ExprAST {
99  double Val;
100public:
101  NumberExprAST(double val) : Val(val) {}
102  virtual Value *Codegen();
103};
104
105/// VariableExprAST - Expression class for referencing a variable, like "a".
106class VariableExprAST : public ExprAST {
107  std::string Name;
108public:
109  VariableExprAST(const std::string &name) : Name(name) {}
110  virtual Value *Codegen();
111};
112
113/// BinaryExprAST - Expression class for a binary operator.
114class BinaryExprAST : public ExprAST {
115  char Op;
116  ExprAST *LHS, *RHS;
117public:
118  BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
119    : Op(op), LHS(lhs), RHS(rhs) {}
120  virtual Value *Codegen();
121};
122
123/// CallExprAST - Expression class for function calls.
124class CallExprAST : public ExprAST {
125  std::string Callee;
126  std::vector<ExprAST*> Args;
127public:
128  CallExprAST(const std::string &callee, std::vector<ExprAST*> &args)
129    : Callee(callee), Args(args) {}
130  virtual Value *Codegen();
131};
132
133/// PrototypeAST - This class represents the "prototype" for a function,
134/// which captures its name, and its argument names (thus implicitly the number
135/// of arguments the function takes).
136class PrototypeAST {
137  std::string Name;
138  std::vector<std::string> Args;
139public:
140  PrototypeAST(const std::string &name, const std::vector<std::string> &args)
141    : Name(name), Args(args) {}
142
143  Function *Codegen();
144};
145
146/// FunctionAST - This class represents a function definition itself.
147class FunctionAST {
148  PrototypeAST *Proto;
149  ExprAST *Body;
150public:
151  FunctionAST(PrototypeAST *proto, ExprAST *body)
152    : Proto(proto), Body(body) {}
153
154  Function *Codegen();
155};
156
157//===----------------------------------------------------------------------===//
158// Parser
159//===----------------------------------------------------------------------===//
160
161/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
162/// token the parser is looking at.  getNextToken reads another token from the
163/// lexer and updates CurTok with its results.
164static int CurTok;
165static int getNextToken() {
166  return CurTok = gettok();
167}
168
169/// BinopPrecedence - This holds the precedence for each binary operator that is
170/// defined.
171static std::map<char, int> BinopPrecedence;
172
173/// GetTokPrecedence - Get the precedence of the pending binary operator token.
174static int GetTokPrecedence() {
175  if (!isascii(CurTok))
176    return -1;
177
178  // Make sure it's a declared binop.
179  int TokPrec = BinopPrecedence[CurTok];
180  if (TokPrec <= 0) return -1;
181  return TokPrec;
182}
183
184/// Error* - These are little helper functions for error handling.
185ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
186PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; }
187FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; }
188
189static ExprAST *ParseExpression();
190
191/// identifierexpr
192///   ::= identifier
193///   ::= identifier '(' expression* ')'
194static ExprAST *ParseIdentifierExpr() {
195  std::string IdName = IdentifierStr;
196
197  getNextToken();  // eat identifier.
198
199  if (CurTok != '(') // Simple variable ref.
200    return new VariableExprAST(IdName);
201
202  // Call.
203  getNextToken();  // eat (
204  std::vector<ExprAST*> Args;
205  if (CurTok != ')') {
206    while (1) {
207      ExprAST *Arg = ParseExpression();
208      if (!Arg) return 0;
209      Args.push_back(Arg);
210
211      if (CurTok == ')') break;
212
213      if (CurTok != ',')
214        return Error("Expected ')' or ',' in argument list");
215      getNextToken();
216    }
217  }
218
219  // Eat the ')'.
220  getNextToken();
221
222  return new CallExprAST(IdName, Args);
223}
224
225/// numberexpr ::= number
226static ExprAST *ParseNumberExpr() {
227  ExprAST *Result = new NumberExprAST(NumVal);
228  getNextToken(); // consume the number
229  return Result;
230}
231
232/// parenexpr ::= '(' expression ')'
233static ExprAST *ParseParenExpr() {
234  getNextToken();  // eat (.
235  ExprAST *V = ParseExpression();
236  if (!V) return 0;
237
238  if (CurTok != ')')
239    return Error("expected ')'");
240  getNextToken();  // eat ).
241  return V;
242}
243
244/// primary
245///   ::= identifierexpr
246///   ::= numberexpr
247///   ::= parenexpr
248static ExprAST *ParsePrimary() {
249  switch (CurTok) {
250  default: return Error("unknown token when expecting an expression");
251  case tok_identifier: return ParseIdentifierExpr();
252  case tok_number:     return ParseNumberExpr();
253  case '(':            return ParseParenExpr();
254  }
255}
256
257/// binoprhs
258///   ::= ('+' primary)*
259static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) {
260  // If this is a binop, find its precedence.
261  while (1) {
262    int TokPrec = GetTokPrecedence();
263
264    // If this is a binop that binds at least as tightly as the current binop,
265    // consume it, otherwise we are done.
266    if (TokPrec < ExprPrec)
267      return LHS;
268
269    // Okay, we know this is a binop.
270    int BinOp = CurTok;
271    getNextToken();  // eat binop
272
273    // Parse the primary expression after the binary operator.
274    ExprAST *RHS = ParsePrimary();
275    if (!RHS) return 0;
276
277    // If BinOp binds less tightly with RHS than the operator after RHS, let
278    // the pending operator take RHS as its LHS.
279    int NextPrec = GetTokPrecedence();
280    if (TokPrec < NextPrec) {
281      RHS = ParseBinOpRHS(TokPrec+1, RHS);
282      if (RHS == 0) return 0;
283    }
284
285    // Merge LHS/RHS.
286    LHS = new BinaryExprAST(BinOp, LHS, RHS);
287  }
288}
289
290/// expression
291///   ::= primary binoprhs
292///
293static ExprAST *ParseExpression() {
294  ExprAST *LHS = ParsePrimary();
295  if (!LHS) return 0;
296
297  return ParseBinOpRHS(0, LHS);
298}
299
300/// prototype
301///   ::= id '(' id* ')'
302static PrototypeAST *ParsePrototype() {
303  if (CurTok != tok_identifier)
304    return ErrorP("Expected function name in prototype");
305
306  std::string FnName = IdentifierStr;
307  getNextToken();
308
309  if (CurTok != '(')
310    return ErrorP("Expected '(' in prototype");
311
312  std::vector<std::string> ArgNames;
313  while (getNextToken() == tok_identifier)
314    ArgNames.push_back(IdentifierStr);
315  if (CurTok != ')')
316    return ErrorP("Expected ')' in prototype");
317
318  // success.
319  getNextToken();  // eat ')'.
320
321  return new PrototypeAST(FnName, ArgNames);
322}
323
324/// definition ::= 'def' prototype expression
325static FunctionAST *ParseDefinition() {
326  getNextToken();  // eat def.
327  PrototypeAST *Proto = ParsePrototype();
328  if (Proto == 0) return 0;
329
330  if (ExprAST *E = ParseExpression())
331    return new FunctionAST(Proto, E);
332  return 0;
333}
334
335/// toplevelexpr ::= expression
336static FunctionAST *ParseTopLevelExpr() {
337  if (ExprAST *E = ParseExpression()) {
338    // Make an anonymous proto.
339    PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
340    return new FunctionAST(Proto, E);
341  }
342  return 0;
343}
344
345/// external ::= 'extern' prototype
346static PrototypeAST *ParseExtern() {
347  getNextToken();  // eat extern.
348  return ParsePrototype();
349}
350
351//===----------------------------------------------------------------------===//
352// Code Generation
353//===----------------------------------------------------------------------===//
354
355static Module *TheModule;
356static IRBuilder<> Builder(getGlobalContext());
357static std::map<std::string, Value*> NamedValues;
358static FunctionPassManager *TheFPM;
359
360Value *ErrorV(const char *Str) { Error(Str); return 0; }
361
362Value *NumberExprAST::Codegen() {
363  return ConstantFP::get(getGlobalContext(), APFloat(Val));
364}
365
366Value *VariableExprAST::Codegen() {
367  // Look this variable up in the function.
368  Value *V = NamedValues[Name];
369  return V ? V : ErrorV("Unknown variable name");
370}
371
372Value *BinaryExprAST::Codegen() {
373  Value *L = LHS->Codegen();
374  Value *R = RHS->Codegen();
375  if (L == 0 || R == 0) return 0;
376
377  switch (Op) {
378  case '+': return Builder.CreateFAdd(L, R, "addtmp");
379  case '-': return Builder.CreateFSub(L, R, "subtmp");
380  case '*': return Builder.CreateFMul(L, R, "multmp");
381  case '<':
382    L = Builder.CreateFCmpULT(L, R, "cmptmp");
383    // Convert bool 0/1 to double 0.0 or 1.0
384    return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
385                                "booltmp");
386  default: return ErrorV("invalid binary operator");
387  }
388}
389
390Value *CallExprAST::Codegen() {
391  // Look up the name in the global module table.
392  Function *CalleeF = TheModule->getFunction(Callee);
393  if (CalleeF == 0)
394    return ErrorV("Unknown function referenced");
395
396  // If argument mismatch error.
397  if (CalleeF->arg_size() != Args.size())
398    return ErrorV("Incorrect # arguments passed");
399
400  std::vector<Value*> ArgsV;
401  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
402    ArgsV.push_back(Args[i]->Codegen());
403    if (ArgsV.back() == 0) return 0;
404  }
405
406  return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
407}
408
409Function *PrototypeAST::Codegen() {
410  // Make the function type:  double(double,double) etc.
411  std::vector<Type*> Doubles(Args.size(),
412                             Type::getDoubleTy(getGlobalContext()));
413  FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
414                                       Doubles, false);
415
416  Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
417
418  // If F conflicted, there was already something named 'Name'.  If it has a
419  // body, don't allow redefinition or reextern.
420  if (F->getName() != Name) {
421    // Delete the one we just made and get the existing one.
422    F->eraseFromParent();
423    F = TheModule->getFunction(Name);
424
425    // If F already has a body, reject this.
426    if (!F->empty()) {
427      ErrorF("redefinition of function");
428      return 0;
429    }
430
431    // If F took a different number of args, reject.
432    if (F->arg_size() != Args.size()) {
433      ErrorF("redefinition of function with different # args");
434      return 0;
435    }
436  }
437
438  // Set names for all arguments.
439  unsigned Idx = 0;
440  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
441       ++AI, ++Idx) {
442    AI->setName(Args[Idx]);
443
444    // Add arguments to variable symbol table.
445    NamedValues[Args[Idx]] = AI;
446  }
447
448  return F;
449}
450
451Function *FunctionAST::Codegen() {
452  NamedValues.clear();
453
454  Function *TheFunction = Proto->Codegen();
455  if (TheFunction == 0)
456    return 0;
457
458  // Create a new basic block to start insertion into.
459  BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
460  Builder.SetInsertPoint(BB);
461
462  if (Value *RetVal = Body->Codegen()) {
463    // Finish off the function.
464    Builder.CreateRet(RetVal);
465
466    // Validate the generated code, checking for consistency.
467    verifyFunction(*TheFunction);
468
469    // Optimize the function.
470    TheFPM->run(*TheFunction);
471
472    return TheFunction;
473  }
474
475  // Error reading body, remove function.
476  TheFunction->eraseFromParent();
477  return 0;
478}
479
480//===----------------------------------------------------------------------===//
481// Top-Level parsing and JIT Driver
482//===----------------------------------------------------------------------===//
483
484static ExecutionEngine *TheExecutionEngine;
485
486static void HandleDefinition() {
487  if (FunctionAST *F = ParseDefinition()) {
488    if (Function *LF = F->Codegen()) {
489      fprintf(stderr, "Read function definition:");
490      LF->dump();
491    }
492  } else {
493    // Skip token for error recovery.
494    getNextToken();
495  }
496}
497
498static void HandleExtern() {
499  if (PrototypeAST *P = ParseExtern()) {
500    if (Function *F = P->Codegen()) {
501      fprintf(stderr, "Read extern: ");
502      F->dump();
503    }
504  } else {
505    // Skip token for error recovery.
506    getNextToken();
507  }
508}
509
510static void HandleTopLevelExpression() {
511  // Evaluate a top-level expression into an anonymous function.
512  if (FunctionAST *F = ParseTopLevelExpr()) {
513    if (Function *LF = F->Codegen()) {
514      // JIT the function, returning a function pointer.
515      void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
516
517      // Cast it to the right type (takes no arguments, returns a double) so we
518      // can call it as a native function.
519      double (*FP)() = (double (*)())(intptr_t)FPtr;
520      fprintf(stderr, "Evaluated to %f\n", FP());
521    }
522  } else {
523    // Skip token for error recovery.
524    getNextToken();
525  }
526}
527
528/// top ::= definition | external | expression | ';'
529static void MainLoop() {
530  while (1) {
531    fprintf(stderr, "ready> ");
532    switch (CurTok) {
533    case tok_eof:    return;
534    case ';':        getNextToken(); break;  // ignore top-level semicolons.
535    case tok_def:    HandleDefinition(); break;
536    case tok_extern: HandleExtern(); break;
537    default:         HandleTopLevelExpression(); break;
538    }
539  }
540}
541
542//===----------------------------------------------------------------------===//
543// "Library" functions that can be "extern'd" from user code.
544//===----------------------------------------------------------------------===//
545
546/// putchard - putchar that takes a double and returns 0.
547extern "C"
548double putchard(double X) {
549  putchar((char)X);
550  return 0;
551}
552
553//===----------------------------------------------------------------------===//
554// Main driver code.
555//===----------------------------------------------------------------------===//
556
557int main() {
558  InitializeNativeTarget();
559  LLVMContext &Context = getGlobalContext();
560
561  // Install standard binary operators.
562  // 1 is lowest precedence.
563  BinopPrecedence['<'] = 10;
564  BinopPrecedence['+'] = 20;
565  BinopPrecedence['-'] = 20;
566  BinopPrecedence['*'] = 40;  // highest.
567
568  // Prime the first token.
569  fprintf(stderr, "ready> ");
570  getNextToken();
571
572  // Make the module, which holds all the code.
573  TheModule = new Module("my cool jit", Context);
574
575  // Create the JIT.  This takes ownership of the module.
576  std::string ErrStr;
577  TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
578  if (!TheExecutionEngine) {
579    fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
580    exit(1);
581  }
582
583  FunctionPassManager OurFPM(TheModule);
584
585  // Set up the optimizer pipeline.  Start with registering info about how the
586  // target lays out data structures.
587  OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData()));
588  // Provide basic AliasAnalysis support for GVN.
589  OurFPM.add(createBasicAliasAnalysisPass());
590  // Do simple "peephole" optimizations and bit-twiddling optzns.
591  OurFPM.add(createInstructionCombiningPass());
592  // Reassociate expressions.
593  OurFPM.add(createReassociatePass());
594  // Eliminate Common SubExpressions.
595  OurFPM.add(createGVNPass());
596  // Simplify the control flow graph (deleting unreachable blocks, etc).
597  OurFPM.add(createCFGSimplificationPass());
598
599  OurFPM.doInitialization();
600
601  // Set the global so the code gen can use this.
602  TheFPM = &OurFPM;
603
604  // Run the main "interpreter loop" now.
605  MainLoop();
606
607  TheFPM = 0;
608
609  // Print out all of the generated code.
610  TheModule->dump();
611
612  return 0;
613}
614