1//===- MergeFunctions.cpp - Merge identical functions ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass looks for equivalent functions that are mergable and folds them.
10//
11// Order relation is defined on set of functions. It was made through
12// special function comparison procedure that returns
13// 0 when functions are equal,
14// -1 when Left function is less than right function, and
15// 1 for opposite case. We need total-ordering, so we need to maintain
16// four properties on the functions set:
17// a <= a (reflexivity)
18// if a <= b and b <= a then a = b (antisymmetry)
19// if a <= b and b <= c then a <= c (transitivity).
20// for all a and b: a <= b or b <= a (totality).
21//
22// Comparison iterates through each instruction in each basic block.
23// Functions are kept on binary tree. For each new function F we perform
24// lookup in binary tree.
25// In practice it works the following way:
26// -- We define Function* container class with custom "operator<" (FunctionPtr).
27// -- "FunctionPtr" instances are stored in std::set collection, so every
28//    std::set::insert operation will give you result in log(N) time.
29//
30// As an optimization, a hash of the function structure is calculated first, and
31// two functions are only compared if they have the same hash. This hash is
32// cheap to compute, and has the property that if function F == G according to
33// the comparison function, then hash(F) == hash(G). This consistency property
34// is critical to ensuring all possible merging opportunities are exploited.
35// Collisions in the hash affect the speed of the pass but not the correctness
36// or determinism of the resulting transformation.
37//
38// When a match is found the functions are folded. If both functions are
39// overridable, we move the functionality into a new internal function and
40// leave two overridable thunks to it.
41//
42//===----------------------------------------------------------------------===//
43//
44// Future work:
45//
46// * virtual functions.
47//
48// Many functions have their address taken by the virtual function table for
49// the object they belong to. However, as long as it's only used for a lookup
50// and call, this is irrelevant, and we'd like to fold such functions.
51//
52// * be smarter about bitcasts.
53//
54// In order to fold functions, we will sometimes add either bitcast instructions
55// or bitcast constant expressions. Unfortunately, this can confound further
56// analysis since the two functions differ where one has a bitcast and the
57// other doesn't. We should learn to look through bitcasts.
58//
59// * Compare complex types with pointer types inside.
60// * Compare cross-reference cases.
61// * Compare complex expressions.
62//
63// All the three issues above could be described as ability to prove that
64// fA == fB == fC == fE == fF == fG in example below:
65//
66//  void fA() {
67//    fB();
68//  }
69//  void fB() {
70//    fA();
71//  }
72//
73//  void fE() {
74//    fF();
75//  }
76//  void fF() {
77//    fG();
78//  }
79//  void fG() {
80//    fE();
81//  }
82//
83// Simplest cross-reference case (fA <--> fB) was implemented in previous
84// versions of MergeFunctions, though it presented only in two function pairs
85// in test-suite (that counts >50k functions)
86// Though possibility to detect complex cross-referencing (e.g.: A->B->C->D->A)
87// could cover much more cases.
88//
89//===----------------------------------------------------------------------===//
90
91#include "llvm/ADT/ArrayRef.h"
92#include "llvm/ADT/SmallPtrSet.h"
93#include "llvm/ADT/SmallVector.h"
94#include "llvm/ADT/Statistic.h"
95#include "llvm/IR/Argument.h"
96#include "llvm/IR/Attributes.h"
97#include "llvm/IR/BasicBlock.h"
98#include "llvm/IR/Constant.h"
99#include "llvm/IR/Constants.h"
100#include "llvm/IR/DebugInfoMetadata.h"
101#include "llvm/IR/DebugLoc.h"
102#include "llvm/IR/DerivedTypes.h"
103#include "llvm/IR/Function.h"
104#include "llvm/IR/GlobalValue.h"
105#include "llvm/IR/IRBuilder.h"
106#include "llvm/IR/InstrTypes.h"
107#include "llvm/IR/Instruction.h"
108#include "llvm/IR/Instructions.h"
109#include "llvm/IR/IntrinsicInst.h"
110#include "llvm/IR/Module.h"
111#include "llvm/IR/Type.h"
112#include "llvm/IR/Use.h"
113#include "llvm/IR/User.h"
114#include "llvm/IR/Value.h"
115#include "llvm/IR/ValueHandle.h"
116#include "llvm/IR/ValueMap.h"
117#include "llvm/InitializePasses.h"
118#include "llvm/Pass.h"
119#include "llvm/Support/Casting.h"
120#include "llvm/Support/CommandLine.h"
121#include "llvm/Support/Debug.h"
122#include "llvm/Support/raw_ostream.h"
123#include "llvm/Transforms/IPO.h"
124#include "llvm/Transforms/IPO/MergeFunctions.h"
125#include "llvm/Transforms/Utils/FunctionComparator.h"
126#include <algorithm>
127#include <cassert>
128#include <iterator>
129#include <set>
130#include <utility>
131#include <vector>
132
133using namespace llvm;
134
135#define DEBUG_TYPE "mergefunc"
136
137STATISTIC(NumFunctionsMerged, "Number of functions merged");
138STATISTIC(NumThunksWritten, "Number of thunks generated");
139STATISTIC(NumAliasesWritten, "Number of aliases generated");
140STATISTIC(NumDoubleWeak, "Number of new functions created");
141
142static cl::opt<unsigned> NumFunctionsForSanityCheck(
143    "mergefunc-sanity",
144    cl::desc("How many functions in module could be used for "
145             "MergeFunctions pass sanity check. "
146             "'0' disables this check. Works only with '-debug' key."),
147    cl::init(0), cl::Hidden);
148
149// Under option -mergefunc-preserve-debug-info we:
150// - Do not create a new function for a thunk.
151// - Retain the debug info for a thunk's parameters (and associated
152//   instructions for the debug info) from the entry block.
153//   Note: -debug will display the algorithm at work.
154// - Create debug-info for the call (to the shared implementation) made by
155//   a thunk and its return value.
156// - Erase the rest of the function, retaining the (minimally sized) entry
157//   block to create a thunk.
158// - Preserve a thunk's call site to point to the thunk even when both occur
159//   within the same translation unit, to aid debugability. Note that this
160//   behaviour differs from the underlying -mergefunc implementation which
161//   modifies the thunk's call site to point to the shared implementation
162//   when both occur within the same translation unit.
163static cl::opt<bool>
164    MergeFunctionsPDI("mergefunc-preserve-debug-info", cl::Hidden,
165                      cl::init(false),
166                      cl::desc("Preserve debug info in thunk when mergefunc "
167                               "transformations are made."));
168
169static cl::opt<bool>
170    MergeFunctionsAliases("mergefunc-use-aliases", cl::Hidden,
171                          cl::init(false),
172                          cl::desc("Allow mergefunc to create aliases"));
173
174namespace {
175
176class FunctionNode {
177  mutable AssertingVH<Function> F;
178  FunctionComparator::FunctionHash Hash;
179
180public:
181  // Note the hash is recalculated potentially multiple times, but it is cheap.
182  FunctionNode(Function *F)
183    : F(F), Hash(FunctionComparator::functionHash(*F))  {}
184
185  Function *getFunc() const { return F; }
186  FunctionComparator::FunctionHash getHash() const { return Hash; }
187
188  /// Replace the reference to the function F by the function G, assuming their
189  /// implementations are equal.
190  void replaceBy(Function *G) const {
191    F = G;
192  }
193};
194
195/// MergeFunctions finds functions which will generate identical machine code,
196/// by considering all pointer types to be equivalent. Once identified,
197/// MergeFunctions will fold them by replacing a call to one to a call to a
198/// bitcast of the other.
199class MergeFunctions {
200public:
201  MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
202  }
203
204  bool runOnModule(Module &M);
205
206private:
207  // The function comparison operator is provided here so that FunctionNodes do
208  // not need to become larger with another pointer.
209  class FunctionNodeCmp {
210    GlobalNumberState* GlobalNumbers;
211
212  public:
213    FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {}
214
215    bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const {
216      // Order first by hashes, then full function comparison.
217      if (LHS.getHash() != RHS.getHash())
218        return LHS.getHash() < RHS.getHash();
219      FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers);
220      return FCmp.compare() == -1;
221    }
222  };
223  using FnTreeType = std::set<FunctionNode, FunctionNodeCmp>;
224
225  GlobalNumberState GlobalNumbers;
226
227  /// A work queue of functions that may have been modified and should be
228  /// analyzed again.
229  std::vector<WeakTrackingVH> Deferred;
230
231#ifndef NDEBUG
232  /// Checks the rules of order relation introduced among functions set.
233  /// Returns true, if sanity check has been passed, and false if failed.
234  bool doSanityCheck(std::vector<WeakTrackingVH> &Worklist);
235#endif
236
237  /// Insert a ComparableFunction into the FnTree, or merge it away if it's
238  /// equal to one that's already present.
239  bool insert(Function *NewFunction);
240
241  /// Remove a Function from the FnTree and queue it up for a second sweep of
242  /// analysis.
243  void remove(Function *F);
244
245  /// Find the functions that use this Value and remove them from FnTree and
246  /// queue the functions.
247  void removeUsers(Value *V);
248
249  /// Replace all direct calls of Old with calls of New. Will bitcast New if
250  /// necessary to make types match.
251  void replaceDirectCallers(Function *Old, Function *New);
252
253  /// Merge two equivalent functions. Upon completion, G may be deleted, or may
254  /// be converted into a thunk. In either case, it should never be visited
255  /// again.
256  void mergeTwoFunctions(Function *F, Function *G);
257
258  /// Fill PDIUnrelatedWL with instructions from the entry block that are
259  /// unrelated to parameter related debug info.
260  void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock,
261                                 std::vector<Instruction *> &PDIUnrelatedWL);
262
263  /// Erase the rest of the CFG (i.e. barring the entry block).
264  void eraseTail(Function *G);
265
266  /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
267  /// parameter debug info, from the entry block.
268  void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL);
269
270  /// Replace G with a simple tail call to bitcast(F). Also (unless
271  /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
272  /// delete G.
273  void writeThunk(Function *F, Function *G);
274
275  // Replace G with an alias to F (deleting function G)
276  void writeAlias(Function *F, Function *G);
277
278  // Replace G with an alias to F if possible, or a thunk to F if possible.
279  // Returns false if neither is the case.
280  bool writeThunkOrAlias(Function *F, Function *G);
281
282  /// Replace function F with function G in the function tree.
283  void replaceFunctionInTree(const FunctionNode &FN, Function *G);
284
285  /// The set of all distinct functions. Use the insert() and remove() methods
286  /// to modify it. The map allows efficient lookup and deferring of Functions.
287  FnTreeType FnTree;
288
289  // Map functions to the iterators of the FunctionNode which contains them
290  // in the FnTree. This must be updated carefully whenever the FnTree is
291  // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid
292  // dangling iterators into FnTree. The invariant that preserves this is that
293  // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
294  DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
295};
296
297class MergeFunctionsLegacyPass : public ModulePass {
298public:
299  static char ID;
300
301  MergeFunctionsLegacyPass(): ModulePass(ID) {
302    initializeMergeFunctionsLegacyPassPass(*PassRegistry::getPassRegistry());
303  }
304
305  bool runOnModule(Module &M) override {
306    if (skipModule(M))
307      return false;
308
309    MergeFunctions MF;
310    return MF.runOnModule(M);
311  }
312};
313
314} // end anonymous namespace
315
316char MergeFunctionsLegacyPass::ID = 0;
317INITIALIZE_PASS(MergeFunctionsLegacyPass, "mergefunc",
318                "Merge Functions", false, false)
319
320ModulePass *llvm::createMergeFunctionsPass() {
321  return new MergeFunctionsLegacyPass();
322}
323
324PreservedAnalyses MergeFunctionsPass::run(Module &M,
325                                          ModuleAnalysisManager &AM) {
326  MergeFunctions MF;
327  if (!MF.runOnModule(M))
328    return PreservedAnalyses::all();
329  return PreservedAnalyses::none();
330}
331
332#ifndef NDEBUG
333bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) {
334  if (const unsigned Max = NumFunctionsForSanityCheck) {
335    unsigned TripleNumber = 0;
336    bool Valid = true;
337
338    dbgs() << "MERGEFUNC-SANITY: Started for first " << Max << " functions.\n";
339
340    unsigned i = 0;
341    for (std::vector<WeakTrackingVH>::iterator I = Worklist.begin(),
342                                               E = Worklist.end();
343         I != E && i < Max; ++I, ++i) {
344      unsigned j = i;
345      for (std::vector<WeakTrackingVH>::iterator J = I; J != E && j < Max;
346           ++J, ++j) {
347        Function *F1 = cast<Function>(*I);
348        Function *F2 = cast<Function>(*J);
349        int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare();
350        int Res2 = FunctionComparator(F2, F1, &GlobalNumbers).compare();
351
352        // If F1 <= F2, then F2 >= F1, otherwise report failure.
353        if (Res1 != -Res2) {
354          dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber
355                 << "\n";
356          dbgs() << *F1 << '\n' << *F2 << '\n';
357          Valid = false;
358        }
359
360        if (Res1 == 0)
361          continue;
362
363        unsigned k = j;
364        for (std::vector<WeakTrackingVH>::iterator K = J; K != E && k < Max;
365             ++k, ++K, ++TripleNumber) {
366          if (K == J)
367            continue;
368
369          Function *F3 = cast<Function>(*K);
370          int Res3 = FunctionComparator(F1, F3, &GlobalNumbers).compare();
371          int Res4 = FunctionComparator(F2, F3, &GlobalNumbers).compare();
372
373          bool Transitive = true;
374
375          if (Res1 != 0 && Res1 == Res4) {
376            // F1 > F2, F2 > F3 => F1 > F3
377            Transitive = Res3 == Res1;
378          } else if (Res3 != 0 && Res3 == -Res4) {
379            // F1 > F3, F3 > F2 => F1 > F2
380            Transitive = Res3 == Res1;
381          } else if (Res4 != 0 && -Res3 == Res4) {
382            // F2 > F3, F3 > F1 => F2 > F1
383            Transitive = Res4 == -Res1;
384          }
385
386          if (!Transitive) {
387            dbgs() << "MERGEFUNC-SANITY: Non-transitive; triple: "
388                   << TripleNumber << "\n";
389            dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", "
390                   << Res4 << "\n";
391            dbgs() << *F1 << '\n' << *F2 << '\n' << *F3 << '\n';
392            Valid = false;
393          }
394        }
395      }
396    }
397
398    dbgs() << "MERGEFUNC-SANITY: " << (Valid ? "Passed." : "Failed.") << "\n";
399    return Valid;
400  }
401  return true;
402}
403#endif
404
405/// Check whether \p F is eligible for function merging.
406static bool isEligibleForMerging(Function &F) {
407  return !F.isDeclaration() && !F.hasAvailableExternallyLinkage();
408}
409
410bool MergeFunctions::runOnModule(Module &M) {
411  bool Changed = false;
412
413  // All functions in the module, ordered by hash. Functions with a unique
414  // hash value are easily eliminated.
415  std::vector<std::pair<FunctionComparator::FunctionHash, Function *>>
416    HashedFuncs;
417  for (Function &Func : M) {
418    if (isEligibleForMerging(Func)) {
419      HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func});
420    }
421  }
422
423  llvm::stable_sort(HashedFuncs, less_first());
424
425  auto S = HashedFuncs.begin();
426  for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
427    // If the hash value matches the previous value or the next one, we must
428    // consider merging it. Otherwise it is dropped and never considered again.
429    if ((I != S && std::prev(I)->first == I->first) ||
430        (std::next(I) != IE && std::next(I)->first == I->first) ) {
431      Deferred.push_back(WeakTrackingVH(I->second));
432    }
433  }
434
435  do {
436    std::vector<WeakTrackingVH> Worklist;
437    Deferred.swap(Worklist);
438
439    LLVM_DEBUG(doSanityCheck(Worklist));
440
441    LLVM_DEBUG(dbgs() << "size of module: " << M.size() << '\n');
442    LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
443
444    // Insert functions and merge them.
445    for (WeakTrackingVH &I : Worklist) {
446      if (!I)
447        continue;
448      Function *F = cast<Function>(I);
449      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
450        Changed |= insert(F);
451      }
452    }
453    LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
454  } while (!Deferred.empty());
455
456  FnTree.clear();
457  FNodesInTree.clear();
458  GlobalNumbers.clear();
459
460  return Changed;
461}
462
463// Replace direct callers of Old with New.
464void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
465  Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType());
466  for (auto UI = Old->use_begin(), UE = Old->use_end(); UI != UE;) {
467    Use *U = &*UI;
468    ++UI;
469    CallBase *CB = dyn_cast<CallBase>(U->getUser());
470    if (CB && CB->isCallee(U)) {
471      // Do not copy attributes from the called function to the call-site.
472      // Function comparison ensures that the attributes are the same up to
473      // type congruences in byval(), in which case we need to keep the byval
474      // type of the call-site, not the callee function.
475      remove(CB->getFunction());
476      U->set(BitcastNew);
477    }
478  }
479}
480
481// Helper for writeThunk,
482// Selects proper bitcast operation,
483// but a bit simpler then CastInst::getCastOpcode.
484static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
485  Type *SrcTy = V->getType();
486  if (SrcTy->isStructTy()) {
487    assert(DestTy->isStructTy());
488    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
489    Value *Result = UndefValue::get(DestTy);
490    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
491      Value *Element = createCast(
492          Builder, Builder.CreateExtractValue(V, makeArrayRef(I)),
493          DestTy->getStructElementType(I));
494
495      Result =
496          Builder.CreateInsertValue(Result, Element, makeArrayRef(I));
497    }
498    return Result;
499  }
500  assert(!DestTy->isStructTy());
501  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
502    return Builder.CreateIntToPtr(V, DestTy);
503  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
504    return Builder.CreatePtrToInt(V, DestTy);
505  else
506    return Builder.CreateBitCast(V, DestTy);
507}
508
509// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
510// parameter debug info, from the entry block.
511void MergeFunctions::eraseInstsUnrelatedToPDI(
512    std::vector<Instruction *> &PDIUnrelatedWL) {
513  LLVM_DEBUG(
514      dbgs() << " Erasing instructions (in reverse order of appearance in "
515                "entry block) unrelated to parameter debug info from entry "
516                "block: {\n");
517  while (!PDIUnrelatedWL.empty()) {
518    Instruction *I = PDIUnrelatedWL.back();
519    LLVM_DEBUG(dbgs() << "  Deleting Instruction: ");
520    LLVM_DEBUG(I->print(dbgs()));
521    LLVM_DEBUG(dbgs() << "\n");
522    I->eraseFromParent();
523    PDIUnrelatedWL.pop_back();
524  }
525  LLVM_DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter "
526                       "debug info from entry block. \n");
527}
528
529// Reduce G to its entry block.
530void MergeFunctions::eraseTail(Function *G) {
531  std::vector<BasicBlock *> WorklistBB;
532  for (Function::iterator BBI = std::next(G->begin()), BBE = G->end();
533       BBI != BBE; ++BBI) {
534    BBI->dropAllReferences();
535    WorklistBB.push_back(&*BBI);
536  }
537  while (!WorklistBB.empty()) {
538    BasicBlock *BB = WorklistBB.back();
539    BB->eraseFromParent();
540    WorklistBB.pop_back();
541  }
542}
543
544// We are interested in the following instructions from the entry block as being
545// related to parameter debug info:
546// - @llvm.dbg.declare
547// - stores from the incoming parameters to locations on the stack-frame
548// - allocas that create these locations on the stack-frame
549// - @llvm.dbg.value
550// - the entry block's terminator
551// The rest are unrelated to debug info for the parameters; fill up
552// PDIUnrelatedWL with such instructions.
553void MergeFunctions::filterInstsUnrelatedToPDI(
554    BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) {
555  std::set<Instruction *> PDIRelated;
556  for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end();
557       BI != BIE; ++BI) {
558    if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) {
559      LLVM_DEBUG(dbgs() << " Deciding: ");
560      LLVM_DEBUG(BI->print(dbgs()));
561      LLVM_DEBUG(dbgs() << "\n");
562      DILocalVariable *DILocVar = DVI->getVariable();
563      if (DILocVar->isParameter()) {
564        LLVM_DEBUG(dbgs() << "  Include (parameter): ");
565        LLVM_DEBUG(BI->print(dbgs()));
566        LLVM_DEBUG(dbgs() << "\n");
567        PDIRelated.insert(&*BI);
568      } else {
569        LLVM_DEBUG(dbgs() << "  Delete (!parameter): ");
570        LLVM_DEBUG(BI->print(dbgs()));
571        LLVM_DEBUG(dbgs() << "\n");
572      }
573    } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) {
574      LLVM_DEBUG(dbgs() << " Deciding: ");
575      LLVM_DEBUG(BI->print(dbgs()));
576      LLVM_DEBUG(dbgs() << "\n");
577      DILocalVariable *DILocVar = DDI->getVariable();
578      if (DILocVar->isParameter()) {
579        LLVM_DEBUG(dbgs() << "  Parameter: ");
580        LLVM_DEBUG(DILocVar->print(dbgs()));
581        AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress());
582        if (AI) {
583          LLVM_DEBUG(dbgs() << "  Processing alloca users: ");
584          LLVM_DEBUG(dbgs() << "\n");
585          for (User *U : AI->users()) {
586            if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
587              if (Value *Arg = SI->getValueOperand()) {
588                if (dyn_cast<Argument>(Arg)) {
589                  LLVM_DEBUG(dbgs() << "  Include: ");
590                  LLVM_DEBUG(AI->print(dbgs()));
591                  LLVM_DEBUG(dbgs() << "\n");
592                  PDIRelated.insert(AI);
593                  LLVM_DEBUG(dbgs() << "   Include (parameter): ");
594                  LLVM_DEBUG(SI->print(dbgs()));
595                  LLVM_DEBUG(dbgs() << "\n");
596                  PDIRelated.insert(SI);
597                  LLVM_DEBUG(dbgs() << "  Include: ");
598                  LLVM_DEBUG(BI->print(dbgs()));
599                  LLVM_DEBUG(dbgs() << "\n");
600                  PDIRelated.insert(&*BI);
601                } else {
602                  LLVM_DEBUG(dbgs() << "   Delete (!parameter): ");
603                  LLVM_DEBUG(SI->print(dbgs()));
604                  LLVM_DEBUG(dbgs() << "\n");
605                }
606              }
607            } else {
608              LLVM_DEBUG(dbgs() << "   Defer: ");
609              LLVM_DEBUG(U->print(dbgs()));
610              LLVM_DEBUG(dbgs() << "\n");
611            }
612          }
613        } else {
614          LLVM_DEBUG(dbgs() << "  Delete (alloca NULL): ");
615          LLVM_DEBUG(BI->print(dbgs()));
616          LLVM_DEBUG(dbgs() << "\n");
617        }
618      } else {
619        LLVM_DEBUG(dbgs() << "  Delete (!parameter): ");
620        LLVM_DEBUG(BI->print(dbgs()));
621        LLVM_DEBUG(dbgs() << "\n");
622      }
623    } else if (BI->isTerminator() && &*BI == GEntryBlock->getTerminator()) {
624      LLVM_DEBUG(dbgs() << " Will Include Terminator: ");
625      LLVM_DEBUG(BI->print(dbgs()));
626      LLVM_DEBUG(dbgs() << "\n");
627      PDIRelated.insert(&*BI);
628    } else {
629      LLVM_DEBUG(dbgs() << " Defer: ");
630      LLVM_DEBUG(BI->print(dbgs()));
631      LLVM_DEBUG(dbgs() << "\n");
632    }
633  }
634  LLVM_DEBUG(
635      dbgs()
636      << " Report parameter debug info related/related instructions: {\n");
637  for (BasicBlock::iterator BI = GEntryBlock->begin(), BE = GEntryBlock->end();
638       BI != BE; ++BI) {
639
640    Instruction *I = &*BI;
641    if (PDIRelated.find(I) == PDIRelated.end()) {
642      LLVM_DEBUG(dbgs() << "  !PDIRelated: ");
643      LLVM_DEBUG(I->print(dbgs()));
644      LLVM_DEBUG(dbgs() << "\n");
645      PDIUnrelatedWL.push_back(I);
646    } else {
647      LLVM_DEBUG(dbgs() << "   PDIRelated: ");
648      LLVM_DEBUG(I->print(dbgs()));
649      LLVM_DEBUG(dbgs() << "\n");
650    }
651  }
652  LLVM_DEBUG(dbgs() << " }\n");
653}
654
655/// Whether this function may be replaced by a forwarding thunk.
656static bool canCreateThunkFor(Function *F) {
657  if (F->isVarArg())
658    return false;
659
660  // Don't merge tiny functions using a thunk, since it can just end up
661  // making the function larger.
662  if (F->size() == 1) {
663    if (F->front().size() <= 2) {
664      LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName()
665                        << " is too small to bother creating a thunk for\n");
666      return false;
667    }
668  }
669  return true;
670}
671
672// Replace G with a simple tail call to bitcast(F). Also (unless
673// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
674// delete G. Under MergeFunctionsPDI, we use G itself for creating
675// the thunk as we preserve the debug info (and associated instructions)
676// from G's entry block pertaining to G's incoming arguments which are
677// passed on as corresponding arguments in the call that G makes to F.
678// For better debugability, under MergeFunctionsPDI, we do not modify G's
679// call sites to point to F even when within the same translation unit.
680void MergeFunctions::writeThunk(Function *F, Function *G) {
681  BasicBlock *GEntryBlock = nullptr;
682  std::vector<Instruction *> PDIUnrelatedWL;
683  BasicBlock *BB = nullptr;
684  Function *NewG = nullptr;
685  if (MergeFunctionsPDI) {
686    LLVM_DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new "
687                         "function as thunk; retain original: "
688                      << G->getName() << "()\n");
689    GEntryBlock = &G->getEntryBlock();
690    LLVM_DEBUG(
691        dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related "
692                  "debug info for "
693               << G->getName() << "() {\n");
694    filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL);
695    GEntryBlock->getTerminator()->eraseFromParent();
696    BB = GEntryBlock;
697  } else {
698    NewG = Function::Create(G->getFunctionType(), G->getLinkage(),
699                            G->getAddressSpace(), "", G->getParent());
700    NewG->setComdat(G->getComdat());
701    BB = BasicBlock::Create(F->getContext(), "", NewG);
702  }
703
704  IRBuilder<> Builder(BB);
705  Function *H = MergeFunctionsPDI ? G : NewG;
706  SmallVector<Value *, 16> Args;
707  unsigned i = 0;
708  FunctionType *FFTy = F->getFunctionType();
709  for (Argument &AI : H->args()) {
710    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
711    ++i;
712  }
713
714  CallInst *CI = Builder.CreateCall(F, Args);
715  ReturnInst *RI = nullptr;
716  CI->setTailCall();
717  CI->setCallingConv(F->getCallingConv());
718  CI->setAttributes(F->getAttributes());
719  if (H->getReturnType()->isVoidTy()) {
720    RI = Builder.CreateRetVoid();
721  } else {
722    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
723  }
724
725  if (MergeFunctionsPDI) {
726    DISubprogram *DIS = G->getSubprogram();
727    if (DIS) {
728      DebugLoc CIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS);
729      DebugLoc RIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS);
730      CI->setDebugLoc(CIDbgLoc);
731      RI->setDebugLoc(RIDbgLoc);
732    } else {
733      LLVM_DEBUG(
734          dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for "
735                 << G->getName() << "()\n");
736    }
737    eraseTail(G);
738    eraseInstsUnrelatedToPDI(PDIUnrelatedWL);
739    LLVM_DEBUG(
740        dbgs() << "} // End of parameter related debug info filtering for: "
741               << G->getName() << "()\n");
742  } else {
743    NewG->copyAttributesFrom(G);
744    NewG->takeName(G);
745    removeUsers(G);
746    G->replaceAllUsesWith(NewG);
747    G->eraseFromParent();
748  }
749
750  LLVM_DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n');
751  ++NumThunksWritten;
752}
753
754// Whether this function may be replaced by an alias
755static bool canCreateAliasFor(Function *F) {
756  if (!MergeFunctionsAliases || !F->hasGlobalUnnamedAddr())
757    return false;
758
759  // We should only see linkages supported by aliases here
760  assert(F->hasLocalLinkage() || F->hasExternalLinkage()
761      || F->hasWeakLinkage() || F->hasLinkOnceLinkage());
762  return true;
763}
764
765// Replace G with an alias to F (deleting function G)
766void MergeFunctions::writeAlias(Function *F, Function *G) {
767  Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
768  PointerType *PtrType = G->getType();
769  auto *GA = GlobalAlias::create(
770      PtrType->getElementType(), PtrType->getAddressSpace(),
771      G->getLinkage(), "", BitcastF, G->getParent());
772
773  F->setAlignment(MaybeAlign(std::max(F->getAlignment(), G->getAlignment())));
774  GA->takeName(G);
775  GA->setVisibility(G->getVisibility());
776  GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
777
778  removeUsers(G);
779  G->replaceAllUsesWith(GA);
780  G->eraseFromParent();
781
782  LLVM_DEBUG(dbgs() << "writeAlias: " << GA->getName() << '\n');
783  ++NumAliasesWritten;
784}
785
786// Replace G with an alias to F if possible, or a thunk to F if
787// profitable. Returns false if neither is the case.
788bool MergeFunctions::writeThunkOrAlias(Function *F, Function *G) {
789  if (canCreateAliasFor(G)) {
790    writeAlias(F, G);
791    return true;
792  }
793  if (canCreateThunkFor(F)) {
794    writeThunk(F, G);
795    return true;
796  }
797  return false;
798}
799
800// Merge two equivalent functions. Upon completion, Function G is deleted.
801void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
802  if (F->isInterposable()) {
803    assert(G->isInterposable());
804
805    // Both writeThunkOrAlias() calls below must succeed, either because we can
806    // create aliases for G and NewF, or because a thunk for F is profitable.
807    // F here has the same signature as NewF below, so that's what we check.
808    if (!canCreateThunkFor(F) &&
809        (!canCreateAliasFor(F) || !canCreateAliasFor(G)))
810      return;
811
812    // Make them both thunks to the same internal function.
813    Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(),
814                                      F->getAddressSpace(), "", F->getParent());
815    NewF->copyAttributesFrom(F);
816    NewF->takeName(F);
817    removeUsers(F);
818    F->replaceAllUsesWith(NewF);
819
820    MaybeAlign MaxAlignment(std::max(G->getAlignment(), NewF->getAlignment()));
821
822    writeThunkOrAlias(F, G);
823    writeThunkOrAlias(F, NewF);
824
825    F->setAlignment(MaxAlignment);
826    F->setLinkage(GlobalValue::PrivateLinkage);
827    ++NumDoubleWeak;
828    ++NumFunctionsMerged;
829  } else {
830    // For better debugability, under MergeFunctionsPDI, we do not modify G's
831    // call sites to point to F even when within the same translation unit.
832    if (!G->isInterposable() && !MergeFunctionsPDI) {
833      if (G->hasGlobalUnnamedAddr()) {
834        // G might have been a key in our GlobalNumberState, and it's illegal
835        // to replace a key in ValueMap<GlobalValue *> with a non-global.
836        GlobalNumbers.erase(G);
837        // If G's address is not significant, replace it entirely.
838        Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
839        removeUsers(G);
840        G->replaceAllUsesWith(BitcastF);
841      } else {
842        // Redirect direct callers of G to F. (See note on MergeFunctionsPDI
843        // above).
844        replaceDirectCallers(G, F);
845      }
846    }
847
848    // If G was internal then we may have replaced all uses of G with F. If so,
849    // stop here and delete G. There's no need for a thunk. (See note on
850    // MergeFunctionsPDI above).
851    if (G->isDiscardableIfUnused() && G->use_empty() && !MergeFunctionsPDI) {
852      G->eraseFromParent();
853      ++NumFunctionsMerged;
854      return;
855    }
856
857    if (writeThunkOrAlias(F, G)) {
858      ++NumFunctionsMerged;
859    }
860  }
861}
862
863/// Replace function F by function G.
864void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN,
865                                           Function *G) {
866  Function *F = FN.getFunc();
867  assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
868         "The two functions must be equal");
869
870  auto I = FNodesInTree.find(F);
871  assert(I != FNodesInTree.end() && "F should be in FNodesInTree");
872  assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G");
873
874  FnTreeType::iterator IterToFNInFnTree = I->second;
875  assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree.");
876  // Remove F -> FN and insert G -> FN
877  FNodesInTree.erase(I);
878  FNodesInTree.insert({G, IterToFNInFnTree});
879  // Replace F with G in FN, which is stored inside the FnTree.
880  FN.replaceBy(G);
881}
882
883// Ordering for functions that are equal under FunctionComparator
884static bool isFuncOrderCorrect(const Function *F, const Function *G) {
885  if (F->isInterposable() != G->isInterposable()) {
886    // Strong before weak, because the weak function may call the strong
887    // one, but not the other way around.
888    return !F->isInterposable();
889  }
890  if (F->hasLocalLinkage() != G->hasLocalLinkage()) {
891    // External before local, because we definitely have to keep the external
892    // function, but may be able to drop the local one.
893    return !F->hasLocalLinkage();
894  }
895  // Impose a total order (by name) on the replacement of functions. This is
896  // important when operating on more than one module independently to prevent
897  // cycles of thunks calling each other when the modules are linked together.
898  return F->getName() <= G->getName();
899}
900
901// Insert a ComparableFunction into the FnTree, or merge it away if equal to one
902// that was already inserted.
903bool MergeFunctions::insert(Function *NewFunction) {
904  std::pair<FnTreeType::iterator, bool> Result =
905      FnTree.insert(FunctionNode(NewFunction));
906
907  if (Result.second) {
908    assert(FNodesInTree.count(NewFunction) == 0);
909    FNodesInTree.insert({NewFunction, Result.first});
910    LLVM_DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName()
911                      << '\n');
912    return false;
913  }
914
915  const FunctionNode &OldF = *Result.first;
916
917  if (!isFuncOrderCorrect(OldF.getFunc(), NewFunction)) {
918    // Swap the two functions.
919    Function *F = OldF.getFunc();
920    replaceFunctionInTree(*Result.first, NewFunction);
921    NewFunction = F;
922    assert(OldF.getFunc() != F && "Must have swapped the functions.");
923  }
924
925  LLVM_DEBUG(dbgs() << "  " << OldF.getFunc()->getName()
926                    << " == " << NewFunction->getName() << '\n');
927
928  Function *DeleteF = NewFunction;
929  mergeTwoFunctions(OldF.getFunc(), DeleteF);
930  return true;
931}
932
933// Remove a function from FnTree. If it was already in FnTree, add
934// it to Deferred so that we'll look at it in the next round.
935void MergeFunctions::remove(Function *F) {
936  auto I = FNodesInTree.find(F);
937  if (I != FNodesInTree.end()) {
938    LLVM_DEBUG(dbgs() << "Deferred " << F->getName() << ".\n");
939    FnTree.erase(I->second);
940    // I->second has been invalidated, remove it from the FNodesInTree map to
941    // preserve the invariant.
942    FNodesInTree.erase(I);
943    Deferred.emplace_back(F);
944  }
945}
946
947// For each instruction used by the value, remove() the function that contains
948// the instruction. This should happen right before a call to RAUW.
949void MergeFunctions::removeUsers(Value *V) {
950  for (User *U : V->users())
951    if (auto *I = dyn_cast<Instruction>(U))
952      remove(I->getFunction());
953}
954