1//===- llvm/Support/KnownBits.h - Stores known zeros/ones -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a class for representing known zeros and ones used by
10// computeKnownBits.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_SUPPORT_KNOWNBITS_H
15#define LLVM_SUPPORT_KNOWNBITS_H
16
17#include "llvm/ADT/APInt.h"
18#include <optional>
19
20namespace llvm {
21
22// Struct for tracking the known zeros and ones of a value.
23struct KnownBits {
24  APInt Zero;
25  APInt One;
26
27private:
28  // Internal constructor for creating a KnownBits from two APInts.
29  KnownBits(APInt Zero, APInt One)
30      : Zero(std::move(Zero)), One(std::move(One)) {}
31
32public:
33  // Default construct Zero and One.
34  KnownBits() = default;
35
36  /// Create a known bits object of BitWidth bits initialized to unknown.
37  KnownBits(unsigned BitWidth) : Zero(BitWidth, 0), One(BitWidth, 0) {}
38
39  /// Get the bit width of this value.
40  unsigned getBitWidth() const {
41    assert(Zero.getBitWidth() == One.getBitWidth() &&
42           "Zero and One should have the same width!");
43    return Zero.getBitWidth();
44  }
45
46  /// Returns true if there is conflicting information.
47  bool hasConflict() const { return Zero.intersects(One); }
48
49  /// Returns true if we know the value of all bits.
50  bool isConstant() const {
51    assert(!hasConflict() && "KnownBits conflict!");
52    return Zero.popcount() + One.popcount() == getBitWidth();
53  }
54
55  /// Returns the value when all bits have a known value. This just returns One
56  /// with a protective assertion.
57  const APInt &getConstant() const {
58    assert(isConstant() && "Can only get value when all bits are known");
59    return One;
60  }
61
62  /// Returns true if we don't know any bits.
63  bool isUnknown() const { return Zero.isZero() && One.isZero(); }
64
65  /// Resets the known state of all bits.
66  void resetAll() {
67    Zero.clearAllBits();
68    One.clearAllBits();
69  }
70
71  /// Returns true if value is all zero.
72  bool isZero() const {
73    assert(!hasConflict() && "KnownBits conflict!");
74    return Zero.isAllOnes();
75  }
76
77  /// Returns true if value is all one bits.
78  bool isAllOnes() const {
79    assert(!hasConflict() && "KnownBits conflict!");
80    return One.isAllOnes();
81  }
82
83  /// Make all bits known to be zero and discard any previous information.
84  void setAllZero() {
85    Zero.setAllBits();
86    One.clearAllBits();
87  }
88
89  /// Make all bits known to be one and discard any previous information.
90  void setAllOnes() {
91    Zero.clearAllBits();
92    One.setAllBits();
93  }
94
95  /// Returns true if this value is known to be negative.
96  bool isNegative() const { return One.isSignBitSet(); }
97
98  /// Returns true if this value is known to be non-negative.
99  bool isNonNegative() const { return Zero.isSignBitSet(); }
100
101  /// Returns true if this value is known to be non-zero.
102  bool isNonZero() const { return !One.isZero(); }
103
104  /// Returns true if this value is known to be positive.
105  bool isStrictlyPositive() const {
106    return Zero.isSignBitSet() && !One.isZero();
107  }
108
109  /// Make this value negative.
110  void makeNegative() {
111    One.setSignBit();
112  }
113
114  /// Make this value non-negative.
115  void makeNonNegative() {
116    Zero.setSignBit();
117  }
118
119  /// Return the minimal unsigned value possible given these KnownBits.
120  APInt getMinValue() const {
121    // Assume that all bits that aren't known-ones are zeros.
122    return One;
123  }
124
125  /// Return the minimal signed value possible given these KnownBits.
126  APInt getSignedMinValue() const {
127    // Assume that all bits that aren't known-ones are zeros.
128    APInt Min = One;
129    // Sign bit is unknown.
130    if (Zero.isSignBitClear())
131      Min.setSignBit();
132    return Min;
133  }
134
135  /// Return the maximal unsigned value possible given these KnownBits.
136  APInt getMaxValue() const {
137    // Assume that all bits that aren't known-zeros are ones.
138    return ~Zero;
139  }
140
141  /// Return the maximal signed value possible given these KnownBits.
142  APInt getSignedMaxValue() const {
143    // Assume that all bits that aren't known-zeros are ones.
144    APInt Max = ~Zero;
145    // Sign bit is unknown.
146    if (One.isSignBitClear())
147      Max.clearSignBit();
148    return Max;
149  }
150
151  /// Return known bits for a truncation of the value we're tracking.
152  KnownBits trunc(unsigned BitWidth) const {
153    return KnownBits(Zero.trunc(BitWidth), One.trunc(BitWidth));
154  }
155
156  /// Return known bits for an "any" extension of the value we're tracking,
157  /// where we don't know anything about the extended bits.
158  KnownBits anyext(unsigned BitWidth) const {
159    return KnownBits(Zero.zext(BitWidth), One.zext(BitWidth));
160  }
161
162  /// Return known bits for a zero extension of the value we're tracking.
163  KnownBits zext(unsigned BitWidth) const {
164    unsigned OldBitWidth = getBitWidth();
165    APInt NewZero = Zero.zext(BitWidth);
166    NewZero.setBitsFrom(OldBitWidth);
167    return KnownBits(NewZero, One.zext(BitWidth));
168  }
169
170  /// Return known bits for a sign extension of the value we're tracking.
171  KnownBits sext(unsigned BitWidth) const {
172    return KnownBits(Zero.sext(BitWidth), One.sext(BitWidth));
173  }
174
175  /// Return known bits for an "any" extension or truncation of the value we're
176  /// tracking.
177  KnownBits anyextOrTrunc(unsigned BitWidth) const {
178    if (BitWidth > getBitWidth())
179      return anyext(BitWidth);
180    if (BitWidth < getBitWidth())
181      return trunc(BitWidth);
182    return *this;
183  }
184
185  /// Return known bits for a zero extension or truncation of the value we're
186  /// tracking.
187  KnownBits zextOrTrunc(unsigned BitWidth) const {
188    if (BitWidth > getBitWidth())
189      return zext(BitWidth);
190    if (BitWidth < getBitWidth())
191      return trunc(BitWidth);
192    return *this;
193  }
194
195  /// Return known bits for a sign extension or truncation of the value we're
196  /// tracking.
197  KnownBits sextOrTrunc(unsigned BitWidth) const {
198    if (BitWidth > getBitWidth())
199      return sext(BitWidth);
200    if (BitWidth < getBitWidth())
201      return trunc(BitWidth);
202    return *this;
203  }
204
205  /// Return known bits for a in-register sign extension of the value we're
206  /// tracking.
207  KnownBits sextInReg(unsigned SrcBitWidth) const;
208
209  /// Insert the bits from a smaller known bits starting at bitPosition.
210  void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
211    Zero.insertBits(SubBits.Zero, BitPosition);
212    One.insertBits(SubBits.One, BitPosition);
213  }
214
215  /// Return a subset of the known bits from [bitPosition,bitPosition+numBits).
216  KnownBits extractBits(unsigned NumBits, unsigned BitPosition) const {
217    return KnownBits(Zero.extractBits(NumBits, BitPosition),
218                     One.extractBits(NumBits, BitPosition));
219  }
220
221  /// Concatenate the bits from \p Lo onto the bottom of *this.  This is
222  /// equivalent to:
223  ///   (this->zext(NewWidth) << Lo.getBitWidth()) | Lo.zext(NewWidth)
224  KnownBits concat(const KnownBits &Lo) const {
225    return KnownBits(Zero.concat(Lo.Zero), One.concat(Lo.One));
226  }
227
228  /// Return KnownBits based on this, but updated given that the underlying
229  /// value is known to be greater than or equal to Val.
230  KnownBits makeGE(const APInt &Val) const;
231
232  /// Returns the minimum number of trailing zero bits.
233  unsigned countMinTrailingZeros() const { return Zero.countr_one(); }
234
235  /// Returns the minimum number of trailing one bits.
236  unsigned countMinTrailingOnes() const { return One.countr_one(); }
237
238  /// Returns the minimum number of leading zero bits.
239  unsigned countMinLeadingZeros() const { return Zero.countl_one(); }
240
241  /// Returns the minimum number of leading one bits.
242  unsigned countMinLeadingOnes() const { return One.countl_one(); }
243
244  /// Returns the number of times the sign bit is replicated into the other
245  /// bits.
246  unsigned countMinSignBits() const {
247    if (isNonNegative())
248      return countMinLeadingZeros();
249    if (isNegative())
250      return countMinLeadingOnes();
251    // Every value has at least 1 sign bit.
252    return 1;
253  }
254
255  /// Returns the maximum number of bits needed to represent all possible
256  /// signed values with these known bits. This is the inverse of the minimum
257  /// number of known sign bits. Examples for bitwidth 5:
258  /// 110?? --> 4
259  /// 0000? --> 2
260  unsigned countMaxSignificantBits() const {
261    return getBitWidth() - countMinSignBits() + 1;
262  }
263
264  /// Returns the maximum number of trailing zero bits possible.
265  unsigned countMaxTrailingZeros() const { return One.countr_zero(); }
266
267  /// Returns the maximum number of trailing one bits possible.
268  unsigned countMaxTrailingOnes() const { return Zero.countr_zero(); }
269
270  /// Returns the maximum number of leading zero bits possible.
271  unsigned countMaxLeadingZeros() const { return One.countl_zero(); }
272
273  /// Returns the maximum number of leading one bits possible.
274  unsigned countMaxLeadingOnes() const { return Zero.countl_zero(); }
275
276  /// Returns the number of bits known to be one.
277  unsigned countMinPopulation() const { return One.popcount(); }
278
279  /// Returns the maximum number of bits that could be one.
280  unsigned countMaxPopulation() const {
281    return getBitWidth() - Zero.popcount();
282  }
283
284  /// Returns the maximum number of bits needed to represent all possible
285  /// unsigned values with these known bits. This is the inverse of the
286  /// minimum number of leading zeros.
287  unsigned countMaxActiveBits() const {
288    return getBitWidth() - countMinLeadingZeros();
289  }
290
291  /// Create known bits from a known constant.
292  static KnownBits makeConstant(const APInt &C) {
293    return KnownBits(~C, C);
294  }
295
296  /// Returns KnownBits information that is known to be true for both this and
297  /// RHS.
298  ///
299  /// When an operation is known to return one of its operands, this can be used
300  /// to combine information about the known bits of the operands to get the
301  /// information that must be true about the result.
302  KnownBits intersectWith(const KnownBits &RHS) const {
303    return KnownBits(Zero & RHS.Zero, One & RHS.One);
304  }
305
306  /// Returns KnownBits information that is known to be true for either this or
307  /// RHS or both.
308  ///
309  /// This can be used to combine different sources of information about the
310  /// known bits of a single value, e.g. information about the low bits and the
311  /// high bits of the result of a multiplication.
312  KnownBits unionWith(const KnownBits &RHS) const {
313    return KnownBits(Zero | RHS.Zero, One | RHS.One);
314  }
315
316  /// Compute known bits common to LHS and RHS.
317  LLVM_DEPRECATED("use intersectWith instead", "intersectWith")
318  static KnownBits commonBits(const KnownBits &LHS, const KnownBits &RHS) {
319    return LHS.intersectWith(RHS);
320  }
321
322  /// Return true if LHS and RHS have no common bits set.
323  static bool haveNoCommonBitsSet(const KnownBits &LHS, const KnownBits &RHS) {
324    return (LHS.Zero | RHS.Zero).isAllOnes();
325  }
326
327  /// Compute known bits resulting from adding LHS, RHS and a 1-bit Carry.
328  static KnownBits computeForAddCarry(
329      const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
330
331  /// Compute known bits resulting from adding LHS and RHS.
332  static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
333                                    KnownBits RHS);
334
335  /// Compute known bits results from subtracting RHS from LHS with 1-bit
336  /// Borrow.
337  static KnownBits computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
338                                       const KnownBits &Borrow);
339
340  /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS)
341  static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS);
342
343  /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS)
344  static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS);
345
346  /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS)
347  static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS);
348
349  /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
350  static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
351
352  /// Compute known bits resulting from multiplying LHS and RHS.
353  static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
354                       bool NoUndefSelfMultiply = false);
355
356  /// Compute known bits from sign-extended multiply-hi.
357  static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS);
358
359  /// Compute known bits from zero-extended multiply-hi.
360  static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS);
361
362  /// Compute known bits for sdiv(LHS, RHS).
363  static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS,
364                        bool Exact = false);
365
366  /// Compute known bits for udiv(LHS, RHS).
367  static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS,
368                        bool Exact = false);
369
370  /// Compute known bits for urem(LHS, RHS).
371  static KnownBits urem(const KnownBits &LHS, const KnownBits &RHS);
372
373  /// Compute known bits for srem(LHS, RHS).
374  static KnownBits srem(const KnownBits &LHS, const KnownBits &RHS);
375
376  /// Compute known bits for umax(LHS, RHS).
377  static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
378
379  /// Compute known bits for umin(LHS, RHS).
380  static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
381
382  /// Compute known bits for smax(LHS, RHS).
383  static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
384
385  /// Compute known bits for smin(LHS, RHS).
386  static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
387
388  /// Compute known bits for shl(LHS, RHS).
389  /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
390  static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
391                       bool NUW = false, bool NSW = false,
392                       bool ShAmtNonZero = false);
393
394  /// Compute known bits for lshr(LHS, RHS).
395  /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
396  static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
397                        bool ShAmtNonZero = false);
398
399  /// Compute known bits for ashr(LHS, RHS).
400  /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
401  static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
402                        bool ShAmtNonZero = false);
403
404  /// Determine if these known bits always give the same ICMP_EQ result.
405  static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
406
407  /// Determine if these known bits always give the same ICMP_NE result.
408  static std::optional<bool> ne(const KnownBits &LHS, const KnownBits &RHS);
409
410  /// Determine if these known bits always give the same ICMP_UGT result.
411  static std::optional<bool> ugt(const KnownBits &LHS, const KnownBits &RHS);
412
413  /// Determine if these known bits always give the same ICMP_UGE result.
414  static std::optional<bool> uge(const KnownBits &LHS, const KnownBits &RHS);
415
416  /// Determine if these known bits always give the same ICMP_ULT result.
417  static std::optional<bool> ult(const KnownBits &LHS, const KnownBits &RHS);
418
419  /// Determine if these known bits always give the same ICMP_ULE result.
420  static std::optional<bool> ule(const KnownBits &LHS, const KnownBits &RHS);
421
422  /// Determine if these known bits always give the same ICMP_SGT result.
423  static std::optional<bool> sgt(const KnownBits &LHS, const KnownBits &RHS);
424
425  /// Determine if these known bits always give the same ICMP_SGE result.
426  static std::optional<bool> sge(const KnownBits &LHS, const KnownBits &RHS);
427
428  /// Determine if these known bits always give the same ICMP_SLT result.
429  static std::optional<bool> slt(const KnownBits &LHS, const KnownBits &RHS);
430
431  /// Determine if these known bits always give the same ICMP_SLE result.
432  static std::optional<bool> sle(const KnownBits &LHS, const KnownBits &RHS);
433
434  /// Update known bits based on ANDing with RHS.
435  KnownBits &operator&=(const KnownBits &RHS);
436
437  /// Update known bits based on ORing with RHS.
438  KnownBits &operator|=(const KnownBits &RHS);
439
440  /// Update known bits based on XORing with RHS.
441  KnownBits &operator^=(const KnownBits &RHS);
442
443  /// Compute known bits for the absolute value.
444  KnownBits abs(bool IntMinIsPoison = false) const;
445
446  KnownBits byteSwap() const {
447    return KnownBits(Zero.byteSwap(), One.byteSwap());
448  }
449
450  KnownBits reverseBits() const {
451    return KnownBits(Zero.reverseBits(), One.reverseBits());
452  }
453
454  /// Compute known bits for X & -X, which has only the lowest bit set of X set.
455  /// The name comes from the X86 BMI instruction
456  KnownBits blsi() const;
457
458  /// Compute known bits for X ^ (X - 1), which has all bits up to and including
459  /// the lowest set bit of X set. The name comes from the X86 BMI instruction.
460  KnownBits blsmsk() const;
461
462  bool operator==(const KnownBits &Other) const {
463    return Zero == Other.Zero && One == Other.One;
464  }
465
466  bool operator!=(const KnownBits &Other) const { return !(*this == Other); }
467
468  void print(raw_ostream &OS) const;
469  void dump() const;
470
471private:
472  // Internal helper for getting the initial KnownBits for an `srem` or `urem`
473  // operation with the low-bits set.
474  static KnownBits remGetLowBits(const KnownBits &LHS, const KnownBits &RHS);
475};
476
477inline KnownBits operator&(KnownBits LHS, const KnownBits &RHS) {
478  LHS &= RHS;
479  return LHS;
480}
481
482inline KnownBits operator&(const KnownBits &LHS, KnownBits &&RHS) {
483  RHS &= LHS;
484  return std::move(RHS);
485}
486
487inline KnownBits operator|(KnownBits LHS, const KnownBits &RHS) {
488  LHS |= RHS;
489  return LHS;
490}
491
492inline KnownBits operator|(const KnownBits &LHS, KnownBits &&RHS) {
493  RHS |= LHS;
494  return std::move(RHS);
495}
496
497inline KnownBits operator^(KnownBits LHS, const KnownBits &RHS) {
498  LHS ^= RHS;
499  return LHS;
500}
501
502inline KnownBits operator^(const KnownBits &LHS, KnownBits &&RHS) {
503  RHS ^= LHS;
504  return std::move(RHS);
505}
506
507inline raw_ostream &operator<<(raw_ostream &OS, const KnownBits &Known) {
508  Known.print(OS);
509  return OS;
510}
511
512} // end namespace llvm
513
514#endif
515