1//===--- ByteCodeStmtGen.cpp - Code generator for expressions ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ByteCodeStmtGen.h"
10#include "ByteCodeEmitter.h"
11#include "ByteCodeGenError.h"
12#include "Context.h"
13#include "Function.h"
14#include "PrimType.h"
15#include "Program.h"
16#include "State.h"
17#include "clang/Basic/LLVM.h"
18
19using namespace clang;
20using namespace clang::interp;
21
22namespace clang {
23namespace interp {
24
25/// Scope managing label targets.
26template <class Emitter> class LabelScope {
27public:
28  virtual ~LabelScope() {  }
29
30protected:
31  LabelScope(ByteCodeStmtGen<Emitter> *Ctx) : Ctx(Ctx) {}
32  /// ByteCodeStmtGen instance.
33  ByteCodeStmtGen<Emitter> *Ctx;
34};
35
36/// Sets the context for break/continue statements.
37template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
38public:
39  using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
40  using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
41
42  LoopScope(ByteCodeStmtGen<Emitter> *Ctx, LabelTy BreakLabel,
43            LabelTy ContinueLabel)
44      : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
45        OldContinueLabel(Ctx->ContinueLabel) {
46    this->Ctx->BreakLabel = BreakLabel;
47    this->Ctx->ContinueLabel = ContinueLabel;
48  }
49
50  ~LoopScope() {
51    this->Ctx->BreakLabel = OldBreakLabel;
52    this->Ctx->ContinueLabel = OldContinueLabel;
53  }
54
55private:
56  OptLabelTy OldBreakLabel;
57  OptLabelTy OldContinueLabel;
58};
59
60// Sets the context for a switch scope, mapping labels.
61template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
62public:
63  using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
64  using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
65  using CaseMap = typename ByteCodeStmtGen<Emitter>::CaseMap;
66
67  SwitchScope(ByteCodeStmtGen<Emitter> *Ctx, CaseMap &&CaseLabels,
68              LabelTy BreakLabel, OptLabelTy DefaultLabel)
69      : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
70        OldDefaultLabel(this->Ctx->DefaultLabel),
71        OldCaseLabels(std::move(this->Ctx->CaseLabels)) {
72    this->Ctx->BreakLabel = BreakLabel;
73    this->Ctx->DefaultLabel = DefaultLabel;
74    this->Ctx->CaseLabels = std::move(CaseLabels);
75  }
76
77  ~SwitchScope() {
78    this->Ctx->BreakLabel = OldBreakLabel;
79    this->Ctx->DefaultLabel = OldDefaultLabel;
80    this->Ctx->CaseLabels = std::move(OldCaseLabels);
81  }
82
83private:
84  OptLabelTy OldBreakLabel;
85  OptLabelTy OldDefaultLabel;
86  CaseMap OldCaseLabels;
87};
88
89} // namespace interp
90} // namespace clang
91
92template <class Emitter>
93bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
94  // Classify the return type.
95  ReturnType = this->classify(F->getReturnType());
96
97  // Constructor. Set up field initializers.
98  if (const auto Ctor = dyn_cast<CXXConstructorDecl>(F)) {
99    const RecordDecl *RD = Ctor->getParent();
100    const Record *R = this->getRecord(RD);
101    if (!R)
102      return false;
103
104    for (const auto *Init : Ctor->inits()) {
105      const Expr *InitExpr = Init->getInit();
106      if (const FieldDecl *Member = Init->getMember()) {
107        const Record::Field *F = R->getField(Member);
108
109        if (std::optional<PrimType> T = this->classify(InitExpr)) {
110          if (!this->emitThis(InitExpr))
111            return false;
112
113          if (!this->visit(InitExpr))
114            return false;
115
116          if (!this->emitInitField(*T, F->Offset, InitExpr))
117            return false;
118
119          if (!this->emitPopPtr(InitExpr))
120            return false;
121        } else {
122          // Non-primitive case. Get a pointer to the field-to-initialize
123          // on the stack and call visitInitialzer() for it.
124          if (!this->emitThis(InitExpr))
125            return false;
126
127          if (!this->emitGetPtrField(F->Offset, InitExpr))
128            return false;
129
130          if (!this->visitInitializer(InitExpr))
131            return false;
132
133          if (!this->emitPopPtr(InitExpr))
134            return false;
135        }
136      } else if (const Type *Base = Init->getBaseClass()) {
137        // Base class initializer.
138        // Get This Base and call initializer on it.
139        auto *BaseDecl = Base->getAsCXXRecordDecl();
140        assert(BaseDecl);
141        const Record::Base *B = R->getBase(BaseDecl);
142        assert(B);
143        if (!this->emitGetPtrThisBase(B->Offset, InitExpr))
144          return false;
145        if (!this->visitInitializer(InitExpr))
146          return false;
147        if (!this->emitPopPtr(InitExpr))
148          return false;
149      }
150    }
151  }
152
153  if (const auto *Body = F->getBody())
154    if (!visitStmt(Body))
155      return false;
156
157  // Emit a guard return to protect against a code path missing one.
158  if (F->getReturnType()->isVoidType())
159    return this->emitRetVoid(SourceInfo{});
160  else
161    return this->emitNoRet(SourceInfo{});
162}
163
164template <class Emitter>
165bool ByteCodeStmtGen<Emitter>::visitStmt(const Stmt *S) {
166  switch (S->getStmtClass()) {
167  case Stmt::CompoundStmtClass:
168    return visitCompoundStmt(cast<CompoundStmt>(S));
169  case Stmt::DeclStmtClass:
170    return visitDeclStmt(cast<DeclStmt>(S));
171  case Stmt::ReturnStmtClass:
172    return visitReturnStmt(cast<ReturnStmt>(S));
173  case Stmt::IfStmtClass:
174    return visitIfStmt(cast<IfStmt>(S));
175  case Stmt::WhileStmtClass:
176    return visitWhileStmt(cast<WhileStmt>(S));
177  case Stmt::DoStmtClass:
178    return visitDoStmt(cast<DoStmt>(S));
179  case Stmt::ForStmtClass:
180    return visitForStmt(cast<ForStmt>(S));
181  case Stmt::BreakStmtClass:
182    return visitBreakStmt(cast<BreakStmt>(S));
183  case Stmt::ContinueStmtClass:
184    return visitContinueStmt(cast<ContinueStmt>(S));
185  case Stmt::NullStmtClass:
186    return true;
187  default: {
188    if (auto *Exp = dyn_cast<Expr>(S))
189      return this->discard(Exp);
190    return this->bail(S);
191  }
192  }
193}
194
195template <class Emitter>
196bool ByteCodeStmtGen<Emitter>::visitCompoundStmt(
197    const CompoundStmt *CompoundStmt) {
198  BlockScope<Emitter> Scope(this);
199  for (auto *InnerStmt : CompoundStmt->body())
200    if (!visitStmt(InnerStmt))
201      return false;
202  return true;
203}
204
205template <class Emitter>
206bool ByteCodeStmtGen<Emitter>::visitDeclStmt(const DeclStmt *DS) {
207  for (auto *D : DS->decls()) {
208    // Variable declarator.
209    if (auto *VD = dyn_cast<VarDecl>(D)) {
210      if (!this->visitVarDecl(VD))
211        return false;
212      continue;
213    }
214
215    // Decomposition declarator.
216    if (auto *DD = dyn_cast<DecompositionDecl>(D)) {
217      return this->bail(DD);
218    }
219  }
220
221  return true;
222}
223
224template <class Emitter>
225bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
226  if (const Expr *RE = RS->getRetValue()) {
227    ExprScope<Emitter> RetScope(this);
228    if (ReturnType) {
229      // Primitive types are simply returned.
230      if (!this->visit(RE))
231        return false;
232      this->emitCleanup();
233      return this->emitRet(*ReturnType, RS);
234    } else {
235      // RVO - construct the value in the return location.
236      if (!this->emitRVOPtr(RE))
237        return false;
238      if (!this->visitInitializer(RE))
239        return false;
240      if (!this->emitPopPtr(RE))
241        return false;
242
243      this->emitCleanup();
244      return this->emitRetVoid(RS);
245    }
246  }
247
248  // Void return.
249  this->emitCleanup();
250  return this->emitRetVoid(RS);
251}
252
253template <class Emitter>
254bool ByteCodeStmtGen<Emitter>::visitIfStmt(const IfStmt *IS) {
255  BlockScope<Emitter> IfScope(this);
256
257  if (IS->isNonNegatedConsteval())
258    return visitStmt(IS->getThen());
259  if (IS->isNegatedConsteval())
260    return IS->getElse() ? visitStmt(IS->getElse()) : true;
261
262  if (auto *CondInit = IS->getInit())
263    if (!visitStmt(IS->getInit()))
264      return false;
265
266  if (const DeclStmt *CondDecl = IS->getConditionVariableDeclStmt())
267    if (!visitDeclStmt(CondDecl))
268      return false;
269
270  if (!this->visitBool(IS->getCond()))
271    return false;
272
273  if (const Stmt *Else = IS->getElse()) {
274    LabelTy LabelElse = this->getLabel();
275    LabelTy LabelEnd = this->getLabel();
276    if (!this->jumpFalse(LabelElse))
277      return false;
278    if (!visitStmt(IS->getThen()))
279      return false;
280    if (!this->jump(LabelEnd))
281      return false;
282    this->emitLabel(LabelElse);
283    if (!visitStmt(Else))
284      return false;
285    this->emitLabel(LabelEnd);
286  } else {
287    LabelTy LabelEnd = this->getLabel();
288    if (!this->jumpFalse(LabelEnd))
289      return false;
290    if (!visitStmt(IS->getThen()))
291      return false;
292    this->emitLabel(LabelEnd);
293  }
294
295  return true;
296}
297
298template <class Emitter>
299bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *S) {
300  const Expr *Cond = S->getCond();
301  const Stmt *Body = S->getBody();
302
303  LabelTy CondLabel = this->getLabel(); // Label before the condition.
304  LabelTy EndLabel = this->getLabel();  // Label after the loop.
305  LoopScope<Emitter> LS(this, EndLabel, CondLabel);
306
307  this->emitLabel(CondLabel);
308  if (!this->visitBool(Cond))
309    return false;
310  if (!this->jumpFalse(EndLabel))
311    return false;
312
313  if (!this->visitStmt(Body))
314    return false;
315  if (!this->jump(CondLabel))
316    return false;
317
318  this->emitLabel(EndLabel);
319
320  return true;
321}
322
323template <class Emitter>
324bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *S) {
325  const Expr *Cond = S->getCond();
326  const Stmt *Body = S->getBody();
327
328  LabelTy StartLabel = this->getLabel();
329  LabelTy EndLabel = this->getLabel();
330  LabelTy CondLabel = this->getLabel();
331  LoopScope<Emitter> LS(this, EndLabel, CondLabel);
332
333  this->emitLabel(StartLabel);
334  if (!this->visitStmt(Body))
335    return false;
336  this->emitLabel(CondLabel);
337  if (!this->visitBool(Cond))
338    return false;
339  if (!this->jumpTrue(StartLabel))
340    return false;
341  this->emitLabel(EndLabel);
342  return true;
343}
344
345template <class Emitter>
346bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *S) {
347  // for (Init; Cond; Inc) { Body }
348  const Stmt *Init = S->getInit();
349  const Expr *Cond = S->getCond();
350  const Expr *Inc = S->getInc();
351  const Stmt *Body = S->getBody();
352
353  LabelTy EndLabel = this->getLabel();
354  LabelTy CondLabel = this->getLabel();
355  LabelTy IncLabel = this->getLabel();
356  LoopScope<Emitter> LS(this, EndLabel, IncLabel);
357
358  if (Init && !this->visitStmt(Init))
359    return false;
360  this->emitLabel(CondLabel);
361  if (Cond) {
362    if (!this->visitBool(Cond))
363      return false;
364    if (!this->jumpFalse(EndLabel))
365      return false;
366  }
367  if (Body && !this->visitStmt(Body))
368    return false;
369  this->emitLabel(IncLabel);
370  if (Inc && !this->discard(Inc))
371    return false;
372  if (!this->jump(CondLabel))
373    return false;
374  this->emitLabel(EndLabel);
375  return true;
376}
377
378template <class Emitter>
379bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *S) {
380  if (!BreakLabel)
381    return false;
382
383  return this->jump(*BreakLabel);
384}
385
386template <class Emitter>
387bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *S) {
388  if (!ContinueLabel)
389    return false;
390
391  return this->jump(*ContinueLabel);
392}
393
394namespace clang {
395namespace interp {
396
397template class ByteCodeStmtGen<ByteCodeEmitter>;
398
399} // namespace interp
400} // namespace clang
401