1#include "llvm/ADT/APFloat.h"
2#include "llvm/ADT/STLExtras.h"
3#include "llvm/IR/BasicBlock.h"
4#include "llvm/IR/Constants.h"
5#include "llvm/IR/DerivedTypes.h"
6#include "llvm/IR/Function.h"
7#include "llvm/IR/Instructions.h"
8#include "llvm/IR/IRBuilder.h"
9#include "llvm/IR/LLVMContext.h"
10#include "llvm/IR/Module.h"
11#include "llvm/IR/Type.h"
12#include "llvm/IR/Verifier.h"
13#include "llvm/Support/TargetSelect.h"
14#include "llvm/Target/TargetMachine.h"
15#include "KaleidoscopeJIT.h"
16#include <algorithm>
17#include <cassert>
18#include <cctype>
19#include <cstdint>
20#include <cstdio>
21#include <cstdlib>
22#include <map>
23#include <memory>
24#include <string>
25#include <utility>
26#include <vector>
27
28using namespace llvm;
29using namespace llvm::orc;
30
31//===----------------------------------------------------------------------===//
32// Lexer
33//===----------------------------------------------------------------------===//
34
35// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
36// of these for known things.
37enum Token {
38  tok_eof = -1,
39
40  // commands
41  tok_def = -2,
42  tok_extern = -3,
43
44  // primary
45  tok_identifier = -4,
46  tok_number = -5,
47
48  // control
49  tok_if = -6,
50  tok_then = -7,
51  tok_else = -8,
52  tok_for = -9,
53  tok_in = -10,
54
55  // operators
56  tok_binary = -11,
57  tok_unary = -12,
58
59  // var definition
60  tok_var = -13
61};
62
63static std::string IdentifierStr; // Filled in if tok_identifier
64static double NumVal;             // Filled in if tok_number
65
66/// gettok - Return the next token from standard input.
67static int gettok() {
68  static int LastChar = ' ';
69
70  // Skip any whitespace.
71  while (isspace(LastChar))
72    LastChar = getchar();
73
74  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
75    IdentifierStr = LastChar;
76    while (isalnum((LastChar = getchar())))
77      IdentifierStr += LastChar;
78
79    if (IdentifierStr == "def")
80      return tok_def;
81    if (IdentifierStr == "extern")
82      return tok_extern;
83    if (IdentifierStr == "if")
84      return tok_if;
85    if (IdentifierStr == "then")
86      return tok_then;
87    if (IdentifierStr == "else")
88      return tok_else;
89    if (IdentifierStr == "for")
90      return tok_for;
91    if (IdentifierStr == "in")
92      return tok_in;
93    if (IdentifierStr == "binary")
94      return tok_binary;
95    if (IdentifierStr == "unary")
96      return tok_unary;
97    if (IdentifierStr == "var")
98      return tok_var;
99    return tok_identifier;
100  }
101
102  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
103    std::string NumStr;
104    do {
105      NumStr += LastChar;
106      LastChar = getchar();
107    } while (isdigit(LastChar) || LastChar == '.');
108
109    NumVal = strtod(NumStr.c_str(), nullptr);
110    return tok_number;
111  }
112
113  if (LastChar == '#') {
114    // Comment until end of line.
115    do
116      LastChar = getchar();
117    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
118
119    if (LastChar != EOF)
120      return gettok();
121  }
122
123  // Check for end of file.  Don't eat the EOF.
124  if (LastChar == EOF)
125    return tok_eof;
126
127  // Otherwise, just return the character as its ascii value.
128  int ThisChar = LastChar;
129  LastChar = getchar();
130  return ThisChar;
131}
132
133//===----------------------------------------------------------------------===//
134// Abstract Syntax Tree (aka Parse Tree)
135//===----------------------------------------------------------------------===//
136
137namespace {
138
139/// ExprAST - Base class for all expression nodes.
140class ExprAST {
141public:
142  virtual ~ExprAST() = default;
143
144  virtual Value *codegen() = 0;
145};
146
147/// NumberExprAST - Expression class for numeric literals like "1.0".
148class NumberExprAST : public ExprAST {
149  double Val;
150
151public:
152  NumberExprAST(double Val) : Val(Val) {}
153
154  Value *codegen() override;
155};
156
157/// VariableExprAST - Expression class for referencing a variable, like "a".
158class VariableExprAST : public ExprAST {
159  std::string Name;
160
161public:
162  VariableExprAST(const std::string &Name) : Name(Name) {}
163
164  Value *codegen() override;
165  const std::string &getName() const { return Name; }
166};
167
168/// UnaryExprAST - Expression class for a unary operator.
169class UnaryExprAST : public ExprAST {
170  char Opcode;
171  std::unique_ptr<ExprAST> Operand;
172
173public:
174  UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
175      : Opcode(Opcode), Operand(std::move(Operand)) {}
176
177  Value *codegen() override;
178};
179
180/// BinaryExprAST - Expression class for a binary operator.
181class BinaryExprAST : public ExprAST {
182  char Op;
183  std::unique_ptr<ExprAST> LHS, RHS;
184
185public:
186  BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
187                std::unique_ptr<ExprAST> RHS)
188      : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
189
190  Value *codegen() override;
191};
192
193/// CallExprAST - Expression class for function calls.
194class CallExprAST : public ExprAST {
195  std::string Callee;
196  std::vector<std::unique_ptr<ExprAST>> Args;
197
198public:
199  CallExprAST(const std::string &Callee,
200              std::vector<std::unique_ptr<ExprAST>> Args)
201      : Callee(Callee), Args(std::move(Args)) {}
202
203  Value *codegen() override;
204};
205
206/// IfExprAST - Expression class for if/then/else.
207class IfExprAST : public ExprAST {
208  std::unique_ptr<ExprAST> Cond, Then, Else;
209
210public:
211  IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
212            std::unique_ptr<ExprAST> Else)
213      : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
214
215  Value *codegen() override;
216};
217
218/// ForExprAST - Expression class for for/in.
219class ForExprAST : public ExprAST {
220  std::string VarName;
221  std::unique_ptr<ExprAST> Start, End, Step, Body;
222
223public:
224  ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
225             std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
226             std::unique_ptr<ExprAST> Body)
227      : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
228        Step(std::move(Step)), Body(std::move(Body)) {}
229
230  Value *codegen() override;
231};
232
233/// VarExprAST - Expression class for var/in
234class VarExprAST : public ExprAST {
235  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
236  std::unique_ptr<ExprAST> Body;
237
238public:
239  VarExprAST(
240      std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
241      std::unique_ptr<ExprAST> Body)
242      : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
243
244  Value *codegen() override;
245};
246
247/// PrototypeAST - This class represents the "prototype" for a function,
248/// which captures its name, and its argument names (thus implicitly the number
249/// of arguments the function takes), as well as if it is an operator.
250class PrototypeAST {
251  std::string Name;
252  std::vector<std::string> Args;
253  bool IsOperator;
254  unsigned Precedence; // Precedence if a binary op.
255
256public:
257  PrototypeAST(const std::string &Name, std::vector<std::string> Args,
258               bool IsOperator = false, unsigned Prec = 0)
259      : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
260        Precedence(Prec) {}
261
262  Function *codegen();
263  const std::string &getName() const { return Name; }
264
265  bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
266  bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
267
268  char getOperatorName() const {
269    assert(isUnaryOp() || isBinaryOp());
270    return Name[Name.size() - 1];
271  }
272
273  unsigned getBinaryPrecedence() const { return Precedence; }
274};
275
276/// FunctionAST - This class represents a function definition itself.
277class FunctionAST {
278  std::unique_ptr<PrototypeAST> Proto;
279  std::unique_ptr<ExprAST> Body;
280
281public:
282  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
283              std::unique_ptr<ExprAST> Body)
284      : Proto(std::move(Proto)), Body(std::move(Body)) {}
285
286  Function *codegen();
287};
288
289} // end anonymous namespace
290
291//===----------------------------------------------------------------------===//
292// Parser
293//===----------------------------------------------------------------------===//
294
295/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
296/// token the parser is looking at.  getNextToken reads another token from the
297/// lexer and updates CurTok with its results.
298static int CurTok;
299static int getNextToken() { return CurTok = gettok(); }
300
301/// BinopPrecedence - This holds the precedence for each binary operator that is
302/// defined.
303static std::map<char, int> BinopPrecedence;
304
305/// GetTokPrecedence - Get the precedence of the pending binary operator token.
306static int GetTokPrecedence() {
307  if (!isascii(CurTok))
308    return -1;
309
310  // Make sure it's a declared binop.
311  int TokPrec = BinopPrecedence[CurTok];
312  if (TokPrec <= 0)
313    return -1;
314  return TokPrec;
315}
316
317/// LogError* - These are little helper functions for error handling.
318std::unique_ptr<ExprAST> LogError(const char *Str) {
319  fprintf(stderr, "Error: %s\n", Str);
320  return nullptr;
321}
322
323std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
324  LogError(Str);
325  return nullptr;
326}
327
328static std::unique_ptr<ExprAST> ParseExpression();
329
330/// numberexpr ::= number
331static std::unique_ptr<ExprAST> ParseNumberExpr() {
332  auto Result = std::make_unique<NumberExprAST>(NumVal);
333  getNextToken(); // consume the number
334  return std::move(Result);
335}
336
337/// parenexpr ::= '(' expression ')'
338static std::unique_ptr<ExprAST> ParseParenExpr() {
339  getNextToken(); // eat (.
340  auto V = ParseExpression();
341  if (!V)
342    return nullptr;
343
344  if (CurTok != ')')
345    return LogError("expected ')'");
346  getNextToken(); // eat ).
347  return V;
348}
349
350/// identifierexpr
351///   ::= identifier
352///   ::= identifier '(' expression* ')'
353static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
354  std::string IdName = IdentifierStr;
355
356  getNextToken(); // eat identifier.
357
358  if (CurTok != '(') // Simple variable ref.
359    return std::make_unique<VariableExprAST>(IdName);
360
361  // Call.
362  getNextToken(); // eat (
363  std::vector<std::unique_ptr<ExprAST>> Args;
364  if (CurTok != ')') {
365    while (true) {
366      if (auto Arg = ParseExpression())
367        Args.push_back(std::move(Arg));
368      else
369        return nullptr;
370
371      if (CurTok == ')')
372        break;
373
374      if (CurTok != ',')
375        return LogError("Expected ')' or ',' in argument list");
376      getNextToken();
377    }
378  }
379
380  // Eat the ')'.
381  getNextToken();
382
383  return std::make_unique<CallExprAST>(IdName, std::move(Args));
384}
385
386/// ifexpr ::= 'if' expression 'then' expression 'else' expression
387static std::unique_ptr<ExprAST> ParseIfExpr() {
388  getNextToken(); // eat the if.
389
390  // condition.
391  auto Cond = ParseExpression();
392  if (!Cond)
393    return nullptr;
394
395  if (CurTok != tok_then)
396    return LogError("expected then");
397  getNextToken(); // eat the then
398
399  auto Then = ParseExpression();
400  if (!Then)
401    return nullptr;
402
403  if (CurTok != tok_else)
404    return LogError("expected else");
405
406  getNextToken();
407
408  auto Else = ParseExpression();
409  if (!Else)
410    return nullptr;
411
412  return std::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
413                                      std::move(Else));
414}
415
416/// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
417static std::unique_ptr<ExprAST> ParseForExpr() {
418  getNextToken(); // eat the for.
419
420  if (CurTok != tok_identifier)
421    return LogError("expected identifier after for");
422
423  std::string IdName = IdentifierStr;
424  getNextToken(); // eat identifier.
425
426  if (CurTok != '=')
427    return LogError("expected '=' after for");
428  getNextToken(); // eat '='.
429
430  auto Start = ParseExpression();
431  if (!Start)
432    return nullptr;
433  if (CurTok != ',')
434    return LogError("expected ',' after for start value");
435  getNextToken();
436
437  auto End = ParseExpression();
438  if (!End)
439    return nullptr;
440
441  // The step value is optional.
442  std::unique_ptr<ExprAST> Step;
443  if (CurTok == ',') {
444    getNextToken();
445    Step = ParseExpression();
446    if (!Step)
447      return nullptr;
448  }
449
450  if (CurTok != tok_in)
451    return LogError("expected 'in' after for");
452  getNextToken(); // eat 'in'.
453
454  auto Body = ParseExpression();
455  if (!Body)
456    return nullptr;
457
458  return std::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
459                                       std::move(Step), std::move(Body));
460}
461
462/// varexpr ::= 'var' identifier ('=' expression)?
463//                    (',' identifier ('=' expression)?)* 'in' expression
464static std::unique_ptr<ExprAST> ParseVarExpr() {
465  getNextToken(); // eat the var.
466
467  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
468
469  // At least one variable name is required.
470  if (CurTok != tok_identifier)
471    return LogError("expected identifier after var");
472
473  while (true) {
474    std::string Name = IdentifierStr;
475    getNextToken(); // eat identifier.
476
477    // Read the optional initializer.
478    std::unique_ptr<ExprAST> Init = nullptr;
479    if (CurTok == '=') {
480      getNextToken(); // eat the '='.
481
482      Init = ParseExpression();
483      if (!Init)
484        return nullptr;
485    }
486
487    VarNames.push_back(std::make_pair(Name, std::move(Init)));
488
489    // End of var list, exit loop.
490    if (CurTok != ',')
491      break;
492    getNextToken(); // eat the ','.
493
494    if (CurTok != tok_identifier)
495      return LogError("expected identifier list after var");
496  }
497
498  // At this point, we have to have 'in'.
499  if (CurTok != tok_in)
500    return LogError("expected 'in' keyword after 'var'");
501  getNextToken(); // eat 'in'.
502
503  auto Body = ParseExpression();
504  if (!Body)
505    return nullptr;
506
507  return std::make_unique<VarExprAST>(std::move(VarNames), std::move(Body));
508}
509
510/// primary
511///   ::= identifierexpr
512///   ::= numberexpr
513///   ::= parenexpr
514///   ::= ifexpr
515///   ::= forexpr
516///   ::= varexpr
517static std::unique_ptr<ExprAST> ParsePrimary() {
518  switch (CurTok) {
519  default:
520    return LogError("unknown token when expecting an expression");
521  case tok_identifier:
522    return ParseIdentifierExpr();
523  case tok_number:
524    return ParseNumberExpr();
525  case '(':
526    return ParseParenExpr();
527  case tok_if:
528    return ParseIfExpr();
529  case tok_for:
530    return ParseForExpr();
531  case tok_var:
532    return ParseVarExpr();
533  }
534}
535
536/// unary
537///   ::= primary
538///   ::= '!' unary
539static std::unique_ptr<ExprAST> ParseUnary() {
540  // If the current token is not an operator, it must be a primary expr.
541  if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
542    return ParsePrimary();
543
544  // If this is a unary operator, read it.
545  int Opc = CurTok;
546  getNextToken();
547  if (auto Operand = ParseUnary())
548    return std::make_unique<UnaryExprAST>(Opc, std::move(Operand));
549  return nullptr;
550}
551
552/// binoprhs
553///   ::= ('+' unary)*
554static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
555                                              std::unique_ptr<ExprAST> LHS) {
556  // If this is a binop, find its precedence.
557  while (true) {
558    int TokPrec = GetTokPrecedence();
559
560    // If this is a binop that binds at least as tightly as the current binop,
561    // consume it, otherwise we are done.
562    if (TokPrec < ExprPrec)
563      return LHS;
564
565    // Okay, we know this is a binop.
566    int BinOp = CurTok;
567    getNextToken(); // eat binop
568
569    // Parse the unary expression after the binary operator.
570    auto RHS = ParseUnary();
571    if (!RHS)
572      return nullptr;
573
574    // If BinOp binds less tightly with RHS than the operator after RHS, let
575    // the pending operator take RHS as its LHS.
576    int NextPrec = GetTokPrecedence();
577    if (TokPrec < NextPrec) {
578      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
579      if (!RHS)
580        return nullptr;
581    }
582
583    // Merge LHS/RHS.
584    LHS =
585        std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
586  }
587}
588
589/// expression
590///   ::= unary binoprhs
591///
592static std::unique_ptr<ExprAST> ParseExpression() {
593  auto LHS = ParseUnary();
594  if (!LHS)
595    return nullptr;
596
597  return ParseBinOpRHS(0, std::move(LHS));
598}
599
600/// prototype
601///   ::= id '(' id* ')'
602///   ::= binary LETTER number? (id, id)
603///   ::= unary LETTER (id)
604static std::unique_ptr<PrototypeAST> ParsePrototype() {
605  std::string FnName;
606
607  unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
608  unsigned BinaryPrecedence = 30;
609
610  switch (CurTok) {
611  default:
612    return LogErrorP("Expected function name in prototype");
613  case tok_identifier:
614    FnName = IdentifierStr;
615    Kind = 0;
616    getNextToken();
617    break;
618  case tok_unary:
619    getNextToken();
620    if (!isascii(CurTok))
621      return LogErrorP("Expected unary operator");
622    FnName = "unary";
623    FnName += (char)CurTok;
624    Kind = 1;
625    getNextToken();
626    break;
627  case tok_binary:
628    getNextToken();
629    if (!isascii(CurTok))
630      return LogErrorP("Expected binary operator");
631    FnName = "binary";
632    FnName += (char)CurTok;
633    Kind = 2;
634    getNextToken();
635
636    // Read the precedence if present.
637    if (CurTok == tok_number) {
638      if (NumVal < 1 || NumVal > 100)
639        return LogErrorP("Invalid precedecnce: must be 1..100");
640      BinaryPrecedence = (unsigned)NumVal;
641      getNextToken();
642    }
643    break;
644  }
645
646  if (CurTok != '(')
647    return LogErrorP("Expected '(' in prototype");
648
649  std::vector<std::string> ArgNames;
650  while (getNextToken() == tok_identifier)
651    ArgNames.push_back(IdentifierStr);
652  if (CurTok != ')')
653    return LogErrorP("Expected ')' in prototype");
654
655  // success.
656  getNextToken(); // eat ')'.
657
658  // Verify right number of names for operator.
659  if (Kind && ArgNames.size() != Kind)
660    return LogErrorP("Invalid number of operands for operator");
661
662  return std::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
663                                         BinaryPrecedence);
664}
665
666/// definition ::= 'def' prototype expression
667static std::unique_ptr<FunctionAST> ParseDefinition() {
668  getNextToken(); // eat def.
669  auto Proto = ParsePrototype();
670  if (!Proto)
671    return nullptr;
672
673  if (auto E = ParseExpression())
674    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
675  return nullptr;
676}
677
678/// toplevelexpr ::= expression
679static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
680  if (auto E = ParseExpression()) {
681    // Make an anonymous proto.
682    auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
683                                                std::vector<std::string>());
684    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
685  }
686  return nullptr;
687}
688
689/// external ::= 'extern' prototype
690static std::unique_ptr<PrototypeAST> ParseExtern() {
691  getNextToken(); // eat extern.
692  return ParsePrototype();
693}
694
695//===----------------------------------------------------------------------===//
696// Code Generation
697//===----------------------------------------------------------------------===//
698
699static std::unique_ptr<KaleidoscopeJIT> TheJIT;
700static std::unique_ptr<LLVMContext> TheContext;
701static std::unique_ptr<IRBuilder<>> Builder;
702static std::unique_ptr<Module> TheModule;
703static std::map<std::string, AllocaInst *> NamedValues;
704static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
705static ExitOnError ExitOnErr;
706
707Value *LogErrorV(const char *Str) {
708  LogError(Str);
709  return nullptr;
710}
711
712Function *getFunction(std::string Name) {
713  // First, see if the function has already been added to the current module.
714  if (auto *F = TheModule->getFunction(Name))
715    return F;
716
717  // If not, check whether we can codegen the declaration from some existing
718  // prototype.
719  auto FI = FunctionProtos.find(Name);
720  if (FI != FunctionProtos.end())
721    return FI->second->codegen();
722
723  // If no existing prototype exists, return null.
724  return nullptr;
725}
726
727/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
728/// the function.  This is used for mutable variables etc.
729static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
730                                          StringRef VarName) {
731  IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
732                   TheFunction->getEntryBlock().begin());
733  return TmpB.CreateAlloca(Type::getDoubleTy(*TheContext), nullptr, VarName);
734}
735
736Value *NumberExprAST::codegen() {
737  return ConstantFP::get(*TheContext, APFloat(Val));
738}
739
740Value *VariableExprAST::codegen() {
741  // Look this variable up in the function.
742  Value *V = NamedValues[Name];
743  if (!V)
744    return LogErrorV("Unknown variable name");
745
746  // Load the value.
747  return Builder->CreateLoad(Type::getDoubleTy(*TheContext), V, Name.c_str());
748}
749
750Value *UnaryExprAST::codegen() {
751  Value *OperandV = Operand->codegen();
752  if (!OperandV)
753    return nullptr;
754
755  Function *F = getFunction(std::string("unary") + Opcode);
756  if (!F)
757    return LogErrorV("Unknown unary operator");
758
759  return Builder->CreateCall(F, OperandV, "unop");
760}
761
762Value *BinaryExprAST::codegen() {
763  // Special case '=' because we don't want to emit the LHS as an expression.
764  if (Op == '=') {
765    // Assignment requires the LHS to be an identifier.
766    // This assume we're building without RTTI because LLVM builds that way by
767    // default.  If you build LLVM with RTTI this can be changed to a
768    // dynamic_cast for automatic error checking.
769    VariableExprAST *LHSE = static_cast<VariableExprAST *>(LHS.get());
770    if (!LHSE)
771      return LogErrorV("destination of '=' must be a variable");
772    // Codegen the RHS.
773    Value *Val = RHS->codegen();
774    if (!Val)
775      return nullptr;
776
777    // Look up the name.
778    Value *Variable = NamedValues[LHSE->getName()];
779    if (!Variable)
780      return LogErrorV("Unknown variable name");
781
782    Builder->CreateStore(Val, Variable);
783    return Val;
784  }
785
786  Value *L = LHS->codegen();
787  Value *R = RHS->codegen();
788  if (!L || !R)
789    return nullptr;
790
791  switch (Op) {
792  case '+':
793    return Builder->CreateFAdd(L, R, "addtmp");
794  case '-':
795    return Builder->CreateFSub(L, R, "subtmp");
796  case '*':
797    return Builder->CreateFMul(L, R, "multmp");
798  case '<':
799    L = Builder->CreateFCmpULT(L, R, "cmptmp");
800    // Convert bool 0/1 to double 0.0 or 1.0
801    return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
802  default:
803    break;
804  }
805
806  // If it wasn't a builtin binary operator, it must be a user defined one. Emit
807  // a call to it.
808  Function *F = getFunction(std::string("binary") + Op);
809  assert(F && "binary operator not found!");
810
811  Value *Ops[] = {L, R};
812  return Builder->CreateCall(F, Ops, "binop");
813}
814
815Value *CallExprAST::codegen() {
816  // Look up the name in the global module table.
817  Function *CalleeF = getFunction(Callee);
818  if (!CalleeF)
819    return LogErrorV("Unknown function referenced");
820
821  // If argument mismatch error.
822  if (CalleeF->arg_size() != Args.size())
823    return LogErrorV("Incorrect # arguments passed");
824
825  std::vector<Value *> ArgsV;
826  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
827    ArgsV.push_back(Args[i]->codegen());
828    if (!ArgsV.back())
829      return nullptr;
830  }
831
832  return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
833}
834
835Value *IfExprAST::codegen() {
836  Value *CondV = Cond->codegen();
837  if (!CondV)
838    return nullptr;
839
840  // Convert condition to a bool by comparing equal to 0.0.
841  CondV = Builder->CreateFCmpONE(
842      CondV, ConstantFP::get(*TheContext, APFloat(0.0)), "ifcond");
843
844  Function *TheFunction = Builder->GetInsertBlock()->getParent();
845
846  // Create blocks for the then and else cases.  Insert the 'then' block at the
847  // end of the function.
848  BasicBlock *ThenBB = BasicBlock::Create(*TheContext, "then", TheFunction);
849  BasicBlock *ElseBB = BasicBlock::Create(*TheContext, "else");
850  BasicBlock *MergeBB = BasicBlock::Create(*TheContext, "ifcont");
851
852  Builder->CreateCondBr(CondV, ThenBB, ElseBB);
853
854  // Emit then value.
855  Builder->SetInsertPoint(ThenBB);
856
857  Value *ThenV = Then->codegen();
858  if (!ThenV)
859    return nullptr;
860
861  Builder->CreateBr(MergeBB);
862  // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
863  ThenBB = Builder->GetInsertBlock();
864
865  // Emit else block.
866  TheFunction->insert(TheFunction->end(), ElseBB);
867  Builder->SetInsertPoint(ElseBB);
868
869  Value *ElseV = Else->codegen();
870  if (!ElseV)
871    return nullptr;
872
873  Builder->CreateBr(MergeBB);
874  // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
875  ElseBB = Builder->GetInsertBlock();
876
877  // Emit merge block.
878  TheFunction->insert(TheFunction->end(), MergeBB);
879  Builder->SetInsertPoint(MergeBB);
880  PHINode *PN = Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, "iftmp");
881
882  PN->addIncoming(ThenV, ThenBB);
883  PN->addIncoming(ElseV, ElseBB);
884  return PN;
885}
886
887// Output for-loop as:
888//   var = alloca double
889//   ...
890//   start = startexpr
891//   store start -> var
892//   goto loop
893// loop:
894//   ...
895//   bodyexpr
896//   ...
897// loopend:
898//   step = stepexpr
899//   endcond = endexpr
900//
901//   curvar = load var
902//   nextvar = curvar + step
903//   store nextvar -> var
904//   br endcond, loop, endloop
905// outloop:
906Value *ForExprAST::codegen() {
907  Function *TheFunction = Builder->GetInsertBlock()->getParent();
908
909  // Create an alloca for the variable in the entry block.
910  AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
911
912  // Emit the start code first, without 'variable' in scope.
913  Value *StartVal = Start->codegen();
914  if (!StartVal)
915    return nullptr;
916
917  // Store the value into the alloca.
918  Builder->CreateStore(StartVal, Alloca);
919
920  // Make the new basic block for the loop header, inserting after current
921  // block.
922  BasicBlock *LoopBB = BasicBlock::Create(*TheContext, "loop", TheFunction);
923
924  // Insert an explicit fall through from the current block to the LoopBB.
925  Builder->CreateBr(LoopBB);
926
927  // Start insertion in LoopBB.
928  Builder->SetInsertPoint(LoopBB);
929
930  // Within the loop, the variable is defined equal to the PHI node.  If it
931  // shadows an existing variable, we have to restore it, so save it now.
932  AllocaInst *OldVal = NamedValues[VarName];
933  NamedValues[VarName] = Alloca;
934
935  // Emit the body of the loop.  This, like any other expr, can change the
936  // current BB.  Note that we ignore the value computed by the body, but don't
937  // allow an error.
938  if (!Body->codegen())
939    return nullptr;
940
941  // Emit the step value.
942  Value *StepVal = nullptr;
943  if (Step) {
944    StepVal = Step->codegen();
945    if (!StepVal)
946      return nullptr;
947  } else {
948    // If not specified, use 1.0.
949    StepVal = ConstantFP::get(*TheContext, APFloat(1.0));
950  }
951
952  // Compute the end condition.
953  Value *EndCond = End->codegen();
954  if (!EndCond)
955    return nullptr;
956
957  // Reload, increment, and restore the alloca.  This handles the case where
958  // the body of the loop mutates the variable.
959  Value *CurVar = Builder->CreateLoad(Type::getDoubleTy(*TheContext), Alloca,
960                                      VarName.c_str());
961  Value *NextVar = Builder->CreateFAdd(CurVar, StepVal, "nextvar");
962  Builder->CreateStore(NextVar, Alloca);
963
964  // Convert condition to a bool by comparing equal to 0.0.
965  EndCond = Builder->CreateFCmpONE(
966      EndCond, ConstantFP::get(*TheContext, APFloat(0.0)), "loopcond");
967
968  // Create the "after loop" block and insert it.
969  BasicBlock *AfterBB =
970      BasicBlock::Create(*TheContext, "afterloop", TheFunction);
971
972  // Insert the conditional branch into the end of LoopEndBB.
973  Builder->CreateCondBr(EndCond, LoopBB, AfterBB);
974
975  // Any new code will be inserted in AfterBB.
976  Builder->SetInsertPoint(AfterBB);
977
978  // Restore the unshadowed variable.
979  if (OldVal)
980    NamedValues[VarName] = OldVal;
981  else
982    NamedValues.erase(VarName);
983
984  // for expr always returns 0.0.
985  return Constant::getNullValue(Type::getDoubleTy(*TheContext));
986}
987
988Value *VarExprAST::codegen() {
989  std::vector<AllocaInst *> OldBindings;
990
991  Function *TheFunction = Builder->GetInsertBlock()->getParent();
992
993  // Register all variables and emit their initializer.
994  for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
995    const std::string &VarName = VarNames[i].first;
996    ExprAST *Init = VarNames[i].second.get();
997
998    // Emit the initializer before adding the variable to scope, this prevents
999    // the initializer from referencing the variable itself, and permits stuff
1000    // like this:
1001    //  var a = 1 in
1002    //    var a = a in ...   # refers to outer 'a'.
1003    Value *InitVal;
1004    if (Init) {
1005      InitVal = Init->codegen();
1006      if (!InitVal)
1007        return nullptr;
1008    } else { // If not specified, use 0.0.
1009      InitVal = ConstantFP::get(*TheContext, APFloat(0.0));
1010    }
1011
1012    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
1013    Builder->CreateStore(InitVal, Alloca);
1014
1015    // Remember the old variable binding so that we can restore the binding when
1016    // we unrecurse.
1017    OldBindings.push_back(NamedValues[VarName]);
1018
1019    // Remember this binding.
1020    NamedValues[VarName] = Alloca;
1021  }
1022
1023  // Codegen the body, now that all vars are in scope.
1024  Value *BodyVal = Body->codegen();
1025  if (!BodyVal)
1026    return nullptr;
1027
1028  // Pop all our variables from scope.
1029  for (unsigned i = 0, e = VarNames.size(); i != e; ++i)
1030    NamedValues[VarNames[i].first] = OldBindings[i];
1031
1032  // Return the body computation.
1033  return BodyVal;
1034}
1035
1036Function *PrototypeAST::codegen() {
1037  // Make the function type:  double(double,double) etc.
1038  std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
1039  FunctionType *FT =
1040      FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
1041
1042  Function *F =
1043      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
1044
1045  // Set names for all arguments.
1046  unsigned Idx = 0;
1047  for (auto &Arg : F->args())
1048    Arg.setName(Args[Idx++]);
1049
1050  return F;
1051}
1052
1053Function *FunctionAST::codegen() {
1054  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
1055  // reference to it for use below.
1056  auto &P = *Proto;
1057  FunctionProtos[Proto->getName()] = std::move(Proto);
1058  Function *TheFunction = getFunction(P.getName());
1059  if (!TheFunction)
1060    return nullptr;
1061
1062  // If this is an operator, install it.
1063  if (P.isBinaryOp())
1064    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
1065
1066  // Create a new basic block to start insertion into.
1067  BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
1068  Builder->SetInsertPoint(BB);
1069
1070  // Record the function arguments in the NamedValues map.
1071  NamedValues.clear();
1072  for (auto &Arg : TheFunction->args()) {
1073    // Create an alloca for this variable.
1074    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
1075
1076    // Store the initial value into the alloca.
1077    Builder->CreateStore(&Arg, Alloca);
1078
1079    // Add arguments to variable symbol table.
1080    NamedValues[std::string(Arg.getName())] = Alloca;
1081  }
1082
1083  if (Value *RetVal = Body->codegen()) {
1084    // Finish off the function.
1085    Builder->CreateRet(RetVal);
1086
1087    // Validate the generated code, checking for consistency.
1088    verifyFunction(*TheFunction);
1089
1090    return TheFunction;
1091  }
1092
1093  // Error reading body, remove function.
1094  TheFunction->eraseFromParent();
1095
1096  if (P.isBinaryOp())
1097    BinopPrecedence.erase(P.getOperatorName());
1098  return nullptr;
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// Top-Level parsing and JIT Driver
1103//===----------------------------------------------------------------------===//
1104
1105static void InitializeModule() {
1106  // Open a new context and module.
1107  TheContext = std::make_unique<LLVMContext>();
1108  TheModule = std::make_unique<Module>("my cool jit", *TheContext);
1109  TheModule->setDataLayout(TheJIT->getDataLayout());
1110
1111  // Create a new builder for the module.
1112  Builder = std::make_unique<IRBuilder<>>(*TheContext);
1113}
1114
1115static void HandleDefinition() {
1116  if (auto FnAST = ParseDefinition()) {
1117    if (auto *FnIR = FnAST->codegen()) {
1118      fprintf(stderr, "Read function definition:");
1119      FnIR->print(errs());
1120      fprintf(stderr, "\n");
1121      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
1122      ExitOnErr(TheJIT->addModule(std::move(TSM)));
1123      InitializeModule();
1124    }
1125  } else {
1126    // Skip token for error recovery.
1127    getNextToken();
1128  }
1129}
1130
1131static void HandleExtern() {
1132  if (auto ProtoAST = ParseExtern()) {
1133    if (auto *FnIR = ProtoAST->codegen()) {
1134      fprintf(stderr, "Read extern: ");
1135      FnIR->print(errs());
1136      fprintf(stderr, "\n");
1137      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1138    }
1139  } else {
1140    // Skip token for error recovery.
1141    getNextToken();
1142  }
1143}
1144
1145static void HandleTopLevelExpression() {
1146  // Evaluate a top-level expression into an anonymous function.
1147  if (auto FnAST = ParseTopLevelExpr()) {
1148    if (FnAST->codegen()) {
1149      // Create a ResourceTracker to track JIT'd memory allocated to our
1150      // anonymous expression -- that way we can free it after executing.
1151      auto RT = TheJIT->getMainJITDylib().createResourceTracker();
1152
1153      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
1154      ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
1155      InitializeModule();
1156
1157      // Get the anonymous expression's JITSymbol.
1158      auto Sym = ExitOnErr(TheJIT->lookup("__anon_expr"));
1159
1160      // Get the symbol's address and cast it to the right type (takes no
1161      // arguments, returns a double) so we can call it as a native function.
1162      auto *FP = (double (*)())(intptr_t)Sym.getAddress();
1163      fprintf(stderr, "Evaluated to %f\n", FP());
1164
1165      // Delete the anonymous expression module from the JIT.
1166      ExitOnErr(RT->remove());
1167    }
1168  } else {
1169    // Skip token for error recovery.
1170    getNextToken();
1171  }
1172}
1173
1174/// top ::= definition | external | expression | ';'
1175static void MainLoop() {
1176  while (true) {
1177    fprintf(stderr, "ready> ");
1178    switch (CurTok) {
1179    case tok_eof:
1180      return;
1181    case ';': // ignore top-level semicolons.
1182      getNextToken();
1183      break;
1184    case tok_def:
1185      HandleDefinition();
1186      break;
1187    case tok_extern:
1188      HandleExtern();
1189      break;
1190    default:
1191      HandleTopLevelExpression();
1192      break;
1193    }
1194  }
1195}
1196
1197//===----------------------------------------------------------------------===//
1198// "Library" functions that can be "extern'd" from user code.
1199//===----------------------------------------------------------------------===//
1200
1201/// putchard - putchar that takes a double and returns 0.
1202extern "C" double putchard(double X) {
1203  fputc((char)X, stderr);
1204  return 0;
1205}
1206
1207/// printd - printf that takes a double prints it as "%f\n", returning 0.
1208extern "C" double printd(double X) {
1209  fprintf(stderr, "%f\n", X);
1210  return 0;
1211}
1212
1213//===----------------------------------------------------------------------===//
1214// Main driver code.
1215//===----------------------------------------------------------------------===//
1216
1217int main() {
1218  InitializeNativeTarget();
1219  InitializeNativeTargetAsmPrinter();
1220  InitializeNativeTargetAsmParser();
1221
1222  // Install standard binary operators.
1223  // 1 is lowest precedence.
1224  BinopPrecedence['='] = 2;
1225  BinopPrecedence['<'] = 10;
1226  BinopPrecedence['+'] = 20;
1227  BinopPrecedence['-'] = 20;
1228  BinopPrecedence['*'] = 40; // highest.
1229
1230  // Prime the first token.
1231  fprintf(stderr, "ready> ");
1232  getNextToken();
1233
1234  TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
1235  InitializeModule();
1236
1237  // Run the main "interpreter loop" now.
1238  MainLoop();
1239
1240  return 0;
1241}
1242