1292915Sdim//===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
2292915Sdim//
3292915Sdim//                      The LLVM Compiler Infrastructure
4292915Sdim//
5292915Sdim// This file is distributed under the University of Illinois Open Source
6292915Sdim// License. See LICENSE.TXT for details.
7292915Sdim//
8292915Sdim//===----------------------------------------------------------------------===//
9292915Sdim//
10292915Sdim// This file implements a Union-find algorithm to compute Minimum Spanning Tree
11292915Sdim// for a given CFG.
12292915Sdim//
13292915Sdim//===----------------------------------------------------------------------===//
14292915Sdim
15292915Sdim#include "llvm/ADT/DenseMap.h"
16292915Sdim#include "llvm/ADT/STLExtras.h"
17292915Sdim#include "llvm/Analysis/BlockFrequencyInfo.h"
18292915Sdim#include "llvm/Analysis/BranchProbabilityInfo.h"
19292915Sdim#include "llvm/Analysis/CFG.h"
20292915Sdim#include "llvm/Support/BranchProbability.h"
21292915Sdim#include "llvm/Support/Debug.h"
22292915Sdim#include "llvm/Support/raw_ostream.h"
23292915Sdim#include "llvm/Transforms/Utils/BasicBlockUtils.h"
24292915Sdim#include <string>
25292915Sdim#include <utility>
26292915Sdim#include <vector>
27292915Sdim
28292915Sdimnamespace llvm {
29292915Sdim
30292915Sdim#define DEBUG_TYPE "cfgmst"
31292915Sdim
32292915Sdim/// \brief An union-find based Minimum Spanning Tree for CFG
33292915Sdim///
34292915Sdim/// Implements a Union-find algorithm to compute Minimum Spanning Tree
35292915Sdim/// for a given CFG.
36292915Sdimtemplate <class Edge, class BBInfo> class CFGMST {
37292915Sdimpublic:
38292915Sdim  Function &F;
39292915Sdim
40292915Sdim  // Store all the edges in CFG. It may contain some stale edges
41292915Sdim  // when Removed is set.
42292915Sdim  std::vector<std::unique_ptr<Edge>> AllEdges;
43292915Sdim
44292915Sdim  // This map records the auxiliary information for each BB.
45292915Sdim  DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
46292915Sdim
47292915Sdim  // Find the root group of the G and compress the path from G to the root.
48292915Sdim  BBInfo *findAndCompressGroup(BBInfo *G) {
49292915Sdim    if (G->Group != G)
50292915Sdim      G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
51292915Sdim    return static_cast<BBInfo *>(G->Group);
52292915Sdim  }
53292915Sdim
54292915Sdim  // Union BB1 and BB2 into the same group and return true.
55292915Sdim  // Returns false if BB1 and BB2 are already in the same group.
56292915Sdim  bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
57292915Sdim    BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
58292915Sdim    BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
59292915Sdim
60292915Sdim    if (BB1G == BB2G)
61292915Sdim      return false;
62292915Sdim
63292915Sdim    // Make the smaller rank tree a direct child or the root of high rank tree.
64292915Sdim    if (BB1G->Rank < BB2G->Rank)
65292915Sdim      BB1G->Group = BB2G;
66292915Sdim    else {
67292915Sdim      BB2G->Group = BB1G;
68292915Sdim      // If the ranks are the same, increment root of one tree by one.
69292915Sdim      if (BB1G->Rank == BB2G->Rank)
70292915Sdim        BB1G->Rank++;
71292915Sdim    }
72292915Sdim    return true;
73292915Sdim  }
74292915Sdim
75292915Sdim  // Give BB, return the auxiliary information.
76292915Sdim  BBInfo &getBBInfo(const BasicBlock *BB) const {
77292915Sdim    auto It = BBInfos.find(BB);
78292915Sdim    assert(It->second.get() != nullptr);
79292915Sdim    return *It->second.get();
80292915Sdim  }
81292915Sdim
82292915Sdim  // Traverse the CFG using a stack. Find all the edges and assign the weight.
83292915Sdim  // Edges with large weight will be put into MST first so they are less likely
84292915Sdim  // to be instrumented.
85292915Sdim  void buildEdges() {
86292915Sdim    DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
87292915Sdim
88292915Sdim    const BasicBlock *BB = &(F.getEntryBlock());
89292915Sdim    uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
90292915Sdim    // Add a fake edge to the entry.
91292915Sdim    addEdge(nullptr, BB, EntryWeight);
92292915Sdim
93292915Sdim    // Special handling for single BB functions.
94292915Sdim    if (succ_empty(BB)) {
95292915Sdim      addEdge(BB, nullptr, EntryWeight);
96292915Sdim      return;
97292915Sdim    }
98292915Sdim
99292915Sdim    static const uint32_t CriticalEdgeMultiplier = 1000;
100292915Sdim
101292915Sdim    for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) {
102292915Sdim      TerminatorInst *TI = BB->getTerminator();
103292915Sdim      uint64_t BBWeight =
104292915Sdim          (BFI != nullptr ? BFI->getBlockFreq(&*BB).getFrequency() : 2);
105292915Sdim      uint64_t Weight = 2;
106292915Sdim      if (int successors = TI->getNumSuccessors()) {
107292915Sdim        for (int i = 0; i != successors; ++i) {
108292915Sdim          BasicBlock *TargetBB = TI->getSuccessor(i);
109292915Sdim          bool Critical = isCriticalEdge(TI, i);
110292915Sdim          uint64_t scaleFactor = BBWeight;
111292915Sdim          if (Critical) {
112292915Sdim            if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
113292915Sdim              scaleFactor *= CriticalEdgeMultiplier;
114292915Sdim            else
115292915Sdim              scaleFactor = UINT64_MAX;
116292915Sdim          }
117292915Sdim          if (BPI != nullptr)
118292915Sdim            Weight = BPI->getEdgeProbability(&*BB, TargetBB).scale(scaleFactor);
119292915Sdim          addEdge(&*BB, TargetBB, Weight).IsCritical = Critical;
120292915Sdim          DEBUG(dbgs() << "  Edge: from " << BB->getName() << " to "
121292915Sdim                       << TargetBB->getName() << "  w=" << Weight << "\n");
122292915Sdim        }
123292915Sdim      } else {
124292915Sdim        addEdge(&*BB, nullptr, BBWeight);
125292915Sdim        DEBUG(dbgs() << "  Edge: from " << BB->getName() << " to exit"
126292915Sdim                     << " w = " << BBWeight << "\n");
127292915Sdim      }
128292915Sdim    }
129292915Sdim  }
130292915Sdim
131292915Sdim  // Sort CFG edges based on its weight.
132292915Sdim  void sortEdgesByWeight() {
133292915Sdim    std::stable_sort(AllEdges.begin(), AllEdges.end(),
134292915Sdim                     [](const std::unique_ptr<Edge> &Edge1,
135292915Sdim                        const std::unique_ptr<Edge> &Edge2) {
136292915Sdim                       return Edge1->Weight > Edge2->Weight;
137292915Sdim                     });
138292915Sdim  }
139292915Sdim
140292915Sdim  // Traverse all the edges and compute the Minimum Weight Spanning Tree
141292915Sdim  // using union-find algorithm.
142292915Sdim  void computeMinimumSpanningTree() {
143292915Sdim    // First, put all the critical edge with landing-pad as the Dest to MST.
144292915Sdim    // This works around the insufficient support of critical edges split
145292915Sdim    // when destination BB is a landing pad.
146292915Sdim    for (auto &Ei : AllEdges) {
147292915Sdim      if (Ei->Removed)
148292915Sdim        continue;
149292915Sdim      if (Ei->IsCritical) {
150292915Sdim        if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
151292915Sdim          if (unionGroups(Ei->SrcBB, Ei->DestBB))
152292915Sdim            Ei->InMST = true;
153292915Sdim        }
154292915Sdim      }
155292915Sdim    }
156292915Sdim
157292915Sdim    for (auto &Ei : AllEdges) {
158292915Sdim      if (Ei->Removed)
159292915Sdim        continue;
160292915Sdim      if (unionGroups(Ei->SrcBB, Ei->DestBB))
161292915Sdim        Ei->InMST = true;
162292915Sdim    }
163292915Sdim  }
164292915Sdim
165292915Sdim  // Dump the Debug information about the instrumentation.
166292915Sdim  void dumpEdges(raw_ostream &OS, const Twine &Message) const {
167292915Sdim    if (!Message.str().empty())
168292915Sdim      OS << Message << "\n";
169292915Sdim    OS << "  Number of Basic Blocks: " << BBInfos.size() << "\n";
170292915Sdim    for (auto &BI : BBInfos) {
171292915Sdim      const BasicBlock *BB = BI.first;
172292915Sdim      OS << "  BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << "  "
173292915Sdim         << BI.second->infoString() << "\n";
174292915Sdim    }
175292915Sdim
176292915Sdim    OS << "  Number of Edges: " << AllEdges.size()
177292915Sdim       << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
178292915Sdim    uint32_t Count = 0;
179292915Sdim    for (auto &EI : AllEdges)
180292915Sdim      OS << "  Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
181292915Sdim         << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
182292915Sdim  }
183292915Sdim
184292915Sdim  // Add an edge to AllEdges with weight W.
185292915Sdim  Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) {
186292915Sdim    uint32_t Index = BBInfos.size();
187292915Sdim    auto Iter = BBInfos.end();
188292915Sdim    bool Inserted;
189292915Sdim    std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
190292915Sdim    if (Inserted) {
191292915Sdim      // Newly inserted, update the real info.
192292915Sdim      Iter->second = std::move(llvm::make_unique<BBInfo>(Index));
193292915Sdim      Index++;
194292915Sdim    }
195292915Sdim    std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
196292915Sdim    if (Inserted)
197292915Sdim      // Newly inserted, update the real info.
198292915Sdim      Iter->second = std::move(llvm::make_unique<BBInfo>(Index));
199292915Sdim    AllEdges.emplace_back(new Edge(Src, Dest, W));
200292915Sdim    return *AllEdges.back();
201292915Sdim  }
202292915Sdim
203292915Sdim  BranchProbabilityInfo *BPI;
204292915Sdim  BlockFrequencyInfo *BFI;
205292915Sdim
206292915Sdimpublic:
207292915Sdim  CFGMST(Function &Func, BranchProbabilityInfo *BPI_ = nullptr,
208292915Sdim         BlockFrequencyInfo *BFI_ = nullptr)
209292915Sdim      : F(Func), BPI(BPI_), BFI(BFI_) {
210292915Sdim    buildEdges();
211292915Sdim    sortEdgesByWeight();
212292915Sdim    computeMinimumSpanningTree();
213292915Sdim  }
214292915Sdim};
215292915Sdim
216292915Sdim#undef DEBUG_TYPE // "cfgmst"
217292915Sdim} // end namespace llvm
218