1//===- InstCombineSimplifyDemanded.cpp ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains logic for simplifying instructions based on information
10// about how they are used.
11//
12//===----------------------------------------------------------------------===//
13
14#include "InstCombineInternal.h"
15#include "llvm/Analysis/ValueTracking.h"
16#include "llvm/IR/IntrinsicInst.h"
17#include "llvm/IR/IntrinsicsAMDGPU.h"
18#include "llvm/IR/IntrinsicsX86.h"
19#include "llvm/IR/PatternMatch.h"
20#include "llvm/Support/KnownBits.h"
21
22using namespace llvm;
23using namespace llvm::PatternMatch;
24
25#define DEBUG_TYPE "instcombine"
26
27namespace {
28
29struct AMDGPUImageDMaskIntrinsic {
30  unsigned Intr;
31};
32
33#define GET_AMDGPUImageDMaskIntrinsicTable_IMPL
34#include "InstCombineTables.inc"
35
36} // end anonymous namespace
37
38/// Check to see if the specified operand of the specified instruction is a
39/// constant integer. If so, check to see if there are any bits set in the
40/// constant that are not demanded. If so, shrink the constant and return true.
41static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
42                                   const APInt &Demanded) {
43  assert(I && "No instruction?");
44  assert(OpNo < I->getNumOperands() && "Operand index too large");
45
46  // The operand must be a constant integer or splat integer.
47  Value *Op = I->getOperand(OpNo);
48  const APInt *C;
49  if (!match(Op, m_APInt(C)))
50    return false;
51
52  // If there are no bits set that aren't demanded, nothing to do.
53  if (C->isSubsetOf(Demanded))
54    return false;
55
56  // This instruction is producing bits that are not demanded. Shrink the RHS.
57  I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded));
58
59  return true;
60}
61
62
63
64/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
65/// the instruction has any properties that allow us to simplify its operands.
66bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
67  unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
68  KnownBits Known(BitWidth);
69  APInt DemandedMask(APInt::getAllOnesValue(BitWidth));
70
71  Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
72                                     0, &Inst);
73  if (!V) return false;
74  if (V == &Inst) return true;
75  replaceInstUsesWith(Inst, V);
76  return true;
77}
78
79/// This form of SimplifyDemandedBits simplifies the specified instruction
80/// operand if possible, updating it in place. It returns true if it made any
81/// change and false otherwise.
82bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
83                                        const APInt &DemandedMask,
84                                        KnownBits &Known,
85                                        unsigned Depth) {
86  Use &U = I->getOperandUse(OpNo);
87  Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known,
88                                          Depth, I);
89  if (!NewVal) return false;
90  if (Instruction* OpInst = dyn_cast<Instruction>(U))
91    salvageDebugInfo(*OpInst);
92
93  replaceUse(U, NewVal);
94  return true;
95}
96
97
98/// This function attempts to replace V with a simpler value based on the
99/// demanded bits. When this function is called, it is known that only the bits
100/// set in DemandedMask of the result of V are ever used downstream.
101/// Consequently, depending on the mask and V, it may be possible to replace V
102/// with a constant or one of its operands. In such cases, this function does
103/// the replacement and returns true. In all other cases, it returns false after
104/// analyzing the expression and setting KnownOne and known to be one in the
105/// expression. Known.Zero contains all the bits that are known to be zero in
106/// the expression. These are provided to potentially allow the caller (which
107/// might recursively be SimplifyDemandedBits itself) to simplify the
108/// expression.
109/// Known.One and Known.Zero always follow the invariant that:
110///   Known.One & Known.Zero == 0.
111/// That is, a bit can't be both 1 and 0. Note that the bits in Known.One and
112/// Known.Zero may only be accurate for those bits set in DemandedMask. Note
113/// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all
114/// be the same.
115///
116/// This returns null if it did not change anything and it permits no
117/// simplification.  This returns V itself if it did some simplification of V's
118/// operands based on the information about what bits are demanded. This returns
119/// some other non-null value if it found out that V is equal to another value
120/// in the context where the specified bits are demanded, but not for all users.
121Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
122                                             KnownBits &Known, unsigned Depth,
123                                             Instruction *CxtI) {
124  assert(V != nullptr && "Null pointer of Value???");
125  assert(Depth <= 6 && "Limit Search Depth");
126  uint32_t BitWidth = DemandedMask.getBitWidth();
127  Type *VTy = V->getType();
128  assert(
129      (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) &&
130      Known.getBitWidth() == BitWidth &&
131      "Value *V, DemandedMask and Known must have same BitWidth");
132
133  if (isa<Constant>(V)) {
134    computeKnownBits(V, Known, Depth, CxtI);
135    return nullptr;
136  }
137
138  Known.resetAll();
139  if (DemandedMask.isNullValue())     // Not demanding any bits from V.
140    return UndefValue::get(VTy);
141
142  if (Depth == 6)        // Limit search depth.
143    return nullptr;
144
145  Instruction *I = dyn_cast<Instruction>(V);
146  if (!I) {
147    computeKnownBits(V, Known, Depth, CxtI);
148    return nullptr;        // Only analyze instructions.
149  }
150
151  // If there are multiple uses of this value and we aren't at the root, then
152  // we can't do any simplifications of the operands, because DemandedMask
153  // only reflects the bits demanded by *one* of the users.
154  if (Depth != 0 && !I->hasOneUse())
155    return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);
156
157  KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
158
159  // If this is the root being simplified, allow it to have multiple uses,
160  // just set the DemandedMask to all bits so that we can try to simplify the
161  // operands.  This allows visitTruncInst (for example) to simplify the
162  // operand of a trunc without duplicating all the logic below.
163  if (Depth == 0 && !V->hasOneUse())
164    DemandedMask.setAllBits();
165
166  switch (I->getOpcode()) {
167  default:
168    computeKnownBits(I, Known, Depth, CxtI);
169    break;
170  case Instruction::And: {
171    // If either the LHS or the RHS are Zero, the result is zero.
172    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
173        SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown,
174                             Depth + 1))
175      return I;
176    assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
177    assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
178
179    Known = LHSKnown & RHSKnown;
180
181    // If the client is only demanding bits that we know, return the known
182    // constant.
183    if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
184      return Constant::getIntegerValue(VTy, Known.One);
185
186    // If all of the demanded bits are known 1 on one side, return the other.
187    // These bits cannot contribute to the result of the 'and'.
188    if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
189      return I->getOperand(0);
190    if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
191      return I->getOperand(1);
192
193    // If the RHS is a constant, see if we can simplify it.
194    if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero))
195      return I;
196
197    break;
198  }
199  case Instruction::Or: {
200    // If either the LHS or the RHS are One, the result is One.
201    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
202        SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
203                             Depth + 1))
204      return I;
205    assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
206    assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
207
208    Known = LHSKnown | RHSKnown;
209
210    // If the client is only demanding bits that we know, return the known
211    // constant.
212    if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
213      return Constant::getIntegerValue(VTy, Known.One);
214
215    // If all of the demanded bits are known zero on one side, return the other.
216    // These bits cannot contribute to the result of the 'or'.
217    if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
218      return I->getOperand(0);
219    if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
220      return I->getOperand(1);
221
222    // If the RHS is a constant, see if we can simplify it.
223    if (ShrinkDemandedConstant(I, 1, DemandedMask))
224      return I;
225
226    break;
227  }
228  case Instruction::Xor: {
229    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
230        SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1))
231      return I;
232    assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
233    assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
234
235    Known = LHSKnown ^ RHSKnown;
236
237    // If the client is only demanding bits that we know, return the known
238    // constant.
239    if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
240      return Constant::getIntegerValue(VTy, Known.One);
241
242    // If all of the demanded bits are known zero on one side, return the other.
243    // These bits cannot contribute to the result of the 'xor'.
244    if (DemandedMask.isSubsetOf(RHSKnown.Zero))
245      return I->getOperand(0);
246    if (DemandedMask.isSubsetOf(LHSKnown.Zero))
247      return I->getOperand(1);
248
249    // If all of the demanded bits are known to be zero on one side or the
250    // other, turn this into an *inclusive* or.
251    //    e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
252    if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) {
253      Instruction *Or =
254        BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
255                                 I->getName());
256      return InsertNewInstWith(Or, *I);
257    }
258
259    // If all of the demanded bits on one side are known, and all of the set
260    // bits on that side are also known to be set on the other side, turn this
261    // into an AND, as we know the bits will be cleared.
262    //    e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
263    if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) &&
264        RHSKnown.One.isSubsetOf(LHSKnown.One)) {
265      Constant *AndC = Constant::getIntegerValue(VTy,
266                                                 ~RHSKnown.One & DemandedMask);
267      Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
268      return InsertNewInstWith(And, *I);
269    }
270
271    // If the RHS is a constant, see if we can simplify it.
272    // FIXME: for XOR, we prefer to force bits to 1 if they will make a -1.
273    if (ShrinkDemandedConstant(I, 1, DemandedMask))
274      return I;
275
276    // If our LHS is an 'and' and if it has one use, and if any of the bits we
277    // are flipping are known to be set, then the xor is just resetting those
278    // bits to zero.  We can just knock out bits from the 'and' and the 'xor',
279    // simplifying both of them.
280    if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0)))
281      if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
282          isa<ConstantInt>(I->getOperand(1)) &&
283          isa<ConstantInt>(LHSInst->getOperand(1)) &&
284          (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
285        ConstantInt *AndRHS = cast<ConstantInt>(LHSInst->getOperand(1));
286        ConstantInt *XorRHS = cast<ConstantInt>(I->getOperand(1));
287        APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
288
289        Constant *AndC =
290          ConstantInt::get(I->getType(), NewMask & AndRHS->getValue());
291        Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
292        InsertNewInstWith(NewAnd, *I);
293
294        Constant *XorC =
295          ConstantInt::get(I->getType(), NewMask & XorRHS->getValue());
296        Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
297        return InsertNewInstWith(NewXor, *I);
298      }
299
300    break;
301  }
302  case Instruction::Select: {
303    Value *LHS, *RHS;
304    SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor;
305    if (SPF == SPF_UMAX) {
306      // UMax(A, C) == A if ...
307      // The lowest non-zero bit of DemandMask is higher than the highest
308      // non-zero bit of C.
309      const APInt *C;
310      unsigned CTZ = DemandedMask.countTrailingZeros();
311      if (match(RHS, m_APInt(C)) && CTZ >= C->getActiveBits())
312        return LHS;
313    } else if (SPF == SPF_UMIN) {
314      // UMin(A, C) == A if ...
315      // The lowest non-zero bit of DemandMask is higher than the highest
316      // non-one bit of C.
317      // This comes from using DeMorgans on the above umax example.
318      const APInt *C;
319      unsigned CTZ = DemandedMask.countTrailingZeros();
320      if (match(RHS, m_APInt(C)) &&
321          CTZ >= C->getBitWidth() - C->countLeadingOnes())
322        return LHS;
323    }
324
325    // If this is a select as part of any other min/max pattern, don't simplify
326    // any further in case we break the structure.
327    if (SPF != SPF_UNKNOWN)
328      return nullptr;
329
330    if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) ||
331        SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1))
332      return I;
333    assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
334    assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
335
336    // If the operands are constants, see if we can simplify them.
337    // This is similar to ShrinkDemandedConstant, but for a select we want to
338    // try to keep the selected constants the same as icmp value constants, if
339    // we can. This helps not break apart (or helps put back together)
340    // canonical patterns like min and max.
341    auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
342                                         APInt DemandedMask) {
343      const APInt *SelC;
344      if (!match(I->getOperand(OpNo), m_APInt(SelC)))
345        return false;
346
347      // Get the constant out of the ICmp, if there is one.
348      const APInt *CmpC;
349      ICmpInst::Predicate Pred;
350      if (!match(I->getOperand(0), m_c_ICmp(Pred, m_APInt(CmpC), m_Value())) ||
351          CmpC->getBitWidth() != SelC->getBitWidth())
352        return ShrinkDemandedConstant(I, OpNo, DemandedMask);
353
354      // If the constant is already the same as the ICmp, leave it as-is.
355      if (*CmpC == *SelC)
356        return false;
357      // If the constants are not already the same, but can be with the demand
358      // mask, use the constant value from the ICmp.
359      if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) {
360        I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC));
361        return true;
362      }
363      return ShrinkDemandedConstant(I, OpNo, DemandedMask);
364    };
365    if (CanonicalizeSelectConstant(I, 1, DemandedMask) ||
366        CanonicalizeSelectConstant(I, 2, DemandedMask))
367      return I;
368
369    // Only known if known in both the LHS and RHS.
370    Known.One = RHSKnown.One & LHSKnown.One;
371    Known.Zero = RHSKnown.Zero & LHSKnown.Zero;
372    break;
373  }
374  case Instruction::ZExt:
375  case Instruction::Trunc: {
376    unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
377
378    APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
379    KnownBits InputKnown(SrcBitWidth);
380    if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1))
381      return I;
382    assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?");
383    Known = InputKnown.zextOrTrunc(BitWidth);
384    assert(!Known.hasConflict() && "Bits known to be one AND zero?");
385    break;
386  }
387  case Instruction::BitCast:
388    if (!I->getOperand(0)->getType()->isIntOrIntVectorTy())
389      return nullptr;  // vector->int or fp->int?
390
391    if (VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) {
392      if (VectorType *SrcVTy =
393            dyn_cast<VectorType>(I->getOperand(0)->getType())) {
394        if (DstVTy->getNumElements() != SrcVTy->getNumElements())
395          // Don't touch a bitcast between vectors of different element counts.
396          return nullptr;
397      } else
398        // Don't touch a scalar-to-vector bitcast.
399        return nullptr;
400    } else if (I->getOperand(0)->getType()->isVectorTy())
401      // Don't touch a vector-to-scalar bitcast.
402      return nullptr;
403
404    if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1))
405      return I;
406    assert(!Known.hasConflict() && "Bits known to be one AND zero?");
407    break;
408  case Instruction::SExt: {
409    // Compute the bits in the result that are not present in the input.
410    unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
411
412    APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth);
413
414    // If any of the sign extended bits are demanded, we know that the sign
415    // bit is demanded.
416    if (DemandedMask.getActiveBits() > SrcBitWidth)
417      InputDemandedBits.setBit(SrcBitWidth-1);
418
419    KnownBits InputKnown(SrcBitWidth);
420    if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1))
421      return I;
422
423    // If the input sign bit is known zero, or if the NewBits are not demanded
424    // convert this into a zero extension.
425    if (InputKnown.isNonNegative() ||
426        DemandedMask.getActiveBits() <= SrcBitWidth) {
427      // Convert to ZExt cast.
428      CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName());
429      return InsertNewInstWith(NewCast, *I);
430     }
431
432    // If the sign bit of the input is known set or clear, then we know the
433    // top bits of the result.
434    Known = InputKnown.sext(BitWidth);
435    assert(!Known.hasConflict() && "Bits known to be one AND zero?");
436    break;
437  }
438  case Instruction::Add:
439    if ((DemandedMask & 1) == 0) {
440      // If we do not need the low bit, try to convert bool math to logic:
441      // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN
442      Value *X, *Y;
443      if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))),
444                           m_OneUse(m_SExt(m_Value(Y))))) &&
445          X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) {
446        // Truth table for inputs and output signbits:
447        //       X:0 | X:1
448        //      ----------
449        // Y:0  |  0 | 0 |
450        // Y:1  | -1 | 0 |
451        //      ----------
452        IRBuilderBase::InsertPointGuard Guard(Builder);
453        Builder.SetInsertPoint(I);
454        Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y);
455        return Builder.CreateSExt(AndNot, VTy);
456      }
457
458      // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN
459      // TODO: Relax the one-use checks because we are removing an instruction?
460      if (match(I, m_Add(m_OneUse(m_SExt(m_Value(X))),
461                         m_OneUse(m_SExt(m_Value(Y))))) &&
462          X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) {
463        // Truth table for inputs and output signbits:
464        //       X:0 | X:1
465        //      -----------
466        // Y:0  | -1 | -1 |
467        // Y:1  | -1 |  0 |
468        //      -----------
469        IRBuilderBase::InsertPointGuard Guard(Builder);
470        Builder.SetInsertPoint(I);
471        Value *Or = Builder.CreateOr(X, Y);
472        return Builder.CreateSExt(Or, VTy);
473      }
474    }
475    LLVM_FALLTHROUGH;
476  case Instruction::Sub: {
477    /// If the high-bits of an ADD/SUB are not demanded, then we do not care
478    /// about the high bits of the operands.
479    unsigned NLZ = DemandedMask.countLeadingZeros();
480    // Right fill the mask of bits for this ADD/SUB to demand the most
481    // significant bit and all those below it.
482    APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
483    if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
484        SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
485        ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
486        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
487      if (NLZ > 0) {
488        // Disable the nsw and nuw flags here: We can no longer guarantee that
489        // we won't wrap after simplification. Removing the nsw/nuw flags is
490        // legal here because the top bit is not demanded.
491        BinaryOperator &BinOP = *cast<BinaryOperator>(I);
492        BinOP.setHasNoSignedWrap(false);
493        BinOP.setHasNoUnsignedWrap(false);
494      }
495      return I;
496    }
497
498    // If we are known to be adding/subtracting zeros to every bit below
499    // the highest demanded bit, we just return the other side.
500    if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
501      return I->getOperand(0);
502    // We can't do this with the LHS for subtraction, unless we are only
503    // demanding the LSB.
504    if ((I->getOpcode() == Instruction::Add ||
505         DemandedFromOps.isOneValue()) &&
506        DemandedFromOps.isSubsetOf(LHSKnown.Zero))
507      return I->getOperand(1);
508
509    // Otherwise just compute the known bits of the result.
510    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
511    Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add,
512                                        NSW, LHSKnown, RHSKnown);
513    break;
514  }
515  case Instruction::Shl: {
516    const APInt *SA;
517    if (match(I->getOperand(1), m_APInt(SA))) {
518      const APInt *ShrAmt;
519      if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt))))
520        if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0)))
521          if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA,
522                                                    DemandedMask, Known))
523            return R;
524
525      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
526      APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
527
528      // If the shift is NUW/NSW, then it does demand the high bits.
529      ShlOperator *IOp = cast<ShlOperator>(I);
530      if (IOp->hasNoSignedWrap())
531        DemandedMaskIn.setHighBits(ShiftAmt+1);
532      else if (IOp->hasNoUnsignedWrap())
533        DemandedMaskIn.setHighBits(ShiftAmt);
534
535      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
536        return I;
537      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
538
539      bool SignBitZero = Known.Zero.isSignBitSet();
540      bool SignBitOne = Known.One.isSignBitSet();
541      Known.Zero <<= ShiftAmt;
542      Known.One  <<= ShiftAmt;
543      // low bits known zero.
544      if (ShiftAmt)
545        Known.Zero.setLowBits(ShiftAmt);
546
547      // If this shift has "nsw" keyword, then the result is either a poison
548      // value or has the same sign bit as the first operand.
549      if (IOp->hasNoSignedWrap()) {
550        if (SignBitZero)
551          Known.Zero.setSignBit();
552        else if (SignBitOne)
553          Known.One.setSignBit();
554        if (Known.hasConflict())
555          return UndefValue::get(I->getType());
556      }
557    } else {
558      computeKnownBits(I, Known, Depth, CxtI);
559    }
560    break;
561  }
562  case Instruction::LShr: {
563    const APInt *SA;
564    if (match(I->getOperand(1), m_APInt(SA))) {
565      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
566
567      // Unsigned shift right.
568      APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
569
570      // If the shift is exact, then it does demand the low bits (and knows that
571      // they are zero).
572      if (cast<LShrOperator>(I)->isExact())
573        DemandedMaskIn.setLowBits(ShiftAmt);
574
575      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
576        return I;
577      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
578      Known.Zero.lshrInPlace(ShiftAmt);
579      Known.One.lshrInPlace(ShiftAmt);
580      if (ShiftAmt)
581        Known.Zero.setHighBits(ShiftAmt);  // high bits known zero.
582    } else {
583      computeKnownBits(I, Known, Depth, CxtI);
584    }
585    break;
586  }
587  case Instruction::AShr: {
588    // If this is an arithmetic shift right and only the low-bit is set, we can
589    // always convert this into a logical shr, even if the shift amount is
590    // variable.  The low bit of the shift cannot be an input sign bit unless
591    // the shift amount is >= the size of the datatype, which is undefined.
592    if (DemandedMask.isOneValue()) {
593      // Perform the logical shift right.
594      Instruction *NewVal = BinaryOperator::CreateLShr(
595                        I->getOperand(0), I->getOperand(1), I->getName());
596      return InsertNewInstWith(NewVal, *I);
597    }
598
599    // If the sign bit is the only bit demanded by this ashr, then there is no
600    // need to do it, the shift doesn't change the high bit.
601    if (DemandedMask.isSignMask())
602      return I->getOperand(0);
603
604    const APInt *SA;
605    if (match(I->getOperand(1), m_APInt(SA))) {
606      uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
607
608      // Signed shift right.
609      APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
610      // If any of the high bits are demanded, we should set the sign bit as
611      // demanded.
612      if (DemandedMask.countLeadingZeros() <= ShiftAmt)
613        DemandedMaskIn.setSignBit();
614
615      // If the shift is exact, then it does demand the low bits (and knows that
616      // they are zero).
617      if (cast<AShrOperator>(I)->isExact())
618        DemandedMaskIn.setLowBits(ShiftAmt);
619
620      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
621        return I;
622
623      unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
624
625      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
626      // Compute the new bits that are at the top now plus sign bits.
627      APInt HighBits(APInt::getHighBitsSet(
628          BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth)));
629      Known.Zero.lshrInPlace(ShiftAmt);
630      Known.One.lshrInPlace(ShiftAmt);
631
632      // If the input sign bit is known to be zero, or if none of the top bits
633      // are demanded, turn this into an unsigned shift right.
634      assert(BitWidth > ShiftAmt && "Shift amount not saturated?");
635      if (Known.Zero[BitWidth-ShiftAmt-1] ||
636          !DemandedMask.intersects(HighBits)) {
637        BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0),
638                                                          I->getOperand(1));
639        LShr->setIsExact(cast<BinaryOperator>(I)->isExact());
640        return InsertNewInstWith(LShr, *I);
641      } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one.
642        Known.One |= HighBits;
643      }
644    } else {
645      computeKnownBits(I, Known, Depth, CxtI);
646    }
647    break;
648  }
649  case Instruction::UDiv: {
650    // UDiv doesn't demand low bits that are zero in the divisor.
651    const APInt *SA;
652    if (match(I->getOperand(1), m_APInt(SA))) {
653      // If the shift is exact, then it does demand the low bits.
654      if (cast<UDivOperator>(I)->isExact())
655        break;
656
657      // FIXME: Take the demanded mask of the result into account.
658      unsigned RHSTrailingZeros = SA->countTrailingZeros();
659      APInt DemandedMaskIn =
660          APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
661      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1))
662        return I;
663
664      // Propagate zero bits from the input.
665      Known.Zero.setHighBits(std::min(
666          BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros));
667    } else {
668      computeKnownBits(I, Known, Depth, CxtI);
669    }
670    break;
671  }
672  case Instruction::SRem:
673    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
674      // X % -1 demands all the bits because we don't want to introduce
675      // INT_MIN % -1 (== undef) by accident.
676      if (Rem->isMinusOne())
677        break;
678      APInt RA = Rem->getValue().abs();
679      if (RA.isPowerOf2()) {
680        if (DemandedMask.ult(RA))    // srem won't affect demanded bits
681          return I->getOperand(0);
682
683        APInt LowBits = RA - 1;
684        APInt Mask2 = LowBits | APInt::getSignMask(BitWidth);
685        if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1))
686          return I;
687
688        // The low bits of LHS are unchanged by the srem.
689        Known.Zero = LHSKnown.Zero & LowBits;
690        Known.One = LHSKnown.One & LowBits;
691
692        // If LHS is non-negative or has all low bits zero, then the upper bits
693        // are all zero.
694        if (LHSKnown.isNonNegative() || LowBits.isSubsetOf(LHSKnown.Zero))
695          Known.Zero |= ~LowBits;
696
697        // If LHS is negative and not all low bits are zero, then the upper bits
698        // are all one.
699        if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One))
700          Known.One |= ~LowBits;
701
702        assert(!Known.hasConflict() && "Bits known to be one AND zero?");
703        break;
704      }
705    }
706
707    // The sign bit is the LHS's sign bit, except when the result of the
708    // remainder is zero.
709    if (DemandedMask.isSignBitSet()) {
710      computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
711      // If it's known zero, our sign bit is also zero.
712      if (LHSKnown.isNonNegative())
713        Known.makeNonNegative();
714    }
715    break;
716  case Instruction::URem: {
717    KnownBits Known2(BitWidth);
718    APInt AllOnes = APInt::getAllOnesValue(BitWidth);
719    if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) ||
720        SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1))
721      return I;
722
723    unsigned Leaders = Known2.countMinLeadingZeros();
724    Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
725    break;
726  }
727  case Instruction::Call: {
728    bool KnownBitsComputed = false;
729    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
730      switch (II->getIntrinsicID()) {
731      default: break;
732      case Intrinsic::bswap: {
733        // If the only bits demanded come from one byte of the bswap result,
734        // just shift the input byte into position to eliminate the bswap.
735        unsigned NLZ = DemandedMask.countLeadingZeros();
736        unsigned NTZ = DemandedMask.countTrailingZeros();
737
738        // Round NTZ down to the next byte.  If we have 11 trailing zeros, then
739        // we need all the bits down to bit 8.  Likewise, round NLZ.  If we
740        // have 14 leading zeros, round to 8.
741        NLZ &= ~7;
742        NTZ &= ~7;
743        // If we need exactly one byte, we can do this transformation.
744        if (BitWidth-NLZ-NTZ == 8) {
745          unsigned ResultBit = NTZ;
746          unsigned InputBit = BitWidth-NTZ-8;
747
748          // Replace this with either a left or right shift to get the byte into
749          // the right place.
750          Instruction *NewVal;
751          if (InputBit > ResultBit)
752            NewVal = BinaryOperator::CreateLShr(II->getArgOperand(0),
753                    ConstantInt::get(I->getType(), InputBit-ResultBit));
754          else
755            NewVal = BinaryOperator::CreateShl(II->getArgOperand(0),
756                    ConstantInt::get(I->getType(), ResultBit-InputBit));
757          NewVal->takeName(I);
758          return InsertNewInstWith(NewVal, *I);
759        }
760        break;
761      }
762      case Intrinsic::fshr:
763      case Intrinsic::fshl: {
764        const APInt *SA;
765        if (!match(I->getOperand(2), m_APInt(SA)))
766          break;
767
768        // Normalize to funnel shift left. APInt shifts of BitWidth are well-
769        // defined, so no need to special-case zero shifts here.
770        uint64_t ShiftAmt = SA->urem(BitWidth);
771        if (II->getIntrinsicID() == Intrinsic::fshr)
772          ShiftAmt = BitWidth - ShiftAmt;
773
774        APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt));
775        APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
776        if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) ||
777            SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
778          return I;
779
780        Known.Zero = LHSKnown.Zero.shl(ShiftAmt) |
781                     RHSKnown.Zero.lshr(BitWidth - ShiftAmt);
782        Known.One = LHSKnown.One.shl(ShiftAmt) |
783                    RHSKnown.One.lshr(BitWidth - ShiftAmt);
784        KnownBitsComputed = true;
785        break;
786      }
787      case Intrinsic::x86_mmx_pmovmskb:
788      case Intrinsic::x86_sse_movmsk_ps:
789      case Intrinsic::x86_sse2_movmsk_pd:
790      case Intrinsic::x86_sse2_pmovmskb_128:
791      case Intrinsic::x86_avx_movmsk_ps_256:
792      case Intrinsic::x86_avx_movmsk_pd_256:
793      case Intrinsic::x86_avx2_pmovmskb: {
794        // MOVMSK copies the vector elements' sign bits to the low bits
795        // and zeros the high bits.
796        unsigned ArgWidth;
797        if (II->getIntrinsicID() == Intrinsic::x86_mmx_pmovmskb) {
798          ArgWidth = 8; // Arg is x86_mmx, but treated as <8 x i8>.
799        } else {
800          auto Arg = II->getArgOperand(0);
801          auto ArgType = cast<VectorType>(Arg->getType());
802          ArgWidth = ArgType->getNumElements();
803        }
804
805        // If we don't need any of low bits then return zero,
806        // we know that DemandedMask is non-zero already.
807        APInt DemandedElts = DemandedMask.zextOrTrunc(ArgWidth);
808        if (DemandedElts.isNullValue())
809          return ConstantInt::getNullValue(VTy);
810
811        // We know that the upper bits are set to zero.
812        Known.Zero.setBitsFrom(ArgWidth);
813        KnownBitsComputed = true;
814        break;
815      }
816      case Intrinsic::x86_sse42_crc32_64_64:
817        Known.Zero.setBitsFrom(32);
818        KnownBitsComputed = true;
819        break;
820      }
821    }
822
823    if (!KnownBitsComputed)
824      computeKnownBits(V, Known, Depth, CxtI);
825    break;
826  }
827  }
828
829  // If the client is only demanding bits that we know, return the known
830  // constant.
831  if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
832    return Constant::getIntegerValue(VTy, Known.One);
833  return nullptr;
834}
835
836/// Helper routine of SimplifyDemandedUseBits. It computes Known
837/// bits. It also tries to handle simplifications that can be done based on
838/// DemandedMask, but without modifying the Instruction.
839Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
840                                                     const APInt &DemandedMask,
841                                                     KnownBits &Known,
842                                                     unsigned Depth,
843                                                     Instruction *CxtI) {
844  unsigned BitWidth = DemandedMask.getBitWidth();
845  Type *ITy = I->getType();
846
847  KnownBits LHSKnown(BitWidth);
848  KnownBits RHSKnown(BitWidth);
849
850  // Despite the fact that we can't simplify this instruction in all User's
851  // context, we can at least compute the known bits, and we can
852  // do simplifications that apply to *just* the one user if we know that
853  // this instruction has a simpler value in that context.
854  switch (I->getOpcode()) {
855  case Instruction::And: {
856    // If either the LHS or the RHS are Zero, the result is zero.
857    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
858    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
859                     CxtI);
860
861    Known = LHSKnown & RHSKnown;
862
863    // If the client is only demanding bits that we know, return the known
864    // constant.
865    if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
866      return Constant::getIntegerValue(ITy, Known.One);
867
868    // If all of the demanded bits are known 1 on one side, return the other.
869    // These bits cannot contribute to the result of the 'and' in this
870    // context.
871    if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
872      return I->getOperand(0);
873    if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
874      return I->getOperand(1);
875
876    break;
877  }
878  case Instruction::Or: {
879    // We can simplify (X|Y) -> X or Y in the user's context if we know that
880    // only bits from X or Y are demanded.
881
882    // If either the LHS or the RHS are One, the result is One.
883    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
884    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
885                     CxtI);
886
887    Known = LHSKnown | RHSKnown;
888
889    // If the client is only demanding bits that we know, return the known
890    // constant.
891    if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
892      return Constant::getIntegerValue(ITy, Known.One);
893
894    // If all of the demanded bits are known zero on one side, return the
895    // other.  These bits cannot contribute to the result of the 'or' in this
896    // context.
897    if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
898      return I->getOperand(0);
899    if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
900      return I->getOperand(1);
901
902    break;
903  }
904  case Instruction::Xor: {
905    // We can simplify (X^Y) -> X or Y in the user's context if we know that
906    // only bits from X or Y are demanded.
907
908    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
909    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
910                     CxtI);
911
912    Known = LHSKnown ^ RHSKnown;
913
914    // If the client is only demanding bits that we know, return the known
915    // constant.
916    if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
917      return Constant::getIntegerValue(ITy, Known.One);
918
919    // If all of the demanded bits are known zero on one side, return the
920    // other.
921    if (DemandedMask.isSubsetOf(RHSKnown.Zero))
922      return I->getOperand(0);
923    if (DemandedMask.isSubsetOf(LHSKnown.Zero))
924      return I->getOperand(1);
925
926    break;
927  }
928  default:
929    // Compute the Known bits to simplify things downstream.
930    computeKnownBits(I, Known, Depth, CxtI);
931
932    // If this user is only demanding bits that we know, return the known
933    // constant.
934    if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
935      return Constant::getIntegerValue(ITy, Known.One);
936
937    break;
938  }
939
940  return nullptr;
941}
942
943
944/// Helper routine of SimplifyDemandedUseBits. It tries to simplify
945/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
946/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
947/// of "C2-C1".
948///
949/// Suppose E1 and E2 are generally different in bits S={bm, bm+1,
950/// ..., bn}, without considering the specific value X is holding.
951/// This transformation is legal iff one of following conditions is hold:
952///  1) All the bit in S are 0, in this case E1 == E2.
953///  2) We don't care those bits in S, per the input DemandedMask.
954///  3) Combination of 1) and 2). Some bits in S are 0, and we don't care the
955///     rest bits.
956///
957/// Currently we only test condition 2).
958///
959/// As with SimplifyDemandedUseBits, it returns NULL if the simplification was
960/// not successful.
961Value *
962InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1,
963                                         Instruction *Shl, const APInt &ShlOp1,
964                                         const APInt &DemandedMask,
965                                         KnownBits &Known) {
966  if (!ShlOp1 || !ShrOp1)
967    return nullptr; // No-op.
968
969  Value *VarX = Shr->getOperand(0);
970  Type *Ty = VarX->getType();
971  unsigned BitWidth = Ty->getScalarSizeInBits();
972  if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth))
973    return nullptr; // Undef.
974
975  unsigned ShlAmt = ShlOp1.getZExtValue();
976  unsigned ShrAmt = ShrOp1.getZExtValue();
977
978  Known.One.clearAllBits();
979  Known.Zero.setLowBits(ShlAmt - 1);
980  Known.Zero &= DemandedMask;
981
982  APInt BitMask1(APInt::getAllOnesValue(BitWidth));
983  APInt BitMask2(APInt::getAllOnesValue(BitWidth));
984
985  bool isLshr = (Shr->getOpcode() == Instruction::LShr);
986  BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) :
987                      (BitMask1.ashr(ShrAmt) << ShlAmt);
988
989  if (ShrAmt <= ShlAmt) {
990    BitMask2 <<= (ShlAmt - ShrAmt);
991  } else {
992    BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt):
993                        BitMask2.ashr(ShrAmt - ShlAmt);
994  }
995
996  // Check if condition-2 (see the comment to this function) is satified.
997  if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) {
998    if (ShrAmt == ShlAmt)
999      return VarX;
1000
1001    if (!Shr->hasOneUse())
1002      return nullptr;
1003
1004    BinaryOperator *New;
1005    if (ShrAmt < ShlAmt) {
1006      Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt);
1007      New = BinaryOperator::CreateShl(VarX, Amt);
1008      BinaryOperator *Orig = cast<BinaryOperator>(Shl);
1009      New->setHasNoSignedWrap(Orig->hasNoSignedWrap());
1010      New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap());
1011    } else {
1012      Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt);
1013      New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) :
1014                     BinaryOperator::CreateAShr(VarX, Amt);
1015      if (cast<BinaryOperator>(Shr)->isExact())
1016        New->setIsExact(true);
1017    }
1018
1019    return InsertNewInstWith(New, *Shl);
1020  }
1021
1022  return nullptr;
1023}
1024
1025/// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics.
1026///
1027/// Note: This only supports non-TFE/LWE image intrinsic calls; those have
1028///       struct returns.
1029Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
1030                                                           APInt DemandedElts,
1031                                                           int DMaskIdx) {
1032
1033  auto *IIVTy = cast<VectorType>(II->getType());
1034  unsigned VWidth = IIVTy->getNumElements();
1035  if (VWidth == 1)
1036    return nullptr;
1037
1038  IRBuilderBase::InsertPointGuard Guard(Builder);
1039  Builder.SetInsertPoint(II);
1040
1041  // Assume the arguments are unchanged and later override them, if needed.
1042  SmallVector<Value *, 16> Args(II->arg_begin(), II->arg_end());
1043
1044  if (DMaskIdx < 0) {
1045    // Buffer case.
1046
1047    const unsigned ActiveBits = DemandedElts.getActiveBits();
1048    const unsigned UnusedComponentsAtFront = DemandedElts.countTrailingZeros();
1049
1050    // Start assuming the prefix of elements is demanded, but possibly clear
1051    // some other bits if there are trailing zeros (unused components at front)
1052    // and update offset.
1053    DemandedElts = (1 << ActiveBits) - 1;
1054
1055    if (UnusedComponentsAtFront > 0) {
1056      static const unsigned InvalidOffsetIdx = 0xf;
1057
1058      unsigned OffsetIdx;
1059      switch (II->getIntrinsicID()) {
1060      case Intrinsic::amdgcn_raw_buffer_load:
1061        OffsetIdx = 1;
1062        break;
1063      case Intrinsic::amdgcn_s_buffer_load:
1064        // If resulting type is vec3, there is no point in trimming the
1065        // load with updated offset, as the vec3 would most likely be widened to
1066        // vec4 anyway during lowering.
1067        if (ActiveBits == 4 && UnusedComponentsAtFront == 1)
1068          OffsetIdx = InvalidOffsetIdx;
1069        else
1070          OffsetIdx = 1;
1071        break;
1072      case Intrinsic::amdgcn_struct_buffer_load:
1073        OffsetIdx = 2;
1074        break;
1075      default:
1076        // TODO: handle tbuffer* intrinsics.
1077        OffsetIdx = InvalidOffsetIdx;
1078        break;
1079      }
1080
1081      if (OffsetIdx != InvalidOffsetIdx) {
1082        // Clear demanded bits and update the offset.
1083        DemandedElts &= ~((1 << UnusedComponentsAtFront) - 1);
1084        auto *Offset = II->getArgOperand(OffsetIdx);
1085        unsigned SingleComponentSizeInBits =
1086            getDataLayout().getTypeSizeInBits(II->getType()->getScalarType());
1087        unsigned OffsetAdd =
1088            UnusedComponentsAtFront * SingleComponentSizeInBits / 8;
1089        auto *OffsetAddVal = ConstantInt::get(Offset->getType(), OffsetAdd);
1090        Args[OffsetIdx] = Builder.CreateAdd(Offset, OffsetAddVal);
1091      }
1092    }
1093  } else {
1094    // Image case.
1095
1096    ConstantInt *DMask = cast<ConstantInt>(II->getArgOperand(DMaskIdx));
1097    unsigned DMaskVal = DMask->getZExtValue() & 0xf;
1098
1099    // Mask off values that are undefined because the dmask doesn't cover them
1100    DemandedElts &= (1 << countPopulation(DMaskVal)) - 1;
1101
1102    unsigned NewDMaskVal = 0;
1103    unsigned OrigLoadIdx = 0;
1104    for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) {
1105      const unsigned Bit = 1 << SrcIdx;
1106      if (!!(DMaskVal & Bit)) {
1107        if (!!DemandedElts[OrigLoadIdx])
1108          NewDMaskVal |= Bit;
1109        OrigLoadIdx++;
1110      }
1111    }
1112
1113    if (DMaskVal != NewDMaskVal)
1114      Args[DMaskIdx] = ConstantInt::get(DMask->getType(), NewDMaskVal);
1115  }
1116
1117  unsigned NewNumElts = DemandedElts.countPopulation();
1118  if (!NewNumElts)
1119    return UndefValue::get(II->getType());
1120
1121  // FIXME: Allow v3i16/v3f16 in buffer and image intrinsics when the types are
1122  // fully supported.
1123  if (II->getType()->getScalarSizeInBits() == 16 && NewNumElts == 3)
1124    return nullptr;
1125
1126  if (NewNumElts >= VWidth && DemandedElts.isMask()) {
1127    if (DMaskIdx >= 0)
1128      II->setArgOperand(DMaskIdx, Args[DMaskIdx]);
1129    return nullptr;
1130  }
1131
1132  // Validate function argument and return types, extracting overloaded types
1133  // along the way.
1134  SmallVector<Type *, 6> OverloadTys;
1135  if (!Intrinsic::getIntrinsicSignature(II->getCalledFunction(), OverloadTys))
1136    return nullptr;
1137
1138  Module *M = II->getParent()->getParent()->getParent();
1139  Type *EltTy = IIVTy->getElementType();
1140  Type *NewTy =
1141      (NewNumElts == 1) ? EltTy : FixedVectorType::get(EltTy, NewNumElts);
1142
1143  OverloadTys[0] = NewTy;
1144  Function *NewIntrin =
1145      Intrinsic::getDeclaration(M, II->getIntrinsicID(), OverloadTys);
1146
1147  CallInst *NewCall = Builder.CreateCall(NewIntrin, Args);
1148  NewCall->takeName(II);
1149  NewCall->copyMetadata(*II);
1150
1151  if (NewNumElts == 1) {
1152    return Builder.CreateInsertElement(UndefValue::get(II->getType()), NewCall,
1153                                       DemandedElts.countTrailingZeros());
1154  }
1155
1156  SmallVector<int, 8> EltMask;
1157  unsigned NewLoadIdx = 0;
1158  for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) {
1159    if (!!DemandedElts[OrigLoadIdx])
1160      EltMask.push_back(NewLoadIdx++);
1161    else
1162      EltMask.push_back(NewNumElts);
1163  }
1164
1165  Value *Shuffle =
1166      Builder.CreateShuffleVector(NewCall, UndefValue::get(NewTy), EltMask);
1167
1168  return Shuffle;
1169}
1170
1171/// The specified value produces a vector with any number of elements.
1172/// This method analyzes which elements of the operand are undef and returns
1173/// that information in UndefElts.
1174///
1175/// DemandedElts contains the set of elements that are actually used by the
1176/// caller, and by default (AllowMultipleUsers equals false) the value is
1177/// simplified only if it has a single caller. If AllowMultipleUsers is set
1178/// to true, DemandedElts refers to the union of sets of elements that are
1179/// used by all callers.
1180///
1181/// If the information about demanded elements can be used to simplify the
1182/// operation, the operation is simplified, then the resultant value is
1183/// returned.  This returns null if no change was made.
1184Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
1185                                                APInt &UndefElts,
1186                                                unsigned Depth,
1187                                                bool AllowMultipleUsers) {
1188  // Cannot analyze scalable type. The number of vector elements is not a
1189  // compile-time constant.
1190  if (isa<ScalableVectorType>(V->getType()))
1191    return nullptr;
1192
1193  unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
1194  APInt EltMask(APInt::getAllOnesValue(VWidth));
1195  assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
1196
1197  if (isa<UndefValue>(V)) {
1198    // If the entire vector is undefined, just return this info.
1199    UndefElts = EltMask;
1200    return nullptr;
1201  }
1202
1203  if (DemandedElts.isNullValue()) { // If nothing is demanded, provide undef.
1204    UndefElts = EltMask;
1205    return UndefValue::get(V->getType());
1206  }
1207
1208  UndefElts = 0;
1209
1210  if (auto *C = dyn_cast<Constant>(V)) {
1211    // Check if this is identity. If so, return 0 since we are not simplifying
1212    // anything.
1213    if (DemandedElts.isAllOnesValue())
1214      return nullptr;
1215
1216    Type *EltTy = cast<VectorType>(V->getType())->getElementType();
1217    Constant *Undef = UndefValue::get(EltTy);
1218    SmallVector<Constant*, 16> Elts;
1219    for (unsigned i = 0; i != VWidth; ++i) {
1220      if (!DemandedElts[i]) {   // If not demanded, set to undef.
1221        Elts.push_back(Undef);
1222        UndefElts.setBit(i);
1223        continue;
1224      }
1225
1226      Constant *Elt = C->getAggregateElement(i);
1227      if (!Elt) return nullptr;
1228
1229      if (isa<UndefValue>(Elt)) {   // Already undef.
1230        Elts.push_back(Undef);
1231        UndefElts.setBit(i);
1232      } else {                               // Otherwise, defined.
1233        Elts.push_back(Elt);
1234      }
1235    }
1236
1237    // If we changed the constant, return it.
1238    Constant *NewCV = ConstantVector::get(Elts);
1239    return NewCV != C ? NewCV : nullptr;
1240  }
1241
1242  // Limit search depth.
1243  if (Depth == 10)
1244    return nullptr;
1245
1246  if (!AllowMultipleUsers) {
1247    // If multiple users are using the root value, proceed with
1248    // simplification conservatively assuming that all elements
1249    // are needed.
1250    if (!V->hasOneUse()) {
1251      // Quit if we find multiple users of a non-root value though.
1252      // They'll be handled when it's their turn to be visited by
1253      // the main instcombine process.
1254      if (Depth != 0)
1255        // TODO: Just compute the UndefElts information recursively.
1256        return nullptr;
1257
1258      // Conservatively assume that all elements are needed.
1259      DemandedElts = EltMask;
1260    }
1261  }
1262
1263  Instruction *I = dyn_cast<Instruction>(V);
1264  if (!I) return nullptr;        // Only analyze instructions.
1265
1266  bool MadeChange = false;
1267  auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum,
1268                              APInt Demanded, APInt &Undef) {
1269    auto *II = dyn_cast<IntrinsicInst>(Inst);
1270    Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum);
1271    if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) {
1272      replaceOperand(*Inst, OpNum, V);
1273      MadeChange = true;
1274    }
1275  };
1276
1277  APInt UndefElts2(VWidth, 0);
1278  APInt UndefElts3(VWidth, 0);
1279  switch (I->getOpcode()) {
1280  default: break;
1281
1282  case Instruction::GetElementPtr: {
1283    // The LangRef requires that struct geps have all constant indices.  As
1284    // such, we can't convert any operand to partial undef.
1285    auto mayIndexStructType = [](GetElementPtrInst &GEP) {
1286      for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP);
1287           I != E; I++)
1288        if (I.isStruct())
1289          return true;;
1290      return false;
1291    };
1292    if (mayIndexStructType(cast<GetElementPtrInst>(*I)))
1293      break;
1294
1295    // Conservatively track the demanded elements back through any vector
1296    // operands we may have.  We know there must be at least one, or we
1297    // wouldn't have a vector result to get here. Note that we intentionally
1298    // merge the undef bits here since gepping with either an undef base or
1299    // index results in undef.
1300    for (unsigned i = 0; i < I->getNumOperands(); i++) {
1301      if (isa<UndefValue>(I->getOperand(i))) {
1302        // If the entire vector is undefined, just return this info.
1303        UndefElts = EltMask;
1304        return nullptr;
1305      }
1306      if (I->getOperand(i)->getType()->isVectorTy()) {
1307        APInt UndefEltsOp(VWidth, 0);
1308        simplifyAndSetOp(I, i, DemandedElts, UndefEltsOp);
1309        UndefElts |= UndefEltsOp;
1310      }
1311    }
1312
1313    break;
1314  }
1315  case Instruction::InsertElement: {
1316    // If this is a variable index, we don't know which element it overwrites.
1317    // demand exactly the same input as we produce.
1318    ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2));
1319    if (!Idx) {
1320      // Note that we can't propagate undef elt info, because we don't know
1321      // which elt is getting updated.
1322      simplifyAndSetOp(I, 0, DemandedElts, UndefElts2);
1323      break;
1324    }
1325
1326    // The element inserted overwrites whatever was there, so the input demanded
1327    // set is simpler than the output set.
1328    unsigned IdxNo = Idx->getZExtValue();
1329    APInt PreInsertDemandedElts = DemandedElts;
1330    if (IdxNo < VWidth)
1331      PreInsertDemandedElts.clearBit(IdxNo);
1332
1333    simplifyAndSetOp(I, 0, PreInsertDemandedElts, UndefElts);
1334
1335    // If this is inserting an element that isn't demanded, remove this
1336    // insertelement.
1337    if (IdxNo >= VWidth || !DemandedElts[IdxNo]) {
1338      Worklist.push(I);
1339      return I->getOperand(0);
1340    }
1341
1342    // The inserted element is defined.
1343    UndefElts.clearBit(IdxNo);
1344    break;
1345  }
1346  case Instruction::ShuffleVector: {
1347    auto *Shuffle = cast<ShuffleVectorInst>(I);
1348    assert(Shuffle->getOperand(0)->getType() ==
1349           Shuffle->getOperand(1)->getType() &&
1350           "Expected shuffle operands to have same type");
1351    unsigned OpWidth =
1352        cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements();
1353    // Handle trivial case of a splat. Only check the first element of LHS
1354    // operand.
1355    if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) &&
1356        DemandedElts.isAllOnesValue()) {
1357      if (!isa<UndefValue>(I->getOperand(1))) {
1358        I->setOperand(1, UndefValue::get(I->getOperand(1)->getType()));
1359        MadeChange = true;
1360      }
1361      APInt LeftDemanded(OpWidth, 1);
1362      APInt LHSUndefElts(OpWidth, 0);
1363      simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
1364      if (LHSUndefElts[0])
1365        UndefElts = EltMask;
1366      else
1367        UndefElts.clearAllBits();
1368      break;
1369    }
1370
1371    APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
1372    for (unsigned i = 0; i < VWidth; i++) {
1373      if (DemandedElts[i]) {
1374        unsigned MaskVal = Shuffle->getMaskValue(i);
1375        if (MaskVal != -1u) {
1376          assert(MaskVal < OpWidth * 2 &&
1377                 "shufflevector mask index out of range!");
1378          if (MaskVal < OpWidth)
1379            LeftDemanded.setBit(MaskVal);
1380          else
1381            RightDemanded.setBit(MaskVal - OpWidth);
1382        }
1383      }
1384    }
1385
1386    APInt LHSUndefElts(OpWidth, 0);
1387    simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
1388
1389    APInt RHSUndefElts(OpWidth, 0);
1390    simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts);
1391
1392    // If this shuffle does not change the vector length and the elements
1393    // demanded by this shuffle are an identity mask, then this shuffle is
1394    // unnecessary.
1395    //
1396    // We are assuming canonical form for the mask, so the source vector is
1397    // operand 0 and operand 1 is not used.
1398    //
1399    // Note that if an element is demanded and this shuffle mask is undefined
1400    // for that element, then the shuffle is not considered an identity
1401    // operation. The shuffle prevents poison from the operand vector from
1402    // leaking to the result by replacing poison with an undefined value.
1403    if (VWidth == OpWidth) {
1404      bool IsIdentityShuffle = true;
1405      for (unsigned i = 0; i < VWidth; i++) {
1406        unsigned MaskVal = Shuffle->getMaskValue(i);
1407        if (DemandedElts[i] && i != MaskVal) {
1408          IsIdentityShuffle = false;
1409          break;
1410        }
1411      }
1412      if (IsIdentityShuffle)
1413        return Shuffle->getOperand(0);
1414    }
1415
1416    bool NewUndefElts = false;
1417    unsigned LHSIdx = -1u, LHSValIdx = -1u;
1418    unsigned RHSIdx = -1u, RHSValIdx = -1u;
1419    bool LHSUniform = true;
1420    bool RHSUniform = true;
1421    for (unsigned i = 0; i < VWidth; i++) {
1422      unsigned MaskVal = Shuffle->getMaskValue(i);
1423      if (MaskVal == -1u) {
1424        UndefElts.setBit(i);
1425      } else if (!DemandedElts[i]) {
1426        NewUndefElts = true;
1427        UndefElts.setBit(i);
1428      } else if (MaskVal < OpWidth) {
1429        if (LHSUndefElts[MaskVal]) {
1430          NewUndefElts = true;
1431          UndefElts.setBit(i);
1432        } else {
1433          LHSIdx = LHSIdx == -1u ? i : OpWidth;
1434          LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
1435          LHSUniform = LHSUniform && (MaskVal == i);
1436        }
1437      } else {
1438        if (RHSUndefElts[MaskVal - OpWidth]) {
1439          NewUndefElts = true;
1440          UndefElts.setBit(i);
1441        } else {
1442          RHSIdx = RHSIdx == -1u ? i : OpWidth;
1443          RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
1444          RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
1445        }
1446      }
1447    }
1448
1449    // Try to transform shuffle with constant vector and single element from
1450    // this constant vector to single insertelement instruction.
1451    // shufflevector V, C, <v1, v2, .., ci, .., vm> ->
1452    // insertelement V, C[ci], ci-n
1453    if (OpWidth == Shuffle->getType()->getNumElements()) {
1454      Value *Op = nullptr;
1455      Constant *Value = nullptr;
1456      unsigned Idx = -1u;
1457
1458      // Find constant vector with the single element in shuffle (LHS or RHS).
1459      if (LHSIdx < OpWidth && RHSUniform) {
1460        if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) {
1461          Op = Shuffle->getOperand(1);
1462          Value = CV->getOperand(LHSValIdx);
1463          Idx = LHSIdx;
1464        }
1465      }
1466      if (RHSIdx < OpWidth && LHSUniform) {
1467        if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) {
1468          Op = Shuffle->getOperand(0);
1469          Value = CV->getOperand(RHSValIdx);
1470          Idx = RHSIdx;
1471        }
1472      }
1473      // Found constant vector with single element - convert to insertelement.
1474      if (Op && Value) {
1475        Instruction *New = InsertElementInst::Create(
1476            Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx),
1477            Shuffle->getName());
1478        InsertNewInstWith(New, *Shuffle);
1479        return New;
1480      }
1481    }
1482    if (NewUndefElts) {
1483      // Add additional discovered undefs.
1484      SmallVector<int, 16> Elts;
1485      for (unsigned i = 0; i < VWidth; ++i) {
1486        if (UndefElts[i])
1487          Elts.push_back(UndefMaskElem);
1488        else
1489          Elts.push_back(Shuffle->getMaskValue(i));
1490      }
1491      Shuffle->setShuffleMask(Elts);
1492      MadeChange = true;
1493    }
1494    break;
1495  }
1496  case Instruction::Select: {
1497    // If this is a vector select, try to transform the select condition based
1498    // on the current demanded elements.
1499    SelectInst *Sel = cast<SelectInst>(I);
1500    if (Sel->getCondition()->getType()->isVectorTy()) {
1501      // TODO: We are not doing anything with UndefElts based on this call.
1502      // It is overwritten below based on the other select operands. If an
1503      // element of the select condition is known undef, then we are free to
1504      // choose the output value from either arm of the select. If we know that
1505      // one of those values is undef, then the output can be undef.
1506      simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
1507    }
1508
1509    // Next, see if we can transform the arms of the select.
1510    APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts);
1511    if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) {
1512      for (unsigned i = 0; i < VWidth; i++) {
1513        // isNullValue() always returns false when called on a ConstantExpr.
1514        // Skip constant expressions to avoid propagating incorrect information.
1515        Constant *CElt = CV->getAggregateElement(i);
1516        if (isa<ConstantExpr>(CElt))
1517          continue;
1518        // TODO: If a select condition element is undef, we can demand from
1519        // either side. If one side is known undef, choosing that side would
1520        // propagate undef.
1521        if (CElt->isNullValue())
1522          DemandedLHS.clearBit(i);
1523        else
1524          DemandedRHS.clearBit(i);
1525      }
1526    }
1527
1528    simplifyAndSetOp(I, 1, DemandedLHS, UndefElts2);
1529    simplifyAndSetOp(I, 2, DemandedRHS, UndefElts3);
1530
1531    // Output elements are undefined if the element from each arm is undefined.
1532    // TODO: This can be improved. See comment in select condition handling.
1533    UndefElts = UndefElts2 & UndefElts3;
1534    break;
1535  }
1536  case Instruction::BitCast: {
1537    // Vector->vector casts only.
1538    VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType());
1539    if (!VTy) break;
1540    unsigned InVWidth = VTy->getNumElements();
1541    APInt InputDemandedElts(InVWidth, 0);
1542    UndefElts2 = APInt(InVWidth, 0);
1543    unsigned Ratio;
1544
1545    if (VWidth == InVWidth) {
1546      // If we are converting from <4 x i32> -> <4 x f32>, we demand the same
1547      // elements as are demanded of us.
1548      Ratio = 1;
1549      InputDemandedElts = DemandedElts;
1550    } else if ((VWidth % InVWidth) == 0) {
1551      // If the number of elements in the output is a multiple of the number of
1552      // elements in the input then an input element is live if any of the
1553      // corresponding output elements are live.
1554      Ratio = VWidth / InVWidth;
1555      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1556        if (DemandedElts[OutIdx])
1557          InputDemandedElts.setBit(OutIdx / Ratio);
1558    } else if ((InVWidth % VWidth) == 0) {
1559      // If the number of elements in the input is a multiple of the number of
1560      // elements in the output then an input element is live if the
1561      // corresponding output element is live.
1562      Ratio = InVWidth / VWidth;
1563      for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
1564        if (DemandedElts[InIdx / Ratio])
1565          InputDemandedElts.setBit(InIdx);
1566    } else {
1567      // Unsupported so far.
1568      break;
1569    }
1570
1571    simplifyAndSetOp(I, 0, InputDemandedElts, UndefElts2);
1572
1573    if (VWidth == InVWidth) {
1574      UndefElts = UndefElts2;
1575    } else if ((VWidth % InVWidth) == 0) {
1576      // If the number of elements in the output is a multiple of the number of
1577      // elements in the input then an output element is undef if the
1578      // corresponding input element is undef.
1579      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1580        if (UndefElts2[OutIdx / Ratio])
1581          UndefElts.setBit(OutIdx);
1582    } else if ((InVWidth % VWidth) == 0) {
1583      // If the number of elements in the input is a multiple of the number of
1584      // elements in the output then an output element is undef if all of the
1585      // corresponding input elements are undef.
1586      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
1587        APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
1588        if (SubUndef.countPopulation() == Ratio)
1589          UndefElts.setBit(OutIdx);
1590      }
1591    } else {
1592      llvm_unreachable("Unimp");
1593    }
1594    break;
1595  }
1596  case Instruction::FPTrunc:
1597  case Instruction::FPExt:
1598    simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
1599    break;
1600
1601  case Instruction::Call: {
1602    IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
1603    if (!II) break;
1604    switch (II->getIntrinsicID()) {
1605    case Intrinsic::masked_gather: // fallthrough
1606    case Intrinsic::masked_load: {
1607      // Subtlety: If we load from a pointer, the pointer must be valid
1608      // regardless of whether the element is demanded.  Doing otherwise risks
1609      // segfaults which didn't exist in the original program.
1610      APInt DemandedPtrs(APInt::getAllOnesValue(VWidth)),
1611        DemandedPassThrough(DemandedElts);
1612      if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2)))
1613        for (unsigned i = 0; i < VWidth; i++) {
1614          Constant *CElt = CV->getAggregateElement(i);
1615          if (CElt->isNullValue())
1616            DemandedPtrs.clearBit(i);
1617          else if (CElt->isAllOnesValue())
1618            DemandedPassThrough.clearBit(i);
1619        }
1620      if (II->getIntrinsicID() == Intrinsic::masked_gather)
1621        simplifyAndSetOp(II, 0, DemandedPtrs, UndefElts2);
1622      simplifyAndSetOp(II, 3, DemandedPassThrough, UndefElts3);
1623
1624      // Output elements are undefined if the element from both sources are.
1625      // TODO: can strengthen via mask as well.
1626      UndefElts = UndefElts2 & UndefElts3;
1627      break;
1628    }
1629    case Intrinsic::x86_xop_vfrcz_ss:
1630    case Intrinsic::x86_xop_vfrcz_sd:
1631      // The instructions for these intrinsics are speced to zero upper bits not
1632      // pass them through like other scalar intrinsics. So we shouldn't just
1633      // use Arg0 if DemandedElts[0] is clear like we do for other intrinsics.
1634      // Instead we should return a zero vector.
1635      if (!DemandedElts[0]) {
1636        Worklist.push(II);
1637        return ConstantAggregateZero::get(II->getType());
1638      }
1639
1640      // Only the lower element is used.
1641      DemandedElts = 1;
1642      simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
1643
1644      // Only the lower element is undefined. The high elements are zero.
1645      UndefElts = UndefElts[0];
1646      break;
1647
1648    // Unary scalar-as-vector operations that work column-wise.
1649    case Intrinsic::x86_sse_rcp_ss:
1650    case Intrinsic::x86_sse_rsqrt_ss:
1651      simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
1652
1653      // If lowest element of a scalar op isn't used then use Arg0.
1654      if (!DemandedElts[0]) {
1655        Worklist.push(II);
1656        return II->getArgOperand(0);
1657      }
1658      // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions
1659      // checks).
1660      break;
1661
1662    // Binary scalar-as-vector operations that work column-wise. The high
1663    // elements come from operand 0. The low element is a function of both
1664    // operands.
1665    case Intrinsic::x86_sse_min_ss:
1666    case Intrinsic::x86_sse_max_ss:
1667    case Intrinsic::x86_sse_cmp_ss:
1668    case Intrinsic::x86_sse2_min_sd:
1669    case Intrinsic::x86_sse2_max_sd:
1670    case Intrinsic::x86_sse2_cmp_sd: {
1671      simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
1672
1673      // If lowest element of a scalar op isn't used then use Arg0.
1674      if (!DemandedElts[0]) {
1675        Worklist.push(II);
1676        return II->getArgOperand(0);
1677      }
1678
1679      // Only lower element is used for operand 1.
1680      DemandedElts = 1;
1681      simplifyAndSetOp(II, 1, DemandedElts, UndefElts2);
1682
1683      // Lower element is undefined if both lower elements are undefined.
1684      // Consider things like undef&0.  The result is known zero, not undef.
1685      if (!UndefElts2[0])
1686        UndefElts.clearBit(0);
1687
1688      break;
1689    }
1690
1691    // Binary scalar-as-vector operations that work column-wise. The high
1692    // elements come from operand 0 and the low element comes from operand 1.
1693    case Intrinsic::x86_sse41_round_ss:
1694    case Intrinsic::x86_sse41_round_sd: {
1695      // Don't use the low element of operand 0.
1696      APInt DemandedElts2 = DemandedElts;
1697      DemandedElts2.clearBit(0);
1698      simplifyAndSetOp(II, 0, DemandedElts2, UndefElts);
1699
1700      // If lowest element of a scalar op isn't used then use Arg0.
1701      if (!DemandedElts[0]) {
1702        Worklist.push(II);
1703        return II->getArgOperand(0);
1704      }
1705
1706      // Only lower element is used for operand 1.
1707      DemandedElts = 1;
1708      simplifyAndSetOp(II, 1, DemandedElts, UndefElts2);
1709
1710      // Take the high undef elements from operand 0 and take the lower element
1711      // from operand 1.
1712      UndefElts.clearBit(0);
1713      UndefElts |= UndefElts2[0];
1714      break;
1715    }
1716
1717    // Three input scalar-as-vector operations that work column-wise. The high
1718    // elements come from operand 0 and the low element is a function of all
1719    // three inputs.
1720    case Intrinsic::x86_avx512_mask_add_ss_round:
1721    case Intrinsic::x86_avx512_mask_div_ss_round:
1722    case Intrinsic::x86_avx512_mask_mul_ss_round:
1723    case Intrinsic::x86_avx512_mask_sub_ss_round:
1724    case Intrinsic::x86_avx512_mask_max_ss_round:
1725    case Intrinsic::x86_avx512_mask_min_ss_round:
1726    case Intrinsic::x86_avx512_mask_add_sd_round:
1727    case Intrinsic::x86_avx512_mask_div_sd_round:
1728    case Intrinsic::x86_avx512_mask_mul_sd_round:
1729    case Intrinsic::x86_avx512_mask_sub_sd_round:
1730    case Intrinsic::x86_avx512_mask_max_sd_round:
1731    case Intrinsic::x86_avx512_mask_min_sd_round:
1732      simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
1733
1734      // If lowest element of a scalar op isn't used then use Arg0.
1735      if (!DemandedElts[0]) {
1736        Worklist.push(II);
1737        return II->getArgOperand(0);
1738      }
1739
1740      // Only lower element is used for operand 1 and 2.
1741      DemandedElts = 1;
1742      simplifyAndSetOp(II, 1, DemandedElts, UndefElts2);
1743      simplifyAndSetOp(II, 2, DemandedElts, UndefElts3);
1744
1745      // Lower element is undefined if all three lower elements are undefined.
1746      // Consider things like undef&0.  The result is known zero, not undef.
1747      if (!UndefElts2[0] || !UndefElts3[0])
1748        UndefElts.clearBit(0);
1749
1750      break;
1751
1752    case Intrinsic::x86_sse2_packssdw_128:
1753    case Intrinsic::x86_sse2_packsswb_128:
1754    case Intrinsic::x86_sse2_packuswb_128:
1755    case Intrinsic::x86_sse41_packusdw:
1756    case Intrinsic::x86_avx2_packssdw:
1757    case Intrinsic::x86_avx2_packsswb:
1758    case Intrinsic::x86_avx2_packusdw:
1759    case Intrinsic::x86_avx2_packuswb:
1760    case Intrinsic::x86_avx512_packssdw_512:
1761    case Intrinsic::x86_avx512_packsswb_512:
1762    case Intrinsic::x86_avx512_packusdw_512:
1763    case Intrinsic::x86_avx512_packuswb_512: {
1764      auto *Ty0 = II->getArgOperand(0)->getType();
1765      unsigned InnerVWidth = cast<VectorType>(Ty0)->getNumElements();
1766      assert(VWidth == (InnerVWidth * 2) && "Unexpected input size");
1767
1768      unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128;
1769      unsigned VWidthPerLane = VWidth / NumLanes;
1770      unsigned InnerVWidthPerLane = InnerVWidth / NumLanes;
1771
1772      // Per lane, pack the elements of the first input and then the second.
1773      // e.g.
1774      // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3])
1775      // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15])
1776      for (int OpNum = 0; OpNum != 2; ++OpNum) {
1777        APInt OpDemandedElts(InnerVWidth, 0);
1778        for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
1779          unsigned LaneIdx = Lane * VWidthPerLane;
1780          for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) {
1781            unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum;
1782            if (DemandedElts[Idx])
1783              OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt);
1784          }
1785        }
1786
1787        // Demand elements from the operand.
1788        APInt OpUndefElts(InnerVWidth, 0);
1789        simplifyAndSetOp(II, OpNum, OpDemandedElts, OpUndefElts);
1790
1791        // Pack the operand's UNDEF elements, one lane at a time.
1792        OpUndefElts = OpUndefElts.zext(VWidth);
1793        for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
1794          APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane);
1795          LaneElts = LaneElts.getLoBits(InnerVWidthPerLane);
1796          LaneElts <<= InnerVWidthPerLane * (2 * Lane + OpNum);
1797          UndefElts |= LaneElts;
1798        }
1799      }
1800      break;
1801    }
1802
1803    // PSHUFB
1804    case Intrinsic::x86_ssse3_pshuf_b_128:
1805    case Intrinsic::x86_avx2_pshuf_b:
1806    case Intrinsic::x86_avx512_pshuf_b_512:
1807    // PERMILVAR
1808    case Intrinsic::x86_avx_vpermilvar_ps:
1809    case Intrinsic::x86_avx_vpermilvar_ps_256:
1810    case Intrinsic::x86_avx512_vpermilvar_ps_512:
1811    case Intrinsic::x86_avx_vpermilvar_pd:
1812    case Intrinsic::x86_avx_vpermilvar_pd_256:
1813    case Intrinsic::x86_avx512_vpermilvar_pd_512:
1814    // PERMV
1815    case Intrinsic::x86_avx2_permd:
1816    case Intrinsic::x86_avx2_permps: {
1817      simplifyAndSetOp(II, 1, DemandedElts, UndefElts);
1818      break;
1819    }
1820
1821    // SSE4A instructions leave the upper 64-bits of the 128-bit result
1822    // in an undefined state.
1823    case Intrinsic::x86_sse4a_extrq:
1824    case Intrinsic::x86_sse4a_extrqi:
1825    case Intrinsic::x86_sse4a_insertq:
1826    case Intrinsic::x86_sse4a_insertqi:
1827      UndefElts.setHighBits(VWidth / 2);
1828      break;
1829    case Intrinsic::amdgcn_buffer_load:
1830    case Intrinsic::amdgcn_buffer_load_format:
1831    case Intrinsic::amdgcn_raw_buffer_load:
1832    case Intrinsic::amdgcn_raw_buffer_load_format:
1833    case Intrinsic::amdgcn_raw_tbuffer_load:
1834    case Intrinsic::amdgcn_s_buffer_load:
1835    case Intrinsic::amdgcn_struct_buffer_load:
1836    case Intrinsic::amdgcn_struct_buffer_load_format:
1837    case Intrinsic::amdgcn_struct_tbuffer_load:
1838    case Intrinsic::amdgcn_tbuffer_load:
1839      return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts);
1840    default: {
1841      if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID()))
1842        return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0);
1843
1844      break;
1845    }
1846    } // switch on IntrinsicID
1847    break;
1848  } // case Call
1849  } // switch on Opcode
1850
1851  // TODO: We bail completely on integer div/rem and shifts because they have
1852  // UB/poison potential, but that should be refined.
1853  BinaryOperator *BO;
1854  if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) {
1855    simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
1856    simplifyAndSetOp(I, 1, DemandedElts, UndefElts2);
1857
1858    // Any change to an instruction with potential poison must clear those flags
1859    // because we can not guarantee those constraints now. Other analysis may
1860    // determine that it is safe to re-apply the flags.
1861    if (MadeChange)
1862      BO->dropPoisonGeneratingFlags();
1863
1864    // Output elements are undefined if both are undefined. Consider things
1865    // like undef & 0. The result is known zero, not undef.
1866    UndefElts &= UndefElts2;
1867  }
1868
1869  // If we've proven all of the lanes undef, return an undef value.
1870  // TODO: Intersect w/demanded lanes
1871  if (UndefElts.isAllOnesValue())
1872    return UndefValue::get(I->getType());;
1873
1874  return MadeChange ? I : nullptr;
1875}
1876