1//===- CoroElide.cpp - Coroutine Frame Allocation Elision Pass ------------===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#include "llvm/Transforms/Coroutines/CoroElide.h" 10#include "CoroInternal.h" 11#include "llvm/ADT/DenseMap.h" 12#include "llvm/Analysis/AliasAnalysis.h" 13#include "llvm/Analysis/InstructionSimplify.h" 14#include "llvm/IR/Dominators.h" 15#include "llvm/IR/InstIterator.h" 16#include "llvm/InitializePasses.h" 17#include "llvm/Pass.h" 18#include "llvm/Support/ErrorHandling.h" 19 20using namespace llvm; 21 22#define DEBUG_TYPE "coro-elide" 23 24namespace { 25// Created on demand if the coro-elide pass has work to do. 26struct Lowerer : coro::LowererBase { 27 SmallVector<CoroIdInst *, 4> CoroIds; 28 SmallVector<CoroBeginInst *, 1> CoroBegins; 29 SmallVector<CoroAllocInst *, 1> CoroAllocs; 30 SmallVector<CoroSubFnInst *, 4> ResumeAddr; 31 DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr; 32 SmallVector<CoroFreeInst *, 1> CoroFrees; 33 SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches; 34 35 Lowerer(Module &M) : LowererBase(M) {} 36 37 void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign, 38 AAResults &AA); 39 bool shouldElide(Function *F, DominatorTree &DT) const; 40 void collectPostSplitCoroIds(Function *F); 41 bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT); 42 bool hasEscapePath(const CoroBeginInst *, 43 const SmallPtrSetImpl<BasicBlock *> &) const; 44}; 45} // end anonymous namespace 46 47// Go through the list of coro.subfn.addr intrinsics and replace them with the 48// provided constant. 49static void replaceWithConstant(Constant *Value, 50 SmallVectorImpl<CoroSubFnInst *> &Users) { 51 if (Users.empty()) 52 return; 53 54 // See if we need to bitcast the constant to match the type of the intrinsic 55 // being replaced. Note: All coro.subfn.addr intrinsics return the same type, 56 // so we only need to examine the type of the first one in the list. 57 Type *IntrTy = Users.front()->getType(); 58 Type *ValueTy = Value->getType(); 59 if (ValueTy != IntrTy) { 60 // May need to tweak the function type to match the type expected at the 61 // use site. 62 assert(ValueTy->isPointerTy() && IntrTy->isPointerTy()); 63 Value = ConstantExpr::getBitCast(Value, IntrTy); 64 } 65 66 // Now the value type matches the type of the intrinsic. Replace them all! 67 for (CoroSubFnInst *I : Users) 68 replaceAndRecursivelySimplify(I, Value); 69} 70 71// See if any operand of the call instruction references the coroutine frame. 72static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) { 73 for (Value *Op : CI->operand_values()) 74 if (AA.alias(Op, Frame) != NoAlias) 75 return true; 76 return false; 77} 78 79// Look for any tail calls referencing the coroutine frame and remove tail 80// attribute from them, since now coroutine frame resides on the stack and tail 81// call implies that the function does not references anything on the stack. 82static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { 83 Function &F = *Frame->getFunction(); 84 for (Instruction &I : instructions(F)) 85 if (auto *Call = dyn_cast<CallInst>(&I)) 86 if (Call->isTailCall() && operandReferences(Call, Frame, AA)) { 87 // FIXME: If we ever hit this check. Evaluate whether it is more 88 // appropriate to retain musttail and allow the code to compile. 89 if (Call->isMustTailCall()) 90 report_fatal_error("Call referring to the coroutine frame cannot be " 91 "marked as musttail"); 92 Call->setTailCall(false); 93 } 94} 95 96// Given a resume function @f.resume(%f.frame* %frame), returns the size 97// and expected alignment of %f.frame type. 98static std::pair<uint64_t, Align> getFrameLayout(Function *Resume) { 99 // Prefer to pull information from the function attributes. 100 auto Size = Resume->getParamDereferenceableBytes(0); 101 auto Align = Resume->getParamAlign(0); 102 103 // If those aren't given, extract them from the type. 104 if (Size == 0 || !Align) { 105 auto *FrameTy = Resume->arg_begin()->getType()->getPointerElementType(); 106 107 const DataLayout &DL = Resume->getParent()->getDataLayout(); 108 if (!Size) Size = DL.getTypeAllocSize(FrameTy); 109 if (!Align) Align = DL.getABITypeAlign(FrameTy); 110 } 111 112 return std::make_pair(Size, *Align); 113} 114 115// Finds first non alloca instruction in the entry block of a function. 116static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { 117 for (Instruction &I : F->getEntryBlock()) 118 if (!isa<AllocaInst>(&I)) 119 return &I; 120 llvm_unreachable("no terminator in the entry block"); 121} 122 123// To elide heap allocations we need to suppress code blocks guarded by 124// llvm.coro.alloc and llvm.coro.free instructions. 125void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize, 126 Align FrameAlign, AAResults &AA) { 127 LLVMContext &C = F->getContext(); 128 auto *InsertPt = 129 getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction()); 130 131 // Replacing llvm.coro.alloc with false will suppress dynamic 132 // allocation as it is expected for the frontend to generate the code that 133 // looks like: 134 // id = coro.id(...) 135 // mem = coro.alloc(id) ? malloc(coro.size()) : 0; 136 // coro.begin(id, mem) 137 auto *False = ConstantInt::getFalse(C); 138 for (auto *CA : CoroAllocs) { 139 CA->replaceAllUsesWith(False); 140 CA->eraseFromParent(); 141 } 142 143 // FIXME: Design how to transmit alignment information for every alloca that 144 // is spilled into the coroutine frame and recreate the alignment information 145 // here. Possibly we will need to do a mini SROA here and break the coroutine 146 // frame into individual AllocaInst recreating the original alignment. 147 const DataLayout &DL = F->getParent()->getDataLayout(); 148 auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); 149 auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); 150 Frame->setAlignment(FrameAlign); 151 auto *FrameVoidPtr = 152 new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt); 153 154 for (auto *CB : CoroBegins) { 155 CB->replaceAllUsesWith(FrameVoidPtr); 156 CB->eraseFromParent(); 157 } 158 159 // Since now coroutine frame lives on the stack we need to make sure that 160 // any tail call referencing it, must be made non-tail call. 161 removeTailCallAttribute(Frame, AA); 162} 163 164bool Lowerer::hasEscapePath(const CoroBeginInst *CB, 165 const SmallPtrSetImpl<BasicBlock *> &TIs) const { 166 const auto &It = DestroyAddr.find(CB); 167 assert(It != DestroyAddr.end()); 168 169 // Limit the number of blocks we visit. 170 unsigned Limit = 32 * (1 + It->second.size()); 171 172 SmallVector<const BasicBlock *, 32> Worklist; 173 Worklist.push_back(CB->getParent()); 174 175 SmallPtrSet<const BasicBlock *, 32> Visited; 176 // Consider basicblock of coro.destroy as visited one, so that we 177 // skip the path pass through coro.destroy. 178 for (auto *DA : It->second) 179 Visited.insert(DA->getParent()); 180 181 do { 182 const auto *BB = Worklist.pop_back_val(); 183 if (!Visited.insert(BB).second) 184 continue; 185 if (TIs.count(BB)) 186 return true; 187 188 // Conservatively say that there is potentially a path. 189 if (!--Limit) 190 return true; 191 192 auto TI = BB->getTerminator(); 193 // Although the default dest of coro.suspend switches is suspend pointer 194 // which means a escape path to normal terminator, it is reasonable to skip 195 // it since coroutine frame doesn't change outside the coroutine body. 196 if (isa<SwitchInst>(TI) && 197 CoroSuspendSwitches.count(cast<SwitchInst>(TI))) { 198 Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1)); 199 Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2)); 200 } else 201 Worklist.append(succ_begin(BB), succ_end(BB)); 202 203 } while (!Worklist.empty()); 204 205 // We have exhausted all possible paths and are certain that coro.begin can 206 // not reach to any of terminators. 207 return false; 208} 209 210bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { 211 // If no CoroAllocs, we cannot suppress allocation, so elision is not 212 // possible. 213 if (CoroAllocs.empty()) 214 return false; 215 216 // Check that for every coro.begin there is at least one coro.destroy directly 217 // referencing the SSA value of that coro.begin along each 218 // non-exceptional path. 219 // If the value escaped, then coro.destroy would have been referencing a 220 // memory location storing that value and not the virtual register. 221 222 SmallPtrSet<BasicBlock *, 8> Terminators; 223 // First gather all of the non-exceptional terminators for the function. 224 // Consider the final coro.suspend as the real terminator when the current 225 // function is a coroutine. 226 for (BasicBlock &B : *F) { 227 auto *TI = B.getTerminator(); 228 if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && 229 !isa<UnreachableInst>(TI)) 230 Terminators.insert(&B); 231 } 232 233 // Filter out the coro.destroy that lie along exceptional paths. 234 SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; 235 for (auto &It : DestroyAddr) { 236 for (Instruction *DA : It.second) { 237 for (BasicBlock *TI : Terminators) { 238 if (DT.dominates(DA, TI->getTerminator())) { 239 ReferencedCoroBegins.insert(It.first); 240 break; 241 } 242 } 243 } 244 245 // Whether there is any paths from coro.begin to Terminators which not pass 246 // through any of the coro.destroys. 247 if (!ReferencedCoroBegins.count(It.first) && 248 !hasEscapePath(It.first, Terminators)) 249 ReferencedCoroBegins.insert(It.first); 250 } 251 252 // If size of the set is the same as total number of coro.begin, that means we 253 // found a coro.free or coro.destroy referencing each coro.begin, so we can 254 // perform heap elision. 255 return ReferencedCoroBegins.size() == CoroBegins.size(); 256} 257 258void Lowerer::collectPostSplitCoroIds(Function *F) { 259 CoroIds.clear(); 260 CoroSuspendSwitches.clear(); 261 for (auto &I : instructions(F)) { 262 if (auto *CII = dyn_cast<CoroIdInst>(&I)) 263 if (CII->getInfo().isPostSplit()) 264 // If it is the coroutine itself, don't touch it. 265 if (CII->getCoroutine() != CII->getFunction()) 266 CoroIds.push_back(CII); 267 268 // Consider case like: 269 // %0 = call i8 @llvm.coro.suspend(...) 270 // switch i8 %0, label %suspend [i8 0, label %resume 271 // i8 1, label %cleanup] 272 // and collect the SwitchInsts which are used by escape analysis later. 273 if (auto *CSI = dyn_cast<CoroSuspendInst>(&I)) 274 if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) { 275 SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser()); 276 if (SWI->getNumCases() == 2) 277 CoroSuspendSwitches.insert(SWI); 278 } 279 } 280} 281 282bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, 283 DominatorTree &DT) { 284 CoroBegins.clear(); 285 CoroAllocs.clear(); 286 CoroFrees.clear(); 287 ResumeAddr.clear(); 288 DestroyAddr.clear(); 289 290 // Collect all coro.begin and coro.allocs associated with this coro.id. 291 for (User *U : CoroId->users()) { 292 if (auto *CB = dyn_cast<CoroBeginInst>(U)) 293 CoroBegins.push_back(CB); 294 else if (auto *CA = dyn_cast<CoroAllocInst>(U)) 295 CoroAllocs.push_back(CA); 296 else if (auto *CF = dyn_cast<CoroFreeInst>(U)) 297 CoroFrees.push_back(CF); 298 } 299 300 // Collect all coro.subfn.addrs associated with coro.begin. 301 // Note, we only devirtualize the calls if their coro.subfn.addr refers to 302 // coro.begin directly. If we run into cases where this check is too 303 // conservative, we can consider relaxing the check. 304 for (CoroBeginInst *CB : CoroBegins) { 305 for (User *U : CB->users()) 306 if (auto *II = dyn_cast<CoroSubFnInst>(U)) 307 switch (II->getIndex()) { 308 case CoroSubFnInst::ResumeIndex: 309 ResumeAddr.push_back(II); 310 break; 311 case CoroSubFnInst::DestroyIndex: 312 DestroyAddr[CB].push_back(II); 313 break; 314 default: 315 llvm_unreachable("unexpected coro.subfn.addr constant"); 316 } 317 } 318 319 // PostSplit coro.id refers to an array of subfunctions in its Info 320 // argument. 321 ConstantArray *Resumers = CoroId->getInfo().Resumers; 322 assert(Resumers && "PostSplit coro.id Info argument must refer to an array" 323 "of coroutine subfunctions"); 324 auto *ResumeAddrConstant = 325 ConstantExpr::getExtractValue(Resumers, CoroSubFnInst::ResumeIndex); 326 327 replaceWithConstant(ResumeAddrConstant, ResumeAddr); 328 329 bool ShouldElide = shouldElide(CoroId->getFunction(), DT); 330 331 auto *DestroyAddrConstant = ConstantExpr::getExtractValue( 332 Resumers, 333 ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); 334 335 for (auto &It : DestroyAddr) 336 replaceWithConstant(DestroyAddrConstant, It.second); 337 338 if (ShouldElide) { 339 auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant)); 340 elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign.first, 341 FrameSizeAndAlign.second, AA); 342 coro::replaceCoroFree(CoroId, /*Elide=*/true); 343 } 344 345 return true; 346} 347 348// See if there are any coro.subfn.addr instructions referring to coro.devirt 349// trigger, if so, replace them with a direct call to devirt trigger function. 350static bool replaceDevirtTrigger(Function &F) { 351 SmallVector<CoroSubFnInst *, 1> DevirtAddr; 352 for (auto &I : instructions(F)) 353 if (auto *SubFn = dyn_cast<CoroSubFnInst>(&I)) 354 if (SubFn->getIndex() == CoroSubFnInst::RestartTrigger) 355 DevirtAddr.push_back(SubFn); 356 357 if (DevirtAddr.empty()) 358 return false; 359 360 Module &M = *F.getParent(); 361 Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); 362 assert(DevirtFn && "coro.devirt.fn not found"); 363 replaceWithConstant(DevirtFn, DevirtAddr); 364 365 return true; 366} 367 368static bool declaresCoroElideIntrinsics(Module &M) { 369 return coro::declaresIntrinsics(M, {"llvm.coro.id"}); 370} 371 372PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { 373 auto &M = *F.getParent(); 374 if (!declaresCoroElideIntrinsics(M)) 375 return PreservedAnalyses::all(); 376 377 Lowerer L(M); 378 L.CoroIds.clear(); 379 L.collectPostSplitCoroIds(&F); 380 // If we did not find any coro.id, there is nothing to do. 381 if (L.CoroIds.empty()) 382 return PreservedAnalyses::all(); 383 384 AAResults &AA = AM.getResult<AAManager>(F); 385 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); 386 387 bool Changed = false; 388 for (auto *CII : L.CoroIds) 389 Changed |= L.processCoroId(CII, AA, DT); 390 391 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 392} 393 394namespace { 395struct CoroElideLegacy : FunctionPass { 396 static char ID; 397 CoroElideLegacy() : FunctionPass(ID) { 398 initializeCoroElideLegacyPass(*PassRegistry::getPassRegistry()); 399 } 400 401 std::unique_ptr<Lowerer> L; 402 403 bool doInitialization(Module &M) override { 404 if (declaresCoroElideIntrinsics(M)) 405 L = std::make_unique<Lowerer>(M); 406 return false; 407 } 408 409 bool runOnFunction(Function &F) override { 410 if (!L) 411 return false; 412 413 bool Changed = false; 414 415 if (F.hasFnAttribute(CORO_PRESPLIT_ATTR)) 416 Changed = replaceDevirtTrigger(F); 417 418 L->CoroIds.clear(); 419 L->collectPostSplitCoroIds(&F); 420 // If we did not find any coro.id, there is nothing to do. 421 if (L->CoroIds.empty()) 422 return Changed; 423 424 AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 425 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 426 427 for (auto *CII : L->CoroIds) 428 Changed |= L->processCoroId(CII, AA, DT); 429 430 return Changed; 431 } 432 void getAnalysisUsage(AnalysisUsage &AU) const override { 433 AU.addRequired<AAResultsWrapperPass>(); 434 AU.addRequired<DominatorTreeWrapperPass>(); 435 } 436 StringRef getPassName() const override { return "Coroutine Elision"; } 437}; 438} 439 440char CoroElideLegacy::ID = 0; 441INITIALIZE_PASS_BEGIN( 442 CoroElideLegacy, "coro-elide", 443 "Coroutine frame allocation elision and indirect calls replacement", false, 444 false) 445INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 446INITIALIZE_PASS_END( 447 CoroElideLegacy, "coro-elide", 448 "Coroutine frame allocation elision and indirect calls replacement", false, 449 false) 450 451Pass *llvm::createCoroElideLegacyPass() { return new CoroElideLegacy(); } 452