1//===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements whole program optimization of virtual calls in cases
10// where we know (via !type metadata) that the list of callees is fixed. This
11// includes the following:
12// - Single implementation devirtualization: if a virtual call has a single
13//   possible callee, replace all calls with a direct call to that callee.
14// - Virtual constant propagation: if the virtual function's return type is an
15//   integer <=64 bits and all possible callees are readnone, for each class and
16//   each list of constant arguments: evaluate the function, store the return
17//   value alongside the virtual table, and rewrite each virtual call as a load
18//   from the virtual table.
19// - Uniform return value optimization: if the conditions for virtual constant
20//   propagation hold and each function returns the same constant value, replace
21//   each virtual call with that constant.
22// - Unique return value optimization for i1 return values: if the conditions
23//   for virtual constant propagation hold and a single vtable's function
24//   returns 0, or a single vtable's function returns 1, replace each virtual
25//   call with a comparison of the vptr against that vtable's address.
26//
27// This pass is intended to be used during the regular and thin LTO pipelines:
28//
29// During regular LTO, the pass determines the best optimization for each
30// virtual call and applies the resolutions directly to virtual calls that are
31// eligible for virtual call optimization (i.e. calls that use either of the
32// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics).
33//
34// During hybrid Regular/ThinLTO, the pass operates in two phases:
35// - Export phase: this is run during the thin link over a single merged module
36//   that contains all vtables with !type metadata that participate in the link.
37//   The pass computes a resolution for each virtual call and stores it in the
38//   type identifier summary.
39// - Import phase: this is run during the thin backends over the individual
40//   modules. The pass applies the resolutions previously computed during the
41//   import phase to each eligible virtual call.
42//
43// During ThinLTO, the pass operates in two phases:
44// - Export phase: this is run during the thin link over the index which
45//   contains a summary of all vtables with !type metadata that participate in
46//   the link. It computes a resolution for each virtual call and stores it in
47//   the type identifier summary. Only single implementation devirtualization
48//   is supported.
49// - Import phase: (same as with hybrid case above).
50//
51//===----------------------------------------------------------------------===//
52
53#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
54#include "llvm/ADT/ArrayRef.h"
55#include "llvm/ADT/DenseMap.h"
56#include "llvm/ADT/DenseMapInfo.h"
57#include "llvm/ADT/DenseSet.h"
58#include "llvm/ADT/MapVector.h"
59#include "llvm/ADT/SmallVector.h"
60#include "llvm/ADT/Statistic.h"
61#include "llvm/ADT/Triple.h"
62#include "llvm/ADT/iterator_range.h"
63#include "llvm/Analysis/AssumptionCache.h"
64#include "llvm/Analysis/BasicAliasAnalysis.h"
65#include "llvm/Analysis/OptimizationRemarkEmitter.h"
66#include "llvm/Analysis/TypeMetadataUtils.h"
67#include "llvm/Bitcode/BitcodeReader.h"
68#include "llvm/Bitcode/BitcodeWriter.h"
69#include "llvm/IR/Constants.h"
70#include "llvm/IR/DataLayout.h"
71#include "llvm/IR/DebugLoc.h"
72#include "llvm/IR/DerivedTypes.h"
73#include "llvm/IR/Dominators.h"
74#include "llvm/IR/Function.h"
75#include "llvm/IR/GlobalAlias.h"
76#include "llvm/IR/GlobalVariable.h"
77#include "llvm/IR/IRBuilder.h"
78#include "llvm/IR/InstrTypes.h"
79#include "llvm/IR/Instruction.h"
80#include "llvm/IR/Instructions.h"
81#include "llvm/IR/Intrinsics.h"
82#include "llvm/IR/LLVMContext.h"
83#include "llvm/IR/MDBuilder.h"
84#include "llvm/IR/Metadata.h"
85#include "llvm/IR/Module.h"
86#include "llvm/IR/ModuleSummaryIndexYAML.h"
87#include "llvm/InitializePasses.h"
88#include "llvm/Pass.h"
89#include "llvm/PassRegistry.h"
90#include "llvm/Support/Casting.h"
91#include "llvm/Support/CommandLine.h"
92#include "llvm/Support/Errc.h"
93#include "llvm/Support/Error.h"
94#include "llvm/Support/FileSystem.h"
95#include "llvm/Support/GlobPattern.h"
96#include "llvm/Support/MathExtras.h"
97#include "llvm/Transforms/IPO.h"
98#include "llvm/Transforms/IPO/FunctionAttrs.h"
99#include "llvm/Transforms/Utils/BasicBlockUtils.h"
100#include "llvm/Transforms/Utils/CallPromotionUtils.h"
101#include "llvm/Transforms/Utils/Evaluator.h"
102#include <algorithm>
103#include <cstddef>
104#include <map>
105#include <set>
106#include <string>
107
108using namespace llvm;
109using namespace wholeprogramdevirt;
110
111#define DEBUG_TYPE "wholeprogramdevirt"
112
113STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets");
114STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations");
115STATISTIC(NumBranchFunnel, "Number of branch funnels");
116STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations");
117STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations");
118STATISTIC(NumVirtConstProp1Bit,
119          "Number of 1 bit virtual constant propagations");
120STATISTIC(NumVirtConstProp, "Number of virtual constant propagations");
121
122static cl::opt<PassSummaryAction> ClSummaryAction(
123    "wholeprogramdevirt-summary-action",
124    cl::desc("What to do with the summary when running this pass"),
125    cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
126               clEnumValN(PassSummaryAction::Import, "import",
127                          "Import typeid resolutions from summary and globals"),
128               clEnumValN(PassSummaryAction::Export, "export",
129                          "Export typeid resolutions to summary and globals")),
130    cl::Hidden);
131
132static cl::opt<std::string> ClReadSummary(
133    "wholeprogramdevirt-read-summary",
134    cl::desc(
135        "Read summary from given bitcode or YAML file before running pass"),
136    cl::Hidden);
137
138static cl::opt<std::string> ClWriteSummary(
139    "wholeprogramdevirt-write-summary",
140    cl::desc("Write summary to given bitcode or YAML file after running pass. "
141             "Output file format is deduced from extension: *.bc means writing "
142             "bitcode, otherwise YAML"),
143    cl::Hidden);
144
145static cl::opt<unsigned>
146    ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden,
147                cl::init(10),
148                cl::desc("Maximum number of call targets per "
149                         "call site to enable branch funnels"));
150
151static cl::opt<bool>
152    PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden,
153                       cl::desc("Print index-based devirtualization messages"));
154
155/// Provide a way to force enable whole program visibility in tests.
156/// This is needed to support legacy tests that don't contain
157/// !vcall_visibility metadata (the mere presense of type tests
158/// previously implied hidden visibility).
159static cl::opt<bool>
160    WholeProgramVisibility("whole-program-visibility", cl::Hidden,
161                           cl::desc("Enable whole program visibility"));
162
163/// Provide a way to force disable whole program for debugging or workarounds,
164/// when enabled via the linker.
165static cl::opt<bool> DisableWholeProgramVisibility(
166    "disable-whole-program-visibility", cl::Hidden,
167    cl::desc("Disable whole program visibility (overrides enabling options)"));
168
169/// Provide way to prevent certain function from being devirtualized
170static cl::list<std::string>
171    SkipFunctionNames("wholeprogramdevirt-skip",
172                      cl::desc("Prevent function(s) from being devirtualized"),
173                      cl::Hidden, cl::CommaSeparated);
174
175/// Mechanism to add runtime checking of devirtualization decisions, optionally
176/// trapping or falling back to indirect call on any that are not correct.
177/// Trapping mode is useful for debugging undefined behavior leading to failures
178/// with WPD. Fallback mode is useful for ensuring safety when whole program
179/// visibility may be compromised.
180enum WPDCheckMode { None, Trap, Fallback };
181static cl::opt<WPDCheckMode> DevirtCheckMode(
182    "wholeprogramdevirt-check", cl::Hidden,
183    cl::desc("Type of checking for incorrect devirtualizations"),
184    cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"),
185               clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"),
186               clEnumValN(WPDCheckMode::Fallback, "fallback",
187                          "Fallback to indirect when incorrect")));
188
189namespace {
190struct PatternList {
191  std::vector<GlobPattern> Patterns;
192  template <class T> void init(const T &StringList) {
193    for (const auto &S : StringList)
194      if (Expected<GlobPattern> Pat = GlobPattern::create(S))
195        Patterns.push_back(std::move(*Pat));
196  }
197  bool match(StringRef S) {
198    for (const GlobPattern &P : Patterns)
199      if (P.match(S))
200        return true;
201    return false;
202  }
203};
204} // namespace
205
206// Find the minimum offset that we may store a value of size Size bits at. If
207// IsAfter is set, look for an offset before the object, otherwise look for an
208// offset after the object.
209uint64_t
210wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
211                                     bool IsAfter, uint64_t Size) {
212  // Find a minimum offset taking into account only vtable sizes.
213  uint64_t MinByte = 0;
214  for (const VirtualCallTarget &Target : Targets) {
215    if (IsAfter)
216      MinByte = std::max(MinByte, Target.minAfterBytes());
217    else
218      MinByte = std::max(MinByte, Target.minBeforeBytes());
219  }
220
221  // Build a vector of arrays of bytes covering, for each target, a slice of the
222  // used region (see AccumBitVector::BytesUsed in
223  // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
224  // this aligns the used regions to start at MinByte.
225  //
226  // In this example, A, B and C are vtables, # is a byte already allocated for
227  // a virtual function pointer, AAAA... (etc.) are the used regions for the
228  // vtables and Offset(X) is the value computed for the Offset variable below
229  // for X.
230  //
231  //                    Offset(A)
232  //                    |       |
233  //                            |MinByte
234  // A: ################AAAAAAAA|AAAAAAAA
235  // B: ########BBBBBBBBBBBBBBBB|BBBB
236  // C: ########################|CCCCCCCCCCCCCCCC
237  //            |   Offset(B)   |
238  //
239  // This code produces the slices of A, B and C that appear after the divider
240  // at MinByte.
241  std::vector<ArrayRef<uint8_t>> Used;
242  for (const VirtualCallTarget &Target : Targets) {
243    ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
244                                       : Target.TM->Bits->Before.BytesUsed;
245    uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
246                              : MinByte - Target.minBeforeBytes();
247
248    // Disregard used regions that are smaller than Offset. These are
249    // effectively all-free regions that do not need to be checked.
250    if (VTUsed.size() > Offset)
251      Used.push_back(VTUsed.slice(Offset));
252  }
253
254  if (Size == 1) {
255    // Find a free bit in each member of Used.
256    for (unsigned I = 0;; ++I) {
257      uint8_t BitsUsed = 0;
258      for (auto &&B : Used)
259        if (I < B.size())
260          BitsUsed |= B[I];
261      if (BitsUsed != 0xff)
262        return (MinByte + I) * 8 + countTrailingZeros(uint8_t(~BitsUsed));
263    }
264  } else {
265    // Find a free (Size/8) byte region in each member of Used.
266    // FIXME: see if alignment helps.
267    for (unsigned I = 0;; ++I) {
268      for (auto &&B : Used) {
269        unsigned Byte = 0;
270        while ((I + Byte) < B.size() && Byte < (Size / 8)) {
271          if (B[I + Byte])
272            goto NextI;
273          ++Byte;
274        }
275      }
276      return (MinByte + I) * 8;
277    NextI:;
278    }
279  }
280}
281
282void wholeprogramdevirt::setBeforeReturnValues(
283    MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
284    unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
285  if (BitWidth == 1)
286    OffsetByte = -(AllocBefore / 8 + 1);
287  else
288    OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
289  OffsetBit = AllocBefore % 8;
290
291  for (VirtualCallTarget &Target : Targets) {
292    if (BitWidth == 1)
293      Target.setBeforeBit(AllocBefore);
294    else
295      Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
296  }
297}
298
299void wholeprogramdevirt::setAfterReturnValues(
300    MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
301    unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
302  if (BitWidth == 1)
303    OffsetByte = AllocAfter / 8;
304  else
305    OffsetByte = (AllocAfter + 7) / 8;
306  OffsetBit = AllocAfter % 8;
307
308  for (VirtualCallTarget &Target : Targets) {
309    if (BitWidth == 1)
310      Target.setAfterBit(AllocAfter);
311    else
312      Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
313  }
314}
315
316VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
317    : Fn(Fn), TM(TM),
318      IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
319
320namespace {
321
322// A slot in a set of virtual tables. The TypeID identifies the set of virtual
323// tables, and the ByteOffset is the offset in bytes from the address point to
324// the virtual function pointer.
325struct VTableSlot {
326  Metadata *TypeID;
327  uint64_t ByteOffset;
328};
329
330} // end anonymous namespace
331
332namespace llvm {
333
334template <> struct DenseMapInfo<VTableSlot> {
335  static VTableSlot getEmptyKey() {
336    return {DenseMapInfo<Metadata *>::getEmptyKey(),
337            DenseMapInfo<uint64_t>::getEmptyKey()};
338  }
339  static VTableSlot getTombstoneKey() {
340    return {DenseMapInfo<Metadata *>::getTombstoneKey(),
341            DenseMapInfo<uint64_t>::getTombstoneKey()};
342  }
343  static unsigned getHashValue(const VTableSlot &I) {
344    return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
345           DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
346  }
347  static bool isEqual(const VTableSlot &LHS,
348                      const VTableSlot &RHS) {
349    return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
350  }
351};
352
353template <> struct DenseMapInfo<VTableSlotSummary> {
354  static VTableSlotSummary getEmptyKey() {
355    return {DenseMapInfo<StringRef>::getEmptyKey(),
356            DenseMapInfo<uint64_t>::getEmptyKey()};
357  }
358  static VTableSlotSummary getTombstoneKey() {
359    return {DenseMapInfo<StringRef>::getTombstoneKey(),
360            DenseMapInfo<uint64_t>::getTombstoneKey()};
361  }
362  static unsigned getHashValue(const VTableSlotSummary &I) {
363    return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^
364           DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
365  }
366  static bool isEqual(const VTableSlotSummary &LHS,
367                      const VTableSlotSummary &RHS) {
368    return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
369  }
370};
371
372} // end namespace llvm
373
374namespace {
375
376// Returns true if the function must be unreachable based on ValueInfo.
377//
378// In particular, identifies a function as unreachable in the following
379// conditions
380//   1) All summaries are live.
381//   2) All function summaries indicate it's unreachable
382bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
383  if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
384    // Returns false if ValueInfo is absent, or the summary list is empty
385    // (e.g., function declarations).
386    return false;
387  }
388
389  for (const auto &Summary : TheFnVI.getSummaryList()) {
390    // Conservatively returns false if any non-live functions are seen.
391    // In general either all summaries should be live or all should be dead.
392    if (!Summary->isLive())
393      return false;
394    if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) {
395      if (!FS->fflags().MustBeUnreachable)
396        return false;
397    }
398    // Do nothing if a non-function has the same GUID (which is rare).
399    // This is correct since non-function summaries are not relevant.
400  }
401  // All function summaries are live and all of them agree that the function is
402  // unreachble.
403  return true;
404}
405
406// A virtual call site. VTable is the loaded virtual table pointer, and CS is
407// the indirect virtual call.
408struct VirtualCallSite {
409  Value *VTable = nullptr;
410  CallBase &CB;
411
412  // If non-null, this field points to the associated unsafe use count stored in
413  // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
414  // of that field for details.
415  unsigned *NumUnsafeUses = nullptr;
416
417  void
418  emitRemark(const StringRef OptName, const StringRef TargetName,
419             function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
420    Function *F = CB.getCaller();
421    DebugLoc DLoc = CB.getDebugLoc();
422    BasicBlock *Block = CB.getParent();
423
424    using namespace ore;
425    OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
426                      << NV("Optimization", OptName)
427                      << ": devirtualized a call to "
428                      << NV("FunctionName", TargetName));
429  }
430
431  void replaceAndErase(
432      const StringRef OptName, const StringRef TargetName, bool RemarksEnabled,
433      function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
434      Value *New) {
435    if (RemarksEnabled)
436      emitRemark(OptName, TargetName, OREGetter);
437    CB.replaceAllUsesWith(New);
438    if (auto *II = dyn_cast<InvokeInst>(&CB)) {
439      BranchInst::Create(II->getNormalDest(), &CB);
440      II->getUnwindDest()->removePredecessor(II->getParent());
441    }
442    CB.eraseFromParent();
443    // This use is no longer unsafe.
444    if (NumUnsafeUses)
445      --*NumUnsafeUses;
446  }
447};
448
449// Call site information collected for a specific VTableSlot and possibly a list
450// of constant integer arguments. The grouping by arguments is handled by the
451// VTableSlotInfo class.
452struct CallSiteInfo {
453  /// The set of call sites for this slot. Used during regular LTO and the
454  /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
455  /// call sites that appear in the merged module itself); in each of these
456  /// cases we are directly operating on the call sites at the IR level.
457  std::vector<VirtualCallSite> CallSites;
458
459  /// Whether all call sites represented by this CallSiteInfo, including those
460  /// in summaries, have been devirtualized. This starts off as true because a
461  /// default constructed CallSiteInfo represents no call sites.
462  bool AllCallSitesDevirted = true;
463
464  // These fields are used during the export phase of ThinLTO and reflect
465  // information collected from function summaries.
466
467  /// Whether any function summary contains an llvm.assume(llvm.type.test) for
468  /// this slot.
469  bool SummaryHasTypeTestAssumeUsers = false;
470
471  /// CFI-specific: a vector containing the list of function summaries that use
472  /// the llvm.type.checked.load intrinsic and therefore will require
473  /// resolutions for llvm.type.test in order to implement CFI checks if
474  /// devirtualization was unsuccessful. If devirtualization was successful, the
475  /// pass will clear this vector by calling markDevirt(). If at the end of the
476  /// pass the vector is non-empty, we will need to add a use of llvm.type.test
477  /// to each of the function summaries in the vector.
478  std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
479  std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers;
480
481  bool isExported() const {
482    return SummaryHasTypeTestAssumeUsers ||
483           !SummaryTypeCheckedLoadUsers.empty();
484  }
485
486  void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
487    SummaryTypeCheckedLoadUsers.push_back(FS);
488    AllCallSitesDevirted = false;
489  }
490
491  void addSummaryTypeTestAssumeUser(FunctionSummary *FS) {
492    SummaryTypeTestAssumeUsers.push_back(FS);
493    SummaryHasTypeTestAssumeUsers = true;
494    AllCallSitesDevirted = false;
495  }
496
497  void markDevirt() {
498    AllCallSitesDevirted = true;
499
500    // As explained in the comment for SummaryTypeCheckedLoadUsers.
501    SummaryTypeCheckedLoadUsers.clear();
502  }
503};
504
505// Call site information collected for a specific VTableSlot.
506struct VTableSlotInfo {
507  // The set of call sites which do not have all constant integer arguments
508  // (excluding "this").
509  CallSiteInfo CSInfo;
510
511  // The set of call sites with all constant integer arguments (excluding
512  // "this"), grouped by argument list.
513  std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
514
515  void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
516
517private:
518  CallSiteInfo &findCallSiteInfo(CallBase &CB);
519};
520
521CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
522  std::vector<uint64_t> Args;
523  auto *CBType = dyn_cast<IntegerType>(CB.getType());
524  if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
525    return CSInfo;
526  for (auto &&Arg : drop_begin(CB.args())) {
527    auto *CI = dyn_cast<ConstantInt>(Arg);
528    if (!CI || CI->getBitWidth() > 64)
529      return CSInfo;
530    Args.push_back(CI->getZExtValue());
531  }
532  return ConstCSInfo[Args];
533}
534
535void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
536                                 unsigned *NumUnsafeUses) {
537  auto &CSI = findCallSiteInfo(CB);
538  CSI.AllCallSitesDevirted = false;
539  CSI.CallSites.push_back({VTable, CB, NumUnsafeUses});
540}
541
542struct DevirtModule {
543  Module &M;
544  function_ref<AAResults &(Function &)> AARGetter;
545  function_ref<DominatorTree &(Function &)> LookupDomTree;
546
547  ModuleSummaryIndex *ExportSummary;
548  const ModuleSummaryIndex *ImportSummary;
549
550  IntegerType *Int8Ty;
551  PointerType *Int8PtrTy;
552  IntegerType *Int32Ty;
553  IntegerType *Int64Ty;
554  IntegerType *IntPtrTy;
555  /// Sizeless array type, used for imported vtables. This provides a signal
556  /// to analyzers that these imports may alias, as they do for example
557  /// when multiple unique return values occur in the same vtable.
558  ArrayType *Int8Arr0Ty;
559
560  bool RemarksEnabled;
561  function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter;
562
563  MapVector<VTableSlot, VTableSlotInfo> CallSlots;
564
565  // Calls that have already been optimized. We may add a call to multiple
566  // VTableSlotInfos if vtable loads are coalesced and need to make sure not to
567  // optimize a call more than once.
568  SmallPtrSet<CallBase *, 8> OptimizedCalls;
569
570  // This map keeps track of the number of "unsafe" uses of a loaded function
571  // pointer. The key is the associated llvm.type.test intrinsic call generated
572  // by this pass. An unsafe use is one that calls the loaded function pointer
573  // directly. Every time we eliminate an unsafe use (for example, by
574  // devirtualizing it or by applying virtual constant propagation), we
575  // decrement the value stored in this map. If a value reaches zero, we can
576  // eliminate the type check by RAUWing the associated llvm.type.test call with
577  // true.
578  std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
579  PatternList FunctionsToSkip;
580
581  DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
582               function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
583               function_ref<DominatorTree &(Function &)> LookupDomTree,
584               ModuleSummaryIndex *ExportSummary,
585               const ModuleSummaryIndex *ImportSummary)
586      : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
587        ExportSummary(ExportSummary), ImportSummary(ImportSummary),
588        Int8Ty(Type::getInt8Ty(M.getContext())),
589        Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
590        Int32Ty(Type::getInt32Ty(M.getContext())),
591        Int64Ty(Type::getInt64Ty(M.getContext())),
592        IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
593        Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)),
594        RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) {
595    assert(!(ExportSummary && ImportSummary));
596    FunctionsToSkip.init(SkipFunctionNames);
597  }
598
599  bool areRemarksEnabled();
600
601  void
602  scanTypeTestUsers(Function *TypeTestFunc,
603                    DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
604  void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
605
606  void buildTypeIdentifierMap(
607      std::vector<VTableBits> &Bits,
608      DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
609
610  bool
611  tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
612                            const std::set<TypeMemberInfo> &TypeMemberInfos,
613                            uint64_t ByteOffset,
614                            ModuleSummaryIndex *ExportSummary);
615
616  void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
617                             bool &IsExported);
618  bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary,
619                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
620                           VTableSlotInfo &SlotInfo,
621                           WholeProgramDevirtResolution *Res);
622
623  void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
624                              bool &IsExported);
625  void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
626                            VTableSlotInfo &SlotInfo,
627                            WholeProgramDevirtResolution *Res, VTableSlot Slot);
628
629  bool tryEvaluateFunctionsWithArgs(
630      MutableArrayRef<VirtualCallTarget> TargetsForSlot,
631      ArrayRef<uint64_t> Args);
632
633  void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
634                             uint64_t TheRetVal);
635  bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
636                           CallSiteInfo &CSInfo,
637                           WholeProgramDevirtResolution::ByArg *Res);
638
639  // Returns the global symbol name that is used to export information about the
640  // given vtable slot and list of arguments.
641  std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
642                            StringRef Name);
643
644  bool shouldExportConstantsAsAbsoluteSymbols();
645
646  // This function is called during the export phase to create a symbol
647  // definition containing information about the given vtable slot and list of
648  // arguments.
649  void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
650                    Constant *C);
651  void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
652                      uint32_t Const, uint32_t &Storage);
653
654  // This function is called during the import phase to create a reference to
655  // the symbol definition created during the export phase.
656  Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
657                         StringRef Name);
658  Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
659                           StringRef Name, IntegerType *IntTy,
660                           uint32_t Storage);
661
662  Constant *getMemberAddr(const TypeMemberInfo *M);
663
664  void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
665                            Constant *UniqueMemberAddr);
666  bool tryUniqueRetValOpt(unsigned BitWidth,
667                          MutableArrayRef<VirtualCallTarget> TargetsForSlot,
668                          CallSiteInfo &CSInfo,
669                          WholeProgramDevirtResolution::ByArg *Res,
670                          VTableSlot Slot, ArrayRef<uint64_t> Args);
671
672  void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
673                             Constant *Byte, Constant *Bit);
674  bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
675                           VTableSlotInfo &SlotInfo,
676                           WholeProgramDevirtResolution *Res, VTableSlot Slot);
677
678  void rebuildGlobal(VTableBits &B);
679
680  // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
681  void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
682
683  // If we were able to eliminate all unsafe uses for a type checked load,
684  // eliminate the associated type tests by replacing them with true.
685  void removeRedundantTypeTests();
686
687  bool run();
688
689  // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`.
690  //
691  // Caller guarantees that `ExportSummary` is not nullptr.
692  static ValueInfo lookUpFunctionValueInfo(Function *TheFn,
693                                           ModuleSummaryIndex *ExportSummary);
694
695  // Returns true if the function definition must be unreachable.
696  //
697  // Note if this helper function returns true, `F` is guaranteed
698  // to be unreachable; if it returns false, `F` might still
699  // be unreachable but not covered by this helper function.
700  //
701  // Implementation-wise, if function definition is present, IR is analyzed; if
702  // not, look up function flags from ExportSummary as a fallback.
703  static bool mustBeUnreachableFunction(Function *const F,
704                                        ModuleSummaryIndex *ExportSummary);
705
706  // Lower the module using the action and summary passed as command line
707  // arguments. For testing purposes only.
708  static bool
709  runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter,
710                function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
711                function_ref<DominatorTree &(Function &)> LookupDomTree);
712};
713
714struct DevirtIndex {
715  ModuleSummaryIndex &ExportSummary;
716  // The set in which to record GUIDs exported from their module by
717  // devirtualization, used by client to ensure they are not internalized.
718  std::set<GlobalValue::GUID> &ExportedGUIDs;
719  // A map in which to record the information necessary to locate the WPD
720  // resolution for local targets in case they are exported by cross module
721  // importing.
722  std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap;
723
724  MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots;
725
726  PatternList FunctionsToSkip;
727
728  DevirtIndex(
729      ModuleSummaryIndex &ExportSummary,
730      std::set<GlobalValue::GUID> &ExportedGUIDs,
731      std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap)
732      : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs),
733        LocalWPDTargetsMap(LocalWPDTargetsMap) {
734    FunctionsToSkip.init(SkipFunctionNames);
735  }
736
737  bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot,
738                                 const TypeIdCompatibleVtableInfo TIdInfo,
739                                 uint64_t ByteOffset);
740
741  bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
742                           VTableSlotSummary &SlotSummary,
743                           VTableSlotInfo &SlotInfo,
744                           WholeProgramDevirtResolution *Res,
745                           std::set<ValueInfo> &DevirtTargets);
746
747  void run();
748};
749} // end anonymous namespace
750
751PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
752                                              ModuleAnalysisManager &AM) {
753  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
754  auto AARGetter = [&](Function &F) -> AAResults & {
755    return FAM.getResult<AAManager>(F);
756  };
757  auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
758    return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
759  };
760  auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
761    return FAM.getResult<DominatorTreeAnalysis>(F);
762  };
763  if (UseCommandLine) {
764    if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree))
765      return PreservedAnalyses::all();
766    return PreservedAnalyses::none();
767  }
768  if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary,
769                    ImportSummary)
770           .run())
771    return PreservedAnalyses::all();
772  return PreservedAnalyses::none();
773}
774
775namespace llvm {
776// Enable whole program visibility if enabled by client (e.g. linker) or
777// internal option, and not force disabled.
778bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
779  return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) &&
780         !DisableWholeProgramVisibility;
781}
782
783/// If whole program visibility asserted, then upgrade all public vcall
784/// visibility metadata on vtable definitions to linkage unit visibility in
785/// Module IR (for regular or hybrid LTO).
786void updateVCallVisibilityInModule(
787    Module &M, bool WholeProgramVisibilityEnabledInLTO,
788    const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) {
789  if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
790    return;
791  for (GlobalVariable &GV : M.globals()) {
792    // Add linkage unit visibility to any variable with type metadata, which are
793    // the vtable definitions. We won't have an existing vcall_visibility
794    // metadata on vtable definitions with public visibility.
795    if (GV.hasMetadata(LLVMContext::MD_type) &&
796        GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic &&
797        // Don't upgrade the visibility for symbols exported to the dynamic
798        // linker, as we have no information on their eventual use.
799        !DynamicExportSymbols.count(GV.getGUID()))
800      GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit);
801  }
802}
803
804void updatePublicTypeTestCalls(Module &M,
805                               bool WholeProgramVisibilityEnabledInLTO) {
806  Function *PublicTypeTestFunc =
807      M.getFunction(Intrinsic::getName(Intrinsic::public_type_test));
808  if (!PublicTypeTestFunc)
809    return;
810  if (hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) {
811    Function *TypeTestFunc =
812        Intrinsic::getDeclaration(&M, Intrinsic::type_test);
813    for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) {
814      auto *CI = cast<CallInst>(U.getUser());
815      auto *NewCI = CallInst::Create(
816          TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)},
817          std::nullopt, "", CI);
818      CI->replaceAllUsesWith(NewCI);
819      CI->eraseFromParent();
820    }
821  } else {
822    auto *True = ConstantInt::getTrue(M.getContext());
823    for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) {
824      auto *CI = cast<CallInst>(U.getUser());
825      CI->replaceAllUsesWith(True);
826      CI->eraseFromParent();
827    }
828  }
829}
830
831/// If whole program visibility asserted, then upgrade all public vcall
832/// visibility metadata on vtable definition summaries to linkage unit
833/// visibility in Module summary index (for ThinLTO).
834void updateVCallVisibilityInIndex(
835    ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO,
836    const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) {
837  if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
838    return;
839  for (auto &P : Index) {
840    // Don't upgrade the visibility for symbols exported to the dynamic
841    // linker, as we have no information on their eventual use.
842    if (DynamicExportSymbols.count(P.first))
843      continue;
844    for (auto &S : P.second.SummaryList) {
845      auto *GVar = dyn_cast<GlobalVarSummary>(S.get());
846      if (!GVar ||
847          GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
848        continue;
849      GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
850    }
851  }
852}
853
854void runWholeProgramDevirtOnIndex(
855    ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
856    std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
857  DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run();
858}
859
860void updateIndexWPDForExports(
861    ModuleSummaryIndex &Summary,
862    function_ref<bool(StringRef, ValueInfo)> isExported,
863    std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
864  for (auto &T : LocalWPDTargetsMap) {
865    auto &VI = T.first;
866    // This was enforced earlier during trySingleImplDevirt.
867    assert(VI.getSummaryList().size() == 1 &&
868           "Devirt of local target has more than one copy");
869    auto &S = VI.getSummaryList()[0];
870    if (!isExported(S->modulePath(), VI))
871      continue;
872
873    // It's been exported by a cross module import.
874    for (auto &SlotSummary : T.second) {
875      auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID);
876      assert(TIdSum);
877      auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset);
878      assert(WPDRes != TIdSum->WPDRes.end());
879      WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
880          WPDRes->second.SingleImplName,
881          Summary.getModuleHash(S->modulePath()));
882    }
883  }
884}
885
886} // end namespace llvm
887
888static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
889  // Check that summary index contains regular LTO module when performing
890  // export to prevent occasional use of index from pure ThinLTO compilation
891  // (-fno-split-lto-module). This kind of summary index is passed to
892  // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting.
893  const auto &ModPaths = Summary->modulePaths();
894  if (ClSummaryAction != PassSummaryAction::Import &&
895      ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) ==
896          ModPaths.end())
897    return createStringError(
898        errc::invalid_argument,
899        "combined summary should contain Regular LTO module");
900  return ErrorSuccess();
901}
902
903bool DevirtModule::runForTesting(
904    Module &M, function_ref<AAResults &(Function &)> AARGetter,
905    function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
906    function_ref<DominatorTree &(Function &)> LookupDomTree) {
907  std::unique_ptr<ModuleSummaryIndex> Summary =
908      std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false);
909
910  // Handle the command-line summary arguments. This code is for testing
911  // purposes only, so we handle errors directly.
912  if (!ClReadSummary.empty()) {
913    ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
914                          ": ");
915    auto ReadSummaryFile =
916        ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
917    if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr =
918            getModuleSummaryIndex(*ReadSummaryFile)) {
919      Summary = std::move(*SummaryOrErr);
920      ExitOnErr(checkCombinedSummaryForTesting(Summary.get()));
921    } else {
922      // Try YAML if we've failed with bitcode.
923      consumeError(SummaryOrErr.takeError());
924      yaml::Input In(ReadSummaryFile->getBuffer());
925      In >> *Summary;
926      ExitOnErr(errorCodeToError(In.error()));
927    }
928  }
929
930  bool Changed =
931      DevirtModule(M, AARGetter, OREGetter, LookupDomTree,
932                   ClSummaryAction == PassSummaryAction::Export ? Summary.get()
933                                                                : nullptr,
934                   ClSummaryAction == PassSummaryAction::Import ? Summary.get()
935                                                                : nullptr)
936          .run();
937
938  if (!ClWriteSummary.empty()) {
939    ExitOnError ExitOnErr(
940        "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
941    std::error_code EC;
942    if (StringRef(ClWriteSummary).endswith(".bc")) {
943      raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None);
944      ExitOnErr(errorCodeToError(EC));
945      writeIndexToFile(*Summary, OS);
946    } else {
947      raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF);
948      ExitOnErr(errorCodeToError(EC));
949      yaml::Output Out(OS);
950      Out << *Summary;
951    }
952  }
953
954  return Changed;
955}
956
957void DevirtModule::buildTypeIdentifierMap(
958    std::vector<VTableBits> &Bits,
959    DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
960  DenseMap<GlobalVariable *, VTableBits *> GVToBits;
961  Bits.reserve(M.getGlobalList().size());
962  SmallVector<MDNode *, 2> Types;
963  for (GlobalVariable &GV : M.globals()) {
964    Types.clear();
965    GV.getMetadata(LLVMContext::MD_type, Types);
966    if (GV.isDeclaration() || Types.empty())
967      continue;
968
969    VTableBits *&BitsPtr = GVToBits[&GV];
970    if (!BitsPtr) {
971      Bits.emplace_back();
972      Bits.back().GV = &GV;
973      Bits.back().ObjectSize =
974          M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
975      BitsPtr = &Bits.back();
976    }
977
978    for (MDNode *Type : Types) {
979      auto TypeID = Type->getOperand(1).get();
980
981      uint64_t Offset =
982          cast<ConstantInt>(
983              cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
984              ->getZExtValue();
985
986      TypeIdMap[TypeID].insert({BitsPtr, Offset});
987    }
988  }
989}
990
991bool DevirtModule::tryFindVirtualCallTargets(
992    std::vector<VirtualCallTarget> &TargetsForSlot,
993    const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset,
994    ModuleSummaryIndex *ExportSummary) {
995  for (const TypeMemberInfo &TM : TypeMemberInfos) {
996    if (!TM.Bits->GV->isConstant())
997      return false;
998
999    // We cannot perform whole program devirtualization analysis on a vtable
1000    // with public LTO visibility.
1001    if (TM.Bits->GV->getVCallVisibility() ==
1002        GlobalObject::VCallVisibilityPublic)
1003      return false;
1004
1005    Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
1006                                       TM.Offset + ByteOffset, M);
1007    if (!Ptr)
1008      return false;
1009
1010    auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
1011    if (!Fn)
1012      return false;
1013
1014    if (FunctionsToSkip.match(Fn->getName()))
1015      return false;
1016
1017    // We can disregard __cxa_pure_virtual as a possible call target, as
1018    // calls to pure virtuals are UB.
1019    if (Fn->getName() == "__cxa_pure_virtual")
1020      continue;
1021
1022    // We can disregard unreachable functions as possible call targets, as
1023    // unreachable functions shouldn't be called.
1024    if (mustBeUnreachableFunction(Fn, ExportSummary))
1025      continue;
1026
1027    TargetsForSlot.push_back({Fn, &TM});
1028  }
1029
1030  // Give up if we couldn't find any targets.
1031  return !TargetsForSlot.empty();
1032}
1033
1034bool DevirtIndex::tryFindVirtualCallTargets(
1035    std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo,
1036    uint64_t ByteOffset) {
1037  for (const TypeIdOffsetVtableInfo &P : TIdInfo) {
1038    // Find a representative copy of the vtable initializer.
1039    // We can have multiple available_externally, linkonce_odr and weak_odr
1040    // vtable initializers. We can also have multiple external vtable
1041    // initializers in the case of comdats, which we cannot check here.
1042    // The linker should give an error in this case.
1043    //
1044    // Also, handle the case of same-named local Vtables with the same path
1045    // and therefore the same GUID. This can happen if there isn't enough
1046    // distinguishing path when compiling the source file. In that case we
1047    // conservatively return false early.
1048    const GlobalVarSummary *VS = nullptr;
1049    bool LocalFound = false;
1050    for (const auto &S : P.VTableVI.getSummaryList()) {
1051      if (GlobalValue::isLocalLinkage(S->linkage())) {
1052        if (LocalFound)
1053          return false;
1054        LocalFound = true;
1055      }
1056      auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject());
1057      if (!CurVS->vTableFuncs().empty() ||
1058          // Previously clang did not attach the necessary type metadata to
1059          // available_externally vtables, in which case there would not
1060          // be any vtable functions listed in the summary and we need
1061          // to treat this case conservatively (in case the bitcode is old).
1062          // However, we will also not have any vtable functions in the
1063          // case of a pure virtual base class. In that case we do want
1064          // to set VS to avoid treating it conservatively.
1065          !GlobalValue::isAvailableExternallyLinkage(S->linkage())) {
1066        VS = CurVS;
1067        // We cannot perform whole program devirtualization analysis on a vtable
1068        // with public LTO visibility.
1069        if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
1070          return false;
1071      }
1072    }
1073    // There will be no VS if all copies are available_externally having no
1074    // type metadata. In that case we can't safely perform WPD.
1075    if (!VS)
1076      return false;
1077    if (!VS->isLive())
1078      continue;
1079    for (auto VTP : VS->vTableFuncs()) {
1080      if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset)
1081        continue;
1082
1083      if (mustBeUnreachableFunction(VTP.FuncVI))
1084        continue;
1085
1086      TargetsForSlot.push_back(VTP.FuncVI);
1087    }
1088  }
1089
1090  // Give up if we couldn't find any targets.
1091  return !TargetsForSlot.empty();
1092}
1093
1094void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
1095                                         Constant *TheFn, bool &IsExported) {
1096  // Don't devirtualize function if we're told to skip it
1097  // in -wholeprogramdevirt-skip.
1098  if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName()))
1099    return;
1100  auto Apply = [&](CallSiteInfo &CSInfo) {
1101    for (auto &&VCallSite : CSInfo.CallSites) {
1102      if (!OptimizedCalls.insert(&VCallSite.CB).second)
1103        continue;
1104
1105      if (RemarksEnabled)
1106        VCallSite.emitRemark("single-impl",
1107                             TheFn->stripPointerCasts()->getName(), OREGetter);
1108      NumSingleImpl++;
1109      auto &CB = VCallSite.CB;
1110      assert(!CB.getCalledFunction() && "devirtualizing direct call?");
1111      IRBuilder<> Builder(&CB);
1112      Value *Callee =
1113          Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType());
1114
1115      // If trap checking is enabled, add support to compare the virtual
1116      // function pointer to the devirtualized target. In case of a mismatch,
1117      // perform a debug trap.
1118      if (DevirtCheckMode == WPDCheckMode::Trap) {
1119        auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee);
1120        Instruction *ThenTerm =
1121            SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false);
1122        Builder.SetInsertPoint(ThenTerm);
1123        Function *TrapFn = Intrinsic::getDeclaration(&M, Intrinsic::debugtrap);
1124        auto *CallTrap = Builder.CreateCall(TrapFn);
1125        CallTrap->setDebugLoc(CB.getDebugLoc());
1126      }
1127
1128      // If fallback checking is enabled, add support to compare the virtual
1129      // function pointer to the devirtualized target. In case of a mismatch,
1130      // fall back to indirect call.
1131      if (DevirtCheckMode == WPDCheckMode::Fallback) {
1132        MDNode *Weights =
1133            MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1);
1134        // Version the indirect call site. If the called value is equal to the
1135        // given callee, 'NewInst' will be executed, otherwise the original call
1136        // site will be executed.
1137        CallBase &NewInst = versionCallSite(CB, Callee, Weights);
1138        NewInst.setCalledOperand(Callee);
1139        // Since the new call site is direct, we must clear metadata that
1140        // is only appropriate for indirect calls. This includes !prof and
1141        // !callees metadata.
1142        NewInst.setMetadata(LLVMContext::MD_prof, nullptr);
1143        NewInst.setMetadata(LLVMContext::MD_callees, nullptr);
1144        // Additionally, we should remove them from the fallback indirect call,
1145        // so that we don't attempt to perform indirect call promotion later.
1146        CB.setMetadata(LLVMContext::MD_prof, nullptr);
1147        CB.setMetadata(LLVMContext::MD_callees, nullptr);
1148      }
1149
1150      // In either trapping or non-checking mode, devirtualize original call.
1151      else {
1152        // Devirtualize unconditionally.
1153        CB.setCalledOperand(Callee);
1154        // Since the call site is now direct, we must clear metadata that
1155        // is only appropriate for indirect calls. This includes !prof and
1156        // !callees metadata.
1157        CB.setMetadata(LLVMContext::MD_prof, nullptr);
1158        CB.setMetadata(LLVMContext::MD_callees, nullptr);
1159      }
1160
1161      // This use is no longer unsafe.
1162      if (VCallSite.NumUnsafeUses)
1163        --*VCallSite.NumUnsafeUses;
1164    }
1165    if (CSInfo.isExported())
1166      IsExported = true;
1167    CSInfo.markDevirt();
1168  };
1169  Apply(SlotInfo.CSInfo);
1170  for (auto &P : SlotInfo.ConstCSInfo)
1171    Apply(P.second);
1172}
1173
1174static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) {
1175  // We can't add calls if we haven't seen a definition
1176  if (Callee.getSummaryList().empty())
1177    return false;
1178
1179  // Insert calls into the summary index so that the devirtualized targets
1180  // are eligible for import.
1181  // FIXME: Annotate type tests with hotness. For now, mark these as hot
1182  // to better ensure we have the opportunity to inline them.
1183  bool IsExported = false;
1184  auto &S = Callee.getSummaryList()[0];
1185  CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0);
1186  auto AddCalls = [&](CallSiteInfo &CSInfo) {
1187    for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
1188      FS->addCall({Callee, CI});
1189      IsExported |= S->modulePath() != FS->modulePath();
1190    }
1191    for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) {
1192      FS->addCall({Callee, CI});
1193      IsExported |= S->modulePath() != FS->modulePath();
1194    }
1195  };
1196  AddCalls(SlotInfo.CSInfo);
1197  for (auto &P : SlotInfo.ConstCSInfo)
1198    AddCalls(P.second);
1199  return IsExported;
1200}
1201
1202bool DevirtModule::trySingleImplDevirt(
1203    ModuleSummaryIndex *ExportSummary,
1204    MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1205    WholeProgramDevirtResolution *Res) {
1206  // See if the program contains a single implementation of this virtual
1207  // function.
1208  Function *TheFn = TargetsForSlot[0].Fn;
1209  for (auto &&Target : TargetsForSlot)
1210    if (TheFn != Target.Fn)
1211      return false;
1212
1213  // If so, update each call site to call that implementation directly.
1214  if (RemarksEnabled || AreStatisticsEnabled())
1215    TargetsForSlot[0].WasDevirt = true;
1216
1217  bool IsExported = false;
1218  applySingleImplDevirt(SlotInfo, TheFn, IsExported);
1219  if (!IsExported)
1220    return false;
1221
1222  // If the only implementation has local linkage, we must promote to external
1223  // to make it visible to thin LTO objects. We can only get here during the
1224  // ThinLTO export phase.
1225  if (TheFn->hasLocalLinkage()) {
1226    std::string NewName = (TheFn->getName() + ".llvm.merged").str();
1227
1228    // Since we are renaming the function, any comdats with the same name must
1229    // also be renamed. This is required when targeting COFF, as the comdat name
1230    // must match one of the names of the symbols in the comdat.
1231    if (Comdat *C = TheFn->getComdat()) {
1232      if (C->getName() == TheFn->getName()) {
1233        Comdat *NewC = M.getOrInsertComdat(NewName);
1234        NewC->setSelectionKind(C->getSelectionKind());
1235        for (GlobalObject &GO : M.global_objects())
1236          if (GO.getComdat() == C)
1237            GO.setComdat(NewC);
1238      }
1239    }
1240
1241    TheFn->setLinkage(GlobalValue::ExternalLinkage);
1242    TheFn->setVisibility(GlobalValue::HiddenVisibility);
1243    TheFn->setName(NewName);
1244  }
1245  if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID()))
1246    // Any needed promotion of 'TheFn' has already been done during
1247    // LTO unit split, so we can ignore return value of AddCalls.
1248    AddCalls(SlotInfo, TheFnVI);
1249
1250  Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1251  Res->SingleImplName = std::string(TheFn->getName());
1252
1253  return true;
1254}
1255
1256bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
1257                                      VTableSlotSummary &SlotSummary,
1258                                      VTableSlotInfo &SlotInfo,
1259                                      WholeProgramDevirtResolution *Res,
1260                                      std::set<ValueInfo> &DevirtTargets) {
1261  // See if the program contains a single implementation of this virtual
1262  // function.
1263  auto TheFn = TargetsForSlot[0];
1264  for (auto &&Target : TargetsForSlot)
1265    if (TheFn != Target)
1266      return false;
1267
1268  // Don't devirtualize if we don't have target definition.
1269  auto Size = TheFn.getSummaryList().size();
1270  if (!Size)
1271    return false;
1272
1273  // Don't devirtualize function if we're told to skip it
1274  // in -wholeprogramdevirt-skip.
1275  if (FunctionsToSkip.match(TheFn.name()))
1276    return false;
1277
1278  // If the summary list contains multiple summaries where at least one is
1279  // a local, give up, as we won't know which (possibly promoted) name to use.
1280  for (const auto &S : TheFn.getSummaryList())
1281    if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1)
1282      return false;
1283
1284  // Collect functions devirtualized at least for one call site for stats.
1285  if (PrintSummaryDevirt || AreStatisticsEnabled())
1286    DevirtTargets.insert(TheFn);
1287
1288  auto &S = TheFn.getSummaryList()[0];
1289  bool IsExported = AddCalls(SlotInfo, TheFn);
1290  if (IsExported)
1291    ExportedGUIDs.insert(TheFn.getGUID());
1292
1293  // Record in summary for use in devirtualization during the ThinLTO import
1294  // step.
1295  Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1296  if (GlobalValue::isLocalLinkage(S->linkage())) {
1297    if (IsExported)
1298      // If target is a local function and we are exporting it by
1299      // devirtualizing a call in another module, we need to record the
1300      // promoted name.
1301      Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
1302          TheFn.name(), ExportSummary.getModuleHash(S->modulePath()));
1303    else {
1304      LocalWPDTargetsMap[TheFn].push_back(SlotSummary);
1305      Res->SingleImplName = std::string(TheFn.name());
1306    }
1307  } else
1308    Res->SingleImplName = std::string(TheFn.name());
1309
1310  // Name will be empty if this thin link driven off of serialized combined
1311  // index (e.g. llvm-lto). However, WPD is not supported/invoked for the
1312  // legacy LTO API anyway.
1313  assert(!Res->SingleImplName.empty());
1314
1315  return true;
1316}
1317
1318void DevirtModule::tryICallBranchFunnel(
1319    MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1320    WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1321  Triple T(M.getTargetTriple());
1322  if (T.getArch() != Triple::x86_64)
1323    return;
1324
1325  if (TargetsForSlot.size() > ClThreshold)
1326    return;
1327
1328  bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
1329  if (!HasNonDevirt)
1330    for (auto &P : SlotInfo.ConstCSInfo)
1331      if (!P.second.AllCallSitesDevirted) {
1332        HasNonDevirt = true;
1333        break;
1334      }
1335
1336  if (!HasNonDevirt)
1337    return;
1338
1339  FunctionType *FT =
1340      FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
1341  Function *JT;
1342  if (isa<MDString>(Slot.TypeID)) {
1343    JT = Function::Create(FT, Function::ExternalLinkage,
1344                          M.getDataLayout().getProgramAddressSpace(),
1345                          getGlobalName(Slot, {}, "branch_funnel"), &M);
1346    JT->setVisibility(GlobalValue::HiddenVisibility);
1347  } else {
1348    JT = Function::Create(FT, Function::InternalLinkage,
1349                          M.getDataLayout().getProgramAddressSpace(),
1350                          "branch_funnel", &M);
1351  }
1352  JT->addParamAttr(0, Attribute::Nest);
1353
1354  std::vector<Value *> JTArgs;
1355  JTArgs.push_back(JT->arg_begin());
1356  for (auto &T : TargetsForSlot) {
1357    JTArgs.push_back(getMemberAddr(T.TM));
1358    JTArgs.push_back(T.Fn);
1359  }
1360
1361  BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
1362  Function *Intr =
1363      Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {});
1364
1365  auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
1366  CI->setTailCallKind(CallInst::TCK_MustTail);
1367  ReturnInst::Create(M.getContext(), nullptr, BB);
1368
1369  bool IsExported = false;
1370  applyICallBranchFunnel(SlotInfo, JT, IsExported);
1371  if (IsExported)
1372    Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
1373}
1374
1375void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
1376                                          Constant *JT, bool &IsExported) {
1377  auto Apply = [&](CallSiteInfo &CSInfo) {
1378    if (CSInfo.isExported())
1379      IsExported = true;
1380    if (CSInfo.AllCallSitesDevirted)
1381      return;
1382    for (auto &&VCallSite : CSInfo.CallSites) {
1383      CallBase &CB = VCallSite.CB;
1384
1385      // Jump tables are only profitable if the retpoline mitigation is enabled.
1386      Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
1387      if (!FSAttr.isValid() ||
1388          !FSAttr.getValueAsString().contains("+retpoline"))
1389        continue;
1390
1391      NumBranchFunnel++;
1392      if (RemarksEnabled)
1393        VCallSite.emitRemark("branch-funnel",
1394                             JT->stripPointerCasts()->getName(), OREGetter);
1395
1396      // Pass the address of the vtable in the nest register, which is r10 on
1397      // x86_64.
1398      std::vector<Type *> NewArgs;
1399      NewArgs.push_back(Int8PtrTy);
1400      append_range(NewArgs, CB.getFunctionType()->params());
1401      FunctionType *NewFT =
1402          FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs,
1403                            CB.getFunctionType()->isVarArg());
1404      PointerType *NewFTPtr = PointerType::getUnqual(NewFT);
1405
1406      IRBuilder<> IRB(&CB);
1407      std::vector<Value *> Args;
1408      Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
1409      llvm::append_range(Args, CB.args());
1410
1411      CallBase *NewCS = nullptr;
1412      if (isa<CallInst>(CB))
1413        NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args);
1414      else
1415        NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr),
1416                                 cast<InvokeInst>(CB).getNormalDest(),
1417                                 cast<InvokeInst>(CB).getUnwindDest(), Args);
1418      NewCS->setCallingConv(CB.getCallingConv());
1419
1420      AttributeList Attrs = CB.getAttributes();
1421      std::vector<AttributeSet> NewArgAttrs;
1422      NewArgAttrs.push_back(AttributeSet::get(
1423          M.getContext(), ArrayRef<Attribute>{Attribute::get(
1424                              M.getContext(), Attribute::Nest)}));
1425      for (unsigned I = 0; I + 2 <  Attrs.getNumAttrSets(); ++I)
1426        NewArgAttrs.push_back(Attrs.getParamAttrs(I));
1427      NewCS->setAttributes(
1428          AttributeList::get(M.getContext(), Attrs.getFnAttrs(),
1429                             Attrs.getRetAttrs(), NewArgAttrs));
1430
1431      CB.replaceAllUsesWith(NewCS);
1432      CB.eraseFromParent();
1433
1434      // This use is no longer unsafe.
1435      if (VCallSite.NumUnsafeUses)
1436        --*VCallSite.NumUnsafeUses;
1437    }
1438    // Don't mark as devirtualized because there may be callers compiled without
1439    // retpoline mitigation, which would mean that they are lowered to
1440    // llvm.type.test and therefore require an llvm.type.test resolution for the
1441    // type identifier.
1442  };
1443  Apply(SlotInfo.CSInfo);
1444  for (auto &P : SlotInfo.ConstCSInfo)
1445    Apply(P.second);
1446}
1447
1448bool DevirtModule::tryEvaluateFunctionsWithArgs(
1449    MutableArrayRef<VirtualCallTarget> TargetsForSlot,
1450    ArrayRef<uint64_t> Args) {
1451  // Evaluate each function and store the result in each target's RetVal
1452  // field.
1453  for (VirtualCallTarget &Target : TargetsForSlot) {
1454    if (Target.Fn->arg_size() != Args.size() + 1)
1455      return false;
1456
1457    Evaluator Eval(M.getDataLayout(), nullptr);
1458    SmallVector<Constant *, 2> EvalArgs;
1459    EvalArgs.push_back(
1460        Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
1461    for (unsigned I = 0; I != Args.size(); ++I) {
1462      auto *ArgTy = dyn_cast<IntegerType>(
1463          Target.Fn->getFunctionType()->getParamType(I + 1));
1464      if (!ArgTy)
1465        return false;
1466      EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
1467    }
1468
1469    Constant *RetVal;
1470    if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
1471        !isa<ConstantInt>(RetVal))
1472      return false;
1473    Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
1474  }
1475  return true;
1476}
1477
1478void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
1479                                         uint64_t TheRetVal) {
1480  for (auto Call : CSInfo.CallSites) {
1481    if (!OptimizedCalls.insert(&Call.CB).second)
1482      continue;
1483    NumUniformRetVal++;
1484    Call.replaceAndErase(
1485        "uniform-ret-val", FnName, RemarksEnabled, OREGetter,
1486        ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal));
1487  }
1488  CSInfo.markDevirt();
1489}
1490
1491bool DevirtModule::tryUniformRetValOpt(
1492    MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
1493    WholeProgramDevirtResolution::ByArg *Res) {
1494  // Uniform return value optimization. If all functions return the same
1495  // constant, replace all calls with that constant.
1496  uint64_t TheRetVal = TargetsForSlot[0].RetVal;
1497  for (const VirtualCallTarget &Target : TargetsForSlot)
1498    if (Target.RetVal != TheRetVal)
1499      return false;
1500
1501  if (CSInfo.isExported()) {
1502    Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
1503    Res->Info = TheRetVal;
1504  }
1505
1506  applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
1507  if (RemarksEnabled || AreStatisticsEnabled())
1508    for (auto &&Target : TargetsForSlot)
1509      Target.WasDevirt = true;
1510  return true;
1511}
1512
1513std::string DevirtModule::getGlobalName(VTableSlot Slot,
1514                                        ArrayRef<uint64_t> Args,
1515                                        StringRef Name) {
1516  std::string FullName = "__typeid_";
1517  raw_string_ostream OS(FullName);
1518  OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
1519  for (uint64_t Arg : Args)
1520    OS << '_' << Arg;
1521  OS << '_' << Name;
1522  return OS.str();
1523}
1524
1525bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() {
1526  Triple T(M.getTargetTriple());
1527  return T.isX86() && T.getObjectFormat() == Triple::ELF;
1528}
1529
1530void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1531                                StringRef Name, Constant *C) {
1532  GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
1533                                        getGlobalName(Slot, Args, Name), C, &M);
1534  GA->setVisibility(GlobalValue::HiddenVisibility);
1535}
1536
1537void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1538                                  StringRef Name, uint32_t Const,
1539                                  uint32_t &Storage) {
1540  if (shouldExportConstantsAsAbsoluteSymbols()) {
1541    exportGlobal(
1542        Slot, Args, Name,
1543        ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy));
1544    return;
1545  }
1546
1547  Storage = Const;
1548}
1549
1550Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1551                                     StringRef Name) {
1552  Constant *C =
1553      M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty);
1554  auto *GV = dyn_cast<GlobalVariable>(C);
1555  if (GV)
1556    GV->setVisibility(GlobalValue::HiddenVisibility);
1557  return C;
1558}
1559
1560Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1561                                       StringRef Name, IntegerType *IntTy,
1562                                       uint32_t Storage) {
1563  if (!shouldExportConstantsAsAbsoluteSymbols())
1564    return ConstantInt::get(IntTy, Storage);
1565
1566  Constant *C = importGlobal(Slot, Args, Name);
1567  auto *GV = cast<GlobalVariable>(C->stripPointerCasts());
1568  C = ConstantExpr::getPtrToInt(C, IntTy);
1569
1570  // We only need to set metadata if the global is newly created, in which
1571  // case it would not have hidden visibility.
1572  if (GV->hasMetadata(LLVMContext::MD_absolute_symbol))
1573    return C;
1574
1575  auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
1576    auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
1577    auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
1578    GV->setMetadata(LLVMContext::MD_absolute_symbol,
1579                    MDNode::get(M.getContext(), {MinC, MaxC}));
1580  };
1581  unsigned AbsWidth = IntTy->getBitWidth();
1582  if (AbsWidth == IntPtrTy->getBitWidth())
1583    SetAbsRange(~0ull, ~0ull); // Full set.
1584  else
1585    SetAbsRange(0, 1ull << AbsWidth);
1586  return C;
1587}
1588
1589void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
1590                                        bool IsOne,
1591                                        Constant *UniqueMemberAddr) {
1592  for (auto &&Call : CSInfo.CallSites) {
1593    if (!OptimizedCalls.insert(&Call.CB).second)
1594      continue;
1595    IRBuilder<> B(&Call.CB);
1596    Value *Cmp =
1597        B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable,
1598                     B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType()));
1599    Cmp = B.CreateZExt(Cmp, Call.CB.getType());
1600    NumUniqueRetVal++;
1601    Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
1602                         Cmp);
1603  }
1604  CSInfo.markDevirt();
1605}
1606
1607Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
1608  Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
1609  return ConstantExpr::getGetElementPtr(Int8Ty, C,
1610                                        ConstantInt::get(Int64Ty, M->Offset));
1611}
1612
1613bool DevirtModule::tryUniqueRetValOpt(
1614    unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
1615    CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
1616    VTableSlot Slot, ArrayRef<uint64_t> Args) {
1617  // IsOne controls whether we look for a 0 or a 1.
1618  auto tryUniqueRetValOptFor = [&](bool IsOne) {
1619    const TypeMemberInfo *UniqueMember = nullptr;
1620    for (const VirtualCallTarget &Target : TargetsForSlot) {
1621      if (Target.RetVal == (IsOne ? 1 : 0)) {
1622        if (UniqueMember)
1623          return false;
1624        UniqueMember = Target.TM;
1625      }
1626    }
1627
1628    // We should have found a unique member or bailed out by now. We already
1629    // checked for a uniform return value in tryUniformRetValOpt.
1630    assert(UniqueMember);
1631
1632    Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
1633    if (CSInfo.isExported()) {
1634      Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
1635      Res->Info = IsOne;
1636
1637      exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
1638    }
1639
1640    // Replace each call with the comparison.
1641    applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
1642                         UniqueMemberAddr);
1643
1644    // Update devirtualization statistics for targets.
1645    if (RemarksEnabled || AreStatisticsEnabled())
1646      for (auto &&Target : TargetsForSlot)
1647        Target.WasDevirt = true;
1648
1649    return true;
1650  };
1651
1652  if (BitWidth == 1) {
1653    if (tryUniqueRetValOptFor(true))
1654      return true;
1655    if (tryUniqueRetValOptFor(false))
1656      return true;
1657  }
1658  return false;
1659}
1660
1661void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
1662                                         Constant *Byte, Constant *Bit) {
1663  for (auto Call : CSInfo.CallSites) {
1664    if (!OptimizedCalls.insert(&Call.CB).second)
1665      continue;
1666    auto *RetType = cast<IntegerType>(Call.CB.getType());
1667    IRBuilder<> B(&Call.CB);
1668    Value *Addr =
1669        B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
1670    if (RetType->getBitWidth() == 1) {
1671      Value *Bits = B.CreateLoad(Int8Ty, Addr);
1672      Value *BitsAndBit = B.CreateAnd(Bits, Bit);
1673      auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
1674      NumVirtConstProp1Bit++;
1675      Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
1676                           OREGetter, IsBitSet);
1677    } else {
1678      Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
1679      Value *Val = B.CreateLoad(RetType, ValAddr);
1680      NumVirtConstProp++;
1681      Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled,
1682                           OREGetter, Val);
1683    }
1684  }
1685  CSInfo.markDevirt();
1686}
1687
1688bool DevirtModule::tryVirtualConstProp(
1689    MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1690    WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1691  // This only works if the function returns an integer.
1692  auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
1693  if (!RetType)
1694    return false;
1695  unsigned BitWidth = RetType->getBitWidth();
1696  if (BitWidth > 64)
1697    return false;
1698
1699  // Make sure that each function is defined, does not access memory, takes at
1700  // least one argument, does not use its first argument (which we assume is
1701  // 'this'), and has the same return type.
1702  //
1703  // Note that we test whether this copy of the function is readnone, rather
1704  // than testing function attributes, which must hold for any copy of the
1705  // function, even a less optimized version substituted at link time. This is
1706  // sound because the virtual constant propagation optimizations effectively
1707  // inline all implementations of the virtual function into each call site,
1708  // rather than using function attributes to perform local optimization.
1709  for (VirtualCallTarget &Target : TargetsForSlot) {
1710    if (Target.Fn->isDeclaration() ||
1711        !computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn))
1712             .doesNotAccessMemory() ||
1713        Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
1714        Target.Fn->getReturnType() != RetType)
1715      return false;
1716  }
1717
1718  for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
1719    if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
1720      continue;
1721
1722    WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
1723    if (Res)
1724      ResByArg = &Res->ResByArg[CSByConstantArg.first];
1725
1726    if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
1727      continue;
1728
1729    if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
1730                           ResByArg, Slot, CSByConstantArg.first))
1731      continue;
1732
1733    // Find an allocation offset in bits in all vtables associated with the
1734    // type.
1735    uint64_t AllocBefore =
1736        findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
1737    uint64_t AllocAfter =
1738        findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
1739
1740    // Calculate the total amount of padding needed to store a value at both
1741    // ends of the object.
1742    uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
1743    for (auto &&Target : TargetsForSlot) {
1744      TotalPaddingBefore += std::max<int64_t>(
1745          (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
1746      TotalPaddingAfter += std::max<int64_t>(
1747          (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
1748    }
1749
1750    // If the amount of padding is too large, give up.
1751    // FIXME: do something smarter here.
1752    if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
1753      continue;
1754
1755    // Calculate the offset to the value as a (possibly negative) byte offset
1756    // and (if applicable) a bit offset, and store the values in the targets.
1757    int64_t OffsetByte;
1758    uint64_t OffsetBit;
1759    if (TotalPaddingBefore <= TotalPaddingAfter)
1760      setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
1761                            OffsetBit);
1762    else
1763      setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
1764                           OffsetBit);
1765
1766    if (RemarksEnabled || AreStatisticsEnabled())
1767      for (auto &&Target : TargetsForSlot)
1768        Target.WasDevirt = true;
1769
1770
1771    if (CSByConstantArg.second.isExported()) {
1772      ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
1773      exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte,
1774                     ResByArg->Byte);
1775      exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit,
1776                     ResByArg->Bit);
1777    }
1778
1779    // Rewrite each call to a load from OffsetByte/OffsetBit.
1780    Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
1781    Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
1782    applyVirtualConstProp(CSByConstantArg.second,
1783                          TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
1784  }
1785  return true;
1786}
1787
1788void DevirtModule::rebuildGlobal(VTableBits &B) {
1789  if (B.Before.Bytes.empty() && B.After.Bytes.empty())
1790    return;
1791
1792  // Align the before byte array to the global's minimum alignment so that we
1793  // don't break any alignment requirements on the global.
1794  Align Alignment = M.getDataLayout().getValueOrABITypeAlignment(
1795      B.GV->getAlign(), B.GV->getValueType());
1796  B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment));
1797
1798  // Before was stored in reverse order; flip it now.
1799  for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
1800    std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
1801
1802  // Build an anonymous global containing the before bytes, followed by the
1803  // original initializer, followed by the after bytes.
1804  auto NewInit = ConstantStruct::getAnon(
1805      {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
1806       B.GV->getInitializer(),
1807       ConstantDataArray::get(M.getContext(), B.After.Bytes)});
1808  auto NewGV =
1809      new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
1810                         GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
1811  NewGV->setSection(B.GV->getSection());
1812  NewGV->setComdat(B.GV->getComdat());
1813  NewGV->setAlignment(B.GV->getAlign());
1814
1815  // Copy the original vtable's metadata to the anonymous global, adjusting
1816  // offsets as required.
1817  NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
1818
1819  // Build an alias named after the original global, pointing at the second
1820  // element (the original initializer).
1821  auto Alias = GlobalAlias::create(
1822      B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
1823      ConstantExpr::getGetElementPtr(
1824          NewInit->getType(), NewGV,
1825          ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
1826                               ConstantInt::get(Int32Ty, 1)}),
1827      &M);
1828  Alias->setVisibility(B.GV->getVisibility());
1829  Alias->takeName(B.GV);
1830
1831  B.GV->replaceAllUsesWith(Alias);
1832  B.GV->eraseFromParent();
1833}
1834
1835bool DevirtModule::areRemarksEnabled() {
1836  const auto &FL = M.getFunctionList();
1837  for (const Function &Fn : FL) {
1838    if (Fn.empty())
1839      continue;
1840    auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &Fn.front());
1841    return DI.isEnabled();
1842  }
1843  return false;
1844}
1845
1846void DevirtModule::scanTypeTestUsers(
1847    Function *TypeTestFunc,
1848    DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
1849  // Find all virtual calls via a virtual table pointer %p under an assumption
1850  // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
1851  // points to a member of the type identifier %md. Group calls by (type ID,
1852  // offset) pair (effectively the identity of the virtual function) and store
1853  // to CallSlots.
1854  for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) {
1855    auto *CI = dyn_cast<CallInst>(U.getUser());
1856    if (!CI)
1857      continue;
1858
1859    // Search for virtual calls based on %p and add them to DevirtCalls.
1860    SmallVector<DevirtCallSite, 1> DevirtCalls;
1861    SmallVector<CallInst *, 1> Assumes;
1862    auto &DT = LookupDomTree(*CI->getFunction());
1863    findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
1864
1865    Metadata *TypeId =
1866        cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
1867    // If we found any, add them to CallSlots.
1868    if (!Assumes.empty()) {
1869      Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
1870      for (DevirtCallSite Call : DevirtCalls)
1871        CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr);
1872    }
1873
1874    auto RemoveTypeTestAssumes = [&]() {
1875      // We no longer need the assumes or the type test.
1876      for (auto *Assume : Assumes)
1877        Assume->eraseFromParent();
1878      // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
1879      // may use the vtable argument later.
1880      if (CI->use_empty())
1881        CI->eraseFromParent();
1882    };
1883
1884    // At this point we could remove all type test assume sequences, as they
1885    // were originally inserted for WPD. However, we can keep these in the
1886    // code stream for later analysis (e.g. to help drive more efficient ICP
1887    // sequences). They will eventually be removed by a second LowerTypeTests
1888    // invocation that cleans them up. In order to do this correctly, the first
1889    // LowerTypeTests invocation needs to know that they have "Unknown" type
1890    // test resolution, so that they aren't treated as Unsat and lowered to
1891    // False, which will break any uses on assumes. Below we remove any type
1892    // test assumes that will not be treated as Unknown by LTT.
1893
1894    // The type test assumes will be treated by LTT as Unsat if the type id is
1895    // not used on a global (in which case it has no entry in the TypeIdMap).
1896    if (!TypeIdMap.count(TypeId))
1897      RemoveTypeTestAssumes();
1898
1899    // For ThinLTO importing, we need to remove the type test assumes if this is
1900    // an MDString type id without a corresponding TypeIdSummary. Any
1901    // non-MDString type ids are ignored and treated as Unknown by LTT, so their
1902    // type test assumes can be kept. If the MDString type id is missing a
1903    // TypeIdSummary (e.g. because there was no use on a vcall, preventing the
1904    // exporting phase of WPD from analyzing it), then it would be treated as
1905    // Unsat by LTT and we need to remove its type test assumes here. If not
1906    // used on a vcall we don't need them for later optimization use in any
1907    // case.
1908    else if (ImportSummary && isa<MDString>(TypeId)) {
1909      const TypeIdSummary *TidSummary =
1910          ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString());
1911      if (!TidSummary)
1912        RemoveTypeTestAssumes();
1913      else
1914        // If one was created it should not be Unsat, because if we reached here
1915        // the type id was used on a global.
1916        assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat);
1917    }
1918  }
1919}
1920
1921void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
1922  Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
1923
1924  for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) {
1925    auto *CI = dyn_cast<CallInst>(U.getUser());
1926    if (!CI)
1927      continue;
1928
1929    Value *Ptr = CI->getArgOperand(0);
1930    Value *Offset = CI->getArgOperand(1);
1931    Value *TypeIdValue = CI->getArgOperand(2);
1932    Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
1933
1934    SmallVector<DevirtCallSite, 1> DevirtCalls;
1935    SmallVector<Instruction *, 1> LoadedPtrs;
1936    SmallVector<Instruction *, 1> Preds;
1937    bool HasNonCallUses = false;
1938    auto &DT = LookupDomTree(*CI->getFunction());
1939    findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
1940                                               HasNonCallUses, CI, DT);
1941
1942    // Start by generating "pessimistic" code that explicitly loads the function
1943    // pointer from the vtable and performs the type check. If possible, we will
1944    // eliminate the load and the type check later.
1945
1946    // If possible, only generate the load at the point where it is used.
1947    // This helps avoid unnecessary spills.
1948    IRBuilder<> LoadB(
1949        (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
1950    Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
1951    Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
1952    Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
1953
1954    for (Instruction *LoadedPtr : LoadedPtrs) {
1955      LoadedPtr->replaceAllUsesWith(LoadedValue);
1956      LoadedPtr->eraseFromParent();
1957    }
1958
1959    // Likewise for the type test.
1960    IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
1961    CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
1962
1963    for (Instruction *Pred : Preds) {
1964      Pred->replaceAllUsesWith(TypeTestCall);
1965      Pred->eraseFromParent();
1966    }
1967
1968    // We have already erased any extractvalue instructions that refer to the
1969    // intrinsic call, but the intrinsic may have other non-extractvalue uses
1970    // (although this is unlikely). In that case, explicitly build a pair and
1971    // RAUW it.
1972    if (!CI->use_empty()) {
1973      Value *Pair = PoisonValue::get(CI->getType());
1974      IRBuilder<> B(CI);
1975      Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
1976      Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
1977      CI->replaceAllUsesWith(Pair);
1978    }
1979
1980    // The number of unsafe uses is initially the number of uses.
1981    auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
1982    NumUnsafeUses = DevirtCalls.size();
1983
1984    // If the function pointer has a non-call user, we cannot eliminate the type
1985    // check, as one of those users may eventually call the pointer. Increment
1986    // the unsafe use count to make sure it cannot reach zero.
1987    if (HasNonCallUses)
1988      ++NumUnsafeUses;
1989    for (DevirtCallSite Call : DevirtCalls) {
1990      CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB,
1991                                                   &NumUnsafeUses);
1992    }
1993
1994    CI->eraseFromParent();
1995  }
1996}
1997
1998void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
1999  auto *TypeId = dyn_cast<MDString>(Slot.TypeID);
2000  if (!TypeId)
2001    return;
2002  const TypeIdSummary *TidSummary =
2003      ImportSummary->getTypeIdSummary(TypeId->getString());
2004  if (!TidSummary)
2005    return;
2006  auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
2007  if (ResI == TidSummary->WPDRes.end())
2008    return;
2009  const WholeProgramDevirtResolution &Res = ResI->second;
2010
2011  if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
2012    assert(!Res.SingleImplName.empty());
2013    // The type of the function in the declaration is irrelevant because every
2014    // call site will cast it to the correct type.
2015    Constant *SingleImpl =
2016        cast<Constant>(M.getOrInsertFunction(Res.SingleImplName,
2017                                             Type::getVoidTy(M.getContext()))
2018                           .getCallee());
2019
2020    // This is the import phase so we should not be exporting anything.
2021    bool IsExported = false;
2022    applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
2023    assert(!IsExported);
2024  }
2025
2026  for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
2027    auto I = Res.ResByArg.find(CSByConstantArg.first);
2028    if (I == Res.ResByArg.end())
2029      continue;
2030    auto &ResByArg = I->second;
2031    // FIXME: We should figure out what to do about the "function name" argument
2032    // to the apply* functions, as the function names are unavailable during the
2033    // importing phase. For now we just pass the empty string. This does not
2034    // impact correctness because the function names are just used for remarks.
2035    switch (ResByArg.TheKind) {
2036    case WholeProgramDevirtResolution::ByArg::UniformRetVal:
2037      applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
2038      break;
2039    case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
2040      Constant *UniqueMemberAddr =
2041          importGlobal(Slot, CSByConstantArg.first, "unique_member");
2042      applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
2043                           UniqueMemberAddr);
2044      break;
2045    }
2046    case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
2047      Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte",
2048                                      Int32Ty, ResByArg.Byte);
2049      Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty,
2050                                     ResByArg.Bit);
2051      applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
2052      break;
2053    }
2054    default:
2055      break;
2056    }
2057  }
2058
2059  if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
2060    // The type of the function is irrelevant, because it's bitcast at calls
2061    // anyhow.
2062    Constant *JT = cast<Constant>(
2063        M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
2064                              Type::getVoidTy(M.getContext()))
2065            .getCallee());
2066    bool IsExported = false;
2067    applyICallBranchFunnel(SlotInfo, JT, IsExported);
2068    assert(!IsExported);
2069  }
2070}
2071
2072void DevirtModule::removeRedundantTypeTests() {
2073  auto True = ConstantInt::getTrue(M.getContext());
2074  for (auto &&U : NumUnsafeUsesForTypeTest) {
2075    if (U.second == 0) {
2076      U.first->replaceAllUsesWith(True);
2077      U.first->eraseFromParent();
2078    }
2079  }
2080}
2081
2082ValueInfo
2083DevirtModule::lookUpFunctionValueInfo(Function *TheFn,
2084                                      ModuleSummaryIndex *ExportSummary) {
2085  assert((ExportSummary != nullptr) &&
2086         "Caller guarantees ExportSummary is not nullptr");
2087
2088  const auto TheFnGUID = TheFn->getGUID();
2089  const auto TheFnGUIDWithExportedName = GlobalValue::getGUID(TheFn->getName());
2090  // Look up ValueInfo with the GUID in the current linkage.
2091  ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFnGUID);
2092  // If no entry is found and GUID is different from GUID computed using
2093  // exported name, look up ValueInfo with the exported name unconditionally.
2094  // This is a fallback.
2095  //
2096  // The reason to have a fallback:
2097  // 1. LTO could enable global value internalization via
2098  // `enable-lto-internalization`.
2099  // 2. The GUID in ExportedSummary is computed using exported name.
2100  if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) {
2101    TheFnVI = ExportSummary->getValueInfo(TheFnGUIDWithExportedName);
2102  }
2103  return TheFnVI;
2104}
2105
2106bool DevirtModule::mustBeUnreachableFunction(
2107    Function *const F, ModuleSummaryIndex *ExportSummary) {
2108  // First, learn unreachability by analyzing function IR.
2109  if (!F->isDeclaration()) {
2110    // A function must be unreachable if its entry block ends with an
2111    // 'unreachable'.
2112    return isa<UnreachableInst>(F->getEntryBlock().getTerminator());
2113  }
2114  // Learn unreachability from ExportSummary if ExportSummary is present.
2115  return ExportSummary &&
2116         ::mustBeUnreachableFunction(
2117             DevirtModule::lookUpFunctionValueInfo(F, ExportSummary));
2118}
2119
2120bool DevirtModule::run() {
2121  // If only some of the modules were split, we cannot correctly perform
2122  // this transformation. We already checked for the presense of type tests
2123  // with partially split modules during the thin link, and would have emitted
2124  // an error if any were found, so here we can simply return.
2125  if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) ||
2126      (ImportSummary && ImportSummary->partiallySplitLTOUnits()))
2127    return false;
2128
2129  Function *TypeTestFunc =
2130      M.getFunction(Intrinsic::getName(Intrinsic::type_test));
2131  Function *TypeCheckedLoadFunc =
2132      M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
2133  Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
2134
2135  // Normally if there are no users of the devirtualization intrinsics in the
2136  // module, this pass has nothing to do. But if we are exporting, we also need
2137  // to handle any users that appear only in the function summaries.
2138  if (!ExportSummary &&
2139      (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
2140       AssumeFunc->use_empty()) &&
2141      (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
2142    return false;
2143
2144  // Rebuild type metadata into a map for easy lookup.
2145  std::vector<VTableBits> Bits;
2146  DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
2147  buildTypeIdentifierMap(Bits, TypeIdMap);
2148
2149  if (TypeTestFunc && AssumeFunc)
2150    scanTypeTestUsers(TypeTestFunc, TypeIdMap);
2151
2152  if (TypeCheckedLoadFunc)
2153    scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
2154
2155  if (ImportSummary) {
2156    for (auto &S : CallSlots)
2157      importResolution(S.first, S.second);
2158
2159    removeRedundantTypeTests();
2160
2161    // We have lowered or deleted the type intrinsics, so we will no longer have
2162    // enough information to reason about the liveness of virtual function
2163    // pointers in GlobalDCE.
2164    for (GlobalVariable &GV : M.globals())
2165      GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
2166
2167    // The rest of the code is only necessary when exporting or during regular
2168    // LTO, so we are done.
2169    return true;
2170  }
2171
2172  if (TypeIdMap.empty())
2173    return true;
2174
2175  // Collect information from summary about which calls to try to devirtualize.
2176  if (ExportSummary) {
2177    DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
2178    for (auto &P : TypeIdMap) {
2179      if (auto *TypeId = dyn_cast<MDString>(P.first))
2180        MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
2181            TypeId);
2182    }
2183
2184    for (auto &P : *ExportSummary) {
2185      for (auto &S : P.second.SummaryList) {
2186        auto *FS = dyn_cast<FunctionSummary>(S.get());
2187        if (!FS)
2188          continue;
2189        // FIXME: Only add live functions.
2190        for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
2191          for (Metadata *MD : MetadataByGUID[VF.GUID]) {
2192            CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
2193          }
2194        }
2195        for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
2196          for (Metadata *MD : MetadataByGUID[VF.GUID]) {
2197            CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
2198          }
2199        }
2200        for (const FunctionSummary::ConstVCall &VC :
2201             FS->type_test_assume_const_vcalls()) {
2202          for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
2203            CallSlots[{MD, VC.VFunc.Offset}]
2204                .ConstCSInfo[VC.Args]
2205                .addSummaryTypeTestAssumeUser(FS);
2206          }
2207        }
2208        for (const FunctionSummary::ConstVCall &VC :
2209             FS->type_checked_load_const_vcalls()) {
2210          for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
2211            CallSlots[{MD, VC.VFunc.Offset}]
2212                .ConstCSInfo[VC.Args]
2213                .addSummaryTypeCheckedLoadUser(FS);
2214          }
2215        }
2216      }
2217    }
2218  }
2219
2220  // For each (type, offset) pair:
2221  bool DidVirtualConstProp = false;
2222  std::map<std::string, Function*> DevirtTargets;
2223  for (auto &S : CallSlots) {
2224    // Search each of the members of the type identifier for the virtual
2225    // function implementation at offset S.first.ByteOffset, and add to
2226    // TargetsForSlot.
2227    std::vector<VirtualCallTarget> TargetsForSlot;
2228    WholeProgramDevirtResolution *Res = nullptr;
2229    const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID];
2230    if (ExportSummary && isa<MDString>(S.first.TypeID) &&
2231        TypeMemberInfos.size())
2232      // For any type id used on a global's type metadata, create the type id
2233      // summary resolution regardless of whether we can devirtualize, so that
2234      // lower type tests knows the type id is not Unsat. If it was not used on
2235      // a global's type metadata, the TypeIdMap entry set will be empty, and
2236      // we don't want to create an entry (with the default Unknown type
2237      // resolution), which can prevent detection of the Unsat.
2238      Res = &ExportSummary
2239                 ->getOrInsertTypeIdSummary(
2240                     cast<MDString>(S.first.TypeID)->getString())
2241                 .WPDRes[S.first.ByteOffset];
2242    if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos,
2243                                  S.first.ByteOffset, ExportSummary)) {
2244
2245      if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) {
2246        DidVirtualConstProp |=
2247            tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
2248
2249        tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
2250      }
2251
2252      // Collect functions devirtualized at least for one call site for stats.
2253      if (RemarksEnabled || AreStatisticsEnabled())
2254        for (const auto &T : TargetsForSlot)
2255          if (T.WasDevirt)
2256            DevirtTargets[std::string(T.Fn->getName())] = T.Fn;
2257    }
2258
2259    // CFI-specific: if we are exporting and any llvm.type.checked.load
2260    // intrinsics were *not* devirtualized, we need to add the resulting
2261    // llvm.type.test intrinsics to the function summaries so that the
2262    // LowerTypeTests pass will export them.
2263    if (ExportSummary && isa<MDString>(S.first.TypeID)) {
2264      auto GUID =
2265          GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
2266      for (auto *FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
2267        FS->addTypeTest(GUID);
2268      for (auto &CCS : S.second.ConstCSInfo)
2269        for (auto *FS : CCS.second.SummaryTypeCheckedLoadUsers)
2270          FS->addTypeTest(GUID);
2271    }
2272  }
2273
2274  if (RemarksEnabled) {
2275    // Generate remarks for each devirtualized function.
2276    for (const auto &DT : DevirtTargets) {
2277      Function *F = DT.second;
2278
2279      using namespace ore;
2280      OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
2281                        << "devirtualized "
2282                        << NV("FunctionName", DT.first));
2283    }
2284  }
2285
2286  NumDevirtTargets += DevirtTargets.size();
2287
2288  removeRedundantTypeTests();
2289
2290  // Rebuild each global we touched as part of virtual constant propagation to
2291  // include the before and after bytes.
2292  if (DidVirtualConstProp)
2293    for (VTableBits &B : Bits)
2294      rebuildGlobal(B);
2295
2296  // We have lowered or deleted the type intrinsics, so we will no longer have
2297  // enough information to reason about the liveness of virtual function
2298  // pointers in GlobalDCE.
2299  for (GlobalVariable &GV : M.globals())
2300    GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
2301
2302  return true;
2303}
2304
2305void DevirtIndex::run() {
2306  if (ExportSummary.typeIdCompatibleVtableMap().empty())
2307    return;
2308
2309  DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
2310  for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
2311    NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first);
2312    // Create the type id summary resolution regardlness of whether we can
2313    // devirtualize, so that lower type tests knows the type id is used on
2314    // a global and not Unsat. We do this here rather than in the loop over the
2315    // CallSlots, since that handling will only see type tests that directly
2316    // feed assumes, and we would miss any that aren't currently handled by WPD
2317    // (such as type tests that feed assumes via phis).
2318    ExportSummary.getOrInsertTypeIdSummary(P.first);
2319  }
2320
2321  // Collect information from summary about which calls to try to devirtualize.
2322  for (auto &P : ExportSummary) {
2323    for (auto &S : P.second.SummaryList) {
2324      auto *FS = dyn_cast<FunctionSummary>(S.get());
2325      if (!FS)
2326        continue;
2327      // FIXME: Only add live functions.
2328      for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
2329        for (StringRef Name : NameByGUID[VF.GUID]) {
2330          CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
2331        }
2332      }
2333      for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
2334        for (StringRef Name : NameByGUID[VF.GUID]) {
2335          CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
2336        }
2337      }
2338      for (const FunctionSummary::ConstVCall &VC :
2339           FS->type_test_assume_const_vcalls()) {
2340        for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
2341          CallSlots[{Name, VC.VFunc.Offset}]
2342              .ConstCSInfo[VC.Args]
2343              .addSummaryTypeTestAssumeUser(FS);
2344        }
2345      }
2346      for (const FunctionSummary::ConstVCall &VC :
2347           FS->type_checked_load_const_vcalls()) {
2348        for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
2349          CallSlots[{Name, VC.VFunc.Offset}]
2350              .ConstCSInfo[VC.Args]
2351              .addSummaryTypeCheckedLoadUser(FS);
2352        }
2353      }
2354    }
2355  }
2356
2357  std::set<ValueInfo> DevirtTargets;
2358  // For each (type, offset) pair:
2359  for (auto &S : CallSlots) {
2360    // Search each of the members of the type identifier for the virtual
2361    // function implementation at offset S.first.ByteOffset, and add to
2362    // TargetsForSlot.
2363    std::vector<ValueInfo> TargetsForSlot;
2364    auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID);
2365    assert(TidSummary);
2366    // The type id summary would have been created while building the NameByGUID
2367    // map earlier.
2368    WholeProgramDevirtResolution *Res =
2369        &ExportSummary.getTypeIdSummary(S.first.TypeID)
2370             ->WPDRes[S.first.ByteOffset];
2371    if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary,
2372                                  S.first.ByteOffset)) {
2373
2374      if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res,
2375                               DevirtTargets))
2376        continue;
2377    }
2378  }
2379
2380  // Optionally have the thin link print message for each devirtualized
2381  // function.
2382  if (PrintSummaryDevirt)
2383    for (const auto &DT : DevirtTargets)
2384      errs() << "Devirtualized call to " << DT << "\n";
2385
2386  NumDevirtTargets += DevirtTargets.size();
2387}
2388