1317017Sdim//===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
2317017Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6317017Sdim//
7317017Sdim//===----------------------------------------------------------------------===//
8317017Sdim//
9317017Sdim// This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
10317017Sdim// there is at most one ret and one unreachable instruction, it ensures there is
11317017Sdim// at most one divergent exiting block.
12317017Sdim//
13317017Sdim// StructurizeCFG can't deal with multi-exit regions formed by branches to
14317017Sdim// multiple return nodes. It is not desirable to structurize regions with
15317017Sdim// uniform branches, so unifying those to the same return block as divergent
16317017Sdim// branches inhibits use of scalar branching. It still can't deal with the case
17317017Sdim// where one branch goes to return, and one unreachable. Replace unreachable in
18317017Sdim// this case with a return.
19317017Sdim//
20317017Sdim//===----------------------------------------------------------------------===//
21317017Sdim
22317017Sdim#include "AMDGPU.h"
23327952Sdim#include "llvm/ADT/ArrayRef.h"
24327952Sdim#include "llvm/ADT/SmallPtrSet.h"
25327952Sdim#include "llvm/ADT/SmallVector.h"
26327952Sdim#include "llvm/ADT/StringRef.h"
27344779Sdim#include "llvm/Analysis/LegacyDivergenceAnalysis.h"
28317017Sdim#include "llvm/Analysis/PostDominators.h"
29317017Sdim#include "llvm/Analysis/TargetTransformInfo.h"
30317017Sdim#include "llvm/IR/BasicBlock.h"
31317017Sdim#include "llvm/IR/CFG.h"
32327952Sdim#include "llvm/IR/Constants.h"
33317017Sdim#include "llvm/IR/Function.h"
34327952Sdim#include "llvm/IR/InstrTypes.h"
35317017Sdim#include "llvm/IR/Instructions.h"
36327952Sdim#include "llvm/IR/Intrinsics.h"
37360784Sdim#include "llvm/IR/IRBuilder.h"
38317017Sdim#include "llvm/IR/Type.h"
39360784Sdim#include "llvm/InitializePasses.h"
40327952Sdim#include "llvm/Pass.h"
41327952Sdim#include "llvm/Support/Casting.h"
42317017Sdim#include "llvm/Transforms/Scalar.h"
43341825Sdim#include "llvm/Transforms/Utils.h"
44360784Sdim#include "llvm/Transforms/Utils/Local.h"
45327952Sdim
46317017Sdimusing namespace llvm;
47317017Sdim
48317017Sdim#define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
49317017Sdim
50317017Sdimnamespace {
51317017Sdim
52317017Sdimclass AMDGPUUnifyDivergentExitNodes : public FunctionPass {
53317017Sdimpublic:
54317017Sdim  static char ID; // Pass identification, replacement for typeid
55327952Sdim
56317017Sdim  AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) {
57317017Sdim    initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry());
58317017Sdim  }
59317017Sdim
60317017Sdim  // We can preserve non-critical-edgeness when we unify function exit nodes
61317017Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override;
62317017Sdim  bool runOnFunction(Function &F) override;
63317017Sdim};
64317017Sdim
65327952Sdim} // end anonymous namespace
66317017Sdim
67317017Sdimchar AMDGPUUnifyDivergentExitNodes::ID = 0;
68327952Sdim
69327952Sdimchar &llvm::AMDGPUUnifyDivergentExitNodesID = AMDGPUUnifyDivergentExitNodes::ID;
70327952Sdim
71317017SdimINITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
72317017Sdim                     "Unify divergent function exit nodes", false, false)
73317017SdimINITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
74344779SdimINITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis)
75317017SdimINITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
76317017Sdim                    "Unify divergent function exit nodes", false, false)
77317017Sdim
78317017Sdimvoid AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
79317017Sdim  // TODO: Preserve dominator tree.
80317017Sdim  AU.addRequired<PostDominatorTreeWrapperPass>();
81317017Sdim
82344779Sdim  AU.addRequired<LegacyDivergenceAnalysis>();
83317017Sdim
84317017Sdim  // No divergent values are changed, only blocks and branch edges.
85344779Sdim  AU.addPreserved<LegacyDivergenceAnalysis>();
86317017Sdim
87317017Sdim  // We preserve the non-critical-edgeness property
88317017Sdim  AU.addPreservedID(BreakCriticalEdgesID);
89317017Sdim
90317017Sdim  // This is a cluster of orthogonal Transforms
91317017Sdim  AU.addPreservedID(LowerSwitchID);
92317017Sdim  FunctionPass::getAnalysisUsage(AU);
93317017Sdim
94317017Sdim  AU.addRequired<TargetTransformInfoWrapperPass>();
95317017Sdim}
96317017Sdim
97317017Sdim/// \returns true if \p BB is reachable through only uniform branches.
98317017Sdim/// XXX - Is there a more efficient way to find this?
99344779Sdimstatic bool isUniformlyReached(const LegacyDivergenceAnalysis &DA,
100317017Sdim                               BasicBlock &BB) {
101317017Sdim  SmallVector<BasicBlock *, 8> Stack;
102317017Sdim  SmallPtrSet<BasicBlock *, 8> Visited;
103317017Sdim
104317017Sdim  for (BasicBlock *Pred : predecessors(&BB))
105317017Sdim    Stack.push_back(Pred);
106317017Sdim
107317017Sdim  while (!Stack.empty()) {
108317017Sdim    BasicBlock *Top = Stack.pop_back_val();
109317017Sdim    if (!DA.isUniform(Top->getTerminator()))
110317017Sdim      return false;
111317017Sdim
112317017Sdim    for (BasicBlock *Pred : predecessors(Top)) {
113317017Sdim      if (Visited.insert(Pred).second)
114317017Sdim        Stack.push_back(Pred);
115317017Sdim    }
116317017Sdim  }
117317017Sdim
118317017Sdim  return true;
119317017Sdim}
120317017Sdim
121360784Sdimstatic void removeDoneExport(Function &F) {
122360784Sdim  ConstantInt *BoolFalse = ConstantInt::getFalse(F.getContext());
123360784Sdim  for (BasicBlock &BB : F) {
124360784Sdim    for (Instruction &I : BB) {
125360784Sdim      if (IntrinsicInst *Intrin = llvm::dyn_cast<IntrinsicInst>(&I)) {
126360784Sdim        if (Intrin->getIntrinsicID() == Intrinsic::amdgcn_exp) {
127360784Sdim          Intrin->setArgOperand(6, BoolFalse); // done
128360784Sdim        } else if (Intrin->getIntrinsicID() == Intrinsic::amdgcn_exp_compr) {
129360784Sdim          Intrin->setArgOperand(4, BoolFalse); // done
130360784Sdim        }
131360784Sdim      }
132360784Sdim    }
133360784Sdim  }
134360784Sdim}
135360784Sdim
136317017Sdimstatic BasicBlock *unifyReturnBlockSet(Function &F,
137317017Sdim                                       ArrayRef<BasicBlock *> ReturningBlocks,
138360784Sdim                                       bool InsertExport,
139317017Sdim                                       const TargetTransformInfo &TTI,
140317017Sdim                                       StringRef Name) {
141317017Sdim  // Otherwise, we need to insert a new basic block into the function, add a PHI
142317017Sdim  // nodes (if the function returns values), and convert all of the return
143317017Sdim  // instructions into unconditional branches.
144317017Sdim  BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F);
145360784Sdim  IRBuilder<> B(NewRetBlock);
146317017Sdim
147360784Sdim  if (InsertExport) {
148360784Sdim    // Ensure that there's only one "done" export in the shader by removing the
149360784Sdim    // "done" bit set on the original final export. More than one "done" export
150360784Sdim    // can lead to undefined behavior.
151360784Sdim    removeDoneExport(F);
152360784Sdim
153360784Sdim    Value *Undef = UndefValue::get(B.getFloatTy());
154360784Sdim    B.CreateIntrinsic(Intrinsic::amdgcn_exp, { B.getFloatTy() },
155360784Sdim                      {
156360784Sdim                        B.getInt32(9), // target, SQ_EXP_NULL
157360784Sdim                        B.getInt32(0), // enabled channels
158360784Sdim                        Undef, Undef, Undef, Undef, // values
159360784Sdim                        B.getTrue(), // done
160360784Sdim                        B.getTrue(), // valid mask
161360784Sdim                      });
162360784Sdim  }
163360784Sdim
164317017Sdim  PHINode *PN = nullptr;
165317017Sdim  if (F.getReturnType()->isVoidTy()) {
166360784Sdim    B.CreateRetVoid();
167317017Sdim  } else {
168317017Sdim    // If the function doesn't return void... add a PHI node to the block...
169360784Sdim    PN = B.CreatePHI(F.getReturnType(), ReturningBlocks.size(),
170360784Sdim                     "UnifiedRetVal");
171360784Sdim    assert(!InsertExport);
172360784Sdim    B.CreateRet(PN);
173317017Sdim  }
174317017Sdim
175317017Sdim  // Loop over all of the blocks, replacing the return instruction with an
176317017Sdim  // unconditional branch.
177317017Sdim  for (BasicBlock *BB : ReturningBlocks) {
178317017Sdim    // Add an incoming element to the PHI node for every return instruction that
179317017Sdim    // is merging into this new block...
180317017Sdim    if (PN)
181317017Sdim      PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
182317017Sdim
183341825Sdim    // Remove and delete the return inst.
184341825Sdim    BB->getTerminator()->eraseFromParent();
185317017Sdim    BranchInst::Create(NewRetBlock, BB);
186317017Sdim  }
187317017Sdim
188317017Sdim  for (BasicBlock *BB : ReturningBlocks) {
189317017Sdim    // Cleanup possible branch to unconditional branch to the return.
190327952Sdim    simplifyCFG(BB, TTI, {2});
191317017Sdim  }
192317017Sdim
193317017Sdim  return NewRetBlock;
194317017Sdim}
195317017Sdim
196317017Sdimbool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function &F) {
197317017Sdim  auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
198317017Sdim  if (PDT.getRoots().size() <= 1)
199317017Sdim    return false;
200317017Sdim
201344779Sdim  LegacyDivergenceAnalysis &DA = getAnalysis<LegacyDivergenceAnalysis>();
202317017Sdim
203317017Sdim  // Loop over all of the blocks in a function, tracking all of the blocks that
204317017Sdim  // return.
205317017Sdim  SmallVector<BasicBlock *, 4> ReturningBlocks;
206317017Sdim  SmallVector<BasicBlock *, 4> UnreachableBlocks;
207317017Sdim
208341825Sdim  // Dummy return block for infinite loop.
209341825Sdim  BasicBlock *DummyReturnBB = nullptr;
210341825Sdim
211360784Sdim  bool InsertExport = false;
212360784Sdim
213317017Sdim  for (BasicBlock *BB : PDT.getRoots()) {
214317017Sdim    if (isa<ReturnInst>(BB->getTerminator())) {
215317017Sdim      if (!isUniformlyReached(DA, *BB))
216317017Sdim        ReturningBlocks.push_back(BB);
217317017Sdim    } else if (isa<UnreachableInst>(BB->getTerminator())) {
218317017Sdim      if (!isUniformlyReached(DA, *BB))
219317017Sdim        UnreachableBlocks.push_back(BB);
220341825Sdim    } else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
221341825Sdim
222341825Sdim      ConstantInt *BoolTrue = ConstantInt::getTrue(F.getContext());
223341825Sdim      if (DummyReturnBB == nullptr) {
224341825Sdim        DummyReturnBB = BasicBlock::Create(F.getContext(),
225341825Sdim                                           "DummyReturnBlock", &F);
226341825Sdim        Type *RetTy = F.getReturnType();
227341825Sdim        Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
228360784Sdim
229360784Sdim        // For pixel shaders, the producer guarantees that an export is
230360784Sdim        // executed before each return instruction. However, if there is an
231360784Sdim        // infinite loop and we insert a return ourselves, we need to uphold
232360784Sdim        // that guarantee by inserting a null export. This can happen e.g. in
233360784Sdim        // an infinite loop with kill instructions, which is supposed to
234360784Sdim        // terminate. However, we don't need to do this if there is a non-void
235360784Sdim        // return value, since then there is an epilog afterwards which will
236360784Sdim        // still export.
237360784Sdim        //
238360784Sdim        // Note: In the case where only some threads enter the infinite loop,
239360784Sdim        // this can result in the null export happening redundantly after the
240360784Sdim        // original exports. However, The last "real" export happens after all
241360784Sdim        // the threads that didn't enter an infinite loop converged, which
242360784Sdim        // means that the only extra threads to execute the null export are
243360784Sdim        // threads that entered the infinite loop, and they only could've
244360784Sdim        // exited through being killed which sets their exec bit to 0.
245360784Sdim        // Therefore, unless there's an actual infinite loop, which can have
246360784Sdim        // invalid results, or there's a kill after the last export, which we
247360784Sdim        // assume the frontend won't do, this export will have the same exec
248360784Sdim        // mask as the last "real" export, and therefore the valid mask will be
249360784Sdim        // overwritten with the same value and will still be correct. Also,
250360784Sdim        // even though this forces an extra unnecessary export wait, we assume
251360784Sdim        // that this happens rare enough in practice to that we don't have to
252360784Sdim        // worry about performance.
253360784Sdim        if (F.getCallingConv() == CallingConv::AMDGPU_PS &&
254360784Sdim            RetTy->isVoidTy()) {
255360784Sdim          InsertExport = true;
256360784Sdim        }
257360784Sdim
258341825Sdim        ReturnInst::Create(F.getContext(), RetVal, DummyReturnBB);
259341825Sdim        ReturningBlocks.push_back(DummyReturnBB);
260341825Sdim      }
261341825Sdim
262341825Sdim      if (BI->isUnconditional()) {
263341825Sdim        BasicBlock *LoopHeaderBB = BI->getSuccessor(0);
264341825Sdim        BI->eraseFromParent(); // Delete the unconditional branch.
265341825Sdim        // Add a new conditional branch with a dummy edge to the return block.
266341825Sdim        BranchInst::Create(LoopHeaderBB, DummyReturnBB, BoolTrue, BB);
267341825Sdim      } else { // Conditional branch.
268341825Sdim        // Create a new transition block to hold the conditional branch.
269353358Sdim        BasicBlock *TransitionBB = BB->splitBasicBlock(BI, "TransitionBlock");
270341825Sdim
271353358Sdim        // Create a branch that will always branch to the transition block and
272353358Sdim        // references DummyReturnBB.
273353358Sdim        BB->getTerminator()->eraseFromParent();
274341825Sdim        BranchInst::Create(TransitionBB, DummyReturnBB, BoolTrue, BB);
275341825Sdim      }
276317017Sdim    }
277317017Sdim  }
278317017Sdim
279317017Sdim  if (!UnreachableBlocks.empty()) {
280317017Sdim    BasicBlock *UnreachableBlock = nullptr;
281317017Sdim
282317017Sdim    if (UnreachableBlocks.size() == 1) {
283317017Sdim      UnreachableBlock = UnreachableBlocks.front();
284317017Sdim    } else {
285317017Sdim      UnreachableBlock = BasicBlock::Create(F.getContext(),
286317017Sdim                                            "UnifiedUnreachableBlock", &F);
287317017Sdim      new UnreachableInst(F.getContext(), UnreachableBlock);
288317017Sdim
289317017Sdim      for (BasicBlock *BB : UnreachableBlocks) {
290341825Sdim        // Remove and delete the unreachable inst.
291341825Sdim        BB->getTerminator()->eraseFromParent();
292317017Sdim        BranchInst::Create(UnreachableBlock, BB);
293317017Sdim      }
294317017Sdim    }
295317017Sdim
296317017Sdim    if (!ReturningBlocks.empty()) {
297317017Sdim      // Don't create a new unreachable inst if we have a return. The
298317017Sdim      // structurizer/annotator can't handle the multiple exits
299317017Sdim
300317017Sdim      Type *RetTy = F.getReturnType();
301317017Sdim      Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
302341825Sdim      // Remove and delete the unreachable inst.
303341825Sdim      UnreachableBlock->getTerminator()->eraseFromParent();
304317017Sdim
305317017Sdim      Function *UnreachableIntrin =
306317017Sdim        Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable);
307317017Sdim
308317017Sdim      // Insert a call to an intrinsic tracking that this is an unreachable
309317017Sdim      // point, in case we want to kill the active lanes or something later.
310317017Sdim      CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock);
311317017Sdim
312317017Sdim      // Don't create a scalar trap. We would only want to trap if this code was
313317017Sdim      // really reached, but a scalar trap would happen even if no lanes
314317017Sdim      // actually reached here.
315317017Sdim      ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock);
316317017Sdim      ReturningBlocks.push_back(UnreachableBlock);
317317017Sdim    }
318317017Sdim  }
319317017Sdim
320317017Sdim  // Now handle return blocks.
321317017Sdim  if (ReturningBlocks.empty())
322317017Sdim    return false; // No blocks return
323317017Sdim
324317017Sdim  if (ReturningBlocks.size() == 1)
325317017Sdim    return false; // Already has a single return block
326317017Sdim
327317017Sdim  const TargetTransformInfo &TTI
328317017Sdim    = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
329317017Sdim
330360784Sdim  unifyReturnBlockSet(F, ReturningBlocks, InsertExport, TTI, "UnifiedReturnBlock");
331317017Sdim  return true;
332317017Sdim}
333