120469Smpp//===-- CFGPrinter.h - CFG printer external interface -----------*- C++ -*-===//
220469Smpp//
320469Smpp// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
420469Smpp// See https://llvm.org/LICENSE.txt for license information.
520469Smpp// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
620469Smpp//
720469Smpp//===----------------------------------------------------------------------===//
820469Smpp//
920469Smpp// This file defines a 'dot-cfg' analysis pass, which emits the
1020469Smpp// cfg.<fnname>.dot file for each function in the program, with a graph of the
1120469Smpp// CFG for that function.
1220469Smpp//
1320469Smpp// This file defines external functions that can be called to explicitly
1420469Smpp// instantiate the CFG printer.
1520469Smpp//
1620469Smpp//===----------------------------------------------------------------------===//
1720469Smpp
1820469Smpp#ifndef LLVM_ANALYSIS_CFGPRINTER_H
1920469Smpp#define LLVM_ANALYSIS_CFGPRINTER_H
2020469Smpp
2120469Smpp#include "llvm/ADT/STLExtras.h"
2220469Smpp#include "llvm/Analysis/BlockFrequencyInfo.h"
2320469Smpp#include "llvm/Analysis/BranchProbabilityInfo.h"
2420469Smpp#include "llvm/Analysis/HeatUtils.h"
2520469Smpp#include "llvm/IR/CFG.h"
2620469Smpp#include "llvm/IR/Constants.h"
2720469Smpp#include "llvm/IR/Function.h"
2820469Smpp#include "llvm/IR/Instructions.h"
2920469Smpp#include "llvm/IR/PassManager.h"
3020469Smpp#include "llvm/Support/FormatVariadic.h"
3120469Smpp#include "llvm/Support/GraphWriter.h"
3220469Smpp
3320469Smppnamespace llvm {
3420469Smppclass CFGViewerPass : public PassInfoMixin<CFGViewerPass> {
3520469Smpppublic:
3620469Smpp  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
3720469Smpp};
3823466Sjmg
3950476Speterclass CFGOnlyViewerPass : public PassInfoMixin<CFGOnlyViewerPass> {
4023466Sjmgpublic:
41212827Sgjb  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
4277042Sru};
4377042Sru
4420469Smppclass CFGPrinterPass : public PassInfoMixin<CFGPrinterPass> {
4577042Srupublic:
46107788Sru  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
4720469Smpp};
48131479Sru
4977042Sruclass CFGOnlyPrinterPass : public PassInfoMixin<CFGOnlyPrinterPass> {
50131479Srupublic:
5120469Smpp  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
52107788Sru};
5368962Sru
5420469Smppclass DOTFuncInfo {
55107788Sruprivate:
5620469Smpp  const Function *F;
5760407Schris  const BlockFrequencyInfo *BFI;
5820469Smpp  const BranchProbabilityInfo *BPI;
59107788Sru  uint64_t MaxFreq;
6020469Smpp  bool ShowHeat;
6120469Smpp  bool EdgeWeights;
6220469Smpp  bool RawWeights;
6320469Smpp
6420469Smpppublic:
6520469Smpp  DOTFuncInfo(const Function *F) : DOTFuncInfo(F, nullptr, nullptr, 0) {}
6620469Smpp
6720469Smpp  DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI,
6820469Smpp              const BranchProbabilityInfo *BPI, uint64_t MaxFreq)
6920469Smpp      : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) {
7020469Smpp    ShowHeat = false;
7120469Smpp    EdgeWeights = !!BPI; // Print EdgeWeights when BPI is available.
7220469Smpp    RawWeights = !!BFI;  // Print RawWeights when BFI is available.
7320469Smpp  }
7420469Smpp
7520469Smpp  const BlockFrequencyInfo *getBFI() { return BFI; }
7620469Smpp
7720469Smpp  const BranchProbabilityInfo *getBPI() { return BPI; }
7820469Smpp
7920469Smpp  const Function *getFunction() { return this->F; }
8020469Smpp
8120469Smpp  uint64_t getMaxFreq() { return MaxFreq; }
8220469Smpp
8320469Smpp  uint64_t getFreq(const BasicBlock *BB) {
8420469Smpp    return BFI->getBlockFreq(BB).getFrequency();
8579727Sschweikh  }
8620469Smpp
8720469Smpp  void setHeatColors(bool ShowHeat) { this->ShowHeat = ShowHeat; }
8820469Smpp
8920469Smpp  bool showHeatColors() { return ShowHeat; }
90212827Sgjb
91212827Sgjb  void setRawEdgeWeights(bool RawWeights) { this->RawWeights = RawWeights; }
92212827Sgjb
93212827Sgjb  bool useRawEdgeWeights() { return RawWeights; }
94212827Sgjb
95212827Sgjb  void setEdgeWeights(bool EdgeWeights) { this->EdgeWeights = EdgeWeights; }
96212827Sgjb
97212827Sgjb  bool showEdgeWeights() { return EdgeWeights; }
98212827Sgjb};
99208028Suqs
100208028Suqstemplate <>
101208028Suqsstruct GraphTraits<DOTFuncInfo *> : public GraphTraits<const BasicBlock *> {
102208028Suqs  static NodeRef getEntryNode(DOTFuncInfo *CFGInfo) {
103164464Srodrigc    return &(CFGInfo->getFunction()->getEntryBlock());
104237216Seadler  }
105164464Srodrigc
106164464Srodrigc  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
107164464Srodrigc  using nodes_iterator = pointer_iterator<Function::const_iterator>;
108164464Srodrigc
109164464Srodrigc  static nodes_iterator nodes_begin(DOTFuncInfo *CFGInfo) {
11020469Smpp    return nodes_iterator(CFGInfo->getFunction()->begin());
111212827Sgjb  }
112164464Srodrigc
11320469Smpp  static nodes_iterator nodes_end(DOTFuncInfo *CFGInfo) {
11420469Smpp    return nodes_iterator(CFGInfo->getFunction()->end());
11579727Sschweikh  }
116107788Sru
11749831Smpp  static size_t size(DOTFuncInfo *CFGInfo) {
11879727Sschweikh    return CFGInfo->getFunction()->size();
11920469Smpp  }
12079727Sschweikh};
12120469Smpp
12234504Scharniertemplate <>
12369027Srustruct DOTGraphTraits<DOTFuncInfo *> : public DefaultDOTGraphTraits {
12479727Sschweikh
12520469Smpp  // Cache for is hidden property
12634504Scharnier  llvm::DenseMap<const BasicBlock *, bool> isHiddenBasicBlock;
12734504Scharnier
12879727Sschweikh  DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
12934504Scharnier
13034504Scharnier  static std::string getGraphName(DOTFuncInfo *CFGInfo) {
131    return "CFG for '" + CFGInfo->getFunction()->getName().str() + "' function";
132  }
133
134  static std::string getSimpleNodeLabel(const BasicBlock *Node, DOTFuncInfo *) {
135    if (!Node->getName().empty())
136      return Node->getName().str();
137
138    std::string Str;
139    raw_string_ostream OS(Str);
140
141    Node->printAsOperand(OS, false);
142    return OS.str();
143  }
144
145  static void eraseComment(std::string &OutStr, unsigned &I, unsigned Idx) {
146    OutStr.erase(OutStr.begin() + I, OutStr.begin() + Idx);
147    --I;
148  }
149
150  static std::string getCompleteNodeLabel(
151      const BasicBlock *Node, DOTFuncInfo *,
152      llvm::function_ref<void(raw_string_ostream &, const BasicBlock &)>
153          HandleBasicBlock = [](raw_string_ostream &OS,
154                                const BasicBlock &Node) -> void { OS << Node; },
155      llvm::function_ref<void(std::string &, unsigned &, unsigned)>
156          HandleComment = eraseComment) {
157    enum { MaxColumns = 80 };
158    std::string Str;
159    raw_string_ostream OS(Str);
160
161    if (Node->getName().empty()) {
162      Node->printAsOperand(OS, false);
163      OS << ":";
164    }
165
166    HandleBasicBlock(OS, *Node);
167    std::string OutStr = OS.str();
168    if (OutStr[0] == '\n')
169      OutStr.erase(OutStr.begin());
170
171    // Process string output to make it nicer...
172    unsigned ColNum = 0;
173    unsigned LastSpace = 0;
174    for (unsigned i = 0; i != OutStr.length(); ++i) {
175      if (OutStr[i] == '\n') { // Left justify
176        OutStr[i] = '\\';
177        OutStr.insert(OutStr.begin() + i + 1, 'l');
178        ColNum = 0;
179        LastSpace = 0;
180      } else if (OutStr[i] == ';') {             // Delete comments!
181        unsigned Idx = OutStr.find('\n', i + 1); // Find end of line
182        HandleComment(OutStr, i, Idx);
183      } else if (ColNum == MaxColumns) { // Wrap lines.
184        // Wrap very long names even though we can't find a space.
185        if (!LastSpace)
186          LastSpace = i;
187        OutStr.insert(LastSpace, "\\l...");
188        ColNum = i - LastSpace;
189        LastSpace = 0;
190        i += 3; // The loop will advance 'i' again.
191      } else
192        ++ColNum;
193      if (OutStr[i] == ' ')
194        LastSpace = i;
195    }
196    return OutStr;
197  }
198
199  std::string getNodeLabel(const BasicBlock *Node, DOTFuncInfo *CFGInfo) {
200
201    if (isSimple())
202      return getSimpleNodeLabel(Node, CFGInfo);
203    else
204      return getCompleteNodeLabel(Node, CFGInfo);
205  }
206
207  static std::string getEdgeSourceLabel(const BasicBlock *Node,
208                                        const_succ_iterator I) {
209    // Label source of conditional branches with "T" or "F"
210    if (const BranchInst *BI = dyn_cast<BranchInst>(Node->getTerminator()))
211      if (BI->isConditional())
212        return (I == succ_begin(Node)) ? "T" : "F";
213
214    // Label source of switch edges with the associated value.
215    if (const SwitchInst *SI = dyn_cast<SwitchInst>(Node->getTerminator())) {
216      unsigned SuccNo = I.getSuccessorIndex();
217
218      if (SuccNo == 0)
219        return "def";
220
221      std::string Str;
222      raw_string_ostream OS(Str);
223      auto Case = *SwitchInst::ConstCaseIt::fromSuccessorIndex(SI, SuccNo);
224      OS << Case.getCaseValue()->getValue();
225      return OS.str();
226    }
227    return "";
228  }
229
230  /// Display the raw branch weights from PGO.
231  std::string getEdgeAttributes(const BasicBlock *Node, const_succ_iterator I,
232                                DOTFuncInfo *CFGInfo) {
233    if (!CFGInfo->showEdgeWeights())
234      return "";
235
236    const Instruction *TI = Node->getTerminator();
237    if (TI->getNumSuccessors() == 1)
238      return "penwidth=2";
239
240    unsigned OpNo = I.getSuccessorIndex();
241
242    if (OpNo >= TI->getNumSuccessors())
243      return "";
244
245    BasicBlock *SuccBB = TI->getSuccessor(OpNo);
246    auto BranchProb = CFGInfo->getBPI()->getEdgeProbability(Node, SuccBB);
247    double WeightPercent = ((double)BranchProb.getNumerator()) /
248                           ((double)BranchProb.getDenominator());
249    double Width = 1 + WeightPercent;
250
251    if (!CFGInfo->useRawEdgeWeights())
252      return formatv("label=\"{0:P}\" penwidth={1}", WeightPercent, Width)
253          .str();
254
255    // Prepend a 'W' to indicate that this is a weight rather than the actual
256    // profile count (due to scaling).
257
258    uint64_t Freq = CFGInfo->getFreq(Node);
259    std::string Attrs = formatv("label=\"W:{0}\" penwidth={1}",
260                                (uint64_t)(Freq * WeightPercent), Width);
261    if (Attrs.size())
262      return Attrs;
263
264    MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
265    if (!WeightsNode)
266      return "";
267
268    MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
269    if (MDName->getString() != "branch_weights")
270      return "";
271
272    OpNo = I.getSuccessorIndex() + 1;
273    if (OpNo >= WeightsNode->getNumOperands())
274      return "";
275    ConstantInt *Weight =
276        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(OpNo));
277    if (!Weight)
278      return "";
279    return ("label=\"W:" + std::to_string(Weight->getZExtValue()) +
280            "\" penwidth=" + std::to_string(Width));
281  }
282
283  std::string getNodeAttributes(const BasicBlock *Node, DOTFuncInfo *CFGInfo) {
284
285    if (!CFGInfo->showHeatColors())
286      return "";
287
288    uint64_t Freq = CFGInfo->getFreq(Node);
289    std::string Color = getHeatColor(Freq, CFGInfo->getMaxFreq());
290    std::string EdgeColor = (Freq <= (CFGInfo->getMaxFreq() / 2))
291                                ? (getHeatColor(0))
292                                : (getHeatColor(1));
293
294    std::string Attrs = "color=\"" + EdgeColor + "ff\", style=filled," +
295                        " fillcolor=\"" + Color + "70\"";
296    return Attrs;
297  }
298  bool isNodeHidden(const BasicBlock *Node, const DOTFuncInfo *CFGInfo);
299  void computeHiddenNodes(const Function *F);
300};
301} // End llvm namespace
302
303namespace llvm {
304class FunctionPass;
305FunctionPass *createCFGPrinterLegacyPassPass();
306FunctionPass *createCFGOnlyPrinterLegacyPassPass();
307} // End llvm namespace
308
309#endif
310