1//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file implements an algorithm that returns for a divergent branch 10// the set of basic blocks whose phi nodes become divergent due to divergent 11// control. These are the blocks that are reachable by two disjoint paths from 12// the branch or loop exits that have a reaching path that is disjoint from a 13// path to the loop latch. 14// 15// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 16// control-induced divergence in phi nodes. 17// 18// 19// -- Reference -- 20// The algorithm is presented in Section 5 of 21// 22// An abstract interpretation for SPMD divergence 23// on reducible control flow graphs. 24// Julian Rosemann, Simon Moll and Sebastian Hack 25// POPL '21 26// 27// 28// -- Sync dependence -- 29// Sync dependence characterizes the control flow aspect of the 30// propagation of branch divergence. For example, 31// 32// %cond = icmp slt i32 %tid, 10 33// br i1 %cond, label %then, label %else 34// then: 35// br label %merge 36// else: 37// br label %merge 38// merge: 39// %a = phi i32 [ 0, %then ], [ 1, %else ] 40// 41// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 42// because %tid is not on its use-def chains, %a is sync dependent on %tid 43// because the branch "br i1 %cond" depends on %tid and affects which value %a 44// is assigned to. 45// 46// 47// -- Reduction to SSA construction -- 48// There are two disjoint paths from A to X, if a certain variant of SSA 49// construction places a phi node in X under the following set-up scheme. 50// 51// This variant of SSA construction ignores incoming undef values. 52// That is paths from the entry without a definition do not result in 53// phi nodes. 54// 55// entry 56// / \ 57// A \ 58// / \ Y 59// B C / 60// \ / \ / 61// D E 62// \ / 63// F 64// 65// Assume that A contains a divergent branch. We are interested 66// in the set of all blocks where each block is reachable from A 67// via two disjoint paths. This would be the set {D, F} in this 68// case. 69// To generally reduce this query to SSA construction we introduce 70// a virtual variable x and assign to x different values in each 71// successor block of A. 72// 73// entry 74// / \ 75// A \ 76// / \ Y 77// x = 0 x = 1 / 78// \ / \ / 79// D E 80// \ / 81// F 82// 83// Our flavor of SSA construction for x will construct the following 84// 85// entry 86// / \ 87// A \ 88// / \ Y 89// x0 = 0 x1 = 1 / 90// \ / \ / 91// x2 = phi E 92// \ / 93// x3 = phi 94// 95// The blocks D and F contain phi nodes and are thus each reachable 96// by two disjoins paths from A. 97// 98// -- Remarks -- 99// * In case of loop exits we need to check the disjoint path criterion for loops. 100// To this end, we check whether the definition of x differs between the 101// loop exit and the loop header (_after_ SSA construction). 102// 103// -- Known Limitations & Future Work -- 104// * The algorithm requires reducible loops because the implementation 105// implicitly performs a single iteration of the underlying data flow analysis. 106// This was done for pragmatism, simplicity and speed. 107// 108// Relevant related work for extending the algorithm to irreducible control: 109// A simple algorithm for global data flow analysis problems. 110// Matthew S. Hecht and Jeffrey D. Ullman. 111// SIAM Journal on Computing, 4(4):519���532, December 1975. 112// 113// * Another reason for requiring reducible loops is that points of 114// synchronization in irreducible loops aren't 'obvious' - there is no unique 115// header where threads 'should' synchronize when entering or coming back 116// around from the latch. 117// 118//===----------------------------------------------------------------------===// 119 120#include "llvm/Analysis/SyncDependenceAnalysis.h" 121#include "llvm/ADT/SmallPtrSet.h" 122#include "llvm/Analysis/LoopInfo.h" 123#include "llvm/IR/BasicBlock.h" 124#include "llvm/IR/CFG.h" 125#include "llvm/IR/Dominators.h" 126#include "llvm/IR/Function.h" 127 128#include <functional> 129 130#define DEBUG_TYPE "sync-dependence" 131 132// The SDA algorithm operates on a modified CFG - we modify the edges leaving 133// loop headers as follows: 134// 135// * We remove all edges leaving all loop headers. 136// * We add additional edges from the loop headers to their exit blocks. 137// 138// The modification is virtual, that is whenever we visit a loop header we 139// pretend it had different successors. 140namespace { 141using namespace llvm; 142 143// Custom Post-Order Traveral 144// 145// We cannot use the vanilla (R)PO computation of LLVM because: 146// * We (virtually) modify the CFG. 147// * We want a loop-compact block enumeration, that is the numbers assigned to 148// blocks of a loop form an interval 149// 150using POCB = std::function<void(const BasicBlock &)>; 151using VisitedSet = std::set<const BasicBlock *>; 152using BlockStack = std::vector<const BasicBlock *>; 153 154// forward 155static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 156 VisitedSet &Finalized); 157 158// for a nested region (top-level loop or nested loop) 159static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, 160 POCB CallBack, VisitedSet &Finalized) { 161 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; 162 while (!Stack.empty()) { 163 const auto *NextBB = Stack.back(); 164 165 auto *NestedLoop = LI.getLoopFor(NextBB); 166 bool IsNestedLoop = NestedLoop != Loop; 167 168 // Treat the loop as a node 169 if (IsNestedLoop) { 170 SmallVector<BasicBlock *, 3> NestedExits; 171 NestedLoop->getUniqueExitBlocks(NestedExits); 172 bool PushedNodes = false; 173 for (const auto *NestedExitBB : NestedExits) { 174 if (NestedExitBB == LoopHeader) 175 continue; 176 if (Loop && !Loop->contains(NestedExitBB)) 177 continue; 178 if (Finalized.count(NestedExitBB)) 179 continue; 180 PushedNodes = true; 181 Stack.push_back(NestedExitBB); 182 } 183 if (!PushedNodes) { 184 // All loop exits finalized -> finish this node 185 Stack.pop_back(); 186 computeLoopPO(LI, *NestedLoop, CallBack, Finalized); 187 } 188 continue; 189 } 190 191 // DAG-style 192 bool PushedNodes = false; 193 for (const auto *SuccBB : successors(NextBB)) { 194 if (SuccBB == LoopHeader) 195 continue; 196 if (Loop && !Loop->contains(SuccBB)) 197 continue; 198 if (Finalized.count(SuccBB)) 199 continue; 200 PushedNodes = true; 201 Stack.push_back(SuccBB); 202 } 203 if (!PushedNodes) { 204 // Never push nodes twice 205 Stack.pop_back(); 206 if (!Finalized.insert(NextBB).second) 207 continue; 208 CallBack(*NextBB); 209 } 210 } 211} 212 213static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { 214 VisitedSet Finalized; 215 BlockStack Stack; 216 Stack.reserve(24); // FIXME made-up number 217 Stack.push_back(&F.getEntryBlock()); 218 computeStackPO(Stack, LI, nullptr, CallBack, Finalized); 219} 220 221static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 222 VisitedSet &Finalized) { 223 /// Call CallBack on all loop blocks. 224 std::vector<const BasicBlock *> Stack; 225 const auto *LoopHeader = Loop.getHeader(); 226 227 // Visit the header last 228 Finalized.insert(LoopHeader); 229 CallBack(*LoopHeader); 230 231 // Initialize with immediate successors 232 for (const auto *BB : successors(LoopHeader)) { 233 if (!Loop.contains(BB)) 234 continue; 235 if (BB == LoopHeader) 236 continue; 237 Stack.push_back(BB); 238 } 239 240 // Compute PO inside region 241 computeStackPO(Stack, LI, &Loop, CallBack, Finalized); 242} 243 244} // namespace 245 246namespace llvm { 247 248ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; 249 250SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 251 const PostDominatorTree &PDT, 252 const LoopInfo &LI) 253 : DT(DT), PDT(PDT), LI(LI) { 254 computeTopLevelPO(*DT.getRoot()->getParent(), LI, 255 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); 256} 257 258SyncDependenceAnalysis::~SyncDependenceAnalysis() = default; 259 260namespace { 261// divergence propagator for reducible CFGs 262struct DivergencePropagator { 263 const ModifiedPO &LoopPOT; 264 const DominatorTree &DT; 265 const PostDominatorTree &PDT; 266 const LoopInfo &LI; 267 const BasicBlock &DivTermBlock; 268 269 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at 270 // block B 271 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet 272 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths 273 // from X or B is an immediate successor of X (initial value). 274 using BlockLabelVec = std::vector<const BasicBlock *>; 275 BlockLabelVec BlockLabels; 276 // divergent join and loop exit descriptor. 277 std::unique_ptr<ControlDivergenceDesc> DivDesc; 278 279 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, 280 const PostDominatorTree &PDT, const LoopInfo &LI, 281 const BasicBlock &DivTermBlock) 282 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), 283 BlockLabels(LoopPOT.size(), nullptr), 284 DivDesc(new ControlDivergenceDesc) {} 285 286 void printDefs(raw_ostream &Out) { 287 Out << "Propagator::BlockLabels {\n"; 288 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { 289 const auto *Label = BlockLabels[BlockIdx]; 290 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx 291 << ") : "; 292 if (!Label) { 293 Out << "<null>\n"; 294 } else { 295 Out << Label->getName() << "\n"; 296 } 297 } 298 Out << "}\n"; 299 } 300 301 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 302 // causes a divergent join. 303 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { 304 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); 305 306 // unset or same reaching label 307 const auto *OldLabel = BlockLabels[SuccIdx]; 308 if (!OldLabel || (OldLabel == &PushedLabel)) { 309 BlockLabels[SuccIdx] = &PushedLabel; 310 return false; 311 } 312 313 // Update the definition 314 BlockLabels[SuccIdx] = &SuccBlock; 315 return true; 316 } 317 318 // visiting a virtual loop exit edge from the loop header --> temporal 319 // divergence on join 320 bool visitLoopExitEdge(const BasicBlock &ExitBlock, 321 const BasicBlock &DefBlock, bool FromParentLoop) { 322 // Pushing from a non-parent loop cannot cause temporal divergence. 323 if (!FromParentLoop) 324 return visitEdge(ExitBlock, DefBlock); 325 326 if (!computeJoin(ExitBlock, DefBlock)) 327 return false; 328 329 // Identified a divergent loop exit 330 DivDesc->LoopDivBlocks.insert(&ExitBlock); 331 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() 332 << "\n"); 333 return true; 334 } 335 336 // process \p SuccBlock with reaching definition \p DefBlock 337 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { 338 if (!computeJoin(SuccBlock, DefBlock)) 339 return false; 340 341 // Divergent, disjoint paths join. 342 DivDesc->JoinDivBlocks.insert(&SuccBlock); 343 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); 344 return true; 345 } 346 347 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() { 348 assert(DivDesc); 349 350 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() 351 << "\n"); 352 353 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); 354 355 // Early stopping criterion 356 int FloorIdx = LoopPOT.size() - 1; 357 const BasicBlock *FloorLabel = nullptr; 358 359 // bootstrap with branch targets 360 int BlockIdx = 0; 361 362 for (const auto *SuccBlock : successors(&DivTermBlock)) { 363 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); 364 BlockLabels[SuccIdx] = SuccBlock; 365 366 // Find the successor with the highest index to start with 367 BlockIdx = std::max<int>(BlockIdx, SuccIdx); 368 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 369 370 // Identify immediate divergent loop exits 371 if (!DivBlockLoop) 372 continue; 373 374 const auto *BlockLoop = LI.getLoopFor(SuccBlock); 375 if (BlockLoop && DivBlockLoop->contains(BlockLoop)) 376 continue; 377 DivDesc->LoopDivBlocks.insert(SuccBlock); 378 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " 379 << SuccBlock->getName() << "\n"); 380 } 381 382 // propagate definitions at the immediate successors of the node in RPO 383 for (; BlockIdx >= FloorIdx; --BlockIdx) { 384 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); 385 386 // Any label available here 387 const auto *Label = BlockLabels[BlockIdx]; 388 if (!Label) 389 continue; 390 391 // Ok. Get the block 392 const auto *Block = LoopPOT.getBlockAt(BlockIdx); 393 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); 394 395 auto *BlockLoop = LI.getLoopFor(Block); 396 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; 397 bool CausedJoin = false; 398 int LoweredFloorIdx = FloorIdx; 399 if (IsLoopHeader) { 400 // Disconnect from immediate successors and propagate directly to loop 401 // exits. 402 SmallVector<BasicBlock *, 4> BlockLoopExits; 403 BlockLoop->getExitBlocks(BlockLoopExits); 404 405 bool IsParentLoop = BlockLoop->contains(&DivTermBlock); 406 for (const auto *BlockLoopExit : BlockLoopExits) { 407 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); 408 LoweredFloorIdx = std::min<int>(LoweredFloorIdx, 409 LoopPOT.getIndexOf(*BlockLoopExit)); 410 } 411 } else { 412 // Acyclic successor case 413 for (const auto *SuccBlock : successors(Block)) { 414 CausedJoin |= visitEdge(*SuccBlock, *Label); 415 LoweredFloorIdx = 416 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); 417 } 418 } 419 420 // Floor update 421 if (CausedJoin) { 422 // 1. Different labels pushed to successors 423 FloorIdx = LoweredFloorIdx; 424 } else if (FloorLabel != Label) { 425 // 2. No join caused BUT we pushed a label that is different than the 426 // last pushed label 427 FloorIdx = LoweredFloorIdx; 428 FloorLabel = Label; 429 } 430 } 431 432 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); 433 434 return std::move(DivDesc); 435 } 436}; 437} // end anonymous namespace 438 439#ifndef NDEBUG 440static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { 441 Out << "["; 442 ListSeparator LS; 443 for (const auto *BB : Blocks) 444 Out << LS << BB->getName(); 445 Out << "]"; 446} 447#endif 448 449const ControlDivergenceDesc & 450SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { 451 // trivial case 452 if (Term.getNumSuccessors() <= 1) { 453 return EmptyDivergenceDesc; 454 } 455 456 // already available in cache? 457 auto ItCached = CachedControlDivDescs.find(&Term); 458 if (ItCached != CachedControlDivDescs.end()) 459 return *ItCached->second; 460 461 // compute all join points 462 // Special handling of divergent loop exits is not needed for LCSSA 463 const auto &TermBlock = *Term.getParent(); 464 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); 465 auto DivDesc = Propagator.computeJoinPoints(); 466 467 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; 468 dbgs() << "JoinDivBlocks: "; 469 printBlockSet(DivDesc->JoinDivBlocks, dbgs()); 470 dbgs() << "\nLoopDivBlocks: "; 471 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); 472 473 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); 474 assert(ItInserted.second); 475 return *ItInserted.first->second; 476} 477 478} // namespace llvm 479