1292915Sdim//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===// 2292915Sdim// 3292915Sdim// The LLVM Compiler Infrastructure 4292915Sdim// 5292915Sdim// This file is distributed under the University of Illinois Open Source 6292915Sdim// License. See LICENSE.TXT for details. 7292915Sdim// 8292915Sdim//===----------------------------------------------------------------------===// 9292915Sdim// 10292915Sdim// This pass exports all llvm.bitset's found in the module in the form of a 11292915Sdim// __cfi_check function, which can be used to verify cross-DSO call targets. 12292915Sdim// 13292915Sdim//===----------------------------------------------------------------------===// 14292915Sdim 15292915Sdim#include "llvm/Transforms/IPO.h" 16292915Sdim#include "llvm/ADT/DenseSet.h" 17292915Sdim#include "llvm/ADT/EquivalenceClasses.h" 18292915Sdim#include "llvm/ADT/Statistic.h" 19292915Sdim#include "llvm/IR/Constant.h" 20292915Sdim#include "llvm/IR/Constants.h" 21292915Sdim#include "llvm/IR/Function.h" 22292915Sdim#include "llvm/IR/GlobalObject.h" 23292915Sdim#include "llvm/IR/GlobalVariable.h" 24292915Sdim#include "llvm/IR/IRBuilder.h" 25292915Sdim#include "llvm/IR/Instructions.h" 26292915Sdim#include "llvm/IR/Intrinsics.h" 27292915Sdim#include "llvm/IR/MDBuilder.h" 28292915Sdim#include "llvm/IR/Module.h" 29292915Sdim#include "llvm/IR/Operator.h" 30292915Sdim#include "llvm/Pass.h" 31292915Sdim#include "llvm/Support/Debug.h" 32292915Sdim#include "llvm/Support/raw_ostream.h" 33292915Sdim#include "llvm/Transforms/Utils/BasicBlockUtils.h" 34292915Sdim 35292915Sdimusing namespace llvm; 36292915Sdim 37292915Sdim#define DEBUG_TYPE "cross-dso-cfi" 38292915Sdim 39292915SdimSTATISTIC(TypeIds, "Number of unique type identifiers"); 40292915Sdim 41292915Sdimnamespace { 42292915Sdim 43292915Sdimstruct CrossDSOCFI : public ModulePass { 44292915Sdim static char ID; 45292915Sdim CrossDSOCFI() : ModulePass(ID) { 46292915Sdim initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry()); 47292915Sdim } 48292915Sdim 49292915Sdim Module *M; 50292915Sdim MDNode *VeryLikelyWeights; 51292915Sdim 52292915Sdim ConstantInt *extractBitSetTypeId(MDNode *MD); 53292915Sdim void buildCFICheck(); 54292915Sdim 55292915Sdim bool doInitialization(Module &M) override; 56292915Sdim bool runOnModule(Module &M) override; 57292915Sdim}; 58292915Sdim 59292915Sdim} // anonymous namespace 60292915Sdim 61292915SdimINITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, 62292915Sdim false) 63292915SdimINITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false) 64292915Sdimchar CrossDSOCFI::ID = 0; 65292915Sdim 66292915SdimModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; } 67292915Sdim 68292915Sdimbool CrossDSOCFI::doInitialization(Module &Mod) { 69292915Sdim M = &Mod; 70292915Sdim VeryLikelyWeights = 71292915Sdim MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1); 72292915Sdim 73292915Sdim return false; 74292915Sdim} 75292915Sdim 76292915Sdim/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode. 77292915SdimConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { 78292915Sdim // This check excludes vtables for classes inside anonymous namespaces. 79292915Sdim auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0)); 80292915Sdim if (!TM) 81292915Sdim return nullptr; 82292915Sdim auto C = dyn_cast_or_null<ConstantInt>(TM->getValue()); 83292915Sdim if (!C) return nullptr; 84292915Sdim // We are looking for i64 constants. 85292915Sdim if (C->getBitWidth() != 64) return nullptr; 86292915Sdim 87292915Sdim // Sanity check. 88292915Sdim auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1)); 89292915Sdim // Can be null if a function was removed by an optimization. 90292915Sdim if (FM) { 91292915Sdim auto F = dyn_cast<Function>(FM->getValue()); 92292915Sdim // But can never be a function declaration. 93292915Sdim assert(!F || !F->isDeclaration()); 94292915Sdim (void)F; // Suppress unused variable warning in the no-asserts build. 95292915Sdim } 96292915Sdim return C; 97292915Sdim} 98292915Sdim 99292915Sdim/// buildCFICheck - emits __cfi_check for the current module. 100292915Sdimvoid CrossDSOCFI::buildCFICheck() { 101292915Sdim // FIXME: verify that __cfi_check ends up near the end of the code section, 102292915Sdim // but before the jump slots created in LowerBitSets. 103292915Sdim llvm::DenseSet<uint64_t> BitSetIds; 104292915Sdim NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets"); 105292915Sdim 106292915Sdim if (BitSetNM) 107292915Sdim for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) 108292915Sdim if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I))) 109292915Sdim BitSetIds.insert(TypeId->getZExtValue()); 110292915Sdim 111292915Sdim LLVMContext &Ctx = M->getContext(); 112292915Sdim Constant *C = M->getOrInsertFunction( 113292915Sdim "__cfi_check", 114292915Sdim FunctionType::get( 115292915Sdim Type::getVoidTy(Ctx), 116292915Sdim {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))}, 117292915Sdim false)); 118292915Sdim Function *F = dyn_cast<Function>(C); 119292915Sdim F->setAlignment(4096); 120292915Sdim auto args = F->arg_begin(); 121292915Sdim Argument &CallSiteTypeId = *(args++); 122292915Sdim CallSiteTypeId.setName("CallSiteTypeId"); 123292915Sdim Argument &Addr = *(args++); 124292915Sdim Addr.setName("Addr"); 125292915Sdim assert(args == F->arg_end()); 126292915Sdim 127292915Sdim BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); 128292915Sdim 129292915Sdim BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F); 130292915Sdim IRBuilder<> IRBTrap(TrapBB); 131292915Sdim Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap); 132292915Sdim llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn); 133292915Sdim TrapCall->setDoesNotReturn(); 134292915Sdim TrapCall->setDoesNotThrow(); 135292915Sdim IRBTrap.CreateUnreachable(); 136292915Sdim 137292915Sdim BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); 138292915Sdim IRBuilder<> IRBExit(ExitBB); 139292915Sdim IRBExit.CreateRetVoid(); 140292915Sdim 141292915Sdim IRBuilder<> IRB(BB); 142292915Sdim SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size()); 143292915Sdim for (uint64_t TypeId : BitSetIds) { 144292915Sdim ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId); 145292915Sdim BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F); 146292915Sdim IRBuilder<> IRBTest(TestBB); 147292915Sdim Function *BitsetTestFn = 148292915Sdim Intrinsic::getDeclaration(M, Intrinsic::bitset_test); 149292915Sdim 150292915Sdim Value *Test = IRBTest.CreateCall( 151292915Sdim BitsetTestFn, {&Addr, MetadataAsValue::get( 152292915Sdim Ctx, ConstantAsMetadata::get(CaseTypeId))}); 153292915Sdim BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB); 154292915Sdim BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights); 155292915Sdim 156292915Sdim SI->addCase(CaseTypeId, TestBB); 157292915Sdim ++TypeIds; 158292915Sdim } 159292915Sdim} 160292915Sdim 161292915Sdimbool CrossDSOCFI::runOnModule(Module &M) { 162292915Sdim if (M.getModuleFlag("Cross-DSO CFI") == nullptr) 163292915Sdim return false; 164292915Sdim buildCFICheck(); 165292915Sdim return true; 166292915Sdim} 167