CrossDSOCFI.cpp revision 292915
1213496Scognet//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===//
2213496Scognet//
3213496Scognet//                     The LLVM Compiler Infrastructure
4213496Scognet//
5213496Scognet// This file is distributed under the University of Illinois Open Source
6213496Scognet// License. See LICENSE.TXT for details.
7213496Scognet//
8213496Scognet//===----------------------------------------------------------------------===//
9213496Scognet//
10213496Scognet// This pass exports all llvm.bitset's found in the module in the form of a
11213496Scognet// __cfi_check function, which can be used to verify cross-DSO call targets.
12213496Scognet//
13213496Scognet//===----------------------------------------------------------------------===//
14213496Scognet
15213496Scognet#include "llvm/Transforms/IPO.h"
16213496Scognet#include "llvm/ADT/DenseSet.h"
17213496Scognet#include "llvm/ADT/EquivalenceClasses.h"
18213496Scognet#include "llvm/ADT/Statistic.h"
19213496Scognet#include "llvm/IR/Constant.h"
20213496Scognet#include "llvm/IR/Constants.h"
21213496Scognet#include "llvm/IR/Function.h"
22213496Scognet#include "llvm/IR/GlobalObject.h"
23213496Scognet#include "llvm/IR/GlobalVariable.h"
24213496Scognet#include "llvm/IR/IRBuilder.h"
25213496Scognet#include "llvm/IR/Instructions.h"
26213496Scognet#include "llvm/IR/Intrinsics.h"
27234281Smarius#include "llvm/IR/MDBuilder.h"
28234281Smarius#include "llvm/IR/Module.h"
29234281Smarius#include "llvm/IR/Operator.h"
30213496Scognet#include "llvm/Pass.h"
31213496Scognet#include "llvm/Support/Debug.h"
32266196Sian#include "llvm/Support/raw_ostream.h"
33266196Sian#include "llvm/Transforms/Utils/BasicBlockUtils.h"
34213496Scognet
35213496Scognetusing namespace llvm;
36213496Scognet
37213496Scognet#define DEBUG_TYPE "cross-dso-cfi"
38213496Scognet
39213496ScognetSTATISTIC(TypeIds, "Number of unique type identifiers");
40213496Scognet
41213496Scognetnamespace {
42213496Scognet
43213496Scognetstruct CrossDSOCFI : public ModulePass {
44213496Scognet  static char ID;
45213496Scognet  CrossDSOCFI() : ModulePass(ID) {
46213496Scognet    initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry());
47213496Scognet  }
48213496Scognet
49213496Scognet  Module *M;
50266196Sian  MDNode *VeryLikelyWeights;
51266196Sian
52266196Sian  ConstantInt *extractBitSetTypeId(MDNode *MD);
53266196Sian  void buildCFICheck();
54266196Sian
55266196Sian  bool doInitialization(Module &M) override;
56213496Scognet  bool runOnModule(Module &M) override;
57213496Scognet};
58213496Scognet
59213496Scognet} // anonymous namespace
60213496Scognet
61213496ScognetINITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false,
62213496Scognet                      false)
63234281SmariusINITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false)
64234281Smariuschar CrossDSOCFI::ID = 0;
65213496Scognet
66213496ScognetModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; }
67213496Scognet
68213496Scognetbool CrossDSOCFI::doInitialization(Module &Mod) {
69213496Scognet  M = &Mod;
70234281Smarius  VeryLikelyWeights =
71213496Scognet      MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1);
72213496Scognet
73213496Scognet  return false;
74213496Scognet}
75213496Scognet
76213496Scognet/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode.
77234281SmariusConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) {
78213496Scognet  // This check excludes vtables for classes inside anonymous namespaces.
79213496Scognet  auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0));
80213496Scognet  if (!TM)
81213496Scognet    return nullptr;
82213496Scognet  auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());
83213496Scognet  if (!C) return nullptr;
84213496Scognet  // We are looking for i64 constants.
85213496Scognet  if (C->getBitWidth() != 64) return nullptr;
86213496Scognet
87213496Scognet  // Sanity check.
88213496Scognet  auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1));
89213496Scognet  // Can be null if a function was removed by an optimization.
90213496Scognet  if (FM) {
91213496Scognet    auto F = dyn_cast<Function>(FM->getValue());
92213496Scognet    // But can never be a function declaration.
93213496Scognet    assert(!F || !F->isDeclaration());
94213496Scognet    (void)F; // Suppress unused variable warning in the no-asserts build.
95213496Scognet  }
96213496Scognet  return C;
97213496Scognet}
98213496Scognet
99213496Scognet/// buildCFICheck - emits __cfi_check for the current module.
100213496Scognetvoid CrossDSOCFI::buildCFICheck() {
101213496Scognet  // FIXME: verify that __cfi_check ends up near the end of the code section,
102213496Scognet  // but before the jump slots created in LowerBitSets.
103213496Scognet  llvm::DenseSet<uint64_t> BitSetIds;
104213496Scognet  NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets");
105213496Scognet
106213496Scognet  if (BitSetNM)
107213496Scognet    for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I)
108213496Scognet      if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I)))
109213496Scognet        BitSetIds.insert(TypeId->getZExtValue());
110213496Scognet
111213496Scognet  LLVMContext &Ctx = M->getContext();
112213496Scognet  Constant *C = M->getOrInsertFunction(
113213496Scognet      "__cfi_check",
114213496Scognet      FunctionType::get(
115221025Scognet          Type::getVoidTy(Ctx),
116213496Scognet          {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))},
117213496Scognet          false));
118213496Scognet  Function *F = dyn_cast<Function>(C);
119213496Scognet  F->setAlignment(4096);
120213496Scognet  auto args = F->arg_begin();
121213496Scognet  Argument &CallSiteTypeId = *(args++);
122213496Scognet  CallSiteTypeId.setName("CallSiteTypeId");
123213496Scognet  Argument &Addr = *(args++);
124213496Scognet  Addr.setName("Addr");
125213496Scognet  assert(args == F->arg_end());
126213496Scognet
127213496Scognet  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
128213496Scognet
129213496Scognet  BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F);
130213496Scognet  IRBuilder<> IRBTrap(TrapBB);
131213496Scognet  Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap);
132213496Scognet  llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn);
133213496Scognet  TrapCall->setDoesNotReturn();
134213496Scognet  TrapCall->setDoesNotThrow();
135213496Scognet  IRBTrap.CreateUnreachable();
136213496Scognet
137213496Scognet  BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
138213496Scognet  IRBuilder<> IRBExit(ExitBB);
139213496Scognet  IRBExit.CreateRetVoid();
140213496Scognet
141213496Scognet  IRBuilder<> IRB(BB);
142266196Sian  SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size());
143266196Sian  for (uint64_t TypeId : BitSetIds) {
144266196Sian    ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);
145266196Sian    BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);
146266196Sian    IRBuilder<> IRBTest(TestBB);
147266196Sian    Function *BitsetTestFn =
148213496Scognet        Intrinsic::getDeclaration(M, Intrinsic::bitset_test);
149213496Scognet
150213496Scognet    Value *Test = IRBTest.CreateCall(
151213496Scognet        BitsetTestFn, {&Addr, MetadataAsValue::get(
152213496Scognet                                  Ctx, ConstantAsMetadata::get(CaseTypeId))});
153213496Scognet    BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB);
154213496Scognet    BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);
155213496Scognet
156213496Scognet    SI->addCase(CaseTypeId, TestBB);
157213496Scognet    ++TypeIds;
158213496Scognet  }
159213496Scognet}
160213496Scognet
161213496Scognetbool CrossDSOCFI::runOnModule(Module &M) {
162213496Scognet  if (M.getModuleFlag("Cross-DSO CFI") == nullptr)
163213496Scognet    return false;
164213496Scognet  buildCFICheck();
165213496Scognet  return true;
166213496Scognet}
167213496Scognet