CrossDSOCFI.cpp revision 292915
1213496Scognet//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===// 2213496Scognet// 3213496Scognet// The LLVM Compiler Infrastructure 4213496Scognet// 5213496Scognet// This file is distributed under the University of Illinois Open Source 6213496Scognet// License. See LICENSE.TXT for details. 7213496Scognet// 8213496Scognet//===----------------------------------------------------------------------===// 9213496Scognet// 10213496Scognet// This pass exports all llvm.bitset's found in the module in the form of a 11213496Scognet// __cfi_check function, which can be used to verify cross-DSO call targets. 12213496Scognet// 13213496Scognet//===----------------------------------------------------------------------===// 14213496Scognet 15213496Scognet#include "llvm/Transforms/IPO.h" 16213496Scognet#include "llvm/ADT/DenseSet.h" 17213496Scognet#include "llvm/ADT/EquivalenceClasses.h" 18213496Scognet#include "llvm/ADT/Statistic.h" 19213496Scognet#include "llvm/IR/Constant.h" 20213496Scognet#include "llvm/IR/Constants.h" 21213496Scognet#include "llvm/IR/Function.h" 22213496Scognet#include "llvm/IR/GlobalObject.h" 23213496Scognet#include "llvm/IR/GlobalVariable.h" 24213496Scognet#include "llvm/IR/IRBuilder.h" 25213496Scognet#include "llvm/IR/Instructions.h" 26213496Scognet#include "llvm/IR/Intrinsics.h" 27234281Smarius#include "llvm/IR/MDBuilder.h" 28234281Smarius#include "llvm/IR/Module.h" 29234281Smarius#include "llvm/IR/Operator.h" 30213496Scognet#include "llvm/Pass.h" 31213496Scognet#include "llvm/Support/Debug.h" 32266196Sian#include "llvm/Support/raw_ostream.h" 33266196Sian#include "llvm/Transforms/Utils/BasicBlockUtils.h" 34213496Scognet 35213496Scognetusing namespace llvm; 36213496Scognet 37213496Scognet#define DEBUG_TYPE "cross-dso-cfi" 38213496Scognet 39213496ScognetSTATISTIC(TypeIds, "Number of unique type identifiers"); 40213496Scognet 41213496Scognetnamespace { 42213496Scognet 43213496Scognetstruct CrossDSOCFI : public ModulePass { 44213496Scognet static char ID; 45213496Scognet CrossDSOCFI() : ModulePass(ID) { 46213496Scognet initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry()); 47213496Scognet } 48213496Scognet 49213496Scognet Module *M; 50266196Sian MDNode *VeryLikelyWeights; 51266196Sian 52266196Sian ConstantInt *extractBitSetTypeId(MDNode *MD); 53266196Sian void buildCFICheck(); 54266196Sian 55266196Sian bool doInitialization(Module &M) override; 56213496Scognet bool runOnModule(Module &M) override; 57213496Scognet}; 58213496Scognet 59213496Scognet} // anonymous namespace 60213496Scognet 61213496ScognetINITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, 62213496Scognet false) 63234281SmariusINITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false) 64234281Smariuschar CrossDSOCFI::ID = 0; 65213496Scognet 66213496ScognetModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; } 67213496Scognet 68213496Scognetbool CrossDSOCFI::doInitialization(Module &Mod) { 69213496Scognet M = &Mod; 70234281Smarius VeryLikelyWeights = 71213496Scognet MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1); 72213496Scognet 73213496Scognet return false; 74213496Scognet} 75213496Scognet 76213496Scognet/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode. 77234281SmariusConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { 78213496Scognet // This check excludes vtables for classes inside anonymous namespaces. 79213496Scognet auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0)); 80213496Scognet if (!TM) 81213496Scognet return nullptr; 82213496Scognet auto C = dyn_cast_or_null<ConstantInt>(TM->getValue()); 83213496Scognet if (!C) return nullptr; 84213496Scognet // We are looking for i64 constants. 85213496Scognet if (C->getBitWidth() != 64) return nullptr; 86213496Scognet 87213496Scognet // Sanity check. 88213496Scognet auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1)); 89213496Scognet // Can be null if a function was removed by an optimization. 90213496Scognet if (FM) { 91213496Scognet auto F = dyn_cast<Function>(FM->getValue()); 92213496Scognet // But can never be a function declaration. 93213496Scognet assert(!F || !F->isDeclaration()); 94213496Scognet (void)F; // Suppress unused variable warning in the no-asserts build. 95213496Scognet } 96213496Scognet return C; 97213496Scognet} 98213496Scognet 99213496Scognet/// buildCFICheck - emits __cfi_check for the current module. 100213496Scognetvoid CrossDSOCFI::buildCFICheck() { 101213496Scognet // FIXME: verify that __cfi_check ends up near the end of the code section, 102213496Scognet // but before the jump slots created in LowerBitSets. 103213496Scognet llvm::DenseSet<uint64_t> BitSetIds; 104213496Scognet NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets"); 105213496Scognet 106213496Scognet if (BitSetNM) 107213496Scognet for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) 108213496Scognet if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I))) 109213496Scognet BitSetIds.insert(TypeId->getZExtValue()); 110213496Scognet 111213496Scognet LLVMContext &Ctx = M->getContext(); 112213496Scognet Constant *C = M->getOrInsertFunction( 113213496Scognet "__cfi_check", 114213496Scognet FunctionType::get( 115221025Scognet Type::getVoidTy(Ctx), 116213496Scognet {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))}, 117213496Scognet false)); 118213496Scognet Function *F = dyn_cast<Function>(C); 119213496Scognet F->setAlignment(4096); 120213496Scognet auto args = F->arg_begin(); 121213496Scognet Argument &CallSiteTypeId = *(args++); 122213496Scognet CallSiteTypeId.setName("CallSiteTypeId"); 123213496Scognet Argument &Addr = *(args++); 124213496Scognet Addr.setName("Addr"); 125213496Scognet assert(args == F->arg_end()); 126213496Scognet 127213496Scognet BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); 128213496Scognet 129213496Scognet BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F); 130213496Scognet IRBuilder<> IRBTrap(TrapBB); 131213496Scognet Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap); 132213496Scognet llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn); 133213496Scognet TrapCall->setDoesNotReturn(); 134213496Scognet TrapCall->setDoesNotThrow(); 135213496Scognet IRBTrap.CreateUnreachable(); 136213496Scognet 137213496Scognet BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); 138213496Scognet IRBuilder<> IRBExit(ExitBB); 139213496Scognet IRBExit.CreateRetVoid(); 140213496Scognet 141213496Scognet IRBuilder<> IRB(BB); 142266196Sian SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size()); 143266196Sian for (uint64_t TypeId : BitSetIds) { 144266196Sian ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId); 145266196Sian BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F); 146266196Sian IRBuilder<> IRBTest(TestBB); 147266196Sian Function *BitsetTestFn = 148213496Scognet Intrinsic::getDeclaration(M, Intrinsic::bitset_test); 149213496Scognet 150213496Scognet Value *Test = IRBTest.CreateCall( 151213496Scognet BitsetTestFn, {&Addr, MetadataAsValue::get( 152213496Scognet Ctx, ConstantAsMetadata::get(CaseTypeId))}); 153213496Scognet BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB); 154213496Scognet BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights); 155213496Scognet 156213496Scognet SI->addCase(CaseTypeId, TestBB); 157213496Scognet ++TypeIds; 158213496Scognet } 159213496Scognet} 160213496Scognet 161213496Scognetbool CrossDSOCFI::runOnModule(Module &M) { 162213496Scognet if (M.getModuleFlag("Cross-DSO CFI") == nullptr) 163213496Scognet return false; 164213496Scognet buildCFICheck(); 165213496Scognet return true; 166213496Scognet} 167213496Scognet