1//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an algorithm that returns for a divergent branch
10// the set of basic blocks whose phi nodes become divergent due to divergent
11// control. These are the blocks that are reachable by two disjoint paths from
12// the branch or loop exits that have a reaching path that is disjoint from a
13// path to the loop latch.
14//
15// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16// control-induced divergence in phi nodes.
17//
18//
19// -- Reference --
20// The algorithm is presented in Section 5 of
21//
22//   An abstract interpretation for SPMD divergence
23//       on reducible control flow graphs.
24//   Julian Rosemann, Simon Moll and Sebastian Hack
25//   POPL '21
26//
27//
28// -- Sync dependence --
29// Sync dependence characterizes the control flow aspect of the
30// propagation of branch divergence. For example,
31//
32//   %cond = icmp slt i32 %tid, 10
33//   br i1 %cond, label %then, label %else
34// then:
35//   br label %merge
36// else:
37//   br label %merge
38// merge:
39//   %a = phi i32 [ 0, %then ], [ 1, %else ]
40//
41// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
42// because %tid is not on its use-def chains, %a is sync dependent on %tid
43// because the branch "br i1 %cond" depends on %tid and affects which value %a
44// is assigned to.
45//
46//
47// -- Reduction to SSA construction --
48// There are two disjoint paths from A to X, if a certain variant of SSA
49// construction places a phi node in X under the following set-up scheme.
50//
51// This variant of SSA construction ignores incoming undef values.
52// That is paths from the entry without a definition do not result in
53// phi nodes.
54//
55//       entry
56//     /      \
57//    A        \
58//  /   \       Y
59// B     C     /
60//  \   /  \  /
61//    D     E
62//     \   /
63//       F
64//
65// Assume that A contains a divergent branch. We are interested
66// in the set of all blocks where each block is reachable from A
67// via two disjoint paths. This would be the set {D, F} in this
68// case.
69// To generally reduce this query to SSA construction we introduce
70// a virtual variable x and assign to x different values in each
71// successor block of A.
72//
73//           entry
74//         /      \
75//        A        \
76//      /   \       Y
77// x = 0   x = 1   /
78//      \  /   \  /
79//        D     E
80//         \   /
81//           F
82//
83// Our flavor of SSA construction for x will construct the following
84//
85//            entry
86//          /      \
87//         A        \
88//       /   \       Y
89// x0 = 0   x1 = 1  /
90//       \   /   \ /
91//     x2 = phi   E
92//         \     /
93//         x3 = phi
94//
95// The blocks D and F contain phi nodes and are thus each reachable
96// by two disjoins paths from A.
97//
98// -- Remarks --
99// * In case of loop exits we need to check the disjoint path criterion for loops.
100//   To this end, we check whether the definition of x differs between the
101//   loop exit and the loop header (_after_ SSA construction).
102//
103// -- Known Limitations & Future Work --
104// * The algorithm requires reducible loops because the implementation
105//   implicitly performs a single iteration of the underlying data flow analysis.
106//   This was done for pragmatism, simplicity and speed.
107//
108//   Relevant related work for extending the algorithm to irreducible control:
109//     A simple algorithm for global data flow analysis problems.
110//     Matthew S. Hecht and Jeffrey D. Ullman.
111//     SIAM Journal on Computing, 4(4):519���532, December 1975.
112//
113// * Another reason for requiring reducible loops is that points of
114//   synchronization in irreducible loops aren't 'obvious' - there is no unique
115//   header where threads 'should' synchronize when entering or coming back
116//   around from the latch.
117//
118//===----------------------------------------------------------------------===//
119
120#include "llvm/Analysis/SyncDependenceAnalysis.h"
121#include "llvm/ADT/SmallPtrSet.h"
122#include "llvm/Analysis/LoopInfo.h"
123#include "llvm/IR/BasicBlock.h"
124#include "llvm/IR/CFG.h"
125#include "llvm/IR/Dominators.h"
126#include "llvm/IR/Function.h"
127
128#include <functional>
129
130#define DEBUG_TYPE "sync-dependence"
131
132// The SDA algorithm operates on a modified CFG - we modify the edges leaving
133// loop headers as follows:
134//
135// * We remove all edges leaving all loop headers.
136// * We add additional edges from the loop headers to their exit blocks.
137//
138// The modification is virtual, that is whenever we visit a loop header we
139// pretend it had different successors.
140namespace {
141using namespace llvm;
142
143// Custom Post-Order Traveral
144//
145// We cannot use the vanilla (R)PO computation of LLVM because:
146// * We (virtually) modify the CFG.
147// * We want a loop-compact block enumeration, that is the numbers assigned to
148//   blocks of a loop form an interval
149//
150using POCB = std::function<void(const BasicBlock &)>;
151using VisitedSet = std::set<const BasicBlock *>;
152using BlockStack = std::vector<const BasicBlock *>;
153
154// forward
155static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
156                          VisitedSet &Finalized);
157
158// for a nested region (top-level loop or nested loop)
159static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
160                           POCB CallBack, VisitedSet &Finalized) {
161  const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
162  while (!Stack.empty()) {
163    const auto *NextBB = Stack.back();
164
165    auto *NestedLoop = LI.getLoopFor(NextBB);
166    bool IsNestedLoop = NestedLoop != Loop;
167
168    // Treat the loop as a node
169    if (IsNestedLoop) {
170      SmallVector<BasicBlock *, 3> NestedExits;
171      NestedLoop->getUniqueExitBlocks(NestedExits);
172      bool PushedNodes = false;
173      for (const auto *NestedExitBB : NestedExits) {
174        if (NestedExitBB == LoopHeader)
175          continue;
176        if (Loop && !Loop->contains(NestedExitBB))
177          continue;
178        if (Finalized.count(NestedExitBB))
179          continue;
180        PushedNodes = true;
181        Stack.push_back(NestedExitBB);
182      }
183      if (!PushedNodes) {
184        // All loop exits finalized -> finish this node
185        Stack.pop_back();
186        computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
187      }
188      continue;
189    }
190
191    // DAG-style
192    bool PushedNodes = false;
193    for (const auto *SuccBB : successors(NextBB)) {
194      if (SuccBB == LoopHeader)
195        continue;
196      if (Loop && !Loop->contains(SuccBB))
197        continue;
198      if (Finalized.count(SuccBB))
199        continue;
200      PushedNodes = true;
201      Stack.push_back(SuccBB);
202    }
203    if (!PushedNodes) {
204      // Never push nodes twice
205      Stack.pop_back();
206      if (!Finalized.insert(NextBB).second)
207        continue;
208      CallBack(*NextBB);
209    }
210  }
211}
212
213static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
214  VisitedSet Finalized;
215  BlockStack Stack;
216  Stack.reserve(24); // FIXME made-up number
217  Stack.push_back(&F.getEntryBlock());
218  computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
219}
220
221static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
222                          VisitedSet &Finalized) {
223  /// Call CallBack on all loop blocks.
224  std::vector<const BasicBlock *> Stack;
225  const auto *LoopHeader = Loop.getHeader();
226
227  // Visit the header last
228  Finalized.insert(LoopHeader);
229  CallBack(*LoopHeader);
230
231  // Initialize with immediate successors
232  for (const auto *BB : successors(LoopHeader)) {
233    if (!Loop.contains(BB))
234      continue;
235    if (BB == LoopHeader)
236      continue;
237    Stack.push_back(BB);
238  }
239
240  // Compute PO inside region
241  computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
242}
243
244} // namespace
245
246namespace llvm {
247
248ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
249
250SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
251                                               const PostDominatorTree &PDT,
252                                               const LoopInfo &LI)
253    : DT(DT), PDT(PDT), LI(LI) {
254  computeTopLevelPO(*DT.getRoot()->getParent(), LI,
255                    [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
256}
257
258SyncDependenceAnalysis::~SyncDependenceAnalysis() = default;
259
260namespace {
261// divergence propagator for reducible CFGs
262struct DivergencePropagator {
263  const ModifiedPO &LoopPOT;
264  const DominatorTree &DT;
265  const PostDominatorTree &PDT;
266  const LoopInfo &LI;
267  const BasicBlock &DivTermBlock;
268
269  // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
270  //   block B
271  // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
272  // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
273  // from X or B is an immediate successor of X (initial value).
274  using BlockLabelVec = std::vector<const BasicBlock *>;
275  BlockLabelVec BlockLabels;
276  // divergent join and loop exit descriptor.
277  std::unique_ptr<ControlDivergenceDesc> DivDesc;
278
279  DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
280                       const PostDominatorTree &PDT, const LoopInfo &LI,
281                       const BasicBlock &DivTermBlock)
282      : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
283        BlockLabels(LoopPOT.size(), nullptr),
284        DivDesc(new ControlDivergenceDesc) {}
285
286  void printDefs(raw_ostream &Out) {
287    Out << "Propagator::BlockLabels {\n";
288    for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
289      const auto *Label = BlockLabels[BlockIdx];
290      Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
291          << ") : ";
292      if (!Label) {
293        Out << "<null>\n";
294      } else {
295        Out << Label->getName() << "\n";
296      }
297    }
298    Out << "}\n";
299  }
300
301  // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
302  // causes a divergent join.
303  bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
304    auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
305
306    // unset or same reaching label
307    const auto *OldLabel = BlockLabels[SuccIdx];
308    if (!OldLabel || (OldLabel == &PushedLabel)) {
309      BlockLabels[SuccIdx] = &PushedLabel;
310      return false;
311    }
312
313    // Update the definition
314    BlockLabels[SuccIdx] = &SuccBlock;
315    return true;
316  }
317
318  // visiting a virtual loop exit edge from the loop header --> temporal
319  // divergence on join
320  bool visitLoopExitEdge(const BasicBlock &ExitBlock,
321                         const BasicBlock &DefBlock, bool FromParentLoop) {
322    // Pushing from a non-parent loop cannot cause temporal divergence.
323    if (!FromParentLoop)
324      return visitEdge(ExitBlock, DefBlock);
325
326    if (!computeJoin(ExitBlock, DefBlock))
327      return false;
328
329    // Identified a divergent loop exit
330    DivDesc->LoopDivBlocks.insert(&ExitBlock);
331    LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
332                      << "\n");
333    return true;
334  }
335
336  // process \p SuccBlock with reaching definition \p DefBlock
337  bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
338    if (!computeJoin(SuccBlock, DefBlock))
339      return false;
340
341    // Divergent, disjoint paths join.
342    DivDesc->JoinDivBlocks.insert(&SuccBlock);
343    LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
344    return true;
345  }
346
347  std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
348    assert(DivDesc);
349
350    LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
351                      << "\n");
352
353    const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
354
355    // Early stopping criterion
356    int FloorIdx = LoopPOT.size() - 1;
357    const BasicBlock *FloorLabel = nullptr;
358
359    // bootstrap with branch targets
360    int BlockIdx = 0;
361
362    for (const auto *SuccBlock : successors(&DivTermBlock)) {
363      auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
364      BlockLabels[SuccIdx] = SuccBlock;
365
366      // Find the successor with the highest index to start with
367      BlockIdx = std::max<int>(BlockIdx, SuccIdx);
368      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
369
370      // Identify immediate divergent loop exits
371      if (!DivBlockLoop)
372        continue;
373
374      const auto *BlockLoop = LI.getLoopFor(SuccBlock);
375      if (BlockLoop && DivBlockLoop->contains(BlockLoop))
376        continue;
377      DivDesc->LoopDivBlocks.insert(SuccBlock);
378      LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
379                        << SuccBlock->getName() << "\n");
380    }
381
382    // propagate definitions at the immediate successors of the node in RPO
383    for (; BlockIdx >= FloorIdx; --BlockIdx) {
384      LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
385
386      // Any label available here
387      const auto *Label = BlockLabels[BlockIdx];
388      if (!Label)
389        continue;
390
391      // Ok. Get the block
392      const auto *Block = LoopPOT.getBlockAt(BlockIdx);
393      LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
394
395      auto *BlockLoop = LI.getLoopFor(Block);
396      bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
397      bool CausedJoin = false;
398      int LoweredFloorIdx = FloorIdx;
399      if (IsLoopHeader) {
400        // Disconnect from immediate successors and propagate directly to loop
401        // exits.
402        SmallVector<BasicBlock *, 4> BlockLoopExits;
403        BlockLoop->getExitBlocks(BlockLoopExits);
404
405        bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
406        for (const auto *BlockLoopExit : BlockLoopExits) {
407          CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
408          LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
409                                          LoopPOT.getIndexOf(*BlockLoopExit));
410        }
411      } else {
412        // Acyclic successor case
413        for (const auto *SuccBlock : successors(Block)) {
414          CausedJoin |= visitEdge(*SuccBlock, *Label);
415          LoweredFloorIdx =
416              std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
417        }
418      }
419
420      // Floor update
421      if (CausedJoin) {
422        // 1. Different labels pushed to successors
423        FloorIdx = LoweredFloorIdx;
424      } else if (FloorLabel != Label) {
425        // 2. No join caused BUT we pushed a label that is different than the
426        // last pushed label
427        FloorIdx = LoweredFloorIdx;
428        FloorLabel = Label;
429      }
430    }
431
432    LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
433
434    return std::move(DivDesc);
435  }
436};
437} // end anonymous namespace
438
439#ifndef NDEBUG
440static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
441  Out << "[";
442  ListSeparator LS;
443  for (const auto *BB : Blocks)
444    Out << LS << BB->getName();
445  Out << "]";
446}
447#endif
448
449const ControlDivergenceDesc &
450SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
451  // trivial case
452  if (Term.getNumSuccessors() <= 1) {
453    return EmptyDivergenceDesc;
454  }
455
456  // already available in cache?
457  auto ItCached = CachedControlDivDescs.find(&Term);
458  if (ItCached != CachedControlDivDescs.end())
459    return *ItCached->second;
460
461  // compute all join points
462  // Special handling of divergent loop exits is not needed for LCSSA
463  const auto &TermBlock = *Term.getParent();
464  DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
465  auto DivDesc = Propagator.computeJoinPoints();
466
467  LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
468             dbgs() << "JoinDivBlocks: ";
469             printBlockSet(DivDesc->JoinDivBlocks, dbgs());
470             dbgs() << "\nLoopDivBlocks: ";
471             printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
472
473  auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
474  assert(ItInserted.second);
475  return *ItInserted.first->second;
476}
477
478} // namespace llvm
479