1//===- MergeFunctions.cpp - Merge identical functions ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass looks for equivalent functions that are mergable and folds them.
10//
11// Order relation is defined on set of functions. It was made through
12// special function comparison procedure that returns
13// 0 when functions are equal,
14// -1 when Left function is less than right function, and
15// 1 for opposite case. We need total-ordering, so we need to maintain
16// four properties on the functions set:
17// a <= a (reflexivity)
18// if a <= b and b <= a then a = b (antisymmetry)
19// if a <= b and b <= c then a <= c (transitivity).
20// for all a and b: a <= b or b <= a (totality).
21//
22// Comparison iterates through each instruction in each basic block.
23// Functions are kept on binary tree. For each new function F we perform
24// lookup in binary tree.
25// In practice it works the following way:
26// -- We define Function* container class with custom "operator<" (FunctionPtr).
27// -- "FunctionPtr" instances are stored in std::set collection, so every
28//    std::set::insert operation will give you result in log(N) time.
29//
30// As an optimization, a hash of the function structure is calculated first, and
31// two functions are only compared if they have the same hash. This hash is
32// cheap to compute, and has the property that if function F == G according to
33// the comparison function, then hash(F) == hash(G). This consistency property
34// is critical to ensuring all possible merging opportunities are exploited.
35// Collisions in the hash affect the speed of the pass but not the correctness
36// or determinism of the resulting transformation.
37//
38// When a match is found the functions are folded. If both functions are
39// overridable, we move the functionality into a new internal function and
40// leave two overridable thunks to it.
41//
42//===----------------------------------------------------------------------===//
43//
44// Future work:
45//
46// * virtual functions.
47//
48// Many functions have their address taken by the virtual function table for
49// the object they belong to. However, as long as it's only used for a lookup
50// and call, this is irrelevant, and we'd like to fold such functions.
51//
52// * be smarter about bitcasts.
53//
54// In order to fold functions, we will sometimes add either bitcast instructions
55// or bitcast constant expressions. Unfortunately, this can confound further
56// analysis since the two functions differ where one has a bitcast and the
57// other doesn't. We should learn to look through bitcasts.
58//
59// * Compare complex types with pointer types inside.
60// * Compare cross-reference cases.
61// * Compare complex expressions.
62//
63// All the three issues above could be described as ability to prove that
64// fA == fB == fC == fE == fF == fG in example below:
65//
66//  void fA() {
67//    fB();
68//  }
69//  void fB() {
70//    fA();
71//  }
72//
73//  void fE() {
74//    fF();
75//  }
76//  void fF() {
77//    fG();
78//  }
79//  void fG() {
80//    fE();
81//  }
82//
83// Simplest cross-reference case (fA <--> fB) was implemented in previous
84// versions of MergeFunctions, though it presented only in two function pairs
85// in test-suite (that counts >50k functions)
86// Though possibility to detect complex cross-referencing (e.g.: A->B->C->D->A)
87// could cover much more cases.
88//
89//===----------------------------------------------------------------------===//
90
91#include "llvm/ADT/ArrayRef.h"
92#include "llvm/ADT/SmallPtrSet.h"
93#include "llvm/ADT/SmallVector.h"
94#include "llvm/ADT/Statistic.h"
95#include "llvm/IR/Argument.h"
96#include "llvm/IR/Attributes.h"
97#include "llvm/IR/BasicBlock.h"
98#include "llvm/IR/CallSite.h"
99#include "llvm/IR/Constant.h"
100#include "llvm/IR/Constants.h"
101#include "llvm/IR/DebugInfoMetadata.h"
102#include "llvm/IR/DebugLoc.h"
103#include "llvm/IR/DerivedTypes.h"
104#include "llvm/IR/Function.h"
105#include "llvm/IR/GlobalValue.h"
106#include "llvm/IR/IRBuilder.h"
107#include "llvm/IR/InstrTypes.h"
108#include "llvm/IR/Instruction.h"
109#include "llvm/IR/Instructions.h"
110#include "llvm/IR/IntrinsicInst.h"
111#include "llvm/IR/Module.h"
112#include "llvm/IR/Type.h"
113#include "llvm/IR/Use.h"
114#include "llvm/IR/User.h"
115#include "llvm/IR/Value.h"
116#include "llvm/IR/ValueHandle.h"
117#include "llvm/IR/ValueMap.h"
118#include "llvm/InitializePasses.h"
119#include "llvm/Pass.h"
120#include "llvm/Support/Casting.h"
121#include "llvm/Support/CommandLine.h"
122#include "llvm/Support/Debug.h"
123#include "llvm/Support/raw_ostream.h"
124#include "llvm/Transforms/IPO.h"
125#include "llvm/Transforms/IPO/MergeFunctions.h"
126#include "llvm/Transforms/Utils/FunctionComparator.h"
127#include <algorithm>
128#include <cassert>
129#include <iterator>
130#include <set>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135
136#define DEBUG_TYPE "mergefunc"
137
138STATISTIC(NumFunctionsMerged, "Number of functions merged");
139STATISTIC(NumThunksWritten, "Number of thunks generated");
140STATISTIC(NumAliasesWritten, "Number of aliases generated");
141STATISTIC(NumDoubleWeak, "Number of new functions created");
142
143static cl::opt<unsigned> NumFunctionsForSanityCheck(
144    "mergefunc-sanity",
145    cl::desc("How many functions in module could be used for "
146             "MergeFunctions pass sanity check. "
147             "'0' disables this check. Works only with '-debug' key."),
148    cl::init(0), cl::Hidden);
149
150// Under option -mergefunc-preserve-debug-info we:
151// - Do not create a new function for a thunk.
152// - Retain the debug info for a thunk's parameters (and associated
153//   instructions for the debug info) from the entry block.
154//   Note: -debug will display the algorithm at work.
155// - Create debug-info for the call (to the shared implementation) made by
156//   a thunk and its return value.
157// - Erase the rest of the function, retaining the (minimally sized) entry
158//   block to create a thunk.
159// - Preserve a thunk's call site to point to the thunk even when both occur
160//   within the same translation unit, to aid debugability. Note that this
161//   behaviour differs from the underlying -mergefunc implementation which
162//   modifies the thunk's call site to point to the shared implementation
163//   when both occur within the same translation unit.
164static cl::opt<bool>
165    MergeFunctionsPDI("mergefunc-preserve-debug-info", cl::Hidden,
166                      cl::init(false),
167                      cl::desc("Preserve debug info in thunk when mergefunc "
168                               "transformations are made."));
169
170static cl::opt<bool>
171    MergeFunctionsAliases("mergefunc-use-aliases", cl::Hidden,
172                          cl::init(false),
173                          cl::desc("Allow mergefunc to create aliases"));
174
175namespace {
176
177class FunctionNode {
178  mutable AssertingVH<Function> F;
179  FunctionComparator::FunctionHash Hash;
180
181public:
182  // Note the hash is recalculated potentially multiple times, but it is cheap.
183  FunctionNode(Function *F)
184    : F(F), Hash(FunctionComparator::functionHash(*F))  {}
185
186  Function *getFunc() const { return F; }
187  FunctionComparator::FunctionHash getHash() const { return Hash; }
188
189  /// Replace the reference to the function F by the function G, assuming their
190  /// implementations are equal.
191  void replaceBy(Function *G) const {
192    F = G;
193  }
194};
195
196/// MergeFunctions finds functions which will generate identical machine code,
197/// by considering all pointer types to be equivalent. Once identified,
198/// MergeFunctions will fold them by replacing a call to one to a call to a
199/// bitcast of the other.
200class MergeFunctions {
201public:
202  MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
203  }
204
205  bool runOnModule(Module &M);
206
207private:
208  // The function comparison operator is provided here so that FunctionNodes do
209  // not need to become larger with another pointer.
210  class FunctionNodeCmp {
211    GlobalNumberState* GlobalNumbers;
212
213  public:
214    FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {}
215
216    bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const {
217      // Order first by hashes, then full function comparison.
218      if (LHS.getHash() != RHS.getHash())
219        return LHS.getHash() < RHS.getHash();
220      FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers);
221      return FCmp.compare() == -1;
222    }
223  };
224  using FnTreeType = std::set<FunctionNode, FunctionNodeCmp>;
225
226  GlobalNumberState GlobalNumbers;
227
228  /// A work queue of functions that may have been modified and should be
229  /// analyzed again.
230  std::vector<WeakTrackingVH> Deferred;
231
232#ifndef NDEBUG
233  /// Checks the rules of order relation introduced among functions set.
234  /// Returns true, if sanity check has been passed, and false if failed.
235  bool doSanityCheck(std::vector<WeakTrackingVH> &Worklist);
236#endif
237
238  /// Insert a ComparableFunction into the FnTree, or merge it away if it's
239  /// equal to one that's already present.
240  bool insert(Function *NewFunction);
241
242  /// Remove a Function from the FnTree and queue it up for a second sweep of
243  /// analysis.
244  void remove(Function *F);
245
246  /// Find the functions that use this Value and remove them from FnTree and
247  /// queue the functions.
248  void removeUsers(Value *V);
249
250  /// Replace all direct calls of Old with calls of New. Will bitcast New if
251  /// necessary to make types match.
252  void replaceDirectCallers(Function *Old, Function *New);
253
254  /// Merge two equivalent functions. Upon completion, G may be deleted, or may
255  /// be converted into a thunk. In either case, it should never be visited
256  /// again.
257  void mergeTwoFunctions(Function *F, Function *G);
258
259  /// Fill PDIUnrelatedWL with instructions from the entry block that are
260  /// unrelated to parameter related debug info.
261  void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock,
262                                 std::vector<Instruction *> &PDIUnrelatedWL);
263
264  /// Erase the rest of the CFG (i.e. barring the entry block).
265  void eraseTail(Function *G);
266
267  /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
268  /// parameter debug info, from the entry block.
269  void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL);
270
271  /// Replace G with a simple tail call to bitcast(F). Also (unless
272  /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
273  /// delete G.
274  void writeThunk(Function *F, Function *G);
275
276  // Replace G with an alias to F (deleting function G)
277  void writeAlias(Function *F, Function *G);
278
279  // Replace G with an alias to F if possible, or a thunk to F if possible.
280  // Returns false if neither is the case.
281  bool writeThunkOrAlias(Function *F, Function *G);
282
283  /// Replace function F with function G in the function tree.
284  void replaceFunctionInTree(const FunctionNode &FN, Function *G);
285
286  /// The set of all distinct functions. Use the insert() and remove() methods
287  /// to modify it. The map allows efficient lookup and deferring of Functions.
288  FnTreeType FnTree;
289
290  // Map functions to the iterators of the FunctionNode which contains them
291  // in the FnTree. This must be updated carefully whenever the FnTree is
292  // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid
293  // dangling iterators into FnTree. The invariant that preserves this is that
294  // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
295  DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
296};
297
298class MergeFunctionsLegacyPass : public ModulePass {
299public:
300  static char ID;
301
302  MergeFunctionsLegacyPass(): ModulePass(ID) {
303    initializeMergeFunctionsLegacyPassPass(*PassRegistry::getPassRegistry());
304  }
305
306  bool runOnModule(Module &M) override {
307    if (skipModule(M))
308      return false;
309
310    MergeFunctions MF;
311    return MF.runOnModule(M);
312  }
313};
314
315} // end anonymous namespace
316
317char MergeFunctionsLegacyPass::ID = 0;
318INITIALIZE_PASS(MergeFunctionsLegacyPass, "mergefunc",
319                "Merge Functions", false, false)
320
321ModulePass *llvm::createMergeFunctionsPass() {
322  return new MergeFunctionsLegacyPass();
323}
324
325PreservedAnalyses MergeFunctionsPass::run(Module &M,
326                                          ModuleAnalysisManager &AM) {
327  MergeFunctions MF;
328  if (!MF.runOnModule(M))
329    return PreservedAnalyses::all();
330  return PreservedAnalyses::none();
331}
332
333#ifndef NDEBUG
334bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) {
335  if (const unsigned Max = NumFunctionsForSanityCheck) {
336    unsigned TripleNumber = 0;
337    bool Valid = true;
338
339    dbgs() << "MERGEFUNC-SANITY: Started for first " << Max << " functions.\n";
340
341    unsigned i = 0;
342    for (std::vector<WeakTrackingVH>::iterator I = Worklist.begin(),
343                                               E = Worklist.end();
344         I != E && i < Max; ++I, ++i) {
345      unsigned j = i;
346      for (std::vector<WeakTrackingVH>::iterator J = I; J != E && j < Max;
347           ++J, ++j) {
348        Function *F1 = cast<Function>(*I);
349        Function *F2 = cast<Function>(*J);
350        int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare();
351        int Res2 = FunctionComparator(F2, F1, &GlobalNumbers).compare();
352
353        // If F1 <= F2, then F2 >= F1, otherwise report failure.
354        if (Res1 != -Res2) {
355          dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber
356                 << "\n";
357          dbgs() << *F1 << '\n' << *F2 << '\n';
358          Valid = false;
359        }
360
361        if (Res1 == 0)
362          continue;
363
364        unsigned k = j;
365        for (std::vector<WeakTrackingVH>::iterator K = J; K != E && k < Max;
366             ++k, ++K, ++TripleNumber) {
367          if (K == J)
368            continue;
369
370          Function *F3 = cast<Function>(*K);
371          int Res3 = FunctionComparator(F1, F3, &GlobalNumbers).compare();
372          int Res4 = FunctionComparator(F2, F3, &GlobalNumbers).compare();
373
374          bool Transitive = true;
375
376          if (Res1 != 0 && Res1 == Res4) {
377            // F1 > F2, F2 > F3 => F1 > F3
378            Transitive = Res3 == Res1;
379          } else if (Res3 != 0 && Res3 == -Res4) {
380            // F1 > F3, F3 > F2 => F1 > F2
381            Transitive = Res3 == Res1;
382          } else if (Res4 != 0 && -Res3 == Res4) {
383            // F2 > F3, F3 > F1 => F2 > F1
384            Transitive = Res4 == -Res1;
385          }
386
387          if (!Transitive) {
388            dbgs() << "MERGEFUNC-SANITY: Non-transitive; triple: "
389                   << TripleNumber << "\n";
390            dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", "
391                   << Res4 << "\n";
392            dbgs() << *F1 << '\n' << *F2 << '\n' << *F3 << '\n';
393            Valid = false;
394          }
395        }
396      }
397    }
398
399    dbgs() << "MERGEFUNC-SANITY: " << (Valid ? "Passed." : "Failed.") << "\n";
400    return Valid;
401  }
402  return true;
403}
404#endif
405
406/// Check whether \p F is eligible for function merging.
407static bool isEligibleForMerging(Function &F) {
408  return !F.isDeclaration() && !F.hasAvailableExternallyLinkage();
409}
410
411bool MergeFunctions::runOnModule(Module &M) {
412  bool Changed = false;
413
414  // All functions in the module, ordered by hash. Functions with a unique
415  // hash value are easily eliminated.
416  std::vector<std::pair<FunctionComparator::FunctionHash, Function *>>
417    HashedFuncs;
418  for (Function &Func : M) {
419    if (isEligibleForMerging(Func)) {
420      HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func});
421    }
422  }
423
424  llvm::stable_sort(HashedFuncs, less_first());
425
426  auto S = HashedFuncs.begin();
427  for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
428    // If the hash value matches the previous value or the next one, we must
429    // consider merging it. Otherwise it is dropped and never considered again.
430    if ((I != S && std::prev(I)->first == I->first) ||
431        (std::next(I) != IE && std::next(I)->first == I->first) ) {
432      Deferred.push_back(WeakTrackingVH(I->second));
433    }
434  }
435
436  do {
437    std::vector<WeakTrackingVH> Worklist;
438    Deferred.swap(Worklist);
439
440    LLVM_DEBUG(doSanityCheck(Worklist));
441
442    LLVM_DEBUG(dbgs() << "size of module: " << M.size() << '\n');
443    LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
444
445    // Insert functions and merge them.
446    for (WeakTrackingVH &I : Worklist) {
447      if (!I)
448        continue;
449      Function *F = cast<Function>(I);
450      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
451        Changed |= insert(F);
452      }
453    }
454    LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
455  } while (!Deferred.empty());
456
457  FnTree.clear();
458  FNodesInTree.clear();
459  GlobalNumbers.clear();
460
461  return Changed;
462}
463
464// Replace direct callers of Old with New.
465void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
466  Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType());
467  for (auto UI = Old->use_begin(), UE = Old->use_end(); UI != UE;) {
468    Use *U = &*UI;
469    ++UI;
470    CallSite CS(U->getUser());
471    if (CS && CS.isCallee(U)) {
472      // Do not copy attributes from the called function to the call-site.
473      // Function comparison ensures that the attributes are the same up to
474      // type congruences in byval(), in which case we need to keep the byval
475      // type of the call-site, not the callee function.
476      remove(CS.getInstruction()->getFunction());
477      U->set(BitcastNew);
478    }
479  }
480}
481
482// Helper for writeThunk,
483// Selects proper bitcast operation,
484// but a bit simpler then CastInst::getCastOpcode.
485static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
486  Type *SrcTy = V->getType();
487  if (SrcTy->isStructTy()) {
488    assert(DestTy->isStructTy());
489    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
490    Value *Result = UndefValue::get(DestTy);
491    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
492      Value *Element = createCast(
493          Builder, Builder.CreateExtractValue(V, makeArrayRef(I)),
494          DestTy->getStructElementType(I));
495
496      Result =
497          Builder.CreateInsertValue(Result, Element, makeArrayRef(I));
498    }
499    return Result;
500  }
501  assert(!DestTy->isStructTy());
502  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
503    return Builder.CreateIntToPtr(V, DestTy);
504  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
505    return Builder.CreatePtrToInt(V, DestTy);
506  else
507    return Builder.CreateBitCast(V, DestTy);
508}
509
510// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
511// parameter debug info, from the entry block.
512void MergeFunctions::eraseInstsUnrelatedToPDI(
513    std::vector<Instruction *> &PDIUnrelatedWL) {
514  LLVM_DEBUG(
515      dbgs() << " Erasing instructions (in reverse order of appearance in "
516                "entry block) unrelated to parameter debug info from entry "
517                "block: {\n");
518  while (!PDIUnrelatedWL.empty()) {
519    Instruction *I = PDIUnrelatedWL.back();
520    LLVM_DEBUG(dbgs() << "  Deleting Instruction: ");
521    LLVM_DEBUG(I->print(dbgs()));
522    LLVM_DEBUG(dbgs() << "\n");
523    I->eraseFromParent();
524    PDIUnrelatedWL.pop_back();
525  }
526  LLVM_DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter "
527                       "debug info from entry block. \n");
528}
529
530// Reduce G to its entry block.
531void MergeFunctions::eraseTail(Function *G) {
532  std::vector<BasicBlock *> WorklistBB;
533  for (Function::iterator BBI = std::next(G->begin()), BBE = G->end();
534       BBI != BBE; ++BBI) {
535    BBI->dropAllReferences();
536    WorklistBB.push_back(&*BBI);
537  }
538  while (!WorklistBB.empty()) {
539    BasicBlock *BB = WorklistBB.back();
540    BB->eraseFromParent();
541    WorklistBB.pop_back();
542  }
543}
544
545// We are interested in the following instructions from the entry block as being
546// related to parameter debug info:
547// - @llvm.dbg.declare
548// - stores from the incoming parameters to locations on the stack-frame
549// - allocas that create these locations on the stack-frame
550// - @llvm.dbg.value
551// - the entry block's terminator
552// The rest are unrelated to debug info for the parameters; fill up
553// PDIUnrelatedWL with such instructions.
554void MergeFunctions::filterInstsUnrelatedToPDI(
555    BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) {
556  std::set<Instruction *> PDIRelated;
557  for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end();
558       BI != BIE; ++BI) {
559    if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) {
560      LLVM_DEBUG(dbgs() << " Deciding: ");
561      LLVM_DEBUG(BI->print(dbgs()));
562      LLVM_DEBUG(dbgs() << "\n");
563      DILocalVariable *DILocVar = DVI->getVariable();
564      if (DILocVar->isParameter()) {
565        LLVM_DEBUG(dbgs() << "  Include (parameter): ");
566        LLVM_DEBUG(BI->print(dbgs()));
567        LLVM_DEBUG(dbgs() << "\n");
568        PDIRelated.insert(&*BI);
569      } else {
570        LLVM_DEBUG(dbgs() << "  Delete (!parameter): ");
571        LLVM_DEBUG(BI->print(dbgs()));
572        LLVM_DEBUG(dbgs() << "\n");
573      }
574    } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) {
575      LLVM_DEBUG(dbgs() << " Deciding: ");
576      LLVM_DEBUG(BI->print(dbgs()));
577      LLVM_DEBUG(dbgs() << "\n");
578      DILocalVariable *DILocVar = DDI->getVariable();
579      if (DILocVar->isParameter()) {
580        LLVM_DEBUG(dbgs() << "  Parameter: ");
581        LLVM_DEBUG(DILocVar->print(dbgs()));
582        AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress());
583        if (AI) {
584          LLVM_DEBUG(dbgs() << "  Processing alloca users: ");
585          LLVM_DEBUG(dbgs() << "\n");
586          for (User *U : AI->users()) {
587            if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
588              if (Value *Arg = SI->getValueOperand()) {
589                if (dyn_cast<Argument>(Arg)) {
590                  LLVM_DEBUG(dbgs() << "  Include: ");
591                  LLVM_DEBUG(AI->print(dbgs()));
592                  LLVM_DEBUG(dbgs() << "\n");
593                  PDIRelated.insert(AI);
594                  LLVM_DEBUG(dbgs() << "   Include (parameter): ");
595                  LLVM_DEBUG(SI->print(dbgs()));
596                  LLVM_DEBUG(dbgs() << "\n");
597                  PDIRelated.insert(SI);
598                  LLVM_DEBUG(dbgs() << "  Include: ");
599                  LLVM_DEBUG(BI->print(dbgs()));
600                  LLVM_DEBUG(dbgs() << "\n");
601                  PDIRelated.insert(&*BI);
602                } else {
603                  LLVM_DEBUG(dbgs() << "   Delete (!parameter): ");
604                  LLVM_DEBUG(SI->print(dbgs()));
605                  LLVM_DEBUG(dbgs() << "\n");
606                }
607              }
608            } else {
609              LLVM_DEBUG(dbgs() << "   Defer: ");
610              LLVM_DEBUG(U->print(dbgs()));
611              LLVM_DEBUG(dbgs() << "\n");
612            }
613          }
614        } else {
615          LLVM_DEBUG(dbgs() << "  Delete (alloca NULL): ");
616          LLVM_DEBUG(BI->print(dbgs()));
617          LLVM_DEBUG(dbgs() << "\n");
618        }
619      } else {
620        LLVM_DEBUG(dbgs() << "  Delete (!parameter): ");
621        LLVM_DEBUG(BI->print(dbgs()));
622        LLVM_DEBUG(dbgs() << "\n");
623      }
624    } else if (BI->isTerminator() && &*BI == GEntryBlock->getTerminator()) {
625      LLVM_DEBUG(dbgs() << " Will Include Terminator: ");
626      LLVM_DEBUG(BI->print(dbgs()));
627      LLVM_DEBUG(dbgs() << "\n");
628      PDIRelated.insert(&*BI);
629    } else {
630      LLVM_DEBUG(dbgs() << " Defer: ");
631      LLVM_DEBUG(BI->print(dbgs()));
632      LLVM_DEBUG(dbgs() << "\n");
633    }
634  }
635  LLVM_DEBUG(
636      dbgs()
637      << " Report parameter debug info related/related instructions: {\n");
638  for (BasicBlock::iterator BI = GEntryBlock->begin(), BE = GEntryBlock->end();
639       BI != BE; ++BI) {
640
641    Instruction *I = &*BI;
642    if (PDIRelated.find(I) == PDIRelated.end()) {
643      LLVM_DEBUG(dbgs() << "  !PDIRelated: ");
644      LLVM_DEBUG(I->print(dbgs()));
645      LLVM_DEBUG(dbgs() << "\n");
646      PDIUnrelatedWL.push_back(I);
647    } else {
648      LLVM_DEBUG(dbgs() << "   PDIRelated: ");
649      LLVM_DEBUG(I->print(dbgs()));
650      LLVM_DEBUG(dbgs() << "\n");
651    }
652  }
653  LLVM_DEBUG(dbgs() << " }\n");
654}
655
656/// Whether this function may be replaced by a forwarding thunk.
657static bool canCreateThunkFor(Function *F) {
658  if (F->isVarArg())
659    return false;
660
661  // Don't merge tiny functions using a thunk, since it can just end up
662  // making the function larger.
663  if (F->size() == 1) {
664    if (F->front().size() <= 2) {
665      LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName()
666                        << " is too small to bother creating a thunk for\n");
667      return false;
668    }
669  }
670  return true;
671}
672
673// Replace G with a simple tail call to bitcast(F). Also (unless
674// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
675// delete G. Under MergeFunctionsPDI, we use G itself for creating
676// the thunk as we preserve the debug info (and associated instructions)
677// from G's entry block pertaining to G's incoming arguments which are
678// passed on as corresponding arguments in the call that G makes to F.
679// For better debugability, under MergeFunctionsPDI, we do not modify G's
680// call sites to point to F even when within the same translation unit.
681void MergeFunctions::writeThunk(Function *F, Function *G) {
682  BasicBlock *GEntryBlock = nullptr;
683  std::vector<Instruction *> PDIUnrelatedWL;
684  BasicBlock *BB = nullptr;
685  Function *NewG = nullptr;
686  if (MergeFunctionsPDI) {
687    LLVM_DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new "
688                         "function as thunk; retain original: "
689                      << G->getName() << "()\n");
690    GEntryBlock = &G->getEntryBlock();
691    LLVM_DEBUG(
692        dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related "
693                  "debug info for "
694               << G->getName() << "() {\n");
695    filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL);
696    GEntryBlock->getTerminator()->eraseFromParent();
697    BB = GEntryBlock;
698  } else {
699    NewG = Function::Create(G->getFunctionType(), G->getLinkage(),
700                            G->getAddressSpace(), "", G->getParent());
701    NewG->setComdat(G->getComdat());
702    BB = BasicBlock::Create(F->getContext(), "", NewG);
703  }
704
705  IRBuilder<> Builder(BB);
706  Function *H = MergeFunctionsPDI ? G : NewG;
707  SmallVector<Value *, 16> Args;
708  unsigned i = 0;
709  FunctionType *FFTy = F->getFunctionType();
710  for (Argument &AI : H->args()) {
711    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
712    ++i;
713  }
714
715  CallInst *CI = Builder.CreateCall(F, Args);
716  ReturnInst *RI = nullptr;
717  CI->setTailCall();
718  CI->setCallingConv(F->getCallingConv());
719  CI->setAttributes(F->getAttributes());
720  if (H->getReturnType()->isVoidTy()) {
721    RI = Builder.CreateRetVoid();
722  } else {
723    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
724  }
725
726  if (MergeFunctionsPDI) {
727    DISubprogram *DIS = G->getSubprogram();
728    if (DIS) {
729      DebugLoc CIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS);
730      DebugLoc RIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS);
731      CI->setDebugLoc(CIDbgLoc);
732      RI->setDebugLoc(RIDbgLoc);
733    } else {
734      LLVM_DEBUG(
735          dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for "
736                 << G->getName() << "()\n");
737    }
738    eraseTail(G);
739    eraseInstsUnrelatedToPDI(PDIUnrelatedWL);
740    LLVM_DEBUG(
741        dbgs() << "} // End of parameter related debug info filtering for: "
742               << G->getName() << "()\n");
743  } else {
744    NewG->copyAttributesFrom(G);
745    NewG->takeName(G);
746    removeUsers(G);
747    G->replaceAllUsesWith(NewG);
748    G->eraseFromParent();
749  }
750
751  LLVM_DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n');
752  ++NumThunksWritten;
753}
754
755// Whether this function may be replaced by an alias
756static bool canCreateAliasFor(Function *F) {
757  if (!MergeFunctionsAliases || !F->hasGlobalUnnamedAddr())
758    return false;
759
760  // We should only see linkages supported by aliases here
761  assert(F->hasLocalLinkage() || F->hasExternalLinkage()
762      || F->hasWeakLinkage() || F->hasLinkOnceLinkage());
763  return true;
764}
765
766// Replace G with an alias to F (deleting function G)
767void MergeFunctions::writeAlias(Function *F, Function *G) {
768  Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
769  PointerType *PtrType = G->getType();
770  auto *GA = GlobalAlias::create(
771      PtrType->getElementType(), PtrType->getAddressSpace(),
772      G->getLinkage(), "", BitcastF, G->getParent());
773
774  F->setAlignment(MaybeAlign(std::max(F->getAlignment(), G->getAlignment())));
775  GA->takeName(G);
776  GA->setVisibility(G->getVisibility());
777  GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
778
779  removeUsers(G);
780  G->replaceAllUsesWith(GA);
781  G->eraseFromParent();
782
783  LLVM_DEBUG(dbgs() << "writeAlias: " << GA->getName() << '\n');
784  ++NumAliasesWritten;
785}
786
787// Replace G with an alias to F if possible, or a thunk to F if
788// profitable. Returns false if neither is the case.
789bool MergeFunctions::writeThunkOrAlias(Function *F, Function *G) {
790  if (canCreateAliasFor(G)) {
791    writeAlias(F, G);
792    return true;
793  }
794  if (canCreateThunkFor(F)) {
795    writeThunk(F, G);
796    return true;
797  }
798  return false;
799}
800
801// Merge two equivalent functions. Upon completion, Function G is deleted.
802void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
803  if (F->isInterposable()) {
804    assert(G->isInterposable());
805
806    // Both writeThunkOrAlias() calls below must succeed, either because we can
807    // create aliases for G and NewF, or because a thunk for F is profitable.
808    // F here has the same signature as NewF below, so that's what we check.
809    if (!canCreateThunkFor(F) &&
810        (!canCreateAliasFor(F) || !canCreateAliasFor(G)))
811      return;
812
813    // Make them both thunks to the same internal function.
814    Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(),
815                                      F->getAddressSpace(), "", F->getParent());
816    NewF->copyAttributesFrom(F);
817    NewF->takeName(F);
818    removeUsers(F);
819    F->replaceAllUsesWith(NewF);
820
821    MaybeAlign MaxAlignment(std::max(G->getAlignment(), NewF->getAlignment()));
822
823    writeThunkOrAlias(F, G);
824    writeThunkOrAlias(F, NewF);
825
826    F->setAlignment(MaxAlignment);
827    F->setLinkage(GlobalValue::PrivateLinkage);
828    ++NumDoubleWeak;
829    ++NumFunctionsMerged;
830  } else {
831    // For better debugability, under MergeFunctionsPDI, we do not modify G's
832    // call sites to point to F even when within the same translation unit.
833    if (!G->isInterposable() && !MergeFunctionsPDI) {
834      if (G->hasGlobalUnnamedAddr()) {
835        // G might have been a key in our GlobalNumberState, and it's illegal
836        // to replace a key in ValueMap<GlobalValue *> with a non-global.
837        GlobalNumbers.erase(G);
838        // If G's address is not significant, replace it entirely.
839        Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
840        removeUsers(G);
841        G->replaceAllUsesWith(BitcastF);
842      } else {
843        // Redirect direct callers of G to F. (See note on MergeFunctionsPDI
844        // above).
845        replaceDirectCallers(G, F);
846      }
847    }
848
849    // If G was internal then we may have replaced all uses of G with F. If so,
850    // stop here and delete G. There's no need for a thunk. (See note on
851    // MergeFunctionsPDI above).
852    if (G->isDiscardableIfUnused() && G->use_empty() && !MergeFunctionsPDI) {
853      G->eraseFromParent();
854      ++NumFunctionsMerged;
855      return;
856    }
857
858    if (writeThunkOrAlias(F, G)) {
859      ++NumFunctionsMerged;
860    }
861  }
862}
863
864/// Replace function F by function G.
865void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN,
866                                           Function *G) {
867  Function *F = FN.getFunc();
868  assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
869         "The two functions must be equal");
870
871  auto I = FNodesInTree.find(F);
872  assert(I != FNodesInTree.end() && "F should be in FNodesInTree");
873  assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G");
874
875  FnTreeType::iterator IterToFNInFnTree = I->second;
876  assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree.");
877  // Remove F -> FN and insert G -> FN
878  FNodesInTree.erase(I);
879  FNodesInTree.insert({G, IterToFNInFnTree});
880  // Replace F with G in FN, which is stored inside the FnTree.
881  FN.replaceBy(G);
882}
883
884// Ordering for functions that are equal under FunctionComparator
885static bool isFuncOrderCorrect(const Function *F, const Function *G) {
886  if (F->isInterposable() != G->isInterposable()) {
887    // Strong before weak, because the weak function may call the strong
888    // one, but not the other way around.
889    return !F->isInterposable();
890  }
891  if (F->hasLocalLinkage() != G->hasLocalLinkage()) {
892    // External before local, because we definitely have to keep the external
893    // function, but may be able to drop the local one.
894    return !F->hasLocalLinkage();
895  }
896  // Impose a total order (by name) on the replacement of functions. This is
897  // important when operating on more than one module independently to prevent
898  // cycles of thunks calling each other when the modules are linked together.
899  return F->getName() <= G->getName();
900}
901
902// Insert a ComparableFunction into the FnTree, or merge it away if equal to one
903// that was already inserted.
904bool MergeFunctions::insert(Function *NewFunction) {
905  std::pair<FnTreeType::iterator, bool> Result =
906      FnTree.insert(FunctionNode(NewFunction));
907
908  if (Result.second) {
909    assert(FNodesInTree.count(NewFunction) == 0);
910    FNodesInTree.insert({NewFunction, Result.first});
911    LLVM_DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName()
912                      << '\n');
913    return false;
914  }
915
916  const FunctionNode &OldF = *Result.first;
917
918  if (!isFuncOrderCorrect(OldF.getFunc(), NewFunction)) {
919    // Swap the two functions.
920    Function *F = OldF.getFunc();
921    replaceFunctionInTree(*Result.first, NewFunction);
922    NewFunction = F;
923    assert(OldF.getFunc() != F && "Must have swapped the functions.");
924  }
925
926  LLVM_DEBUG(dbgs() << "  " << OldF.getFunc()->getName()
927                    << " == " << NewFunction->getName() << '\n');
928
929  Function *DeleteF = NewFunction;
930  mergeTwoFunctions(OldF.getFunc(), DeleteF);
931  return true;
932}
933
934// Remove a function from FnTree. If it was already in FnTree, add
935// it to Deferred so that we'll look at it in the next round.
936void MergeFunctions::remove(Function *F) {
937  auto I = FNodesInTree.find(F);
938  if (I != FNodesInTree.end()) {
939    LLVM_DEBUG(dbgs() << "Deferred " << F->getName() << ".\n");
940    FnTree.erase(I->second);
941    // I->second has been invalidated, remove it from the FNodesInTree map to
942    // preserve the invariant.
943    FNodesInTree.erase(I);
944    Deferred.emplace_back(F);
945  }
946}
947
948// For each instruction used by the value, remove() the function that contains
949// the instruction. This should happen right before a call to RAUW.
950void MergeFunctions::removeUsers(Value *V) {
951  for (User *U : V->users())
952    if (auto *I = dyn_cast<Instruction>(U))
953      remove(I->getFunction());
954}
955