1133589Smarius//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===// 2133589Smarius// 3133589Smarius// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4133589Smarius// See https://llvm.org/LICENSE.txt for license information. 5133589Smarius// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6133589Smarius// 7133589Smarius//===----------------------------------------------------------------------===// 8133589Smarius// 9133589Smarius// This pass exports all llvm.bitset's found in the module in the form of a 10133589Smarius// __cfi_check function, which can be used to verify cross-DSO call targets. 11133589Smarius// 12133589Smarius//===----------------------------------------------------------------------===// 13133589Smarius 14133589Smarius#include "llvm/Transforms/IPO/CrossDSOCFI.h" 15133589Smarius#include "llvm/ADT/SetVector.h" 16133589Smarius#include "llvm/ADT/Statistic.h" 17133589Smarius#include "llvm/IR/Constants.h" 18133589Smarius#include "llvm/IR/Function.h" 19133589Smarius#include "llvm/IR/GlobalObject.h" 20133589Smarius#include "llvm/IR/IRBuilder.h" 21133589Smarius#include "llvm/IR/Instructions.h" 22133589Smarius#include "llvm/IR/Intrinsics.h" 23133589Smarius#include "llvm/IR/MDBuilder.h" 24133589Smarius#include "llvm/IR/Module.h" 25133589Smarius#include "llvm/TargetParser/Triple.h" 26133589Smarius#include "llvm/Transforms/IPO.h" 27133589Smarius 28133589Smariususing namespace llvm; 29133589Smarius 30133589Smarius#define DEBUG_TYPE "cross-dso-cfi" 31133589Smarius 32133589SmariusSTATISTIC(NumTypeIds, "Number of unique type identifiers"); 33133589Smarius 34133589Smariusnamespace { 35133589Smarius 36133589Smariusstruct CrossDSOCFI { 37133589Smarius MDNode *VeryLikelyWeights; 38133589Smarius 39133589Smarius ConstantInt *extractNumericTypeId(MDNode *MD); 40133589Smarius void buildCFICheck(Module &M); 41133589Smarius bool runOnModule(Module &M); 42133589Smarius}; 43133589Smarius 44133589Smarius} // anonymous namespace 45133589Smarius 46133589Smarius/// Extracts a numeric type identifier from an MDNode containing type metadata. 47133589SmariusConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) { 48133589Smarius // This check excludes vtables for classes inside anonymous namespaces. 49133589Smarius auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1)); 50133589Smarius if (!TM) 51133589Smarius return nullptr; 52133589Smarius auto C = dyn_cast_or_null<ConstantInt>(TM->getValue()); 53133589Smarius if (!C) return nullptr; 54133589Smarius // We are looking for i64 constants. 55133589Smarius if (C->getBitWidth() != 64) return nullptr; 56133589Smarius 57133589Smarius return C; 58133589Smarius} 59133589Smarius 60133589Smarius/// buildCFICheck - emits __cfi_check for the current module. 61133589Smariusvoid CrossDSOCFI::buildCFICheck(Module &M) { 62133589Smarius // FIXME: verify that __cfi_check ends up near the end of the code section, 63133589Smarius // but before the jump slots created in LowerTypeTests. 64133589Smarius SetVector<uint64_t> TypeIds; 65133589Smarius SmallVector<MDNode *, 2> Types; 66133589Smarius for (GlobalObject &GO : M.global_objects()) { 67133589Smarius Types.clear(); 68133589Smarius GO.getMetadata(LLVMContext::MD_type, Types); 69133589Smarius for (MDNode *Type : Types) 70133589Smarius if (ConstantInt *TypeId = extractNumericTypeId(Type)) 71133589Smarius TypeIds.insert(TypeId->getZExtValue()); 72133589Smarius } 73133589Smarius 74 NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); 75 if (CfiFunctionsMD) { 76 for (auto *Func : CfiFunctionsMD->operands()) { 77 assert(Func->getNumOperands() >= 2); 78 for (unsigned I = 2; I < Func->getNumOperands(); ++I) 79 if (ConstantInt *TypeId = 80 extractNumericTypeId(cast<MDNode>(Func->getOperand(I).get()))) 81 TypeIds.insert(TypeId->getZExtValue()); 82 } 83 } 84 85 LLVMContext &Ctx = M.getContext(); 86 FunctionCallee C = M.getOrInsertFunction( 87 "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), 88 PointerType::getUnqual(Ctx), PointerType::getUnqual(Ctx)); 89 Function *F = cast<Function>(C.getCallee()); 90 // Take over the existing function. The frontend emits a weak stub so that the 91 // linker knows about the symbol; this pass replaces the function body. 92 F->deleteBody(); 93 F->setAlignment(Align(4096)); 94 95 Triple T(M.getTargetTriple()); 96 if (T.isARM() || T.isThumb()) 97 F->addFnAttr("target-features", "+thumb-mode"); 98 99 auto args = F->arg_begin(); 100 Value &CallSiteTypeId = *(args++); 101 CallSiteTypeId.setName("CallSiteTypeId"); 102 Value &Addr = *(args++); 103 Addr.setName("Addr"); 104 Value &CFICheckFailData = *(args++); 105 CFICheckFailData.setName("CFICheckFailData"); 106 assert(args == F->arg_end()); 107 108 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); 109 BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); 110 111 BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F); 112 IRBuilder<> IRBFail(TrapBB); 113 FunctionCallee CFICheckFailFn = M.getOrInsertFunction( 114 "__cfi_check_fail", Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx), 115 PointerType::getUnqual(Ctx)); 116 IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); 117 IRBFail.CreateBr(ExitBB); 118 119 IRBuilder<> IRBExit(ExitBB); 120 IRBExit.CreateRetVoid(); 121 122 IRBuilder<> IRB(BB); 123 SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size()); 124 for (uint64_t TypeId : TypeIds) { 125 ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId); 126 BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F); 127 IRBuilder<> IRBTest(TestBB); 128 Function *BitsetTestFn = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 129 130 Value *Test = IRBTest.CreateCall( 131 BitsetTestFn, {&Addr, MetadataAsValue::get( 132 Ctx, ConstantAsMetadata::get(CaseTypeId))}); 133 BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB); 134 BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights); 135 136 SI->addCase(CaseTypeId, TestBB); 137 ++NumTypeIds; 138 } 139} 140 141bool CrossDSOCFI::runOnModule(Module &M) { 142 VeryLikelyWeights = 143 MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); 144 if (M.getModuleFlag("Cross-DSO CFI") == nullptr) 145 return false; 146 buildCFICheck(M); 147 return true; 148} 149 150PreservedAnalyses CrossDSOCFIPass::run(Module &M, ModuleAnalysisManager &AM) { 151 CrossDSOCFI Impl; 152 bool Changed = Impl.runOnModule(M); 153 if (!Changed) 154 return PreservedAnalyses::all(); 155 return PreservedAnalyses::none(); 156} 157