LegacyDivergenceAnalysis.cpp revision 360784
1//===- LegacyDivergenceAnalysis.cpp --------- Legacy Divergence Analysis
2//Implementation -==//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements divergence analysis which determines whether a branch
11// in a GPU program is divergent.It can help branch optimizations such as jump
12// threading and loop unswitching to make better decisions.
13//
14// GPU programs typically use the SIMD execution model, where multiple threads
15// in the same execution group have to execute in lock-step. Therefore, if the
16// code contains divergent branches (i.e., threads in a group do not agree on
17// which path of the branch to take), the group of threads has to execute all
18// the paths from that branch with different subsets of threads enabled until
19// they converge at the immediately post-dominating BB of the paths.
20//
21// Due to this execution model, some optimizations such as jump
22// threading and loop unswitching can be unfortunately harmful when performed on
23// divergent branches. Therefore, an analysis that computes which branches in a
24// GPU program are divergent can help the compiler to selectively run these
25// optimizations.
26//
27// This file defines divergence analysis which computes a conservative but
28// non-trivial approximation of all divergent branches in a GPU program. It
29// partially implements the approach described in
30//
31//   Divergence Analysis
32//   Sampaio, Souza, Collange, Pereira
33//   TOPLAS '13
34//
35// The divergence analysis identifies the sources of divergence (e.g., special
36// variables that hold the thread ID), and recursively marks variables that are
37// data or sync dependent on a source of divergence as divergent.
38//
39// While data dependency is a well-known concept, the notion of sync dependency
40// is worth more explanation. Sync dependence characterizes the control flow
41// aspect of the propagation of branch divergence. For example,
42//
43//   %cond = icmp slt i32 %tid, 10
44//   br i1 %cond, label %then, label %else
45// then:
46//   br label %merge
47// else:
48//   br label %merge
49// merge:
50//   %a = phi i32 [ 0, %then ], [ 1, %else ]
51//
52// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
53// because %tid is not on its use-def chains, %a is sync dependent on %tid
54// because the branch "br i1 %cond" depends on %tid and affects which value %a
55// is assigned to.
56//
57// The current implementation has the following limitations:
58// 1. intra-procedural. It conservatively considers the arguments of a
59//    non-kernel-entry function and the return value of a function call as
60//    divergent.
61// 2. memory as black box. It conservatively considers values loaded from
62//    generic or local address as divergent. This can be improved by leveraging
63//    pointer analysis.
64//
65//===----------------------------------------------------------------------===//
66
67#include "llvm/Analysis/LegacyDivergenceAnalysis.h"
68#include "llvm/ADT/PostOrderIterator.h"
69#include "llvm/Analysis/CFG.h"
70#include "llvm/Analysis/DivergenceAnalysis.h"
71#include "llvm/Analysis/Passes.h"
72#include "llvm/Analysis/PostDominators.h"
73#include "llvm/Analysis/TargetTransformInfo.h"
74#include "llvm/IR/Dominators.h"
75#include "llvm/IR/InstIterator.h"
76#include "llvm/IR/Instructions.h"
77#include "llvm/IR/Value.h"
78#include "llvm/InitializePasses.h"
79#include "llvm/Support/CommandLine.h"
80#include "llvm/Support/Debug.h"
81#include "llvm/Support/raw_ostream.h"
82#include <vector>
83using namespace llvm;
84
85#define DEBUG_TYPE "divergence"
86
87// transparently use the GPUDivergenceAnalysis
88static cl::opt<bool> UseGPUDA("use-gpu-divergence-analysis", cl::init(false),
89                              cl::Hidden,
90                              cl::desc("turn the LegacyDivergenceAnalysis into "
91                                       "a wrapper for GPUDivergenceAnalysis"));
92
93namespace {
94
95class DivergencePropagator {
96public:
97  DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
98                       PostDominatorTree &PDT, DenseSet<const Value *> &DV,
99                       DenseSet<const Use *> &DU)
100      : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV), DU(DU) {}
101  void populateWithSourcesOfDivergence();
102  void propagate();
103
104private:
105  // A helper function that explores data dependents of V.
106  void exploreDataDependency(Value *V);
107  // A helper function that explores sync dependents of TI.
108  void exploreSyncDependency(Instruction *TI);
109  // Computes the influence region from Start to End. This region includes all
110  // basic blocks on any simple path from Start to End.
111  void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
112                              DenseSet<BasicBlock *> &InfluenceRegion);
113  // Finds all users of I that are outside the influence region, and add these
114  // users to Worklist.
115  void findUsersOutsideInfluenceRegion(
116      Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion);
117
118  Function &F;
119  TargetTransformInfo &TTI;
120  DominatorTree &DT;
121  PostDominatorTree &PDT;
122  std::vector<Value *> Worklist; // Stack for DFS.
123  DenseSet<const Value *> &DV;   // Stores all divergent values.
124  DenseSet<const Use *> &DU;   // Stores divergent uses of possibly uniform
125                               // values.
126};
127
128void DivergencePropagator::populateWithSourcesOfDivergence() {
129  Worklist.clear();
130  DV.clear();
131  DU.clear();
132  for (auto &I : instructions(F)) {
133    if (TTI.isSourceOfDivergence(&I)) {
134      Worklist.push_back(&I);
135      DV.insert(&I);
136    }
137  }
138  for (auto &Arg : F.args()) {
139    if (TTI.isSourceOfDivergence(&Arg)) {
140      Worklist.push_back(&Arg);
141      DV.insert(&Arg);
142    }
143  }
144}
145
146void DivergencePropagator::exploreSyncDependency(Instruction *TI) {
147  // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's
148  // immediate post dominator are divergent. This rule handles if-then-else
149  // patterns. For example,
150  //
151  // if (tid < 5)
152  //   a1 = 1;
153  // else
154  //   a2 = 2;
155  // a = phi(a1, a2); // sync dependent on (tid < 5)
156  BasicBlock *ThisBB = TI->getParent();
157
158  // Unreachable blocks may not be in the dominator tree.
159  if (!DT.isReachableFromEntry(ThisBB))
160    return;
161
162  // If the function has no exit blocks or doesn't reach any exit blocks, the
163  // post dominator may be null.
164  DomTreeNode *ThisNode = PDT.getNode(ThisBB);
165  if (!ThisNode)
166    return;
167
168  BasicBlock *IPostDom = ThisNode->getIDom()->getBlock();
169  if (IPostDom == nullptr)
170    return;
171
172  for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
173    // A PHINode is uniform if it returns the same value no matter which path is
174    // taken.
175    if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second)
176      Worklist.push_back(&*I);
177  }
178
179  // Propagation rule 2: if a value defined in a loop is used outside, the user
180  // is sync dependent on the condition of the loop exits that dominate the
181  // user. For example,
182  //
183  // int i = 0;
184  // do {
185  //   i++;
186  //   if (foo(i)) ... // uniform
187  // } while (i < tid);
188  // if (bar(i)) ...   // divergent
189  //
190  // A program may contain unstructured loops. Therefore, we cannot leverage
191  // LoopInfo, which only recognizes natural loops.
192  //
193  // The algorithm used here handles both natural and unstructured loops.  Given
194  // a branch TI, we first compute its influence region, the union of all simple
195  // paths from TI to its immediate post dominator (IPostDom). Then, we search
196  // for all the values defined in the influence region but used outside. All
197  // these users are sync dependent on TI.
198  DenseSet<BasicBlock *> InfluenceRegion;
199  computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion);
200  // An insight that can speed up the search process is that all the in-region
201  // values that are used outside must dominate TI. Therefore, instead of
202  // searching every basic blocks in the influence region, we search all the
203  // dominators of TI until it is outside the influence region.
204  BasicBlock *InfluencedBB = ThisBB;
205  while (InfluenceRegion.count(InfluencedBB)) {
206    for (auto &I : *InfluencedBB) {
207      if (!DV.count(&I))
208        findUsersOutsideInfluenceRegion(I, InfluenceRegion);
209    }
210    DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom();
211    if (IDomNode == nullptr)
212      break;
213    InfluencedBB = IDomNode->getBlock();
214  }
215}
216
217void DivergencePropagator::findUsersOutsideInfluenceRegion(
218    Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) {
219  for (Use &Use : I.uses()) {
220    Instruction *UserInst = cast<Instruction>(Use.getUser());
221    if (!InfluenceRegion.count(UserInst->getParent())) {
222      DU.insert(&Use);
223      if (DV.insert(UserInst).second)
224        Worklist.push_back(UserInst);
225    }
226  }
227}
228
229// A helper function for computeInfluenceRegion that adds successors of "ThisBB"
230// to the influence region.
231static void
232addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End,
233                               DenseSet<BasicBlock *> &InfluenceRegion,
234                               std::vector<BasicBlock *> &InfluenceStack) {
235  for (BasicBlock *Succ : successors(ThisBB)) {
236    if (Succ != End && InfluenceRegion.insert(Succ).second)
237      InfluenceStack.push_back(Succ);
238  }
239}
240
241void DivergencePropagator::computeInfluenceRegion(
242    BasicBlock *Start, BasicBlock *End,
243    DenseSet<BasicBlock *> &InfluenceRegion) {
244  assert(PDT.properlyDominates(End, Start) &&
245         "End does not properly dominate Start");
246
247  // The influence region starts from the end of "Start" to the beginning of
248  // "End". Therefore, "Start" should not be in the region unless "Start" is in
249  // a loop that doesn't contain "End".
250  std::vector<BasicBlock *> InfluenceStack;
251  addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack);
252  while (!InfluenceStack.empty()) {
253    BasicBlock *BB = InfluenceStack.back();
254    InfluenceStack.pop_back();
255    addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack);
256  }
257}
258
259void DivergencePropagator::exploreDataDependency(Value *V) {
260  // Follow def-use chains of V.
261  for (User *U : V->users()) {
262    if (!TTI.isAlwaysUniform(U) && DV.insert(U).second)
263      Worklist.push_back(U);
264  }
265}
266
267void DivergencePropagator::propagate() {
268  // Traverse the dependency graph using DFS.
269  while (!Worklist.empty()) {
270    Value *V = Worklist.back();
271    Worklist.pop_back();
272    if (Instruction *I = dyn_cast<Instruction>(V)) {
273      // Terminators with less than two successors won't introduce sync
274      // dependency. Ignore them.
275      if (I->isTerminator() && I->getNumSuccessors() > 1)
276        exploreSyncDependency(I);
277    }
278    exploreDataDependency(V);
279  }
280}
281
282} // namespace
283
284// Register this pass.
285char LegacyDivergenceAnalysis::ID = 0;
286LegacyDivergenceAnalysis::LegacyDivergenceAnalysis() : FunctionPass(ID) {
287  initializeLegacyDivergenceAnalysisPass(*PassRegistry::getPassRegistry());
288}
289INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence",
290                      "Legacy Divergence Analysis", false, true)
291INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
292INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
293INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
294INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence",
295                    "Legacy Divergence Analysis", false, true)
296
297FunctionPass *llvm::createLegacyDivergenceAnalysisPass() {
298  return new LegacyDivergenceAnalysis();
299}
300
301void LegacyDivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
302  AU.addRequired<DominatorTreeWrapperPass>();
303  AU.addRequired<PostDominatorTreeWrapperPass>();
304  if (UseGPUDA)
305    AU.addRequired<LoopInfoWrapperPass>();
306  AU.setPreservesAll();
307}
308
309bool LegacyDivergenceAnalysis::shouldUseGPUDivergenceAnalysis(
310    const Function &F) const {
311  if (!UseGPUDA)
312    return false;
313
314  // GPUDivergenceAnalysis requires a reducible CFG.
315  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
316  using RPOTraversal = ReversePostOrderTraversal<const Function *>;
317  RPOTraversal FuncRPOT(&F);
318  return !containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
319                                 const LoopInfo>(FuncRPOT, LI);
320}
321
322bool LegacyDivergenceAnalysis::runOnFunction(Function &F) {
323  auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
324  if (TTIWP == nullptr)
325    return false;
326
327  TargetTransformInfo &TTI = TTIWP->getTTI(F);
328  // Fast path: if the target does not have branch divergence, we do not mark
329  // any branch as divergent.
330  if (!TTI.hasBranchDivergence())
331    return false;
332
333  DivergentValues.clear();
334  DivergentUses.clear();
335  gpuDA = nullptr;
336
337  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
338  auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
339
340  if (shouldUseGPUDivergenceAnalysis(F)) {
341    // run the new GPU divergence analysis
342    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
343    gpuDA = std::make_unique<GPUDivergenceAnalysis>(F, DT, PDT, LI, TTI);
344
345  } else {
346    // run LLVM's existing DivergenceAnalysis
347    DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues, DivergentUses);
348    DP.populateWithSourcesOfDivergence();
349    DP.propagate();
350  }
351
352  LLVM_DEBUG(dbgs() << "\nAfter divergence analysis on " << F.getName()
353                    << ":\n";
354             print(dbgs(), F.getParent()));
355
356  return false;
357}
358
359bool LegacyDivergenceAnalysis::isDivergent(const Value *V) const {
360  if (gpuDA) {
361    return gpuDA->isDivergent(*V);
362  }
363  return DivergentValues.count(V);
364}
365
366bool LegacyDivergenceAnalysis::isDivergentUse(const Use *U) const {
367  if (gpuDA) {
368    return gpuDA->isDivergentUse(*U);
369  }
370  return DivergentValues.count(U->get()) || DivergentUses.count(U);
371}
372
373void LegacyDivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
374  if ((!gpuDA || !gpuDA->hasDivergence()) && DivergentValues.empty())
375    return;
376
377  const Function *F = nullptr;
378  if (!DivergentValues.empty()) {
379    const Value *FirstDivergentValue = *DivergentValues.begin();
380    if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) {
381      F = Arg->getParent();
382    } else if (const Instruction *I =
383                   dyn_cast<Instruction>(FirstDivergentValue)) {
384      F = I->getParent()->getParent();
385    } else {
386      llvm_unreachable("Only arguments and instructions can be divergent");
387    }
388  } else if (gpuDA) {
389    F = &gpuDA->getFunction();
390  }
391  if (!F)
392    return;
393
394  // Dumps all divergent values in F, arguments and then instructions.
395  for (auto &Arg : F->args()) {
396    OS << (isDivergent(&Arg) ? "DIVERGENT: " : "           ");
397    OS << Arg << "\n";
398  }
399  // Iterate instructions using instructions() to ensure a deterministic order.
400  for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) {
401    auto &BB = *BI;
402    OS << "\n           " << BB.getName() << ":\n";
403    for (auto &I : BB.instructionsWithoutDebug()) {
404      OS << (isDivergent(&I) ? "DIVERGENT:     " : "               ");
405      OS << I << "\n";
406    }
407  }
408  OS << "\n";
409}
410