1292915Sdim//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===//
2292915Sdim//
3292915Sdim//                     The LLVM Compiler Infrastructure
4292915Sdim//
5292915Sdim// This file is distributed under the University of Illinois Open Source
6292915Sdim// License. See LICENSE.TXT for details.
7292915Sdim//
8292915Sdim//===----------------------------------------------------------------------===//
9292915Sdim//
10292915Sdim// This pass exports all llvm.bitset's found in the module in the form of a
11292915Sdim// __cfi_check function, which can be used to verify cross-DSO call targets.
12292915Sdim//
13292915Sdim//===----------------------------------------------------------------------===//
14292915Sdim
15292915Sdim#include "llvm/Transforms/IPO.h"
16292915Sdim#include "llvm/ADT/DenseSet.h"
17292915Sdim#include "llvm/ADT/EquivalenceClasses.h"
18292915Sdim#include "llvm/ADT/Statistic.h"
19292915Sdim#include "llvm/IR/Constant.h"
20292915Sdim#include "llvm/IR/Constants.h"
21292915Sdim#include "llvm/IR/Function.h"
22292915Sdim#include "llvm/IR/GlobalObject.h"
23292915Sdim#include "llvm/IR/GlobalVariable.h"
24292915Sdim#include "llvm/IR/IRBuilder.h"
25292915Sdim#include "llvm/IR/Instructions.h"
26292915Sdim#include "llvm/IR/Intrinsics.h"
27292915Sdim#include "llvm/IR/MDBuilder.h"
28292915Sdim#include "llvm/IR/Module.h"
29292915Sdim#include "llvm/IR/Operator.h"
30292915Sdim#include "llvm/Pass.h"
31292915Sdim#include "llvm/Support/Debug.h"
32292915Sdim#include "llvm/Support/raw_ostream.h"
33292915Sdim#include "llvm/Transforms/Utils/BasicBlockUtils.h"
34292915Sdim
35292915Sdimusing namespace llvm;
36292915Sdim
37292915Sdim#define DEBUG_TYPE "cross-dso-cfi"
38292915Sdim
39292915SdimSTATISTIC(TypeIds, "Number of unique type identifiers");
40292915Sdim
41292915Sdimnamespace {
42292915Sdim
43292915Sdimstruct CrossDSOCFI : public ModulePass {
44292915Sdim  static char ID;
45292915Sdim  CrossDSOCFI() : ModulePass(ID) {
46292915Sdim    initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry());
47292915Sdim  }
48292915Sdim
49292915Sdim  Module *M;
50292915Sdim  MDNode *VeryLikelyWeights;
51292915Sdim
52292915Sdim  ConstantInt *extractBitSetTypeId(MDNode *MD);
53292915Sdim  void buildCFICheck();
54292915Sdim
55292915Sdim  bool doInitialization(Module &M) override;
56292915Sdim  bool runOnModule(Module &M) override;
57292915Sdim};
58292915Sdim
59292915Sdim} // anonymous namespace
60292915Sdim
61292915SdimINITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false,
62292915Sdim                      false)
63292915SdimINITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false)
64292915Sdimchar CrossDSOCFI::ID = 0;
65292915Sdim
66292915SdimModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; }
67292915Sdim
68292915Sdimbool CrossDSOCFI::doInitialization(Module &Mod) {
69292915Sdim  M = &Mod;
70292915Sdim  VeryLikelyWeights =
71292915Sdim      MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1);
72292915Sdim
73292915Sdim  return false;
74292915Sdim}
75292915Sdim
76292915Sdim/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode.
77292915SdimConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) {
78292915Sdim  // This check excludes vtables for classes inside anonymous namespaces.
79292915Sdim  auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0));
80292915Sdim  if (!TM)
81292915Sdim    return nullptr;
82292915Sdim  auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());
83292915Sdim  if (!C) return nullptr;
84292915Sdim  // We are looking for i64 constants.
85292915Sdim  if (C->getBitWidth() != 64) return nullptr;
86292915Sdim
87292915Sdim  // Sanity check.
88292915Sdim  auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1));
89292915Sdim  // Can be null if a function was removed by an optimization.
90292915Sdim  if (FM) {
91292915Sdim    auto F = dyn_cast<Function>(FM->getValue());
92292915Sdim    // But can never be a function declaration.
93292915Sdim    assert(!F || !F->isDeclaration());
94292915Sdim    (void)F; // Suppress unused variable warning in the no-asserts build.
95292915Sdim  }
96292915Sdim  return C;
97292915Sdim}
98292915Sdim
99292915Sdim/// buildCFICheck - emits __cfi_check for the current module.
100292915Sdimvoid CrossDSOCFI::buildCFICheck() {
101292915Sdim  // FIXME: verify that __cfi_check ends up near the end of the code section,
102292915Sdim  // but before the jump slots created in LowerBitSets.
103292915Sdim  llvm::DenseSet<uint64_t> BitSetIds;
104292915Sdim  NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets");
105292915Sdim
106292915Sdim  if (BitSetNM)
107292915Sdim    for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I)
108292915Sdim      if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I)))
109292915Sdim        BitSetIds.insert(TypeId->getZExtValue());
110292915Sdim
111292915Sdim  LLVMContext &Ctx = M->getContext();
112292915Sdim  Constant *C = M->getOrInsertFunction(
113292915Sdim      "__cfi_check",
114292915Sdim      FunctionType::get(
115292915Sdim          Type::getVoidTy(Ctx),
116292915Sdim          {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))},
117292915Sdim          false));
118292915Sdim  Function *F = dyn_cast<Function>(C);
119292915Sdim  F->setAlignment(4096);
120292915Sdim  auto args = F->arg_begin();
121292915Sdim  Argument &CallSiteTypeId = *(args++);
122292915Sdim  CallSiteTypeId.setName("CallSiteTypeId");
123292915Sdim  Argument &Addr = *(args++);
124292915Sdim  Addr.setName("Addr");
125292915Sdim  assert(args == F->arg_end());
126292915Sdim
127292915Sdim  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
128292915Sdim
129292915Sdim  BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F);
130292915Sdim  IRBuilder<> IRBTrap(TrapBB);
131292915Sdim  Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap);
132292915Sdim  llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn);
133292915Sdim  TrapCall->setDoesNotReturn();
134292915Sdim  TrapCall->setDoesNotThrow();
135292915Sdim  IRBTrap.CreateUnreachable();
136292915Sdim
137292915Sdim  BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
138292915Sdim  IRBuilder<> IRBExit(ExitBB);
139292915Sdim  IRBExit.CreateRetVoid();
140292915Sdim
141292915Sdim  IRBuilder<> IRB(BB);
142292915Sdim  SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size());
143292915Sdim  for (uint64_t TypeId : BitSetIds) {
144292915Sdim    ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);
145292915Sdim    BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);
146292915Sdim    IRBuilder<> IRBTest(TestBB);
147292915Sdim    Function *BitsetTestFn =
148292915Sdim        Intrinsic::getDeclaration(M, Intrinsic::bitset_test);
149292915Sdim
150292915Sdim    Value *Test = IRBTest.CreateCall(
151292915Sdim        BitsetTestFn, {&Addr, MetadataAsValue::get(
152292915Sdim                                  Ctx, ConstantAsMetadata::get(CaseTypeId))});
153292915Sdim    BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB);
154292915Sdim    BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);
155292915Sdim
156292915Sdim    SI->addCase(CaseTypeId, TestBB);
157292915Sdim    ++TypeIds;
158292915Sdim  }
159292915Sdim}
160292915Sdim
161292915Sdimbool CrossDSOCFI::runOnModule(Module &M) {
162292915Sdim  if (M.getModuleFlag("Cross-DSO CFI") == nullptr)
163292915Sdim    return false;
164292915Sdim  buildCFICheck();
165292915Sdim  return true;
166292915Sdim}
167