1//===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a class for representing known zeros and ones used by
10// computeKnownBits.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/KnownBits.h"
15#include "llvm/Support/Debug.h"
16#include "llvm/Support/raw_ostream.h"
17#include <cassert>
18
19using namespace llvm;
20
21static KnownBits computeForAddCarry(
22    const KnownBits &LHS, const KnownBits &RHS,
23    bool CarryZero, bool CarryOne) {
24  assert(!(CarryZero && CarryOne) &&
25         "Carry can't be zero and one at the same time");
26
27  APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
28  APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
29
30  // Compute known bits of the carry.
31  APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
32  APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
33
34  // Compute set of known bits (where all three relevant bits are known).
35  APInt LHSKnownUnion = LHS.Zero | LHS.One;
36  APInt RHSKnownUnion = RHS.Zero | RHS.One;
37  APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
38  APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
39
40  assert((PossibleSumZero & Known) == (PossibleSumOne & Known) &&
41         "known bits of sum differ");
42
43  // Compute known bits of the result.
44  KnownBits KnownOut;
45  KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
46  KnownOut.One = std::move(PossibleSumOne) & Known;
47  return KnownOut;
48}
49
50KnownBits KnownBits::computeForAddCarry(
51    const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
52  assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
53  return ::computeForAddCarry(
54      LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
55}
56
57KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
58                                      const KnownBits &LHS, KnownBits RHS) {
59  KnownBits KnownOut;
60  if (Add) {
61    // Sum = LHS + RHS + 0
62    KnownOut = ::computeForAddCarry(
63        LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
64  } else {
65    // Sum = LHS + ~RHS + 1
66    std::swap(RHS.Zero, RHS.One);
67    KnownOut = ::computeForAddCarry(
68        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
69  }
70
71  // Are we still trying to solve for the sign bit?
72  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
73    if (NSW) {
74      // Adding two non-negative numbers, or subtracting a negative number from
75      // a non-negative one, can't wrap into negative.
76      if (LHS.isNonNegative() && RHS.isNonNegative())
77        KnownOut.makeNonNegative();
78      // Adding two negative numbers, or subtracting a non-negative number from
79      // a negative one, can't wrap into non-negative.
80      else if (LHS.isNegative() && RHS.isNegative())
81        KnownOut.makeNegative();
82    }
83  }
84
85  return KnownOut;
86}
87
88KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
89  unsigned BitWidth = getBitWidth();
90  assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
91         "Illegal sext-in-register");
92
93  if (SrcBitWidth == BitWidth)
94    return *this;
95
96  unsigned ExtBits = BitWidth - SrcBitWidth;
97  KnownBits Result;
98  Result.One = One << ExtBits;
99  Result.Zero = Zero << ExtBits;
100  Result.One.ashrInPlace(ExtBits);
101  Result.Zero.ashrInPlace(ExtBits);
102  return Result;
103}
104
105KnownBits KnownBits::makeGE(const APInt &Val) const {
106  // Count the number of leading bit positions where our underlying value is
107  // known to be less than or equal to Val.
108  unsigned N = (Zero | Val).countLeadingOnes();
109
110  // For each of those bit positions, if Val has a 1 in that bit then our
111  // underlying value must also have a 1.
112  APInt MaskedVal(Val);
113  MaskedVal.clearLowBits(getBitWidth() - N);
114  return KnownBits(Zero, One | MaskedVal);
115}
116
117KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
118  // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
119  // RHS. Ideally our caller would already have spotted these cases and
120  // optimized away the umax operation, but we handle them here for
121  // completeness.
122  if (LHS.getMinValue().uge(RHS.getMaxValue()))
123    return LHS;
124  if (RHS.getMinValue().uge(LHS.getMaxValue()))
125    return RHS;
126
127  // If the result of the umax is LHS then it must be greater than or equal to
128  // the minimum possible value of RHS. Likewise for RHS. Any known bits that
129  // are common to these two values are also known in the result.
130  KnownBits L = LHS.makeGE(RHS.getMinValue());
131  KnownBits R = RHS.makeGE(LHS.getMinValue());
132  return KnownBits::commonBits(L, R);
133}
134
135KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
136  // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
137  auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
138  return Flip(umax(Flip(LHS), Flip(RHS)));
139}
140
141KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
142  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
143  auto Flip = [](const KnownBits &Val) {
144    unsigned SignBitPosition = Val.getBitWidth() - 1;
145    APInt Zero = Val.Zero;
146    APInt One = Val.One;
147    Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
148    One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
149    return KnownBits(Zero, One);
150  };
151  return Flip(umax(Flip(LHS), Flip(RHS)));
152}
153
154KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
155  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
156  auto Flip = [](const KnownBits &Val) {
157    unsigned SignBitPosition = Val.getBitWidth() - 1;
158    APInt Zero = Val.One;
159    APInt One = Val.Zero;
160    Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
161    One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
162    return KnownBits(Zero, One);
163  };
164  return Flip(umax(Flip(LHS), Flip(RHS)));
165}
166
167KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
168  unsigned BitWidth = LHS.getBitWidth();
169  KnownBits Known(BitWidth);
170
171  // If the shift amount is a valid constant then transform LHS directly.
172  if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
173    unsigned Shift = RHS.getConstant().getZExtValue();
174    Known = LHS;
175    Known.Zero <<= Shift;
176    Known.One <<= Shift;
177    // Low bits are known zero.
178    Known.Zero.setLowBits(Shift);
179    return Known;
180  }
181
182  // No matter the shift amount, the trailing zeros will stay zero.
183  unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
184
185  // Minimum shift amount low bits are known zero.
186  APInt MinShiftAmount = RHS.getMinValue();
187  if (MinShiftAmount.ult(BitWidth)) {
188    MinTrailingZeros += MinShiftAmount.getZExtValue();
189    MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
190  }
191
192  // If the maximum shift is in range, then find the common bits from all
193  // possible shifts.
194  APInt MaxShiftAmount = RHS.getMaxValue();
195  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
196    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
197    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
198    assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
199    Known.Zero.setAllBits();
200    Known.One.setAllBits();
201    for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
202                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
203         ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
204      // Skip if the shift amount is impossible.
205      if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
206          (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
207        continue;
208      KnownBits SpecificShift;
209      SpecificShift.Zero = LHS.Zero << ShiftAmt;
210      SpecificShift.One = LHS.One << ShiftAmt;
211      Known = KnownBits::commonBits(Known, SpecificShift);
212      if (Known.isUnknown())
213        break;
214    }
215  }
216
217  Known.Zero.setLowBits(MinTrailingZeros);
218  return Known;
219}
220
221KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
222  unsigned BitWidth = LHS.getBitWidth();
223  KnownBits Known(BitWidth);
224
225  if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
226    unsigned Shift = RHS.getConstant().getZExtValue();
227    Known = LHS;
228    Known.Zero.lshrInPlace(Shift);
229    Known.One.lshrInPlace(Shift);
230    // High bits are known zero.
231    Known.Zero.setHighBits(Shift);
232    return Known;
233  }
234
235  // No matter the shift amount, the leading zeros will stay zero.
236  unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
237
238  // Minimum shift amount high bits are known zero.
239  APInt MinShiftAmount = RHS.getMinValue();
240  if (MinShiftAmount.ult(BitWidth)) {
241    MinLeadingZeros += MinShiftAmount.getZExtValue();
242    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
243  }
244
245  // If the maximum shift is in range, then find the common bits from all
246  // possible shifts.
247  APInt MaxShiftAmount = RHS.getMaxValue();
248  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
249    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
250    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
251    assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
252    Known.Zero.setAllBits();
253    Known.One.setAllBits();
254    for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
255                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
256         ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
257      // Skip if the shift amount is impossible.
258      if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
259          (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
260        continue;
261      KnownBits SpecificShift = LHS;
262      SpecificShift.Zero.lshrInPlace(ShiftAmt);
263      SpecificShift.One.lshrInPlace(ShiftAmt);
264      Known = KnownBits::commonBits(Known, SpecificShift);
265      if (Known.isUnknown())
266        break;
267    }
268  }
269
270  Known.Zero.setHighBits(MinLeadingZeros);
271  return Known;
272}
273
274KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
275  unsigned BitWidth = LHS.getBitWidth();
276  KnownBits Known(BitWidth);
277
278  if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
279    unsigned Shift = RHS.getConstant().getZExtValue();
280    Known = LHS;
281    Known.Zero.ashrInPlace(Shift);
282    Known.One.ashrInPlace(Shift);
283    return Known;
284  }
285
286  // No matter the shift amount, the leading sign bits will stay.
287  unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
288  unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
289
290  // Minimum shift amount high bits are known sign bits.
291  APInt MinShiftAmount = RHS.getMinValue();
292  if (MinShiftAmount.ult(BitWidth)) {
293    if (MinLeadingZeros) {
294      MinLeadingZeros += MinShiftAmount.getZExtValue();
295      MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
296    }
297    if (MinLeadingOnes) {
298      MinLeadingOnes += MinShiftAmount.getZExtValue();
299      MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
300    }
301  }
302
303  // If the maximum shift is in range, then find the common bits from all
304  // possible shifts.
305  APInt MaxShiftAmount = RHS.getMaxValue();
306  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
307    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
308    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
309    assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
310    Known.Zero.setAllBits();
311    Known.One.setAllBits();
312    for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
313                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
314         ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
315      // Skip if the shift amount is impossible.
316      if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
317          (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
318        continue;
319      KnownBits SpecificShift = LHS;
320      SpecificShift.Zero.ashrInPlace(ShiftAmt);
321      SpecificShift.One.ashrInPlace(ShiftAmt);
322      Known = KnownBits::commonBits(Known, SpecificShift);
323      if (Known.isUnknown())
324        break;
325    }
326  }
327
328  Known.Zero.setHighBits(MinLeadingZeros);
329  Known.One.setHighBits(MinLeadingOnes);
330  return Known;
331}
332
333std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
334  if (LHS.isConstant() && RHS.isConstant())
335    return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
336  if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
337    return std::optional<bool>(false);
338  return std::nullopt;
339}
340
341std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
342  if (std::optional<bool> KnownEQ = eq(LHS, RHS))
343    return std::optional<bool>(!*KnownEQ);
344  return std::nullopt;
345}
346
347std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
348  // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
349  if (LHS.getMaxValue().ule(RHS.getMinValue()))
350    return std::optional<bool>(false);
351  // LHS >u RHS -> true if umin(LHS) > umax(RHS)
352  if (LHS.getMinValue().ugt(RHS.getMaxValue()))
353    return std::optional<bool>(true);
354  return std::nullopt;
355}
356
357std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
358  if (std::optional<bool> IsUGT = ugt(RHS, LHS))
359    return std::optional<bool>(!*IsUGT);
360  return std::nullopt;
361}
362
363std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
364  return ugt(RHS, LHS);
365}
366
367std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
368  return uge(RHS, LHS);
369}
370
371std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
372  // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
373  if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
374    return std::optional<bool>(false);
375  // LHS >s RHS -> true if smin(LHS) > smax(RHS)
376  if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
377    return std::optional<bool>(true);
378  return std::nullopt;
379}
380
381std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
382  if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
383    return std::optional<bool>(!*KnownSGT);
384  return std::nullopt;
385}
386
387std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
388  return sgt(RHS, LHS);
389}
390
391std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
392  return sge(RHS, LHS);
393}
394
395KnownBits KnownBits::abs(bool IntMinIsPoison) const {
396  // If the source's MSB is zero then we know the rest of the bits already.
397  if (isNonNegative())
398    return *this;
399
400  // Absolute value preserves trailing zero count.
401  KnownBits KnownAbs(getBitWidth());
402  KnownAbs.Zero.setLowBits(countMinTrailingZeros());
403
404  // We only know that the absolute values's MSB will be zero if INT_MIN is
405  // poison, or there is a set bit that isn't the sign bit (otherwise it could
406  // be INT_MIN).
407  if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue()))
408    KnownAbs.Zero.setSignBit();
409
410  // FIXME: Handle known negative input?
411  // FIXME: Calculate the negated Known bits and combine them?
412  return KnownAbs;
413}
414
415KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
416                         bool NoUndefSelfMultiply) {
417  unsigned BitWidth = LHS.getBitWidth();
418  assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
419         !RHS.hasConflict() && "Operand mismatch");
420  assert((!NoUndefSelfMultiply || LHS == RHS) &&
421         "Self multiplication knownbits mismatch");
422
423  // Compute the high known-0 bits by multiplying the unsigned max of each side.
424  // Conservatively, M active bits * N active bits results in M + N bits in the
425  // result. But if we know a value is a power-of-2 for example, then this
426  // computes one more leading zero.
427  // TODO: This could be generalized to number of sign bits (negative numbers).
428  APInt UMaxLHS = LHS.getMaxValue();
429  APInt UMaxRHS = RHS.getMaxValue();
430
431  // For leading zeros in the result to be valid, the unsigned max product must
432  // fit in the bitwidth (it must not overflow).
433  bool HasOverflow;
434  APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
435  unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countLeadingZeros();
436
437  // The result of the bottom bits of an integer multiply can be
438  // inferred by looking at the bottom bits of both operands and
439  // multiplying them together.
440  // We can infer at least the minimum number of known trailing bits
441  // of both operands. Depending on number of trailing zeros, we can
442  // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
443  // a and b are divisible by m and n respectively.
444  // We then calculate how many of those bits are inferrable and set
445  // the output. For example, the i8 mul:
446  //  a = XXXX1100 (12)
447  //  b = XXXX1110 (14)
448  // We know the bottom 3 bits are zero since the first can be divided by
449  // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
450  // Applying the multiplication to the trimmed arguments gets:
451  //    XX11 (3)
452  //    X111 (7)
453  // -------
454  //    XX11
455  //   XX11
456  //  XX11
457  // XX11
458  // -------
459  // XXXXX01
460  // Which allows us to infer the 2 LSBs. Since we're multiplying the result
461  // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
462  // The proof for this can be described as:
463  // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
464  //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
465  //                    umin(countTrailingZeros(C2), C6) +
466  //                    umin(C5 - umin(countTrailingZeros(C1), C5),
467  //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
468  // %aa = shl i8 %a, C5
469  // %bb = shl i8 %b, C6
470  // %aaa = or i8 %aa, C1
471  // %bbb = or i8 %bb, C2
472  // %mul = mul i8 %aaa, %bbb
473  // %mask = and i8 %mul, C7
474  //   =>
475  // %mask = i8 ((C1*C2)&C7)
476  // Where C5, C6 describe the known bits of %a, %b
477  // C1, C2 describe the known bottom bits of %a, %b.
478  // C7 describes the mask of the known bits of the result.
479  const APInt &Bottom0 = LHS.One;
480  const APInt &Bottom1 = RHS.One;
481
482  // How many times we'd be able to divide each argument by 2 (shr by 1).
483  // This gives us the number of trailing zeros on the multiplication result.
484  unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes();
485  unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes();
486  unsigned TrailZero0 = LHS.countMinTrailingZeros();
487  unsigned TrailZero1 = RHS.countMinTrailingZeros();
488  unsigned TrailZ = TrailZero0 + TrailZero1;
489
490  // Figure out the fewest known-bits operand.
491  unsigned SmallestOperand =
492      std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
493  unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
494
495  APInt BottomKnown =
496      Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
497
498  KnownBits Res(BitWidth);
499  Res.Zero.setHighBits(LeadZ);
500  Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
501  Res.One = BottomKnown.getLoBits(ResultBitsKnown);
502
503  // If we're self-multiplying then bit[1] is guaranteed to be zero.
504  if (NoUndefSelfMultiply && BitWidth > 1) {
505    assert(Res.One[1] == 0 &&
506           "Self-multiplication failed Quadratic Reciprocity!");
507    Res.Zero.setBit(1);
508  }
509
510  return Res;
511}
512
513KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
514  unsigned BitWidth = LHS.getBitWidth();
515  assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
516         !RHS.hasConflict() && "Operand mismatch");
517  KnownBits WideLHS = LHS.sext(2 * BitWidth);
518  KnownBits WideRHS = RHS.sext(2 * BitWidth);
519  return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
520}
521
522KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
523  unsigned BitWidth = LHS.getBitWidth();
524  assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
525         !RHS.hasConflict() && "Operand mismatch");
526  KnownBits WideLHS = LHS.zext(2 * BitWidth);
527  KnownBits WideRHS = RHS.zext(2 * BitWidth);
528  return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
529}
530
531KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) {
532  unsigned BitWidth = LHS.getBitWidth();
533  assert(!LHS.hasConflict() && !RHS.hasConflict());
534  KnownBits Known(BitWidth);
535
536  // For the purposes of computing leading zeros we can conservatively
537  // treat a udiv as a logical right shift by the power of 2 known to
538  // be less than the denominator.
539  unsigned LeadZ = LHS.countMinLeadingZeros();
540  unsigned RHSMaxLeadingZeros = RHS.countMaxLeadingZeros();
541
542  if (RHSMaxLeadingZeros != BitWidth)
543    LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1);
544
545  Known.Zero.setHighBits(LeadZ);
546  return Known;
547}
548
549KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
550  unsigned BitWidth = LHS.getBitWidth();
551  assert(!LHS.hasConflict() && !RHS.hasConflict());
552  KnownBits Known(BitWidth);
553
554  if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
555    // The upper bits are all zero, the lower ones are unchanged.
556    APInt LowBits = RHS.getConstant() - 1;
557    Known.Zero = LHS.Zero | ~LowBits;
558    Known.One = LHS.One & LowBits;
559    return Known;
560  }
561
562  // Since the result is less than or equal to either operand, any leading
563  // zero bits in either operand must also exist in the result.
564  uint32_t Leaders =
565      std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
566  Known.Zero.setHighBits(Leaders);
567  return Known;
568}
569
570KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
571  unsigned BitWidth = LHS.getBitWidth();
572  assert(!LHS.hasConflict() && !RHS.hasConflict());
573  KnownBits Known(BitWidth);
574
575  if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
576    // The low bits of the first operand are unchanged by the srem.
577    APInt LowBits = RHS.getConstant() - 1;
578    Known.Zero = LHS.Zero & LowBits;
579    Known.One = LHS.One & LowBits;
580
581    // If the first operand is non-negative or has all low bits zero, then
582    // the upper bits are all zero.
583    if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
584      Known.Zero |= ~LowBits;
585
586    // If the first operand is negative and not all low bits are zero, then
587    // the upper bits are all one.
588    if (LHS.isNegative() && LowBits.intersects(LHS.One))
589      Known.One |= ~LowBits;
590    return Known;
591  }
592
593  // The sign bit is the LHS's sign bit, except when the result of the
594  // remainder is zero. The magnitude of the result should be less than or
595  // equal to the magnitude of the LHS. Therefore any leading zeros that exist
596  // in the left hand side must also exist in the result.
597  Known.Zero.setHighBits(LHS.countMinLeadingZeros());
598  return Known;
599}
600
601KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
602  // Result bit is 0 if either operand bit is 0.
603  Zero |= RHS.Zero;
604  // Result bit is 1 if both operand bits are 1.
605  One &= RHS.One;
606  return *this;
607}
608
609KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
610  // Result bit is 0 if both operand bits are 0.
611  Zero &= RHS.Zero;
612  // Result bit is 1 if either operand bit is 1.
613  One |= RHS.One;
614  return *this;
615}
616
617KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
618  // Result bit is 0 if both operand bits are 0 or both are 1.
619  APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
620  // Result bit is 1 if one operand bit is 0 and the other is 1.
621  One = (Zero & RHS.One) | (One & RHS.Zero);
622  Zero = std::move(Z);
623  return *this;
624}
625
626void KnownBits::print(raw_ostream &OS) const {
627  OS << "{Zero=" << Zero << ", One=" << One << "}";
628}
629void KnownBits::dump() const {
630  print(dbgs());
631  dbgs() << "\n";
632}
633