1326938Sdim//===- EntryExitInstrumenter.cpp - Function Entry/Exit Instrumentation ----===//
2326938Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6326938Sdim//
7326938Sdim//===----------------------------------------------------------------------===//
8326938Sdim
9326938Sdim#include "llvm/Transforms/Utils/EntryExitInstrumenter.h"
10326938Sdim#include "llvm/Analysis/GlobalsModRef.h"
11326938Sdim#include "llvm/IR/DebugInfoMetadata.h"
12326938Sdim#include "llvm/IR/Function.h"
13326938Sdim#include "llvm/IR/Instructions.h"
14326938Sdim#include "llvm/IR/Module.h"
15326938Sdim#include "llvm/IR/Type.h"
16360784Sdim#include "llvm/InitializePasses.h"
17326938Sdim#include "llvm/Pass.h"
18341825Sdim#include "llvm/Transforms/Utils.h"
19326938Sdimusing namespace llvm;
20326938Sdim
21326938Sdimstatic void insertCall(Function &CurFn, StringRef Func,
22326938Sdim                       Instruction *InsertionPt, DebugLoc DL) {
23326938Sdim  Module &M = *InsertionPt->getParent()->getParent()->getParent();
24326938Sdim  LLVMContext &C = InsertionPt->getParent()->getContext();
25326938Sdim
26326938Sdim  if (Func == "mcount" ||
27326938Sdim      Func == ".mcount" ||
28360784Sdim      Func == "llvm.arm.gnu.eabi.mcount" ||
29326938Sdim      Func == "\01_mcount" ||
30326938Sdim      Func == "\01mcount" ||
31326938Sdim      Func == "__mcount" ||
32326938Sdim      Func == "_mcount" ||
33326938Sdim      Func == "__cyg_profile_func_enter_bare") {
34353358Sdim    FunctionCallee Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C));
35326938Sdim    CallInst *Call = CallInst::Create(Fn, "", InsertionPt);
36326938Sdim    Call->setDebugLoc(DL);
37326938Sdim    return;
38326938Sdim  }
39326938Sdim
40326938Sdim  if (Func == "__cyg_profile_func_enter" || Func == "__cyg_profile_func_exit") {
41326938Sdim    Type *ArgTypes[] = {Type::getInt8PtrTy(C), Type::getInt8PtrTy(C)};
42326938Sdim
43353358Sdim    FunctionCallee Fn = M.getOrInsertFunction(
44326938Sdim        Func, FunctionType::get(Type::getVoidTy(C), ArgTypes, false));
45326938Sdim
46326938Sdim    Instruction *RetAddr = CallInst::Create(
47326938Sdim        Intrinsic::getDeclaration(&M, Intrinsic::returnaddress),
48326938Sdim        ArrayRef<Value *>(ConstantInt::get(Type::getInt32Ty(C), 0)), "",
49326938Sdim        InsertionPt);
50326938Sdim    RetAddr->setDebugLoc(DL);
51326938Sdim
52326938Sdim    Value *Args[] = {ConstantExpr::getBitCast(&CurFn, Type::getInt8PtrTy(C)),
53326938Sdim                     RetAddr};
54326938Sdim
55326938Sdim    CallInst *Call =
56326938Sdim        CallInst::Create(Fn, ArrayRef<Value *>(Args), "", InsertionPt);
57326938Sdim    Call->setDebugLoc(DL);
58326938Sdim    return;
59326938Sdim  }
60326938Sdim
61326938Sdim  // We only know how to call a fixed set of instrumentation functions, because
62326938Sdim  // they all expect different arguments, etc.
63326938Sdim  report_fatal_error(Twine("Unknown instrumentation function: '") + Func + "'");
64326938Sdim}
65326938Sdim
66326938Sdimstatic bool runOnFunction(Function &F, bool PostInlining) {
67326938Sdim  StringRef EntryAttr = PostInlining ? "instrument-function-entry-inlined"
68326938Sdim                                     : "instrument-function-entry";
69326938Sdim
70326938Sdim  StringRef ExitAttr = PostInlining ? "instrument-function-exit-inlined"
71326938Sdim                                    : "instrument-function-exit";
72326938Sdim
73326938Sdim  StringRef EntryFunc = F.getFnAttribute(EntryAttr).getValueAsString();
74326938Sdim  StringRef ExitFunc = F.getFnAttribute(ExitAttr).getValueAsString();
75326938Sdim
76326938Sdim  bool Changed = false;
77326938Sdim
78326938Sdim  // If the attribute is specified, insert instrumentation and then "consume"
79326938Sdim  // the attribute so that it's not inserted again if the pass should happen to
80326938Sdim  // run later for some reason.
81326938Sdim
82326938Sdim  if (!EntryFunc.empty()) {
83326938Sdim    DebugLoc DL;
84326938Sdim    if (auto SP = F.getSubprogram())
85326938Sdim      DL = DebugLoc::get(SP->getScopeLine(), 0, SP);
86326938Sdim
87326938Sdim    insertCall(F, EntryFunc, &*F.begin()->getFirstInsertionPt(), DL);
88326938Sdim    Changed = true;
89326938Sdim    F.removeAttribute(AttributeList::FunctionIndex, EntryAttr);
90326938Sdim  }
91326938Sdim
92326938Sdim  if (!ExitFunc.empty()) {
93326938Sdim    for (BasicBlock &BB : F) {
94341825Sdim      Instruction *T = BB.getTerminator();
95341825Sdim      if (!isa<ReturnInst>(T))
96341825Sdim        continue;
97341825Sdim
98341825Sdim      // If T is preceded by a musttail call, that's the real terminator.
99341825Sdim      Instruction *Prev = T->getPrevNode();
100341825Sdim      if (BitCastInst *BCI = dyn_cast_or_null<BitCastInst>(Prev))
101341825Sdim        Prev = BCI->getPrevNode();
102341825Sdim      if (CallInst *CI = dyn_cast_or_null<CallInst>(Prev)) {
103341825Sdim        if (CI->isMustTailCall())
104341825Sdim          T = CI;
105341825Sdim      }
106341825Sdim
107326938Sdim      DebugLoc DL;
108326938Sdim      if (DebugLoc TerminatorDL = T->getDebugLoc())
109326938Sdim        DL = TerminatorDL;
110326938Sdim      else if (auto SP = F.getSubprogram())
111326938Sdim        DL = DebugLoc::get(0, 0, SP);
112326938Sdim
113341825Sdim      insertCall(F, ExitFunc, T, DL);
114341825Sdim      Changed = true;
115326938Sdim    }
116326938Sdim    F.removeAttribute(AttributeList::FunctionIndex, ExitAttr);
117326938Sdim  }
118326938Sdim
119326938Sdim  return Changed;
120326938Sdim}
121326938Sdim
122326938Sdimnamespace {
123326938Sdimstruct EntryExitInstrumenter : public FunctionPass {
124326938Sdim  static char ID;
125326938Sdim  EntryExitInstrumenter() : FunctionPass(ID) {
126326938Sdim    initializeEntryExitInstrumenterPass(*PassRegistry::getPassRegistry());
127326938Sdim  }
128326938Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
129326938Sdim    AU.addPreserved<GlobalsAAWrapperPass>();
130326938Sdim  }
131326938Sdim  bool runOnFunction(Function &F) override { return ::runOnFunction(F, false); }
132326938Sdim};
133326938Sdimchar EntryExitInstrumenter::ID = 0;
134326938Sdim
135326938Sdimstruct PostInlineEntryExitInstrumenter : public FunctionPass {
136326938Sdim  static char ID;
137326938Sdim  PostInlineEntryExitInstrumenter() : FunctionPass(ID) {
138326938Sdim    initializePostInlineEntryExitInstrumenterPass(
139326938Sdim        *PassRegistry::getPassRegistry());
140326938Sdim  }
141326938Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
142326938Sdim    AU.addPreserved<GlobalsAAWrapperPass>();
143326938Sdim  }
144326938Sdim  bool runOnFunction(Function &F) override { return ::runOnFunction(F, true); }
145326938Sdim};
146326938Sdimchar PostInlineEntryExitInstrumenter::ID = 0;
147326938Sdim}
148326938Sdim
149326938SdimINITIALIZE_PASS(
150326938Sdim    EntryExitInstrumenter, "ee-instrument",
151326938Sdim    "Instrument function entry/exit with calls to e.g. mcount() (pre inlining)",
152326938Sdim    false, false)
153326938SdimINITIALIZE_PASS(PostInlineEntryExitInstrumenter, "post-inline-ee-instrument",
154326938Sdim                "Instrument function entry/exit with calls to e.g. mcount() "
155326938Sdim                "(post inlining)",
156326938Sdim                false, false)
157326938Sdim
158326938SdimFunctionPass *llvm::createEntryExitInstrumenterPass() {
159326938Sdim  return new EntryExitInstrumenter();
160326938Sdim}
161326938Sdim
162326938SdimFunctionPass *llvm::createPostInlineEntryExitInstrumenterPass() {
163326938Sdim  return new PostInlineEntryExitInstrumenter();
164326938Sdim}
165326938Sdim
166326938SdimPreservedAnalyses
167326938Sdimllvm::EntryExitInstrumenterPass::run(Function &F, FunctionAnalysisManager &AM) {
168326938Sdim  runOnFunction(F, PostInlining);
169326938Sdim  PreservedAnalyses PA;
170326938Sdim  PA.preserveSet<CFGAnalyses>();
171326938Sdim  return PA;
172326938Sdim}
173