1//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a pass that keeps track of @llvm.assume and
10// @llvm.experimental.guard intrinsics in the functions of a module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/AssumptionCache.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/SmallPtrSet.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Analysis/AssumeBundleQueries.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/InstrTypes.h"
23#include "llvm/IR/Instruction.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/IR/PassManager.h"
26#include "llvm/IR/PatternMatch.h"
27#include "llvm/InitializePasses.h"
28#include "llvm/Pass.h"
29#include "llvm/Support/Casting.h"
30#include "llvm/Support/CommandLine.h"
31#include "llvm/Support/ErrorHandling.h"
32#include "llvm/Support/raw_ostream.h"
33#include <cassert>
34#include <utility>
35
36using namespace llvm;
37using namespace llvm::PatternMatch;
38
39static cl::opt<bool>
40    VerifyAssumptionCache("verify-assumption-cache", cl::Hidden,
41                          cl::desc("Enable verification of assumption cache"),
42                          cl::init(false));
43
44SmallVector<AssumptionCache::ResultElem, 1> &
45AssumptionCache::getOrInsertAffectedValues(Value *V) {
46  // Try using find_as first to avoid creating extra value handles just for the
47  // purpose of doing the lookup.
48  auto AVI = AffectedValues.find_as(V);
49  if (AVI != AffectedValues.end())
50    return AVI->second;
51
52  auto AVIP = AffectedValues.insert(
53      {AffectedValueCallbackVH(V, this), SmallVector<ResultElem, 1>()});
54  return AVIP.first->second;
55}
56
57static void
58findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
59                   SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
60  // Note: This code must be kept in-sync with the code in
61  // computeKnownBitsFromAssume in ValueTracking.
62
63  auto AddAffected = [&Affected](Value *V, unsigned Idx =
64                                               AssumptionCache::ExprResultIdx) {
65    if (isa<Argument>(V)) {
66      Affected.push_back({V, Idx});
67    } else if (auto *I = dyn_cast<Instruction>(V)) {
68      Affected.push_back({I, Idx});
69
70      // Peek through unary operators to find the source of the condition.
71      Value *Op;
72      if (match(I, m_BitCast(m_Value(Op))) ||
73          match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) {
74        if (isa<Instruction>(Op) || isa<Argument>(Op))
75          Affected.push_back({Op, Idx});
76      }
77    }
78  };
79
80  for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
81    if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn &&
82        CI->getOperandBundleAt(Idx).getTagName() != IgnoreBundleTag)
83      AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx);
84  }
85
86  Value *Cond = CI->getArgOperand(0), *A, *B;
87  AddAffected(Cond);
88
89  CmpInst::Predicate Pred;
90  if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
91    AddAffected(A);
92    AddAffected(B);
93
94    if (Pred == ICmpInst::ICMP_EQ) {
95      // For equality comparisons, we handle the case of bit inversion.
96      auto AddAffectedFromEq = [&AddAffected](Value *V) {
97        Value *A;
98        if (match(V, m_Not(m_Value(A)))) {
99          AddAffected(A);
100          V = A;
101        }
102
103        Value *B;
104        // (A & B) or (A | B) or (A ^ B).
105        if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
106          AddAffected(A);
107          AddAffected(B);
108          // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
109        } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
110          AddAffected(A);
111        }
112      };
113
114      AddAffectedFromEq(A);
115      AddAffectedFromEq(B);
116    } else if (Pred == ICmpInst::ICMP_NE) {
117      Value *X, *Y;
118      // Handle (a & b != 0). If a/b is a power of 2 we can use this
119      // information.
120      if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
121        AddAffected(X);
122        AddAffected(Y);
123      }
124    } else if (Pred == ICmpInst::ICMP_ULT) {
125      Value *X;
126      // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
127      // and recognized by LVI at least.
128      if (match(A, m_Add(m_Value(X), m_ConstantInt())) &&
129          match(B, m_ConstantInt()))
130        AddAffected(X);
131    }
132  }
133
134  if (TTI) {
135    const Value *Ptr;
136    unsigned AS;
137    std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(Cond);
138    if (Ptr)
139      AddAffected(const_cast<Value *>(Ptr->stripInBoundsOffsets()));
140  }
141}
142
143void AssumptionCache::updateAffectedValues(CondGuardInst *CI) {
144  SmallVector<AssumptionCache::ResultElem, 16> Affected;
145  findAffectedValues(CI, TTI, Affected);
146
147  for (auto &AV : Affected) {
148    auto &AVV = getOrInsertAffectedValues(AV.Assume);
149    if (llvm::none_of(AVV, [&](ResultElem &Elem) {
150          return Elem.Assume == CI && Elem.Index == AV.Index;
151        }))
152      AVV.push_back({CI, AV.Index});
153  }
154}
155
156void AssumptionCache::unregisterAssumption(CondGuardInst *CI) {
157  SmallVector<AssumptionCache::ResultElem, 16> Affected;
158  findAffectedValues(CI, TTI, Affected);
159
160  for (auto &AV : Affected) {
161    auto AVI = AffectedValues.find_as(AV.Assume);
162    if (AVI == AffectedValues.end())
163      continue;
164    bool Found = false;
165    bool HasNonnull = false;
166    for (ResultElem &Elem : AVI->second) {
167      if (Elem.Assume == CI) {
168        Found = true;
169        Elem.Assume = nullptr;
170      }
171      HasNonnull |= !!Elem.Assume;
172      if (HasNonnull && Found)
173        break;
174    }
175    assert(Found && "already unregistered or incorrect cache state");
176    if (!HasNonnull)
177      AffectedValues.erase(AVI);
178  }
179
180  erase_value(AssumeHandles, CI);
181}
182
183void AssumptionCache::AffectedValueCallbackVH::deleted() {
184  AC->AffectedValues.erase(getValPtr());
185  // 'this' now dangles!
186}
187
188void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) {
189  auto &NAVV = getOrInsertAffectedValues(NV);
190  auto AVI = AffectedValues.find(OV);
191  if (AVI == AffectedValues.end())
192    return;
193
194  for (auto &A : AVI->second)
195    if (!llvm::is_contained(NAVV, A))
196      NAVV.push_back(A);
197  AffectedValues.erase(OV);
198}
199
200void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
201  if (!isa<Instruction>(NV) && !isa<Argument>(NV))
202    return;
203
204  // Any assumptions that affected this value now affect the new value.
205
206  AC->transferAffectedValuesInCache(getValPtr(), NV);
207  // 'this' now might dangle! If the AffectedValues map was resized to add an
208  // entry for NV then this object might have been destroyed in favor of some
209  // copy in the grown map.
210}
211
212void AssumptionCache::scanFunction() {
213  assert(!Scanned && "Tried to scan the function twice!");
214  assert(AssumeHandles.empty() && "Already have assumes when scanning!");
215
216  // Go through all instructions in all blocks, add all calls to @llvm.assume
217  // to this cache.
218  for (BasicBlock &B : F)
219    for (Instruction &I : B)
220      if (isa<CondGuardInst>(&I))
221        AssumeHandles.push_back({&I, ExprResultIdx});
222
223  // Mark the scan as complete.
224  Scanned = true;
225
226  // Update affected values.
227  for (auto &A : AssumeHandles)
228    updateAffectedValues(cast<CondGuardInst>(A));
229}
230
231void AssumptionCache::registerAssumption(CondGuardInst *CI) {
232  // If we haven't scanned the function yet, just drop this assumption. It will
233  // be found when we scan later.
234  if (!Scanned)
235    return;
236
237  AssumeHandles.push_back({CI, ExprResultIdx});
238
239#ifndef NDEBUG
240  assert(CI->getParent() &&
241         "Cannot a register CondGuardInst not in a basic block");
242  assert(&F == CI->getParent()->getParent() &&
243         "Cannot a register CondGuardInst not in this function");
244
245  // We expect the number of assumptions to be small, so in an asserts build
246  // check that we don't accumulate duplicates and that all assumptions point
247  // to the same function.
248  SmallPtrSet<Value *, 16> AssumptionSet;
249  for (auto &VH : AssumeHandles) {
250    if (!VH)
251      continue;
252
253    assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
254           "Cached assumption not inside this function!");
255    assert(isa<CondGuardInst>(VH) &&
256           "Cached something other than CondGuardInst!");
257    assert(AssumptionSet.insert(VH).second &&
258           "Cache contains multiple copies of a call!");
259  }
260#endif
261
262  updateAffectedValues(CI);
263}
264
265AssumptionCache AssumptionAnalysis::run(Function &F,
266                                        FunctionAnalysisManager &FAM) {
267  auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
268  return AssumptionCache(F, &TTI);
269}
270
271AnalysisKey AssumptionAnalysis::Key;
272
273PreservedAnalyses AssumptionPrinterPass::run(Function &F,
274                                             FunctionAnalysisManager &AM) {
275  AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
276
277  OS << "Cached assumptions for function: " << F.getName() << "\n";
278  for (auto &VH : AC.assumptions())
279    if (VH)
280      OS << "  " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
281
282  return PreservedAnalyses::all();
283}
284
285void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
286  auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr()));
287  if (I != ACT->AssumptionCaches.end())
288    ACT->AssumptionCaches.erase(I);
289  // 'this' now dangles!
290}
291
292AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
293  // We probe the function map twice to try and avoid creating a value handle
294  // around the function in common cases. This makes insertion a bit slower,
295  // but if we have to insert we're going to scan the whole function so that
296  // shouldn't matter.
297  auto I = AssumptionCaches.find_as(&F);
298  if (I != AssumptionCaches.end())
299    return *I->second;
300
301  auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
302  auto *TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
303
304  // Ok, build a new cache by scanning the function, insert it and the value
305  // handle into our map, and return the newly populated cache.
306  auto IP = AssumptionCaches.insert(std::make_pair(
307      FunctionCallbackVH(&F, this), std::make_unique<AssumptionCache>(F, TTI)));
308  assert(IP.second && "Scanning function already in the map?");
309  return *IP.first->second;
310}
311
312AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
313  auto I = AssumptionCaches.find_as(&F);
314  if (I != AssumptionCaches.end())
315    return I->second.get();
316  return nullptr;
317}
318
319void AssumptionCacheTracker::verifyAnalysis() const {
320  // FIXME: In the long term the verifier should not be controllable with a
321  // flag. We should either fix all passes to correctly update the assumption
322  // cache and enable the verifier unconditionally or somehow arrange for the
323  // assumption list to be updated automatically by passes.
324  if (!VerifyAssumptionCache)
325    return;
326
327  SmallPtrSet<const CallInst *, 4> AssumptionSet;
328  for (const auto &I : AssumptionCaches) {
329    for (auto &VH : I.second->assumptions())
330      if (VH)
331        AssumptionSet.insert(cast<CallInst>(VH));
332
333    for (const BasicBlock &B : cast<Function>(*I.first))
334      for (const Instruction &II : B)
335        if (match(&II, m_Intrinsic<Intrinsic::assume>()) &&
336            !AssumptionSet.count(cast<CallInst>(&II)))
337          report_fatal_error("Assumption in scanned function not in cache");
338  }
339}
340
341AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {
342  initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry());
343}
344
345AssumptionCacheTracker::~AssumptionCacheTracker() = default;
346
347char AssumptionCacheTracker::ID = 0;
348
349INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
350                "Assumption Cache Tracker", false, true)
351