1//===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a Loop Data Prefetching Pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14#include "llvm/InitializePasses.h"
15
16#define DEBUG_TYPE "loop-data-prefetch"
17#include "llvm/ADT/DepthFirstIterator.h"
18#include "llvm/ADT/Statistic.h"
19#include "llvm/Analysis/AssumptionCache.h"
20#include "llvm/Analysis/CodeMetrics.h"
21#include "llvm/Analysis/LoopInfo.h"
22#include "llvm/Analysis/OptimizationRemarkEmitter.h"
23#include "llvm/Analysis/ScalarEvolution.h"
24#include "llvm/Analysis/ScalarEvolutionExpressions.h"
25#include "llvm/Analysis/TargetTransformInfo.h"
26#include "llvm/IR/CFG.h"
27#include "llvm/IR/Dominators.h"
28#include "llvm/IR/Function.h"
29#include "llvm/IR/Module.h"
30#include "llvm/Support/CommandLine.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Transforms/Scalar.h"
33#include "llvm/Transforms/Utils/BasicBlockUtils.h"
34#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
35#include "llvm/Transforms/Utils/ValueMapper.h"
36using namespace llvm;
37
38// By default, we limit this to creating 16 PHIs (which is a little over half
39// of the allocatable register set).
40static cl::opt<bool>
41PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
42               cl::desc("Prefetch write addresses"));
43
44static cl::opt<unsigned>
45    PrefetchDistance("prefetch-distance",
46                     cl::desc("Number of instructions to prefetch ahead"),
47                     cl::Hidden);
48
49static cl::opt<unsigned>
50    MinPrefetchStride("min-prefetch-stride",
51                      cl::desc("Min stride to add prefetches"), cl::Hidden);
52
53static cl::opt<unsigned> MaxPrefetchIterationsAhead(
54    "max-prefetch-iters-ahead",
55    cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
56
57STATISTIC(NumPrefetches, "Number of prefetches inserted");
58
59namespace {
60
61/// Loop prefetch implementation class.
62class LoopDataPrefetch {
63public:
64  LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
65                   ScalarEvolution *SE, const TargetTransformInfo *TTI,
66                   OptimizationRemarkEmitter *ORE)
67      : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
68
69  bool run();
70
71private:
72  bool runOnLoop(Loop *L);
73
74  /// Check if the stride of the accesses is large enough to
75  /// warrant a prefetch.
76  bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
77
78  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
79                                unsigned NumStridedMemAccesses,
80                                unsigned NumPrefetches,
81                                bool HasCall) {
82    if (MinPrefetchStride.getNumOccurrences() > 0)
83      return MinPrefetchStride;
84    return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
85                                     NumPrefetches, HasCall);
86  }
87
88  unsigned getPrefetchDistance() {
89    if (PrefetchDistance.getNumOccurrences() > 0)
90      return PrefetchDistance;
91    return TTI->getPrefetchDistance();
92  }
93
94  unsigned getMaxPrefetchIterationsAhead() {
95    if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
96      return MaxPrefetchIterationsAhead;
97    return TTI->getMaxPrefetchIterationsAhead();
98  }
99
100  bool doPrefetchWrites() {
101    if (PrefetchWrites.getNumOccurrences() > 0)
102      return PrefetchWrites;
103    return TTI->enableWritePrefetching();
104  }
105
106  AssumptionCache *AC;
107  DominatorTree *DT;
108  LoopInfo *LI;
109  ScalarEvolution *SE;
110  const TargetTransformInfo *TTI;
111  OptimizationRemarkEmitter *ORE;
112};
113
114/// Legacy class for inserting loop data prefetches.
115class LoopDataPrefetchLegacyPass : public FunctionPass {
116public:
117  static char ID; // Pass ID, replacement for typeid
118  LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
119    initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
120  }
121
122  void getAnalysisUsage(AnalysisUsage &AU) const override {
123    AU.addRequired<AssumptionCacheTracker>();
124    AU.addRequired<DominatorTreeWrapperPass>();
125    AU.addPreserved<DominatorTreeWrapperPass>();
126    AU.addRequired<LoopInfoWrapperPass>();
127    AU.addPreserved<LoopInfoWrapperPass>();
128    AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
129    AU.addRequired<ScalarEvolutionWrapperPass>();
130    AU.addPreserved<ScalarEvolutionWrapperPass>();
131    AU.addRequired<TargetTransformInfoWrapperPass>();
132  }
133
134  bool runOnFunction(Function &F) override;
135  };
136}
137
138char LoopDataPrefetchLegacyPass::ID = 0;
139INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
140                      "Loop Data Prefetch", false, false)
141INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
142INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
143INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
144INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
145INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
146INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
147                    "Loop Data Prefetch", false, false)
148
149FunctionPass *llvm::createLoopDataPrefetchPass() {
150  return new LoopDataPrefetchLegacyPass();
151}
152
153bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
154                                           unsigned TargetMinStride) {
155  // No need to check if any stride goes.
156  if (TargetMinStride <= 1)
157    return true;
158
159  const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
160  // If MinStride is set, don't prefetch unless we can ensure that stride is
161  // larger.
162  if (!ConstStride)
163    return false;
164
165  unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
166  return TargetMinStride <= AbsStride;
167}
168
169PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
170                                            FunctionAnalysisManager &AM) {
171  DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
172  LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
173  ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
174  AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
175  OptimizationRemarkEmitter *ORE =
176      &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
177  const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
178
179  LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
180  bool Changed = LDP.run();
181
182  if (Changed) {
183    PreservedAnalyses PA;
184    PA.preserve<DominatorTreeAnalysis>();
185    PA.preserve<LoopAnalysis>();
186    return PA;
187  }
188
189  return PreservedAnalyses::all();
190}
191
192bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
193  if (skipFunction(F))
194    return false;
195
196  DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
197  LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
198  ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
199  AssumptionCache *AC =
200      &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
201  OptimizationRemarkEmitter *ORE =
202      &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
203  const TargetTransformInfo *TTI =
204      &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
205
206  LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
207  return LDP.run();
208}
209
210bool LoopDataPrefetch::run() {
211  // If PrefetchDistance is not set, don't run the pass.  This gives an
212  // opportunity for targets to run this pass for selected subtargets only
213  // (whose TTI sets PrefetchDistance).
214  if (getPrefetchDistance() == 0)
215    return false;
216  assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
217
218  bool MadeChange = false;
219
220  for (Loop *I : *LI)
221    for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
222      MadeChange |= runOnLoop(*L);
223
224  return MadeChange;
225}
226
227/// A record for a potential prefetch made during the initial scan of the
228/// loop. This is used to let a single prefetch target multiple memory accesses.
229struct Prefetch {
230  /// The address formula for this prefetch as returned by ScalarEvolution.
231  const SCEVAddRecExpr *LSCEVAddRec;
232  /// The point of insertion for the prefetch instruction.
233  Instruction *InsertPt;
234  /// True if targeting a write memory access.
235  bool Writes;
236  /// The (first seen) prefetched instruction.
237  Instruction *MemI;
238
239  /// Constructor to create a new Prefetch for \p I.
240  Prefetch(const SCEVAddRecExpr *L, Instruction *I)
241      : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
242    addInstruction(I);
243  };
244
245  /// Add the instruction \param I to this prefetch. If it's not the first
246  /// one, 'InsertPt' and 'Writes' will be updated as required.
247  /// \param PtrDiff the known constant address difference to the first added
248  /// instruction.
249  void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
250                      int64_t PtrDiff = 0) {
251    if (!InsertPt) {
252      MemI = I;
253      InsertPt = I;
254      Writes = isa<StoreInst>(I);
255    } else {
256      BasicBlock *PrefBB = InsertPt->getParent();
257      BasicBlock *InsBB = I->getParent();
258      if (PrefBB != InsBB) {
259        BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
260        if (DomBB != PrefBB)
261          InsertPt = DomBB->getTerminator();
262      }
263
264      if (isa<StoreInst>(I) && PtrDiff == 0)
265        Writes = true;
266    }
267  }
268};
269
270bool LoopDataPrefetch::runOnLoop(Loop *L) {
271  bool MadeChange = false;
272
273  // Only prefetch in the inner-most loop
274  if (!L->empty())
275    return MadeChange;
276
277  SmallPtrSet<const Value *, 32> EphValues;
278  CodeMetrics::collectEphemeralValues(L, AC, EphValues);
279
280  // Calculate the number of iterations ahead to prefetch
281  CodeMetrics Metrics;
282  bool HasCall = false;
283  for (const auto BB : L->blocks()) {
284    // If the loop already has prefetches, then assume that the user knows
285    // what they are doing and don't add any more.
286    for (auto &I : *BB) {
287      if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
288        if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
289          if (F->getIntrinsicID() == Intrinsic::prefetch)
290            return MadeChange;
291          if (TTI->isLoweredToCall(F))
292            HasCall = true;
293        } else { // indirect call.
294          HasCall = true;
295        }
296      }
297    }
298    Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
299  }
300  unsigned LoopSize = Metrics.NumInsts;
301  if (!LoopSize)
302    LoopSize = 1;
303
304  unsigned ItersAhead = getPrefetchDistance() / LoopSize;
305  if (!ItersAhead)
306    ItersAhead = 1;
307
308  if (ItersAhead > getMaxPrefetchIterationsAhead())
309    return MadeChange;
310
311  unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
312  if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
313    return MadeChange;
314
315  unsigned NumMemAccesses = 0;
316  unsigned NumStridedMemAccesses = 0;
317  SmallVector<Prefetch, 16> Prefetches;
318  for (const auto BB : L->blocks())
319    for (auto &I : *BB) {
320      Value *PtrValue;
321      Instruction *MemI;
322
323      if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
324        MemI = LMemI;
325        PtrValue = LMemI->getPointerOperand();
326      } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
327        if (!doPrefetchWrites()) continue;
328        MemI = SMemI;
329        PtrValue = SMemI->getPointerOperand();
330      } else continue;
331
332      unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
333      if (PtrAddrSpace)
334        continue;
335      NumMemAccesses++;
336      if (L->isLoopInvariant(PtrValue))
337        continue;
338
339      const SCEV *LSCEV = SE->getSCEV(PtrValue);
340      const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
341      if (!LSCEVAddRec)
342        continue;
343      NumStridedMemAccesses++;
344
345      // We don't want to double prefetch individual cache lines. If this
346      // access is known to be within one cache line of some other one that
347      // has already been prefetched, then don't prefetch this one as well.
348      bool DupPref = false;
349      for (auto &Pref : Prefetches) {
350        const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
351        if (const SCEVConstant *ConstPtrDiff =
352            dyn_cast<SCEVConstant>(PtrDiff)) {
353          int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
354          if (PD < (int64_t) TTI->getCacheLineSize()) {
355            Pref.addInstruction(MemI, DT, PD);
356            DupPref = true;
357            break;
358          }
359        }
360      }
361      if (!DupPref)
362        Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
363    }
364
365  unsigned TargetMinStride =
366    getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
367                         Prefetches.size(), HasCall);
368
369  LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
370             << " iterations ahead (loop size: " << LoopSize << ") in "
371             << L->getHeader()->getParent()->getName() << ": " << *L);
372  LLVM_DEBUG(dbgs() << "Loop has: "
373             << NumMemAccesses << " memory accesses, "
374             << NumStridedMemAccesses << " strided memory accesses, "
375             << Prefetches.size() << " potential prefetch(es), "
376             << "a minimum stride of " << TargetMinStride << ", "
377             << (HasCall ? "calls" : "no calls") << ".\n");
378
379  for (auto &P : Prefetches) {
380    // Check if the stride of the accesses is large enough to warrant a
381    // prefetch.
382    if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
383      continue;
384
385    const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
386      SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
387      P.LSCEVAddRec->getStepRecurrence(*SE)));
388    if (!isSafeToExpand(NextLSCEV, *SE))
389      continue;
390
391    BasicBlock *BB = P.InsertPt->getParent();
392    Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
393    SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
394    Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
395
396    IRBuilder<> Builder(P.InsertPt);
397    Module *M = BB->getParent()->getParent();
398    Type *I32 = Type::getInt32Ty(BB->getContext());
399    Function *PrefetchFunc = Intrinsic::getDeclaration(
400        M, Intrinsic::prefetch, PrefPtrValue->getType());
401    Builder.CreateCall(
402        PrefetchFunc,
403        {PrefPtrValue,
404         ConstantInt::get(I32, P.Writes),
405         ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
406    ++NumPrefetches;
407    LLVM_DEBUG(dbgs() << "  Access: "
408               << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
409               << ", SCEV: " << *P.LSCEVAddRec << "\n");
410    ORE->emit([&]() {
411        return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
412          << "prefetched memory access";
413      });
414
415    MadeChange = true;
416  }
417
418  return MadeChange;
419}
420