1239310Sdim//===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
2239310Sdim//
3239310Sdim//                     The LLVM Compiler Infrastructure
4239310Sdim//
5239310Sdim// This file is distributed under the University of Illinois Open Source
6239310Sdim// License. See LICENSE.TXT for details.
7239310Sdim//
8239310Sdim//===----------------------------------------------------------------------===//
9239310Sdim//
10239310Sdim// This file implements a pass that instruments the code to perform run-time
11239310Sdim// bounds checking on loads, stores, and other memory intrinsics.
12239310Sdim//
13239310Sdim//===----------------------------------------------------------------------===//
14239310Sdim
15239310Sdim#define DEBUG_TYPE "bounds-checking"
16249423Sdim#include "llvm/Transforms/Instrumentation.h"
17239310Sdim#include "llvm/ADT/Statistic.h"
18239310Sdim#include "llvm/Analysis/MemoryBuiltins.h"
19249423Sdim#include "llvm/IR/DataLayout.h"
20249423Sdim#include "llvm/IR/IRBuilder.h"
21249423Sdim#include "llvm/IR/Intrinsics.h"
22249423Sdim#include "llvm/Pass.h"
23239310Sdim#include "llvm/Support/CommandLine.h"
24239310Sdim#include "llvm/Support/Debug.h"
25239310Sdim#include "llvm/Support/InstIterator.h"
26239310Sdim#include "llvm/Support/TargetFolder.h"
27239310Sdim#include "llvm/Support/raw_ostream.h"
28243830Sdim#include "llvm/Target/TargetLibraryInfo.h"
29239310Sdimusing namespace llvm;
30239310Sdim
31239310Sdimstatic cl::opt<bool> SingleTrapBB("bounds-checking-single-trap",
32239310Sdim                                  cl::desc("Use one trap block per function"));
33239310Sdim
34239310SdimSTATISTIC(ChecksAdded, "Bounds checks added");
35239310SdimSTATISTIC(ChecksSkipped, "Bounds checks skipped");
36239310SdimSTATISTIC(ChecksUnable, "Bounds checks unable to add");
37239310Sdim
38239310Sdimtypedef IRBuilder<true, TargetFolder> BuilderTy;
39239310Sdim
40239310Sdimnamespace {
41239310Sdim  struct BoundsChecking : public FunctionPass {
42239310Sdim    static char ID;
43239310Sdim
44249423Sdim    BoundsChecking() : FunctionPass(ID) {
45239310Sdim      initializeBoundsCheckingPass(*PassRegistry::getPassRegistry());
46239310Sdim    }
47239310Sdim
48239310Sdim    virtual bool runOnFunction(Function &F);
49239310Sdim
50239310Sdim    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
51243830Sdim      AU.addRequired<DataLayout>();
52243830Sdim      AU.addRequired<TargetLibraryInfo>();
53239310Sdim    }
54239310Sdim
55239310Sdim  private:
56243830Sdim    const DataLayout *TD;
57243830Sdim    const TargetLibraryInfo *TLI;
58239310Sdim    ObjectSizeOffsetEvaluator *ObjSizeEval;
59239310Sdim    BuilderTy *Builder;
60239310Sdim    Instruction *Inst;
61239310Sdim    BasicBlock *TrapBB;
62239310Sdim
63239310Sdim    BasicBlock *getTrapBB();
64239310Sdim    void emitBranchToTrap(Value *Cmp = 0);
65239310Sdim    bool computeAllocSize(Value *Ptr, APInt &Offset, Value* &OffsetValue,
66239310Sdim                          APInt &Size, Value* &SizeValue);
67239310Sdim    bool instrument(Value *Ptr, Value *Val);
68239310Sdim };
69239310Sdim}
70239310Sdim
71239310Sdimchar BoundsChecking::ID = 0;
72239310SdimINITIALIZE_PASS(BoundsChecking, "bounds-checking", "Run-time bounds checking",
73239310Sdim                false, false)
74239310Sdim
75239310Sdim
76239310Sdim/// getTrapBB - create a basic block that traps. All overflowing conditions
77239310Sdim/// branch to this block. There's only one trap block per function.
78239310SdimBasicBlock *BoundsChecking::getTrapBB() {
79239310Sdim  if (TrapBB && SingleTrapBB)
80239310Sdim    return TrapBB;
81239310Sdim
82239310Sdim  Function *Fn = Inst->getParent()->getParent();
83263508Sdim  IRBuilder<>::InsertPointGuard Guard(*Builder);
84239310Sdim  TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
85239310Sdim  Builder->SetInsertPoint(TrapBB);
86239310Sdim
87239310Sdim  llvm::Value *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap);
88239310Sdim  CallInst *TrapCall = Builder->CreateCall(F);
89239310Sdim  TrapCall->setDoesNotReturn();
90239310Sdim  TrapCall->setDoesNotThrow();
91239310Sdim  TrapCall->setDebugLoc(Inst->getDebugLoc());
92239310Sdim  Builder->CreateUnreachable();
93239310Sdim
94239310Sdim  return TrapBB;
95239310Sdim}
96239310Sdim
97239310Sdim
98239310Sdim/// emitBranchToTrap - emit a branch instruction to a trap block.
99239310Sdim/// If Cmp is non-null, perform a jump only if its value evaluates to true.
100239310Sdimvoid BoundsChecking::emitBranchToTrap(Value *Cmp) {
101239310Sdim  // check if the comparison is always false
102239310Sdim  ConstantInt *C = dyn_cast_or_null<ConstantInt>(Cmp);
103239310Sdim  if (C) {
104239310Sdim    ++ChecksSkipped;
105239310Sdim    if (!C->getZExtValue())
106239310Sdim      return;
107239310Sdim    else
108239310Sdim      Cmp = 0; // unconditional branch
109239310Sdim  }
110249423Sdim  ++ChecksAdded;
111239310Sdim
112239310Sdim  Instruction *Inst = Builder->GetInsertPoint();
113239310Sdim  BasicBlock *OldBB = Inst->getParent();
114239310Sdim  BasicBlock *Cont = OldBB->splitBasicBlock(Inst);
115239310Sdim  OldBB->getTerminator()->eraseFromParent();
116239310Sdim
117239310Sdim  if (Cmp)
118239310Sdim    BranchInst::Create(getTrapBB(), Cont, Cmp, OldBB);
119239310Sdim  else
120239310Sdim    BranchInst::Create(getTrapBB(), OldBB);
121239310Sdim}
122239310Sdim
123239310Sdim
124239310Sdim/// instrument - adds run-time bounds checks to memory accessing instructions.
125239310Sdim/// Ptr is the pointer that will be read/written, and InstVal is either the
126239310Sdim/// result from the load or the value being stored. It is used to determine the
127239310Sdim/// size of memory block that is touched.
128239310Sdim/// Returns true if any change was made to the IR, false otherwise.
129239310Sdimbool BoundsChecking::instrument(Value *Ptr, Value *InstVal) {
130239310Sdim  uint64_t NeededSize = TD->getTypeStoreSize(InstVal->getType());
131239310Sdim  DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
132239310Sdim              << " bytes\n");
133239310Sdim
134239310Sdim  SizeOffsetEvalType SizeOffset = ObjSizeEval->compute(Ptr);
135239310Sdim
136239310Sdim  if (!ObjSizeEval->bothKnown(SizeOffset)) {
137239310Sdim    ++ChecksUnable;
138239310Sdim    return false;
139239310Sdim  }
140239310Sdim
141239310Sdim  Value *Size   = SizeOffset.first;
142239310Sdim  Value *Offset = SizeOffset.second;
143239310Sdim  ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size);
144239310Sdim
145243830Sdim  Type *IntTy = TD->getIntPtrType(Ptr->getType());
146239310Sdim  Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize);
147239310Sdim
148239310Sdim  // three checks are required to ensure safety:
149239310Sdim  // . Offset >= 0  (since the offset is given from the base ptr)
150239310Sdim  // . Size >= Offset  (unsigned)
151239310Sdim  // . Size - Offset >= NeededSize  (unsigned)
152239310Sdim  //
153239310Sdim  // optimization: if Size >= 0 (signed), skip 1st check
154239310Sdim  // FIXME: add NSW/NUW here?  -- we dont care if the subtraction overflows
155239310Sdim  Value *ObjSize = Builder->CreateSub(Size, Offset);
156239310Sdim  Value *Cmp2 = Builder->CreateICmpULT(Size, Offset);
157239310Sdim  Value *Cmp3 = Builder->CreateICmpULT(ObjSize, NeededSizeVal);
158239310Sdim  Value *Or = Builder->CreateOr(Cmp2, Cmp3);
159239310Sdim  if (!SizeCI || SizeCI->getValue().slt(0)) {
160239310Sdim    Value *Cmp1 = Builder->CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0));
161239310Sdim    Or = Builder->CreateOr(Cmp1, Or);
162239310Sdim  }
163239310Sdim  emitBranchToTrap(Or);
164239310Sdim
165239310Sdim  return true;
166239310Sdim}
167239310Sdim
168239310Sdimbool BoundsChecking::runOnFunction(Function &F) {
169243830Sdim  TD = &getAnalysis<DataLayout>();
170243830Sdim  TLI = &getAnalysis<TargetLibraryInfo>();
171239310Sdim
172239310Sdim  TrapBB = 0;
173239310Sdim  BuilderTy TheBuilder(F.getContext(), TargetFolder(TD));
174239310Sdim  Builder = &TheBuilder;
175263508Sdim  ObjectSizeOffsetEvaluator TheObjSizeEval(TD, TLI, F.getContext(),
176263508Sdim                                           /*RoundToAlign=*/true);
177239310Sdim  ObjSizeEval = &TheObjSizeEval;
178239310Sdim
179239310Sdim  // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
180239310Sdim  // touching instructions
181239310Sdim  std::vector<Instruction*> WorkList;
182239310Sdim  for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) {
183239310Sdim    Instruction *I = &*i;
184239310Sdim    if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<AtomicCmpXchgInst>(I) ||
185239310Sdim        isa<AtomicRMWInst>(I))
186239310Sdim        WorkList.push_back(I);
187239310Sdim  }
188239310Sdim
189239310Sdim  bool MadeChange = false;
190239310Sdim  for (std::vector<Instruction*>::iterator i = WorkList.begin(),
191239310Sdim       e = WorkList.end(); i != e; ++i) {
192239310Sdim    Inst = *i;
193239310Sdim
194239310Sdim    Builder->SetInsertPoint(Inst);
195239310Sdim    if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
196239310Sdim      MadeChange |= instrument(LI->getPointerOperand(), LI);
197239310Sdim    } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
198239310Sdim      MadeChange |= instrument(SI->getPointerOperand(), SI->getValueOperand());
199239310Sdim    } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) {
200239310Sdim      MadeChange |= instrument(AI->getPointerOperand(),AI->getCompareOperand());
201239310Sdim    } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(Inst)) {
202239310Sdim      MadeChange |= instrument(AI->getPointerOperand(), AI->getValOperand());
203239310Sdim    } else {
204239310Sdim      llvm_unreachable("unknown Instruction type");
205239310Sdim    }
206239310Sdim  }
207239310Sdim  return MadeChange;
208239310Sdim}
209239310Sdim
210249423SdimFunctionPass *llvm::createBoundsCheckingPass() {
211249423Sdim  return new BoundsChecking();
212239310Sdim}
213