1//===- lib/CodeGen/GlobalISel/GISelKnownBits.cpp --------------*- C++ *-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// Provides analysis for querying information about KnownBits during GISel
10/// passes.
11//
12//===------------------
13#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
14#include "llvm/Analysis/ValueTracking.h"
15#include "llvm/CodeGen/GlobalISel/Utils.h"
16#include "llvm/CodeGen/MachineFrameInfo.h"
17#include "llvm/CodeGen/MachineRegisterInfo.h"
18#include "llvm/CodeGen/TargetLowering.h"
19#include "llvm/CodeGen/TargetOpcodes.h"
20
21#define DEBUG_TYPE "gisel-known-bits"
22
23using namespace llvm;
24
25char llvm::GISelKnownBitsAnalysis::ID = 0;
26
27INITIALIZE_PASS_BEGIN(GISelKnownBitsAnalysis, DEBUG_TYPE,
28                      "Analysis for ComputingKnownBits", false, true)
29INITIALIZE_PASS_END(GISelKnownBitsAnalysis, DEBUG_TYPE,
30                    "Analysis for ComputingKnownBits", false, true)
31
32GISelKnownBits::GISelKnownBits(MachineFunction &MF)
33    : MF(MF), MRI(MF.getRegInfo()), TL(*MF.getSubtarget().getTargetLowering()),
34      DL(MF.getFunction().getParent()->getDataLayout()) {}
35
36Align GISelKnownBits::inferAlignmentForFrameIdx(int FrameIdx, int Offset,
37                                                const MachineFunction &MF) {
38  const MachineFrameInfo &MFI = MF.getFrameInfo();
39  return commonAlignment(Align(MFI.getObjectAlignment(FrameIdx)), Offset);
40  // TODO: How to handle cases with Base + Offset?
41}
42
43MaybeAlign GISelKnownBits::inferPtrAlignment(const MachineInstr &MI) {
44  if (MI.getOpcode() == TargetOpcode::G_FRAME_INDEX) {
45    int FrameIdx = MI.getOperand(1).getIndex();
46    return inferAlignmentForFrameIdx(FrameIdx, 0, *MI.getMF());
47  }
48  return None;
49}
50
51void GISelKnownBits::computeKnownBitsForFrameIndex(Register R, KnownBits &Known,
52                                                   const APInt &DemandedElts,
53                                                   unsigned Depth) {
54  const MachineInstr &MI = *MRI.getVRegDef(R);
55  computeKnownBitsForAlignment(Known, inferPtrAlignment(MI));
56}
57
58void GISelKnownBits::computeKnownBitsForAlignment(KnownBits &Known,
59                                                  MaybeAlign Alignment) {
60  if (Alignment)
61    // The low bits are known zero if the pointer is aligned.
62    Known.Zero.setLowBits(Log2(Alignment));
63}
64
65KnownBits GISelKnownBits::getKnownBits(MachineInstr &MI) {
66  return getKnownBits(MI.getOperand(0).getReg());
67}
68
69KnownBits GISelKnownBits::getKnownBits(Register R) {
70  KnownBits Known;
71  LLT Ty = MRI.getType(R);
72  APInt DemandedElts =
73      Ty.isVector() ? APInt::getAllOnesValue(Ty.getNumElements()) : APInt(1, 1);
74  computeKnownBitsImpl(R, Known, DemandedElts);
75  return Known;
76}
77
78bool GISelKnownBits::signBitIsZero(Register R) {
79  LLT Ty = MRI.getType(R);
80  unsigned BitWidth = Ty.getScalarSizeInBits();
81  return maskedValueIsZero(R, APInt::getSignMask(BitWidth));
82}
83
84APInt GISelKnownBits::getKnownZeroes(Register R) {
85  return getKnownBits(R).Zero;
86}
87
88APInt GISelKnownBits::getKnownOnes(Register R) { return getKnownBits(R).One; }
89
90void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
91                                          const APInt &DemandedElts,
92                                          unsigned Depth) {
93  MachineInstr &MI = *MRI.getVRegDef(R);
94  unsigned Opcode = MI.getOpcode();
95  LLT DstTy = MRI.getType(R);
96
97  // Handle the case where this is called on a register that does not have a
98  // type constraint (i.e. it has a register class constraint instead). This is
99  // unlikely to occur except by looking through copies but it is possible for
100  // the initial register being queried to be in this state.
101  if (!DstTy.isValid()) {
102    Known = KnownBits();
103    return;
104  }
105
106  unsigned BitWidth = DstTy.getSizeInBits();
107  Known = KnownBits(BitWidth); // Don't know anything
108
109  if (DstTy.isVector())
110    return; // TODO: Handle vectors.
111
112  if (Depth == getMaxDepth())
113    return;
114
115  if (!DemandedElts)
116    return; // No demanded elts, better to assume we don't know anything.
117
118  KnownBits Known2;
119
120  switch (Opcode) {
121  default:
122    TL.computeKnownBitsForTargetInstr(*this, R, Known, DemandedElts, MRI,
123                                      Depth);
124    break;
125  case TargetOpcode::COPY: {
126    MachineOperand Dst = MI.getOperand(0);
127    MachineOperand Src = MI.getOperand(1);
128    // Look through trivial copies but don't look through trivial copies of the
129    // form `%1:(s32) = OP %0:gpr32` known-bits analysis is currently unable to
130    // determine the bit width of a register class.
131    //
132    // We can't use NoSubRegister by name as it's defined by each target but
133    // it's always defined to be 0 by tablegen.
134    if (Dst.getSubReg() == 0 /*NoSubRegister*/ && Src.getReg().isVirtual() &&
135        Src.getSubReg() == 0 /*NoSubRegister*/ &&
136        MRI.getType(Src.getReg()).isValid()) {
137      // Don't increment Depth for this one since we didn't do any work.
138      computeKnownBitsImpl(Src.getReg(), Known, DemandedElts, Depth);
139    }
140    break;
141  }
142  case TargetOpcode::G_CONSTANT: {
143    auto CstVal = getConstantVRegVal(R, MRI);
144    if (!CstVal)
145      break;
146    Known.One = *CstVal;
147    Known.Zero = ~Known.One;
148    break;
149  }
150  case TargetOpcode::G_FRAME_INDEX: {
151    computeKnownBitsForFrameIndex(R, Known, DemandedElts);
152    break;
153  }
154  case TargetOpcode::G_SUB: {
155    // If low bits are known to be zero in both operands, then we know they are
156    // going to be 0 in the result. Both addition and complement operations
157    // preserve the low zero bits.
158    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
159                         Depth + 1);
160    unsigned KnownZeroLow = Known2.countMinTrailingZeros();
161    if (KnownZeroLow == 0)
162      break;
163    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
164                         Depth + 1);
165    KnownZeroLow = std::min(KnownZeroLow, Known2.countMinTrailingZeros());
166    Known.Zero.setLowBits(KnownZeroLow);
167    break;
168  }
169  case TargetOpcode::G_XOR: {
170    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
171                         Depth + 1);
172    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
173                         Depth + 1);
174
175    // Output known-0 bits are known if clear or set in both the LHS & RHS.
176    APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One);
177    // Output known-1 are known to be set if set in only one of the LHS, RHS.
178    Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero);
179    Known.Zero = KnownZeroOut;
180    break;
181  }
182  case TargetOpcode::G_PTR_ADD: {
183    // G_PTR_ADD is like G_ADD. FIXME: Is this true for all targets?
184    LLT Ty = MRI.getType(MI.getOperand(1).getReg());
185    if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
186      break;
187    LLVM_FALLTHROUGH;
188  }
189  case TargetOpcode::G_ADD: {
190    // Output known-0 bits are known if clear or set in both the low clear bits
191    // common to both LHS & RHS.  For example, 8+(X<<3) is known to have the
192    // low 3 bits clear.
193    // Output known-0 bits are also known if the top bits of each input are
194    // known to be clear. For example, if one input has the top 10 bits clear
195    // and the other has the top 8 bits clear, we know the top 7 bits of the
196    // output must be clear.
197    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
198                         Depth + 1);
199    unsigned KnownZeroHigh = Known2.countMinLeadingZeros();
200    unsigned KnownZeroLow = Known2.countMinTrailingZeros();
201    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
202                         Depth + 1);
203    KnownZeroHigh = std::min(KnownZeroHigh, Known2.countMinLeadingZeros());
204    KnownZeroLow = std::min(KnownZeroLow, Known2.countMinTrailingZeros());
205    Known.Zero.setLowBits(KnownZeroLow);
206    if (KnownZeroHigh > 1)
207      Known.Zero.setHighBits(KnownZeroHigh - 1);
208    break;
209  }
210  case TargetOpcode::G_AND: {
211    // If either the LHS or the RHS are Zero, the result is zero.
212    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
213                         Depth + 1);
214    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
215                         Depth + 1);
216
217    // Output known-1 bits are only known if set in both the LHS & RHS.
218    Known.One &= Known2.One;
219    // Output known-0 are known to be clear if zero in either the LHS | RHS.
220    Known.Zero |= Known2.Zero;
221    break;
222  }
223  case TargetOpcode::G_OR: {
224    // If either the LHS or the RHS are Zero, the result is zero.
225    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
226                         Depth + 1);
227    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
228                         Depth + 1);
229
230    // Output known-0 bits are only known if clear in both the LHS & RHS.
231    Known.Zero &= Known2.Zero;
232    // Output known-1 are known to be set if set in either the LHS | RHS.
233    Known.One |= Known2.One;
234    break;
235  }
236  case TargetOpcode::G_MUL: {
237    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
238                         Depth + 1);
239    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
240                         Depth + 1);
241    // If low bits are zero in either operand, output low known-0 bits.
242    // Also compute a conservative estimate for high known-0 bits.
243    // More trickiness is possible, but this is sufficient for the
244    // interesting case of alignment computation.
245    unsigned TrailZ =
246        Known.countMinTrailingZeros() + Known2.countMinTrailingZeros();
247    unsigned LeadZ =
248        std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(),
249                 BitWidth) -
250        BitWidth;
251
252    Known.resetAll();
253    Known.Zero.setLowBits(std::min(TrailZ, BitWidth));
254    Known.Zero.setHighBits(std::min(LeadZ, BitWidth));
255    break;
256  }
257  case TargetOpcode::G_SELECT: {
258    computeKnownBitsImpl(MI.getOperand(3).getReg(), Known, DemandedElts,
259                         Depth + 1);
260    // If we don't know any bits, early out.
261    if (Known.isUnknown())
262      break;
263    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
264                         Depth + 1);
265    // Only known if known in both the LHS and RHS.
266    Known.One &= Known2.One;
267    Known.Zero &= Known2.Zero;
268    break;
269  }
270  case TargetOpcode::G_FCMP:
271  case TargetOpcode::G_ICMP: {
272    if (TL.getBooleanContents(DstTy.isVector(),
273                              Opcode == TargetOpcode::G_FCMP) ==
274            TargetLowering::ZeroOrOneBooleanContent &&
275        BitWidth > 1)
276      Known.Zero.setBitsFrom(1);
277    break;
278  }
279  case TargetOpcode::G_SEXT: {
280    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
281                         Depth + 1);
282    // If the sign bit is known to be zero or one, then sext will extend
283    // it to the top bits, else it will just zext.
284    Known = Known.sext(BitWidth);
285    break;
286  }
287  case TargetOpcode::G_ANYEXT: {
288    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
289                         Depth + 1);
290    Known = Known.zext(BitWidth, true /* ExtendedBitsAreKnownZero */);
291    break;
292  }
293  case TargetOpcode::G_LOAD: {
294    if (MI.hasOneMemOperand()) {
295      const MachineMemOperand *MMO = *MI.memoperands_begin();
296      if (const MDNode *Ranges = MMO->getRanges()) {
297        computeKnownBitsFromRangeMetadata(*Ranges, Known);
298      }
299    }
300    break;
301  }
302  case TargetOpcode::G_ZEXTLOAD: {
303    // Everything above the retrieved bits is zero
304    if (MI.hasOneMemOperand())
305      Known.Zero.setBitsFrom((*MI.memoperands_begin())->getSizeInBits());
306    break;
307  }
308  case TargetOpcode::G_ASHR:
309  case TargetOpcode::G_LSHR:
310  case TargetOpcode::G_SHL: {
311    KnownBits RHSKnown;
312    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
313                         Depth + 1);
314    if (!RHSKnown.isConstant()) {
315      LLVM_DEBUG(
316          MachineInstr *RHSMI = MRI.getVRegDef(MI.getOperand(2).getReg());
317          dbgs() << '[' << Depth << "] Shift not known constant: " << *RHSMI);
318      break;
319    }
320    uint64_t Shift = RHSKnown.getConstant().getZExtValue();
321    LLVM_DEBUG(dbgs() << '[' << Depth << "] Shift is " << Shift << '\n');
322
323    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
324                         Depth + 1);
325
326    switch (Opcode) {
327    case TargetOpcode::G_ASHR:
328      Known.Zero = Known.Zero.ashr(Shift);
329      Known.One = Known.One.ashr(Shift);
330      break;
331    case TargetOpcode::G_LSHR:
332      Known.Zero = Known.Zero.lshr(Shift);
333      Known.One = Known.One.lshr(Shift);
334      Known.Zero.setBitsFrom(Known.Zero.getBitWidth() - Shift);
335      break;
336    case TargetOpcode::G_SHL:
337      Known.Zero = Known.Zero.shl(Shift);
338      Known.One = Known.One.shl(Shift);
339      Known.Zero.setBits(0, Shift);
340      break;
341    }
342    break;
343  }
344  case TargetOpcode::G_INTTOPTR:
345  case TargetOpcode::G_PTRTOINT:
346    // Fall through and handle them the same as zext/trunc.
347    LLVM_FALLTHROUGH;
348  case TargetOpcode::G_ZEXT:
349  case TargetOpcode::G_TRUNC: {
350    Register SrcReg = MI.getOperand(1).getReg();
351    LLT SrcTy = MRI.getType(SrcReg);
352    unsigned SrcBitWidth = SrcTy.isPointer()
353                               ? DL.getIndexSizeInBits(SrcTy.getAddressSpace())
354                               : SrcTy.getSizeInBits();
355    assert(SrcBitWidth && "SrcBitWidth can't be zero");
356    Known = Known.zextOrTrunc(SrcBitWidth, true);
357    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
358    Known = Known.zextOrTrunc(BitWidth, true);
359    if (BitWidth > SrcBitWidth)
360      Known.Zero.setBitsFrom(SrcBitWidth);
361    break;
362  }
363  }
364
365  assert(!Known.hasConflict() && "Bits known to be one AND zero?");
366  LLVM_DEBUG(dbgs() << "[" << Depth << "] Compute known bits: " << MI << "["
367                    << Depth << "] Computed for: " << MI << "[" << Depth
368                    << "] Known: 0x"
369                    << (Known.Zero | Known.One).toString(16, false) << "\n"
370                    << "[" << Depth << "] Zero: 0x"
371                    << Known.Zero.toString(16, false) << "\n"
372                    << "[" << Depth << "] One:  0x"
373                    << Known.One.toString(16, false) << "\n");
374}
375
376unsigned GISelKnownBits::computeNumSignBits(Register R,
377                                            const APInt &DemandedElts,
378                                            unsigned Depth) {
379  MachineInstr &MI = *MRI.getVRegDef(R);
380  unsigned Opcode = MI.getOpcode();
381
382  if (Opcode == TargetOpcode::G_CONSTANT)
383    return MI.getOperand(1).getCImm()->getValue().getNumSignBits();
384
385  if (Depth == getMaxDepth())
386    return 1;
387
388  if (!DemandedElts)
389    return 1; // No demanded elts, better to assume we don't know anything.
390
391  LLT DstTy = MRI.getType(R);
392
393  // Handle the case where this is called on a register that does not have a
394  // type constraint. This is unlikely to occur except by looking through copies
395  // but it is possible for the initial register being queried to be in this
396  // state.
397  if (!DstTy.isValid())
398    return 1;
399
400  switch (Opcode) {
401  case TargetOpcode::COPY: {
402    MachineOperand &Src = MI.getOperand(1);
403    if (Src.getReg().isVirtual() && Src.getSubReg() == 0 &&
404        MRI.getType(Src.getReg()).isValid()) {
405      // Don't increment Depth for this one since we didn't do any work.
406      return computeNumSignBits(Src.getReg(), DemandedElts, Depth);
407    }
408
409    return 1;
410  }
411  case TargetOpcode::G_SEXT: {
412    Register Src = MI.getOperand(1).getReg();
413    LLT SrcTy = MRI.getType(Src);
414    unsigned Tmp = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits();
415    return computeNumSignBits(Src, DemandedElts, Depth + 1) + Tmp;
416  }
417  case TargetOpcode::G_TRUNC: {
418    Register Src = MI.getOperand(1).getReg();
419    LLT SrcTy = MRI.getType(Src);
420
421    // Check if the sign bits of source go down as far as the truncated value.
422    unsigned DstTyBits = DstTy.getScalarSizeInBits();
423    unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
424    unsigned NumSrcSignBits = computeNumSignBits(Src, DemandedElts, Depth + 1);
425    if (NumSrcSignBits > (NumSrcBits - DstTyBits))
426      return NumSrcSignBits - (NumSrcBits - DstTyBits);
427    break;
428  }
429  default:
430    break;
431  }
432
433  // TODO: Handle target instructions
434  // TODO: Fall back to known bits
435  return 1;
436}
437
438unsigned GISelKnownBits::computeNumSignBits(Register R, unsigned Depth) {
439  LLT Ty = MRI.getType(R);
440  APInt DemandedElts = Ty.isVector()
441                           ? APInt::getAllOnesValue(Ty.getNumElements())
442                           : APInt(1, 1);
443  return computeNumSignBits(R, DemandedElts, Depth);
444}
445
446void GISelKnownBitsAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
447  AU.setPreservesAll();
448  MachineFunctionPass::getAnalysisUsage(AU);
449}
450
451bool GISelKnownBitsAnalysis::runOnMachineFunction(MachineFunction &MF) {
452  return false;
453}
454