1//===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9#include "llvm/ADT/SetVector.h"
10#include "llvm/ADT/SmallBitVector.h"
11#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
13#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
14#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
17#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
18#include "llvm/CodeGen/GlobalISel/Utils.h"
19#include "llvm/CodeGen/LowLevelType.h"
20#include "llvm/CodeGen/MachineBasicBlock.h"
21#include "llvm/CodeGen/MachineDominators.h"
22#include "llvm/CodeGen/MachineInstr.h"
23#include "llvm/CodeGen/MachineMemOperand.h"
24#include "llvm/CodeGen/MachineRegisterInfo.h"
25#include "llvm/CodeGen/RegisterBankInfo.h"
26#include "llvm/CodeGen/TargetInstrInfo.h"
27#include "llvm/CodeGen/TargetLowering.h"
28#include "llvm/CodeGen/TargetOpcodes.h"
29#include "llvm/IR/DataLayout.h"
30#include "llvm/IR/InstrTypes.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/DivisionByConstantInfo.h"
33#include "llvm/Support/MathExtras.h"
34#include "llvm/Target/TargetMachine.h"
35#include <cmath>
36#include <optional>
37#include <tuple>
38
39#define DEBUG_TYPE "gi-combiner"
40
41using namespace llvm;
42using namespace MIPatternMatch;
43
44// Option to allow testing of the combiner while no targets know about indexed
45// addressing.
46static cl::opt<bool>
47    ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(false),
48                       cl::desc("Force all indexed operations to be "
49                                "legal for the GlobalISel combiner"));
50
51CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
52                               MachineIRBuilder &B, bool IsPreLegalize,
53                               GISelKnownBits *KB, MachineDominatorTree *MDT,
54                               const LegalizerInfo *LI)
55    : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB),
56      MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
57      RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
58      TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
59  (void)this->KB;
60}
61
62const TargetLowering &CombinerHelper::getTargetLowering() const {
63  return *Builder.getMF().getSubtarget().getTargetLowering();
64}
65
66/// \returns The little endian in-memory byte position of byte \p I in a
67/// \p ByteWidth bytes wide type.
68///
69/// E.g. Given a 4-byte type x, x[0] -> byte 0
70static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
71  assert(I < ByteWidth && "I must be in [0, ByteWidth)");
72  return I;
73}
74
75/// Determines the LogBase2 value for a non-null input value using the
76/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
77static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
78  auto &MRI = *MIB.getMRI();
79  LLT Ty = MRI.getType(V);
80  auto Ctlz = MIB.buildCTLZ(Ty, V);
81  auto Base = MIB.buildConstant(Ty, Ty.getScalarSizeInBits() - 1);
82  return MIB.buildSub(Ty, Base, Ctlz).getReg(0);
83}
84
85/// \returns The big endian in-memory byte position of byte \p I in a
86/// \p ByteWidth bytes wide type.
87///
88/// E.g. Given a 4-byte type x, x[0] -> byte 3
89static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
90  assert(I < ByteWidth && "I must be in [0, ByteWidth)");
91  return ByteWidth - I - 1;
92}
93
94/// Given a map from byte offsets in memory to indices in a load/store,
95/// determine if that map corresponds to a little or big endian byte pattern.
96///
97/// \param MemOffset2Idx maps memory offsets to address offsets.
98/// \param LowestIdx is the lowest index in \p MemOffset2Idx.
99///
100/// \returns true if the map corresponds to a big endian byte pattern, false if
101/// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
102///
103/// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
104/// are as follows:
105///
106/// AddrOffset   Little endian    Big endian
107/// 0            0                3
108/// 1            1                2
109/// 2            2                1
110/// 3            3                0
111static std::optional<bool>
112isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
113            int64_t LowestIdx) {
114  // Need at least two byte positions to decide on endianness.
115  unsigned Width = MemOffset2Idx.size();
116  if (Width < 2)
117    return std::nullopt;
118  bool BigEndian = true, LittleEndian = true;
119  for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
120    auto MemOffsetAndIdx = MemOffset2Idx.find(MemOffset);
121    if (MemOffsetAndIdx == MemOffset2Idx.end())
122      return std::nullopt;
123    const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
124    assert(Idx >= 0 && "Expected non-negative byte offset?");
125    LittleEndian &= Idx == littleEndianByteAt(Width, MemOffset);
126    BigEndian &= Idx == bigEndianByteAt(Width, MemOffset);
127    if (!BigEndian && !LittleEndian)
128      return std::nullopt;
129  }
130
131  assert((BigEndian != LittleEndian) &&
132         "Pattern cannot be both big and little endian!");
133  return BigEndian;
134}
135
136bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
137
138bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
139  assert(LI && "Must have LegalizerInfo to query isLegal!");
140  return LI->getAction(Query).Action == LegalizeActions::Legal;
141}
142
143bool CombinerHelper::isLegalOrBeforeLegalizer(
144    const LegalityQuery &Query) const {
145  return isPreLegalize() || isLegal(Query);
146}
147
148bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
149  if (!Ty.isVector())
150    return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
151  // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
152  if (isPreLegalize())
153    return true;
154  LLT EltTy = Ty.getElementType();
155  return isLegal({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
156         isLegal({TargetOpcode::G_CONSTANT, {EltTy}});
157}
158
159void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
160                                    Register ToReg) const {
161  Observer.changingAllUsesOfReg(MRI, FromReg);
162
163  if (MRI.constrainRegAttrs(ToReg, FromReg))
164    MRI.replaceRegWith(FromReg, ToReg);
165  else
166    Builder.buildCopy(ToReg, FromReg);
167
168  Observer.finishedChangingAllUsesOfReg();
169}
170
171void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
172                                      MachineOperand &FromRegOp,
173                                      Register ToReg) const {
174  assert(FromRegOp.getParent() && "Expected an operand in an MI");
175  Observer.changingInstr(*FromRegOp.getParent());
176
177  FromRegOp.setReg(ToReg);
178
179  Observer.changedInstr(*FromRegOp.getParent());
180}
181
182void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
183                                       unsigned ToOpcode) const {
184  Observer.changingInstr(FromMI);
185
186  FromMI.setDesc(Builder.getTII().get(ToOpcode));
187
188  Observer.changedInstr(FromMI);
189}
190
191const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
192  return RBI->getRegBank(Reg, MRI, *TRI);
193}
194
195void CombinerHelper::setRegBank(Register Reg, const RegisterBank *RegBank) {
196  if (RegBank)
197    MRI.setRegBank(Reg, *RegBank);
198}
199
200bool CombinerHelper::tryCombineCopy(MachineInstr &MI) {
201  if (matchCombineCopy(MI)) {
202    applyCombineCopy(MI);
203    return true;
204  }
205  return false;
206}
207bool CombinerHelper::matchCombineCopy(MachineInstr &MI) {
208  if (MI.getOpcode() != TargetOpcode::COPY)
209    return false;
210  Register DstReg = MI.getOperand(0).getReg();
211  Register SrcReg = MI.getOperand(1).getReg();
212  return canReplaceReg(DstReg, SrcReg, MRI);
213}
214void CombinerHelper::applyCombineCopy(MachineInstr &MI) {
215  Register DstReg = MI.getOperand(0).getReg();
216  Register SrcReg = MI.getOperand(1).getReg();
217  MI.eraseFromParent();
218  replaceRegWith(MRI, DstReg, SrcReg);
219}
220
221bool CombinerHelper::tryCombineConcatVectors(MachineInstr &MI) {
222  bool IsUndef = false;
223  SmallVector<Register, 4> Ops;
224  if (matchCombineConcatVectors(MI, IsUndef, Ops)) {
225    applyCombineConcatVectors(MI, IsUndef, Ops);
226    return true;
227  }
228  return false;
229}
230
231bool CombinerHelper::matchCombineConcatVectors(MachineInstr &MI, bool &IsUndef,
232                                               SmallVectorImpl<Register> &Ops) {
233  assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
234         "Invalid instruction");
235  IsUndef = true;
236  MachineInstr *Undef = nullptr;
237
238  // Walk over all the operands of concat vectors and check if they are
239  // build_vector themselves or undef.
240  // Then collect their operands in Ops.
241  for (const MachineOperand &MO : MI.uses()) {
242    Register Reg = MO.getReg();
243    MachineInstr *Def = MRI.getVRegDef(Reg);
244    assert(Def && "Operand not defined");
245    switch (Def->getOpcode()) {
246    case TargetOpcode::G_BUILD_VECTOR:
247      IsUndef = false;
248      // Remember the operands of the build_vector to fold
249      // them into the yet-to-build flattened concat vectors.
250      for (const MachineOperand &BuildVecMO : Def->uses())
251        Ops.push_back(BuildVecMO.getReg());
252      break;
253    case TargetOpcode::G_IMPLICIT_DEF: {
254      LLT OpType = MRI.getType(Reg);
255      // Keep one undef value for all the undef operands.
256      if (!Undef) {
257        Builder.setInsertPt(*MI.getParent(), MI);
258        Undef = Builder.buildUndef(OpType.getScalarType());
259      }
260      assert(MRI.getType(Undef->getOperand(0).getReg()) ==
261                 OpType.getScalarType() &&
262             "All undefs should have the same type");
263      // Break the undef vector in as many scalar elements as needed
264      // for the flattening.
265      for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
266           EltIdx != EltEnd; ++EltIdx)
267        Ops.push_back(Undef->getOperand(0).getReg());
268      break;
269    }
270    default:
271      return false;
272    }
273  }
274  return true;
275}
276void CombinerHelper::applyCombineConcatVectors(
277    MachineInstr &MI, bool IsUndef, const ArrayRef<Register> Ops) {
278  // We determined that the concat_vectors can be flatten.
279  // Generate the flattened build_vector.
280  Register DstReg = MI.getOperand(0).getReg();
281  Builder.setInsertPt(*MI.getParent(), MI);
282  Register NewDstReg = MRI.cloneVirtualRegister(DstReg);
283
284  // Note: IsUndef is sort of redundant. We could have determine it by
285  // checking that at all Ops are undef.  Alternatively, we could have
286  // generate a build_vector of undefs and rely on another combine to
287  // clean that up.  For now, given we already gather this information
288  // in tryCombineConcatVectors, just save compile time and issue the
289  // right thing.
290  if (IsUndef)
291    Builder.buildUndef(NewDstReg);
292  else
293    Builder.buildBuildVector(NewDstReg, Ops);
294  MI.eraseFromParent();
295  replaceRegWith(MRI, DstReg, NewDstReg);
296}
297
298bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) {
299  SmallVector<Register, 4> Ops;
300  if (matchCombineShuffleVector(MI, Ops)) {
301    applyCombineShuffleVector(MI, Ops);
302    return true;
303  }
304  return false;
305}
306
307bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
308                                               SmallVectorImpl<Register> &Ops) {
309  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
310         "Invalid instruction kind");
311  LLT DstType = MRI.getType(MI.getOperand(0).getReg());
312  Register Src1 = MI.getOperand(1).getReg();
313  LLT SrcType = MRI.getType(Src1);
314  // As bizarre as it may look, shuffle vector can actually produce
315  // scalar! This is because at the IR level a <1 x ty> shuffle
316  // vector is perfectly valid.
317  unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1;
318  unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1;
319
320  // If the resulting vector is smaller than the size of the source
321  // vectors being concatenated, we won't be able to replace the
322  // shuffle vector into a concat_vectors.
323  //
324  // Note: We may still be able to produce a concat_vectors fed by
325  //       extract_vector_elt and so on. It is less clear that would
326  //       be better though, so don't bother for now.
327  //
328  // If the destination is a scalar, the size of the sources doesn't
329  // matter. we will lower the shuffle to a plain copy. This will
330  // work only if the source and destination have the same size. But
331  // that's covered by the next condition.
332  //
333  // TODO: If the size between the source and destination don't match
334  //       we could still emit an extract vector element in that case.
335  if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1)
336    return false;
337
338  // Check that the shuffle mask can be broken evenly between the
339  // different sources.
340  if (DstNumElts % SrcNumElts != 0)
341    return false;
342
343  // Mask length is a multiple of the source vector length.
344  // Check if the shuffle is some kind of concatenation of the input
345  // vectors.
346  unsigned NumConcat = DstNumElts / SrcNumElts;
347  SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
348  ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
349  for (unsigned i = 0; i != DstNumElts; ++i) {
350    int Idx = Mask[i];
351    // Undef value.
352    if (Idx < 0)
353      continue;
354    // Ensure the indices in each SrcType sized piece are sequential and that
355    // the same source is used for the whole piece.
356    if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
357        (ConcatSrcs[i / SrcNumElts] >= 0 &&
358         ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
359      return false;
360    // Remember which source this index came from.
361    ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
362  }
363
364  // The shuffle is concatenating multiple vectors together.
365  // Collect the different operands for that.
366  Register UndefReg;
367  Register Src2 = MI.getOperand(2).getReg();
368  for (auto Src : ConcatSrcs) {
369    if (Src < 0) {
370      if (!UndefReg) {
371        Builder.setInsertPt(*MI.getParent(), MI);
372        UndefReg = Builder.buildUndef(SrcType).getReg(0);
373      }
374      Ops.push_back(UndefReg);
375    } else if (Src == 0)
376      Ops.push_back(Src1);
377    else
378      Ops.push_back(Src2);
379  }
380  return true;
381}
382
383void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
384                                               const ArrayRef<Register> Ops) {
385  Register DstReg = MI.getOperand(0).getReg();
386  Builder.setInsertPt(*MI.getParent(), MI);
387  Register NewDstReg = MRI.cloneVirtualRegister(DstReg);
388
389  if (Ops.size() == 1)
390    Builder.buildCopy(NewDstReg, Ops[0]);
391  else
392    Builder.buildMergeLikeInstr(NewDstReg, Ops);
393
394  MI.eraseFromParent();
395  replaceRegWith(MRI, DstReg, NewDstReg);
396}
397
398namespace {
399
400/// Select a preference between two uses. CurrentUse is the current preference
401/// while *ForCandidate is attributes of the candidate under consideration.
402PreferredTuple ChoosePreferredUse(PreferredTuple &CurrentUse,
403                                  const LLT TyForCandidate,
404                                  unsigned OpcodeForCandidate,
405                                  MachineInstr *MIForCandidate) {
406  if (!CurrentUse.Ty.isValid()) {
407    if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
408        CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
409      return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
410    return CurrentUse;
411  }
412
413  // We permit the extend to hoist through basic blocks but this is only
414  // sensible if the target has extending loads. If you end up lowering back
415  // into a load and extend during the legalizer then the end result is
416  // hoisting the extend up to the load.
417
418  // Prefer defined extensions to undefined extensions as these are more
419  // likely to reduce the number of instructions.
420  if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
421      CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
422    return CurrentUse;
423  else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
424           OpcodeForCandidate != TargetOpcode::G_ANYEXT)
425    return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
426
427  // Prefer sign extensions to zero extensions as sign-extensions tend to be
428  // more expensive.
429  if (CurrentUse.Ty == TyForCandidate) {
430    if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
431        OpcodeForCandidate == TargetOpcode::G_ZEXT)
432      return CurrentUse;
433    else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
434             OpcodeForCandidate == TargetOpcode::G_SEXT)
435      return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
436  }
437
438  // This is potentially target specific. We've chosen the largest type
439  // because G_TRUNC is usually free. One potential catch with this is that
440  // some targets have a reduced number of larger registers than smaller
441  // registers and this choice potentially increases the live-range for the
442  // larger value.
443  if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
444    return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
445  }
446  return CurrentUse;
447}
448
449/// Find a suitable place to insert some instructions and insert them. This
450/// function accounts for special cases like inserting before a PHI node.
451/// The current strategy for inserting before PHI's is to duplicate the
452/// instructions for each predecessor. However, while that's ok for G_TRUNC
453/// on most targets since it generally requires no code, other targets/cases may
454/// want to try harder to find a dominating block.
455static void InsertInsnsWithoutSideEffectsBeforeUse(
456    MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
457    std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
458                       MachineOperand &UseMO)>
459        Inserter) {
460  MachineInstr &UseMI = *UseMO.getParent();
461
462  MachineBasicBlock *InsertBB = UseMI.getParent();
463
464  // If the use is a PHI then we want the predecessor block instead.
465  if (UseMI.isPHI()) {
466    MachineOperand *PredBB = std::next(&UseMO);
467    InsertBB = PredBB->getMBB();
468  }
469
470  // If the block is the same block as the def then we want to insert just after
471  // the def instead of at the start of the block.
472  if (InsertBB == DefMI.getParent()) {
473    MachineBasicBlock::iterator InsertPt = &DefMI;
474    Inserter(InsertBB, std::next(InsertPt), UseMO);
475    return;
476  }
477
478  // Otherwise we want the start of the BB
479  Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
480}
481} // end anonymous namespace
482
483bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) {
484  PreferredTuple Preferred;
485  if (matchCombineExtendingLoads(MI, Preferred)) {
486    applyCombineExtendingLoads(MI, Preferred);
487    return true;
488  }
489  return false;
490}
491
492static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
493  unsigned CandidateLoadOpc;
494  switch (ExtOpc) {
495  case TargetOpcode::G_ANYEXT:
496    CandidateLoadOpc = TargetOpcode::G_LOAD;
497    break;
498  case TargetOpcode::G_SEXT:
499    CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
500    break;
501  case TargetOpcode::G_ZEXT:
502    CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
503    break;
504  default:
505    llvm_unreachable("Unexpected extend opc");
506  }
507  return CandidateLoadOpc;
508}
509
510bool CombinerHelper::matchCombineExtendingLoads(MachineInstr &MI,
511                                                PreferredTuple &Preferred) {
512  // We match the loads and follow the uses to the extend instead of matching
513  // the extends and following the def to the load. This is because the load
514  // must remain in the same position for correctness (unless we also add code
515  // to find a safe place to sink it) whereas the extend is freely movable.
516  // It also prevents us from duplicating the load for the volatile case or just
517  // for performance.
518  GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(&MI);
519  if (!LoadMI)
520    return false;
521
522  Register LoadReg = LoadMI->getDstReg();
523
524  LLT LoadValueTy = MRI.getType(LoadReg);
525  if (!LoadValueTy.isScalar())
526    return false;
527
528  // Most architectures are going to legalize <s8 loads into at least a 1 byte
529  // load, and the MMOs can only describe memory accesses in multiples of bytes.
530  // If we try to perform extload combining on those, we can end up with
531  // %a(s8) = extload %ptr (load 1 byte from %ptr)
532  // ... which is an illegal extload instruction.
533  if (LoadValueTy.getSizeInBits() < 8)
534    return false;
535
536  // For non power-of-2 types, they will very likely be legalized into multiple
537  // loads. Don't bother trying to match them into extending loads.
538  if (!isPowerOf2_32(LoadValueTy.getSizeInBits()))
539    return false;
540
541  // Find the preferred type aside from the any-extends (unless it's the only
542  // one) and non-extending ops. We'll emit an extending load to that type and
543  // and emit a variant of (extend (trunc X)) for the others according to the
544  // relative type sizes. At the same time, pick an extend to use based on the
545  // extend involved in the chosen type.
546  unsigned PreferredOpcode =
547      isa<GLoad>(&MI)
548          ? TargetOpcode::G_ANYEXT
549          : isa<GSExtLoad>(&MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
550  Preferred = {LLT(), PreferredOpcode, nullptr};
551  for (auto &UseMI : MRI.use_nodbg_instructions(LoadReg)) {
552    if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
553        UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
554        (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
555      const auto &MMO = LoadMI->getMMO();
556      // For atomics, only form anyextending loads.
557      if (MMO.isAtomic() && UseMI.getOpcode() != TargetOpcode::G_ANYEXT)
558        continue;
559      // Check for legality.
560      if (!isPreLegalize()) {
561        LegalityQuery::MemDesc MMDesc(MMO);
562        unsigned CandidateLoadOpc = getExtLoadOpcForExtend(UseMI.getOpcode());
563        LLT UseTy = MRI.getType(UseMI.getOperand(0).getReg());
564        LLT SrcTy = MRI.getType(LoadMI->getPointerReg());
565        if (LI->getAction({CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
566                .Action != LegalizeActions::Legal)
567          continue;
568      }
569      Preferred = ChoosePreferredUse(Preferred,
570                                     MRI.getType(UseMI.getOperand(0).getReg()),
571                                     UseMI.getOpcode(), &UseMI);
572    }
573  }
574
575  // There were no extends
576  if (!Preferred.MI)
577    return false;
578  // It should be impossible to chose an extend without selecting a different
579  // type since by definition the result of an extend is larger.
580  assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
581
582  LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
583  return true;
584}
585
586void CombinerHelper::applyCombineExtendingLoads(MachineInstr &MI,
587                                                PreferredTuple &Preferred) {
588  // Rewrite the load to the chosen extending load.
589  Register ChosenDstReg = Preferred.MI->getOperand(0).getReg();
590
591  // Inserter to insert a truncate back to the original type at a given point
592  // with some basic CSE to limit truncate duplication to one per BB.
593  DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
594  auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
595                           MachineBasicBlock::iterator InsertBefore,
596                           MachineOperand &UseMO) {
597    MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(InsertIntoBB);
598    if (PreviouslyEmitted) {
599      Observer.changingInstr(*UseMO.getParent());
600      UseMO.setReg(PreviouslyEmitted->getOperand(0).getReg());
601      Observer.changedInstr(*UseMO.getParent());
602      return;
603    }
604
605    Builder.setInsertPt(*InsertIntoBB, InsertBefore);
606    Register NewDstReg = MRI.cloneVirtualRegister(MI.getOperand(0).getReg());
607    MachineInstr *NewMI = Builder.buildTrunc(NewDstReg, ChosenDstReg);
608    EmittedInsns[InsertIntoBB] = NewMI;
609    replaceRegOpWith(MRI, UseMO, NewDstReg);
610  };
611
612  Observer.changingInstr(MI);
613  unsigned LoadOpc = getExtLoadOpcForExtend(Preferred.ExtendOpcode);
614  MI.setDesc(Builder.getTII().get(LoadOpc));
615
616  // Rewrite all the uses to fix up the types.
617  auto &LoadValue = MI.getOperand(0);
618  SmallVector<MachineOperand *, 4> Uses;
619  for (auto &UseMO : MRI.use_operands(LoadValue.getReg()))
620    Uses.push_back(&UseMO);
621
622  for (auto *UseMO : Uses) {
623    MachineInstr *UseMI = UseMO->getParent();
624
625    // If the extend is compatible with the preferred extend then we should fix
626    // up the type and extend so that it uses the preferred use.
627    if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
628        UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
629      Register UseDstReg = UseMI->getOperand(0).getReg();
630      MachineOperand &UseSrcMO = UseMI->getOperand(1);
631      const LLT UseDstTy = MRI.getType(UseDstReg);
632      if (UseDstReg != ChosenDstReg) {
633        if (Preferred.Ty == UseDstTy) {
634          // If the use has the same type as the preferred use, then merge
635          // the vregs and erase the extend. For example:
636          //    %1:_(s8) = G_LOAD ...
637          //    %2:_(s32) = G_SEXT %1(s8)
638          //    %3:_(s32) = G_ANYEXT %1(s8)
639          //    ... = ... %3(s32)
640          // rewrites to:
641          //    %2:_(s32) = G_SEXTLOAD ...
642          //    ... = ... %2(s32)
643          replaceRegWith(MRI, UseDstReg, ChosenDstReg);
644          Observer.erasingInstr(*UseMO->getParent());
645          UseMO->getParent()->eraseFromParent();
646        } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
647          // If the preferred size is smaller, then keep the extend but extend
648          // from the result of the extending load. For example:
649          //    %1:_(s8) = G_LOAD ...
650          //    %2:_(s32) = G_SEXT %1(s8)
651          //    %3:_(s64) = G_ANYEXT %1(s8)
652          //    ... = ... %3(s64)
653          /// rewrites to:
654          //    %2:_(s32) = G_SEXTLOAD ...
655          //    %3:_(s64) = G_ANYEXT %2:_(s32)
656          //    ... = ... %3(s64)
657          replaceRegOpWith(MRI, UseSrcMO, ChosenDstReg);
658        } else {
659          // If the preferred size is large, then insert a truncate. For
660          // example:
661          //    %1:_(s8) = G_LOAD ...
662          //    %2:_(s64) = G_SEXT %1(s8)
663          //    %3:_(s32) = G_ZEXT %1(s8)
664          //    ... = ... %3(s32)
665          /// rewrites to:
666          //    %2:_(s64) = G_SEXTLOAD ...
667          //    %4:_(s8) = G_TRUNC %2:_(s32)
668          //    %3:_(s64) = G_ZEXT %2:_(s8)
669          //    ... = ... %3(s64)
670          InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO,
671                                                 InsertTruncAt);
672        }
673        continue;
674      }
675      // The use is (one of) the uses of the preferred use we chose earlier.
676      // We're going to update the load to def this value later so just erase
677      // the old extend.
678      Observer.erasingInstr(*UseMO->getParent());
679      UseMO->getParent()->eraseFromParent();
680      continue;
681    }
682
683    // The use isn't an extend. Truncate back to the type we originally loaded.
684    // This is free on many targets.
685    InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, InsertTruncAt);
686  }
687
688  MI.getOperand(0).setReg(ChosenDstReg);
689  Observer.changedInstr(MI);
690}
691
692bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
693                                                 BuildFnTy &MatchInfo) {
694  assert(MI.getOpcode() == TargetOpcode::G_AND);
695
696  // If we have the following code:
697  //  %mask = G_CONSTANT 255
698  //  %ld   = G_LOAD %ptr, (load s16)
699  //  %and  = G_AND %ld, %mask
700  //
701  // Try to fold it into
702  //   %ld = G_ZEXTLOAD %ptr, (load s8)
703
704  Register Dst = MI.getOperand(0).getReg();
705  if (MRI.getType(Dst).isVector())
706    return false;
707
708  auto MaybeMask =
709      getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
710  if (!MaybeMask)
711    return false;
712
713  APInt MaskVal = MaybeMask->Value;
714
715  if (!MaskVal.isMask())
716    return false;
717
718  Register SrcReg = MI.getOperand(1).getReg();
719  // Don't use getOpcodeDef() here since intermediate instructions may have
720  // multiple users.
721  GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(MRI.getVRegDef(SrcReg));
722  if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg()))
723    return false;
724
725  Register LoadReg = LoadMI->getDstReg();
726  LLT RegTy = MRI.getType(LoadReg);
727  Register PtrReg = LoadMI->getPointerReg();
728  unsigned RegSize = RegTy.getSizeInBits();
729  uint64_t LoadSizeBits = LoadMI->getMemSizeInBits();
730  unsigned MaskSizeBits = MaskVal.countTrailingOnes();
731
732  // The mask may not be larger than the in-memory type, as it might cover sign
733  // extended bits
734  if (MaskSizeBits > LoadSizeBits)
735    return false;
736
737  // If the mask covers the whole destination register, there's nothing to
738  // extend
739  if (MaskSizeBits >= RegSize)
740    return false;
741
742  // Most targets cannot deal with loads of size < 8 and need to re-legalize to
743  // at least byte loads. Avoid creating such loads here
744  if (MaskSizeBits < 8 || !isPowerOf2_32(MaskSizeBits))
745    return false;
746
747  const MachineMemOperand &MMO = LoadMI->getMMO();
748  LegalityQuery::MemDesc MemDesc(MMO);
749
750  // Don't modify the memory access size if this is atomic/volatile, but we can
751  // still adjust the opcode to indicate the high bit behavior.
752  if (LoadMI->isSimple())
753    MemDesc.MemoryTy = LLT::scalar(MaskSizeBits);
754  else if (LoadSizeBits > MaskSizeBits || LoadSizeBits == RegSize)
755    return false;
756
757  // TODO: Could check if it's legal with the reduced or original memory size.
758  if (!isLegalOrBeforeLegalizer(
759          {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(PtrReg)}, {MemDesc}}))
760    return false;
761
762  MatchInfo = [=](MachineIRBuilder &B) {
763    B.setInstrAndDebugLoc(*LoadMI);
764    auto &MF = B.getMF();
765    auto PtrInfo = MMO.getPointerInfo();
766    auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MemDesc.MemoryTy);
767    B.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, Dst, PtrReg, *NewMMO);
768    LoadMI->eraseFromParent();
769  };
770  return true;
771}
772
773bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
774                                   const MachineInstr &UseMI) {
775  assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
776         "shouldn't consider debug uses");
777  assert(DefMI.getParent() == UseMI.getParent());
778  if (&DefMI == &UseMI)
779    return true;
780  const MachineBasicBlock &MBB = *DefMI.getParent();
781  auto DefOrUse = find_if(MBB, [&DefMI, &UseMI](const MachineInstr &MI) {
782    return &MI == &DefMI || &MI == &UseMI;
783  });
784  if (DefOrUse == MBB.end())
785    llvm_unreachable("Block must contain both DefMI and UseMI!");
786  return &*DefOrUse == &DefMI;
787}
788
789bool CombinerHelper::dominates(const MachineInstr &DefMI,
790                               const MachineInstr &UseMI) {
791  assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
792         "shouldn't consider debug uses");
793  if (MDT)
794    return MDT->dominates(&DefMI, &UseMI);
795  else if (DefMI.getParent() != UseMI.getParent())
796    return false;
797
798  return isPredecessor(DefMI, UseMI);
799}
800
801bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) {
802  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
803  Register SrcReg = MI.getOperand(1).getReg();
804  Register LoadUser = SrcReg;
805
806  if (MRI.getType(SrcReg).isVector())
807    return false;
808
809  Register TruncSrc;
810  if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc))))
811    LoadUser = TruncSrc;
812
813  uint64_t SizeInBits = MI.getOperand(2).getImm();
814  // If the source is a G_SEXTLOAD from the same bit width, then we don't
815  // need any extend at all, just a truncate.
816  if (auto *LoadMI = getOpcodeDef<GSExtLoad>(LoadUser, MRI)) {
817    // If truncating more than the original extended value, abort.
818    auto LoadSizeBits = LoadMI->getMemSizeInBits();
819    if (TruncSrc && MRI.getType(TruncSrc).getSizeInBits() < LoadSizeBits)
820      return false;
821    if (LoadSizeBits == SizeInBits)
822      return true;
823  }
824  return false;
825}
826
827void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) {
828  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
829  Builder.setInstrAndDebugLoc(MI);
830  Builder.buildCopy(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
831  MI.eraseFromParent();
832}
833
834bool CombinerHelper::matchSextInRegOfLoad(
835    MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
836  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
837
838  Register DstReg = MI.getOperand(0).getReg();
839  LLT RegTy = MRI.getType(DstReg);
840
841  // Only supports scalars for now.
842  if (RegTy.isVector())
843    return false;
844
845  Register SrcReg = MI.getOperand(1).getReg();
846  auto *LoadDef = getOpcodeDef<GLoad>(SrcReg, MRI);
847  if (!LoadDef || !MRI.hasOneNonDBGUse(DstReg))
848    return false;
849
850  uint64_t MemBits = LoadDef->getMemSizeInBits();
851
852  // If the sign extend extends from a narrower width than the load's width,
853  // then we can narrow the load width when we combine to a G_SEXTLOAD.
854  // Avoid widening the load at all.
855  unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), MemBits);
856
857  // Don't generate G_SEXTLOADs with a < 1 byte width.
858  if (NewSizeBits < 8)
859    return false;
860  // Don't bother creating a non-power-2 sextload, it will likely be broken up
861  // anyway for most targets.
862  if (!isPowerOf2_32(NewSizeBits))
863    return false;
864
865  const MachineMemOperand &MMO = LoadDef->getMMO();
866  LegalityQuery::MemDesc MMDesc(MMO);
867
868  // Don't modify the memory access size if this is atomic/volatile, but we can
869  // still adjust the opcode to indicate the high bit behavior.
870  if (LoadDef->isSimple())
871    MMDesc.MemoryTy = LLT::scalar(NewSizeBits);
872  else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
873    return false;
874
875  // TODO: Could check if it's legal with the reduced or original memory size.
876  if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SEXTLOAD,
877                                 {MRI.getType(LoadDef->getDstReg()),
878                                  MRI.getType(LoadDef->getPointerReg())},
879                                 {MMDesc}}))
880    return false;
881
882  MatchInfo = std::make_tuple(LoadDef->getDstReg(), NewSizeBits);
883  return true;
884}
885
886void CombinerHelper::applySextInRegOfLoad(
887    MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
888  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
889  Register LoadReg;
890  unsigned ScalarSizeBits;
891  std::tie(LoadReg, ScalarSizeBits) = MatchInfo;
892  GLoad *LoadDef = cast<GLoad>(MRI.getVRegDef(LoadReg));
893
894  // If we have the following:
895  // %ld = G_LOAD %ptr, (load 2)
896  // %ext = G_SEXT_INREG %ld, 8
897  //    ==>
898  // %ld = G_SEXTLOAD %ptr (load 1)
899
900  auto &MMO = LoadDef->getMMO();
901  Builder.setInstrAndDebugLoc(*LoadDef);
902  auto &MF = Builder.getMF();
903  auto PtrInfo = MMO.getPointerInfo();
904  auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, ScalarSizeBits / 8);
905  Builder.buildLoadInstr(TargetOpcode::G_SEXTLOAD, MI.getOperand(0).getReg(),
906                         LoadDef->getPointerReg(), *NewMMO);
907  MI.eraseFromParent();
908}
909
910bool CombinerHelper::findPostIndexCandidate(MachineInstr &MI, Register &Addr,
911                                            Register &Base, Register &Offset) {
912  auto &MF = *MI.getParent()->getParent();
913  const auto &TLI = *MF.getSubtarget().getTargetLowering();
914
915#ifndef NDEBUG
916  unsigned Opcode = MI.getOpcode();
917  assert(Opcode == TargetOpcode::G_LOAD || Opcode == TargetOpcode::G_SEXTLOAD ||
918         Opcode == TargetOpcode::G_ZEXTLOAD || Opcode == TargetOpcode::G_STORE);
919#endif
920
921  Base = MI.getOperand(1).getReg();
922  MachineInstr *BaseDef = MRI.getUniqueVRegDef(Base);
923  if (BaseDef && BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
924    return false;
925
926  LLVM_DEBUG(dbgs() << "Searching for post-indexing opportunity for: " << MI);
927  // FIXME: The following use traversal needs a bail out for patholigical cases.
928  for (auto &Use : MRI.use_nodbg_instructions(Base)) {
929    if (Use.getOpcode() != TargetOpcode::G_PTR_ADD)
930      continue;
931
932    Offset = Use.getOperand(2).getReg();
933    if (!ForceLegalIndexing &&
934        !TLI.isIndexingLegal(MI, Base, Offset, /*IsPre*/ false, MRI)) {
935      LLVM_DEBUG(dbgs() << "    Ignoring candidate with illegal addrmode: "
936                        << Use);
937      continue;
938    }
939
940    // Make sure the offset calculation is before the potentially indexed op.
941    // FIXME: we really care about dependency here. The offset calculation might
942    // be movable.
943    MachineInstr *OffsetDef = MRI.getUniqueVRegDef(Offset);
944    if (!OffsetDef || !dominates(*OffsetDef, MI)) {
945      LLVM_DEBUG(dbgs() << "    Ignoring candidate with offset after mem-op: "
946                        << Use);
947      continue;
948    }
949
950    // FIXME: check whether all uses of Base are load/store with foldable
951    // addressing modes. If so, using the normal addr-modes is better than
952    // forming an indexed one.
953
954    bool MemOpDominatesAddrUses = true;
955    for (auto &PtrAddUse :
956         MRI.use_nodbg_instructions(Use.getOperand(0).getReg())) {
957      if (!dominates(MI, PtrAddUse)) {
958        MemOpDominatesAddrUses = false;
959        break;
960      }
961    }
962
963    if (!MemOpDominatesAddrUses) {
964      LLVM_DEBUG(
965          dbgs() << "    Ignoring candidate as memop does not dominate uses: "
966                 << Use);
967      continue;
968    }
969
970    LLVM_DEBUG(dbgs() << "    Found match: " << Use);
971    Addr = Use.getOperand(0).getReg();
972    return true;
973  }
974
975  return false;
976}
977
978bool CombinerHelper::findPreIndexCandidate(MachineInstr &MI, Register &Addr,
979                                           Register &Base, Register &Offset) {
980  auto &MF = *MI.getParent()->getParent();
981  const auto &TLI = *MF.getSubtarget().getTargetLowering();
982
983#ifndef NDEBUG
984  unsigned Opcode = MI.getOpcode();
985  assert(Opcode == TargetOpcode::G_LOAD || Opcode == TargetOpcode::G_SEXTLOAD ||
986         Opcode == TargetOpcode::G_ZEXTLOAD || Opcode == TargetOpcode::G_STORE);
987#endif
988
989  Addr = MI.getOperand(1).getReg();
990  MachineInstr *AddrDef = getOpcodeDef(TargetOpcode::G_PTR_ADD, Addr, MRI);
991  if (!AddrDef || MRI.hasOneNonDBGUse(Addr))
992    return false;
993
994  Base = AddrDef->getOperand(1).getReg();
995  Offset = AddrDef->getOperand(2).getReg();
996
997  LLVM_DEBUG(dbgs() << "Found potential pre-indexed load_store: " << MI);
998
999  if (!ForceLegalIndexing &&
1000      !TLI.isIndexingLegal(MI, Base, Offset, /*IsPre*/ true, MRI)) {
1001    LLVM_DEBUG(dbgs() << "    Skipping, not legal for target");
1002    return false;
1003  }
1004
1005  MachineInstr *BaseDef = getDefIgnoringCopies(Base, MRI);
1006  if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) {
1007    LLVM_DEBUG(dbgs() << "    Skipping, frame index would need copy anyway.");
1008    return false;
1009  }
1010
1011  if (MI.getOpcode() == TargetOpcode::G_STORE) {
1012    // Would require a copy.
1013    if (Base == MI.getOperand(0).getReg()) {
1014      LLVM_DEBUG(dbgs() << "    Skipping, storing base so need copy anyway.");
1015      return false;
1016    }
1017
1018    // We're expecting one use of Addr in MI, but it could also be the
1019    // value stored, which isn't actually dominated by the instruction.
1020    if (MI.getOperand(0).getReg() == Addr) {
1021      LLVM_DEBUG(dbgs() << "    Skipping, does not dominate all addr uses");
1022      return false;
1023    }
1024  }
1025
1026  // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1027  // That might allow us to end base's liveness here by adjusting the constant.
1028
1029  for (auto &UseMI : MRI.use_nodbg_instructions(Addr)) {
1030    if (!dominates(MI, UseMI)) {
1031      LLVM_DEBUG(dbgs() << "    Skipping, does not dominate all addr uses.");
1032      return false;
1033    }
1034  }
1035
1036  return true;
1037}
1038
1039bool CombinerHelper::tryCombineIndexedLoadStore(MachineInstr &MI) {
1040  IndexedLoadStoreMatchInfo MatchInfo;
1041  if (matchCombineIndexedLoadStore(MI, MatchInfo)) {
1042    applyCombineIndexedLoadStore(MI, MatchInfo);
1043    return true;
1044  }
1045  return false;
1046}
1047
1048bool CombinerHelper::matchCombineIndexedLoadStore(MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) {
1049  unsigned Opcode = MI.getOpcode();
1050  if (Opcode != TargetOpcode::G_LOAD && Opcode != TargetOpcode::G_SEXTLOAD &&
1051      Opcode != TargetOpcode::G_ZEXTLOAD && Opcode != TargetOpcode::G_STORE)
1052    return false;
1053
1054  // For now, no targets actually support these opcodes so don't waste time
1055  // running these unless we're forced to for testing.
1056  if (!ForceLegalIndexing)
1057    return false;
1058
1059  MatchInfo.IsPre = findPreIndexCandidate(MI, MatchInfo.Addr, MatchInfo.Base,
1060                                          MatchInfo.Offset);
1061  if (!MatchInfo.IsPre &&
1062      !findPostIndexCandidate(MI, MatchInfo.Addr, MatchInfo.Base,
1063                              MatchInfo.Offset))
1064    return false;
1065
1066  return true;
1067}
1068
1069void CombinerHelper::applyCombineIndexedLoadStore(
1070    MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) {
1071  MachineInstr &AddrDef = *MRI.getUniqueVRegDef(MatchInfo.Addr);
1072  MachineIRBuilder MIRBuilder(MI);
1073  unsigned Opcode = MI.getOpcode();
1074  bool IsStore = Opcode == TargetOpcode::G_STORE;
1075  unsigned NewOpcode;
1076  switch (Opcode) {
1077  case TargetOpcode::G_LOAD:
1078    NewOpcode = TargetOpcode::G_INDEXED_LOAD;
1079    break;
1080  case TargetOpcode::G_SEXTLOAD:
1081    NewOpcode = TargetOpcode::G_INDEXED_SEXTLOAD;
1082    break;
1083  case TargetOpcode::G_ZEXTLOAD:
1084    NewOpcode = TargetOpcode::G_INDEXED_ZEXTLOAD;
1085    break;
1086  case TargetOpcode::G_STORE:
1087    NewOpcode = TargetOpcode::G_INDEXED_STORE;
1088    break;
1089  default:
1090    llvm_unreachable("Unknown load/store opcode");
1091  }
1092
1093  auto MIB = MIRBuilder.buildInstr(NewOpcode);
1094  if (IsStore) {
1095    MIB.addDef(MatchInfo.Addr);
1096    MIB.addUse(MI.getOperand(0).getReg());
1097  } else {
1098    MIB.addDef(MI.getOperand(0).getReg());
1099    MIB.addDef(MatchInfo.Addr);
1100  }
1101
1102  MIB.addUse(MatchInfo.Base);
1103  MIB.addUse(MatchInfo.Offset);
1104  MIB.addImm(MatchInfo.IsPre);
1105  MI.eraseFromParent();
1106  AddrDef.eraseFromParent();
1107
1108  LLVM_DEBUG(dbgs() << "    Combinined to indexed operation");
1109}
1110
1111bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1112                                        MachineInstr *&OtherMI) {
1113  unsigned Opcode = MI.getOpcode();
1114  bool IsDiv, IsSigned;
1115
1116  switch (Opcode) {
1117  default:
1118    llvm_unreachable("Unexpected opcode!");
1119  case TargetOpcode::G_SDIV:
1120  case TargetOpcode::G_UDIV: {
1121    IsDiv = true;
1122    IsSigned = Opcode == TargetOpcode::G_SDIV;
1123    break;
1124  }
1125  case TargetOpcode::G_SREM:
1126  case TargetOpcode::G_UREM: {
1127    IsDiv = false;
1128    IsSigned = Opcode == TargetOpcode::G_SREM;
1129    break;
1130  }
1131  }
1132
1133  Register Src1 = MI.getOperand(1).getReg();
1134  unsigned DivOpcode, RemOpcode, DivremOpcode;
1135  if (IsSigned) {
1136    DivOpcode = TargetOpcode::G_SDIV;
1137    RemOpcode = TargetOpcode::G_SREM;
1138    DivremOpcode = TargetOpcode::G_SDIVREM;
1139  } else {
1140    DivOpcode = TargetOpcode::G_UDIV;
1141    RemOpcode = TargetOpcode::G_UREM;
1142    DivremOpcode = TargetOpcode::G_UDIVREM;
1143  }
1144
1145  if (!isLegalOrBeforeLegalizer({DivremOpcode, {MRI.getType(Src1)}}))
1146    return false;
1147
1148  // Combine:
1149  //   %div:_ = G_[SU]DIV %src1:_, %src2:_
1150  //   %rem:_ = G_[SU]REM %src1:_, %src2:_
1151  // into:
1152  //  %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1153
1154  // Combine:
1155  //   %rem:_ = G_[SU]REM %src1:_, %src2:_
1156  //   %div:_ = G_[SU]DIV %src1:_, %src2:_
1157  // into:
1158  //  %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1159
1160  for (auto &UseMI : MRI.use_nodbg_instructions(Src1)) {
1161    if (MI.getParent() == UseMI.getParent() &&
1162        ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1163         (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1164        matchEqualDefs(MI.getOperand(2), UseMI.getOperand(2)) &&
1165        matchEqualDefs(MI.getOperand(1), UseMI.getOperand(1))) {
1166      OtherMI = &UseMI;
1167      return true;
1168    }
1169  }
1170
1171  return false;
1172}
1173
1174void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1175                                        MachineInstr *&OtherMI) {
1176  unsigned Opcode = MI.getOpcode();
1177  assert(OtherMI && "OtherMI shouldn't be empty.");
1178
1179  Register DestDivReg, DestRemReg;
1180  if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1181    DestDivReg = MI.getOperand(0).getReg();
1182    DestRemReg = OtherMI->getOperand(0).getReg();
1183  } else {
1184    DestDivReg = OtherMI->getOperand(0).getReg();
1185    DestRemReg = MI.getOperand(0).getReg();
1186  }
1187
1188  bool IsSigned =
1189      Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1190
1191  // Check which instruction is first in the block so we don't break def-use
1192  // deps by "moving" the instruction incorrectly.
1193  if (dominates(MI, *OtherMI))
1194    Builder.setInstrAndDebugLoc(MI);
1195  else
1196    Builder.setInstrAndDebugLoc(*OtherMI);
1197
1198  Builder.buildInstr(IsSigned ? TargetOpcode::G_SDIVREM
1199                              : TargetOpcode::G_UDIVREM,
1200                     {DestDivReg, DestRemReg},
1201                     {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
1202  MI.eraseFromParent();
1203  OtherMI->eraseFromParent();
1204}
1205
1206bool CombinerHelper::matchOptBrCondByInvertingCond(MachineInstr &MI,
1207                                                   MachineInstr *&BrCond) {
1208  assert(MI.getOpcode() == TargetOpcode::G_BR);
1209
1210  // Try to match the following:
1211  // bb1:
1212  //   G_BRCOND %c1, %bb2
1213  //   G_BR %bb3
1214  // bb2:
1215  // ...
1216  // bb3:
1217
1218  // The above pattern does not have a fall through to the successor bb2, always
1219  // resulting in a branch no matter which path is taken. Here we try to find
1220  // and replace that pattern with conditional branch to bb3 and otherwise
1221  // fallthrough to bb2. This is generally better for branch predictors.
1222
1223  MachineBasicBlock *MBB = MI.getParent();
1224  MachineBasicBlock::iterator BrIt(MI);
1225  if (BrIt == MBB->begin())
1226    return false;
1227  assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1228
1229  BrCond = &*std::prev(BrIt);
1230  if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1231    return false;
1232
1233  // Check that the next block is the conditional branch target. Also make sure
1234  // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1235  MachineBasicBlock *BrCondTarget = BrCond->getOperand(1).getMBB();
1236  return BrCondTarget != MI.getOperand(0).getMBB() &&
1237         MBB->isLayoutSuccessor(BrCondTarget);
1238}
1239
1240void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI,
1241                                                   MachineInstr *&BrCond) {
1242  MachineBasicBlock *BrTarget = MI.getOperand(0).getMBB();
1243  Builder.setInstrAndDebugLoc(*BrCond);
1244  LLT Ty = MRI.getType(BrCond->getOperand(0).getReg());
1245  // FIXME: Does int/fp matter for this? If so, we might need to restrict
1246  // this to i1 only since we might not know for sure what kind of
1247  // compare generated the condition value.
1248  auto True = Builder.buildConstant(
1249      Ty, getICmpTrueVal(getTargetLowering(), false, false));
1250  auto Xor = Builder.buildXor(Ty, BrCond->getOperand(0), True);
1251
1252  auto *FallthroughBB = BrCond->getOperand(1).getMBB();
1253  Observer.changingInstr(MI);
1254  MI.getOperand(0).setMBB(FallthroughBB);
1255  Observer.changedInstr(MI);
1256
1257  // Change the conditional branch to use the inverted condition and
1258  // new target block.
1259  Observer.changingInstr(*BrCond);
1260  BrCond->getOperand(0).setReg(Xor.getReg(0));
1261  BrCond->getOperand(1).setMBB(BrTarget);
1262  Observer.changedInstr(*BrCond);
1263}
1264
1265static Type *getTypeForLLT(LLT Ty, LLVMContext &C) {
1266  if (Ty.isVector())
1267    return FixedVectorType::get(IntegerType::get(C, Ty.getScalarSizeInBits()),
1268                                Ty.getNumElements());
1269  return IntegerType::get(C, Ty.getSizeInBits());
1270}
1271
1272bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) {
1273  MachineIRBuilder HelperBuilder(MI);
1274  GISelObserverWrapper DummyObserver;
1275  LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1276  return Helper.lowerMemcpyInline(MI) ==
1277         LegalizerHelper::LegalizeResult::Legalized;
1278}
1279
1280bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen) {
1281  MachineIRBuilder HelperBuilder(MI);
1282  GISelObserverWrapper DummyObserver;
1283  LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1284  return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1285         LegalizerHelper::LegalizeResult::Legalized;
1286}
1287
1288static std::optional<APFloat>
1289constantFoldFpUnary(unsigned Opcode, LLT DstTy, const Register Op,
1290                    const MachineRegisterInfo &MRI) {
1291  const ConstantFP *MaybeCst = getConstantFPVRegVal(Op, MRI);
1292  if (!MaybeCst)
1293    return std::nullopt;
1294
1295  APFloat V = MaybeCst->getValueAPF();
1296  switch (Opcode) {
1297  default:
1298    llvm_unreachable("Unexpected opcode!");
1299  case TargetOpcode::G_FNEG: {
1300    V.changeSign();
1301    return V;
1302  }
1303  case TargetOpcode::G_FABS: {
1304    V.clearSign();
1305    return V;
1306  }
1307  case TargetOpcode::G_FPTRUNC:
1308    break;
1309  case TargetOpcode::G_FSQRT: {
1310    bool Unused;
1311    V.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &Unused);
1312    V = APFloat(sqrt(V.convertToDouble()));
1313    break;
1314  }
1315  case TargetOpcode::G_FLOG2: {
1316    bool Unused;
1317    V.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &Unused);
1318    V = APFloat(log2(V.convertToDouble()));
1319    break;
1320  }
1321  }
1322  // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1323  // `buildFConstant` will assert on size mismatch. Only `G_FPTRUNC`, `G_FSQRT`,
1324  // and `G_FLOG2` reach here.
1325  bool Unused;
1326  V.convert(getFltSemanticForLLT(DstTy), APFloat::rmNearestTiesToEven, &Unused);
1327  return V;
1328}
1329
1330bool CombinerHelper::matchCombineConstantFoldFpUnary(
1331    MachineInstr &MI, std::optional<APFloat> &Cst) {
1332  Register DstReg = MI.getOperand(0).getReg();
1333  Register SrcReg = MI.getOperand(1).getReg();
1334  LLT DstTy = MRI.getType(DstReg);
1335  Cst = constantFoldFpUnary(MI.getOpcode(), DstTy, SrcReg, MRI);
1336  return Cst.has_value();
1337}
1338
1339void CombinerHelper::applyCombineConstantFoldFpUnary(
1340    MachineInstr &MI, std::optional<APFloat> &Cst) {
1341  assert(Cst && "Optional is unexpectedly empty!");
1342  Builder.setInstrAndDebugLoc(MI);
1343  MachineFunction &MF = Builder.getMF();
1344  auto *FPVal = ConstantFP::get(MF.getFunction().getContext(), *Cst);
1345  Register DstReg = MI.getOperand(0).getReg();
1346  Builder.buildFConstant(DstReg, *FPVal);
1347  MI.eraseFromParent();
1348}
1349
1350bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1351                                           PtrAddChain &MatchInfo) {
1352  // We're trying to match the following pattern:
1353  //   %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1354  //   %root = G_PTR_ADD %t1, G_CONSTANT imm2
1355  // -->
1356  //   %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1357
1358  if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1359    return false;
1360
1361  Register Add2 = MI.getOperand(1).getReg();
1362  Register Imm1 = MI.getOperand(2).getReg();
1363  auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI);
1364  if (!MaybeImmVal)
1365    return false;
1366
1367  MachineInstr *Add2Def = MRI.getVRegDef(Add2);
1368  if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1369    return false;
1370
1371  Register Base = Add2Def->getOperand(1).getReg();
1372  Register Imm2 = Add2Def->getOperand(2).getReg();
1373  auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI);
1374  if (!MaybeImm2Val)
1375    return false;
1376
1377  // Check if the new combined immediate forms an illegal addressing mode.
1378  // Do not combine if it was legal before but would get illegal.
1379  // To do so, we need to find a load/store user of the pointer to get
1380  // the access type.
1381  Type *AccessTy = nullptr;
1382  auto &MF = *MI.getMF();
1383  for (auto &UseMI : MRI.use_nodbg_instructions(MI.getOperand(0).getReg())) {
1384    if (auto *LdSt = dyn_cast<GLoadStore>(&UseMI)) {
1385      AccessTy = getTypeForLLT(MRI.getType(LdSt->getReg(0)),
1386                               MF.getFunction().getContext());
1387      break;
1388    }
1389  }
1390  TargetLoweringBase::AddrMode AMNew;
1391  APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1392  AMNew.BaseOffs = CombinedImm.getSExtValue();
1393  if (AccessTy) {
1394    AMNew.HasBaseReg = true;
1395    TargetLoweringBase::AddrMode AMOld;
1396    AMOld.BaseOffs = MaybeImm2Val->Value.getSExtValue();
1397    AMOld.HasBaseReg = true;
1398    unsigned AS = MRI.getType(Add2).getAddressSpace();
1399    const auto &TLI = *MF.getSubtarget().getTargetLowering();
1400    if (TLI.isLegalAddressingMode(MF.getDataLayout(), AMOld, AccessTy, AS) &&
1401        !TLI.isLegalAddressingMode(MF.getDataLayout(), AMNew, AccessTy, AS))
1402      return false;
1403  }
1404
1405  // Pass the combined immediate to the apply function.
1406  MatchInfo.Imm = AMNew.BaseOffs;
1407  MatchInfo.Base = Base;
1408  MatchInfo.Bank = getRegBank(Imm2);
1409  return true;
1410}
1411
1412void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1413                                           PtrAddChain &MatchInfo) {
1414  assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1415  MachineIRBuilder MIB(MI);
1416  LLT OffsetTy = MRI.getType(MI.getOperand(2).getReg());
1417  auto NewOffset = MIB.buildConstant(OffsetTy, MatchInfo.Imm);
1418  setRegBank(NewOffset.getReg(0), MatchInfo.Bank);
1419  Observer.changingInstr(MI);
1420  MI.getOperand(1).setReg(MatchInfo.Base);
1421  MI.getOperand(2).setReg(NewOffset.getReg(0));
1422  Observer.changedInstr(MI);
1423}
1424
1425bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1426                                          RegisterImmPair &MatchInfo) {
1427  // We're trying to match the following pattern with any of
1428  // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1429  //   %t1 = SHIFT %base, G_CONSTANT imm1
1430  //   %root = SHIFT %t1, G_CONSTANT imm2
1431  // -->
1432  //   %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1433
1434  unsigned Opcode = MI.getOpcode();
1435  assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1436          Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1437          Opcode == TargetOpcode::G_USHLSAT) &&
1438         "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1439
1440  Register Shl2 = MI.getOperand(1).getReg();
1441  Register Imm1 = MI.getOperand(2).getReg();
1442  auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI);
1443  if (!MaybeImmVal)
1444    return false;
1445
1446  MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Shl2);
1447  if (Shl2Def->getOpcode() != Opcode)
1448    return false;
1449
1450  Register Base = Shl2Def->getOperand(1).getReg();
1451  Register Imm2 = Shl2Def->getOperand(2).getReg();
1452  auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI);
1453  if (!MaybeImm2Val)
1454    return false;
1455
1456  // Pass the combined immediate to the apply function.
1457  MatchInfo.Imm =
1458      (MaybeImmVal->Value.getSExtValue() + MaybeImm2Val->Value).getSExtValue();
1459  MatchInfo.Reg = Base;
1460
1461  // There is no simple replacement for a saturating unsigned left shift that
1462  // exceeds the scalar size.
1463  if (Opcode == TargetOpcode::G_USHLSAT &&
1464      MatchInfo.Imm >= MRI.getType(Shl2).getScalarSizeInBits())
1465    return false;
1466
1467  return true;
1468}
1469
1470void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1471                                          RegisterImmPair &MatchInfo) {
1472  unsigned Opcode = MI.getOpcode();
1473  assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1474          Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1475          Opcode == TargetOpcode::G_USHLSAT) &&
1476         "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1477
1478  Builder.setInstrAndDebugLoc(MI);
1479  LLT Ty = MRI.getType(MI.getOperand(1).getReg());
1480  unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1481  auto Imm = MatchInfo.Imm;
1482
1483  if (Imm >= ScalarSizeInBits) {
1484    // Any logical shift that exceeds scalar size will produce zero.
1485    if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1486      Builder.buildConstant(MI.getOperand(0), 0);
1487      MI.eraseFromParent();
1488      return;
1489    }
1490    // Arithmetic shift and saturating signed left shift have no effect beyond
1491    // scalar size.
1492    Imm = ScalarSizeInBits - 1;
1493  }
1494
1495  LLT ImmTy = MRI.getType(MI.getOperand(2).getReg());
1496  Register NewImm = Builder.buildConstant(ImmTy, Imm).getReg(0);
1497  Observer.changingInstr(MI);
1498  MI.getOperand(1).setReg(MatchInfo.Reg);
1499  MI.getOperand(2).setReg(NewImm);
1500  Observer.changedInstr(MI);
1501}
1502
1503bool CombinerHelper::matchShiftOfShiftedLogic(MachineInstr &MI,
1504                                              ShiftOfShiftedLogic &MatchInfo) {
1505  // We're trying to match the following pattern with any of
1506  // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1507  // with any of G_AND/G_OR/G_XOR logic instructions.
1508  //   %t1 = SHIFT %X, G_CONSTANT C0
1509  //   %t2 = LOGIC %t1, %Y
1510  //   %root = SHIFT %t2, G_CONSTANT C1
1511  // -->
1512  //   %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1513  //   %t4 = SHIFT %Y, G_CONSTANT C1
1514  //   %root = LOGIC %t3, %t4
1515  unsigned ShiftOpcode = MI.getOpcode();
1516  assert((ShiftOpcode == TargetOpcode::G_SHL ||
1517          ShiftOpcode == TargetOpcode::G_ASHR ||
1518          ShiftOpcode == TargetOpcode::G_LSHR ||
1519          ShiftOpcode == TargetOpcode::G_USHLSAT ||
1520          ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1521         "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1522
1523  // Match a one-use bitwise logic op.
1524  Register LogicDest = MI.getOperand(1).getReg();
1525  if (!MRI.hasOneNonDBGUse(LogicDest))
1526    return false;
1527
1528  MachineInstr *LogicMI = MRI.getUniqueVRegDef(LogicDest);
1529  unsigned LogicOpcode = LogicMI->getOpcode();
1530  if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1531      LogicOpcode != TargetOpcode::G_XOR)
1532    return false;
1533
1534  // Find a matching one-use shift by constant.
1535  const Register C1 = MI.getOperand(2).getReg();
1536  auto MaybeImmVal = getIConstantVRegValWithLookThrough(C1, MRI);
1537  if (!MaybeImmVal)
1538    return false;
1539
1540  const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1541
1542  auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1543    // Shift should match previous one and should be a one-use.
1544    if (MI->getOpcode() != ShiftOpcode ||
1545        !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
1546      return false;
1547
1548    // Must be a constant.
1549    auto MaybeImmVal =
1550        getIConstantVRegValWithLookThrough(MI->getOperand(2).getReg(), MRI);
1551    if (!MaybeImmVal)
1552      return false;
1553
1554    ShiftVal = MaybeImmVal->Value.getSExtValue();
1555    return true;
1556  };
1557
1558  // Logic ops are commutative, so check each operand for a match.
1559  Register LogicMIReg1 = LogicMI->getOperand(1).getReg();
1560  MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(LogicMIReg1);
1561  Register LogicMIReg2 = LogicMI->getOperand(2).getReg();
1562  MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(LogicMIReg2);
1563  uint64_t C0Val;
1564
1565  if (matchFirstShift(LogicMIOp1, C0Val)) {
1566    MatchInfo.LogicNonShiftReg = LogicMIReg2;
1567    MatchInfo.Shift2 = LogicMIOp1;
1568  } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1569    MatchInfo.LogicNonShiftReg = LogicMIReg1;
1570    MatchInfo.Shift2 = LogicMIOp2;
1571  } else
1572    return false;
1573
1574  MatchInfo.ValSum = C0Val + C1Val;
1575
1576  // The fold is not valid if the sum of the shift values exceeds bitwidth.
1577  if (MatchInfo.ValSum >= MRI.getType(LogicDest).getScalarSizeInBits())
1578    return false;
1579
1580  MatchInfo.Logic = LogicMI;
1581  return true;
1582}
1583
1584void CombinerHelper::applyShiftOfShiftedLogic(MachineInstr &MI,
1585                                              ShiftOfShiftedLogic &MatchInfo) {
1586  unsigned Opcode = MI.getOpcode();
1587  assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1588          Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
1589          Opcode == TargetOpcode::G_SSHLSAT) &&
1590         "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1591
1592  LLT ShlType = MRI.getType(MI.getOperand(2).getReg());
1593  LLT DestType = MRI.getType(MI.getOperand(0).getReg());
1594  Builder.setInstrAndDebugLoc(MI);
1595
1596  Register Const = Builder.buildConstant(ShlType, MatchInfo.ValSum).getReg(0);
1597
1598  Register Shift1Base = MatchInfo.Shift2->getOperand(1).getReg();
1599  Register Shift1 =
1600      Builder.buildInstr(Opcode, {DestType}, {Shift1Base, Const}).getReg(0);
1601
1602  // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
1603  // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
1604  // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
1605  // remove old shift1. And it will cause crash later. So erase it earlier to
1606  // avoid the crash.
1607  MatchInfo.Shift2->eraseFromParent();
1608
1609  Register Shift2Const = MI.getOperand(2).getReg();
1610  Register Shift2 = Builder
1611                        .buildInstr(Opcode, {DestType},
1612                                    {MatchInfo.LogicNonShiftReg, Shift2Const})
1613                        .getReg(0);
1614
1615  Register Dest = MI.getOperand(0).getReg();
1616  Builder.buildInstr(MatchInfo.Logic->getOpcode(), {Dest}, {Shift1, Shift2});
1617
1618  // This was one use so it's safe to remove it.
1619  MatchInfo.Logic->eraseFromParent();
1620
1621  MI.eraseFromParent();
1622}
1623
1624bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
1625                                          unsigned &ShiftVal) {
1626  assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
1627  auto MaybeImmVal =
1628      getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
1629  if (!MaybeImmVal)
1630    return false;
1631
1632  ShiftVal = MaybeImmVal->Value.exactLogBase2();
1633  return (static_cast<int32_t>(ShiftVal) != -1);
1634}
1635
1636void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
1637                                          unsigned &ShiftVal) {
1638  assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
1639  MachineIRBuilder MIB(MI);
1640  LLT ShiftTy = MRI.getType(MI.getOperand(0).getReg());
1641  auto ShiftCst = MIB.buildConstant(ShiftTy, ShiftVal);
1642  Observer.changingInstr(MI);
1643  MI.setDesc(MIB.getTII().get(TargetOpcode::G_SHL));
1644  MI.getOperand(2).setReg(ShiftCst.getReg(0));
1645  Observer.changedInstr(MI);
1646}
1647
1648// shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
1649bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
1650                                             RegisterImmPair &MatchData) {
1651  assert(MI.getOpcode() == TargetOpcode::G_SHL && KB);
1652
1653  Register LHS = MI.getOperand(1).getReg();
1654
1655  Register ExtSrc;
1656  if (!mi_match(LHS, MRI, m_GAnyExt(m_Reg(ExtSrc))) &&
1657      !mi_match(LHS, MRI, m_GZExt(m_Reg(ExtSrc))) &&
1658      !mi_match(LHS, MRI, m_GSExt(m_Reg(ExtSrc))))
1659    return false;
1660
1661  // TODO: Should handle vector splat.
1662  Register RHS = MI.getOperand(2).getReg();
1663  auto MaybeShiftAmtVal = getIConstantVRegValWithLookThrough(RHS, MRI);
1664  if (!MaybeShiftAmtVal)
1665    return false;
1666
1667  if (LI) {
1668    LLT SrcTy = MRI.getType(ExtSrc);
1669
1670    // We only really care about the legality with the shifted value. We can
1671    // pick any type the constant shift amount, so ask the target what to
1672    // use. Otherwise we would have to guess and hope it is reported as legal.
1673    LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(SrcTy);
1674    if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
1675      return false;
1676  }
1677
1678  int64_t ShiftAmt = MaybeShiftAmtVal->Value.getSExtValue();
1679  MatchData.Reg = ExtSrc;
1680  MatchData.Imm = ShiftAmt;
1681
1682  unsigned MinLeadingZeros = KB->getKnownZeroes(ExtSrc).countLeadingOnes();
1683  return MinLeadingZeros >= ShiftAmt;
1684}
1685
1686void CombinerHelper::applyCombineShlOfExtend(MachineInstr &MI,
1687                                             const RegisterImmPair &MatchData) {
1688  Register ExtSrcReg = MatchData.Reg;
1689  int64_t ShiftAmtVal = MatchData.Imm;
1690
1691  LLT ExtSrcTy = MRI.getType(ExtSrcReg);
1692  Builder.setInstrAndDebugLoc(MI);
1693  auto ShiftAmt = Builder.buildConstant(ExtSrcTy, ShiftAmtVal);
1694  auto NarrowShift =
1695      Builder.buildShl(ExtSrcTy, ExtSrcReg, ShiftAmt, MI.getFlags());
1696  Builder.buildZExt(MI.getOperand(0), NarrowShift);
1697  MI.eraseFromParent();
1698}
1699
1700bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
1701                                              Register &MatchInfo) {
1702  GMerge &Merge = cast<GMerge>(MI);
1703  SmallVector<Register, 16> MergedValues;
1704  for (unsigned I = 0; I < Merge.getNumSources(); ++I)
1705    MergedValues.emplace_back(Merge.getSourceReg(I));
1706
1707  auto *Unmerge = getOpcodeDef<GUnmerge>(MergedValues[0], MRI);
1708  if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
1709    return false;
1710
1711  for (unsigned I = 0; I < MergedValues.size(); ++I)
1712    if (MergedValues[I] != Unmerge->getReg(I))
1713      return false;
1714
1715  MatchInfo = Unmerge->getSourceReg();
1716  return true;
1717}
1718
1719static Register peekThroughBitcast(Register Reg,
1720                                   const MachineRegisterInfo &MRI) {
1721  while (mi_match(Reg, MRI, m_GBitcast(m_Reg(Reg))))
1722    ;
1723
1724  return Reg;
1725}
1726
1727bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
1728    MachineInstr &MI, SmallVectorImpl<Register> &Operands) {
1729  assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1730         "Expected an unmerge");
1731  auto &Unmerge = cast<GUnmerge>(MI);
1732  Register SrcReg = peekThroughBitcast(Unmerge.getSourceReg(), MRI);
1733
1734  auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(SrcReg, MRI);
1735  if (!SrcInstr)
1736    return false;
1737
1738  // Check the source type of the merge.
1739  LLT SrcMergeTy = MRI.getType(SrcInstr->getSourceReg(0));
1740  LLT Dst0Ty = MRI.getType(Unmerge.getReg(0));
1741  bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
1742  if (SrcMergeTy != Dst0Ty && !SameSize)
1743    return false;
1744  // They are the same now (modulo a bitcast).
1745  // We can collect all the src registers.
1746  for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
1747    Operands.push_back(SrcInstr->getSourceReg(Idx));
1748  return true;
1749}
1750
1751void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
1752    MachineInstr &MI, SmallVectorImpl<Register> &Operands) {
1753  assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1754         "Expected an unmerge");
1755  assert((MI.getNumOperands() - 1 == Operands.size()) &&
1756         "Not enough operands to replace all defs");
1757  unsigned NumElems = MI.getNumOperands() - 1;
1758
1759  LLT SrcTy = MRI.getType(Operands[0]);
1760  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1761  bool CanReuseInputDirectly = DstTy == SrcTy;
1762  Builder.setInstrAndDebugLoc(MI);
1763  for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
1764    Register DstReg = MI.getOperand(Idx).getReg();
1765    Register SrcReg = Operands[Idx];
1766    if (CanReuseInputDirectly)
1767      replaceRegWith(MRI, DstReg, SrcReg);
1768    else
1769      Builder.buildCast(DstReg, SrcReg);
1770  }
1771  MI.eraseFromParent();
1772}
1773
1774bool CombinerHelper::matchCombineUnmergeConstant(MachineInstr &MI,
1775                                                 SmallVectorImpl<APInt> &Csts) {
1776  unsigned SrcIdx = MI.getNumOperands() - 1;
1777  Register SrcReg = MI.getOperand(SrcIdx).getReg();
1778  MachineInstr *SrcInstr = MRI.getVRegDef(SrcReg);
1779  if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
1780      SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
1781    return false;
1782  // Break down the big constant in smaller ones.
1783  const MachineOperand &CstVal = SrcInstr->getOperand(1);
1784  APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
1785                  ? CstVal.getCImm()->getValue()
1786                  : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
1787
1788  LLT Dst0Ty = MRI.getType(MI.getOperand(0).getReg());
1789  unsigned ShiftAmt = Dst0Ty.getSizeInBits();
1790  // Unmerge a constant.
1791  for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
1792    Csts.emplace_back(Val.trunc(ShiftAmt));
1793    Val = Val.lshr(ShiftAmt);
1794  }
1795
1796  return true;
1797}
1798
1799void CombinerHelper::applyCombineUnmergeConstant(MachineInstr &MI,
1800                                                 SmallVectorImpl<APInt> &Csts) {
1801  assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1802         "Expected an unmerge");
1803  assert((MI.getNumOperands() - 1 == Csts.size()) &&
1804         "Not enough operands to replace all defs");
1805  unsigned NumElems = MI.getNumOperands() - 1;
1806  Builder.setInstrAndDebugLoc(MI);
1807  for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
1808    Register DstReg = MI.getOperand(Idx).getReg();
1809    Builder.buildConstant(DstReg, Csts[Idx]);
1810  }
1811
1812  MI.eraseFromParent();
1813}
1814
1815bool CombinerHelper::matchCombineUnmergeUndef(
1816    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
1817  unsigned SrcIdx = MI.getNumOperands() - 1;
1818  Register SrcReg = MI.getOperand(SrcIdx).getReg();
1819  MatchInfo = [&MI](MachineIRBuilder &B) {
1820    unsigned NumElems = MI.getNumOperands() - 1;
1821    for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
1822      Register DstReg = MI.getOperand(Idx).getReg();
1823      B.buildUndef(DstReg);
1824    }
1825  };
1826  return isa<GImplicitDef>(MRI.getVRegDef(SrcReg));
1827}
1828
1829bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) {
1830  assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1831         "Expected an unmerge");
1832  // Check that all the lanes are dead except the first one.
1833  for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
1834    if (!MRI.use_nodbg_empty(MI.getOperand(Idx).getReg()))
1835      return false;
1836  }
1837  return true;
1838}
1839
1840void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) {
1841  Builder.setInstrAndDebugLoc(MI);
1842  Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg();
1843  // Truncating a vector is going to truncate every single lane,
1844  // whereas we want the full lowbits.
1845  // Do the operation on a scalar instead.
1846  LLT SrcTy = MRI.getType(SrcReg);
1847  if (SrcTy.isVector())
1848    SrcReg =
1849        Builder.buildCast(LLT::scalar(SrcTy.getSizeInBits()), SrcReg).getReg(0);
1850
1851  Register Dst0Reg = MI.getOperand(0).getReg();
1852  LLT Dst0Ty = MRI.getType(Dst0Reg);
1853  if (Dst0Ty.isVector()) {
1854    auto MIB = Builder.buildTrunc(LLT::scalar(Dst0Ty.getSizeInBits()), SrcReg);
1855    Builder.buildCast(Dst0Reg, MIB);
1856  } else
1857    Builder.buildTrunc(Dst0Reg, SrcReg);
1858  MI.eraseFromParent();
1859}
1860
1861bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) {
1862  assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1863         "Expected an unmerge");
1864  Register Dst0Reg = MI.getOperand(0).getReg();
1865  LLT Dst0Ty = MRI.getType(Dst0Reg);
1866  // G_ZEXT on vector applies to each lane, so it will
1867  // affect all destinations. Therefore we won't be able
1868  // to simplify the unmerge to just the first definition.
1869  if (Dst0Ty.isVector())
1870    return false;
1871  Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg();
1872  LLT SrcTy = MRI.getType(SrcReg);
1873  if (SrcTy.isVector())
1874    return false;
1875
1876  Register ZExtSrcReg;
1877  if (!mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZExtSrcReg))))
1878    return false;
1879
1880  // Finally we can replace the first definition with
1881  // a zext of the source if the definition is big enough to hold
1882  // all of ZExtSrc bits.
1883  LLT ZExtSrcTy = MRI.getType(ZExtSrcReg);
1884  return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
1885}
1886
1887void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) {
1888  assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1889         "Expected an unmerge");
1890
1891  Register Dst0Reg = MI.getOperand(0).getReg();
1892
1893  MachineInstr *ZExtInstr =
1894      MRI.getVRegDef(MI.getOperand(MI.getNumDefs()).getReg());
1895  assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
1896         "Expecting a G_ZEXT");
1897
1898  Register ZExtSrcReg = ZExtInstr->getOperand(1).getReg();
1899  LLT Dst0Ty = MRI.getType(Dst0Reg);
1900  LLT ZExtSrcTy = MRI.getType(ZExtSrcReg);
1901
1902  Builder.setInstrAndDebugLoc(MI);
1903
1904  if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
1905    Builder.buildZExt(Dst0Reg, ZExtSrcReg);
1906  } else {
1907    assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
1908           "ZExt src doesn't fit in destination");
1909    replaceRegWith(MRI, Dst0Reg, ZExtSrcReg);
1910  }
1911
1912  Register ZeroReg;
1913  for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
1914    if (!ZeroReg)
1915      ZeroReg = Builder.buildConstant(Dst0Ty, 0).getReg(0);
1916    replaceRegWith(MRI, MI.getOperand(Idx).getReg(), ZeroReg);
1917  }
1918  MI.eraseFromParent();
1919}
1920
1921bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
1922                                                unsigned TargetShiftSize,
1923                                                unsigned &ShiftVal) {
1924  assert((MI.getOpcode() == TargetOpcode::G_SHL ||
1925          MI.getOpcode() == TargetOpcode::G_LSHR ||
1926          MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
1927
1928  LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1929  if (Ty.isVector()) // TODO:
1930    return false;
1931
1932  // Don't narrow further than the requested size.
1933  unsigned Size = Ty.getSizeInBits();
1934  if (Size <= TargetShiftSize)
1935    return false;
1936
1937  auto MaybeImmVal =
1938      getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
1939  if (!MaybeImmVal)
1940    return false;
1941
1942  ShiftVal = MaybeImmVal->Value.getSExtValue();
1943  return ShiftVal >= Size / 2 && ShiftVal < Size;
1944}
1945
1946void CombinerHelper::applyCombineShiftToUnmerge(MachineInstr &MI,
1947                                                const unsigned &ShiftVal) {
1948  Register DstReg = MI.getOperand(0).getReg();
1949  Register SrcReg = MI.getOperand(1).getReg();
1950  LLT Ty = MRI.getType(SrcReg);
1951  unsigned Size = Ty.getSizeInBits();
1952  unsigned HalfSize = Size / 2;
1953  assert(ShiftVal >= HalfSize);
1954
1955  LLT HalfTy = LLT::scalar(HalfSize);
1956
1957  Builder.setInstr(MI);
1958  auto Unmerge = Builder.buildUnmerge(HalfTy, SrcReg);
1959  unsigned NarrowShiftAmt = ShiftVal - HalfSize;
1960
1961  if (MI.getOpcode() == TargetOpcode::G_LSHR) {
1962    Register Narrowed = Unmerge.getReg(1);
1963
1964    //  dst = G_LSHR s64:x, C for C >= 32
1965    // =>
1966    //   lo, hi = G_UNMERGE_VALUES x
1967    //   dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
1968
1969    if (NarrowShiftAmt != 0) {
1970      Narrowed = Builder.buildLShr(HalfTy, Narrowed,
1971        Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0);
1972    }
1973
1974    auto Zero = Builder.buildConstant(HalfTy, 0);
1975    Builder.buildMergeLikeInstr(DstReg, {Narrowed, Zero});
1976  } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
1977    Register Narrowed = Unmerge.getReg(0);
1978    //  dst = G_SHL s64:x, C for C >= 32
1979    // =>
1980    //   lo, hi = G_UNMERGE_VALUES x
1981    //   dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
1982    if (NarrowShiftAmt != 0) {
1983      Narrowed = Builder.buildShl(HalfTy, Narrowed,
1984        Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0);
1985    }
1986
1987    auto Zero = Builder.buildConstant(HalfTy, 0);
1988    Builder.buildMergeLikeInstr(DstReg, {Zero, Narrowed});
1989  } else {
1990    assert(MI.getOpcode() == TargetOpcode::G_ASHR);
1991    auto Hi = Builder.buildAShr(
1992      HalfTy, Unmerge.getReg(1),
1993      Builder.buildConstant(HalfTy, HalfSize - 1));
1994
1995    if (ShiftVal == HalfSize) {
1996      // (G_ASHR i64:x, 32) ->
1997      //   G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
1998      Builder.buildMergeLikeInstr(DstReg, {Unmerge.getReg(1), Hi});
1999    } else if (ShiftVal == Size - 1) {
2000      // Don't need a second shift.
2001      // (G_ASHR i64:x, 63) ->
2002      //   %narrowed = (G_ASHR hi_32(x), 31)
2003      //   G_MERGE_VALUES %narrowed, %narrowed
2004      Builder.buildMergeLikeInstr(DstReg, {Hi, Hi});
2005    } else {
2006      auto Lo = Builder.buildAShr(
2007        HalfTy, Unmerge.getReg(1),
2008        Builder.buildConstant(HalfTy, ShiftVal - HalfSize));
2009
2010      // (G_ASHR i64:x, C) ->, for C >= 32
2011      //   G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2012      Builder.buildMergeLikeInstr(DstReg, {Lo, Hi});
2013    }
2014  }
2015
2016  MI.eraseFromParent();
2017}
2018
2019bool CombinerHelper::tryCombineShiftToUnmerge(MachineInstr &MI,
2020                                              unsigned TargetShiftAmount) {
2021  unsigned ShiftAmt;
2022  if (matchCombineShiftToUnmerge(MI, TargetShiftAmount, ShiftAmt)) {
2023    applyCombineShiftToUnmerge(MI, ShiftAmt);
2024    return true;
2025  }
2026
2027  return false;
2028}
2029
2030bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, Register &Reg) {
2031  assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2032  Register DstReg = MI.getOperand(0).getReg();
2033  LLT DstTy = MRI.getType(DstReg);
2034  Register SrcReg = MI.getOperand(1).getReg();
2035  return mi_match(SrcReg, MRI,
2036                  m_GPtrToInt(m_all_of(m_SpecificType(DstTy), m_Reg(Reg))));
2037}
2038
2039void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, Register &Reg) {
2040  assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2041  Register DstReg = MI.getOperand(0).getReg();
2042  Builder.setInstr(MI);
2043  Builder.buildCopy(DstReg, Reg);
2044  MI.eraseFromParent();
2045}
2046
2047void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, Register &Reg) {
2048  assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2049  Register DstReg = MI.getOperand(0).getReg();
2050  Builder.setInstr(MI);
2051  Builder.buildZExtOrTrunc(DstReg, Reg);
2052  MI.eraseFromParent();
2053}
2054
2055bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2056    MachineInstr &MI, std::pair<Register, bool> &PtrReg) {
2057  assert(MI.getOpcode() == TargetOpcode::G_ADD);
2058  Register LHS = MI.getOperand(1).getReg();
2059  Register RHS = MI.getOperand(2).getReg();
2060  LLT IntTy = MRI.getType(LHS);
2061
2062  // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2063  // instruction.
2064  PtrReg.second = false;
2065  for (Register SrcReg : {LHS, RHS}) {
2066    if (mi_match(SrcReg, MRI, m_GPtrToInt(m_Reg(PtrReg.first)))) {
2067      // Don't handle cases where the integer is implicitly converted to the
2068      // pointer width.
2069      LLT PtrTy = MRI.getType(PtrReg.first);
2070      if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2071        return true;
2072    }
2073
2074    PtrReg.second = true;
2075  }
2076
2077  return false;
2078}
2079
2080void CombinerHelper::applyCombineAddP2IToPtrAdd(
2081    MachineInstr &MI, std::pair<Register, bool> &PtrReg) {
2082  Register Dst = MI.getOperand(0).getReg();
2083  Register LHS = MI.getOperand(1).getReg();
2084  Register RHS = MI.getOperand(2).getReg();
2085
2086  const bool DoCommute = PtrReg.second;
2087  if (DoCommute)
2088    std::swap(LHS, RHS);
2089  LHS = PtrReg.first;
2090
2091  LLT PtrTy = MRI.getType(LHS);
2092
2093  Builder.setInstrAndDebugLoc(MI);
2094  auto PtrAdd = Builder.buildPtrAdd(PtrTy, LHS, RHS);
2095  Builder.buildPtrToInt(Dst, PtrAdd);
2096  MI.eraseFromParent();
2097}
2098
2099bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2100                                                  APInt &NewCst) {
2101  auto &PtrAdd = cast<GPtrAdd>(MI);
2102  Register LHS = PtrAdd.getBaseReg();
2103  Register RHS = PtrAdd.getOffsetReg();
2104  MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2105
2106  if (auto RHSCst = getIConstantVRegVal(RHS, MRI)) {
2107    APInt Cst;
2108    if (mi_match(LHS, MRI, m_GIntToPtr(m_ICst(Cst)))) {
2109      auto DstTy = MRI.getType(PtrAdd.getReg(0));
2110      // G_INTTOPTR uses zero-extension
2111      NewCst = Cst.zextOrTrunc(DstTy.getSizeInBits());
2112      NewCst += RHSCst->sextOrTrunc(DstTy.getSizeInBits());
2113      return true;
2114    }
2115  }
2116
2117  return false;
2118}
2119
2120void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2121                                                  APInt &NewCst) {
2122  auto &PtrAdd = cast<GPtrAdd>(MI);
2123  Register Dst = PtrAdd.getReg(0);
2124
2125  Builder.setInstrAndDebugLoc(MI);
2126  Builder.buildConstant(Dst, NewCst);
2127  PtrAdd.eraseFromParent();
2128}
2129
2130bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, Register &Reg) {
2131  assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2132  Register DstReg = MI.getOperand(0).getReg();
2133  Register SrcReg = MI.getOperand(1).getReg();
2134  LLT DstTy = MRI.getType(DstReg);
2135  return mi_match(SrcReg, MRI,
2136                  m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))));
2137}
2138
2139bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, Register &Reg) {
2140  assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2141  Register DstReg = MI.getOperand(0).getReg();
2142  Register SrcReg = MI.getOperand(1).getReg();
2143  LLT DstTy = MRI.getType(DstReg);
2144  if (mi_match(SrcReg, MRI,
2145               m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))))) {
2146    unsigned DstSize = DstTy.getScalarSizeInBits();
2147    unsigned SrcSize = MRI.getType(SrcReg).getScalarSizeInBits();
2148    return KB->getKnownBits(Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2149  }
2150  return false;
2151}
2152
2153bool CombinerHelper::matchCombineExtOfExt(
2154    MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
2155  assert((MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2156          MI.getOpcode() == TargetOpcode::G_SEXT ||
2157          MI.getOpcode() == TargetOpcode::G_ZEXT) &&
2158         "Expected a G_[ASZ]EXT");
2159  Register SrcReg = MI.getOperand(1).getReg();
2160  MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
2161  // Match exts with the same opcode, anyext([sz]ext) and sext(zext).
2162  unsigned Opc = MI.getOpcode();
2163  unsigned SrcOpc = SrcMI->getOpcode();
2164  if (Opc == SrcOpc ||
2165      (Opc == TargetOpcode::G_ANYEXT &&
2166       (SrcOpc == TargetOpcode::G_SEXT || SrcOpc == TargetOpcode::G_ZEXT)) ||
2167      (Opc == TargetOpcode::G_SEXT && SrcOpc == TargetOpcode::G_ZEXT)) {
2168    MatchInfo = std::make_tuple(SrcMI->getOperand(1).getReg(), SrcOpc);
2169    return true;
2170  }
2171  return false;
2172}
2173
2174void CombinerHelper::applyCombineExtOfExt(
2175    MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
2176  assert((MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2177          MI.getOpcode() == TargetOpcode::G_SEXT ||
2178          MI.getOpcode() == TargetOpcode::G_ZEXT) &&
2179         "Expected a G_[ASZ]EXT");
2180
2181  Register Reg = std::get<0>(MatchInfo);
2182  unsigned SrcExtOp = std::get<1>(MatchInfo);
2183
2184  // Combine exts with the same opcode.
2185  if (MI.getOpcode() == SrcExtOp) {
2186    Observer.changingInstr(MI);
2187    MI.getOperand(1).setReg(Reg);
2188    Observer.changedInstr(MI);
2189    return;
2190  }
2191
2192  // Combine:
2193  // - anyext([sz]ext x) to [sz]ext x
2194  // - sext(zext x) to zext x
2195  if (MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2196      (MI.getOpcode() == TargetOpcode::G_SEXT &&
2197       SrcExtOp == TargetOpcode::G_ZEXT)) {
2198    Register DstReg = MI.getOperand(0).getReg();
2199    Builder.setInstrAndDebugLoc(MI);
2200    Builder.buildInstr(SrcExtOp, {DstReg}, {Reg});
2201    MI.eraseFromParent();
2202  }
2203}
2204
2205void CombinerHelper::applyCombineMulByNegativeOne(MachineInstr &MI) {
2206  assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2207  Register DstReg = MI.getOperand(0).getReg();
2208  Register SrcReg = MI.getOperand(1).getReg();
2209  LLT DstTy = MRI.getType(DstReg);
2210
2211  Builder.setInstrAndDebugLoc(MI);
2212  Builder.buildSub(DstReg, Builder.buildConstant(DstTy, 0), SrcReg,
2213                   MI.getFlags());
2214  MI.eraseFromParent();
2215}
2216
2217bool CombinerHelper::matchCombineFAbsOfFNeg(MachineInstr &MI,
2218                                            BuildFnTy &MatchInfo) {
2219  assert(MI.getOpcode() == TargetOpcode::G_FABS && "Expected a G_FABS");
2220  Register Src = MI.getOperand(1).getReg();
2221  Register NegSrc;
2222
2223  if (!mi_match(Src, MRI, m_GFNeg(m_Reg(NegSrc))))
2224    return false;
2225
2226  MatchInfo = [=, &MI](MachineIRBuilder &B) {
2227    Observer.changingInstr(MI);
2228    MI.getOperand(1).setReg(NegSrc);
2229    Observer.changedInstr(MI);
2230  };
2231  return true;
2232}
2233
2234bool CombinerHelper::matchCombineTruncOfExt(
2235    MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
2236  assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2237  Register SrcReg = MI.getOperand(1).getReg();
2238  MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
2239  unsigned SrcOpc = SrcMI->getOpcode();
2240  if (SrcOpc == TargetOpcode::G_ANYEXT || SrcOpc == TargetOpcode::G_SEXT ||
2241      SrcOpc == TargetOpcode::G_ZEXT) {
2242    MatchInfo = std::make_pair(SrcMI->getOperand(1).getReg(), SrcOpc);
2243    return true;
2244  }
2245  return false;
2246}
2247
2248void CombinerHelper::applyCombineTruncOfExt(
2249    MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
2250  assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2251  Register SrcReg = MatchInfo.first;
2252  unsigned SrcExtOp = MatchInfo.second;
2253  Register DstReg = MI.getOperand(0).getReg();
2254  LLT SrcTy = MRI.getType(SrcReg);
2255  LLT DstTy = MRI.getType(DstReg);
2256  if (SrcTy == DstTy) {
2257    MI.eraseFromParent();
2258    replaceRegWith(MRI, DstReg, SrcReg);
2259    return;
2260  }
2261  Builder.setInstrAndDebugLoc(MI);
2262  if (SrcTy.getSizeInBits() < DstTy.getSizeInBits())
2263    Builder.buildInstr(SrcExtOp, {DstReg}, {SrcReg});
2264  else
2265    Builder.buildTrunc(DstReg, SrcReg);
2266  MI.eraseFromParent();
2267}
2268
2269static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2270  const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2271  const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2272
2273  // ShiftTy > 32 > TruncTy -> 32
2274  if (ShiftSize > 32 && TruncSize < 32)
2275    return ShiftTy.changeElementSize(32);
2276
2277  // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2278  //  Some targets like it, some don't, some only like it under certain
2279  //  conditions/processor versions, etc.
2280  //  A TL hook might be needed for this.
2281
2282  // Don't combine
2283  return ShiftTy;
2284}
2285
2286bool CombinerHelper::matchCombineTruncOfShift(
2287    MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
2288  assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2289  Register DstReg = MI.getOperand(0).getReg();
2290  Register SrcReg = MI.getOperand(1).getReg();
2291
2292  if (!MRI.hasOneNonDBGUse(SrcReg))
2293    return false;
2294
2295  LLT SrcTy = MRI.getType(SrcReg);
2296  LLT DstTy = MRI.getType(DstReg);
2297
2298  MachineInstr *SrcMI = getDefIgnoringCopies(SrcReg, MRI);
2299  const auto &TL = getTargetLowering();
2300
2301  LLT NewShiftTy;
2302  switch (SrcMI->getOpcode()) {
2303  default:
2304    return false;
2305  case TargetOpcode::G_SHL: {
2306    NewShiftTy = DstTy;
2307
2308    // Make sure new shift amount is legal.
2309    KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
2310    if (Known.getMaxValue().uge(NewShiftTy.getScalarSizeInBits()))
2311      return false;
2312    break;
2313  }
2314  case TargetOpcode::G_LSHR:
2315  case TargetOpcode::G_ASHR: {
2316    // For right shifts, we conservatively do not do the transform if the TRUNC
2317    // has any STORE users. The reason is that if we change the type of the
2318    // shift, we may break the truncstore combine.
2319    //
2320    // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2321    for (auto &User : MRI.use_instructions(DstReg))
2322      if (User.getOpcode() == TargetOpcode::G_STORE)
2323        return false;
2324
2325    NewShiftTy = getMidVTForTruncRightShiftCombine(SrcTy, DstTy);
2326    if (NewShiftTy == SrcTy)
2327      return false;
2328
2329    // Make sure we won't lose information by truncating the high bits.
2330    KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
2331    if (Known.getMaxValue().ugt(NewShiftTy.getScalarSizeInBits() -
2332                                DstTy.getScalarSizeInBits()))
2333      return false;
2334    break;
2335  }
2336  }
2337
2338  if (!isLegalOrBeforeLegalizer(
2339          {SrcMI->getOpcode(),
2340           {NewShiftTy, TL.getPreferredShiftAmountTy(NewShiftTy)}}))
2341    return false;
2342
2343  MatchInfo = std::make_pair(SrcMI, NewShiftTy);
2344  return true;
2345}
2346
2347void CombinerHelper::applyCombineTruncOfShift(
2348    MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
2349  Builder.setInstrAndDebugLoc(MI);
2350
2351  MachineInstr *ShiftMI = MatchInfo.first;
2352  LLT NewShiftTy = MatchInfo.second;
2353
2354  Register Dst = MI.getOperand(0).getReg();
2355  LLT DstTy = MRI.getType(Dst);
2356
2357  Register ShiftAmt = ShiftMI->getOperand(2).getReg();
2358  Register ShiftSrc = ShiftMI->getOperand(1).getReg();
2359  ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0);
2360
2361  Register NewShift =
2362      Builder
2363          .buildInstr(ShiftMI->getOpcode(), {NewShiftTy}, {ShiftSrc, ShiftAmt})
2364          .getReg(0);
2365
2366  if (NewShiftTy == DstTy)
2367    replaceRegWith(MRI, Dst, NewShift);
2368  else
2369    Builder.buildTrunc(Dst, NewShift);
2370
2371  eraseInst(MI);
2372}
2373
2374bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) {
2375  return any_of(MI.explicit_uses(), [this](const MachineOperand &MO) {
2376    return MO.isReg() &&
2377           getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2378  });
2379}
2380
2381bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) {
2382  return all_of(MI.explicit_uses(), [this](const MachineOperand &MO) {
2383    return !MO.isReg() ||
2384           getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2385  });
2386}
2387
2388bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) {
2389  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2390  ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
2391  return all_of(Mask, [](int Elt) { return Elt < 0; });
2392}
2393
2394bool CombinerHelper::matchUndefStore(MachineInstr &MI) {
2395  assert(MI.getOpcode() == TargetOpcode::G_STORE);
2396  return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(0).getReg(),
2397                      MRI);
2398}
2399
2400bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) {
2401  assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2402  return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(),
2403                      MRI);
2404}
2405
2406bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(MachineInstr &MI) {
2407  assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2408          MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2409         "Expected an insert/extract element op");
2410  LLT VecTy = MRI.getType(MI.getOperand(1).getReg());
2411  unsigned IdxIdx =
2412      MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2413  auto Idx = getIConstantVRegVal(MI.getOperand(IdxIdx).getReg(), MRI);
2414  if (!Idx)
2415    return false;
2416  return Idx->getZExtValue() >= VecTy.getNumElements();
2417}
2418
2419bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) {
2420  GSelect &SelMI = cast<GSelect>(MI);
2421  auto Cst =
2422      isConstantOrConstantSplatVector(*MRI.getVRegDef(SelMI.getCondReg()), MRI);
2423  if (!Cst)
2424    return false;
2425  OpIdx = Cst->isZero() ? 3 : 2;
2426  return true;
2427}
2428
2429bool CombinerHelper::eraseInst(MachineInstr &MI) {
2430  MI.eraseFromParent();
2431  return true;
2432}
2433
2434bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2435                                    const MachineOperand &MOP2) {
2436  if (!MOP1.isReg() || !MOP2.isReg())
2437    return false;
2438  auto InstAndDef1 = getDefSrcRegIgnoringCopies(MOP1.getReg(), MRI);
2439  if (!InstAndDef1)
2440    return false;
2441  auto InstAndDef2 = getDefSrcRegIgnoringCopies(MOP2.getReg(), MRI);
2442  if (!InstAndDef2)
2443    return false;
2444  MachineInstr *I1 = InstAndDef1->MI;
2445  MachineInstr *I2 = InstAndDef2->MI;
2446
2447  // Handle a case like this:
2448  //
2449  // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2450  //
2451  // Even though %0 and %1 are produced by the same instruction they are not
2452  // the same values.
2453  if (I1 == I2)
2454    return MOP1.getReg() == MOP2.getReg();
2455
2456  // If we have an instruction which loads or stores, we can't guarantee that
2457  // it is identical.
2458  //
2459  // For example, we may have
2460  //
2461  // %x1 = G_LOAD %addr (load N from @somewhere)
2462  // ...
2463  // call @foo
2464  // ...
2465  // %x2 = G_LOAD %addr (load N from @somewhere)
2466  // ...
2467  // %or = G_OR %x1, %x2
2468  //
2469  // It's possible that @foo will modify whatever lives at the address we're
2470  // loading from. To be safe, let's just assume that all loads and stores
2471  // are different (unless we have something which is guaranteed to not
2472  // change.)
2473  if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2474    return false;
2475
2476  // If both instructions are loads or stores, they are equal only if both
2477  // are dereferenceable invariant loads with the same number of bits.
2478  if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2479    GLoadStore *LS1 = dyn_cast<GLoadStore>(I1);
2480    GLoadStore *LS2 = dyn_cast<GLoadStore>(I2);
2481    if (!LS1 || !LS2)
2482      return false;
2483
2484    if (!I2->isDereferenceableInvariantLoad() ||
2485        (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2486      return false;
2487  }
2488
2489  // Check for physical registers on the instructions first to avoid cases
2490  // like this:
2491  //
2492  // %a = COPY $physreg
2493  // ...
2494  // SOMETHING implicit-def $physreg
2495  // ...
2496  // %b = COPY $physreg
2497  //
2498  // These copies are not equivalent.
2499  if (any_of(I1->uses(), [](const MachineOperand &MO) {
2500        return MO.isReg() && MO.getReg().isPhysical();
2501      })) {
2502    // Check if we have a case like this:
2503    //
2504    // %a = COPY $physreg
2505    // %b = COPY %a
2506    //
2507    // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2508    // From that, we know that they must have the same value, since they must
2509    // have come from the same COPY.
2510    return I1->isIdenticalTo(*I2);
2511  }
2512
2513  // We don't have any physical registers, so we don't necessarily need the
2514  // same vreg defs.
2515  //
2516  // On the off-chance that there's some target instruction feeding into the
2517  // instruction, let's use produceSameValue instead of isIdenticalTo.
2518  if (Builder.getTII().produceSameValue(*I1, *I2, &MRI)) {
2519    // Handle instructions with multiple defs that produce same values. Values
2520    // are same for operands with same index.
2521    // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2522    // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2523    // I1 and I2 are different instructions but produce same values,
2524    // %1 and %6 are same, %1 and %7 are not the same value.
2525    return I1->findRegisterDefOperandIdx(InstAndDef1->Reg) ==
2526           I2->findRegisterDefOperandIdx(InstAndDef2->Reg);
2527  }
2528  return false;
2529}
2530
2531bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, int64_t C) {
2532  if (!MOP.isReg())
2533    return false;
2534  auto *MI = MRI.getVRegDef(MOP.getReg());
2535  auto MaybeCst = isConstantOrConstantSplatVector(*MI, MRI);
2536  return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2537         MaybeCst->getSExtValue() == C;
2538}
2539
2540bool CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2541                                                     unsigned OpIdx) {
2542  assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2543  Register OldReg = MI.getOperand(0).getReg();
2544  Register Replacement = MI.getOperand(OpIdx).getReg();
2545  assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2546  MI.eraseFromParent();
2547  replaceRegWith(MRI, OldReg, Replacement);
2548  return true;
2549}
2550
2551bool CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2552                                                 Register Replacement) {
2553  assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2554  Register OldReg = MI.getOperand(0).getReg();
2555  assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2556  MI.eraseFromParent();
2557  replaceRegWith(MRI, OldReg, Replacement);
2558  return true;
2559}
2560
2561bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) {
2562  assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2563  // Match (cond ? x : x)
2564  return matchEqualDefs(MI.getOperand(2), MI.getOperand(3)) &&
2565         canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(2).getReg(),
2566                       MRI);
2567}
2568
2569bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) {
2570  return matchEqualDefs(MI.getOperand(1), MI.getOperand(2)) &&
2571         canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(),
2572                       MRI);
2573}
2574
2575bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, unsigned OpIdx) {
2576  return matchConstantOp(MI.getOperand(OpIdx), 0) &&
2577         canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(OpIdx).getReg(),
2578                       MRI);
2579}
2580
2581bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, unsigned OpIdx) {
2582  MachineOperand &MO = MI.getOperand(OpIdx);
2583  return MO.isReg() &&
2584         getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2585}
2586
2587bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
2588                                                        unsigned OpIdx) {
2589  MachineOperand &MO = MI.getOperand(OpIdx);
2590  return isKnownToBeAPowerOfTwo(MO.getReg(), MRI, KB);
2591}
2592
2593bool CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, double C) {
2594  assert(MI.getNumDefs() == 1 && "Expected only one def?");
2595  Builder.setInstr(MI);
2596  Builder.buildFConstant(MI.getOperand(0), C);
2597  MI.eraseFromParent();
2598  return true;
2599}
2600
2601bool CombinerHelper::replaceInstWithConstant(MachineInstr &MI, int64_t C) {
2602  assert(MI.getNumDefs() == 1 && "Expected only one def?");
2603  Builder.setInstr(MI);
2604  Builder.buildConstant(MI.getOperand(0), C);
2605  MI.eraseFromParent();
2606  return true;
2607}
2608
2609bool CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) {
2610  assert(MI.getNumDefs() == 1 && "Expected only one def?");
2611  Builder.setInstr(MI);
2612  Builder.buildConstant(MI.getOperand(0), C);
2613  MI.eraseFromParent();
2614  return true;
2615}
2616
2617bool CombinerHelper::replaceInstWithUndef(MachineInstr &MI) {
2618  assert(MI.getNumDefs() == 1 && "Expected only one def?");
2619  Builder.setInstr(MI);
2620  Builder.buildUndef(MI.getOperand(0));
2621  MI.eraseFromParent();
2622  return true;
2623}
2624
2625bool CombinerHelper::matchSimplifyAddToSub(
2626    MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) {
2627  Register LHS = MI.getOperand(1).getReg();
2628  Register RHS = MI.getOperand(2).getReg();
2629  Register &NewLHS = std::get<0>(MatchInfo);
2630  Register &NewRHS = std::get<1>(MatchInfo);
2631
2632  // Helper lambda to check for opportunities for
2633  // ((0-A) + B) -> B - A
2634  // (A + (0-B)) -> A - B
2635  auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
2636    if (!mi_match(MaybeSub, MRI, m_Neg(m_Reg(NewRHS))))
2637      return false;
2638    NewLHS = MaybeNewLHS;
2639    return true;
2640  };
2641
2642  return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
2643}
2644
2645bool CombinerHelper::matchCombineInsertVecElts(
2646    MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) {
2647  assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
2648         "Invalid opcode");
2649  Register DstReg = MI.getOperand(0).getReg();
2650  LLT DstTy = MRI.getType(DstReg);
2651  assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
2652  unsigned NumElts = DstTy.getNumElements();
2653  // If this MI is part of a sequence of insert_vec_elts, then
2654  // don't do the combine in the middle of the sequence.
2655  if (MRI.hasOneUse(DstReg) && MRI.use_instr_begin(DstReg)->getOpcode() ==
2656                                   TargetOpcode::G_INSERT_VECTOR_ELT)
2657    return false;
2658  MachineInstr *CurrInst = &MI;
2659  MachineInstr *TmpInst;
2660  int64_t IntImm;
2661  Register TmpReg;
2662  MatchInfo.resize(NumElts);
2663  while (mi_match(
2664      CurrInst->getOperand(0).getReg(), MRI,
2665      m_GInsertVecElt(m_MInstr(TmpInst), m_Reg(TmpReg), m_ICst(IntImm)))) {
2666    if (IntImm >= NumElts || IntImm < 0)
2667      return false;
2668    if (!MatchInfo[IntImm])
2669      MatchInfo[IntImm] = TmpReg;
2670    CurrInst = TmpInst;
2671  }
2672  // Variable index.
2673  if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
2674    return false;
2675  if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
2676    for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
2677      if (!MatchInfo[I - 1].isValid())
2678        MatchInfo[I - 1] = TmpInst->getOperand(I).getReg();
2679    }
2680    return true;
2681  }
2682  // If we didn't end in a G_IMPLICIT_DEF, bail out.
2683  return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF;
2684}
2685
2686void CombinerHelper::applyCombineInsertVecElts(
2687    MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) {
2688  Builder.setInstr(MI);
2689  Register UndefReg;
2690  auto GetUndef = [&]() {
2691    if (UndefReg)
2692      return UndefReg;
2693    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
2694    UndefReg = Builder.buildUndef(DstTy.getScalarType()).getReg(0);
2695    return UndefReg;
2696  };
2697  for (unsigned I = 0; I < MatchInfo.size(); ++I) {
2698    if (!MatchInfo[I])
2699      MatchInfo[I] = GetUndef();
2700  }
2701  Builder.buildBuildVector(MI.getOperand(0).getReg(), MatchInfo);
2702  MI.eraseFromParent();
2703}
2704
2705void CombinerHelper::applySimplifyAddToSub(
2706    MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) {
2707  Builder.setInstr(MI);
2708  Register SubLHS, SubRHS;
2709  std::tie(SubLHS, SubRHS) = MatchInfo;
2710  Builder.buildSub(MI.getOperand(0).getReg(), SubLHS, SubRHS);
2711  MI.eraseFromParent();
2712}
2713
2714bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
2715    MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) {
2716  // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
2717  //
2718  // Creates the new hand + logic instruction (but does not insert them.)
2719  //
2720  // On success, MatchInfo is populated with the new instructions. These are
2721  // inserted in applyHoistLogicOpWithSameOpcodeHands.
2722  unsigned LogicOpcode = MI.getOpcode();
2723  assert(LogicOpcode == TargetOpcode::G_AND ||
2724         LogicOpcode == TargetOpcode::G_OR ||
2725         LogicOpcode == TargetOpcode::G_XOR);
2726  MachineIRBuilder MIB(MI);
2727  Register Dst = MI.getOperand(0).getReg();
2728  Register LHSReg = MI.getOperand(1).getReg();
2729  Register RHSReg = MI.getOperand(2).getReg();
2730
2731  // Don't recompute anything.
2732  if (!MRI.hasOneNonDBGUse(LHSReg) || !MRI.hasOneNonDBGUse(RHSReg))
2733    return false;
2734
2735  // Make sure we have (hand x, ...), (hand y, ...)
2736  MachineInstr *LeftHandInst = getDefIgnoringCopies(LHSReg, MRI);
2737  MachineInstr *RightHandInst = getDefIgnoringCopies(RHSReg, MRI);
2738  if (!LeftHandInst || !RightHandInst)
2739    return false;
2740  unsigned HandOpcode = LeftHandInst->getOpcode();
2741  if (HandOpcode != RightHandInst->getOpcode())
2742    return false;
2743  if (!LeftHandInst->getOperand(1).isReg() ||
2744      !RightHandInst->getOperand(1).isReg())
2745    return false;
2746
2747  // Make sure the types match up, and if we're doing this post-legalization,
2748  // we end up with legal types.
2749  Register X = LeftHandInst->getOperand(1).getReg();
2750  Register Y = RightHandInst->getOperand(1).getReg();
2751  LLT XTy = MRI.getType(X);
2752  LLT YTy = MRI.getType(Y);
2753  if (XTy != YTy)
2754    return false;
2755  if (!isLegalOrBeforeLegalizer({LogicOpcode, {XTy, YTy}}))
2756    return false;
2757
2758  // Optional extra source register.
2759  Register ExtraHandOpSrcReg;
2760  switch (HandOpcode) {
2761  default:
2762    return false;
2763  case TargetOpcode::G_ANYEXT:
2764  case TargetOpcode::G_SEXT:
2765  case TargetOpcode::G_ZEXT: {
2766    // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
2767    break;
2768  }
2769  case TargetOpcode::G_AND:
2770  case TargetOpcode::G_ASHR:
2771  case TargetOpcode::G_LSHR:
2772  case TargetOpcode::G_SHL: {
2773    // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
2774    MachineOperand &ZOp = LeftHandInst->getOperand(2);
2775    if (!matchEqualDefs(ZOp, RightHandInst->getOperand(2)))
2776      return false;
2777    ExtraHandOpSrcReg = ZOp.getReg();
2778    break;
2779  }
2780  }
2781
2782  // Record the steps to build the new instructions.
2783  //
2784  // Steps to build (logic x, y)
2785  auto NewLogicDst = MRI.createGenericVirtualRegister(XTy);
2786  OperandBuildSteps LogicBuildSteps = {
2787      [=](MachineInstrBuilder &MIB) { MIB.addDef(NewLogicDst); },
2788      [=](MachineInstrBuilder &MIB) { MIB.addReg(X); },
2789      [=](MachineInstrBuilder &MIB) { MIB.addReg(Y); }};
2790  InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
2791
2792  // Steps to build hand (logic x, y), ...z
2793  OperandBuildSteps HandBuildSteps = {
2794      [=](MachineInstrBuilder &MIB) { MIB.addDef(Dst); },
2795      [=](MachineInstrBuilder &MIB) { MIB.addReg(NewLogicDst); }};
2796  if (ExtraHandOpSrcReg.isValid())
2797    HandBuildSteps.push_back(
2798        [=](MachineInstrBuilder &MIB) { MIB.addReg(ExtraHandOpSrcReg); });
2799  InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
2800
2801  MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
2802  return true;
2803}
2804
2805void CombinerHelper::applyBuildInstructionSteps(
2806    MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) {
2807  assert(MatchInfo.InstrsToBuild.size() &&
2808         "Expected at least one instr to build?");
2809  Builder.setInstr(MI);
2810  for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
2811    assert(InstrToBuild.Opcode && "Expected a valid opcode?");
2812    assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
2813    MachineInstrBuilder Instr = Builder.buildInstr(InstrToBuild.Opcode);
2814    for (auto &OperandFn : InstrToBuild.OperandFns)
2815      OperandFn(Instr);
2816  }
2817  MI.eraseFromParent();
2818}
2819
2820bool CombinerHelper::matchAshrShlToSextInreg(
2821    MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) {
2822  assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2823  int64_t ShlCst, AshrCst;
2824  Register Src;
2825  if (!mi_match(MI.getOperand(0).getReg(), MRI,
2826                m_GAShr(m_GShl(m_Reg(Src), m_ICstOrSplat(ShlCst)),
2827                        m_ICstOrSplat(AshrCst))))
2828    return false;
2829  if (ShlCst != AshrCst)
2830    return false;
2831  if (!isLegalOrBeforeLegalizer(
2832          {TargetOpcode::G_SEXT_INREG, {MRI.getType(Src)}}))
2833    return false;
2834  MatchInfo = std::make_tuple(Src, ShlCst);
2835  return true;
2836}
2837
2838void CombinerHelper::applyAshShlToSextInreg(
2839    MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) {
2840  assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2841  Register Src;
2842  int64_t ShiftAmt;
2843  std::tie(Src, ShiftAmt) = MatchInfo;
2844  unsigned Size = MRI.getType(Src).getScalarSizeInBits();
2845  Builder.setInstrAndDebugLoc(MI);
2846  Builder.buildSExtInReg(MI.getOperand(0).getReg(), Src, Size - ShiftAmt);
2847  MI.eraseFromParent();
2848}
2849
2850/// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
2851bool CombinerHelper::matchOverlappingAnd(
2852    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
2853  assert(MI.getOpcode() == TargetOpcode::G_AND);
2854
2855  Register Dst = MI.getOperand(0).getReg();
2856  LLT Ty = MRI.getType(Dst);
2857
2858  Register R;
2859  int64_t C1;
2860  int64_t C2;
2861  if (!mi_match(
2862          Dst, MRI,
2863          m_GAnd(m_GAnd(m_Reg(R), m_ICst(C1)), m_ICst(C2))))
2864    return false;
2865
2866  MatchInfo = [=](MachineIRBuilder &B) {
2867    if (C1 & C2) {
2868      B.buildAnd(Dst, R, B.buildConstant(Ty, C1 & C2));
2869      return;
2870    }
2871    auto Zero = B.buildConstant(Ty, 0);
2872    replaceRegWith(MRI, Dst, Zero->getOperand(0).getReg());
2873  };
2874  return true;
2875}
2876
2877bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
2878                                       Register &Replacement) {
2879  // Given
2880  //
2881  // %y:_(sN) = G_SOMETHING
2882  // %x:_(sN) = G_SOMETHING
2883  // %res:_(sN) = G_AND %x, %y
2884  //
2885  // Eliminate the G_AND when it is known that x & y == x or x & y == y.
2886  //
2887  // Patterns like this can appear as a result of legalization. E.g.
2888  //
2889  // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
2890  // %one:_(s32) = G_CONSTANT i32 1
2891  // %and:_(s32) = G_AND %cmp, %one
2892  //
2893  // In this case, G_ICMP only produces a single bit, so x & 1 == x.
2894  assert(MI.getOpcode() == TargetOpcode::G_AND);
2895  if (!KB)
2896    return false;
2897
2898  Register AndDst = MI.getOperand(0).getReg();
2899  Register LHS = MI.getOperand(1).getReg();
2900  Register RHS = MI.getOperand(2).getReg();
2901  KnownBits LHSBits = KB->getKnownBits(LHS);
2902  KnownBits RHSBits = KB->getKnownBits(RHS);
2903
2904  // Check that x & Mask == x.
2905  // x & 1 == x, always
2906  // x & 0 == x, only if x is also 0
2907  // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
2908  //
2909  // Check if we can replace AndDst with the LHS of the G_AND
2910  if (canReplaceReg(AndDst, LHS, MRI) &&
2911      (LHSBits.Zero | RHSBits.One).isAllOnes()) {
2912    Replacement = LHS;
2913    return true;
2914  }
2915
2916  // Check if we can replace AndDst with the RHS of the G_AND
2917  if (canReplaceReg(AndDst, RHS, MRI) &&
2918      (LHSBits.One | RHSBits.Zero).isAllOnes()) {
2919    Replacement = RHS;
2920    return true;
2921  }
2922
2923  return false;
2924}
2925
2926bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) {
2927  // Given
2928  //
2929  // %y:_(sN) = G_SOMETHING
2930  // %x:_(sN) = G_SOMETHING
2931  // %res:_(sN) = G_OR %x, %y
2932  //
2933  // Eliminate the G_OR when it is known that x | y == x or x | y == y.
2934  assert(MI.getOpcode() == TargetOpcode::G_OR);
2935  if (!KB)
2936    return false;
2937
2938  Register OrDst = MI.getOperand(0).getReg();
2939  Register LHS = MI.getOperand(1).getReg();
2940  Register RHS = MI.getOperand(2).getReg();
2941  KnownBits LHSBits = KB->getKnownBits(LHS);
2942  KnownBits RHSBits = KB->getKnownBits(RHS);
2943
2944  // Check that x | Mask == x.
2945  // x | 0 == x, always
2946  // x | 1 == x, only if x is also 1
2947  // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
2948  //
2949  // Check if we can replace OrDst with the LHS of the G_OR
2950  if (canReplaceReg(OrDst, LHS, MRI) &&
2951      (LHSBits.One | RHSBits.Zero).isAllOnes()) {
2952    Replacement = LHS;
2953    return true;
2954  }
2955
2956  // Check if we can replace OrDst with the RHS of the G_OR
2957  if (canReplaceReg(OrDst, RHS, MRI) &&
2958      (LHSBits.Zero | RHSBits.One).isAllOnes()) {
2959    Replacement = RHS;
2960    return true;
2961  }
2962
2963  return false;
2964}
2965
2966bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) {
2967  // If the input is already sign extended, just drop the extension.
2968  Register Src = MI.getOperand(1).getReg();
2969  unsigned ExtBits = MI.getOperand(2).getImm();
2970  unsigned TypeSize = MRI.getType(Src).getScalarSizeInBits();
2971  return KB->computeNumSignBits(Src) >= (TypeSize - ExtBits + 1);
2972}
2973
2974static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
2975                             int64_t Cst, bool IsVector, bool IsFP) {
2976  // For i1, Cst will always be -1 regardless of boolean contents.
2977  return (ScalarSizeBits == 1 && Cst == -1) ||
2978         isConstTrueVal(TLI, Cst, IsVector, IsFP);
2979}
2980
2981bool CombinerHelper::matchNotCmp(MachineInstr &MI,
2982                                 SmallVectorImpl<Register> &RegsToNegate) {
2983  assert(MI.getOpcode() == TargetOpcode::G_XOR);
2984  LLT Ty = MRI.getType(MI.getOperand(0).getReg());
2985  const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
2986  Register XorSrc;
2987  Register CstReg;
2988  // We match xor(src, true) here.
2989  if (!mi_match(MI.getOperand(0).getReg(), MRI,
2990                m_GXor(m_Reg(XorSrc), m_Reg(CstReg))))
2991    return false;
2992
2993  if (!MRI.hasOneNonDBGUse(XorSrc))
2994    return false;
2995
2996  // Check that XorSrc is the root of a tree of comparisons combined with ANDs
2997  // and ORs. The suffix of RegsToNegate starting from index I is used a work
2998  // list of tree nodes to visit.
2999  RegsToNegate.push_back(XorSrc);
3000  // Remember whether the comparisons are all integer or all floating point.
3001  bool IsInt = false;
3002  bool IsFP = false;
3003  for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3004    Register Reg = RegsToNegate[I];
3005    if (!MRI.hasOneNonDBGUse(Reg))
3006      return false;
3007    MachineInstr *Def = MRI.getVRegDef(Reg);
3008    switch (Def->getOpcode()) {
3009    default:
3010      // Don't match if the tree contains anything other than ANDs, ORs and
3011      // comparisons.
3012      return false;
3013    case TargetOpcode::G_ICMP:
3014      if (IsFP)
3015        return false;
3016      IsInt = true;
3017      // When we apply the combine we will invert the predicate.
3018      break;
3019    case TargetOpcode::G_FCMP:
3020      if (IsInt)
3021        return false;
3022      IsFP = true;
3023      // When we apply the combine we will invert the predicate.
3024      break;
3025    case TargetOpcode::G_AND:
3026    case TargetOpcode::G_OR:
3027      // Implement De Morgan's laws:
3028      // ~(x & y) -> ~x | ~y
3029      // ~(x | y) -> ~x & ~y
3030      // When we apply the combine we will change the opcode and recursively
3031      // negate the operands.
3032      RegsToNegate.push_back(Def->getOperand(1).getReg());
3033      RegsToNegate.push_back(Def->getOperand(2).getReg());
3034      break;
3035    }
3036  }
3037
3038  // Now we know whether the comparisons are integer or floating point, check
3039  // the constant in the xor.
3040  int64_t Cst;
3041  if (Ty.isVector()) {
3042    MachineInstr *CstDef = MRI.getVRegDef(CstReg);
3043    auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI);
3044    if (!MaybeCst)
3045      return false;
3046    if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP))
3047      return false;
3048  } else {
3049    if (!mi_match(CstReg, MRI, m_ICst(Cst)))
3050      return false;
3051    if (!isConstValidTrue(TLI, Ty.getSizeInBits(), Cst, false, IsFP))
3052      return false;
3053  }
3054
3055  return true;
3056}
3057
3058void CombinerHelper::applyNotCmp(MachineInstr &MI,
3059                                 SmallVectorImpl<Register> &RegsToNegate) {
3060  for (Register Reg : RegsToNegate) {
3061    MachineInstr *Def = MRI.getVRegDef(Reg);
3062    Observer.changingInstr(*Def);
3063    // For each comparison, invert the opcode. For each AND and OR, change the
3064    // opcode.
3065    switch (Def->getOpcode()) {
3066    default:
3067      llvm_unreachable("Unexpected opcode");
3068    case TargetOpcode::G_ICMP:
3069    case TargetOpcode::G_FCMP: {
3070      MachineOperand &PredOp = Def->getOperand(1);
3071      CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3072          (CmpInst::Predicate)PredOp.getPredicate());
3073      PredOp.setPredicate(NewP);
3074      break;
3075    }
3076    case TargetOpcode::G_AND:
3077      Def->setDesc(Builder.getTII().get(TargetOpcode::G_OR));
3078      break;
3079    case TargetOpcode::G_OR:
3080      Def->setDesc(Builder.getTII().get(TargetOpcode::G_AND));
3081      break;
3082    }
3083    Observer.changedInstr(*Def);
3084  }
3085
3086  replaceRegWith(MRI, MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
3087  MI.eraseFromParent();
3088}
3089
3090bool CombinerHelper::matchXorOfAndWithSameReg(
3091    MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
3092  // Match (xor (and x, y), y) (or any of its commuted cases)
3093  assert(MI.getOpcode() == TargetOpcode::G_XOR);
3094  Register &X = MatchInfo.first;
3095  Register &Y = MatchInfo.second;
3096  Register AndReg = MI.getOperand(1).getReg();
3097  Register SharedReg = MI.getOperand(2).getReg();
3098
3099  // Find a G_AND on either side of the G_XOR.
3100  // Look for one of
3101  //
3102  // (xor (and x, y), SharedReg)
3103  // (xor SharedReg, (and x, y))
3104  if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) {
3105    std::swap(AndReg, SharedReg);
3106    if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y))))
3107      return false;
3108  }
3109
3110  // Only do this if we'll eliminate the G_AND.
3111  if (!MRI.hasOneNonDBGUse(AndReg))
3112    return false;
3113
3114  // We can combine if SharedReg is the same as either the LHS or RHS of the
3115  // G_AND.
3116  if (Y != SharedReg)
3117    std::swap(X, Y);
3118  return Y == SharedReg;
3119}
3120
3121void CombinerHelper::applyXorOfAndWithSameReg(
3122    MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
3123  // Fold (xor (and x, y), y) -> (and (not x), y)
3124  Builder.setInstrAndDebugLoc(MI);
3125  Register X, Y;
3126  std::tie(X, Y) = MatchInfo;
3127  auto Not = Builder.buildNot(MRI.getType(X), X);
3128  Observer.changingInstr(MI);
3129  MI.setDesc(Builder.getTII().get(TargetOpcode::G_AND));
3130  MI.getOperand(1).setReg(Not->getOperand(0).getReg());
3131  MI.getOperand(2).setReg(Y);
3132  Observer.changedInstr(MI);
3133}
3134
3135bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) {
3136  auto &PtrAdd = cast<GPtrAdd>(MI);
3137  Register DstReg = PtrAdd.getReg(0);
3138  LLT Ty = MRI.getType(DstReg);
3139  const DataLayout &DL = Builder.getMF().getDataLayout();
3140
3141  if (DL.isNonIntegralAddressSpace(Ty.getScalarType().getAddressSpace()))
3142    return false;
3143
3144  if (Ty.isPointer()) {
3145    auto ConstVal = getIConstantVRegVal(PtrAdd.getBaseReg(), MRI);
3146    return ConstVal && *ConstVal == 0;
3147  }
3148
3149  assert(Ty.isVector() && "Expecting a vector type");
3150  const MachineInstr *VecMI = MRI.getVRegDef(PtrAdd.getBaseReg());
3151  return isBuildVectorAllZeros(*VecMI, MRI);
3152}
3153
3154void CombinerHelper::applyPtrAddZero(MachineInstr &MI) {
3155  auto &PtrAdd = cast<GPtrAdd>(MI);
3156  Builder.setInstrAndDebugLoc(PtrAdd);
3157  Builder.buildIntToPtr(PtrAdd.getReg(0), PtrAdd.getOffsetReg());
3158  PtrAdd.eraseFromParent();
3159}
3160
3161/// The second source operand is known to be a power of 2.
3162void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) {
3163  Register DstReg = MI.getOperand(0).getReg();
3164  Register Src0 = MI.getOperand(1).getReg();
3165  Register Pow2Src1 = MI.getOperand(2).getReg();
3166  LLT Ty = MRI.getType(DstReg);
3167  Builder.setInstrAndDebugLoc(MI);
3168
3169  // Fold (urem x, pow2) -> (and x, pow2-1)
3170  auto NegOne = Builder.buildConstant(Ty, -1);
3171  auto Add = Builder.buildAdd(Ty, Pow2Src1, NegOne);
3172  Builder.buildAnd(DstReg, Src0, Add);
3173  MI.eraseFromParent();
3174}
3175
3176bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3177                                              unsigned &SelectOpNo) {
3178  Register LHS = MI.getOperand(1).getReg();
3179  Register RHS = MI.getOperand(2).getReg();
3180
3181  Register OtherOperandReg = RHS;
3182  SelectOpNo = 1;
3183  MachineInstr *Select = MRI.getVRegDef(LHS);
3184
3185  // Don't do this unless the old select is going away. We want to eliminate the
3186  // binary operator, not replace a binop with a select.
3187  if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3188      !MRI.hasOneNonDBGUse(LHS)) {
3189    OtherOperandReg = LHS;
3190    SelectOpNo = 2;
3191    Select = MRI.getVRegDef(RHS);
3192    if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3193        !MRI.hasOneNonDBGUse(RHS))
3194      return false;
3195  }
3196
3197  MachineInstr *SelectLHS = MRI.getVRegDef(Select->getOperand(2).getReg());
3198  MachineInstr *SelectRHS = MRI.getVRegDef(Select->getOperand(3).getReg());
3199
3200  if (!isConstantOrConstantVector(*SelectLHS, MRI,
3201                                  /*AllowFP*/ true,
3202                                  /*AllowOpaqueConstants*/ false))
3203    return false;
3204  if (!isConstantOrConstantVector(*SelectRHS, MRI,
3205                                  /*AllowFP*/ true,
3206                                  /*AllowOpaqueConstants*/ false))
3207    return false;
3208
3209  unsigned BinOpcode = MI.getOpcode();
3210
3211  // We know know one of the operands is a select of constants. Now verify that
3212  // the other binary operator operand is either a constant, or we can handle a
3213  // variable.
3214  bool CanFoldNonConst =
3215      (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3216      (isNullOrNullSplat(*SelectLHS, MRI) ||
3217       isAllOnesOrAllOnesSplat(*SelectLHS, MRI)) &&
3218      (isNullOrNullSplat(*SelectRHS, MRI) ||
3219       isAllOnesOrAllOnesSplat(*SelectRHS, MRI));
3220  if (CanFoldNonConst)
3221    return true;
3222
3223  return isConstantOrConstantVector(*MRI.getVRegDef(OtherOperandReg), MRI,
3224                                    /*AllowFP*/ true,
3225                                    /*AllowOpaqueConstants*/ false);
3226}
3227
3228/// \p SelectOperand is the operand in binary operator \p MI that is the select
3229/// to fold.
3230bool CombinerHelper::applyFoldBinOpIntoSelect(MachineInstr &MI,
3231                                              const unsigned &SelectOperand) {
3232  Builder.setInstrAndDebugLoc(MI);
3233
3234  Register Dst = MI.getOperand(0).getReg();
3235  Register LHS = MI.getOperand(1).getReg();
3236  Register RHS = MI.getOperand(2).getReg();
3237  MachineInstr *Select = MRI.getVRegDef(MI.getOperand(SelectOperand).getReg());
3238
3239  Register SelectCond = Select->getOperand(1).getReg();
3240  Register SelectTrue = Select->getOperand(2).getReg();
3241  Register SelectFalse = Select->getOperand(3).getReg();
3242
3243  LLT Ty = MRI.getType(Dst);
3244  unsigned BinOpcode = MI.getOpcode();
3245
3246  Register FoldTrue, FoldFalse;
3247
3248  // We have a select-of-constants followed by a binary operator with a
3249  // constant. Eliminate the binop by pulling the constant math into the select.
3250  // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3251  if (SelectOperand == 1) {
3252    // TODO: SelectionDAG verifies this actually constant folds before
3253    // committing to the combine.
3254
3255    FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {SelectTrue, RHS}).getReg(0);
3256    FoldFalse =
3257        Builder.buildInstr(BinOpcode, {Ty}, {SelectFalse, RHS}).getReg(0);
3258  } else {
3259    FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectTrue}).getReg(0);
3260    FoldFalse =
3261        Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectFalse}).getReg(0);
3262  }
3263
3264  Builder.buildSelect(Dst, SelectCond, FoldTrue, FoldFalse, MI.getFlags());
3265  MI.eraseFromParent();
3266
3267  return true;
3268}
3269
3270std::optional<SmallVector<Register, 8>>
3271CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3272  assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3273  // We want to detect if Root is part of a tree which represents a bunch
3274  // of loads being merged into a larger load. We'll try to recognize patterns
3275  // like, for example:
3276  //
3277  //  Reg   Reg
3278  //   \    /
3279  //    OR_1   Reg
3280  //     \    /
3281  //      OR_2
3282  //        \     Reg
3283  //         .. /
3284  //        Root
3285  //
3286  //  Reg   Reg   Reg   Reg
3287  //     \ /       \   /
3288  //     OR_1      OR_2
3289  //       \       /
3290  //        \    /
3291  //         ...
3292  //         Root
3293  //
3294  // Each "Reg" may have been produced by a load + some arithmetic. This
3295  // function will save each of them.
3296  SmallVector<Register, 8> RegsToVisit;
3297  SmallVector<const MachineInstr *, 7> Ors = {Root};
3298
3299  // In the "worst" case, we're dealing with a load for each byte. So, there
3300  // are at most #bytes - 1 ORs.
3301  const unsigned MaxIter =
3302      MRI.getType(Root->getOperand(0).getReg()).getSizeInBytes() - 1;
3303  for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
3304    if (Ors.empty())
3305      break;
3306    const MachineInstr *Curr = Ors.pop_back_val();
3307    Register OrLHS = Curr->getOperand(1).getReg();
3308    Register OrRHS = Curr->getOperand(2).getReg();
3309
3310    // In the combine, we want to elimate the entire tree.
3311    if (!MRI.hasOneNonDBGUse(OrLHS) || !MRI.hasOneNonDBGUse(OrRHS))
3312      return std::nullopt;
3313
3314    // If it's a G_OR, save it and continue to walk. If it's not, then it's
3315    // something that may be a load + arithmetic.
3316    if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrLHS, MRI))
3317      Ors.push_back(Or);
3318    else
3319      RegsToVisit.push_back(OrLHS);
3320    if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrRHS, MRI))
3321      Ors.push_back(Or);
3322    else
3323      RegsToVisit.push_back(OrRHS);
3324  }
3325
3326  // We're going to try and merge each register into a wider power-of-2 type,
3327  // so we ought to have an even number of registers.
3328  if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
3329    return std::nullopt;
3330  return RegsToVisit;
3331}
3332
3333/// Helper function for findLoadOffsetsForLoadOrCombine.
3334///
3335/// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
3336/// and then moving that value into a specific byte offset.
3337///
3338/// e.g. x[i] << 24
3339///
3340/// \returns The load instruction and the byte offset it is moved into.
3341static std::optional<std::pair<GZExtLoad *, int64_t>>
3342matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
3343                         const MachineRegisterInfo &MRI) {
3344  assert(MRI.hasOneNonDBGUse(Reg) &&
3345         "Expected Reg to only have one non-debug use?");
3346  Register MaybeLoad;
3347  int64_t Shift;
3348  if (!mi_match(Reg, MRI,
3349                m_OneNonDBGUse(m_GShl(m_Reg(MaybeLoad), m_ICst(Shift))))) {
3350    Shift = 0;
3351    MaybeLoad = Reg;
3352  }
3353
3354  if (Shift % MemSizeInBits != 0)
3355    return std::nullopt;
3356
3357  // TODO: Handle other types of loads.
3358  auto *Load = getOpcodeDef<GZExtLoad>(MaybeLoad, MRI);
3359  if (!Load)
3360    return std::nullopt;
3361
3362  if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
3363    return std::nullopt;
3364
3365  return std::make_pair(Load, Shift / MemSizeInBits);
3366}
3367
3368std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
3369CombinerHelper::findLoadOffsetsForLoadOrCombine(
3370    SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
3371    const SmallVector<Register, 8> &RegsToVisit, const unsigned MemSizeInBits) {
3372
3373  // Each load found for the pattern. There should be one for each RegsToVisit.
3374  SmallSetVector<const MachineInstr *, 8> Loads;
3375
3376  // The lowest index used in any load. (The lowest "i" for each x[i].)
3377  int64_t LowestIdx = INT64_MAX;
3378
3379  // The load which uses the lowest index.
3380  GZExtLoad *LowestIdxLoad = nullptr;
3381
3382  // Keeps track of the load indices we see. We shouldn't see any indices twice.
3383  SmallSet<int64_t, 8> SeenIdx;
3384
3385  // Ensure each load is in the same MBB.
3386  // TODO: Support multiple MachineBasicBlocks.
3387  MachineBasicBlock *MBB = nullptr;
3388  const MachineMemOperand *MMO = nullptr;
3389
3390  // Earliest instruction-order load in the pattern.
3391  GZExtLoad *EarliestLoad = nullptr;
3392
3393  // Latest instruction-order load in the pattern.
3394  GZExtLoad *LatestLoad = nullptr;
3395
3396  // Base pointer which every load should share.
3397  Register BasePtr;
3398
3399  // We want to find a load for each register. Each load should have some
3400  // appropriate bit twiddling arithmetic. During this loop, we will also keep
3401  // track of the load which uses the lowest index. Later, we will check if we
3402  // can use its pointer in the final, combined load.
3403  for (auto Reg : RegsToVisit) {
3404    // Find the load, and find the position that it will end up in (e.g. a
3405    // shifted) value.
3406    auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
3407    if (!LoadAndPos)
3408      return std::nullopt;
3409    GZExtLoad *Load;
3410    int64_t DstPos;
3411    std::tie(Load, DstPos) = *LoadAndPos;
3412
3413    // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
3414    // it is difficult to check for stores/calls/etc between loads.
3415    MachineBasicBlock *LoadMBB = Load->getParent();
3416    if (!MBB)
3417      MBB = LoadMBB;
3418    if (LoadMBB != MBB)
3419      return std::nullopt;
3420
3421    // Make sure that the MachineMemOperands of every seen load are compatible.
3422    auto &LoadMMO = Load->getMMO();
3423    if (!MMO)
3424      MMO = &LoadMMO;
3425    if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
3426      return std::nullopt;
3427
3428    // Find out what the base pointer and index for the load is.
3429    Register LoadPtr;
3430    int64_t Idx;
3431    if (!mi_match(Load->getOperand(1).getReg(), MRI,
3432                  m_GPtrAdd(m_Reg(LoadPtr), m_ICst(Idx)))) {
3433      LoadPtr = Load->getOperand(1).getReg();
3434      Idx = 0;
3435    }
3436
3437    // Don't combine things like a[i], a[i] -> a bigger load.
3438    if (!SeenIdx.insert(Idx).second)
3439      return std::nullopt;
3440
3441    // Every load must share the same base pointer; don't combine things like:
3442    //
3443    // a[i], b[i + 1] -> a bigger load.
3444    if (!BasePtr.isValid())
3445      BasePtr = LoadPtr;
3446    if (BasePtr != LoadPtr)
3447      return std::nullopt;
3448
3449    if (Idx < LowestIdx) {
3450      LowestIdx = Idx;
3451      LowestIdxLoad = Load;
3452    }
3453
3454    // Keep track of the byte offset that this load ends up at. If we have seen
3455    // the byte offset, then stop here. We do not want to combine:
3456    //
3457    // a[i] << 16, a[i + k] << 16 -> a bigger load.
3458    if (!MemOffset2Idx.try_emplace(DstPos, Idx).second)
3459      return std::nullopt;
3460    Loads.insert(Load);
3461
3462    // Keep track of the position of the earliest/latest loads in the pattern.
3463    // We will check that there are no load fold barriers between them later
3464    // on.
3465    //
3466    // FIXME: Is there a better way to check for load fold barriers?
3467    if (!EarliestLoad || dominates(*Load, *EarliestLoad))
3468      EarliestLoad = Load;
3469    if (!LatestLoad || dominates(*LatestLoad, *Load))
3470      LatestLoad = Load;
3471  }
3472
3473  // We found a load for each register. Let's check if each load satisfies the
3474  // pattern.
3475  assert(Loads.size() == RegsToVisit.size() &&
3476         "Expected to find a load for each register?");
3477  assert(EarliestLoad != LatestLoad && EarliestLoad &&
3478         LatestLoad && "Expected at least two loads?");
3479
3480  // Check if there are any stores, calls, etc. between any of the loads. If
3481  // there are, then we can't safely perform the combine.
3482  //
3483  // MaxIter is chosen based off the (worst case) number of iterations it
3484  // typically takes to succeed in the LLVM test suite plus some padding.
3485  //
3486  // FIXME: Is there a better way to check for load fold barriers?
3487  const unsigned MaxIter = 20;
3488  unsigned Iter = 0;
3489  for (const auto &MI : instructionsWithoutDebug(EarliestLoad->getIterator(),
3490                                                 LatestLoad->getIterator())) {
3491    if (Loads.count(&MI))
3492      continue;
3493    if (MI.isLoadFoldBarrier())
3494      return std::nullopt;
3495    if (Iter++ == MaxIter)
3496      return std::nullopt;
3497  }
3498
3499  return std::make_tuple(LowestIdxLoad, LowestIdx, LatestLoad);
3500}
3501
3502bool CombinerHelper::matchLoadOrCombine(
3503    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
3504  assert(MI.getOpcode() == TargetOpcode::G_OR);
3505  MachineFunction &MF = *MI.getMF();
3506  // Assuming a little-endian target, transform:
3507  //  s8 *a = ...
3508  //  s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
3509  // =>
3510  //  s32 val = *((i32)a)
3511  //
3512  //  s8 *a = ...
3513  //  s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
3514  // =>
3515  //  s32 val = BSWAP(*((s32)a))
3516  Register Dst = MI.getOperand(0).getReg();
3517  LLT Ty = MRI.getType(Dst);
3518  if (Ty.isVector())
3519    return false;
3520
3521  // We need to combine at least two loads into this type. Since the smallest
3522  // possible load is into a byte, we need at least a 16-bit wide type.
3523  const unsigned WideMemSizeInBits = Ty.getSizeInBits();
3524  if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
3525    return false;
3526
3527  // Match a collection of non-OR instructions in the pattern.
3528  auto RegsToVisit = findCandidatesForLoadOrCombine(&MI);
3529  if (!RegsToVisit)
3530    return false;
3531
3532  // We have a collection of non-OR instructions. Figure out how wide each of
3533  // the small loads should be based off of the number of potential loads we
3534  // found.
3535  const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
3536  if (NarrowMemSizeInBits % 8 != 0)
3537    return false;
3538
3539  // Check if each register feeding into each OR is a load from the same
3540  // base pointer + some arithmetic.
3541  //
3542  // e.g. a[0], a[1] << 8, a[2] << 16, etc.
3543  //
3544  // Also verify that each of these ends up putting a[i] into the same memory
3545  // offset as a load into a wide type would.
3546  SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
3547  GZExtLoad *LowestIdxLoad, *LatestLoad;
3548  int64_t LowestIdx;
3549  auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
3550      MemOffset2Idx, *RegsToVisit, NarrowMemSizeInBits);
3551  if (!MaybeLoadInfo)
3552    return false;
3553  std::tie(LowestIdxLoad, LowestIdx, LatestLoad) = *MaybeLoadInfo;
3554
3555  // We have a bunch of loads being OR'd together. Using the addresses + offsets
3556  // we found before, check if this corresponds to a big or little endian byte
3557  // pattern. If it does, then we can represent it using a load + possibly a
3558  // BSWAP.
3559  bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
3560  std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
3561  if (!IsBigEndian)
3562    return false;
3563  bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
3564  if (NeedsBSwap && !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {Ty}}))
3565    return false;
3566
3567  // Make sure that the load from the lowest index produces offset 0 in the
3568  // final value.
3569  //
3570  // This ensures that we won't combine something like this:
3571  //
3572  // load x[i] -> byte 2
3573  // load x[i+1] -> byte 0 ---> wide_load x[i]
3574  // load x[i+2] -> byte 1
3575  const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
3576  const unsigned ZeroByteOffset =
3577      *IsBigEndian
3578          ? bigEndianByteAt(NumLoadsInTy, 0)
3579          : littleEndianByteAt(NumLoadsInTy, 0);
3580  auto ZeroOffsetIdx = MemOffset2Idx.find(ZeroByteOffset);
3581  if (ZeroOffsetIdx == MemOffset2Idx.end() ||
3582      ZeroOffsetIdx->second != LowestIdx)
3583    return false;
3584
3585  // We wil reuse the pointer from the load which ends up at byte offset 0. It
3586  // may not use index 0.
3587  Register Ptr = LowestIdxLoad->getPointerReg();
3588  const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
3589  LegalityQuery::MemDesc MMDesc(MMO);
3590  MMDesc.MemoryTy = Ty;
3591  if (!isLegalOrBeforeLegalizer(
3592          {TargetOpcode::G_LOAD, {Ty, MRI.getType(Ptr)}, {MMDesc}}))
3593    return false;
3594  auto PtrInfo = MMO.getPointerInfo();
3595  auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, WideMemSizeInBits / 8);
3596
3597  // Load must be allowed and fast on the target.
3598  LLVMContext &C = MF.getFunction().getContext();
3599  auto &DL = MF.getDataLayout();
3600  unsigned Fast = 0;
3601  if (!getTargetLowering().allowsMemoryAccess(C, DL, Ty, *NewMMO, &Fast) ||
3602      !Fast)
3603    return false;
3604
3605  MatchInfo = [=](MachineIRBuilder &MIB) {
3606    MIB.setInstrAndDebugLoc(*LatestLoad);
3607    Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(Dst) : Dst;
3608    MIB.buildLoad(LoadDst, Ptr, *NewMMO);
3609    if (NeedsBSwap)
3610      MIB.buildBSwap(Dst, LoadDst);
3611  };
3612  return true;
3613}
3614
3615/// Check if the store \p Store is a truncstore that can be merged. That is,
3616/// it's a store of a shifted value of \p SrcVal. If \p SrcVal is an empty
3617/// Register then it does not need to match and SrcVal is set to the source
3618/// value found.
3619/// On match, returns the start byte offset of the \p SrcVal that is being
3620/// stored.
3621static std::optional<int64_t>
3622getTruncStoreByteOffset(GStore &Store, Register &SrcVal,
3623                        MachineRegisterInfo &MRI) {
3624  Register TruncVal;
3625  if (!mi_match(Store.getValueReg(), MRI, m_GTrunc(m_Reg(TruncVal))))
3626    return std::nullopt;
3627
3628  // The shift amount must be a constant multiple of the narrow type.
3629  // It is translated to the offset address in the wide source value "y".
3630  //
3631  // x = G_LSHR y, ShiftAmtC
3632  // s8 z = G_TRUNC x
3633  // store z, ...
3634  Register FoundSrcVal;
3635  int64_t ShiftAmt;
3636  if (!mi_match(TruncVal, MRI,
3637                m_any_of(m_GLShr(m_Reg(FoundSrcVal), m_ICst(ShiftAmt)),
3638                         m_GAShr(m_Reg(FoundSrcVal), m_ICst(ShiftAmt))))) {
3639    if (!SrcVal.isValid() || TruncVal == SrcVal) {
3640      if (!SrcVal.isValid())
3641        SrcVal = TruncVal;
3642      return 0; // If it's the lowest index store.
3643    }
3644    return std::nullopt;
3645  }
3646
3647  unsigned NarrowBits = Store.getMMO().getMemoryType().getScalarSizeInBits();
3648  if (ShiftAmt % NarrowBits!= 0)
3649    return std::nullopt;
3650  const unsigned Offset = ShiftAmt / NarrowBits;
3651
3652  if (SrcVal.isValid() && FoundSrcVal != SrcVal)
3653    return std::nullopt;
3654
3655  if (!SrcVal.isValid())
3656    SrcVal = FoundSrcVal;
3657  else if (MRI.getType(SrcVal) != MRI.getType(FoundSrcVal))
3658    return std::nullopt;
3659  return Offset;
3660}
3661
3662/// Match a pattern where a wide type scalar value is stored by several narrow
3663/// stores. Fold it into a single store or a BSWAP and a store if the targets
3664/// supports it.
3665///
3666/// Assuming little endian target:
3667///  i8 *p = ...
3668///  i32 val = ...
3669///  p[0] = (val >> 0) & 0xFF;
3670///  p[1] = (val >> 8) & 0xFF;
3671///  p[2] = (val >> 16) & 0xFF;
3672///  p[3] = (val >> 24) & 0xFF;
3673/// =>
3674///  *((i32)p) = val;
3675///
3676///  i8 *p = ...
3677///  i32 val = ...
3678///  p[0] = (val >> 24) & 0xFF;
3679///  p[1] = (val >> 16) & 0xFF;
3680///  p[2] = (val >> 8) & 0xFF;
3681///  p[3] = (val >> 0) & 0xFF;
3682/// =>
3683///  *((i32)p) = BSWAP(val);
3684bool CombinerHelper::matchTruncStoreMerge(MachineInstr &MI,
3685                                          MergeTruncStoresInfo &MatchInfo) {
3686  auto &StoreMI = cast<GStore>(MI);
3687  LLT MemTy = StoreMI.getMMO().getMemoryType();
3688
3689  // We only handle merging simple stores of 1-4 bytes.
3690  if (!MemTy.isScalar())
3691    return false;
3692  switch (MemTy.getSizeInBits()) {
3693  case 8:
3694  case 16:
3695  case 32:
3696    break;
3697  default:
3698    return false;
3699  }
3700  if (!StoreMI.isSimple())
3701    return false;
3702
3703  // We do a simple search for mergeable stores prior to this one.
3704  // Any potential alias hazard along the way terminates the search.
3705  SmallVector<GStore *> FoundStores;
3706
3707  // We're looking for:
3708  // 1) a (store(trunc(...)))
3709  // 2) of an LSHR/ASHR of a single wide value, by the appropriate shift to get
3710  //    the partial value stored.
3711  // 3) where the offsets form either a little or big-endian sequence.
3712
3713  auto &LastStore = StoreMI;
3714
3715  // The single base pointer that all stores must use.
3716  Register BaseReg;
3717  int64_t LastOffset;
3718  if (!mi_match(LastStore.getPointerReg(), MRI,
3719                m_GPtrAdd(m_Reg(BaseReg), m_ICst(LastOffset)))) {
3720    BaseReg = LastStore.getPointerReg();
3721    LastOffset = 0;
3722  }
3723
3724  GStore *LowestIdxStore = &LastStore;
3725  int64_t LowestIdxOffset = LastOffset;
3726
3727  Register WideSrcVal;
3728  auto LowestShiftAmt = getTruncStoreByteOffset(LastStore, WideSrcVal, MRI);
3729  if (!LowestShiftAmt)
3730    return false; // Didn't match a trunc.
3731  assert(WideSrcVal.isValid());
3732
3733  LLT WideStoreTy = MRI.getType(WideSrcVal);
3734  // The wide type might not be a multiple of the memory type, e.g. s48 and s32.
3735  if (WideStoreTy.getSizeInBits() % MemTy.getSizeInBits() != 0)
3736    return false;
3737  const unsigned NumStoresRequired =
3738      WideStoreTy.getSizeInBits() / MemTy.getSizeInBits();
3739
3740  SmallVector<int64_t, 8> OffsetMap(NumStoresRequired, INT64_MAX);
3741  OffsetMap[*LowestShiftAmt] = LastOffset;
3742  FoundStores.emplace_back(&LastStore);
3743
3744  // Search the block up for more stores.
3745  // We use a search threshold of 10 instructions here because the combiner
3746  // works top-down within a block, and we don't want to search an unbounded
3747  // number of predecessor instructions trying to find matching stores.
3748  // If we moved this optimization into a separate pass then we could probably
3749  // use a more efficient search without having a hard-coded threshold.
3750  const int MaxInstsToCheck = 10;
3751  int NumInstsChecked = 0;
3752  for (auto II = ++LastStore.getReverseIterator();
3753       II != LastStore.getParent()->rend() && NumInstsChecked < MaxInstsToCheck;
3754       ++II) {
3755    NumInstsChecked++;
3756    GStore *NewStore;
3757    if ((NewStore = dyn_cast<GStore>(&*II))) {
3758      if (NewStore->getMMO().getMemoryType() != MemTy || !NewStore->isSimple())
3759        break;
3760    } else if (II->isLoadFoldBarrier() || II->mayLoad()) {
3761      break;
3762    } else {
3763      continue; // This is a safe instruction we can look past.
3764    }
3765
3766    Register NewBaseReg;
3767    int64_t MemOffset;
3768    // Check we're storing to the same base + some offset.
3769    if (!mi_match(NewStore->getPointerReg(), MRI,
3770                  m_GPtrAdd(m_Reg(NewBaseReg), m_ICst(MemOffset)))) {
3771      NewBaseReg = NewStore->getPointerReg();
3772      MemOffset = 0;
3773    }
3774    if (BaseReg != NewBaseReg)
3775      break;
3776
3777    auto ShiftByteOffset = getTruncStoreByteOffset(*NewStore, WideSrcVal, MRI);
3778    if (!ShiftByteOffset)
3779      break;
3780    if (MemOffset < LowestIdxOffset) {
3781      LowestIdxOffset = MemOffset;
3782      LowestIdxStore = NewStore;
3783    }
3784
3785    // Map the offset in the store and the offset in the combined value, and
3786    // early return if it has been set before.
3787    if (*ShiftByteOffset < 0 || *ShiftByteOffset >= NumStoresRequired ||
3788        OffsetMap[*ShiftByteOffset] != INT64_MAX)
3789      break;
3790    OffsetMap[*ShiftByteOffset] = MemOffset;
3791
3792    FoundStores.emplace_back(NewStore);
3793    // Reset counter since we've found a matching inst.
3794    NumInstsChecked = 0;
3795    if (FoundStores.size() == NumStoresRequired)
3796      break;
3797  }
3798
3799  if (FoundStores.size() != NumStoresRequired) {
3800    return false;
3801  }
3802
3803  const auto &DL = LastStore.getMF()->getDataLayout();
3804  auto &C = LastStore.getMF()->getFunction().getContext();
3805  // Check that a store of the wide type is both allowed and fast on the target
3806  unsigned Fast = 0;
3807  bool Allowed = getTargetLowering().allowsMemoryAccess(
3808      C, DL, WideStoreTy, LowestIdxStore->getMMO(), &Fast);
3809  if (!Allowed || !Fast)
3810    return false;
3811
3812  // Check if the pieces of the value are going to the expected places in memory
3813  // to merge the stores.
3814  unsigned NarrowBits = MemTy.getScalarSizeInBits();
3815  auto checkOffsets = [&](bool MatchLittleEndian) {
3816    if (MatchLittleEndian) {
3817      for (unsigned i = 0; i != NumStoresRequired; ++i)
3818        if (OffsetMap[i] != i * (NarrowBits / 8) + LowestIdxOffset)
3819          return false;
3820    } else { // MatchBigEndian by reversing loop counter.
3821      for (unsigned i = 0, j = NumStoresRequired - 1; i != NumStoresRequired;
3822           ++i, --j)
3823        if (OffsetMap[j] != i * (NarrowBits / 8) + LowestIdxOffset)
3824          return false;
3825    }
3826    return true;
3827  };
3828
3829  // Check if the offsets line up for the native data layout of this target.
3830  bool NeedBswap = false;
3831  bool NeedRotate = false;
3832  if (!checkOffsets(DL.isLittleEndian())) {
3833    // Special-case: check if byte offsets line up for the opposite endian.
3834    if (NarrowBits == 8 && checkOffsets(DL.isBigEndian()))
3835      NeedBswap = true;
3836    else if (NumStoresRequired == 2 && checkOffsets(DL.isBigEndian()))
3837      NeedRotate = true;
3838    else
3839      return false;
3840  }
3841
3842  if (NeedBswap &&
3843      !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {WideStoreTy}}))
3844    return false;
3845  if (NeedRotate &&
3846      !isLegalOrBeforeLegalizer({TargetOpcode::G_ROTR, {WideStoreTy}}))
3847    return false;
3848
3849  MatchInfo.NeedBSwap = NeedBswap;
3850  MatchInfo.NeedRotate = NeedRotate;
3851  MatchInfo.LowestIdxStore = LowestIdxStore;
3852  MatchInfo.WideSrcVal = WideSrcVal;
3853  MatchInfo.FoundStores = std::move(FoundStores);
3854  return true;
3855}
3856
3857void CombinerHelper::applyTruncStoreMerge(MachineInstr &MI,
3858                                          MergeTruncStoresInfo &MatchInfo) {
3859
3860  Builder.setInstrAndDebugLoc(MI);
3861  Register WideSrcVal = MatchInfo.WideSrcVal;
3862  LLT WideStoreTy = MRI.getType(WideSrcVal);
3863
3864  if (MatchInfo.NeedBSwap) {
3865    WideSrcVal = Builder.buildBSwap(WideStoreTy, WideSrcVal).getReg(0);
3866  } else if (MatchInfo.NeedRotate) {
3867    assert(WideStoreTy.getSizeInBits() % 2 == 0 &&
3868           "Unexpected type for rotate");
3869    auto RotAmt =
3870        Builder.buildConstant(WideStoreTy, WideStoreTy.getSizeInBits() / 2);
3871    WideSrcVal =
3872        Builder.buildRotateRight(WideStoreTy, WideSrcVal, RotAmt).getReg(0);
3873  }
3874
3875  Builder.buildStore(WideSrcVal, MatchInfo.LowestIdxStore->getPointerReg(),
3876                     MatchInfo.LowestIdxStore->getMMO().getPointerInfo(),
3877                     MatchInfo.LowestIdxStore->getMMO().getAlign());
3878
3879  // Erase the old stores.
3880  for (auto *ST : MatchInfo.FoundStores)
3881    ST->eraseFromParent();
3882}
3883
3884bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
3885                                            MachineInstr *&ExtMI) {
3886  assert(MI.getOpcode() == TargetOpcode::G_PHI);
3887
3888  Register DstReg = MI.getOperand(0).getReg();
3889
3890  // TODO: Extending a vector may be expensive, don't do this until heuristics
3891  // are better.
3892  if (MRI.getType(DstReg).isVector())
3893    return false;
3894
3895  // Try to match a phi, whose only use is an extend.
3896  if (!MRI.hasOneNonDBGUse(DstReg))
3897    return false;
3898  ExtMI = &*MRI.use_instr_nodbg_begin(DstReg);
3899  switch (ExtMI->getOpcode()) {
3900  case TargetOpcode::G_ANYEXT:
3901    return true; // G_ANYEXT is usually free.
3902  case TargetOpcode::G_ZEXT:
3903  case TargetOpcode::G_SEXT:
3904    break;
3905  default:
3906    return false;
3907  }
3908
3909  // If the target is likely to fold this extend away, don't propagate.
3910  if (Builder.getTII().isExtendLikelyToBeFolded(*ExtMI, MRI))
3911    return false;
3912
3913  // We don't want to propagate the extends unless there's a good chance that
3914  // they'll be optimized in some way.
3915  // Collect the unique incoming values.
3916  SmallPtrSet<MachineInstr *, 4> InSrcs;
3917  for (unsigned Idx = 1; Idx < MI.getNumOperands(); Idx += 2) {
3918    auto *DefMI = getDefIgnoringCopies(MI.getOperand(Idx).getReg(), MRI);
3919    switch (DefMI->getOpcode()) {
3920    case TargetOpcode::G_LOAD:
3921    case TargetOpcode::G_TRUNC:
3922    case TargetOpcode::G_SEXT:
3923    case TargetOpcode::G_ZEXT:
3924    case TargetOpcode::G_ANYEXT:
3925    case TargetOpcode::G_CONSTANT:
3926      InSrcs.insert(getDefIgnoringCopies(MI.getOperand(Idx).getReg(), MRI));
3927      // Don't try to propagate if there are too many places to create new
3928      // extends, chances are it'll increase code size.
3929      if (InSrcs.size() > 2)
3930        return false;
3931      break;
3932    default:
3933      return false;
3934    }
3935  }
3936  return true;
3937}
3938
3939void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
3940                                            MachineInstr *&ExtMI) {
3941  assert(MI.getOpcode() == TargetOpcode::G_PHI);
3942  Register DstReg = ExtMI->getOperand(0).getReg();
3943  LLT ExtTy = MRI.getType(DstReg);
3944
3945  // Propagate the extension into the block of each incoming reg's block.
3946  // Use a SetVector here because PHIs can have duplicate edges, and we want
3947  // deterministic iteration order.
3948  SmallSetVector<MachineInstr *, 8> SrcMIs;
3949  SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
3950  for (unsigned SrcIdx = 1; SrcIdx < MI.getNumOperands(); SrcIdx += 2) {
3951    auto *SrcMI = MRI.getVRegDef(MI.getOperand(SrcIdx).getReg());
3952    if (!SrcMIs.insert(SrcMI))
3953      continue;
3954
3955    // Build an extend after each src inst.
3956    auto *MBB = SrcMI->getParent();
3957    MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
3958    if (InsertPt != MBB->end() && InsertPt->isPHI())
3959      InsertPt = MBB->getFirstNonPHI();
3960
3961    Builder.setInsertPt(*SrcMI->getParent(), InsertPt);
3962    Builder.setDebugLoc(MI.getDebugLoc());
3963    auto NewExt = Builder.buildExtOrTrunc(ExtMI->getOpcode(), ExtTy,
3964                                          SrcMI->getOperand(0).getReg());
3965    OldToNewSrcMap[SrcMI] = NewExt;
3966  }
3967
3968  // Create a new phi with the extended inputs.
3969  Builder.setInstrAndDebugLoc(MI);
3970  auto NewPhi = Builder.buildInstrNoInsert(TargetOpcode::G_PHI);
3971  NewPhi.addDef(DstReg);
3972  for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) {
3973    if (!MO.isReg()) {
3974      NewPhi.addMBB(MO.getMBB());
3975      continue;
3976    }
3977    auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(MO.getReg())];
3978    NewPhi.addUse(NewSrc->getOperand(0).getReg());
3979  }
3980  Builder.insertInstr(NewPhi);
3981  ExtMI->eraseFromParent();
3982}
3983
3984bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
3985                                                Register &Reg) {
3986  assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
3987  // If we have a constant index, look for a G_BUILD_VECTOR source
3988  // and find the source register that the index maps to.
3989  Register SrcVec = MI.getOperand(1).getReg();
3990  LLT SrcTy = MRI.getType(SrcVec);
3991
3992  auto Cst = getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
3993  if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
3994    return false;
3995
3996  unsigned VecIdx = Cst->Value.getZExtValue();
3997
3998  // Check if we have a build_vector or build_vector_trunc with an optional
3999  // trunc in front.
4000  MachineInstr *SrcVecMI = MRI.getVRegDef(SrcVec);
4001  if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4002    SrcVecMI = MRI.getVRegDef(SrcVecMI->getOperand(1).getReg());
4003  }
4004
4005  if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4006      SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4007    return false;
4008
4009  EVT Ty(getMVTForLLT(SrcTy));
4010  if (!MRI.hasOneNonDBGUse(SrcVec) &&
4011      !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
4012    return false;
4013
4014  Reg = SrcVecMI->getOperand(VecIdx + 1).getReg();
4015  return true;
4016}
4017
4018void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4019                                                Register &Reg) {
4020  // Check the type of the register, since it may have come from a
4021  // G_BUILD_VECTOR_TRUNC.
4022  LLT ScalarTy = MRI.getType(Reg);
4023  Register DstReg = MI.getOperand(0).getReg();
4024  LLT DstTy = MRI.getType(DstReg);
4025
4026  Builder.setInstrAndDebugLoc(MI);
4027  if (ScalarTy != DstTy) {
4028    assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4029    Builder.buildTrunc(DstReg, Reg);
4030    MI.eraseFromParent();
4031    return;
4032  }
4033  replaceSingleDefInstWithReg(MI, Reg);
4034}
4035
4036bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4037    MachineInstr &MI,
4038    SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) {
4039  assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4040  // This combine tries to find build_vector's which have every source element
4041  // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4042  // the masked load scalarization is run late in the pipeline. There's already
4043  // a combine for a similar pattern starting from the extract, but that
4044  // doesn't attempt to do it if there are multiple uses of the build_vector,
4045  // which in this case is true. Starting the combine from the build_vector
4046  // feels more natural than trying to find sibling nodes of extracts.
4047  // E.g.
4048  //  %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4049  //  %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4050  //  %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4051  //  %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4052  //  %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4053  // ==>
4054  // replace ext{1,2,3,4} with %s{1,2,3,4}
4055
4056  Register DstReg = MI.getOperand(0).getReg();
4057  LLT DstTy = MRI.getType(DstReg);
4058  unsigned NumElts = DstTy.getNumElements();
4059
4060  SmallBitVector ExtractedElts(NumElts);
4061  for (MachineInstr &II : MRI.use_nodbg_instructions(DstReg)) {
4062    if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4063      return false;
4064    auto Cst = getIConstantVRegVal(II.getOperand(2).getReg(), MRI);
4065    if (!Cst)
4066      return false;
4067    unsigned Idx = Cst->getZExtValue();
4068    if (Idx >= NumElts)
4069      return false; // Out of range.
4070    ExtractedElts.set(Idx);
4071    SrcDstPairs.emplace_back(
4072        std::make_pair(MI.getOperand(Idx + 1).getReg(), &II));
4073  }
4074  // Match if every element was extracted.
4075  return ExtractedElts.all();
4076}
4077
4078void CombinerHelper::applyExtractAllEltsFromBuildVector(
4079    MachineInstr &MI,
4080    SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) {
4081  assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4082  for (auto &Pair : SrcDstPairs) {
4083    auto *ExtMI = Pair.second;
4084    replaceRegWith(MRI, ExtMI->getOperand(0).getReg(), Pair.first);
4085    ExtMI->eraseFromParent();
4086  }
4087  MI.eraseFromParent();
4088}
4089
4090void CombinerHelper::applyBuildFn(
4091    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4092  Builder.setInstrAndDebugLoc(MI);
4093  MatchInfo(Builder);
4094  MI.eraseFromParent();
4095}
4096
4097void CombinerHelper::applyBuildFnNoErase(
4098    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4099  Builder.setInstrAndDebugLoc(MI);
4100  MatchInfo(Builder);
4101}
4102
4103bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4104                                               BuildFnTy &MatchInfo) {
4105  assert(MI.getOpcode() == TargetOpcode::G_OR);
4106
4107  Register Dst = MI.getOperand(0).getReg();
4108  LLT Ty = MRI.getType(Dst);
4109  unsigned BitWidth = Ty.getScalarSizeInBits();
4110
4111  Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4112  unsigned FshOpc = 0;
4113
4114  // Match (or (shl ...), (lshr ...)).
4115  if (!mi_match(Dst, MRI,
4116                // m_GOr() handles the commuted version as well.
4117                m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
4118                      m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)))))
4119    return false;
4120
4121  // Given constants C0 and C1 such that C0 + C1 is bit-width:
4122  // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4123  int64_t CstShlAmt, CstLShrAmt;
4124  if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) &&
4125      mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) &&
4126      CstShlAmt + CstLShrAmt == BitWidth) {
4127    FshOpc = TargetOpcode::G_FSHR;
4128    Amt = LShrAmt;
4129
4130  } else if (mi_match(LShrAmt, MRI,
4131                      m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) &&
4132             ShlAmt == Amt) {
4133    // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4134    FshOpc = TargetOpcode::G_FSHL;
4135
4136  } else if (mi_match(ShlAmt, MRI,
4137                      m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) &&
4138             LShrAmt == Amt) {
4139    // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4140    FshOpc = TargetOpcode::G_FSHR;
4141
4142  } else {
4143    return false;
4144  }
4145
4146  LLT AmtTy = MRI.getType(Amt);
4147  if (!isLegalOrBeforeLegalizer({FshOpc, {Ty, AmtTy}}))
4148    return false;
4149
4150  MatchInfo = [=](MachineIRBuilder &B) {
4151    B.buildInstr(FshOpc, {Dst}, {ShlSrc, LShrSrc, Amt});
4152  };
4153  return true;
4154}
4155
4156/// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4157bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) {
4158  unsigned Opc = MI.getOpcode();
4159  assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4160  Register X = MI.getOperand(1).getReg();
4161  Register Y = MI.getOperand(2).getReg();
4162  if (X != Y)
4163    return false;
4164  unsigned RotateOpc =
4165      Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4166  return isLegalOrBeforeLegalizer({RotateOpc, {MRI.getType(X), MRI.getType(Y)}});
4167}
4168
4169void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) {
4170  unsigned Opc = MI.getOpcode();
4171  assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4172  bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4173  Observer.changingInstr(MI);
4174  MI.setDesc(Builder.getTII().get(IsFSHL ? TargetOpcode::G_ROTL
4175                                         : TargetOpcode::G_ROTR));
4176  MI.removeOperand(2);
4177  Observer.changedInstr(MI);
4178}
4179
4180// Fold (rot x, c) -> (rot x, c % BitSize)
4181bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) {
4182  assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4183         MI.getOpcode() == TargetOpcode::G_ROTR);
4184  unsigned Bitsize =
4185      MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits();
4186  Register AmtReg = MI.getOperand(2).getReg();
4187  bool OutOfRange = false;
4188  auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4189    if (auto *CI = dyn_cast<ConstantInt>(C))
4190      OutOfRange |= CI->getValue().uge(Bitsize);
4191    return true;
4192  };
4193  return matchUnaryPredicate(MRI, AmtReg, MatchOutOfRange) && OutOfRange;
4194}
4195
4196void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) {
4197  assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4198         MI.getOpcode() == TargetOpcode::G_ROTR);
4199  unsigned Bitsize =
4200      MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits();
4201  Builder.setInstrAndDebugLoc(MI);
4202  Register Amt = MI.getOperand(2).getReg();
4203  LLT AmtTy = MRI.getType(Amt);
4204  auto Bits = Builder.buildConstant(AmtTy, Bitsize);
4205  Amt = Builder.buildURem(AmtTy, MI.getOperand(2).getReg(), Bits).getReg(0);
4206  Observer.changingInstr(MI);
4207  MI.getOperand(2).setReg(Amt);
4208  Observer.changedInstr(MI);
4209}
4210
4211bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4212                                                   int64_t &MatchInfo) {
4213  assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4214  auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
4215  auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());
4216  auto KnownRHS = KB->getKnownBits(MI.getOperand(3).getReg());
4217  std::optional<bool> KnownVal;
4218  switch (Pred) {
4219  default:
4220    llvm_unreachable("Unexpected G_ICMP predicate?");
4221  case CmpInst::ICMP_EQ:
4222    KnownVal = KnownBits::eq(KnownLHS, KnownRHS);
4223    break;
4224  case CmpInst::ICMP_NE:
4225    KnownVal = KnownBits::ne(KnownLHS, KnownRHS);
4226    break;
4227  case CmpInst::ICMP_SGE:
4228    KnownVal = KnownBits::sge(KnownLHS, KnownRHS);
4229    break;
4230  case CmpInst::ICMP_SGT:
4231    KnownVal = KnownBits::sgt(KnownLHS, KnownRHS);
4232    break;
4233  case CmpInst::ICMP_SLE:
4234    KnownVal = KnownBits::sle(KnownLHS, KnownRHS);
4235    break;
4236  case CmpInst::ICMP_SLT:
4237    KnownVal = KnownBits::slt(KnownLHS, KnownRHS);
4238    break;
4239  case CmpInst::ICMP_UGE:
4240    KnownVal = KnownBits::uge(KnownLHS, KnownRHS);
4241    break;
4242  case CmpInst::ICMP_UGT:
4243    KnownVal = KnownBits::ugt(KnownLHS, KnownRHS);
4244    break;
4245  case CmpInst::ICMP_ULE:
4246    KnownVal = KnownBits::ule(KnownLHS, KnownRHS);
4247    break;
4248  case CmpInst::ICMP_ULT:
4249    KnownVal = KnownBits::ult(KnownLHS, KnownRHS);
4250    break;
4251  }
4252  if (!KnownVal)
4253    return false;
4254  MatchInfo =
4255      *KnownVal
4256          ? getICmpTrueVal(getTargetLowering(),
4257                           /*IsVector = */
4258                           MRI.getType(MI.getOperand(0).getReg()).isVector(),
4259                           /* IsFP = */ false)
4260          : 0;
4261  return true;
4262}
4263
4264bool CombinerHelper::matchICmpToLHSKnownBits(
4265    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4266  assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4267  // Given:
4268  //
4269  // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4270  // %cmp = G_ICMP ne %x, 0
4271  //
4272  // Or:
4273  //
4274  // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4275  // %cmp = G_ICMP eq %x, 1
4276  //
4277  // We can replace %cmp with %x assuming true is 1 on the target.
4278  auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
4279  if (!CmpInst::isEquality(Pred))
4280    return false;
4281  Register Dst = MI.getOperand(0).getReg();
4282  LLT DstTy = MRI.getType(Dst);
4283  if (getICmpTrueVal(getTargetLowering(), DstTy.isVector(),
4284                     /* IsFP = */ false) != 1)
4285    return false;
4286  int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4287  if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(OneOrZero)))
4288    return false;
4289  Register LHS = MI.getOperand(2).getReg();
4290  auto KnownLHS = KB->getKnownBits(LHS);
4291  if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4292    return false;
4293  // Make sure replacing Dst with the LHS is a legal operation.
4294  LLT LHSTy = MRI.getType(LHS);
4295  unsigned LHSSize = LHSTy.getSizeInBits();
4296  unsigned DstSize = DstTy.getSizeInBits();
4297  unsigned Op = TargetOpcode::COPY;
4298  if (DstSize != LHSSize)
4299    Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4300  if (!isLegalOrBeforeLegalizer({Op, {DstTy, LHSTy}}))
4301    return false;
4302  MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Op, {Dst}, {LHS}); };
4303  return true;
4304}
4305
4306// Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4307bool CombinerHelper::matchAndOrDisjointMask(
4308    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4309  assert(MI.getOpcode() == TargetOpcode::G_AND);
4310
4311  // Ignore vector types to simplify matching the two constants.
4312  // TODO: do this for vectors and scalars via a demanded bits analysis.
4313  LLT Ty = MRI.getType(MI.getOperand(0).getReg());
4314  if (Ty.isVector())
4315    return false;
4316
4317  Register Src;
4318  Register AndMaskReg;
4319  int64_t AndMaskBits;
4320  int64_t OrMaskBits;
4321  if (!mi_match(MI, MRI,
4322                m_GAnd(m_GOr(m_Reg(Src), m_ICst(OrMaskBits)),
4323                       m_all_of(m_ICst(AndMaskBits), m_Reg(AndMaskReg)))))
4324    return false;
4325
4326  // Check if OrMask could turn on any bits in Src.
4327  if (AndMaskBits & OrMaskBits)
4328    return false;
4329
4330  MatchInfo = [=, &MI](MachineIRBuilder &B) {
4331    Observer.changingInstr(MI);
4332    // Canonicalize the result to have the constant on the RHS.
4333    if (MI.getOperand(1).getReg() == AndMaskReg)
4334      MI.getOperand(2).setReg(AndMaskReg);
4335    MI.getOperand(1).setReg(Src);
4336    Observer.changedInstr(MI);
4337  };
4338  return true;
4339}
4340
4341/// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4342bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4343    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4344  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4345  Register Dst = MI.getOperand(0).getReg();
4346  Register Src = MI.getOperand(1).getReg();
4347  LLT Ty = MRI.getType(Src);
4348  LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4349  if (!LI || !LI->isLegalOrCustom({TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4350    return false;
4351  int64_t Width = MI.getOperand(2).getImm();
4352  Register ShiftSrc;
4353  int64_t ShiftImm;
4354  if (!mi_match(
4355          Src, MRI,
4356          m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)),
4357                                  m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm))))))
4358    return false;
4359  if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4360    return false;
4361
4362  MatchInfo = [=](MachineIRBuilder &B) {
4363    auto Cst1 = B.buildConstant(ExtractTy, ShiftImm);
4364    auto Cst2 = B.buildConstant(ExtractTy, Width);
4365    B.buildSbfx(Dst, ShiftSrc, Cst1, Cst2);
4366  };
4367  return true;
4368}
4369
4370/// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4371bool CombinerHelper::matchBitfieldExtractFromAnd(
4372    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4373  assert(MI.getOpcode() == TargetOpcode::G_AND);
4374  Register Dst = MI.getOperand(0).getReg();
4375  LLT Ty = MRI.getType(Dst);
4376  LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4377  if (!getTargetLowering().isConstantUnsignedBitfieldExtractLegal(
4378          TargetOpcode::G_UBFX, Ty, ExtractTy))
4379    return false;
4380
4381  int64_t AndImm, LSBImm;
4382  Register ShiftSrc;
4383  const unsigned Size = Ty.getScalarSizeInBits();
4384  if (!mi_match(MI.getOperand(0).getReg(), MRI,
4385                m_GAnd(m_OneNonDBGUse(m_GLShr(m_Reg(ShiftSrc), m_ICst(LSBImm))),
4386                       m_ICst(AndImm))))
4387    return false;
4388
4389  // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4390  auto MaybeMask = static_cast<uint64_t>(AndImm);
4391  if (MaybeMask & (MaybeMask + 1))
4392    return false;
4393
4394  // LSB must fit within the register.
4395  if (static_cast<uint64_t>(LSBImm) >= Size)
4396    return false;
4397
4398  uint64_t Width = APInt(Size, AndImm).countTrailingOnes();
4399  MatchInfo = [=](MachineIRBuilder &B) {
4400    auto WidthCst = B.buildConstant(ExtractTy, Width);
4401    auto LSBCst = B.buildConstant(ExtractTy, LSBImm);
4402    B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {ShiftSrc, LSBCst, WidthCst});
4403  };
4404  return true;
4405}
4406
4407bool CombinerHelper::matchBitfieldExtractFromShr(
4408    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4409  const unsigned Opcode = MI.getOpcode();
4410  assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4411
4412  const Register Dst = MI.getOperand(0).getReg();
4413
4414  const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4415                                  ? TargetOpcode::G_SBFX
4416                                  : TargetOpcode::G_UBFX;
4417
4418  // Check if the type we would use for the extract is legal
4419  LLT Ty = MRI.getType(Dst);
4420  LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4421  if (!LI || !LI->isLegalOrCustom({ExtrOpcode, {Ty, ExtractTy}}))
4422    return false;
4423
4424  Register ShlSrc;
4425  int64_t ShrAmt;
4426  int64_t ShlAmt;
4427  const unsigned Size = Ty.getScalarSizeInBits();
4428
4429  // Try to match shr (shl x, c1), c2
4430  if (!mi_match(Dst, MRI,
4431                m_BinOp(Opcode,
4432                        m_OneNonDBGUse(m_GShl(m_Reg(ShlSrc), m_ICst(ShlAmt))),
4433                        m_ICst(ShrAmt))))
4434    return false;
4435
4436  // Make sure that the shift sizes can fit a bitfield extract
4437  if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4438    return false;
4439
4440  // Skip this combine if the G_SEXT_INREG combine could handle it
4441  if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4442    return false;
4443
4444  // Calculate start position and width of the extract
4445  const int64_t Pos = ShrAmt - ShlAmt;
4446  const int64_t Width = Size - ShrAmt;
4447
4448  MatchInfo = [=](MachineIRBuilder &B) {
4449    auto WidthCst = B.buildConstant(ExtractTy, Width);
4450    auto PosCst = B.buildConstant(ExtractTy, Pos);
4451    B.buildInstr(ExtrOpcode, {Dst}, {ShlSrc, PosCst, WidthCst});
4452  };
4453  return true;
4454}
4455
4456bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4457    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4458  const unsigned Opcode = MI.getOpcode();
4459  assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4460
4461  const Register Dst = MI.getOperand(0).getReg();
4462  LLT Ty = MRI.getType(Dst);
4463  LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4464  if (!getTargetLowering().isConstantUnsignedBitfieldExtractLegal(
4465          TargetOpcode::G_UBFX, Ty, ExtractTy))
4466    return false;
4467
4468  // Try to match shr (and x, c1), c2
4469  Register AndSrc;
4470  int64_t ShrAmt;
4471  int64_t SMask;
4472  if (!mi_match(Dst, MRI,
4473                m_BinOp(Opcode,
4474                        m_OneNonDBGUse(m_GAnd(m_Reg(AndSrc), m_ICst(SMask))),
4475                        m_ICst(ShrAmt))))
4476    return false;
4477
4478  const unsigned Size = Ty.getScalarSizeInBits();
4479  if (ShrAmt < 0 || ShrAmt >= Size)
4480    return false;
4481
4482  // If the shift subsumes the mask, emit the 0 directly.
4483  if (0 == (SMask >> ShrAmt)) {
4484    MatchInfo = [=](MachineIRBuilder &B) {
4485      B.buildConstant(Dst, 0);
4486    };
4487    return true;
4488  }
4489
4490  // Check that ubfx can do the extraction, with no holes in the mask.
4491  uint64_t UMask = SMask;
4492  UMask |= maskTrailingOnes<uint64_t>(ShrAmt);
4493  UMask &= maskTrailingOnes<uint64_t>(Size);
4494  if (!isMask_64(UMask))
4495    return false;
4496
4497  // Calculate start position and width of the extract.
4498  const int64_t Pos = ShrAmt;
4499  const int64_t Width = countTrailingOnes(UMask) - ShrAmt;
4500
4501  // It's preferable to keep the shift, rather than form G_SBFX.
4502  // TODO: remove the G_AND via demanded bits analysis.
4503  if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4504    return false;
4505
4506  MatchInfo = [=](MachineIRBuilder &B) {
4507    auto WidthCst = B.buildConstant(ExtractTy, Width);
4508    auto PosCst = B.buildConstant(ExtractTy, Pos);
4509    B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {AndSrc, PosCst, WidthCst});
4510  };
4511  return true;
4512}
4513
4514bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4515    MachineInstr &PtrAdd) {
4516  assert(PtrAdd.getOpcode() == TargetOpcode::G_PTR_ADD);
4517
4518  Register Src1Reg = PtrAdd.getOperand(1).getReg();
4519  MachineInstr *Src1Def = getOpcodeDef(TargetOpcode::G_PTR_ADD, Src1Reg, MRI);
4520  if (!Src1Def)
4521    return false;
4522
4523  Register Src2Reg = PtrAdd.getOperand(2).getReg();
4524
4525  if (MRI.hasOneNonDBGUse(Src1Reg))
4526    return false;
4527
4528  auto C1 = getIConstantVRegVal(Src1Def->getOperand(2).getReg(), MRI);
4529  if (!C1)
4530    return false;
4531  auto C2 = getIConstantVRegVal(Src2Reg, MRI);
4532  if (!C2)
4533    return false;
4534
4535  const APInt &C1APIntVal = *C1;
4536  const APInt &C2APIntVal = *C2;
4537  const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4538
4539  for (auto &UseMI : MRI.use_nodbg_instructions(Src1Reg)) {
4540    // This combine may end up running before ptrtoint/inttoptr combines
4541    // manage to eliminate redundant conversions, so try to look through them.
4542    MachineInstr *ConvUseMI = &UseMI;
4543    unsigned ConvUseOpc = ConvUseMI->getOpcode();
4544    while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4545           ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4546      Register DefReg = ConvUseMI->getOperand(0).getReg();
4547      if (!MRI.hasOneNonDBGUse(DefReg))
4548        break;
4549      ConvUseMI = &*MRI.use_instr_nodbg_begin(DefReg);
4550      ConvUseOpc = ConvUseMI->getOpcode();
4551    }
4552    auto LoadStore = ConvUseOpc == TargetOpcode::G_LOAD ||
4553                     ConvUseOpc == TargetOpcode::G_STORE;
4554    if (!LoadStore)
4555      continue;
4556    // Is x[offset2] already not a legal addressing mode? If so then
4557    // reassociating the constants breaks nothing (we test offset2 because
4558    // that's the one we hope to fold into the load or store).
4559    TargetLoweringBase::AddrMode AM;
4560    AM.HasBaseReg = true;
4561    AM.BaseOffs = C2APIntVal.getSExtValue();
4562    unsigned AS =
4563        MRI.getType(ConvUseMI->getOperand(1).getReg()).getAddressSpace();
4564    Type *AccessTy =
4565        getTypeForLLT(MRI.getType(ConvUseMI->getOperand(0).getReg()),
4566                      PtrAdd.getMF()->getFunction().getContext());
4567    const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4568    if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM,
4569                                   AccessTy, AS))
4570      continue;
4571
4572    // Would x[offset1+offset2] still be a legal addressing mode?
4573    AM.BaseOffs = CombinedValue;
4574    if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM,
4575                                   AccessTy, AS))
4576      return true;
4577  }
4578
4579  return false;
4580}
4581
4582bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
4583                                                  MachineInstr *RHS,
4584                                                  BuildFnTy &MatchInfo) {
4585  // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4586  Register Src1Reg = MI.getOperand(1).getReg();
4587  if (RHS->getOpcode() != TargetOpcode::G_ADD)
4588    return false;
4589  auto C2 = getIConstantVRegVal(RHS->getOperand(2).getReg(), MRI);
4590  if (!C2)
4591    return false;
4592
4593  MatchInfo = [=, &MI](MachineIRBuilder &B) {
4594    LLT PtrTy = MRI.getType(MI.getOperand(0).getReg());
4595
4596    auto NewBase =
4597        Builder.buildPtrAdd(PtrTy, Src1Reg, RHS->getOperand(1).getReg());
4598    Observer.changingInstr(MI);
4599    MI.getOperand(1).setReg(NewBase.getReg(0));
4600    MI.getOperand(2).setReg(RHS->getOperand(2).getReg());
4601    Observer.changedInstr(MI);
4602  };
4603  return !reassociationCanBreakAddressingModePattern(MI);
4604}
4605
4606bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
4607                                                  MachineInstr *LHS,
4608                                                  MachineInstr *RHS,
4609                                                  BuildFnTy &MatchInfo) {
4610  // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
4611  // if and only if (G_PTR_ADD X, C) has one use.
4612  Register LHSBase;
4613  std::optional<ValueAndVReg> LHSCstOff;
4614  if (!mi_match(MI.getBaseReg(), MRI,
4615                m_OneNonDBGUse(m_GPtrAdd(m_Reg(LHSBase), m_GCst(LHSCstOff)))))
4616    return false;
4617
4618  auto *LHSPtrAdd = cast<GPtrAdd>(LHS);
4619  MatchInfo = [=, &MI](MachineIRBuilder &B) {
4620    // When we change LHSPtrAdd's offset register we might cause it to use a reg
4621    // before its def. Sink the instruction so the outer PTR_ADD to ensure this
4622    // doesn't happen.
4623    LHSPtrAdd->moveBefore(&MI);
4624    Register RHSReg = MI.getOffsetReg();
4625    // set VReg will cause type mismatch if it comes from extend/trunc
4626    auto NewCst = B.buildConstant(MRI.getType(RHSReg), LHSCstOff->Value);
4627    Observer.changingInstr(MI);
4628    MI.getOperand(2).setReg(NewCst.getReg(0));
4629    Observer.changedInstr(MI);
4630    Observer.changingInstr(*LHSPtrAdd);
4631    LHSPtrAdd->getOperand(2).setReg(RHSReg);
4632    Observer.changedInstr(*LHSPtrAdd);
4633  };
4634  return !reassociationCanBreakAddressingModePattern(MI);
4635}
4636
4637bool CombinerHelper::matchReassocFoldConstantsInSubTree(GPtrAdd &MI,
4638                                                        MachineInstr *LHS,
4639                                                        MachineInstr *RHS,
4640                                                        BuildFnTy &MatchInfo) {
4641  // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4642  auto *LHSPtrAdd = dyn_cast<GPtrAdd>(LHS);
4643  if (!LHSPtrAdd)
4644    return false;
4645
4646  Register Src2Reg = MI.getOperand(2).getReg();
4647  Register LHSSrc1 = LHSPtrAdd->getBaseReg();
4648  Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
4649  auto C1 = getIConstantVRegVal(LHSSrc2, MRI);
4650  if (!C1)
4651    return false;
4652  auto C2 = getIConstantVRegVal(Src2Reg, MRI);
4653  if (!C2)
4654    return false;
4655
4656  MatchInfo = [=, &MI](MachineIRBuilder &B) {
4657    auto NewCst = B.buildConstant(MRI.getType(Src2Reg), *C1 + *C2);
4658    Observer.changingInstr(MI);
4659    MI.getOperand(1).setReg(LHSSrc1);
4660    MI.getOperand(2).setReg(NewCst.getReg(0));
4661    Observer.changedInstr(MI);
4662  };
4663  return !reassociationCanBreakAddressingModePattern(MI);
4664}
4665
4666bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
4667                                        BuildFnTy &MatchInfo) {
4668  auto &PtrAdd = cast<GPtrAdd>(MI);
4669  // We're trying to match a few pointer computation patterns here for
4670  // re-association opportunities.
4671  // 1) Isolating a constant operand to be on the RHS, e.g.:
4672  // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4673  //
4674  // 2) Folding two constants in each sub-tree as long as such folding
4675  // doesn't break a legal addressing mode.
4676  // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4677  //
4678  // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
4679  // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
4680  // iif (G_PTR_ADD X, C) has one use.
4681  MachineInstr *LHS = MRI.getVRegDef(PtrAdd.getBaseReg());
4682  MachineInstr *RHS = MRI.getVRegDef(PtrAdd.getOffsetReg());
4683
4684  // Try to match example 2.
4685  if (matchReassocFoldConstantsInSubTree(PtrAdd, LHS, RHS, MatchInfo))
4686    return true;
4687
4688  // Try to match example 3.
4689  if (matchReassocConstantInnerLHS(PtrAdd, LHS, RHS, MatchInfo))
4690    return true;
4691
4692  // Try to match example 1.
4693  if (matchReassocConstantInnerRHS(PtrAdd, RHS, MatchInfo))
4694    return true;
4695
4696  return false;
4697}
4698
4699bool CombinerHelper::matchConstantFold(MachineInstr &MI, APInt &MatchInfo) {
4700  Register Op1 = MI.getOperand(1).getReg();
4701  Register Op2 = MI.getOperand(2).getReg();
4702  auto MaybeCst = ConstantFoldBinOp(MI.getOpcode(), Op1, Op2, MRI);
4703  if (!MaybeCst)
4704    return false;
4705  MatchInfo = *MaybeCst;
4706  return true;
4707}
4708
4709bool CombinerHelper::matchNarrowBinopFeedingAnd(
4710    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4711  // Look for a binop feeding into an AND with a mask:
4712  //
4713  // %add = G_ADD %lhs, %rhs
4714  // %and = G_AND %add, 000...11111111
4715  //
4716  // Check if it's possible to perform the binop at a narrower width and zext
4717  // back to the original width like so:
4718  //
4719  // %narrow_lhs = G_TRUNC %lhs
4720  // %narrow_rhs = G_TRUNC %rhs
4721  // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
4722  // %new_add = G_ZEXT %narrow_add
4723  // %and = G_AND %new_add, 000...11111111
4724  //
4725  // This can allow later combines to eliminate the G_AND if it turns out
4726  // that the mask is irrelevant.
4727  assert(MI.getOpcode() == TargetOpcode::G_AND);
4728  Register Dst = MI.getOperand(0).getReg();
4729  Register AndLHS = MI.getOperand(1).getReg();
4730  Register AndRHS = MI.getOperand(2).getReg();
4731  LLT WideTy = MRI.getType(Dst);
4732
4733  // If the potential binop has more than one use, then it's possible that one
4734  // of those uses will need its full width.
4735  if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(AndLHS))
4736    return false;
4737
4738  // Check if the LHS feeding the AND is impacted by the high bits that we're
4739  // masking out.
4740  //
4741  // e.g. for 64-bit x, y:
4742  //
4743  // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
4744  MachineInstr *LHSInst = getDefIgnoringCopies(AndLHS, MRI);
4745  if (!LHSInst)
4746    return false;
4747  unsigned LHSOpc = LHSInst->getOpcode();
4748  switch (LHSOpc) {
4749  default:
4750    return false;
4751  case TargetOpcode::G_ADD:
4752  case TargetOpcode::G_SUB:
4753  case TargetOpcode::G_MUL:
4754  case TargetOpcode::G_AND:
4755  case TargetOpcode::G_OR:
4756  case TargetOpcode::G_XOR:
4757    break;
4758  }
4759
4760  // Find the mask on the RHS.
4761  auto Cst = getIConstantVRegValWithLookThrough(AndRHS, MRI);
4762  if (!Cst)
4763    return false;
4764  auto Mask = Cst->Value;
4765  if (!Mask.isMask())
4766    return false;
4767
4768  // No point in combining if there's nothing to truncate.
4769  unsigned NarrowWidth = Mask.countTrailingOnes();
4770  if (NarrowWidth == WideTy.getSizeInBits())
4771    return false;
4772  LLT NarrowTy = LLT::scalar(NarrowWidth);
4773
4774  // Check if adding the zext + truncates could be harmful.
4775  auto &MF = *MI.getMF();
4776  const auto &TLI = getTargetLowering();
4777  LLVMContext &Ctx = MF.getFunction().getContext();
4778  auto &DL = MF.getDataLayout();
4779  if (!TLI.isTruncateFree(WideTy, NarrowTy, DL, Ctx) ||
4780      !TLI.isZExtFree(NarrowTy, WideTy, DL, Ctx))
4781    return false;
4782  if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
4783      !isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
4784    return false;
4785  Register BinOpLHS = LHSInst->getOperand(1).getReg();
4786  Register BinOpRHS = LHSInst->getOperand(2).getReg();
4787  MatchInfo = [=, &MI](MachineIRBuilder &B) {
4788    auto NarrowLHS = Builder.buildTrunc(NarrowTy, BinOpLHS);
4789    auto NarrowRHS = Builder.buildTrunc(NarrowTy, BinOpRHS);
4790    auto NarrowBinOp =
4791        Builder.buildInstr(LHSOpc, {NarrowTy}, {NarrowLHS, NarrowRHS});
4792    auto Ext = Builder.buildZExt(WideTy, NarrowBinOp);
4793    Observer.changingInstr(MI);
4794    MI.getOperand(1).setReg(Ext.getReg(0));
4795    Observer.changedInstr(MI);
4796  };
4797  return true;
4798}
4799
4800bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) {
4801  unsigned Opc = MI.getOpcode();
4802  assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
4803
4804  if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2)))
4805    return false;
4806
4807  MatchInfo = [=, &MI](MachineIRBuilder &B) {
4808    Observer.changingInstr(MI);
4809    unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
4810                                                   : TargetOpcode::G_SADDO;
4811    MI.setDesc(Builder.getTII().get(NewOpc));
4812    MI.getOperand(3).setReg(MI.getOperand(2).getReg());
4813    Observer.changedInstr(MI);
4814  };
4815  return true;
4816}
4817
4818bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4819  // (G_*MULO x, 0) -> 0 + no carry out
4820  assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
4821         MI.getOpcode() == TargetOpcode::G_SMULO);
4822  if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
4823    return false;
4824  Register Dst = MI.getOperand(0).getReg();
4825  Register Carry = MI.getOperand(1).getReg();
4826  if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Dst)) ||
4827      !isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
4828    return false;
4829  MatchInfo = [=](MachineIRBuilder &B) {
4830    B.buildConstant(Dst, 0);
4831    B.buildConstant(Carry, 0);
4832  };
4833  return true;
4834}
4835
4836bool CombinerHelper::matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4837  // (G_*ADDO x, 0) -> x + no carry out
4838  assert(MI.getOpcode() == TargetOpcode::G_UADDO ||
4839         MI.getOpcode() == TargetOpcode::G_SADDO);
4840  if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
4841    return false;
4842  Register Carry = MI.getOperand(1).getReg();
4843  if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
4844    return false;
4845  Register Dst = MI.getOperand(0).getReg();
4846  Register LHS = MI.getOperand(2).getReg();
4847  MatchInfo = [=](MachineIRBuilder &B) {
4848    B.buildCopy(Dst, LHS);
4849    B.buildConstant(Carry, 0);
4850  };
4851  return true;
4852}
4853
4854bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) {
4855  // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
4856  // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
4857  assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
4858         MI.getOpcode() == TargetOpcode::G_SADDE ||
4859         MI.getOpcode() == TargetOpcode::G_USUBE ||
4860         MI.getOpcode() == TargetOpcode::G_SSUBE);
4861  if (!mi_match(MI.getOperand(4).getReg(), MRI, m_SpecificICstOrSplat(0)))
4862    return false;
4863  MatchInfo = [&](MachineIRBuilder &B) {
4864    unsigned NewOpcode;
4865    switch (MI.getOpcode()) {
4866    case TargetOpcode::G_UADDE:
4867      NewOpcode = TargetOpcode::G_UADDO;
4868      break;
4869    case TargetOpcode::G_SADDE:
4870      NewOpcode = TargetOpcode::G_SADDO;
4871      break;
4872    case TargetOpcode::G_USUBE:
4873      NewOpcode = TargetOpcode::G_USUBO;
4874      break;
4875    case TargetOpcode::G_SSUBE:
4876      NewOpcode = TargetOpcode::G_SSUBO;
4877      break;
4878    }
4879    Observer.changingInstr(MI);
4880    MI.setDesc(B.getTII().get(NewOpcode));
4881    MI.removeOperand(4);
4882    Observer.changedInstr(MI);
4883  };
4884  return true;
4885}
4886
4887bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
4888                                        BuildFnTy &MatchInfo) {
4889  assert(MI.getOpcode() == TargetOpcode::G_SUB);
4890  Register Dst = MI.getOperand(0).getReg();
4891  // (x + y) - z -> x (if y == z)
4892  // (x + y) - z -> y (if x == z)
4893  Register X, Y, Z;
4894  if (mi_match(Dst, MRI, m_GSub(m_GAdd(m_Reg(X), m_Reg(Y)), m_Reg(Z)))) {
4895    Register ReplaceReg;
4896    int64_t CstX, CstY;
4897    if (Y == Z || (mi_match(Y, MRI, m_ICstOrSplat(CstY)) &&
4898                   mi_match(Z, MRI, m_SpecificICstOrSplat(CstY))))
4899      ReplaceReg = X;
4900    else if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
4901                        mi_match(Z, MRI, m_SpecificICstOrSplat(CstX))))
4902      ReplaceReg = Y;
4903    if (ReplaceReg) {
4904      MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, ReplaceReg); };
4905      return true;
4906    }
4907  }
4908
4909  // x - (y + z) -> 0 - y (if x == z)
4910  // x - (y + z) -> 0 - z (if x == y)
4911  if (mi_match(Dst, MRI, m_GSub(m_Reg(X), m_GAdd(m_Reg(Y), m_Reg(Z))))) {
4912    Register ReplaceReg;
4913    int64_t CstX;
4914    if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
4915                   mi_match(Z, MRI, m_SpecificICstOrSplat(CstX))))
4916      ReplaceReg = Y;
4917    else if (X == Y || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
4918                        mi_match(Y, MRI, m_SpecificICstOrSplat(CstX))))
4919      ReplaceReg = Z;
4920    if (ReplaceReg) {
4921      MatchInfo = [=](MachineIRBuilder &B) {
4922        auto Zero = B.buildConstant(MRI.getType(Dst), 0);
4923        B.buildSub(Dst, Zero, ReplaceReg);
4924      };
4925      return true;
4926    }
4927  }
4928  return false;
4929}
4930
4931MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) {
4932  assert(MI.getOpcode() == TargetOpcode::G_UDIV);
4933  auto &UDiv = cast<GenericMachineInstr>(MI);
4934  Register Dst = UDiv.getReg(0);
4935  Register LHS = UDiv.getReg(1);
4936  Register RHS = UDiv.getReg(2);
4937  LLT Ty = MRI.getType(Dst);
4938  LLT ScalarTy = Ty.getScalarType();
4939  const unsigned EltBits = ScalarTy.getScalarSizeInBits();
4940  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4941  LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
4942  auto &MIB = Builder;
4943  MIB.setInstrAndDebugLoc(MI);
4944
4945  bool UseNPQ = false;
4946  SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
4947
4948  auto BuildUDIVPattern = [&](const Constant *C) {
4949    auto *CI = cast<ConstantInt>(C);
4950    const APInt &Divisor = CI->getValue();
4951
4952    bool SelNPQ = false;
4953    APInt Magic(Divisor.getBitWidth(), 0);
4954    unsigned PreShift = 0, PostShift = 0;
4955
4956    // Magic algorithm doesn't work for division by 1. We need to emit a select
4957    // at the end.
4958    // TODO: Use undef values for divisor of 1.
4959    if (!Divisor.isOneValue()) {
4960      UnsignedDivisionByConstantInfo magics =
4961          UnsignedDivisionByConstantInfo::get(Divisor);
4962
4963      Magic = std::move(magics.Magic);
4964
4965      assert(magics.PreShift < Divisor.getBitWidth() &&
4966             "We shouldn't generate an undefined shift!");
4967      assert(magics.PostShift < Divisor.getBitWidth() &&
4968             "We shouldn't generate an undefined shift!");
4969      assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
4970      PreShift = magics.PreShift;
4971      PostShift = magics.PostShift;
4972      SelNPQ = magics.IsAdd;
4973    }
4974
4975    PreShifts.push_back(
4976        MIB.buildConstant(ScalarShiftAmtTy, PreShift).getReg(0));
4977    MagicFactors.push_back(MIB.buildConstant(ScalarTy, Magic).getReg(0));
4978    NPQFactors.push_back(
4979        MIB.buildConstant(ScalarTy,
4980                          SelNPQ ? APInt::getOneBitSet(EltBits, EltBits - 1)
4981                                 : APInt::getZero(EltBits))
4982            .getReg(0));
4983    PostShifts.push_back(
4984        MIB.buildConstant(ScalarShiftAmtTy, PostShift).getReg(0));
4985    UseNPQ |= SelNPQ;
4986    return true;
4987  };
4988
4989  // Collect the shifts/magic values from each element.
4990  bool Matched = matchUnaryPredicate(MRI, RHS, BuildUDIVPattern);
4991  (void)Matched;
4992  assert(Matched && "Expected unary predicate match to succeed");
4993
4994  Register PreShift, PostShift, MagicFactor, NPQFactor;
4995  auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
4996  if (RHSDef) {
4997    PreShift = MIB.buildBuildVector(ShiftAmtTy, PreShifts).getReg(0);
4998    MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
4999    NPQFactor = MIB.buildBuildVector(Ty, NPQFactors).getReg(0);
5000    PostShift = MIB.buildBuildVector(ShiftAmtTy, PostShifts).getReg(0);
5001  } else {
5002    assert(MRI.getType(RHS).isScalar() &&
5003           "Non-build_vector operation should have been a scalar");
5004    PreShift = PreShifts[0];
5005    MagicFactor = MagicFactors[0];
5006    PostShift = PostShifts[0];
5007  }
5008
5009  Register Q = LHS;
5010  Q = MIB.buildLShr(Ty, Q, PreShift).getReg(0);
5011
5012  // Multiply the numerator (operand 0) by the magic value.
5013  Q = MIB.buildUMulH(Ty, Q, MagicFactor).getReg(0);
5014
5015  if (UseNPQ) {
5016    Register NPQ = MIB.buildSub(Ty, LHS, Q).getReg(0);
5017
5018    // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5019    // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5020    if (Ty.isVector())
5021      NPQ = MIB.buildUMulH(Ty, NPQ, NPQFactor).getReg(0);
5022    else
5023      NPQ = MIB.buildLShr(Ty, NPQ, MIB.buildConstant(ShiftAmtTy, 1)).getReg(0);
5024
5025    Q = MIB.buildAdd(Ty, NPQ, Q).getReg(0);
5026  }
5027
5028  Q = MIB.buildLShr(Ty, Q, PostShift).getReg(0);
5029  auto One = MIB.buildConstant(Ty, 1);
5030  auto IsOne = MIB.buildICmp(
5031      CmpInst::Predicate::ICMP_EQ,
5032      Ty.isScalar() ? LLT::scalar(1) : Ty.changeElementSize(1), RHS, One);
5033  return MIB.buildSelect(Ty, IsOne, LHS, Q);
5034}
5035
5036bool CombinerHelper::matchUDivByConst(MachineInstr &MI) {
5037  assert(MI.getOpcode() == TargetOpcode::G_UDIV);
5038  Register Dst = MI.getOperand(0).getReg();
5039  Register RHS = MI.getOperand(2).getReg();
5040  LLT DstTy = MRI.getType(Dst);
5041  auto *RHSDef = MRI.getVRegDef(RHS);
5042  if (!isConstantOrConstantVector(*RHSDef, MRI))
5043    return false;
5044
5045  auto &MF = *MI.getMF();
5046  AttributeList Attr = MF.getFunction().getAttributes();
5047  const auto &TLI = getTargetLowering();
5048  LLVMContext &Ctx = MF.getFunction().getContext();
5049  auto &DL = MF.getDataLayout();
5050  if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, DL, Ctx), Attr))
5051    return false;
5052
5053  // Don't do this for minsize because the instruction sequence is usually
5054  // larger.
5055  if (MF.getFunction().hasMinSize())
5056    return false;
5057
5058  // Don't do this if the types are not going to be legal.
5059  if (LI) {
5060    if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5061      return false;
5062    if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMULH, {DstTy}}))
5063      return false;
5064    if (!isLegalOrBeforeLegalizer(
5065            {TargetOpcode::G_ICMP,
5066             {DstTy.isVector() ? DstTy.changeElementSize(1) : LLT::scalar(1),
5067              DstTy}}))
5068      return false;
5069  }
5070
5071  auto CheckEltValue = [&](const Constant *C) {
5072    if (auto *CI = dyn_cast_or_null<ConstantInt>(C))
5073      return !CI->isZero();
5074    return false;
5075  };
5076  return matchUnaryPredicate(MRI, RHS, CheckEltValue);
5077}
5078
5079void CombinerHelper::applyUDivByConst(MachineInstr &MI) {
5080  auto *NewMI = buildUDivUsingMul(MI);
5081  replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg());
5082}
5083
5084bool CombinerHelper::matchSDivByConst(MachineInstr &MI) {
5085  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5086  Register Dst = MI.getOperand(0).getReg();
5087  Register RHS = MI.getOperand(2).getReg();
5088  LLT DstTy = MRI.getType(Dst);
5089
5090  auto &MF = *MI.getMF();
5091  AttributeList Attr = MF.getFunction().getAttributes();
5092  const auto &TLI = getTargetLowering();
5093  LLVMContext &Ctx = MF.getFunction().getContext();
5094  auto &DL = MF.getDataLayout();
5095  if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, DL, Ctx), Attr))
5096    return false;
5097
5098  // Don't do this for minsize because the instruction sequence is usually
5099  // larger.
5100  if (MF.getFunction().hasMinSize())
5101    return false;
5102
5103  // If the sdiv has an 'exact' flag we can use a simpler lowering.
5104  if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5105    return matchUnaryPredicate(
5106        MRI, RHS, [](const Constant *C) { return C && !C->isZeroValue(); });
5107  }
5108
5109  // Don't support the general case for now.
5110  return false;
5111}
5112
5113void CombinerHelper::applySDivByConst(MachineInstr &MI) {
5114  auto *NewMI = buildSDivUsingMul(MI);
5115  replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg());
5116}
5117
5118MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
5119  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5120  auto &SDiv = cast<GenericMachineInstr>(MI);
5121  Register Dst = SDiv.getReg(0);
5122  Register LHS = SDiv.getReg(1);
5123  Register RHS = SDiv.getReg(2);
5124  LLT Ty = MRI.getType(Dst);
5125  LLT ScalarTy = Ty.getScalarType();
5126  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5127  LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5128  auto &MIB = Builder;
5129  MIB.setInstrAndDebugLoc(MI);
5130
5131  bool UseSRA = false;
5132  SmallVector<Register, 16> Shifts, Factors;
5133
5134  auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5135  bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
5136
5137  auto BuildSDIVPattern = [&](const Constant *C) {
5138    // Don't recompute inverses for each splat element.
5139    if (IsSplat && !Factors.empty()) {
5140      Shifts.push_back(Shifts[0]);
5141      Factors.push_back(Factors[0]);
5142      return true;
5143    }
5144
5145    auto *CI = cast<ConstantInt>(C);
5146    APInt Divisor = CI->getValue();
5147    unsigned Shift = Divisor.countTrailingZeros();
5148    if (Shift) {
5149      Divisor.ashrInPlace(Shift);
5150      UseSRA = true;
5151    }
5152
5153    // Calculate the multiplicative inverse modulo BW.
5154    // 2^W requires W + 1 bits, so we have to extend and then truncate.
5155    unsigned W = Divisor.getBitWidth();
5156    APInt Factor = Divisor.zext(W + 1)
5157                       .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
5158                       .trunc(W);
5159    Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5160    Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
5161    return true;
5162  };
5163
5164  // Collect all magic values from the build vector.
5165  bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
5166  (void)Matched;
5167  assert(Matched && "Expected unary predicate match to succeed");
5168
5169  Register Shift, Factor;
5170  if (Ty.isVector()) {
5171    Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5172    Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
5173  } else {
5174    Shift = Shifts[0];
5175    Factor = Factors[0];
5176  }
5177
5178  Register Res = LHS;
5179
5180  if (UseSRA)
5181    Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5182
5183  return MIB.buildMul(Ty, Res, Factor);
5184}
5185
5186bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
5187  assert(MI.getOpcode() == TargetOpcode::G_UMULH);
5188  Register RHS = MI.getOperand(2).getReg();
5189  Register Dst = MI.getOperand(0).getReg();
5190  LLT Ty = MRI.getType(Dst);
5191  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5192  auto MatchPow2ExceptOne = [&](const Constant *C) {
5193    if (auto *CI = dyn_cast<ConstantInt>(C))
5194      return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
5195    return false;
5196  };
5197  if (!matchUnaryPredicate(MRI, RHS, MatchPow2ExceptOne, false))
5198    return false;
5199  return isLegalOrBeforeLegalizer({TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}});
5200}
5201
5202void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) {
5203  Register LHS = MI.getOperand(1).getReg();
5204  Register RHS = MI.getOperand(2).getReg();
5205  Register Dst = MI.getOperand(0).getReg();
5206  LLT Ty = MRI.getType(Dst);
5207  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5208  unsigned NumEltBits = Ty.getScalarSizeInBits();
5209
5210  Builder.setInstrAndDebugLoc(MI);
5211  auto LogBase2 = buildLogBase2(RHS, Builder);
5212  auto ShiftAmt =
5213      Builder.buildSub(Ty, Builder.buildConstant(Ty, NumEltBits), LogBase2);
5214  auto Trunc = Builder.buildZExtOrTrunc(ShiftAmtTy, ShiftAmt);
5215  Builder.buildLShr(Dst, LHS, Trunc);
5216  MI.eraseFromParent();
5217}
5218
5219bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
5220                                               BuildFnTy &MatchInfo) {
5221  unsigned Opc = MI.getOpcode();
5222  assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
5223         Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5224         Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
5225
5226  Register Dst = MI.getOperand(0).getReg();
5227  Register X = MI.getOperand(1).getReg();
5228  Register Y = MI.getOperand(2).getReg();
5229  LLT Type = MRI.getType(Dst);
5230
5231  // fold (fadd x, fneg(y)) -> (fsub x, y)
5232  // fold (fadd fneg(y), x) -> (fsub x, y)
5233  // G_ADD is commutative so both cases are checked by m_GFAdd
5234  if (mi_match(Dst, MRI, m_GFAdd(m_Reg(X), m_GFNeg(m_Reg(Y)))) &&
5235      isLegalOrBeforeLegalizer({TargetOpcode::G_FSUB, {Type}})) {
5236    Opc = TargetOpcode::G_FSUB;
5237  }
5238  /// fold (fsub x, fneg(y)) -> (fadd x, y)
5239  else if (mi_match(Dst, MRI, m_GFSub(m_Reg(X), m_GFNeg(m_Reg(Y)))) &&
5240           isLegalOrBeforeLegalizer({TargetOpcode::G_FADD, {Type}})) {
5241    Opc = TargetOpcode::G_FADD;
5242  }
5243  // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
5244  // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
5245  // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
5246  // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
5247  else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5248            Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
5249           mi_match(X, MRI, m_GFNeg(m_Reg(X))) &&
5250           mi_match(Y, MRI, m_GFNeg(m_Reg(Y)))) {
5251    // no opcode change
5252  } else
5253    return false;
5254
5255  MatchInfo = [=, &MI](MachineIRBuilder &B) {
5256    Observer.changingInstr(MI);
5257    MI.setDesc(B.getTII().get(Opc));
5258    MI.getOperand(1).setReg(X);
5259    MI.getOperand(2).setReg(Y);
5260    Observer.changedInstr(MI);
5261  };
5262  return true;
5263}
5264
5265bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, Register &MatchInfo) {
5266  assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5267
5268  Register LHS = MI.getOperand(1).getReg();
5269  MatchInfo = MI.getOperand(2).getReg();
5270  LLT Ty = MRI.getType(MI.getOperand(0).getReg());
5271
5272  const auto LHSCst = Ty.isVector()
5273                          ? getFConstantSplat(LHS, MRI, /* allowUndef */ true)
5274                          : getFConstantVRegValWithLookThrough(LHS, MRI);
5275  if (!LHSCst)
5276    return false;
5277
5278  // -0.0 is always allowed
5279  if (LHSCst->Value.isNegZero())
5280    return true;
5281
5282  // +0.0 is only allowed if nsz is set.
5283  if (LHSCst->Value.isPosZero())
5284    return MI.getFlag(MachineInstr::FmNsz);
5285
5286  return false;
5287}
5288
5289void CombinerHelper::applyFsubToFneg(MachineInstr &MI, Register &MatchInfo) {
5290  Builder.setInstrAndDebugLoc(MI);
5291  Register Dst = MI.getOperand(0).getReg();
5292  Builder.buildFNeg(
5293      Dst, Builder.buildFCanonicalize(MRI.getType(Dst), MatchInfo).getReg(0));
5294  eraseInst(MI);
5295}
5296
5297/// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
5298/// due to global flags or MachineInstr flags.
5299static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
5300  if (MI.getOpcode() != TargetOpcode::G_FMUL)
5301    return false;
5302  return AllowFusionGlobally || MI.getFlag(MachineInstr::MIFlag::FmContract);
5303}
5304
5305static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
5306                        const MachineRegisterInfo &MRI) {
5307  return std::distance(MRI.use_instr_nodbg_begin(MI0.getOperand(0).getReg()),
5308                       MRI.use_instr_nodbg_end()) >
5309         std::distance(MRI.use_instr_nodbg_begin(MI1.getOperand(0).getReg()),
5310                       MRI.use_instr_nodbg_end());
5311}
5312
5313bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
5314                                         bool &AllowFusionGlobally,
5315                                         bool &HasFMAD, bool &Aggressive,
5316                                         bool CanReassociate) {
5317
5318  auto *MF = MI.getMF();
5319  const auto &TLI = *MF->getSubtarget().getTargetLowering();
5320  const TargetOptions &Options = MF->getTarget().Options;
5321  LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5322
5323  if (CanReassociate &&
5324      !(Options.UnsafeFPMath || MI.getFlag(MachineInstr::MIFlag::FmReassoc)))
5325    return false;
5326
5327  // Floating-point multiply-add with intermediate rounding.
5328  HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, DstType));
5329  // Floating-point multiply-add without intermediate rounding.
5330  bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(*MF, DstType) &&
5331                isLegalOrBeforeLegalizer({TargetOpcode::G_FMA, {DstType}});
5332  // No valid opcode, do not combine.
5333  if (!HasFMAD && !HasFMA)
5334    return false;
5335
5336  AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast ||
5337                        Options.UnsafeFPMath || HasFMAD;
5338  // If the addition is not contractable, do not combine.
5339  if (!AllowFusionGlobally && !MI.getFlag(MachineInstr::MIFlag::FmContract))
5340    return false;
5341
5342  Aggressive = TLI.enableAggressiveFMAFusion(DstType);
5343  return true;
5344}
5345
5346bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
5347    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5348  assert(MI.getOpcode() == TargetOpcode::G_FADD);
5349
5350  bool AllowFusionGlobally, HasFMAD, Aggressive;
5351  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5352    return false;
5353
5354  Register Op1 = MI.getOperand(1).getReg();
5355  Register Op2 = MI.getOperand(2).getReg();
5356  DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5357  DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5358  unsigned PreferredFusedOpcode =
5359      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5360
5361  // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5362  // prefer to fold the multiply with fewer uses.
5363  if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5364      isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5365    if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5366      std::swap(LHS, RHS);
5367  }
5368
5369  // fold (fadd (fmul x, y), z) -> (fma x, y, z)
5370  if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5371      (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg))) {
5372    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5373      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5374                   {LHS.MI->getOperand(1).getReg(),
5375                    LHS.MI->getOperand(2).getReg(), RHS.Reg});
5376    };
5377    return true;
5378  }
5379
5380  // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
5381  if (isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5382      (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg))) {
5383    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5384      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5385                   {RHS.MI->getOperand(1).getReg(),
5386                    RHS.MI->getOperand(2).getReg(), LHS.Reg});
5387    };
5388    return true;
5389  }
5390
5391  return false;
5392}
5393
5394bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
5395    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5396  assert(MI.getOpcode() == TargetOpcode::G_FADD);
5397
5398  bool AllowFusionGlobally, HasFMAD, Aggressive;
5399  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5400    return false;
5401
5402  const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5403  Register Op1 = MI.getOperand(1).getReg();
5404  Register Op2 = MI.getOperand(2).getReg();
5405  DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5406  DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5407  LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5408
5409  unsigned PreferredFusedOpcode =
5410      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5411
5412  // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5413  // prefer to fold the multiply with fewer uses.
5414  if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5415      isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5416    if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5417      std::swap(LHS, RHS);
5418  }
5419
5420  // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
5421  MachineInstr *FpExtSrc;
5422  if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
5423      isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
5424      TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5425                          MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
5426    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5427      auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg());
5428      auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg());
5429      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5430                   {FpExtX.getReg(0), FpExtY.getReg(0), RHS.Reg});
5431    };
5432    return true;
5433  }
5434
5435  // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
5436  // Note: Commutes FADD operands.
5437  if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
5438      isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
5439      TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5440                          MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
5441    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5442      auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg());
5443      auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg());
5444      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5445                   {FpExtX.getReg(0), FpExtY.getReg(0), LHS.Reg});
5446    };
5447    return true;
5448  }
5449
5450  return false;
5451}
5452
5453bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
5454    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5455  assert(MI.getOpcode() == TargetOpcode::G_FADD);
5456
5457  bool AllowFusionGlobally, HasFMAD, Aggressive;
5458  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, true))
5459    return false;
5460
5461  Register Op1 = MI.getOperand(1).getReg();
5462  Register Op2 = MI.getOperand(2).getReg();
5463  DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5464  DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5465  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5466
5467  unsigned PreferredFusedOpcode =
5468      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5469
5470  // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5471  // prefer to fold the multiply with fewer uses.
5472  if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5473      isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5474    if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5475      std::swap(LHS, RHS);
5476  }
5477
5478  MachineInstr *FMA = nullptr;
5479  Register Z;
5480  // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
5481  if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5482      (MRI.getVRegDef(LHS.MI->getOperand(3).getReg())->getOpcode() ==
5483       TargetOpcode::G_FMUL) &&
5484      MRI.hasOneNonDBGUse(LHS.MI->getOperand(0).getReg()) &&
5485      MRI.hasOneNonDBGUse(LHS.MI->getOperand(3).getReg())) {
5486    FMA = LHS.MI;
5487    Z = RHS.Reg;
5488  }
5489  // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
5490  else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5491           (MRI.getVRegDef(RHS.MI->getOperand(3).getReg())->getOpcode() ==
5492            TargetOpcode::G_FMUL) &&
5493           MRI.hasOneNonDBGUse(RHS.MI->getOperand(0).getReg()) &&
5494           MRI.hasOneNonDBGUse(RHS.MI->getOperand(3).getReg())) {
5495    Z = LHS.Reg;
5496    FMA = RHS.MI;
5497  }
5498
5499  if (FMA) {
5500    MachineInstr *FMulMI = MRI.getVRegDef(FMA->getOperand(3).getReg());
5501    Register X = FMA->getOperand(1).getReg();
5502    Register Y = FMA->getOperand(2).getReg();
5503    Register U = FMulMI->getOperand(1).getReg();
5504    Register V = FMulMI->getOperand(2).getReg();
5505
5506    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5507      Register InnerFMA = MRI.createGenericVirtualRegister(DstTy);
5508      B.buildInstr(PreferredFusedOpcode, {InnerFMA}, {U, V, Z});
5509      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5510                   {X, Y, InnerFMA});
5511    };
5512    return true;
5513  }
5514
5515  return false;
5516}
5517
5518bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
5519    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5520  assert(MI.getOpcode() == TargetOpcode::G_FADD);
5521
5522  bool AllowFusionGlobally, HasFMAD, Aggressive;
5523  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5524    return false;
5525
5526  if (!Aggressive)
5527    return false;
5528
5529  const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5530  LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5531  Register Op1 = MI.getOperand(1).getReg();
5532  Register Op2 = MI.getOperand(2).getReg();
5533  DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5534  DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5535
5536  unsigned PreferredFusedOpcode =
5537      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5538
5539  // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5540  // prefer to fold the multiply with fewer uses.
5541  if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5542      isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5543    if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5544      std::swap(LHS, RHS);
5545  }
5546
5547  // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
5548  auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
5549                                 Register Y, MachineIRBuilder &B) {
5550    Register FpExtU = B.buildFPExt(DstType, U).getReg(0);
5551    Register FpExtV = B.buildFPExt(DstType, V).getReg(0);
5552    Register InnerFMA =
5553        B.buildInstr(PreferredFusedOpcode, {DstType}, {FpExtU, FpExtV, Z})
5554            .getReg(0);
5555    B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5556                 {X, Y, InnerFMA});
5557  };
5558
5559  MachineInstr *FMulMI, *FMAMI;
5560  // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
5561  //   -> (fma x, y, (fma (fpext u), (fpext v), z))
5562  if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5563      mi_match(LHS.MI->getOperand(3).getReg(), MRI,
5564               m_GFPExt(m_MInstr(FMulMI))) &&
5565      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5566      TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5567                          MRI.getType(FMulMI->getOperand(0).getReg()))) {
5568    MatchInfo = [=](MachineIRBuilder &B) {
5569      buildMatchInfo(FMulMI->getOperand(1).getReg(),
5570                     FMulMI->getOperand(2).getReg(), RHS.Reg,
5571                     LHS.MI->getOperand(1).getReg(),
5572                     LHS.MI->getOperand(2).getReg(), B);
5573    };
5574    return true;
5575  }
5576
5577  // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
5578  //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
5579  // FIXME: This turns two single-precision and one double-precision
5580  // operation into two double-precision operations, which might not be
5581  // interesting for all targets, especially GPUs.
5582  if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) &&
5583      FMAMI->getOpcode() == PreferredFusedOpcode) {
5584    MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
5585    if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5586        TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5587                            MRI.getType(FMAMI->getOperand(0).getReg()))) {
5588      MatchInfo = [=](MachineIRBuilder &B) {
5589        Register X = FMAMI->getOperand(1).getReg();
5590        Register Y = FMAMI->getOperand(2).getReg();
5591        X = B.buildFPExt(DstType, X).getReg(0);
5592        Y = B.buildFPExt(DstType, Y).getReg(0);
5593        buildMatchInfo(FMulMI->getOperand(1).getReg(),
5594                       FMulMI->getOperand(2).getReg(), RHS.Reg, X, Y, B);
5595      };
5596
5597      return true;
5598    }
5599  }
5600
5601  // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
5602  //   -> (fma x, y, (fma (fpext u), (fpext v), z))
5603  if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5604      mi_match(RHS.MI->getOperand(3).getReg(), MRI,
5605               m_GFPExt(m_MInstr(FMulMI))) &&
5606      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5607      TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5608                          MRI.getType(FMulMI->getOperand(0).getReg()))) {
5609    MatchInfo = [=](MachineIRBuilder &B) {
5610      buildMatchInfo(FMulMI->getOperand(1).getReg(),
5611                     FMulMI->getOperand(2).getReg(), LHS.Reg,
5612                     RHS.MI->getOperand(1).getReg(),
5613                     RHS.MI->getOperand(2).getReg(), B);
5614    };
5615    return true;
5616  }
5617
5618  // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
5619  //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
5620  // FIXME: This turns two single-precision and one double-precision
5621  // operation into two double-precision operations, which might not be
5622  // interesting for all targets, especially GPUs.
5623  if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) &&
5624      FMAMI->getOpcode() == PreferredFusedOpcode) {
5625    MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
5626    if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5627        TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5628                            MRI.getType(FMAMI->getOperand(0).getReg()))) {
5629      MatchInfo = [=](MachineIRBuilder &B) {
5630        Register X = FMAMI->getOperand(1).getReg();
5631        Register Y = FMAMI->getOperand(2).getReg();
5632        X = B.buildFPExt(DstType, X).getReg(0);
5633        Y = B.buildFPExt(DstType, Y).getReg(0);
5634        buildMatchInfo(FMulMI->getOperand(1).getReg(),
5635                       FMulMI->getOperand(2).getReg(), LHS.Reg, X, Y, B);
5636      };
5637      return true;
5638    }
5639  }
5640
5641  return false;
5642}
5643
5644bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
5645    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5646  assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5647
5648  bool AllowFusionGlobally, HasFMAD, Aggressive;
5649  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5650    return false;
5651
5652  Register Op1 = MI.getOperand(1).getReg();
5653  Register Op2 = MI.getOperand(2).getReg();
5654  DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5655  DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5656  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5657
5658  // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5659  // prefer to fold the multiply with fewer uses.
5660  int FirstMulHasFewerUses = true;
5661  if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5662      isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5663      hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5664    FirstMulHasFewerUses = false;
5665
5666  unsigned PreferredFusedOpcode =
5667      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5668
5669  // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
5670  if (FirstMulHasFewerUses &&
5671      (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5672       (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg)))) {
5673    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5674      Register NegZ = B.buildFNeg(DstTy, RHS.Reg).getReg(0);
5675      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5676                   {LHS.MI->getOperand(1).getReg(),
5677                    LHS.MI->getOperand(2).getReg(), NegZ});
5678    };
5679    return true;
5680  }
5681  // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
5682  else if ((isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5683            (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg)))) {
5684    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5685      Register NegY =
5686          B.buildFNeg(DstTy, RHS.MI->getOperand(1).getReg()).getReg(0);
5687      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5688                   {NegY, RHS.MI->getOperand(2).getReg(), LHS.Reg});
5689    };
5690    return true;
5691  }
5692
5693  return false;
5694}
5695
5696bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
5697    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5698  assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5699
5700  bool AllowFusionGlobally, HasFMAD, Aggressive;
5701  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5702    return false;
5703
5704  Register LHSReg = MI.getOperand(1).getReg();
5705  Register RHSReg = MI.getOperand(2).getReg();
5706  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5707
5708  unsigned PreferredFusedOpcode =
5709      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5710
5711  MachineInstr *FMulMI;
5712  // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
5713  if (mi_match(LHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
5714      (Aggressive || (MRI.hasOneNonDBGUse(LHSReg) &&
5715                      MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
5716      isContractableFMul(*FMulMI, AllowFusionGlobally)) {
5717    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5718      Register NegX =
5719          B.buildFNeg(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
5720      Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
5721      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5722                   {NegX, FMulMI->getOperand(2).getReg(), NegZ});
5723    };
5724    return true;
5725  }
5726
5727  // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
5728  if (mi_match(RHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
5729      (Aggressive || (MRI.hasOneNonDBGUse(RHSReg) &&
5730                      MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
5731      isContractableFMul(*FMulMI, AllowFusionGlobally)) {
5732    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5733      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5734                   {FMulMI->getOperand(1).getReg(),
5735                    FMulMI->getOperand(2).getReg(), LHSReg});
5736    };
5737    return true;
5738  }
5739
5740  return false;
5741}
5742
5743bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
5744    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5745  assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5746
5747  bool AllowFusionGlobally, HasFMAD, Aggressive;
5748  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5749    return false;
5750
5751  Register LHSReg = MI.getOperand(1).getReg();
5752  Register RHSReg = MI.getOperand(2).getReg();
5753  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5754
5755  unsigned PreferredFusedOpcode =
5756      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5757
5758  MachineInstr *FMulMI;
5759  // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
5760  if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
5761      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5762      (Aggressive || MRI.hasOneNonDBGUse(LHSReg))) {
5763    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5764      Register FpExtX =
5765          B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
5766      Register FpExtY =
5767          B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
5768      Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
5769      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5770                   {FpExtX, FpExtY, NegZ});
5771    };
5772    return true;
5773  }
5774
5775  // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
5776  if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
5777      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5778      (Aggressive || MRI.hasOneNonDBGUse(RHSReg))) {
5779    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5780      Register FpExtY =
5781          B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
5782      Register NegY = B.buildFNeg(DstTy, FpExtY).getReg(0);
5783      Register FpExtZ =
5784          B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
5785      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5786                   {NegY, FpExtZ, LHSReg});
5787    };
5788    return true;
5789  }
5790
5791  return false;
5792}
5793
5794bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
5795    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5796  assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5797
5798  bool AllowFusionGlobally, HasFMAD, Aggressive;
5799  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5800    return false;
5801
5802  const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5803  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5804  Register LHSReg = MI.getOperand(1).getReg();
5805  Register RHSReg = MI.getOperand(2).getReg();
5806
5807  unsigned PreferredFusedOpcode =
5808      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5809
5810  auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
5811                            MachineIRBuilder &B) {
5812    Register FpExtX = B.buildFPExt(DstTy, X).getReg(0);
5813    Register FpExtY = B.buildFPExt(DstTy, Y).getReg(0);
5814    B.buildInstr(PreferredFusedOpcode, {Dst}, {FpExtX, FpExtY, Z});
5815  };
5816
5817  MachineInstr *FMulMI;
5818  // fold (fsub (fpext (fneg (fmul x, y))), z) ->
5819  //      (fneg (fma (fpext x), (fpext y), z))
5820  // fold (fsub (fneg (fpext (fmul x, y))), z) ->
5821  //      (fneg (fma (fpext x), (fpext y), z))
5822  if ((mi_match(LHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
5823       mi_match(LHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
5824      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5825      TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
5826                          MRI.getType(FMulMI->getOperand(0).getReg()))) {
5827    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5828      Register FMAReg = MRI.createGenericVirtualRegister(DstTy);
5829      buildMatchInfo(FMAReg, FMulMI->getOperand(1).getReg(),
5830                     FMulMI->getOperand(2).getReg(), RHSReg, B);
5831      B.buildFNeg(MI.getOperand(0).getReg(), FMAReg);
5832    };
5833    return true;
5834  }
5835
5836  // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
5837  // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
5838  if ((mi_match(RHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
5839       mi_match(RHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
5840      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5841      TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
5842                          MRI.getType(FMulMI->getOperand(0).getReg()))) {
5843    MatchInfo = [=, &MI](MachineIRBuilder &B) {
5844      buildMatchInfo(MI.getOperand(0).getReg(), FMulMI->getOperand(1).getReg(),
5845                     FMulMI->getOperand(2).getReg(), LHSReg, B);
5846    };
5847    return true;
5848  }
5849
5850  return false;
5851}
5852
5853bool CombinerHelper::matchSelectToLogical(MachineInstr &MI,
5854                                          BuildFnTy &MatchInfo) {
5855  GSelect &Sel = cast<GSelect>(MI);
5856  Register DstReg = Sel.getReg(0);
5857  Register Cond = Sel.getCondReg();
5858  Register TrueReg = Sel.getTrueReg();
5859  Register FalseReg = Sel.getFalseReg();
5860
5861  auto *TrueDef = getDefIgnoringCopies(TrueReg, MRI);
5862  auto *FalseDef = getDefIgnoringCopies(FalseReg, MRI);
5863
5864  const LLT CondTy = MRI.getType(Cond);
5865  const LLT OpTy = MRI.getType(TrueReg);
5866  if (CondTy != OpTy || OpTy.getScalarSizeInBits() != 1)
5867    return false;
5868
5869  // We have a boolean select.
5870
5871  // select Cond, Cond, F --> or Cond, F
5872  // select Cond, 1, F    --> or Cond, F
5873  auto MaybeCstTrue = isConstantOrConstantSplatVector(*TrueDef, MRI);
5874  if (Cond == TrueReg || (MaybeCstTrue && MaybeCstTrue->isOne())) {
5875    MatchInfo = [=](MachineIRBuilder &MIB) {
5876      MIB.buildOr(DstReg, Cond, FalseReg);
5877    };
5878    return true;
5879  }
5880
5881  // select Cond, T, Cond --> and Cond, T
5882  // select Cond, T, 0    --> and Cond, T
5883  auto MaybeCstFalse = isConstantOrConstantSplatVector(*FalseDef, MRI);
5884  if (Cond == FalseReg || (MaybeCstFalse && MaybeCstFalse->isZero())) {
5885    MatchInfo = [=](MachineIRBuilder &MIB) {
5886      MIB.buildAnd(DstReg, Cond, TrueReg);
5887    };
5888    return true;
5889  }
5890
5891 // select Cond, T, 1 --> or (not Cond), T
5892  if (MaybeCstFalse && MaybeCstFalse->isOne()) {
5893    MatchInfo = [=](MachineIRBuilder &MIB) {
5894      MIB.buildOr(DstReg, MIB.buildNot(OpTy, Cond), TrueReg);
5895    };
5896    return true;
5897  }
5898
5899  // select Cond, 0, F --> and (not Cond), F
5900  if (MaybeCstTrue && MaybeCstTrue->isZero()) {
5901    MatchInfo = [=](MachineIRBuilder &MIB) {
5902      MIB.buildAnd(DstReg, MIB.buildNot(OpTy, Cond), FalseReg);
5903    };
5904    return true;
5905  }
5906  return false;
5907}
5908
5909bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
5910                                            unsigned &IdxToPropagate) {
5911  bool PropagateNaN;
5912  switch (MI.getOpcode()) {
5913  default:
5914    return false;
5915  case TargetOpcode::G_FMINNUM:
5916  case TargetOpcode::G_FMAXNUM:
5917    PropagateNaN = false;
5918    break;
5919  case TargetOpcode::G_FMINIMUM:
5920  case TargetOpcode::G_FMAXIMUM:
5921    PropagateNaN = true;
5922    break;
5923  }
5924
5925  auto MatchNaN = [&](unsigned Idx) {
5926    Register MaybeNaNReg = MI.getOperand(Idx).getReg();
5927    const ConstantFP *MaybeCst = getConstantFPVRegVal(MaybeNaNReg, MRI);
5928    if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
5929      return false;
5930    IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
5931    return true;
5932  };
5933
5934  return MatchNaN(1) || MatchNaN(2);
5935}
5936
5937bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) {
5938  assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
5939  Register LHS = MI.getOperand(1).getReg();
5940  Register RHS = MI.getOperand(2).getReg();
5941
5942  // Helper lambda to check for opportunities for
5943  // A + (B - A) -> B
5944  // (B - A) + A -> B
5945  auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
5946    Register Reg;
5947    return mi_match(MaybeSub, MRI, m_GSub(m_Reg(Src), m_Reg(Reg))) &&
5948           Reg == MaybeSameReg;
5949  };
5950  return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
5951}
5952
5953bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
5954                                                  Register &MatchInfo) {
5955  // This combine folds the following patterns:
5956  //
5957  //  G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
5958  //  G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
5959  //    into
5960  //      x
5961  //    if
5962  //      k == sizeof(VecEltTy)/2
5963  //      type(x) == type(dst)
5964  //
5965  //  G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
5966  //    into
5967  //      x
5968  //    if
5969  //      type(x) == type(dst)
5970
5971  LLT DstVecTy = MRI.getType(MI.getOperand(0).getReg());
5972  LLT DstEltTy = DstVecTy.getElementType();
5973
5974  Register Lo, Hi;
5975
5976  if (mi_match(
5977          MI, MRI,
5978          m_GBuildVector(m_GTrunc(m_GBitcast(m_Reg(Lo))), m_GImplicitDef()))) {
5979    MatchInfo = Lo;
5980    return MRI.getType(MatchInfo) == DstVecTy;
5981  }
5982
5983  std::optional<ValueAndVReg> ShiftAmount;
5984  const auto LoPattern = m_GBitcast(m_Reg(Lo));
5985  const auto HiPattern = m_GLShr(m_GBitcast(m_Reg(Hi)), m_GCst(ShiftAmount));
5986  if (mi_match(
5987          MI, MRI,
5988          m_any_of(m_GBuildVectorTrunc(LoPattern, HiPattern),
5989                   m_GBuildVector(m_GTrunc(LoPattern), m_GTrunc(HiPattern))))) {
5990    if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
5991      MatchInfo = Lo;
5992      return MRI.getType(MatchInfo) == DstVecTy;
5993    }
5994  }
5995
5996  return false;
5997}
5998
5999bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6000                                               Register &MatchInfo) {
6001  // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6002  // if type(x) == type(G_TRUNC)
6003  if (!mi_match(MI.getOperand(1).getReg(), MRI,
6004                m_GBitcast(m_GBuildVector(m_Reg(MatchInfo), m_Reg()))))
6005    return false;
6006
6007  return MRI.getType(MatchInfo) == MRI.getType(MI.getOperand(0).getReg());
6008}
6009
6010bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6011                                                   Register &MatchInfo) {
6012  // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6013  //    y if K == size of vector element type
6014  std::optional<ValueAndVReg> ShiftAmt;
6015  if (!mi_match(MI.getOperand(1).getReg(), MRI,
6016                m_GLShr(m_GBitcast(m_GBuildVector(m_Reg(), m_Reg(MatchInfo))),
6017                        m_GCst(ShiftAmt))))
6018    return false;
6019
6020  LLT MatchTy = MRI.getType(MatchInfo);
6021  return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6022         MatchTy == MRI.getType(MI.getOperand(0).getReg());
6023}
6024
6025unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6026    CmpInst::Predicate Pred, LLT DstTy,
6027    SelectPatternNaNBehaviour VsNaNRetVal) const {
6028  assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6029         "Expected a NaN behaviour?");
6030  // Choose an opcode based off of legality or the behaviour when one of the
6031  // LHS/RHS may be NaN.
6032  switch (Pred) {
6033  default:
6034    return 0;
6035  case CmpInst::FCMP_UGT:
6036  case CmpInst::FCMP_UGE:
6037  case CmpInst::FCMP_OGT:
6038  case CmpInst::FCMP_OGE:
6039    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6040      return TargetOpcode::G_FMAXNUM;
6041    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6042      return TargetOpcode::G_FMAXIMUM;
6043    if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}}))
6044      return TargetOpcode::G_FMAXNUM;
6045    if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}}))
6046      return TargetOpcode::G_FMAXIMUM;
6047    return 0;
6048  case CmpInst::FCMP_ULT:
6049  case CmpInst::FCMP_ULE:
6050  case CmpInst::FCMP_OLT:
6051  case CmpInst::FCMP_OLE:
6052    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6053      return TargetOpcode::G_FMINNUM;
6054    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6055      return TargetOpcode::G_FMINIMUM;
6056    if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}}))
6057      return TargetOpcode::G_FMINNUM;
6058    if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}}))
6059      return 0;
6060    return TargetOpcode::G_FMINIMUM;
6061  }
6062}
6063
6064CombinerHelper::SelectPatternNaNBehaviour
6065CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
6066                                        bool IsOrderedComparison) const {
6067  bool LHSSafe = isKnownNeverNaN(LHS, MRI);
6068  bool RHSSafe = isKnownNeverNaN(RHS, MRI);
6069  // Completely unsafe.
6070  if (!LHSSafe && !RHSSafe)
6071    return SelectPatternNaNBehaviour::NOT_APPLICABLE;
6072  if (LHSSafe && RHSSafe)
6073    return SelectPatternNaNBehaviour::RETURNS_ANY;
6074  // An ordered comparison will return false when given a NaN, so it
6075  // returns the RHS.
6076  if (IsOrderedComparison)
6077    return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
6078                   : SelectPatternNaNBehaviour::RETURNS_OTHER;
6079  // An unordered comparison will return true when given a NaN, so it
6080  // returns the LHS.
6081  return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
6082                 : SelectPatternNaNBehaviour::RETURNS_NAN;
6083}
6084
6085bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
6086                                           Register TrueVal, Register FalseVal,
6087                                           BuildFnTy &MatchInfo) {
6088  // Match: select (fcmp cond x, y) x, y
6089  //        select (fcmp cond x, y) y, x
6090  // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
6091  LLT DstTy = MRI.getType(Dst);
6092  // Bail out early on pointers, since we'll never want to fold to a min/max.
6093  if (DstTy.isPointer())
6094    return false;
6095  // Match a floating point compare with a less-than/greater-than predicate.
6096  // TODO: Allow multiple users of the compare if they are all selects.
6097  CmpInst::Predicate Pred;
6098  Register CmpLHS, CmpRHS;
6099  if (!mi_match(Cond, MRI,
6100                m_OneNonDBGUse(
6101                    m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) ||
6102      CmpInst::isEquality(Pred))
6103    return false;
6104  SelectPatternNaNBehaviour ResWithKnownNaNInfo =
6105      computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred));
6106  if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
6107    return false;
6108  if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
6109    std::swap(CmpLHS, CmpRHS);
6110    Pred = CmpInst::getSwappedPredicate(Pred);
6111    if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
6112      ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
6113    else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
6114      ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
6115  }
6116  if (TrueVal != CmpLHS || FalseVal != CmpRHS)
6117    return false;
6118  // Decide what type of max/min this should be based off of the predicate.
6119  unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo);
6120  if (!Opc || !isLegal({Opc, {DstTy}}))
6121    return false;
6122  // Comparisons between signed zero and zero may have different results...
6123  // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
6124  if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
6125    // We don't know if a comparison between two 0s will give us a consistent
6126    // result. Be conservative and only proceed if at least one side is
6127    // non-zero.
6128    auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI);
6129    if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
6130      KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI);
6131      if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
6132        return false;
6133    }
6134  }
6135  MatchInfo = [=](MachineIRBuilder &B) {
6136    B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS});
6137  };
6138  return true;
6139}
6140
6141bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
6142                                                 BuildFnTy &MatchInfo) {
6143  // TODO: Handle integer cases.
6144  assert(MI.getOpcode() == TargetOpcode::G_SELECT);
6145  // Condition may be fed by a truncated compare.
6146  Register Cond = MI.getOperand(1).getReg();
6147  Register MaybeTrunc;
6148  if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc)))))
6149    Cond = MaybeTrunc;
6150  Register Dst = MI.getOperand(0).getReg();
6151  Register TrueVal = MI.getOperand(2).getReg();
6152  Register FalseVal = MI.getOperand(3).getReg();
6153  return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
6154}
6155
6156bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
6157                                                   BuildFnTy &MatchInfo) {
6158  assert(MI.getOpcode() == TargetOpcode::G_ICMP);
6159  // (X + Y) == X --> Y == 0
6160  // (X + Y) != X --> Y != 0
6161  // (X - Y) == X --> Y == 0
6162  // (X - Y) != X --> Y != 0
6163  // (X ^ Y) == X --> Y == 0
6164  // (X ^ Y) != X --> Y != 0
6165  Register Dst = MI.getOperand(0).getReg();
6166  CmpInst::Predicate Pred;
6167  Register X, Y, OpLHS, OpRHS;
6168  bool MatchedSub = mi_match(
6169      Dst, MRI,
6170      m_c_GICmp(m_Pred(Pred), m_Reg(X), m_GSub(m_Reg(OpLHS), m_Reg(Y))));
6171  if (MatchedSub && X != OpLHS)
6172    return false;
6173  if (!MatchedSub) {
6174    if (!mi_match(Dst, MRI,
6175                  m_c_GICmp(m_Pred(Pred), m_Reg(X),
6176                            m_any_of(m_GAdd(m_Reg(OpLHS), m_Reg(OpRHS)),
6177                                     m_GXor(m_Reg(OpLHS), m_Reg(OpRHS))))))
6178      return false;
6179    Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
6180  }
6181  MatchInfo = [=](MachineIRBuilder &B) {
6182    auto Zero = B.buildConstant(MRI.getType(Y), 0);
6183    B.buildICmp(Pred, Dst, Y, Zero);
6184  };
6185  return CmpInst::isEquality(Pred) && Y.isValid();
6186}
6187
6188bool CombinerHelper::tryCombine(MachineInstr &MI) {
6189  if (tryCombineCopy(MI))
6190    return true;
6191  if (tryCombineExtendingLoads(MI))
6192    return true;
6193  if (tryCombineIndexedLoadStore(MI))
6194    return true;
6195  return false;
6196}
6197