167117Sdfr//===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
267117Sdfr//
3106973Smarcel// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4134287Smarcel// See https://llvm.org/LICENSE.txt for license information.
5134287Smarcel// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6231981Skib//
7134287Smarcel//===----------------------------------------------------------------------===//
894569Smarcel//
994569Smarcel// This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
1094569Smarcel// containing the call graph of a module.
1194569Smarcel//
12// There is also a pass available to directly call dotty ('-view-callgraph').
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/Analysis/CallPrinter.h"
17#include "llvm/ADT/DenseMap.h"
18#include "llvm/ADT/SmallSet.h"
19#include "llvm/Analysis/BlockFrequencyInfo.h"
20#include "llvm/Analysis/CallGraph.h"
21#include "llvm/Analysis/HeatUtils.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/InitializePasses.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/DOTGraphTraits.h"
26#include "llvm/Support/GraphWriter.h"
27
28using namespace llvm;
29
30namespace llvm {
31template <class GraphType> struct GraphTraits;
32}
33
34// This option shows static (relative) call counts.
35// FIXME:
36// Need to show real counts when profile data is available
37static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
38                                    cl::Hidden,
39                                    cl::desc("Show heat colors in call-graph"));
40
41static cl::opt<bool>
42    ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
43                       cl::desc("Show edges labeled with weights"));
44
45static cl::opt<bool>
46    CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
47            cl::desc("Show call-multigraph (do not remove parallel edges)"));
48
49static cl::opt<std::string> CallGraphDotFilenamePrefix(
50    "callgraph-dot-filename-prefix", cl::Hidden,
51    cl::desc("The prefix used for the CallGraph dot file names."));
52
53namespace llvm {
54
55class CallGraphDOTInfo {
56private:
57  Module *M;
58  CallGraph *CG;
59  DenseMap<const Function *, uint64_t> Freq;
60  uint64_t MaxFreq;
61
62public:
63  std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
64
65  CallGraphDOTInfo(Module *M, CallGraph *CG,
66                   function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
67      : M(M), CG(CG), LookupBFI(LookupBFI) {
68    MaxFreq = 0;
69
70    for (Function &F : M->getFunctionList()) {
71      uint64_t localSumFreq = 0;
72      SmallSet<Function *, 16> Callers;
73      for (User *U : F.users())
74        if (isa<CallInst>(U))
75          Callers.insert(cast<Instruction>(U)->getFunction());
76      for (Function *Caller : Callers)
77        localSumFreq += getNumOfCalls(*Caller, F);
78      if (localSumFreq >= MaxFreq)
79        MaxFreq = localSumFreq;
80      Freq[&F] = localSumFreq;
81    }
82    if (!CallMultiGraph)
83      removeParallelEdges();
84  }
85
86  Module *getModule() const { return M; }
87
88  CallGraph *getCallGraph() const { return CG; }
89
90  uint64_t getFreq(const Function *F) { return Freq[F]; }
91
92  uint64_t getMaxFreq() { return MaxFreq; }
93
94private:
95  void removeParallelEdges() {
96    for (auto &I : (*CG)) {
97      CallGraphNode *Node = I.second.get();
98
99      bool FoundParallelEdge = true;
100      while (FoundParallelEdge) {
101        SmallSet<Function *, 16> Visited;
102        FoundParallelEdge = false;
103        for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
104          if (!(Visited.insert(CI->second->getFunction())).second) {
105            FoundParallelEdge = true;
106            Node->removeCallEdge(CI);
107            break;
108          }
109        }
110      }
111    }
112  }
113};
114
115template <>
116struct GraphTraits<CallGraphDOTInfo *>
117    : public GraphTraits<const CallGraphNode *> {
118  static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
119    // Start at the external node!
120    return CGInfo->getCallGraph()->getExternalCallingNode();
121  }
122
123  typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
124      PairTy;
125  static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
126    return P.second.get();
127  }
128
129  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
130  typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
131      nodes_iterator;
132
133  static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
134    return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
135  }
136  static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
137    return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
138  }
139};
140
141template <>
142struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
143
144  DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
145
146  static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
147    return "Call graph: " +
148           std::string(CGInfo->getModule()->getModuleIdentifier());
149  }
150
151  static bool isNodeHidden(const CallGraphNode *Node,
152                           const CallGraphDOTInfo *CGInfo) {
153    if (CallMultiGraph || Node->getFunction())
154      return false;
155    return true;
156  }
157
158  std::string getNodeLabel(const CallGraphNode *Node,
159                           CallGraphDOTInfo *CGInfo) {
160    if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
161      return "external caller";
162    if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
163      return "external callee";
164
165    if (Function *Func = Node->getFunction())
166      return std::string(Func->getName());
167    return "external node";
168  }
169  static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
170    return P.second;
171  }
172
173  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
174  typedef mapped_iterator<CallGraphNode::const_iterator,
175                          decltype(&CGGetValuePtr)>
176      nodes_iterator;
177
178  std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
179                                CallGraphDOTInfo *CGInfo) {
180    if (!ShowEdgeWeight)
181      return "";
182
183    Function *Caller = Node->getFunction();
184    if (Caller == nullptr || Caller->isDeclaration())
185      return "";
186
187    Function *Callee = (*I)->getFunction();
188    if (Callee == nullptr)
189      return "";
190
191    uint64_t Counter = getNumOfCalls(*Caller, *Callee);
192    double Width =
193        1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
194    std::string Attrs = "label=\"" + std::to_string(Counter) +
195                        "\" penwidth=" + std::to_string(Width);
196    return Attrs;
197  }
198
199  std::string getNodeAttributes(const CallGraphNode *Node,
200                                CallGraphDOTInfo *CGInfo) {
201    Function *F = Node->getFunction();
202    if (F == nullptr)
203      return "";
204    std::string attrs;
205    if (ShowHeatColors) {
206      uint64_t freq = CGInfo->getFreq(F);
207      std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
208      std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
209                                  ? getHeatColor(0)
210                                  : getHeatColor(1);
211      attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
212              color + "80\"";
213    }
214    return attrs;
215  }
216};
217
218} // end llvm namespace
219
220namespace {
221void doCallGraphDOTPrinting(
222    Module &M, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
223  std::string Filename;
224  if (!CallGraphDotFilenamePrefix.empty())
225    Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
226  else
227    Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
228  errs() << "Writing '" << Filename << "'...";
229
230  std::error_code EC;
231  raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
232
233  CallGraph CG(M);
234  CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
235
236  if (!EC)
237    WriteGraph(File, &CFGInfo);
238  else
239    errs() << "  error opening file for writing!";
240  errs() << "\n";
241}
242
243void viewCallGraph(Module &M,
244                   function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
245  CallGraph CG(M);
246  CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
247
248  std::string Title =
249      DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
250  ViewGraph(&CFGInfo, "callgraph", true, Title);
251}
252} // namespace
253
254namespace llvm {
255PreservedAnalyses CallGraphDOTPrinterPass::run(Module &M,
256                                               ModuleAnalysisManager &AM) {
257  FunctionAnalysisManager &FAM =
258      AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
259
260  auto LookupBFI = [&FAM](Function &F) {
261    return &FAM.getResult<BlockFrequencyAnalysis>(F);
262  };
263
264  doCallGraphDOTPrinting(M, LookupBFI);
265
266  return PreservedAnalyses::all();
267}
268
269PreservedAnalyses CallGraphViewerPass::run(Module &M,
270                                           ModuleAnalysisManager &AM) {
271
272  FunctionAnalysisManager &FAM =
273      AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
274
275  auto LookupBFI = [&FAM](Function &F) {
276    return &FAM.getResult<BlockFrequencyAnalysis>(F);
277  };
278
279  viewCallGraph(M, LookupBFI);
280
281  return PreservedAnalyses::all();
282}
283} // namespace llvm
284
285namespace {
286// Viewer
287class CallGraphViewer : public ModulePass {
288public:
289  static char ID;
290  CallGraphViewer() : ModulePass(ID) {}
291
292  void getAnalysisUsage(AnalysisUsage &AU) const override;
293  bool runOnModule(Module &M) override;
294};
295
296void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
297  ModulePass::getAnalysisUsage(AU);
298  AU.addRequired<BlockFrequencyInfoWrapperPass>();
299  AU.setPreservesAll();
300}
301
302bool CallGraphViewer::runOnModule(Module &M) {
303  auto LookupBFI = [this](Function &F) {
304    return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
305  };
306
307  viewCallGraph(M, LookupBFI);
308
309  return false;
310}
311
312// DOT Printer
313
314class CallGraphDOTPrinter : public ModulePass {
315public:
316  static char ID;
317  CallGraphDOTPrinter() : ModulePass(ID) {}
318
319  void getAnalysisUsage(AnalysisUsage &AU) const override;
320  bool runOnModule(Module &M) override;
321};
322
323void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
324  ModulePass::getAnalysisUsage(AU);
325  AU.addRequired<BlockFrequencyInfoWrapperPass>();
326  AU.setPreservesAll();
327}
328
329bool CallGraphDOTPrinter::runOnModule(Module &M) {
330  auto LookupBFI = [this](Function &F) {
331    return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
332  };
333
334  doCallGraphDOTPrinting(M, LookupBFI);
335
336  return false;
337}
338
339} // end anonymous namespace
340
341char CallGraphViewer::ID = 0;
342INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
343                false)
344
345char CallGraphDOTPrinter::ID = 0;
346INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
347                "Print call graph to 'dot' file", false, false)
348
349// Create methods available outside of this file, to use them
350// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
351// the link time optimization.
352
353ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
354
355ModulePass *llvm::createCallGraphDOTPrinterPass() {
356  return new CallGraphDOTPrinter();
357}
358