1286425Sdim//===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
2286425Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6286425Sdim//
7286425Sdim//===----------------------------------------------------------------------===//
8286425Sdim//
9286425Sdim// This file defines a utility class to perform loop versioning.  The versioned
10286425Sdim// loop speculates that otherwise may-aliasing memory accesses don't overlap and
11286425Sdim// emits checks to prove this.
12286425Sdim//
13286425Sdim//===----------------------------------------------------------------------===//
14286425Sdim
15296417Sdim#include "llvm/Transforms/Utils/LoopVersioning.h"
16286425Sdim#include "llvm/Analysis/LoopAccessAnalysis.h"
17286425Sdim#include "llvm/Analysis/LoopInfo.h"
18296417Sdim#include "llvm/Analysis/ScalarEvolutionExpander.h"
19286425Sdim#include "llvm/IR/Dominators.h"
20309124Sdim#include "llvm/IR/MDBuilder.h"
21360784Sdim#include "llvm/InitializePasses.h"
22360784Sdim#include "llvm/Support/CommandLine.h"
23286425Sdim#include "llvm/Transforms/Utils/BasicBlockUtils.h"
24286425Sdim#include "llvm/Transforms/Utils/Cloning.h"
25286425Sdim
26286425Sdimusing namespace llvm;
27286425Sdim
28309124Sdimstatic cl::opt<bool>
29309124Sdim    AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
30309124Sdim                    cl::Hidden,
31309124Sdim                    cl::desc("Add no-alias annotation for instructions that "
32309124Sdim                             "are disambiguated by memchecks"));
33309124Sdim
34286425SdimLoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
35296417Sdim                               DominatorTree *DT, ScalarEvolution *SE,
36296417Sdim                               bool UseLAIChecks)
37296417Sdim    : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
38296417Sdim      SE(SE) {
39286425Sdim  assert(L->getExitBlock() && "No single exit block");
40314564Sdim  assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form");
41296417Sdim  if (UseLAIChecks) {
42296417Sdim    setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
43309124Sdim    setSCEVChecks(LAI.getPSE().getUnionPredicate());
44296417Sdim  }
45286425Sdim}
46286425Sdim
47296417Sdimvoid LoopVersioning::setAliasChecks(
48309124Sdim    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
49296417Sdim  AliasChecks = std::move(Checks);
50286425Sdim}
51286425Sdim
52296417Sdimvoid LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
53296417Sdim  Preds = std::move(Check);
54296417Sdim}
55296417Sdim
56296417Sdimvoid LoopVersioning::versionLoop(
57296417Sdim    const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
58286425Sdim  Instruction *FirstCheckInst;
59286425Sdim  Instruction *MemRuntimeCheck;
60296417Sdim  Value *SCEVRuntimeCheck;
61296417Sdim  Value *RuntimeCheck = nullptr;
62296417Sdim
63286425Sdim  // Add the memcheck in the original preheader (this is empty initially).
64296417Sdim  BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
65286425Sdim  std::tie(FirstCheckInst, MemRuntimeCheck) =
66296417Sdim      LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
67286425Sdim
68309124Sdim  const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate();
69296417Sdim  SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
70296417Sdim                   "scev.check");
71296417Sdim  SCEVRuntimeCheck =
72296417Sdim      Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
73296417Sdim  auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
74296417Sdim
75296417Sdim  // Discard the SCEV runtime check if it is always true.
76296417Sdim  if (CI && CI->isZero())
77296417Sdim    SCEVRuntimeCheck = nullptr;
78296417Sdim
79296417Sdim  if (MemRuntimeCheck && SCEVRuntimeCheck) {
80296417Sdim    RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
81309124Sdim                                          SCEVRuntimeCheck, "lver.safe");
82296417Sdim    if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
83296417Sdim      I->insertBefore(RuntimeCheckBB->getTerminator());
84296417Sdim  } else
85296417Sdim    RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
86296417Sdim
87296417Sdim  assert(RuntimeCheck && "called even though we don't need "
88296417Sdim                         "any runtime checks");
89296417Sdim
90286425Sdim  // Rename the block to make the IR more readable.
91296417Sdim  RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
92296417Sdim                          ".lver.check");
93286425Sdim
94286425Sdim  // Create empty preheader for the loop (and after cloning for the
95286425Sdim  // non-versioned loop).
96296417Sdim  BasicBlock *PH =
97360784Sdim      SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
98360784Sdim                 nullptr, VersionedLoop->getHeader()->getName() + ".ph");
99286425Sdim
100286425Sdim  // Clone the loop including the preheader.
101286425Sdim  //
102286425Sdim  // FIXME: This does not currently preserve SimplifyLoop because the exit
103286425Sdim  // block is a join between the two loops.
104286425Sdim  SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
105286425Sdim  NonVersionedLoop =
106296417Sdim      cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
107296417Sdim                             ".lver.orig", LI, DT, NonVersionedLoopBlocks);
108286425Sdim  remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
109286425Sdim
110286425Sdim  // Insert the conditional branch based on the result of the memchecks.
111296417Sdim  Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
112286425Sdim  BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
113296417Sdim                     VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
114286425Sdim  OrigTerm->eraseFromParent();
115286425Sdim
116286425Sdim  // The loops merge in the original exit block.  This is now dominated by the
117286425Sdim  // memchecking block.
118296417Sdim  DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
119296417Sdim
120296417Sdim  // Adds the necessary PHI nodes for the versioned loops based on the
121296417Sdim  // loop-defined values used outside of the loop.
122296417Sdim  addPHINodes(DefsUsedOutside);
123286425Sdim}
124286425Sdim
125286425Sdimvoid LoopVersioning::addPHINodes(
126286425Sdim    const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
127286425Sdim  BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
128286425Sdim  assert(PHIBlock && "No single successor to loop exit block");
129309124Sdim  PHINode *PN;
130286425Sdim
131309124Sdim  // First add a single-operand PHI for each DefsUsedOutside if one does not
132309124Sdim  // exists yet.
133286425Sdim  for (auto *Inst : DefsUsedOutside) {
134309124Sdim    // See if we have a single-operand PHI with the value defined by the
135286425Sdim    // original loop.
136286425Sdim    for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
137286425Sdim      if (PN->getIncomingValue(0) == Inst)
138286425Sdim        break;
139286425Sdim    }
140286425Sdim    // If not create it.
141286425Sdim    if (!PN) {
142286425Sdim      PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
143296417Sdim                           &PHIBlock->front());
144341825Sdim      SmallVector<User*, 8> UsersToUpdate;
145341825Sdim      for (User *U : Inst->users())
146341825Sdim        if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
147341825Sdim          UsersToUpdate.push_back(U);
148341825Sdim      for (User *U : UsersToUpdate)
149341825Sdim        U->replaceUsesOfWith(Inst, PN);
150286425Sdim      PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
151286425Sdim    }
152286425Sdim  }
153309124Sdim
154309124Sdim  // Then for each PHI add the operand for the edge from the cloned loop.
155309124Sdim  for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
156309124Sdim    assert(PN->getNumOperands() == 1 &&
157309124Sdim           "Exit block should only have on predecessor");
158309124Sdim
159309124Sdim    // If the definition was cloned used that otherwise use the same value.
160309124Sdim    Value *ClonedValue = PN->getIncomingValue(0);
161309124Sdim    auto Mapped = VMap.find(ClonedValue);
162309124Sdim    if (Mapped != VMap.end())
163309124Sdim      ClonedValue = Mapped->second;
164309124Sdim
165309124Sdim    PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
166309124Sdim  }
167286425Sdim}
168309124Sdim
169309124Sdimvoid LoopVersioning::prepareNoAliasMetadata() {
170309124Sdim  // We need to turn the no-alias relation between pointer checking groups into
171309124Sdim  // no-aliasing annotations between instructions.
172309124Sdim  //
173309124Sdim  // We accomplish this by mapping each pointer checking group (a set of
174309124Sdim  // pointers memchecked together) to an alias scope and then also mapping each
175309124Sdim  // group to the list of scopes it can't alias.
176309124Sdim
177309124Sdim  const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
178309124Sdim  LLVMContext &Context = VersionedLoop->getHeader()->getContext();
179309124Sdim
180309124Sdim  // First allocate an aliasing scope for each pointer checking group.
181309124Sdim  //
182309124Sdim  // While traversing through the checking groups in the loop, also create a
183309124Sdim  // reverse map from pointers to the pointer checking group they were assigned
184309124Sdim  // to.
185309124Sdim  MDBuilder MDB(Context);
186309124Sdim  MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
187309124Sdim
188309124Sdim  for (const auto &Group : RtPtrChecking->CheckingGroups) {
189309124Sdim    GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
190309124Sdim
191309124Sdim    for (unsigned PtrIdx : Group.Members)
192309124Sdim      PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
193309124Sdim  }
194309124Sdim
195309124Sdim  // Go through the checks and for each pointer group, collect the scopes for
196309124Sdim  // each non-aliasing pointer group.
197309124Sdim  DenseMap<const RuntimePointerChecking::CheckingPtrGroup *,
198309124Sdim           SmallVector<Metadata *, 4>>
199309124Sdim      GroupToNonAliasingScopes;
200309124Sdim
201309124Sdim  for (const auto &Check : AliasChecks)
202309124Sdim    GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
203309124Sdim
204309124Sdim  // Finally, transform the above to actually map to scope list which is what
205309124Sdim  // the metadata uses.
206309124Sdim
207309124Sdim  for (auto Pair : GroupToNonAliasingScopes)
208309124Sdim    GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
209309124Sdim}
210309124Sdim
211309124Sdimvoid LoopVersioning::annotateLoopWithNoAlias() {
212309124Sdim  if (!AnnotateNoAlias)
213309124Sdim    return;
214309124Sdim
215309124Sdim  // First prepare the maps.
216309124Sdim  prepareNoAliasMetadata();
217309124Sdim
218309124Sdim  // Add the scope and no-alias metadata to the instructions.
219309124Sdim  for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
220309124Sdim    annotateInstWithNoAlias(I);
221309124Sdim  }
222309124Sdim}
223309124Sdim
224309124Sdimvoid LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
225309124Sdim                                             const Instruction *OrigInst) {
226309124Sdim  if (!AnnotateNoAlias)
227309124Sdim    return;
228309124Sdim
229309124Sdim  LLVMContext &Context = VersionedLoop->getHeader()->getContext();
230309124Sdim  const Value *Ptr = isa<LoadInst>(OrigInst)
231309124Sdim                         ? cast<LoadInst>(OrigInst)->getPointerOperand()
232309124Sdim                         : cast<StoreInst>(OrigInst)->getPointerOperand();
233309124Sdim
234309124Sdim  // Find the group for the pointer and then add the scope metadata.
235309124Sdim  auto Group = PtrToGroup.find(Ptr);
236309124Sdim  if (Group != PtrToGroup.end()) {
237309124Sdim    VersionedInst->setMetadata(
238309124Sdim        LLVMContext::MD_alias_scope,
239309124Sdim        MDNode::concatenate(
240309124Sdim            VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
241309124Sdim            MDNode::get(Context, GroupToScope[Group->second])));
242309124Sdim
243309124Sdim    // Add the no-alias metadata.
244309124Sdim    auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
245309124Sdim    if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
246309124Sdim      VersionedInst->setMetadata(
247309124Sdim          LLVMContext::MD_noalias,
248309124Sdim          MDNode::concatenate(
249309124Sdim              VersionedInst->getMetadata(LLVMContext::MD_noalias),
250309124Sdim              NonAliasingScopeList->second));
251309124Sdim  }
252309124Sdim}
253309124Sdim
254309124Sdimnamespace {
255341825Sdim/// Also expose this is a pass.  Currently this is only used for
256309124Sdim/// unit-testing.  It adds all memchecks necessary to remove all may-aliasing
257309124Sdim/// array accesses from the loop.
258309124Sdimclass LoopVersioningPass : public FunctionPass {
259309124Sdimpublic:
260309124Sdim  LoopVersioningPass() : FunctionPass(ID) {
261309124Sdim    initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry());
262309124Sdim  }
263309124Sdim
264309124Sdim  bool runOnFunction(Function &F) override {
265309124Sdim    auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
266309124Sdim    auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
267309124Sdim    auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
268309124Sdim    auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
269309124Sdim
270309124Sdim    // Build up a worklist of inner-loops to version. This is necessary as the
271309124Sdim    // act of versioning a loop creates new loops and can invalidate iterators
272309124Sdim    // across the loops.
273309124Sdim    SmallVector<Loop *, 8> Worklist;
274309124Sdim
275309124Sdim    for (Loop *TopLevelLoop : *LI)
276309124Sdim      for (Loop *L : depth_first(TopLevelLoop))
277309124Sdim        // We only handle inner-most loops.
278309124Sdim        if (L->empty())
279309124Sdim          Worklist.push_back(L);
280309124Sdim
281309124Sdim    // Now walk the identified inner loops.
282309124Sdim    bool Changed = false;
283309124Sdim    for (Loop *L : Worklist) {
284309124Sdim      const LoopAccessInfo &LAI = LAA->getInfo(L);
285353358Sdim      if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() &&
286353358Sdim          (LAI.getNumRuntimePointerChecks() ||
287353358Sdim           !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
288309124Sdim        LoopVersioning LVer(LAI, L, LI, DT, SE);
289309124Sdim        LVer.versionLoop();
290309124Sdim        LVer.annotateLoopWithNoAlias();
291309124Sdim        Changed = true;
292309124Sdim      }
293309124Sdim    }
294309124Sdim
295309124Sdim    return Changed;
296309124Sdim  }
297309124Sdim
298309124Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
299309124Sdim    AU.addRequired<LoopInfoWrapperPass>();
300309124Sdim    AU.addPreserved<LoopInfoWrapperPass>();
301309124Sdim    AU.addRequired<LoopAccessLegacyAnalysis>();
302309124Sdim    AU.addRequired<DominatorTreeWrapperPass>();
303309124Sdim    AU.addPreserved<DominatorTreeWrapperPass>();
304309124Sdim    AU.addRequired<ScalarEvolutionWrapperPass>();
305309124Sdim  }
306309124Sdim
307309124Sdim  static char ID;
308309124Sdim};
309309124Sdim}
310309124Sdim
311309124Sdim#define LVER_OPTION "loop-versioning"
312309124Sdim#define DEBUG_TYPE LVER_OPTION
313309124Sdim
314309124Sdimchar LoopVersioningPass::ID;
315309124Sdimstatic const char LVer_name[] = "Loop Versioning";
316309124Sdim
317309124SdimINITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
318309124SdimINITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
319309124SdimINITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
320309124SdimINITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
321309124SdimINITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
322309124SdimINITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
323309124Sdim
324309124Sdimnamespace llvm {
325309124SdimFunctionPass *createLoopVersioningPass() {
326309124Sdim  return new LoopVersioningPass();
327309124Sdim}
328309124Sdim}
329