1#include "llvm/ADT/APFloat.h"
2#include "llvm/ADT/STLExtras.h"
3#include "llvm/IR/BasicBlock.h"
4#include "llvm/IR/Constants.h"
5#include "llvm/IR/DerivedTypes.h"
6#include "llvm/IR/Function.h"
7#include "llvm/IR/Instructions.h"
8#include "llvm/IR/IRBuilder.h"
9#include "llvm/IR/LLVMContext.h"
10#include "llvm/IR/Module.h"
11#include "llvm/IR/Type.h"
12#include "llvm/IR/Verifier.h"
13#include "llvm/Support/TargetSelect.h"
14#include "llvm/Target/TargetMachine.h"
15#include "KaleidoscopeJIT.h"
16#include <algorithm>
17#include <cassert>
18#include <cctype>
19#include <cstdint>
20#include <cstdio>
21#include <cstdlib>
22#include <map>
23#include <memory>
24#include <string>
25#include <utility>
26#include <vector>
27
28using namespace llvm;
29using namespace llvm::orc;
30
31//===----------------------------------------------------------------------===//
32// Lexer
33//===----------------------------------------------------------------------===//
34
35// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
36// of these for known things.
37enum Token {
38  tok_eof = -1,
39
40  // commands
41  tok_def = -2,
42  tok_extern = -3,
43
44  // primary
45  tok_identifier = -4,
46  tok_number = -5,
47
48  // control
49  tok_if = -6,
50  tok_then = -7,
51  tok_else = -8,
52  tok_for = -9,
53  tok_in = -10,
54
55  // operators
56  tok_binary = -11,
57  tok_unary = -12,
58
59  // var definition
60  tok_var = -13
61};
62
63static std::string IdentifierStr; // Filled in if tok_identifier
64static double NumVal;             // Filled in if tok_number
65
66/// gettok - Return the next token from standard input.
67static int gettok() {
68  static int LastChar = ' ';
69
70  // Skip any whitespace.
71  while (isspace(LastChar))
72    LastChar = getchar();
73
74  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
75    IdentifierStr = LastChar;
76    while (isalnum((LastChar = getchar())))
77      IdentifierStr += LastChar;
78
79    if (IdentifierStr == "def")
80      return tok_def;
81    if (IdentifierStr == "extern")
82      return tok_extern;
83    if (IdentifierStr == "if")
84      return tok_if;
85    if (IdentifierStr == "then")
86      return tok_then;
87    if (IdentifierStr == "else")
88      return tok_else;
89    if (IdentifierStr == "for")
90      return tok_for;
91    if (IdentifierStr == "in")
92      return tok_in;
93    if (IdentifierStr == "binary")
94      return tok_binary;
95    if (IdentifierStr == "unary")
96      return tok_unary;
97    if (IdentifierStr == "var")
98      return tok_var;
99    return tok_identifier;
100  }
101
102  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
103    std::string NumStr;
104    do {
105      NumStr += LastChar;
106      LastChar = getchar();
107    } while (isdigit(LastChar) || LastChar == '.');
108
109    NumVal = strtod(NumStr.c_str(), nullptr);
110    return tok_number;
111  }
112
113  if (LastChar == '#') {
114    // Comment until end of line.
115    do
116      LastChar = getchar();
117    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
118
119    if (LastChar != EOF)
120      return gettok();
121  }
122
123  // Check for end of file.  Don't eat the EOF.
124  if (LastChar == EOF)
125    return tok_eof;
126
127  // Otherwise, just return the character as its ascii value.
128  int ThisChar = LastChar;
129  LastChar = getchar();
130  return ThisChar;
131}
132
133//===----------------------------------------------------------------------===//
134// Abstract Syntax Tree (aka Parse Tree)
135//===----------------------------------------------------------------------===//
136
137/// ExprAST - Base class for all expression nodes.
138class ExprAST {
139public:
140  virtual ~ExprAST() = default;
141
142  virtual Value *codegen() = 0;
143};
144
145/// NumberExprAST - Expression class for numeric literals like "1.0".
146class NumberExprAST : public ExprAST {
147  double Val;
148
149public:
150  NumberExprAST(double Val) : Val(Val) {}
151
152  Value *codegen() override;
153};
154
155/// VariableExprAST - Expression class for referencing a variable, like "a".
156class VariableExprAST : public ExprAST {
157  std::string Name;
158
159public:
160  VariableExprAST(const std::string &Name) : Name(Name) {}
161
162  Value *codegen() override;
163  const std::string &getName() const { return Name; }
164};
165
166/// UnaryExprAST - Expression class for a unary operator.
167class UnaryExprAST : public ExprAST {
168  char Opcode;
169  std::unique_ptr<ExprAST> Operand;
170
171public:
172  UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
173      : Opcode(Opcode), Operand(std::move(Operand)) {}
174
175  Value *codegen() override;
176};
177
178/// BinaryExprAST - Expression class for a binary operator.
179class BinaryExprAST : public ExprAST {
180  char Op;
181  std::unique_ptr<ExprAST> LHS, RHS;
182
183public:
184  BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
185                std::unique_ptr<ExprAST> RHS)
186      : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
187
188  Value *codegen() override;
189};
190
191/// CallExprAST - Expression class for function calls.
192class CallExprAST : public ExprAST {
193  std::string Callee;
194  std::vector<std::unique_ptr<ExprAST>> Args;
195
196public:
197  CallExprAST(const std::string &Callee,
198              std::vector<std::unique_ptr<ExprAST>> Args)
199      : Callee(Callee), Args(std::move(Args)) {}
200
201  Value *codegen() override;
202};
203
204/// IfExprAST - Expression class for if/then/else.
205class IfExprAST : public ExprAST {
206  std::unique_ptr<ExprAST> Cond, Then, Else;
207
208public:
209  IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
210            std::unique_ptr<ExprAST> Else)
211      : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
212
213  Value *codegen() override;
214};
215
216/// ForExprAST - Expression class for for/in.
217class ForExprAST : public ExprAST {
218  std::string VarName;
219  std::unique_ptr<ExprAST> Start, End, Step, Body;
220
221public:
222  ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
223             std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
224             std::unique_ptr<ExprAST> Body)
225      : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
226        Step(std::move(Step)), Body(std::move(Body)) {}
227
228  Value *codegen() override;
229};
230
231/// VarExprAST - Expression class for var/in
232class VarExprAST : public ExprAST {
233  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
234  std::unique_ptr<ExprAST> Body;
235
236public:
237  VarExprAST(
238      std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
239      std::unique_ptr<ExprAST> Body)
240      : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
241
242  Value *codegen() override;
243};
244
245/// PrototypeAST - This class represents the "prototype" for a function,
246/// which captures its name, and its argument names (thus implicitly the number
247/// of arguments the function takes), as well as if it is an operator.
248class PrototypeAST {
249  std::string Name;
250  std::vector<std::string> Args;
251  bool IsOperator;
252  unsigned Precedence; // Precedence if a binary op.
253
254public:
255  PrototypeAST(const std::string &Name, std::vector<std::string> Args,
256               bool IsOperator = false, unsigned Prec = 0)
257      : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
258        Precedence(Prec) {}
259
260  Function *codegen();
261  const std::string &getName() const { return Name; }
262
263  bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
264  bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
265
266  char getOperatorName() const {
267    assert(isUnaryOp() || isBinaryOp());
268    return Name[Name.size() - 1];
269  }
270
271  unsigned getBinaryPrecedence() const { return Precedence; }
272};
273
274//===----------------------------------------------------------------------===//
275// Parser
276//===----------------------------------------------------------------------===//
277
278/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
279/// token the parser is looking at.  getNextToken reads another token from the
280/// lexer and updates CurTok with its results.
281static int CurTok;
282static int getNextToken() { return CurTok = gettok(); }
283
284/// BinopPrecedence - This holds the precedence for each binary operator that is
285/// defined.
286static std::map<char, int> BinopPrecedence;
287
288/// GetTokPrecedence - Get the precedence of the pending binary operator token.
289static int GetTokPrecedence() {
290  if (!isascii(CurTok))
291    return -1;
292
293  // Make sure it's a declared binop.
294  int TokPrec = BinopPrecedence[CurTok];
295  if (TokPrec <= 0)
296    return -1;
297  return TokPrec;
298}
299
300/// LogError* - These are little helper functions for error handling.
301std::unique_ptr<ExprAST> LogError(const char *Str) {
302  fprintf(stderr, "Error: %s\n", Str);
303  return nullptr;
304}
305
306std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
307  LogError(Str);
308  return nullptr;
309}
310
311static std::unique_ptr<ExprAST> ParseExpression();
312
313/// numberexpr ::= number
314static std::unique_ptr<ExprAST> ParseNumberExpr() {
315  auto Result = std::make_unique<NumberExprAST>(NumVal);
316  getNextToken(); // consume the number
317  return std::move(Result);
318}
319
320/// parenexpr ::= '(' expression ')'
321static std::unique_ptr<ExprAST> ParseParenExpr() {
322  getNextToken(); // eat (.
323  auto V = ParseExpression();
324  if (!V)
325    return nullptr;
326
327  if (CurTok != ')')
328    return LogError("expected ')'");
329  getNextToken(); // eat ).
330  return V;
331}
332
333/// identifierexpr
334///   ::= identifier
335///   ::= identifier '(' expression* ')'
336static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
337  std::string IdName = IdentifierStr;
338
339  getNextToken(); // eat identifier.
340
341  if (CurTok != '(') // Simple variable ref.
342    return std::make_unique<VariableExprAST>(IdName);
343
344  // Call.
345  getNextToken(); // eat (
346  std::vector<std::unique_ptr<ExprAST>> Args;
347  if (CurTok != ')') {
348    while (true) {
349      if (auto Arg = ParseExpression())
350        Args.push_back(std::move(Arg));
351      else
352        return nullptr;
353
354      if (CurTok == ')')
355        break;
356
357      if (CurTok != ',')
358        return LogError("Expected ')' or ',' in argument list");
359      getNextToken();
360    }
361  }
362
363  // Eat the ')'.
364  getNextToken();
365
366  return std::make_unique<CallExprAST>(IdName, std::move(Args));
367}
368
369/// ifexpr ::= 'if' expression 'then' expression 'else' expression
370static std::unique_ptr<ExprAST> ParseIfExpr() {
371  getNextToken(); // eat the if.
372
373  // condition.
374  auto Cond = ParseExpression();
375  if (!Cond)
376    return nullptr;
377
378  if (CurTok != tok_then)
379    return LogError("expected then");
380  getNextToken(); // eat the then
381
382  auto Then = ParseExpression();
383  if (!Then)
384    return nullptr;
385
386  if (CurTok != tok_else)
387    return LogError("expected else");
388
389  getNextToken();
390
391  auto Else = ParseExpression();
392  if (!Else)
393    return nullptr;
394
395  return std::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
396                                      std::move(Else));
397}
398
399/// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
400static std::unique_ptr<ExprAST> ParseForExpr() {
401  getNextToken(); // eat the for.
402
403  if (CurTok != tok_identifier)
404    return LogError("expected identifier after for");
405
406  std::string IdName = IdentifierStr;
407  getNextToken(); // eat identifier.
408
409  if (CurTok != '=')
410    return LogError("expected '=' after for");
411  getNextToken(); // eat '='.
412
413  auto Start = ParseExpression();
414  if (!Start)
415    return nullptr;
416  if (CurTok != ',')
417    return LogError("expected ',' after for start value");
418  getNextToken();
419
420  auto End = ParseExpression();
421  if (!End)
422    return nullptr;
423
424  // The step value is optional.
425  std::unique_ptr<ExprAST> Step;
426  if (CurTok == ',') {
427    getNextToken();
428    Step = ParseExpression();
429    if (!Step)
430      return nullptr;
431  }
432
433  if (CurTok != tok_in)
434    return LogError("expected 'in' after for");
435  getNextToken(); // eat 'in'.
436
437  auto Body = ParseExpression();
438  if (!Body)
439    return nullptr;
440
441  return std::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
442                                       std::move(Step), std::move(Body));
443}
444
445/// varexpr ::= 'var' identifier ('=' expression)?
446//                    (',' identifier ('=' expression)?)* 'in' expression
447static std::unique_ptr<ExprAST> ParseVarExpr() {
448  getNextToken(); // eat the var.
449
450  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
451
452  // At least one variable name is required.
453  if (CurTok != tok_identifier)
454    return LogError("expected identifier after var");
455
456  while (true) {
457    std::string Name = IdentifierStr;
458    getNextToken(); // eat identifier.
459
460    // Read the optional initializer.
461    std::unique_ptr<ExprAST> Init = nullptr;
462    if (CurTok == '=') {
463      getNextToken(); // eat the '='.
464
465      Init = ParseExpression();
466      if (!Init)
467        return nullptr;
468    }
469
470    VarNames.push_back(std::make_pair(Name, std::move(Init)));
471
472    // End of var list, exit loop.
473    if (CurTok != ',')
474      break;
475    getNextToken(); // eat the ','.
476
477    if (CurTok != tok_identifier)
478      return LogError("expected identifier list after var");
479  }
480
481  // At this point, we have to have 'in'.
482  if (CurTok != tok_in)
483    return LogError("expected 'in' keyword after 'var'");
484  getNextToken(); // eat 'in'.
485
486  auto Body = ParseExpression();
487  if (!Body)
488    return nullptr;
489
490  return std::make_unique<VarExprAST>(std::move(VarNames), std::move(Body));
491}
492
493/// primary
494///   ::= identifierexpr
495///   ::= numberexpr
496///   ::= parenexpr
497///   ::= ifexpr
498///   ::= forexpr
499///   ::= varexpr
500static std::unique_ptr<ExprAST> ParsePrimary() {
501  switch (CurTok) {
502  default:
503    return LogError("unknown token when expecting an expression");
504  case tok_identifier:
505    return ParseIdentifierExpr();
506  case tok_number:
507    return ParseNumberExpr();
508  case '(':
509    return ParseParenExpr();
510  case tok_if:
511    return ParseIfExpr();
512  case tok_for:
513    return ParseForExpr();
514  case tok_var:
515    return ParseVarExpr();
516  }
517}
518
519/// unary
520///   ::= primary
521///   ::= '!' unary
522static std::unique_ptr<ExprAST> ParseUnary() {
523  // If the current token is not an operator, it must be a primary expr.
524  if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
525    return ParsePrimary();
526
527  // If this is a unary operator, read it.
528  int Opc = CurTok;
529  getNextToken();
530  if (auto Operand = ParseUnary())
531    return std::make_unique<UnaryExprAST>(Opc, std::move(Operand));
532  return nullptr;
533}
534
535/// binoprhs
536///   ::= ('+' unary)*
537static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
538                                              std::unique_ptr<ExprAST> LHS) {
539  // If this is a binop, find its precedence.
540  while (true) {
541    int TokPrec = GetTokPrecedence();
542
543    // If this is a binop that binds at least as tightly as the current binop,
544    // consume it, otherwise we are done.
545    if (TokPrec < ExprPrec)
546      return LHS;
547
548    // Okay, we know this is a binop.
549    int BinOp = CurTok;
550    getNextToken(); // eat binop
551
552    // Parse the unary expression after the binary operator.
553    auto RHS = ParseUnary();
554    if (!RHS)
555      return nullptr;
556
557    // If BinOp binds less tightly with RHS than the operator after RHS, let
558    // the pending operator take RHS as its LHS.
559    int NextPrec = GetTokPrecedence();
560    if (TokPrec < NextPrec) {
561      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
562      if (!RHS)
563        return nullptr;
564    }
565
566    // Merge LHS/RHS.
567    LHS =
568        std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
569  }
570}
571
572/// expression
573///   ::= unary binoprhs
574///
575static std::unique_ptr<ExprAST> ParseExpression() {
576  auto LHS = ParseUnary();
577  if (!LHS)
578    return nullptr;
579
580  return ParseBinOpRHS(0, std::move(LHS));
581}
582
583/// prototype
584///   ::= id '(' id* ')'
585///   ::= binary LETTER number? (id, id)
586///   ::= unary LETTER (id)
587static std::unique_ptr<PrototypeAST> ParsePrototype() {
588  std::string FnName;
589
590  unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
591  unsigned BinaryPrecedence = 30;
592
593  switch (CurTok) {
594  default:
595    return LogErrorP("Expected function name in prototype");
596  case tok_identifier:
597    FnName = IdentifierStr;
598    Kind = 0;
599    getNextToken();
600    break;
601  case tok_unary:
602    getNextToken();
603    if (!isascii(CurTok))
604      return LogErrorP("Expected unary operator");
605    FnName = "unary";
606    FnName += (char)CurTok;
607    Kind = 1;
608    getNextToken();
609    break;
610  case tok_binary:
611    getNextToken();
612    if (!isascii(CurTok))
613      return LogErrorP("Expected binary operator");
614    FnName = "binary";
615    FnName += (char)CurTok;
616    Kind = 2;
617    getNextToken();
618
619    // Read the precedence if present.
620    if (CurTok == tok_number) {
621      if (NumVal < 1 || NumVal > 100)
622        return LogErrorP("Invalid precedecnce: must be 1..100");
623      BinaryPrecedence = (unsigned)NumVal;
624      getNextToken();
625    }
626    break;
627  }
628
629  if (CurTok != '(')
630    return LogErrorP("Expected '(' in prototype");
631
632  std::vector<std::string> ArgNames;
633  while (getNextToken() == tok_identifier)
634    ArgNames.push_back(IdentifierStr);
635  if (CurTok != ')')
636    return LogErrorP("Expected ')' in prototype");
637
638  // success.
639  getNextToken(); // eat ')'.
640
641  // Verify right number of names for operator.
642  if (Kind && ArgNames.size() != Kind)
643    return LogErrorP("Invalid number of operands for operator");
644
645  return std::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
646                                         BinaryPrecedence);
647}
648
649/// definition ::= 'def' prototype expression
650static std::unique_ptr<FunctionAST> ParseDefinition() {
651  getNextToken(); // eat def.
652  auto Proto = ParsePrototype();
653  if (!Proto)
654    return nullptr;
655
656  if (auto E = ParseExpression())
657    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
658  return nullptr;
659}
660
661/// toplevelexpr ::= expression
662static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
663  if (auto E = ParseExpression()) {
664    // Make an anonymous proto.
665    auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
666                                                std::vector<std::string>());
667    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
668  }
669  return nullptr;
670}
671
672/// external ::= 'extern' prototype
673static std::unique_ptr<PrototypeAST> ParseExtern() {
674  getNextToken(); // eat extern.
675  return ParsePrototype();
676}
677
678//===----------------------------------------------------------------------===//
679// Code Generation
680//===----------------------------------------------------------------------===//
681
682static std::unique_ptr<KaleidoscopeJIT> TheJIT;
683static std::unique_ptr<LLVMContext> TheContext;
684static std::unique_ptr<IRBuilder<>> Builder;
685static std::unique_ptr<Module> TheModule;
686static std::map<std::string, AllocaInst *> NamedValues;
687static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
688static ExitOnError ExitOnErr;
689
690Value *LogErrorV(const char *Str) {
691  LogError(Str);
692  return nullptr;
693}
694
695Function *getFunction(std::string Name) {
696  // First, see if the function has already been added to the current module.
697  if (auto *F = TheModule->getFunction(Name))
698    return F;
699
700  // If not, check whether we can codegen the declaration from some existing
701  // prototype.
702  auto FI = FunctionProtos.find(Name);
703  if (FI != FunctionProtos.end())
704    return FI->second->codegen();
705
706  // If no existing prototype exists, return null.
707  return nullptr;
708}
709
710/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
711/// the function.  This is used for mutable variables etc.
712static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
713                                          StringRef VarName) {
714  IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
715                   TheFunction->getEntryBlock().begin());
716  return TmpB.CreateAlloca(Type::getDoubleTy(*TheContext), nullptr, VarName);
717}
718
719Value *NumberExprAST::codegen() {
720  return ConstantFP::get(*TheContext, APFloat(Val));
721}
722
723Value *VariableExprAST::codegen() {
724  // Look this variable up in the function.
725  Value *V = NamedValues[Name];
726  if (!V)
727    return LogErrorV("Unknown variable name");
728
729  // Load the value.
730  return Builder->CreateLoad(Type::getDoubleTy(*TheContext), V, Name.c_str());
731}
732
733Value *UnaryExprAST::codegen() {
734  Value *OperandV = Operand->codegen();
735  if (!OperandV)
736    return nullptr;
737
738  Function *F = getFunction(std::string("unary") + Opcode);
739  if (!F)
740    return LogErrorV("Unknown unary operator");
741
742  return Builder->CreateCall(F, OperandV, "unop");
743}
744
745Value *BinaryExprAST::codegen() {
746  // Special case '=' because we don't want to emit the LHS as an expression.
747  if (Op == '=') {
748    // Assignment requires the LHS to be an identifier.
749    // This assume we're building without RTTI because LLVM builds that way by
750    // default.  If you build LLVM with RTTI this can be changed to a
751    // dynamic_cast for automatic error checking.
752    VariableExprAST *LHSE = static_cast<VariableExprAST *>(LHS.get());
753    if (!LHSE)
754      return LogErrorV("destination of '=' must be a variable");
755    // Codegen the RHS.
756    Value *Val = RHS->codegen();
757    if (!Val)
758      return nullptr;
759
760    // Look up the name.
761    Value *Variable = NamedValues[LHSE->getName()];
762    if (!Variable)
763      return LogErrorV("Unknown variable name");
764
765    Builder->CreateStore(Val, Variable);
766    return Val;
767  }
768
769  Value *L = LHS->codegen();
770  Value *R = RHS->codegen();
771  if (!L || !R)
772    return nullptr;
773
774  switch (Op) {
775  case '+':
776    return Builder->CreateFAdd(L, R, "addtmp");
777  case '-':
778    return Builder->CreateFSub(L, R, "subtmp");
779  case '*':
780    return Builder->CreateFMul(L, R, "multmp");
781  case '<':
782    L = Builder->CreateFCmpULT(L, R, "cmptmp");
783    // Convert bool 0/1 to double 0.0 or 1.0
784    return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
785  default:
786    break;
787  }
788
789  // If it wasn't a builtin binary operator, it must be a user defined one. Emit
790  // a call to it.
791  Function *F = getFunction(std::string("binary") + Op);
792  assert(F && "binary operator not found!");
793
794  Value *Ops[] = {L, R};
795  return Builder->CreateCall(F, Ops, "binop");
796}
797
798Value *CallExprAST::codegen() {
799  // Look up the name in the global module table.
800  Function *CalleeF = getFunction(Callee);
801  if (!CalleeF)
802    return LogErrorV("Unknown function referenced");
803
804  // If argument mismatch error.
805  if (CalleeF->arg_size() != Args.size())
806    return LogErrorV("Incorrect # arguments passed");
807
808  std::vector<Value *> ArgsV;
809  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
810    ArgsV.push_back(Args[i]->codegen());
811    if (!ArgsV.back())
812      return nullptr;
813  }
814
815  return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
816}
817
818Value *IfExprAST::codegen() {
819  Value *CondV = Cond->codegen();
820  if (!CondV)
821    return nullptr;
822
823  // Convert condition to a bool by comparing equal to 0.0.
824  CondV = Builder->CreateFCmpONE(
825      CondV, ConstantFP::get(*TheContext, APFloat(0.0)), "ifcond");
826
827  Function *TheFunction = Builder->GetInsertBlock()->getParent();
828
829  // Create blocks for the then and else cases.  Insert the 'then' block at the
830  // end of the function.
831  BasicBlock *ThenBB = BasicBlock::Create(*TheContext, "then", TheFunction);
832  BasicBlock *ElseBB = BasicBlock::Create(*TheContext, "else");
833  BasicBlock *MergeBB = BasicBlock::Create(*TheContext, "ifcont");
834
835  Builder->CreateCondBr(CondV, ThenBB, ElseBB);
836
837  // Emit then value.
838  Builder->SetInsertPoint(ThenBB);
839
840  Value *ThenV = Then->codegen();
841  if (!ThenV)
842    return nullptr;
843
844  Builder->CreateBr(MergeBB);
845  // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
846  ThenBB = Builder->GetInsertBlock();
847
848  // Emit else block.
849  TheFunction->insert(TheFunction->end(), ElseBB);
850  Builder->SetInsertPoint(ElseBB);
851
852  Value *ElseV = Else->codegen();
853  if (!ElseV)
854    return nullptr;
855
856  Builder->CreateBr(MergeBB);
857  // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
858  ElseBB = Builder->GetInsertBlock();
859
860  // Emit merge block.
861  TheFunction->insert(TheFunction->end(), MergeBB);
862  Builder->SetInsertPoint(MergeBB);
863  PHINode *PN = Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, "iftmp");
864
865  PN->addIncoming(ThenV, ThenBB);
866  PN->addIncoming(ElseV, ElseBB);
867  return PN;
868}
869
870// Output for-loop as:
871//   var = alloca double
872//   ...
873//   start = startexpr
874//   store start -> var
875//   goto loop
876// loop:
877//   ...
878//   bodyexpr
879//   ...
880// loopend:
881//   step = stepexpr
882//   endcond = endexpr
883//
884//   curvar = load var
885//   nextvar = curvar + step
886//   store nextvar -> var
887//   br endcond, loop, endloop
888// outloop:
889Value *ForExprAST::codegen() {
890  Function *TheFunction = Builder->GetInsertBlock()->getParent();
891
892  // Create an alloca for the variable in the entry block.
893  AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
894
895  // Emit the start code first, without 'variable' in scope.
896  Value *StartVal = Start->codegen();
897  if (!StartVal)
898    return nullptr;
899
900  // Store the value into the alloca.
901  Builder->CreateStore(StartVal, Alloca);
902
903  // Make the new basic block for the loop header, inserting after current
904  // block.
905  BasicBlock *LoopBB = BasicBlock::Create(*TheContext, "loop", TheFunction);
906
907  // Insert an explicit fall through from the current block to the LoopBB.
908  Builder->CreateBr(LoopBB);
909
910  // Start insertion in LoopBB.
911  Builder->SetInsertPoint(LoopBB);
912
913  // Within the loop, the variable is defined equal to the PHI node.  If it
914  // shadows an existing variable, we have to restore it, so save it now.
915  AllocaInst *OldVal = NamedValues[VarName];
916  NamedValues[VarName] = Alloca;
917
918  // Emit the body of the loop.  This, like any other expr, can change the
919  // current BB.  Note that we ignore the value computed by the body, but don't
920  // allow an error.
921  if (!Body->codegen())
922    return nullptr;
923
924  // Emit the step value.
925  Value *StepVal = nullptr;
926  if (Step) {
927    StepVal = Step->codegen();
928    if (!StepVal)
929      return nullptr;
930  } else {
931    // If not specified, use 1.0.
932    StepVal = ConstantFP::get(*TheContext, APFloat(1.0));
933  }
934
935  // Compute the end condition.
936  Value *EndCond = End->codegen();
937  if (!EndCond)
938    return nullptr;
939
940  // Reload, increment, and restore the alloca.  This handles the case where
941  // the body of the loop mutates the variable.
942  Value *CurVar = Builder->CreateLoad(Type::getDoubleTy(*TheContext), Alloca,
943                                      VarName.c_str());
944  Value *NextVar = Builder->CreateFAdd(CurVar, StepVal, "nextvar");
945  Builder->CreateStore(NextVar, Alloca);
946
947  // Convert condition to a bool by comparing equal to 0.0.
948  EndCond = Builder->CreateFCmpONE(
949      EndCond, ConstantFP::get(*TheContext, APFloat(0.0)), "loopcond");
950
951  // Create the "after loop" block and insert it.
952  BasicBlock *AfterBB =
953      BasicBlock::Create(*TheContext, "afterloop", TheFunction);
954
955  // Insert the conditional branch into the end of LoopEndBB.
956  Builder->CreateCondBr(EndCond, LoopBB, AfterBB);
957
958  // Any new code will be inserted in AfterBB.
959  Builder->SetInsertPoint(AfterBB);
960
961  // Restore the unshadowed variable.
962  if (OldVal)
963    NamedValues[VarName] = OldVal;
964  else
965    NamedValues.erase(VarName);
966
967  // for expr always returns 0.0.
968  return Constant::getNullValue(Type::getDoubleTy(*TheContext));
969}
970
971Value *VarExprAST::codegen() {
972  std::vector<AllocaInst *> OldBindings;
973
974  Function *TheFunction = Builder->GetInsertBlock()->getParent();
975
976  // Register all variables and emit their initializer.
977  for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
978    const std::string &VarName = VarNames[i].first;
979    ExprAST *Init = VarNames[i].second.get();
980
981    // Emit the initializer before adding the variable to scope, this prevents
982    // the initializer from referencing the variable itself, and permits stuff
983    // like this:
984    //  var a = 1 in
985    //    var a = a in ...   # refers to outer 'a'.
986    Value *InitVal;
987    if (Init) {
988      InitVal = Init->codegen();
989      if (!InitVal)
990        return nullptr;
991    } else { // If not specified, use 0.0.
992      InitVal = ConstantFP::get(*TheContext, APFloat(0.0));
993    }
994
995    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
996    Builder->CreateStore(InitVal, Alloca);
997
998    // Remember the old variable binding so that we can restore the binding when
999    // we unrecurse.
1000    OldBindings.push_back(NamedValues[VarName]);
1001
1002    // Remember this binding.
1003    NamedValues[VarName] = Alloca;
1004  }
1005
1006  // Codegen the body, now that all vars are in scope.
1007  Value *BodyVal = Body->codegen();
1008  if (!BodyVal)
1009    return nullptr;
1010
1011  // Pop all our variables from scope.
1012  for (unsigned i = 0, e = VarNames.size(); i != e; ++i)
1013    NamedValues[VarNames[i].first] = OldBindings[i];
1014
1015  // Return the body computation.
1016  return BodyVal;
1017}
1018
1019Function *PrototypeAST::codegen() {
1020  // Make the function type:  double(double,double) etc.
1021  std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
1022  FunctionType *FT =
1023      FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
1024
1025  Function *F =
1026      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
1027
1028  // Set names for all arguments.
1029  unsigned Idx = 0;
1030  for (auto &Arg : F->args())
1031    Arg.setName(Args[Idx++]);
1032
1033  return F;
1034}
1035
1036const PrototypeAST& FunctionAST::getProto() const {
1037  return *Proto;
1038}
1039
1040const std::string& FunctionAST::getName() const {
1041  return Proto->getName();
1042}
1043
1044Function *FunctionAST::codegen() {
1045  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
1046  // reference to it for use below.
1047  auto &P = *Proto;
1048  FunctionProtos[Proto->getName()] = std::move(Proto);
1049  Function *TheFunction = getFunction(P.getName());
1050  if (!TheFunction)
1051    return nullptr;
1052
1053  // If this is an operator, install it.
1054  if (P.isBinaryOp())
1055    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
1056
1057  // Create a new basic block to start insertion into.
1058  BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
1059  Builder->SetInsertPoint(BB);
1060
1061  // Record the function arguments in the NamedValues map.
1062  NamedValues.clear();
1063  for (auto &Arg : TheFunction->args()) {
1064    // Create an alloca for this variable.
1065    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
1066
1067    // Store the initial value into the alloca.
1068    Builder->CreateStore(&Arg, Alloca);
1069
1070    // Add arguments to variable symbol table.
1071    NamedValues[std::string(Arg.getName())] = Alloca;
1072  }
1073
1074  if (Value *RetVal = Body->codegen()) {
1075    // Finish off the function.
1076    Builder->CreateRet(RetVal);
1077
1078    // Validate the generated code, checking for consistency.
1079    verifyFunction(*TheFunction);
1080
1081    return TheFunction;
1082  }
1083
1084  // Error reading body, remove function.
1085  TheFunction->eraseFromParent();
1086
1087  if (P.isBinaryOp())
1088    BinopPrecedence.erase(P.getOperatorName());
1089  return nullptr;
1090}
1091
1092//===----------------------------------------------------------------------===//
1093// Top-Level parsing and JIT Driver
1094//===----------------------------------------------------------------------===//
1095
1096static void InitializeModule() {
1097  // Open a new context and module.
1098  TheContext = std::make_unique<LLVMContext>();
1099  TheModule = std::make_unique<Module>("my cool jit", *TheContext);
1100  TheModule->setDataLayout(TheJIT->getDataLayout());
1101
1102  // Create a new builder for the module.
1103  Builder = std::make_unique<IRBuilder<>>(*TheContext);
1104}
1105
1106ThreadSafeModule irgenAndTakeOwnership(FunctionAST &FnAST,
1107                                       const std::string &Suffix) {
1108  if (auto *F = FnAST.codegen()) {
1109    F->setName(F->getName() + Suffix);
1110    auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
1111    // Start a new module.
1112    InitializeModule();
1113    return TSM;
1114  } else
1115    report_fatal_error("Couldn't compile lazily JIT'd function");
1116}
1117
1118static void HandleDefinition() {
1119  if (auto FnAST = ParseDefinition()) {
1120    FunctionProtos[FnAST->getProto().getName()] =
1121      std::make_unique<PrototypeAST>(FnAST->getProto());
1122    ExitOnErr(TheJIT->addAST(std::move(FnAST)));
1123  } else {
1124    // Skip token for error recovery.
1125    getNextToken();
1126  }
1127}
1128
1129static void HandleExtern() {
1130  if (auto ProtoAST = ParseExtern()) {
1131    if (auto *FnIR = ProtoAST->codegen()) {
1132      fprintf(stderr, "Read extern: ");
1133      FnIR->print(errs());
1134      fprintf(stderr, "\n");
1135      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1136    }
1137  } else {
1138    // Skip token for error recovery.
1139    getNextToken();
1140  }
1141}
1142
1143static void HandleTopLevelExpression() {
1144  // Evaluate a top-level expression into an anonymous function.
1145  if (auto FnAST = ParseTopLevelExpr()) {
1146    if (FnAST->codegen()) {
1147      // Create a ResourceTracker to track JIT'd memory allocated to our
1148      // anonymous expression -- that way we can free it after executing.
1149      auto RT = TheJIT->getMainJITDylib().createResourceTracker();
1150
1151      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
1152      ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
1153      InitializeModule();
1154
1155      // Get the anonymous expression's JITSymbol.
1156      auto Sym = ExitOnErr(TheJIT->lookup("__anon_expr"));
1157
1158      // Get the symbol's address and cast it to the right type (takes no
1159      // arguments, returns a double) so we can call it as a native function.
1160      auto *FP = (double (*)())(intptr_t)Sym.getAddress();
1161      fprintf(stderr, "Evaluated to %f\n", FP());
1162
1163      // Delete the anonymous expression module from the JIT.
1164      ExitOnErr(RT->remove());
1165    }
1166  } else {
1167    // Skip token for error recovery.
1168    getNextToken();
1169  }
1170}
1171
1172/// top ::= definition | external | expression | ';'
1173static void MainLoop() {
1174  while (true) {
1175    fprintf(stderr, "ready> ");
1176    switch (CurTok) {
1177    case tok_eof:
1178      return;
1179    case ';': // ignore top-level semicolons.
1180      getNextToken();
1181      break;
1182    case tok_def:
1183      HandleDefinition();
1184      break;
1185    case tok_extern:
1186      HandleExtern();
1187      break;
1188    default:
1189      HandleTopLevelExpression();
1190      break;
1191    }
1192  }
1193}
1194
1195//===----------------------------------------------------------------------===//
1196// "Library" functions that can be "extern'd" from user code.
1197//===----------------------------------------------------------------------===//
1198
1199/// putchard - putchar that takes a double and returns 0.
1200extern "C" double putchard(double X) {
1201  fputc((char)X, stderr);
1202  return 0;
1203}
1204
1205/// printd - printf that takes a double prints it as "%f\n", returning 0.
1206extern "C" double printd(double X) {
1207  fprintf(stderr, "%f\n", X);
1208  return 0;
1209}
1210
1211//===----------------------------------------------------------------------===//
1212// Main driver code.
1213//===----------------------------------------------------------------------===//
1214
1215int main() {
1216  InitializeNativeTarget();
1217  InitializeNativeTargetAsmPrinter();
1218  InitializeNativeTargetAsmParser();
1219
1220  // Install standard binary operators.
1221  // 1 is lowest precedence.
1222  BinopPrecedence['='] = 2;
1223  BinopPrecedence['<'] = 10;
1224  BinopPrecedence['+'] = 20;
1225  BinopPrecedence['-'] = 20;
1226  BinopPrecedence['*'] = 40; // highest.
1227
1228  // Prime the first token.
1229  fprintf(stderr, "ready> ");
1230  getNextToken();
1231
1232  TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
1233  InitializeModule();
1234
1235  // Run the main "interpreter loop" now.
1236  MainLoop();
1237
1238  return 0;
1239}
1240