1//===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Check IR and adjust IR for verifier friendly codes.
10// The following are done for IR checking:
11//   - no relocation globals in PHI node.
12// The following are done for IR adjustment:
13//   - remove __builtin_bpf_passthrough builtins. Target independent IR
14//     optimizations are done and those builtins can be removed.
15//   - remove llvm.bpf.getelementptr.and.load builtins.
16//   - remove llvm.bpf.getelementptr.and.store builtins.
17//
18//===----------------------------------------------------------------------===//
19
20#include "BPF.h"
21#include "BPFCORE.h"
22#include "BPFTargetMachine.h"
23#include "llvm/Analysis/LoopInfo.h"
24#include "llvm/IR/DebugInfoMetadata.h"
25#include "llvm/IR/GlobalVariable.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/IntrinsicsBPF.h"
30#include "llvm/IR/Module.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/User.h"
33#include "llvm/IR/Value.h"
34#include "llvm/Pass.h"
35#include "llvm/Transforms/Utils/BasicBlockUtils.h"
36
37#define DEBUG_TYPE "bpf-check-and-opt-ir"
38
39using namespace llvm;
40
41namespace {
42
43class BPFCheckAndAdjustIR final : public ModulePass {
44  bool runOnModule(Module &F) override;
45
46public:
47  static char ID;
48  BPFCheckAndAdjustIR() : ModulePass(ID) {}
49  virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
50
51private:
52  void checkIR(Module &M);
53  bool adjustIR(Module &M);
54  bool removePassThroughBuiltin(Module &M);
55  bool removeCompareBuiltin(Module &M);
56  bool sinkMinMax(Module &M);
57  bool removeGEPBuiltins(Module &M);
58};
59} // End anonymous namespace
60
61char BPFCheckAndAdjustIR::ID = 0;
62INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
63                false, false)
64
65ModulePass *llvm::createBPFCheckAndAdjustIR() {
66  return new BPFCheckAndAdjustIR();
67}
68
69void BPFCheckAndAdjustIR::checkIR(Module &M) {
70  // Ensure relocation global won't appear in PHI node
71  // This may happen if the compiler generated the following code:
72  //   B1:
73  //      g1 = @llvm.skb_buff:0:1...
74  //      ...
75  //      goto B_COMMON
76  //   B2:
77  //      g2 = @llvm.skb_buff:0:2...
78  //      ...
79  //      goto B_COMMON
80  //   B_COMMON:
81  //      g = PHI(g1, g2)
82  //      x = load g
83  //      ...
84  // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
85  for (Function &F : M)
86    for (auto &BB : F)
87      for (auto &I : BB) {
88        PHINode *PN = dyn_cast<PHINode>(&I);
89        if (!PN || PN->use_empty())
90          continue;
91        for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
92          auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
93          if (!GV)
94            continue;
95          if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
96              GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
97            report_fatal_error("relocation global in PHI node");
98        }
99      }
100}
101
102bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
103  // Remove __builtin_bpf_passthrough()'s which are used to prevent
104  // certain IR optimizations. Now major IR optimizations are done,
105  // remove them.
106  bool Changed = false;
107  CallInst *ToBeDeleted = nullptr;
108  for (Function &F : M)
109    for (auto &BB : F)
110      for (auto &I : BB) {
111        if (ToBeDeleted) {
112          ToBeDeleted->eraseFromParent();
113          ToBeDeleted = nullptr;
114        }
115
116        auto *Call = dyn_cast<CallInst>(&I);
117        if (!Call)
118          continue;
119        auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
120        if (!GV)
121          continue;
122        if (!GV->getName().starts_with("llvm.bpf.passthrough"))
123          continue;
124        Changed = true;
125        Value *Arg = Call->getArgOperand(1);
126        Call->replaceAllUsesWith(Arg);
127        ToBeDeleted = Call;
128      }
129  return Changed;
130}
131
132bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
133  // Remove __builtin_bpf_compare()'s which are used to prevent
134  // certain IR optimizations. Now major IR optimizations are done,
135  // remove them.
136  bool Changed = false;
137  CallInst *ToBeDeleted = nullptr;
138  for (Function &F : M)
139    for (auto &BB : F)
140      for (auto &I : BB) {
141        if (ToBeDeleted) {
142          ToBeDeleted->eraseFromParent();
143          ToBeDeleted = nullptr;
144        }
145
146        auto *Call = dyn_cast<CallInst>(&I);
147        if (!Call)
148          continue;
149        auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
150        if (!GV)
151          continue;
152        if (!GV->getName().starts_with("llvm.bpf.compare"))
153          continue;
154
155        Changed = true;
156        Value *Arg0 = Call->getArgOperand(0);
157        Value *Arg1 = Call->getArgOperand(1);
158        Value *Arg2 = Call->getArgOperand(2);
159
160        auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
161        CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal;
162
163        auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
164        ICmp->insertBefore(Call);
165
166        Call->replaceAllUsesWith(ICmp);
167        ToBeDeleted = Call;
168      }
169  return Changed;
170}
171
172struct MinMaxSinkInfo {
173  ICmpInst *ICmp;
174  Value *Other;
175  ICmpInst::Predicate Predicate;
176  CallInst *MinMax;
177  ZExtInst *ZExt;
178  SExtInst *SExt;
179
180  MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
181      : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
182        ZExt(nullptr), SExt(nullptr) {}
183};
184
185static bool sinkMinMaxInBB(BasicBlock &BB,
186                           const std::function<bool(Instruction *)> &Filter) {
187  // Check if V is:
188  //   (fn %a %b) or (ext (fn %a %b))
189  // Where:
190  //   ext := sext | zext
191  //   fn  := smin | umin | smax | umax
192  auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
193    if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
194      V = ZExt->getOperand(0);
195      Info.ZExt = ZExt;
196    } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
197      V = SExt->getOperand(0);
198      Info.SExt = SExt;
199    }
200
201    auto *Call = dyn_cast<CallInst>(V);
202    if (!Call)
203      return false;
204
205    auto *Called = dyn_cast<Function>(Call->getCalledOperand());
206    if (!Called)
207      return false;
208
209    switch (Called->getIntrinsicID()) {
210    case Intrinsic::smin:
211    case Intrinsic::umin:
212    case Intrinsic::smax:
213    case Intrinsic::umax:
214      break;
215    default:
216      return false;
217    }
218
219    if (!Filter(Call))
220      return false;
221
222    Info.MinMax = Call;
223
224    return true;
225  };
226
227  auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
228                             MinMaxSinkInfo &Info) {
229    if (Info.SExt) {
230      if (Info.SExt->getType() == V->getType())
231        return V;
232      return Builder.CreateSExt(V, Info.SExt->getType());
233    }
234    if (Info.ZExt) {
235      if (Info.ZExt->getType() == V->getType())
236        return V;
237      return Builder.CreateZExt(V, Info.ZExt->getType());
238    }
239    return V;
240  };
241
242  bool Changed = false;
243  SmallVector<MinMaxSinkInfo, 2> SinkList;
244
245  // Check BB for instructions like:
246  //   insn := (icmp %a (fn ...)) | (icmp (fn ...)  %a)
247  //
248  // Where:
249  //   fn := min | max | (sext (min ...)) | (sext (max ...))
250  //
251  // Put such instructions to SinkList.
252  for (Instruction &I : BB) {
253    ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
254    if (!ICmp)
255      continue;
256    if (!ICmp->isRelational())
257      continue;
258    MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
259                         ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
260    MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
261    bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
262    bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
263    if (!(FirstMinMax ^ SecondMinMax))
264      continue;
265    SinkList.push_back(FirstMinMax ? First : Second);
266  }
267
268  // Iterate SinkList and replace each (icmp ...) with corresponding
269  // `x < a && x < b` or similar expression.
270  for (auto &Info : SinkList) {
271    ICmpInst *ICmp = Info.ICmp;
272    CallInst *MinMax = Info.MinMax;
273    Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
274    ICmpInst::Predicate P = Info.Predicate;
275    if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
276        IID != Intrinsic::smax)
277      continue;
278
279    IRBuilder<> Builder(ICmp);
280    Value *X = Info.Other;
281    Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
282    Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
283    bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
284    bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
285    bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
286    bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
287    assert(IsMin ^ IsMax);
288    assert(IsLess ^ IsGreater);
289
290    Value *Replacement;
291    Value *LHS = Builder.CreateICmp(P, X, A);
292    Value *RHS = Builder.CreateICmp(P, X, B);
293    if ((IsLess && IsMin) || (IsGreater && IsMax))
294      // x < min(a, b) -> x < a && x < b
295      // x > max(a, b) -> x > a && x > b
296      Replacement = Builder.CreateLogicalAnd(LHS, RHS);
297    else
298      // x > min(a, b) -> x > a || x > b
299      // x < max(a, b) -> x < a || x < b
300      Replacement = Builder.CreateLogicalOr(LHS, RHS);
301
302    ICmp->replaceAllUsesWith(Replacement);
303
304    Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
305    for (Instruction *I : ToRemove)
306      if (I && I->use_empty())
307        I->eraseFromParent();
308
309    Changed = true;
310  }
311
312  return Changed;
313}
314
315// Do the following transformation:
316//
317//   x < min(a, b) -> x < a && x < b
318//   x > min(a, b) -> x > a || x > b
319//   x < max(a, b) -> x < a || x < b
320//   x > max(a, b) -> x > a && x > b
321//
322// Such patterns are introduced by LICM.cpp:hoistMinMax()
323// transformation and might lead to BPF verification failures for
324// older kernels.
325//
326// To minimize "collateral" changes only do it for icmp + min/max
327// calls when icmp is inside a loop and min/max is outside of that
328// loop.
329//
330// Verification failure happens when:
331// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
332// - verifier can recognize RHS as a constant scalar in some context;
333// - verifier can't recognize RHS1 as a constant scalar in the same
334//   context;
335//
336// The "constant scalar" is not a compile time constant, but a register
337// that holds a scalar value known to verifier at some point in time
338// during abstract interpretation.
339//
340// See also:
341//   https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
342bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
343  bool Changed = false;
344
345  for (Function &F : M) {
346    if (F.isDeclaration())
347      continue;
348
349    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
350    for (Loop *L : LI)
351      for (BasicBlock *BB : L->blocks()) {
352        // Filter out instructions coming from the same loop
353        Loop *BBLoop = LI.getLoopFor(BB);
354        auto OtherLoopFilter = [&](Instruction *I) {
355          return LI.getLoopFor(I->getParent()) != BBLoop;
356        };
357        Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
358      }
359  }
360
361  return Changed;
362}
363
364void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
365  AU.addRequired<LoopInfoWrapperPass>();
366}
367
368static void unrollGEPLoad(CallInst *Call) {
369  auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call);
370  GEP->insertBefore(Call);
371  Load->insertBefore(Call);
372  Call->replaceAllUsesWith(Load);
373  Call->eraseFromParent();
374}
375
376static void unrollGEPStore(CallInst *Call) {
377  auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call);
378  GEP->insertBefore(Call);
379  Store->insertBefore(Call);
380  Call->eraseFromParent();
381}
382
383static bool removeGEPBuiltinsInFunc(Function &F) {
384  SmallVector<CallInst *> GEPLoads;
385  SmallVector<CallInst *> GEPStores;
386  for (auto &BB : F)
387    for (auto &Insn : BB)
388      if (auto *Call = dyn_cast<CallInst>(&Insn))
389        if (auto *Called = Call->getCalledFunction())
390          switch (Called->getIntrinsicID()) {
391          case Intrinsic::bpf_getelementptr_and_load:
392            GEPLoads.push_back(Call);
393            break;
394          case Intrinsic::bpf_getelementptr_and_store:
395            GEPStores.push_back(Call);
396            break;
397          }
398
399  if (GEPLoads.empty() && GEPStores.empty())
400    return false;
401
402  for_each(GEPLoads, unrollGEPLoad);
403  for_each(GEPStores, unrollGEPStore);
404
405  return true;
406}
407
408// Rewrites the following builtins:
409// - llvm.bpf.getelementptr.and.load
410// - llvm.bpf.getelementptr.and.store
411// As (load (getelementptr ...)) or (store (getelementptr ...)).
412bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) {
413  bool Changed = false;
414  for (auto &F : M)
415    Changed = removeGEPBuiltinsInFunc(F) || Changed;
416  return Changed;
417}
418
419bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
420  bool Changed = removePassThroughBuiltin(M);
421  Changed = removeCompareBuiltin(M) || Changed;
422  Changed = sinkMinMax(M) || Changed;
423  Changed = removeGEPBuiltins(M) || Changed;
424  return Changed;
425}
426
427bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
428  checkIR(M);
429  return adjustIR(M);
430}
431