1//===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Coroutines/CoroEarly.h"
10#include "CoroInternal.h"
11#include "llvm/IR/IRBuilder.h"
12#include "llvm/IR/InstIterator.h"
13#include "llvm/IR/Module.h"
14#include "llvm/Pass.h"
15
16using namespace llvm;
17
18#define DEBUG_TYPE "coro-early"
19
20namespace {
21// Created on demand if the coro-early pass has work to do.
22class Lowerer : public coro::LowererBase {
23  IRBuilder<> Builder;
24  PointerType *const AnyResumeFnPtrTy;
25  Constant *NoopCoro = nullptr;
26
27  void lowerResumeOrDestroy(CallBase &CB, CoroSubFnInst::ResumeKind);
28  void lowerCoroPromise(CoroPromiseInst *Intrin);
29  void lowerCoroDone(IntrinsicInst *II);
30  void lowerCoroNoop(IntrinsicInst *II);
31
32public:
33  Lowerer(Module &M)
34      : LowererBase(M), Builder(Context),
35        AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
36                                           /*isVarArg=*/false)
37                             ->getPointerTo()) {}
38  bool lowerEarlyIntrinsics(Function &F);
39};
40}
41
42// Replace a direct call to coro.resume or coro.destroy with an indirect call to
43// an address returned by coro.subfn.addr intrinsic. This is done so that
44// CGPassManager recognizes devirtualization when CoroElide pass replaces a call
45// to coro.subfn.addr with an appropriate function address.
46void Lowerer::lowerResumeOrDestroy(CallBase &CB,
47                                   CoroSubFnInst::ResumeKind Index) {
48  Value *ResumeAddr = makeSubFnCall(CB.getArgOperand(0), Index, &CB);
49  CB.setCalledOperand(ResumeAddr);
50  CB.setCallingConv(CallingConv::Fast);
51}
52
53// Coroutine promise field is always at the fixed offset from the beginning of
54// the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
55// to a passed pointer to move from coroutine frame to coroutine promise and
56// vice versa. Since we don't know exactly which coroutine frame it is, we build
57// a coroutine frame mock up starting with two function pointers, followed by a
58// properly aligned coroutine promise field.
59// TODO: Handle the case when coroutine promise alloca has align override.
60void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
61  Value *Operand = Intrin->getArgOperand(0);
62  Align Alignment = Intrin->getAlignment();
63  Type *Int8Ty = Builder.getInt8Ty();
64
65  auto *SampleStruct =
66      StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
67  const DataLayout &DL = TheModule.getDataLayout();
68  int64_t Offset = alignTo(
69      DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignment);
70  if (Intrin->isFromPromise())
71    Offset = -Offset;
72
73  Builder.SetInsertPoint(Intrin);
74  Value *Replacement =
75      Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);
76
77  Intrin->replaceAllUsesWith(Replacement);
78  Intrin->eraseFromParent();
79}
80
81// When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
82// the coroutine frame (it is UB to resume from a final suspend point).
83// The llvm.coro.done intrinsic is used to check whether a coroutine is
84// suspended at the final suspend point or not.
85void Lowerer::lowerCoroDone(IntrinsicInst *II) {
86  Value *Operand = II->getArgOperand(0);
87
88  // ResumeFnAddr is the first pointer sized element of the coroutine frame.
89  static_assert(coro::Shape::SwitchFieldIndex::Resume == 0,
90                "resume function not at offset zero");
91  auto *FrameTy = Int8Ptr;
92  PointerType *FramePtrTy = FrameTy->getPointerTo();
93
94  Builder.SetInsertPoint(II);
95  auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
96  auto *Load = Builder.CreateLoad(FrameTy, BCI);
97  auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
98
99  II->replaceAllUsesWith(Cond);
100  II->eraseFromParent();
101}
102
103void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
104  if (!NoopCoro) {
105    LLVMContext &C = Builder.getContext();
106    Module &M = *II->getModule();
107
108    // Create a noop.frame struct type.
109    StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
110    auto *FramePtrTy = FrameTy->getPointerTo();
111    auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
112                                   /*isVarArg=*/false);
113    auto *FnPtrTy = FnTy->getPointerTo();
114    FrameTy->setBody({FnPtrTy, FnPtrTy});
115
116    // Create a Noop function that does nothing.
117    Function *NoopFn =
118        Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
119                         "NoopCoro.ResumeDestroy", &M);
120    NoopFn->setCallingConv(CallingConv::Fast);
121    auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
122    ReturnInst::Create(C, Entry);
123
124    // Create a constant struct for the frame.
125    Constant* Values[] = {NoopFn, NoopFn};
126    Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
127    NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
128                                GlobalVariable::PrivateLinkage, NoopCoroConst,
129                                "NoopCoro.Frame.Const");
130  }
131
132  Builder.SetInsertPoint(II);
133  auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
134  II->replaceAllUsesWith(NoopCoroVoidPtr);
135  II->eraseFromParent();
136}
137
138// Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
139// as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
140// NoDuplicate attribute will be removed from coro.begin otherwise, it will
141// interfere with inlining.
142static void setCannotDuplicate(CoroIdInst *CoroId) {
143  for (User *U : CoroId->users())
144    if (auto *CB = dyn_cast<CoroBeginInst>(U))
145      CB->setCannotDuplicate();
146}
147
148bool Lowerer::lowerEarlyIntrinsics(Function &F) {
149  bool Changed = false;
150  CoroIdInst *CoroId = nullptr;
151  SmallVector<CoroFreeInst *, 4> CoroFrees;
152  for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
153    Instruction &I = *IB++;
154    if (auto *CB = dyn_cast<CallBase>(&I)) {
155      switch (CB->getIntrinsicID()) {
156      default:
157        continue;
158      case Intrinsic::coro_free:
159        CoroFrees.push_back(cast<CoroFreeInst>(&I));
160        break;
161      case Intrinsic::coro_suspend:
162        // Make sure that final suspend point is not duplicated as CoroSplit
163        // pass expects that there is at most one final suspend point.
164        if (cast<CoroSuspendInst>(&I)->isFinal())
165          CB->setCannotDuplicate();
166        break;
167      case Intrinsic::coro_end:
168        // Make sure that fallthrough coro.end is not duplicated as CoroSplit
169        // pass expects that there is at most one fallthrough coro.end.
170        if (cast<CoroEndInst>(&I)->isFallthrough())
171          CB->setCannotDuplicate();
172        break;
173      case Intrinsic::coro_noop:
174        lowerCoroNoop(cast<IntrinsicInst>(&I));
175        break;
176      case Intrinsic::coro_id:
177        // Mark a function that comes out of the frontend that has a coro.id
178        // with a coroutine attribute.
179        if (auto *CII = cast<CoroIdInst>(&I)) {
180          if (CII->getInfo().isPreSplit()) {
181            F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
182            setCannotDuplicate(CII);
183            CII->setCoroutineSelf();
184            CoroId = cast<CoroIdInst>(&I);
185          }
186        }
187        break;
188      case Intrinsic::coro_id_retcon:
189      case Intrinsic::coro_id_retcon_once:
190        F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
191        break;
192      case Intrinsic::coro_resume:
193        lowerResumeOrDestroy(*CB, CoroSubFnInst::ResumeIndex);
194        break;
195      case Intrinsic::coro_destroy:
196        lowerResumeOrDestroy(*CB, CoroSubFnInst::DestroyIndex);
197        break;
198      case Intrinsic::coro_promise:
199        lowerCoroPromise(cast<CoroPromiseInst>(&I));
200        break;
201      case Intrinsic::coro_done:
202        lowerCoroDone(cast<IntrinsicInst>(&I));
203        break;
204      }
205      Changed = true;
206    }
207  }
208  // Make sure that all CoroFree reference the coro.id intrinsic.
209  // Token type is not exposed through coroutine C/C++ builtins to plain C, so
210  // we allow specifying none and fixing it up here.
211  if (CoroId)
212    for (CoroFreeInst *CF : CoroFrees)
213      CF->setArgOperand(0, CoroId);
214  return Changed;
215}
216
217static bool declaresCoroEarlyIntrinsics(const Module &M) {
218  return coro::declaresIntrinsics(
219      M, {"llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.retcon.once",
220          "llvm.coro.destroy", "llvm.coro.done", "llvm.coro.end",
221          "llvm.coro.noop", "llvm.coro.free", "llvm.coro.promise",
222          "llvm.coro.resume", "llvm.coro.suspend"});
223}
224
225PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) {
226  Module &M = *F.getParent();
227  if (!declaresCoroEarlyIntrinsics(M) || !Lowerer(M).lowerEarlyIntrinsics(F))
228    return PreservedAnalyses::all();
229
230  PreservedAnalyses PA;
231  PA.preserveSet<CFGAnalyses>();
232  return PA;
233}
234
235namespace {
236
237struct CoroEarlyLegacy : public FunctionPass {
238  static char ID; // Pass identification, replacement for typeid.
239  CoroEarlyLegacy() : FunctionPass(ID) {
240    initializeCoroEarlyLegacyPass(*PassRegistry::getPassRegistry());
241  }
242
243  std::unique_ptr<Lowerer> L;
244
245  // This pass has work to do only if we find intrinsics we are going to lower
246  // in the module.
247  bool doInitialization(Module &M) override {
248    if (declaresCoroEarlyIntrinsics(M))
249      L = std::make_unique<Lowerer>(M);
250    return false;
251  }
252
253  bool runOnFunction(Function &F) override {
254    if (!L)
255      return false;
256
257    return L->lowerEarlyIntrinsics(F);
258  }
259
260  void getAnalysisUsage(AnalysisUsage &AU) const override {
261    AU.setPreservesCFG();
262  }
263  StringRef getPassName() const override {
264    return "Lower early coroutine intrinsics";
265  }
266};
267}
268
269char CoroEarlyLegacy::ID = 0;
270INITIALIZE_PASS(CoroEarlyLegacy, "coro-early",
271                "Lower early coroutine intrinsics", false, false)
272
273Pass *llvm::createCoroEarlyLegacyPass() { return new CoroEarlyLegacy(); }
274