1311116Sdim//===- CoroCleanup.cpp - Coroutine Cleanup Pass ---------------------------===//
2311116Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6311116Sdim//
7311116Sdim//===----------------------------------------------------------------------===//
8311116Sdim// This pass lowers all remaining coroutine intrinsics.
9311116Sdim//===----------------------------------------------------------------------===//
10311116Sdim
11311116Sdim#include "CoroInternal.h"
12311116Sdim#include "llvm/IR/IRBuilder.h"
13311116Sdim#include "llvm/IR/InstIterator.h"
14311116Sdim#include "llvm/IR/LegacyPassManager.h"
15311116Sdim#include "llvm/Pass.h"
16311116Sdim#include "llvm/Transforms/Scalar.h"
17311116Sdim
18311116Sdimusing namespace llvm;
19311116Sdim
20311116Sdim#define DEBUG_TYPE "coro-cleanup"
21311116Sdim
22311116Sdimnamespace {
23311116Sdim// Created on demand if CoroCleanup pass has work to do.
24311116Sdimstruct Lowerer : coro::LowererBase {
25311116Sdim  IRBuilder<> Builder;
26311116Sdim  Lowerer(Module &M) : LowererBase(M), Builder(Context) {}
27311116Sdim  bool lowerRemainingCoroIntrinsics(Function &F);
28311116Sdim};
29311116Sdim}
30311116Sdim
31311116Sdimstatic void simplifyCFG(Function &F) {
32311116Sdim  llvm::legacy::FunctionPassManager FPM(F.getParent());
33311116Sdim  FPM.add(createCFGSimplificationPass());
34311116Sdim
35311116Sdim  FPM.doInitialization();
36311116Sdim  FPM.run(F);
37311116Sdim  FPM.doFinalization();
38311116Sdim}
39311116Sdim
40311116Sdimstatic void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) {
41311116Sdim  Builder.SetInsertPoint(SubFn);
42311116Sdim  Value *FrameRaw = SubFn->getFrame();
43311116Sdim  int Index = SubFn->getIndex();
44311116Sdim
45311116Sdim  auto *FrameTy = StructType::get(
46311116Sdim      SubFn->getContext(), {Builder.getInt8PtrTy(), Builder.getInt8PtrTy()});
47311116Sdim  PointerType *FramePtrTy = FrameTy->getPointerTo();
48311116Sdim
49311116Sdim  Builder.SetInsertPoint(SubFn);
50311116Sdim  auto *FramePtr = Builder.CreateBitCast(FrameRaw, FramePtrTy);
51311116Sdim  auto *Gep = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index);
52353358Sdim  auto *Load = Builder.CreateLoad(FrameTy->getElementType(Index), Gep);
53311116Sdim
54311116Sdim  SubFn->replaceAllUsesWith(Load);
55311116Sdim}
56311116Sdim
57311116Sdimbool Lowerer::lowerRemainingCoroIntrinsics(Function &F) {
58311116Sdim  bool Changed = false;
59311116Sdim
60311116Sdim  for (auto IB = inst_begin(F), E = inst_end(F); IB != E;) {
61311116Sdim    Instruction &I = *IB++;
62311116Sdim    if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
63311116Sdim      switch (II->getIntrinsicID()) {
64311116Sdim      default:
65311116Sdim        continue;
66311116Sdim      case Intrinsic::coro_begin:
67311116Sdim        II->replaceAllUsesWith(II->getArgOperand(1));
68311116Sdim        break;
69311116Sdim      case Intrinsic::coro_free:
70311116Sdim        II->replaceAllUsesWith(II->getArgOperand(1));
71311116Sdim        break;
72311116Sdim      case Intrinsic::coro_alloc:
73311116Sdim        II->replaceAllUsesWith(ConstantInt::getTrue(Context));
74311116Sdim        break;
75311116Sdim      case Intrinsic::coro_id:
76360784Sdim      case Intrinsic::coro_id_retcon:
77360784Sdim      case Intrinsic::coro_id_retcon_once:
78311116Sdim        II->replaceAllUsesWith(ConstantTokenNone::get(Context));
79311116Sdim        break;
80311116Sdim      case Intrinsic::coro_subfn_addr:
81311116Sdim        lowerSubFn(Builder, cast<CoroSubFnInst>(II));
82311116Sdim        break;
83311116Sdim      }
84311116Sdim      II->eraseFromParent();
85311116Sdim      Changed = true;
86311116Sdim    }
87311116Sdim  }
88311116Sdim
89311116Sdim  if (Changed) {
90311116Sdim    // After replacement were made we can cleanup the function body a little.
91311116Sdim    simplifyCFG(F);
92311116Sdim  }
93311116Sdim  return Changed;
94311116Sdim}
95311116Sdim
96311116Sdim//===----------------------------------------------------------------------===//
97311116Sdim//                              Top Level Driver
98311116Sdim//===----------------------------------------------------------------------===//
99311116Sdim
100311116Sdimnamespace {
101311116Sdim
102360784Sdimstruct CoroCleanupLegacy : FunctionPass {
103311116Sdim  static char ID; // Pass identification, replacement for typeid
104311116Sdim
105360784Sdim  CoroCleanupLegacy() : FunctionPass(ID) {
106360784Sdim    initializeCoroCleanupLegacyPass(*PassRegistry::getPassRegistry());
107321369Sdim  }
108311116Sdim
109311116Sdim  std::unique_ptr<Lowerer> L;
110311116Sdim
111311116Sdim  // This pass has work to do only if we find intrinsics we are going to lower
112311116Sdim  // in the module.
113311116Sdim  bool doInitialization(Module &M) override {
114311116Sdim    if (coro::declaresIntrinsics(M, {"llvm.coro.alloc", "llvm.coro.begin",
115311116Sdim                                     "llvm.coro.subfn.addr", "llvm.coro.free",
116360784Sdim                                     "llvm.coro.id", "llvm.coro.id.retcon",
117360784Sdim                                     "llvm.coro.id.retcon.once"}))
118360784Sdim      L = std::make_unique<Lowerer>(M);
119311116Sdim    return false;
120311116Sdim  }
121311116Sdim
122311116Sdim  bool runOnFunction(Function &F) override {
123311116Sdim    if (L)
124311116Sdim      return L->lowerRemainingCoroIntrinsics(F);
125311116Sdim    return false;
126311116Sdim  }
127311116Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
128311116Sdim    if (!L)
129311116Sdim      AU.setPreservesAll();
130311116Sdim  }
131321369Sdim  StringRef getPassName() const override { return "Coroutine Cleanup"; }
132311116Sdim};
133311116Sdim}
134311116Sdim
135360784Sdimchar CoroCleanupLegacy::ID = 0;
136360784SdimINITIALIZE_PASS(CoroCleanupLegacy, "coro-cleanup",
137311116Sdim                "Lower all coroutine related intrinsics", false, false)
138311116Sdim
139360784SdimPass *llvm::createCoroCleanupLegacyPass() { return new CoroCleanupLegacy(); }
140