1//===-- SpeculateAnalyses.cpp  --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/ExecutionEngine/Orc/SpeculateAnalyses.h"
10#include "llvm/ADT/ArrayRef.h"
11#include "llvm/ADT/DenseMap.h"
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/ADT/SmallPtrSet.h"
14#include "llvm/ADT/SmallVector.h"
15#include "llvm/Analysis/BlockFrequencyInfo.h"
16#include "llvm/Analysis/BranchProbabilityInfo.h"
17#include "llvm/Analysis/CFG.h"
18#include "llvm/IR/PassManager.h"
19#include "llvm/Passes/PassBuilder.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#include <algorithm>
23
24namespace {
25using namespace llvm;
26SmallVector<const BasicBlock *, 8> findBBwithCalls(const Function &F,
27                                                   bool IndirectCall = false) {
28  SmallVector<const BasicBlock *, 8> BBs;
29
30  auto findCallInst = [&IndirectCall](const Instruction &I) {
31    if (auto Call = dyn_cast<CallBase>(&I))
32      return Call->isIndirectCall() ? IndirectCall : true;
33    else
34      return false;
35  };
36  for (auto &BB : F)
37    if (findCallInst(*BB.getTerminator()) ||
38        llvm::any_of(BB.instructionsWithoutDebug(), findCallInst))
39      BBs.emplace_back(&BB);
40
41  return BBs;
42}
43} // namespace
44
45// Implementations of Queries shouldn't need to lock the resources
46// such as LLVMContext, each argument (function) has a non-shared LLVMContext
47// Plus, if Queries contain states necessary locking scheme should be provided.
48namespace llvm {
49namespace orc {
50
51// Collect direct calls only
52void SpeculateQuery::findCalles(const BasicBlock *BB,
53                                DenseSet<StringRef> &CallesNames) {
54  assert(BB != nullptr && "Traversing Null BB to find calls?");
55
56  auto getCalledFunction = [&CallesNames](const CallBase *Call) {
57    auto CalledValue = Call->getCalledOperand()->stripPointerCasts();
58    if (auto DirectCall = dyn_cast<Function>(CalledValue))
59      CallesNames.insert(DirectCall->getName());
60  };
61  for (auto &I : BB->instructionsWithoutDebug())
62    if (auto CI = dyn_cast<CallInst>(&I))
63      getCalledFunction(CI);
64
65  if (auto II = dyn_cast<InvokeInst>(BB->getTerminator()))
66    getCalledFunction(II);
67}
68
69bool SpeculateQuery::isStraightLine(const Function &F) {
70  return llvm::all_of(F.getBasicBlockList(), [](const BasicBlock &BB) {
71    return BB.getSingleSuccessor() != nullptr;
72  });
73}
74
75// BlockFreqQuery Implementations
76
77size_t BlockFreqQuery::numBBToGet(size_t numBB) {
78  // small CFG
79  if (numBB < 4)
80    return numBB;
81  // mid-size CFG
82  else if (numBB < 20)
83    return (numBB / 2);
84  else
85    return (numBB / 2) + (numBB / 4);
86}
87
88BlockFreqQuery::ResultTy BlockFreqQuery::operator()(Function &F) {
89  DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
90  DenseSet<StringRef> Calles;
91  SmallVector<std::pair<const BasicBlock *, uint64_t>, 8> BBFreqs;
92
93  PassBuilder PB;
94  FunctionAnalysisManager FAM;
95  PB.registerFunctionAnalyses(FAM);
96
97  auto IBBs = findBBwithCalls(F);
98
99  if (IBBs.empty())
100    return None;
101
102  auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
103
104  for (const auto I : IBBs)
105    BBFreqs.push_back({I, BFI.getBlockFreq(I).getFrequency()});
106
107  assert(IBBs.size() == BBFreqs.size() && "BB Count Mismatch");
108
109  llvm::sort(BBFreqs.begin(), BBFreqs.end(),
110             [](decltype(BBFreqs)::const_reference BBF,
111                decltype(BBFreqs)::const_reference BBS) {
112               return BBF.second > BBS.second ? true : false;
113             });
114
115  // ignoring number of direct calls in a BB
116  auto Topk = numBBToGet(BBFreqs.size());
117
118  for (size_t i = 0; i < Topk; i++)
119    findCalles(BBFreqs[i].first, Calles);
120
121  assert(!Calles.empty() && "Running Analysis on Function with no calls?");
122
123  CallerAndCalles.insert({F.getName(), std::move(Calles)});
124
125  return CallerAndCalles;
126}
127
128// SequenceBBQuery Implementation
129std::size_t SequenceBBQuery::getHottestBlocks(std::size_t TotalBlocks) {
130  if (TotalBlocks == 1)
131    return TotalBlocks;
132  return TotalBlocks / 2;
133}
134
135// FIXME : find good implementation.
136SequenceBBQuery::BlockListTy
137SequenceBBQuery::rearrangeBB(const Function &F, const BlockListTy &BBList) {
138  BlockListTy RearrangedBBSet;
139
140  for (auto &Block : F.getBasicBlockList())
141    if (llvm::is_contained(BBList, &Block))
142      RearrangedBBSet.push_back(&Block);
143
144  assert(RearrangedBBSet.size() == BBList.size() &&
145         "BasicBlock missing while rearranging?");
146  return RearrangedBBSet;
147}
148
149void SequenceBBQuery::traverseToEntryBlock(const BasicBlock *AtBB,
150                                           const BlockListTy &CallerBlocks,
151                                           const BackEdgesInfoTy &BackEdgesInfo,
152                                           const BranchProbabilityInfo *BPI,
153                                           VisitedBlocksInfoTy &VisitedBlocks) {
154  auto Itr = VisitedBlocks.find(AtBB);
155  if (Itr != VisitedBlocks.end()) { // already visited.
156    if (!Itr->second.Upward)
157      return;
158    Itr->second.Upward = false;
159  } else {
160    // Create hint for newly discoverd blocks.
161    WalkDirection BlockHint;
162    BlockHint.Upward = false;
163    // FIXME: Expensive Check
164    if (llvm::is_contained(CallerBlocks, AtBB))
165      BlockHint.CallerBlock = true;
166    VisitedBlocks.insert(std::make_pair(AtBB, BlockHint));
167  }
168
169  const_pred_iterator PIt = pred_begin(AtBB), EIt = pred_end(AtBB);
170  // Move this check to top, when we have code setup to launch speculative
171  // compiles for function in entry BB, this triggers the speculative compiles
172  // before running the program.
173  if (PIt == EIt) // No Preds.
174    return;
175
176  DenseSet<const BasicBlock *> PredSkipNodes;
177
178  // Since we are checking for predecessor's backedges, this Block
179  // occurs in second position.
180  for (auto &I : BackEdgesInfo)
181    if (I.second == AtBB)
182      PredSkipNodes.insert(I.first);
183
184  // Skip predecessors which source of back-edges.
185  for (; PIt != EIt; ++PIt)
186    // checking EdgeHotness is cheaper
187    if (BPI->isEdgeHot(*PIt, AtBB) && !PredSkipNodes.count(*PIt))
188      traverseToEntryBlock(*PIt, CallerBlocks, BackEdgesInfo, BPI,
189                           VisitedBlocks);
190}
191
192void SequenceBBQuery::traverseToExitBlock(const BasicBlock *AtBB,
193                                          const BlockListTy &CallerBlocks,
194                                          const BackEdgesInfoTy &BackEdgesInfo,
195                                          const BranchProbabilityInfo *BPI,
196                                          VisitedBlocksInfoTy &VisitedBlocks) {
197  auto Itr = VisitedBlocks.find(AtBB);
198  if (Itr != VisitedBlocks.end()) { // already visited.
199    if (!Itr->second.Downward)
200      return;
201    Itr->second.Downward = false;
202  } else {
203    // Create hint for newly discoverd blocks.
204    WalkDirection BlockHint;
205    BlockHint.Downward = false;
206    // FIXME: Expensive Check
207    if (llvm::is_contained(CallerBlocks, AtBB))
208      BlockHint.CallerBlock = true;
209    VisitedBlocks.insert(std::make_pair(AtBB, BlockHint));
210  }
211
212  succ_const_iterator PIt = succ_begin(AtBB), EIt = succ_end(AtBB);
213  if (PIt == EIt) // No succs.
214    return;
215
216  // If there are hot edges, then compute SuccSkipNodes.
217  DenseSet<const BasicBlock *> SuccSkipNodes;
218
219  // Since we are checking for successor's backedges, this Block
220  // occurs in first position.
221  for (auto &I : BackEdgesInfo)
222    if (I.first == AtBB)
223      SuccSkipNodes.insert(I.second);
224
225  for (; PIt != EIt; ++PIt)
226    if (BPI->isEdgeHot(AtBB, *PIt) && !SuccSkipNodes.count(*PIt))
227      traverseToExitBlock(*PIt, CallerBlocks, BackEdgesInfo, BPI,
228                          VisitedBlocks);
229}
230
231// Get Block frequencies for blocks and take most frquently executed block,
232// walk towards the entry block from those blocks and discover the basic blocks
233// with call.
234SequenceBBQuery::BlockListTy
235SequenceBBQuery::queryCFG(Function &F, const BlockListTy &CallerBlocks) {
236
237  BlockFreqInfoTy BBFreqs;
238  VisitedBlocksInfoTy VisitedBlocks;
239  BackEdgesInfoTy BackEdgesInfo;
240
241  PassBuilder PB;
242  FunctionAnalysisManager FAM;
243  PB.registerFunctionAnalyses(FAM);
244
245  auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
246
247  llvm::FindFunctionBackedges(F, BackEdgesInfo);
248
249  for (const auto I : CallerBlocks)
250    BBFreqs.push_back({I, BFI.getBlockFreq(I).getFrequency()});
251
252  llvm::sort(BBFreqs, [](decltype(BBFreqs)::const_reference Bbf,
253                         decltype(BBFreqs)::const_reference Bbs) {
254    return Bbf.second > Bbs.second;
255  });
256
257  ArrayRef<std::pair<const BasicBlock *, uint64_t>> HotBlocksRef(BBFreqs);
258  HotBlocksRef =
259      HotBlocksRef.drop_back(BBFreqs.size() - getHottestBlocks(BBFreqs.size()));
260
261  BranchProbabilityInfo *BPI =
262      FAM.getCachedResult<BranchProbabilityAnalysis>(F);
263
264  // visit NHotBlocks,
265  // traverse upwards to entry
266  // traverse downwards to end.
267
268  for (auto I : HotBlocksRef) {
269    traverseToEntryBlock(I.first, CallerBlocks, BackEdgesInfo, BPI,
270                         VisitedBlocks);
271    traverseToExitBlock(I.first, CallerBlocks, BackEdgesInfo, BPI,
272                        VisitedBlocks);
273  }
274
275  BlockListTy MinCallerBlocks;
276  for (auto &I : VisitedBlocks)
277    if (I.second.CallerBlock)
278      MinCallerBlocks.push_back(std::move(I.first));
279
280  return rearrangeBB(F, MinCallerBlocks);
281}
282
283SpeculateQuery::ResultTy SequenceBBQuery::operator()(Function &F) {
284  // reduce the number of lists!
285  DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
286  DenseSet<StringRef> Calles;
287  BlockListTy SequencedBlocks;
288  BlockListTy CallerBlocks;
289
290  CallerBlocks = findBBwithCalls(F);
291  if (CallerBlocks.empty())
292    return None;
293
294  if (isStraightLine(F))
295    SequencedBlocks = rearrangeBB(F, CallerBlocks);
296  else
297    SequencedBlocks = queryCFG(F, CallerBlocks);
298
299  for (auto BB : SequencedBlocks)
300    findCalles(BB, Calles);
301
302  CallerAndCalles.insert({F.getName(), std::move(Calles)});
303  return CallerAndCalles;
304}
305
306} // namespace orc
307} // namespace llvm
308