1//===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass tries to expand memcmp() calls into optimally-sized loads and
10// compares for the target.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/Statistic.h"
15#include "llvm/Analysis/ConstantFolding.h"
16#include "llvm/Analysis/LazyBlockFrequencyInfo.h"
17#include "llvm/Analysis/ProfileSummaryInfo.h"
18#include "llvm/Analysis/TargetLibraryInfo.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/Analysis/ValueTracking.h"
21#include "llvm/CodeGen/TargetLowering.h"
22#include "llvm/CodeGen/TargetPassConfig.h"
23#include "llvm/CodeGen/TargetSubtargetInfo.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/InitializePasses.h"
26#include "llvm/Transforms/Utils/SizeOpts.h"
27
28using namespace llvm;
29
30#define DEBUG_TYPE "expandmemcmp"
31
32STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
33STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
34STATISTIC(NumMemCmpGreaterThanMax,
35          "Number of memcmp calls with size greater than max size");
36STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
37
38static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
39    "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
40    cl::desc("The number of loads per basic block for inline expansion of "
41             "memcmp that is only being compared against zero."));
42
43static cl::opt<unsigned> MaxLoadsPerMemcmp(
44    "max-loads-per-memcmp", cl::Hidden,
45    cl::desc("Set maximum number of loads used in expanded memcmp"));
46
47static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize(
48    "max-loads-per-memcmp-opt-size", cl::Hidden,
49    cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz"));
50
51namespace {
52
53
54// This class provides helper functions to expand a memcmp library call into an
55// inline expansion.
56class MemCmpExpansion {
57  struct ResultBlock {
58    BasicBlock *BB = nullptr;
59    PHINode *PhiSrc1 = nullptr;
60    PHINode *PhiSrc2 = nullptr;
61
62    ResultBlock() = default;
63  };
64
65  CallInst *const CI;
66  ResultBlock ResBlock;
67  const uint64_t Size;
68  unsigned MaxLoadSize;
69  uint64_t NumLoadsNonOneByte;
70  const uint64_t NumLoadsPerBlockForZeroCmp;
71  std::vector<BasicBlock *> LoadCmpBlocks;
72  BasicBlock *EndBlock;
73  PHINode *PhiRes;
74  const bool IsUsedForZeroCmp;
75  const DataLayout &DL;
76  IRBuilder<> Builder;
77  // Represents the decomposition in blocks of the expansion. For example,
78  // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
79  // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
80  struct LoadEntry {
81    LoadEntry(unsigned LoadSize, uint64_t Offset)
82        : LoadSize(LoadSize), Offset(Offset) {
83    }
84
85    // The size of the load for this block, in bytes.
86    unsigned LoadSize;
87    // The offset of this load from the base pointer, in bytes.
88    uint64_t Offset;
89  };
90  using LoadEntryVector = SmallVector<LoadEntry, 8>;
91  LoadEntryVector LoadSequence;
92
93  void createLoadCmpBlocks();
94  void createResultBlock();
95  void setupResultBlockPHINodes();
96  void setupEndBlockPHINodes();
97  Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
98  void emitLoadCompareBlock(unsigned BlockIndex);
99  void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
100                                         unsigned &LoadIndex);
101  void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes);
102  void emitMemCmpResultBlock();
103  Value *getMemCmpExpansionZeroCase();
104  Value *getMemCmpEqZeroOneBlock();
105  Value *getMemCmpOneBlock();
106  Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
107                                 uint64_t OffsetBytes);
108
109  static LoadEntryVector
110  computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
111                            unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte);
112  static LoadEntryVector
113  computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize,
114                                 unsigned MaxNumLoads,
115                                 unsigned &NumLoadsNonOneByte);
116
117public:
118  MemCmpExpansion(CallInst *CI, uint64_t Size,
119                  const TargetTransformInfo::MemCmpExpansionOptions &Options,
120                  const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout);
121
122  unsigned getNumBlocks();
123  uint64_t getNumLoads() const { return LoadSequence.size(); }
124
125  Value *getMemCmpExpansion();
126};
127
128MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence(
129    uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
130    const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) {
131  NumLoadsNonOneByte = 0;
132  LoadEntryVector LoadSequence;
133  uint64_t Offset = 0;
134  while (Size && !LoadSizes.empty()) {
135    const unsigned LoadSize = LoadSizes.front();
136    const uint64_t NumLoadsForThisSize = Size / LoadSize;
137    if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
138      // Do not expand if the total number of loads is larger than what the
139      // target allows. Note that it's important that we exit before completing
140      // the expansion to avoid using a ton of memory to store the expansion for
141      // large sizes.
142      return {};
143    }
144    if (NumLoadsForThisSize > 0) {
145      for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
146        LoadSequence.push_back({LoadSize, Offset});
147        Offset += LoadSize;
148      }
149      if (LoadSize > 1)
150        ++NumLoadsNonOneByte;
151      Size = Size % LoadSize;
152    }
153    LoadSizes = LoadSizes.drop_front();
154  }
155  return LoadSequence;
156}
157
158MemCmpExpansion::LoadEntryVector
159MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
160                                                const unsigned MaxLoadSize,
161                                                const unsigned MaxNumLoads,
162                                                unsigned &NumLoadsNonOneByte) {
163  // These are already handled by the greedy approach.
164  if (Size < 2 || MaxLoadSize < 2)
165    return {};
166
167  // We try to do as many non-overlapping loads as possible starting from the
168  // beginning.
169  const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
170  assert(NumNonOverlappingLoads && "there must be at least one load");
171  // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
172  // an overlapping load.
173  Size = Size - NumNonOverlappingLoads * MaxLoadSize;
174  // Bail if we do not need an overloapping store, this is already handled by
175  // the greedy approach.
176  if (Size == 0)
177    return {};
178  // Bail if the number of loads (non-overlapping + potential overlapping one)
179  // is larger than the max allowed.
180  if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
181    return {};
182
183  // Add non-overlapping loads.
184  LoadEntryVector LoadSequence;
185  uint64_t Offset = 0;
186  for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) {
187    LoadSequence.push_back({MaxLoadSize, Offset});
188    Offset += MaxLoadSize;
189  }
190
191  // Add the last overlapping load.
192  assert(Size > 0 && Size < MaxLoadSize && "broken invariant");
193  LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)});
194  NumLoadsNonOneByte = 1;
195  return LoadSequence;
196}
197
198// Initialize the basic block structure required for expansion of memcmp call
199// with given maximum load size and memcmp size parameter.
200// This structure includes:
201// 1. A list of load compare blocks - LoadCmpBlocks.
202// 2. An EndBlock, split from original instruction point, which is the block to
203// return from.
204// 3. ResultBlock, block to branch to for early exit when a
205// LoadCmpBlock finds a difference.
206MemCmpExpansion::MemCmpExpansion(
207    CallInst *const CI, uint64_t Size,
208    const TargetTransformInfo::MemCmpExpansionOptions &Options,
209    const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout)
210    : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0),
211      NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock),
212      IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), Builder(CI) {
213  assert(Size > 0 && "zero blocks");
214  // Scale the max size down if the target can load more bytes than we need.
215  llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes);
216  while (!LoadSizes.empty() && LoadSizes.front() > Size) {
217    LoadSizes = LoadSizes.drop_front();
218  }
219  assert(!LoadSizes.empty() && "cannot load Size bytes");
220  MaxLoadSize = LoadSizes.front();
221  // Compute the decomposition.
222  unsigned GreedyNumLoadsNonOneByte = 0;
223  LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads,
224                                           GreedyNumLoadsNonOneByte);
225  NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
226  assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
227  // If we allow overlapping loads and the load sequence is not already optimal,
228  // use overlapping loads.
229  if (Options.AllowOverlappingLoads &&
230      (LoadSequence.empty() || LoadSequence.size() > 2)) {
231    unsigned OverlappingNumLoadsNonOneByte = 0;
232    auto OverlappingLoads = computeOverlappingLoadSequence(
233        Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte);
234    if (!OverlappingLoads.empty() &&
235        (LoadSequence.empty() ||
236         OverlappingLoads.size() < LoadSequence.size())) {
237      LoadSequence = OverlappingLoads;
238      NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
239    }
240  }
241  assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
242}
243
244unsigned MemCmpExpansion::getNumBlocks() {
245  if (IsUsedForZeroCmp)
246    return getNumLoads() / NumLoadsPerBlockForZeroCmp +
247           (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
248  return getNumLoads();
249}
250
251void MemCmpExpansion::createLoadCmpBlocks() {
252  for (unsigned i = 0; i < getNumBlocks(); i++) {
253    BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
254                                        EndBlock->getParent(), EndBlock);
255    LoadCmpBlocks.push_back(BB);
256  }
257}
258
259void MemCmpExpansion::createResultBlock() {
260  ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
261                                   EndBlock->getParent(), EndBlock);
262}
263
264/// Return a pointer to an element of type `LoadSizeType` at offset
265/// `OffsetBytes`.
266Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
267                                                Type *LoadSizeType,
268                                                uint64_t OffsetBytes) {
269  if (OffsetBytes > 0) {
270    auto *ByteType = Type::getInt8Ty(CI->getContext());
271    Source = Builder.CreateConstGEP1_64(
272        ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
273        OffsetBytes);
274  }
275  return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
276}
277
278// This function creates the IR instructions for loading and comparing 1 byte.
279// It loads 1 byte from each source of the memcmp parameters with the given
280// GEPIndex. It then subtracts the two loaded values and adds this result to the
281// final phi node for selecting the memcmp result.
282void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
283                                               unsigned OffsetBytes) {
284  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
285  Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
286  Value *Source1 =
287      getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
288  Value *Source2 =
289      getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);
290
291  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
292  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
293
294  LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
295  LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
296  Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
297
298  PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
299
300  if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
301    // Early exit branch if difference found to EndBlock. Otherwise, continue to
302    // next LoadCmpBlock,
303    Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
304                                    ConstantInt::get(Diff->getType(), 0));
305    BranchInst *CmpBr =
306        BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
307    Builder.Insert(CmpBr);
308  } else {
309    // The last block has an unconditional branch to EndBlock.
310    BranchInst *CmpBr = BranchInst::Create(EndBlock);
311    Builder.Insert(CmpBr);
312  }
313}
314
315/// Generate an equality comparison for one or more pairs of loaded values.
316/// This is used in the case where the memcmp() call is compared equal or not
317/// equal to zero.
318Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
319                                            unsigned &LoadIndex) {
320  assert(LoadIndex < getNumLoads() &&
321         "getCompareLoadPairs() called with no remaining loads");
322  std::vector<Value *> XorList, OrList;
323  Value *Diff = nullptr;
324
325  const unsigned NumLoads =
326      std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
327
328  // For a single-block expansion, start inserting before the memcmp call.
329  if (LoadCmpBlocks.empty())
330    Builder.SetInsertPoint(CI);
331  else
332    Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
333
334  Value *Cmp = nullptr;
335  // If we have multiple loads per block, we need to generate a composite
336  // comparison using xor+or. The type for the combinations is the largest load
337  // type.
338  IntegerType *const MaxLoadType =
339      NumLoads == 1 ? nullptr
340                    : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
341  for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
342    const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
343
344    IntegerType *LoadSizeType =
345        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
346
347    Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
348                                             CurLoadEntry.Offset);
349    Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
350                                             CurLoadEntry.Offset);
351
352    // Get a constant or load a value for each source address.
353    Value *LoadSrc1 = nullptr;
354    if (auto *Source1C = dyn_cast<Constant>(Source1))
355      LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
356    if (!LoadSrc1)
357      LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
358
359    Value *LoadSrc2 = nullptr;
360    if (auto *Source2C = dyn_cast<Constant>(Source2))
361      LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
362    if (!LoadSrc2)
363      LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
364
365    if (NumLoads != 1) {
366      if (LoadSizeType != MaxLoadType) {
367        LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
368        LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
369      }
370      // If we have multiple loads per block, we need to generate a composite
371      // comparison using xor+or.
372      Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
373      Diff = Builder.CreateZExt(Diff, MaxLoadType);
374      XorList.push_back(Diff);
375    } else {
376      // If there's only one load per block, we just compare the loaded values.
377      Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
378    }
379  }
380
381  auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
382    std::vector<Value *> OutList;
383    for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
384      Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
385      OutList.push_back(Or);
386    }
387    if (InList.size() % 2 != 0)
388      OutList.push_back(InList.back());
389    return OutList;
390  };
391
392  if (!Cmp) {
393    // Pairwise OR the XOR results.
394    OrList = pairWiseOr(XorList);
395
396    // Pairwise OR the OR results until one result left.
397    while (OrList.size() != 1) {
398      OrList = pairWiseOr(OrList);
399    }
400
401    assert(Diff && "Failed to find comparison diff");
402    Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
403  }
404
405  return Cmp;
406}
407
408void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
409                                                        unsigned &LoadIndex) {
410  Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
411
412  BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
413                           ? EndBlock
414                           : LoadCmpBlocks[BlockIndex + 1];
415  // Early exit branch if difference found to ResultBlock. Otherwise,
416  // continue to next LoadCmpBlock or EndBlock.
417  BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
418  Builder.Insert(CmpBr);
419
420  // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
421  // since early exit to ResultBlock was not taken (no difference was found in
422  // any of the bytes).
423  if (BlockIndex == LoadCmpBlocks.size() - 1) {
424    Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
425    PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
426  }
427}
428
429// This function creates the IR intructions for loading and comparing using the
430// given LoadSize. It loads the number of bytes specified by LoadSize from each
431// source of the memcmp parameters. It then does a subtract to see if there was
432// a difference in the loaded values. If a difference is found, it branches
433// with an early exit to the ResultBlock for calculating which source was
434// larger. Otherwise, it falls through to the either the next LoadCmpBlock or
435// the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
436// a special case through emitLoadCompareByteBlock. The special handling can
437// simply subtract the loaded values and add it to the result phi node.
438void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
439  // There is one load per block in this case, BlockIndex == LoadIndex.
440  const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
441
442  if (CurLoadEntry.LoadSize == 1) {
443    MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
444    return;
445  }
446
447  Type *LoadSizeType =
448      IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
449  Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
450  assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
451
452  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
453
454  Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
455                                           CurLoadEntry.Offset);
456  Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
457                                           CurLoadEntry.Offset);
458
459  // Load LoadSizeType from the base address.
460  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
461  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
462
463  if (DL.isLittleEndian()) {
464    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
465                                                Intrinsic::bswap, LoadSizeType);
466    LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
467    LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
468  }
469
470  if (LoadSizeType != MaxLoadType) {
471    LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
472    LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
473  }
474
475  // Add the loaded values to the phi nodes for calculating memcmp result only
476  // if result is not used in a zero equality.
477  if (!IsUsedForZeroCmp) {
478    ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
479    ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
480  }
481
482  Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
483  BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
484                           ? EndBlock
485                           : LoadCmpBlocks[BlockIndex + 1];
486  // Early exit branch if difference found to ResultBlock. Otherwise, continue
487  // to next LoadCmpBlock or EndBlock.
488  BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
489  Builder.Insert(CmpBr);
490
491  // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
492  // since early exit to ResultBlock was not taken (no difference was found in
493  // any of the bytes).
494  if (BlockIndex == LoadCmpBlocks.size() - 1) {
495    Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
496    PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
497  }
498}
499
500// This function populates the ResultBlock with a sequence to calculate the
501// memcmp result. It compares the two loaded source values and returns -1 if
502// src1 < src2 and 1 if src1 > src2.
503void MemCmpExpansion::emitMemCmpResultBlock() {
504  // Special case: if memcmp result is used in a zero equality, result does not
505  // need to be calculated and can simply return 1.
506  if (IsUsedForZeroCmp) {
507    BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
508    Builder.SetInsertPoint(ResBlock.BB, InsertPt);
509    Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
510    PhiRes->addIncoming(Res, ResBlock.BB);
511    BranchInst *NewBr = BranchInst::Create(EndBlock);
512    Builder.Insert(NewBr);
513    return;
514  }
515  BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
516  Builder.SetInsertPoint(ResBlock.BB, InsertPt);
517
518  Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
519                                  ResBlock.PhiSrc2);
520
521  Value *Res =
522      Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
523                           ConstantInt::get(Builder.getInt32Ty(), 1));
524
525  BranchInst *NewBr = BranchInst::Create(EndBlock);
526  Builder.Insert(NewBr);
527  PhiRes->addIncoming(Res, ResBlock.BB);
528}
529
530void MemCmpExpansion::setupResultBlockPHINodes() {
531  Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
532  Builder.SetInsertPoint(ResBlock.BB);
533  // Note: this assumes one load per block.
534  ResBlock.PhiSrc1 =
535      Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
536  ResBlock.PhiSrc2 =
537      Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
538}
539
540void MemCmpExpansion::setupEndBlockPHINodes() {
541  Builder.SetInsertPoint(&EndBlock->front());
542  PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
543}
544
545Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
546  unsigned LoadIndex = 0;
547  // This loop populates each of the LoadCmpBlocks with the IR sequence to
548  // handle multiple loads per block.
549  for (unsigned I = 0; I < getNumBlocks(); ++I) {
550    emitLoadCompareBlockMultipleLoads(I, LoadIndex);
551  }
552
553  emitMemCmpResultBlock();
554  return PhiRes;
555}
556
557/// A memcmp expansion that compares equality with 0 and only has one block of
558/// load and compare can bypass the compare, branch, and phi IR that is required
559/// in the general case.
560Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
561  unsigned LoadIndex = 0;
562  Value *Cmp = getCompareLoadPairs(0, LoadIndex);
563  assert(LoadIndex == getNumLoads() && "some entries were not consumed");
564  return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
565}
566
567/// A memcmp expansion that only has one block of load and compare can bypass
568/// the compare, branch, and phi IR that is required in the general case.
569Value *MemCmpExpansion::getMemCmpOneBlock() {
570  Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
571  Value *Source1 = CI->getArgOperand(0);
572  Value *Source2 = CI->getArgOperand(1);
573
574  // Cast source to LoadSizeType*.
575  if (Source1->getType() != LoadSizeType)
576    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
577  if (Source2->getType() != LoadSizeType)
578    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
579
580  // Load LoadSizeType from the base address.
581  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
582  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
583
584  if (DL.isLittleEndian() && Size != 1) {
585    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
586                                                Intrinsic::bswap, LoadSizeType);
587    LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
588    LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
589  }
590
591  if (Size < 4) {
592    // The i8 and i16 cases don't need compares. We zext the loaded values and
593    // subtract them to get the suitable negative, zero, or positive i32 result.
594    LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
595    LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
596    return Builder.CreateSub(LoadSrc1, LoadSrc2);
597  }
598
599  // The result of memcmp is negative, zero, or positive, so produce that by
600  // subtracting 2 extended compare bits: sub (ugt, ult).
601  // If a target prefers to use selects to get -1/0/1, they should be able
602  // to transform this later. The inverse transform (going from selects to math)
603  // may not be possible in the DAG because the selects got converted into
604  // branches before we got there.
605  Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
606  Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
607  Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
608  Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
609  return Builder.CreateSub(ZextUGT, ZextULT);
610}
611
612// This function expands the memcmp call into an inline expansion and returns
613// the memcmp result.
614Value *MemCmpExpansion::getMemCmpExpansion() {
615  // Create the basic block framework for a multi-block expansion.
616  if (getNumBlocks() != 1) {
617    BasicBlock *StartBlock = CI->getParent();
618    EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
619    setupEndBlockPHINodes();
620    createResultBlock();
621
622    // If return value of memcmp is not used in a zero equality, we need to
623    // calculate which source was larger. The calculation requires the
624    // two loaded source values of each load compare block.
625    // These will be saved in the phi nodes created by setupResultBlockPHINodes.
626    if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
627
628    // Create the number of required load compare basic blocks.
629    createLoadCmpBlocks();
630
631    // Update the terminator added by splitBasicBlock to branch to the first
632    // LoadCmpBlock.
633    StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
634  }
635
636  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
637
638  if (IsUsedForZeroCmp)
639    return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
640                               : getMemCmpExpansionZeroCase();
641
642  if (getNumBlocks() == 1)
643    return getMemCmpOneBlock();
644
645  for (unsigned I = 0; I < getNumBlocks(); ++I) {
646    emitLoadCompareBlock(I);
647  }
648
649  emitMemCmpResultBlock();
650  return PhiRes;
651}
652
653// This function checks to see if an expansion of memcmp can be generated.
654// It checks for constant compare size that is less than the max inline size.
655// If an expansion cannot occur, returns false to leave as a library call.
656// Otherwise, the library call is replaced with a new IR instruction sequence.
657/// We want to transform:
658/// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
659/// To:
660/// loadbb:
661///  %0 = bitcast i32* %buffer2 to i8*
662///  %1 = bitcast i32* %buffer1 to i8*
663///  %2 = bitcast i8* %1 to i64*
664///  %3 = bitcast i8* %0 to i64*
665///  %4 = load i64, i64* %2
666///  %5 = load i64, i64* %3
667///  %6 = call i64 @llvm.bswap.i64(i64 %4)
668///  %7 = call i64 @llvm.bswap.i64(i64 %5)
669///  %8 = sub i64 %6, %7
670///  %9 = icmp ne i64 %8, 0
671///  br i1 %9, label %res_block, label %loadbb1
672/// res_block:                                        ; preds = %loadbb2,
673/// %loadbb1, %loadbb
674///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
675///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
676///  %10 = icmp ult i64 %phi.src1, %phi.src2
677///  %11 = select i1 %10, i32 -1, i32 1
678///  br label %endblock
679/// loadbb1:                                          ; preds = %loadbb
680///  %12 = bitcast i32* %buffer2 to i8*
681///  %13 = bitcast i32* %buffer1 to i8*
682///  %14 = bitcast i8* %13 to i32*
683///  %15 = bitcast i8* %12 to i32*
684///  %16 = getelementptr i32, i32* %14, i32 2
685///  %17 = getelementptr i32, i32* %15, i32 2
686///  %18 = load i32, i32* %16
687///  %19 = load i32, i32* %17
688///  %20 = call i32 @llvm.bswap.i32(i32 %18)
689///  %21 = call i32 @llvm.bswap.i32(i32 %19)
690///  %22 = zext i32 %20 to i64
691///  %23 = zext i32 %21 to i64
692///  %24 = sub i64 %22, %23
693///  %25 = icmp ne i64 %24, 0
694///  br i1 %25, label %res_block, label %loadbb2
695/// loadbb2:                                          ; preds = %loadbb1
696///  %26 = bitcast i32* %buffer2 to i8*
697///  %27 = bitcast i32* %buffer1 to i8*
698///  %28 = bitcast i8* %27 to i16*
699///  %29 = bitcast i8* %26 to i16*
700///  %30 = getelementptr i16, i16* %28, i16 6
701///  %31 = getelementptr i16, i16* %29, i16 6
702///  %32 = load i16, i16* %30
703///  %33 = load i16, i16* %31
704///  %34 = call i16 @llvm.bswap.i16(i16 %32)
705///  %35 = call i16 @llvm.bswap.i16(i16 %33)
706///  %36 = zext i16 %34 to i64
707///  %37 = zext i16 %35 to i64
708///  %38 = sub i64 %36, %37
709///  %39 = icmp ne i64 %38, 0
710///  br i1 %39, label %res_block, label %loadbb3
711/// loadbb3:                                          ; preds = %loadbb2
712///  %40 = bitcast i32* %buffer2 to i8*
713///  %41 = bitcast i32* %buffer1 to i8*
714///  %42 = getelementptr i8, i8* %41, i8 14
715///  %43 = getelementptr i8, i8* %40, i8 14
716///  %44 = load i8, i8* %42
717///  %45 = load i8, i8* %43
718///  %46 = zext i8 %44 to i32
719///  %47 = zext i8 %45 to i32
720///  %48 = sub i32 %46, %47
721///  br label %endblock
722/// endblock:                                         ; preds = %res_block,
723/// %loadbb3
724///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
725///  ret i32 %phi.res
726static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
727                         const TargetLowering *TLI, const DataLayout *DL,
728                         ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) {
729  NumMemCmpCalls++;
730
731  // Early exit from expansion if -Oz.
732  if (CI->getFunction()->hasMinSize())
733    return false;
734
735  // Early exit from expansion if size is not a constant.
736  ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
737  if (!SizeCast) {
738    NumMemCmpNotConstant++;
739    return false;
740  }
741  const uint64_t SizeVal = SizeCast->getZExtValue();
742
743  if (SizeVal == 0) {
744    return false;
745  }
746  // TTI call to check if target would like to expand memcmp. Also, get the
747  // available load sizes.
748  const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
749  bool OptForSize = CI->getFunction()->hasOptSize() ||
750                    llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI);
751  auto Options = TTI->enableMemCmpExpansion(OptForSize,
752                                            IsUsedForZeroCmp);
753  if (!Options) return false;
754
755  if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences())
756    Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock;
757
758  if (OptForSize &&
759      MaxLoadsPerMemcmpOptSize.getNumOccurrences())
760    Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize;
761
762  if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences())
763    Options.MaxNumLoads = MaxLoadsPerMemcmp;
764
765  MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL);
766
767  // Don't expand if this will require more loads than desired by the target.
768  if (Expansion.getNumLoads() == 0) {
769    NumMemCmpGreaterThanMax++;
770    return false;
771  }
772
773  NumMemCmpInlined++;
774
775  Value *Res = Expansion.getMemCmpExpansion();
776
777  // Replace call with result of expansion and erase call.
778  CI->replaceAllUsesWith(Res);
779  CI->eraseFromParent();
780
781  return true;
782}
783
784
785
786class ExpandMemCmpPass : public FunctionPass {
787public:
788  static char ID;
789
790  ExpandMemCmpPass() : FunctionPass(ID) {
791    initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
792  }
793
794  bool runOnFunction(Function &F) override {
795    if (skipFunction(F)) return false;
796
797    auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
798    if (!TPC) {
799      return false;
800    }
801    const TargetLowering* TL =
802        TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
803
804    const TargetLibraryInfo *TLI =
805        &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
806    const TargetTransformInfo *TTI =
807        &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
808    auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
809    auto *BFI = (PSI && PSI->hasProfileSummary()) ?
810           &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
811           nullptr;
812    auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI);
813    return !PA.areAllPreserved();
814  }
815
816private:
817  void getAnalysisUsage(AnalysisUsage &AU) const override {
818    AU.addRequired<TargetLibraryInfoWrapperPass>();
819    AU.addRequired<TargetTransformInfoWrapperPass>();
820    AU.addRequired<ProfileSummaryInfoWrapperPass>();
821    LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU);
822    FunctionPass::getAnalysisUsage(AU);
823  }
824
825  PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
826                            const TargetTransformInfo *TTI,
827                            const TargetLowering* TL,
828                            ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
829  // Returns true if a change was made.
830  bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
831                  const TargetTransformInfo *TTI, const TargetLowering* TL,
832                  const DataLayout& DL, ProfileSummaryInfo *PSI,
833                  BlockFrequencyInfo *BFI);
834};
835
836bool ExpandMemCmpPass::runOnBlock(
837    BasicBlock &BB, const TargetLibraryInfo *TLI,
838    const TargetTransformInfo *TTI, const TargetLowering* TL,
839    const DataLayout& DL, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) {
840  for (Instruction& I : BB) {
841    CallInst *CI = dyn_cast<CallInst>(&I);
842    if (!CI) {
843      continue;
844    }
845    LibFunc Func;
846    if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
847        (Func == LibFunc_memcmp || Func == LibFunc_bcmp) &&
848        expandMemCmp(CI, TTI, TL, &DL, PSI, BFI)) {
849      return true;
850    }
851  }
852  return false;
853}
854
855
856PreservedAnalyses ExpandMemCmpPass::runImpl(
857    Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
858    const TargetLowering* TL, ProfileSummaryInfo *PSI,
859    BlockFrequencyInfo *BFI) {
860  const DataLayout& DL = F.getParent()->getDataLayout();
861  bool MadeChanges = false;
862  for (auto BBIt = F.begin(); BBIt != F.end();) {
863    if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI)) {
864      MadeChanges = true;
865      // If changes were made, restart the function from the beginning, since
866      // the structure of the function was changed.
867      BBIt = F.begin();
868    } else {
869      ++BBIt;
870    }
871  }
872  return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
873}
874
875} // namespace
876
877char ExpandMemCmpPass::ID = 0;
878INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
879                      "Expand memcmp() to load/stores", false, false)
880INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
881INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
882INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass)
883INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
884INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
885                    "Expand memcmp() to load/stores", false, false)
886
887FunctionPass *llvm::createExpandMemCmpPass() {
888  return new ExpandMemCmpPass();
889}
890