1//===- lib/CodeGen/GlobalISel/GISelKnownBits.cpp --------------*- C++ *-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// Provides analysis for querying information about KnownBits during GISel
10/// passes.
11//
12//===------------------
13#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
14#include "llvm/Analysis/TargetTransformInfo.h"
15#include "llvm/Analysis/ValueTracking.h"
16#include "llvm/CodeGen/GlobalISel/Utils.h"
17#include "llvm/CodeGen/MachineFrameInfo.h"
18#include "llvm/CodeGen/MachineRegisterInfo.h"
19#include "llvm/CodeGen/TargetLowering.h"
20#include "llvm/CodeGen/TargetOpcodes.h"
21
22#define DEBUG_TYPE "gisel-known-bits"
23
24using namespace llvm;
25
26char llvm::GISelKnownBitsAnalysis::ID = 0;
27
28INITIALIZE_PASS(GISelKnownBitsAnalysis, DEBUG_TYPE,
29                "Analysis for ComputingKnownBits", false, true)
30
31GISelKnownBits::GISelKnownBits(MachineFunction &MF, unsigned MaxDepth)
32    : MF(MF), MRI(MF.getRegInfo()), TL(*MF.getSubtarget().getTargetLowering()),
33      DL(MF.getFunction().getParent()->getDataLayout()), MaxDepth(MaxDepth) {}
34
35Align GISelKnownBits::computeKnownAlignment(Register R, unsigned Depth) {
36  const MachineInstr *MI = MRI.getVRegDef(R);
37  switch (MI->getOpcode()) {
38  case TargetOpcode::COPY:
39    return computeKnownAlignment(MI->getOperand(1).getReg(), Depth);
40  case TargetOpcode::G_FRAME_INDEX: {
41    int FrameIdx = MI->getOperand(1).getIndex();
42    return MF.getFrameInfo().getObjectAlign(FrameIdx);
43  }
44  case TargetOpcode::G_INTRINSIC:
45  case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
46  default:
47    return TL.computeKnownAlignForTargetInstr(*this, R, MRI, Depth + 1);
48  }
49}
50
51KnownBits GISelKnownBits::getKnownBits(MachineInstr &MI) {
52  assert(MI.getNumExplicitDefs() == 1 &&
53         "expected single return generic instruction");
54  return getKnownBits(MI.getOperand(0).getReg());
55}
56
57KnownBits GISelKnownBits::getKnownBits(Register R) {
58  const LLT Ty = MRI.getType(R);
59  APInt DemandedElts =
60      Ty.isVector() ? APInt::getAllOnesValue(Ty.getNumElements()) : APInt(1, 1);
61  return getKnownBits(R, DemandedElts);
62}
63
64KnownBits GISelKnownBits::getKnownBits(Register R, const APInt &DemandedElts,
65                                       unsigned Depth) {
66  // For now, we only maintain the cache during one request.
67  assert(ComputeKnownBitsCache.empty() && "Cache should have been cleared");
68
69  KnownBits Known;
70  computeKnownBitsImpl(R, Known, DemandedElts);
71  ComputeKnownBitsCache.clear();
72  return Known;
73}
74
75bool GISelKnownBits::signBitIsZero(Register R) {
76  LLT Ty = MRI.getType(R);
77  unsigned BitWidth = Ty.getScalarSizeInBits();
78  return maskedValueIsZero(R, APInt::getSignMask(BitWidth));
79}
80
81APInt GISelKnownBits::getKnownZeroes(Register R) {
82  return getKnownBits(R).Zero;
83}
84
85APInt GISelKnownBits::getKnownOnes(Register R) { return getKnownBits(R).One; }
86
87LLVM_ATTRIBUTE_UNUSED static void
88dumpResult(const MachineInstr &MI, const KnownBits &Known, unsigned Depth) {
89  dbgs() << "[" << Depth << "] Compute known bits: " << MI << "[" << Depth
90         << "] Computed for: " << MI << "[" << Depth << "] Known: 0x"
91         << (Known.Zero | Known.One).toString(16, false) << "\n"
92         << "[" << Depth << "] Zero: 0x" << Known.Zero.toString(16, false)
93         << "\n"
94         << "[" << Depth << "] One:  0x" << Known.One.toString(16, false)
95         << "\n";
96}
97
98void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
99                                          const APInt &DemandedElts,
100                                          unsigned Depth) {
101  MachineInstr &MI = *MRI.getVRegDef(R);
102  unsigned Opcode = MI.getOpcode();
103  LLT DstTy = MRI.getType(R);
104
105  // Handle the case where this is called on a register that does not have a
106  // type constraint (i.e. it has a register class constraint instead). This is
107  // unlikely to occur except by looking through copies but it is possible for
108  // the initial register being queried to be in this state.
109  if (!DstTy.isValid()) {
110    Known = KnownBits();
111    return;
112  }
113
114  unsigned BitWidth = DstTy.getSizeInBits();
115  auto CacheEntry = ComputeKnownBitsCache.find(R);
116  if (CacheEntry != ComputeKnownBitsCache.end()) {
117    Known = CacheEntry->second;
118    LLVM_DEBUG(dbgs() << "Cache hit at ");
119    LLVM_DEBUG(dumpResult(MI, Known, Depth));
120    assert(Known.getBitWidth() == BitWidth && "Cache entry size doesn't match");
121    return;
122  }
123  Known = KnownBits(BitWidth); // Don't know anything
124
125  if (DstTy.isVector())
126    return; // TODO: Handle vectors.
127
128  // Depth may get bigger than max depth if it gets passed to a different
129  // GISelKnownBits object.
130  // This may happen when say a generic part uses a GISelKnownBits object
131  // with some max depth, but then we hit TL.computeKnownBitsForTargetInstr
132  // which creates a new GISelKnownBits object with a different and smaller
133  // depth. If we just check for equality, we would never exit if the depth
134  // that is passed down to the target specific GISelKnownBits object is
135  // already bigger than its max depth.
136  if (Depth >= getMaxDepth())
137    return;
138
139  if (!DemandedElts)
140    return; // No demanded elts, better to assume we don't know anything.
141
142  KnownBits Known2;
143
144  switch (Opcode) {
145  default:
146    TL.computeKnownBitsForTargetInstr(*this, R, Known, DemandedElts, MRI,
147                                      Depth);
148    break;
149  case TargetOpcode::COPY:
150  case TargetOpcode::G_PHI:
151  case TargetOpcode::PHI: {
152    Known.One = APInt::getAllOnesValue(BitWidth);
153    Known.Zero = APInt::getAllOnesValue(BitWidth);
154    // Destination registers should not have subregisters at this
155    // point of the pipeline, otherwise the main live-range will be
156    // defined more than once, which is against SSA.
157    assert(MI.getOperand(0).getSubReg() == 0 && "Is this code in SSA?");
158    // Record in the cache that we know nothing for MI.
159    // This will get updated later and in the meantime, if we reach that
160    // phi again, because of a loop, we will cut the search thanks to this
161    // cache entry.
162    // We could actually build up more information on the phi by not cutting
163    // the search, but that additional information is more a side effect
164    // than an intended choice.
165    // Therefore, for now, save on compile time until we derive a proper way
166    // to derive known bits for PHIs within loops.
167    ComputeKnownBitsCache[R] = KnownBits(BitWidth);
168    // PHI's operand are a mix of registers and basic blocks interleaved.
169    // We only care about the register ones.
170    for (unsigned Idx = 1; Idx < MI.getNumOperands(); Idx += 2) {
171      const MachineOperand &Src = MI.getOperand(Idx);
172      Register SrcReg = Src.getReg();
173      // Look through trivial copies and phis but don't look through trivial
174      // copies or phis of the form `%1:(s32) = OP %0:gpr32`, known-bits
175      // analysis is currently unable to determine the bit width of a
176      // register class.
177      //
178      // We can't use NoSubRegister by name as it's defined by each target but
179      // it's always defined to be 0 by tablegen.
180      if (SrcReg.isVirtual() && Src.getSubReg() == 0 /*NoSubRegister*/ &&
181          MRI.getType(SrcReg).isValid()) {
182        // For COPYs we don't do anything, don't increase the depth.
183        computeKnownBitsImpl(SrcReg, Known2, DemandedElts,
184                             Depth + (Opcode != TargetOpcode::COPY));
185        Known.One &= Known2.One;
186        Known.Zero &= Known2.Zero;
187        // If we reach a point where we don't know anything
188        // just stop looking through the operands.
189        if (Known.One == 0 && Known.Zero == 0)
190          break;
191      } else {
192        // We know nothing.
193        Known = KnownBits(BitWidth);
194        break;
195      }
196    }
197    break;
198  }
199  case TargetOpcode::G_CONSTANT: {
200    auto CstVal = getConstantVRegVal(R, MRI);
201    if (!CstVal)
202      break;
203    Known.One = *CstVal;
204    Known.Zero = ~Known.One;
205    break;
206  }
207  case TargetOpcode::G_FRAME_INDEX: {
208    int FrameIdx = MI.getOperand(1).getIndex();
209    TL.computeKnownBitsForFrameIndex(FrameIdx, Known, MF);
210    break;
211  }
212  case TargetOpcode::G_SUB: {
213    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
214                         Depth + 1);
215    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
216                         Depth + 1);
217    Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false, Known,
218                                        Known2);
219    break;
220  }
221  case TargetOpcode::G_XOR: {
222    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
223                         Depth + 1);
224    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
225                         Depth + 1);
226
227    Known ^= Known2;
228    break;
229  }
230  case TargetOpcode::G_PTR_ADD: {
231    // G_PTR_ADD is like G_ADD. FIXME: Is this true for all targets?
232    LLT Ty = MRI.getType(MI.getOperand(1).getReg());
233    if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
234      break;
235    LLVM_FALLTHROUGH;
236  }
237  case TargetOpcode::G_ADD: {
238    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
239                         Depth + 1);
240    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
241                         Depth + 1);
242    Known =
243        KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false, Known, Known2);
244    break;
245  }
246  case TargetOpcode::G_AND: {
247    // If either the LHS or the RHS are Zero, the result is zero.
248    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
249                         Depth + 1);
250    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
251                         Depth + 1);
252
253    Known &= Known2;
254    break;
255  }
256  case TargetOpcode::G_OR: {
257    // If either the LHS or the RHS are Zero, the result is zero.
258    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
259                         Depth + 1);
260    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
261                         Depth + 1);
262
263    Known |= Known2;
264    break;
265  }
266  case TargetOpcode::G_MUL: {
267    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
268                         Depth + 1);
269    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
270                         Depth + 1);
271    // If low bits are zero in either operand, output low known-0 bits.
272    // Also compute a conservative estimate for high known-0 bits.
273    // More trickiness is possible, but this is sufficient for the
274    // interesting case of alignment computation.
275    unsigned TrailZ =
276        Known.countMinTrailingZeros() + Known2.countMinTrailingZeros();
277    unsigned LeadZ =
278        std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(),
279                 BitWidth) -
280        BitWidth;
281
282    Known.resetAll();
283    Known.Zero.setLowBits(std::min(TrailZ, BitWidth));
284    Known.Zero.setHighBits(std::min(LeadZ, BitWidth));
285    break;
286  }
287  case TargetOpcode::G_SELECT: {
288    computeKnownBitsImpl(MI.getOperand(3).getReg(), Known, DemandedElts,
289                         Depth + 1);
290    // If we don't know any bits, early out.
291    if (Known.isUnknown())
292      break;
293    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
294                         Depth + 1);
295    // Only known if known in both the LHS and RHS.
296    Known.One &= Known2.One;
297    Known.Zero &= Known2.Zero;
298    break;
299  }
300  case TargetOpcode::G_FCMP:
301  case TargetOpcode::G_ICMP: {
302    if (TL.getBooleanContents(DstTy.isVector(),
303                              Opcode == TargetOpcode::G_FCMP) ==
304            TargetLowering::ZeroOrOneBooleanContent &&
305        BitWidth > 1)
306      Known.Zero.setBitsFrom(1);
307    break;
308  }
309  case TargetOpcode::G_SEXT: {
310    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
311                         Depth + 1);
312    // If the sign bit is known to be zero or one, then sext will extend
313    // it to the top bits, else it will just zext.
314    Known = Known.sext(BitWidth);
315    break;
316  }
317  case TargetOpcode::G_ANYEXT: {
318    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
319                         Depth + 1);
320    Known = Known.zext(BitWidth);
321    break;
322  }
323  case TargetOpcode::G_LOAD: {
324    if (MI.hasOneMemOperand()) {
325      const MachineMemOperand *MMO = *MI.memoperands_begin();
326      if (const MDNode *Ranges = MMO->getRanges()) {
327        computeKnownBitsFromRangeMetadata(*Ranges, Known);
328      }
329    }
330    break;
331  }
332  case TargetOpcode::G_ZEXTLOAD: {
333    // Everything above the retrieved bits is zero
334    if (MI.hasOneMemOperand())
335      Known.Zero.setBitsFrom((*MI.memoperands_begin())->getSizeInBits());
336    break;
337  }
338  case TargetOpcode::G_ASHR:
339  case TargetOpcode::G_LSHR:
340  case TargetOpcode::G_SHL: {
341    KnownBits RHSKnown;
342    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
343                         Depth + 1);
344    if (!RHSKnown.isConstant()) {
345      LLVM_DEBUG(
346          MachineInstr *RHSMI = MRI.getVRegDef(MI.getOperand(2).getReg());
347          dbgs() << '[' << Depth << "] Shift not known constant: " << *RHSMI);
348      break;
349    }
350    uint64_t Shift = RHSKnown.getConstant().getZExtValue();
351    LLVM_DEBUG(dbgs() << '[' << Depth << "] Shift is " << Shift << '\n');
352
353    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
354                         Depth + 1);
355
356    switch (Opcode) {
357    case TargetOpcode::G_ASHR:
358      Known.Zero = Known.Zero.ashr(Shift);
359      Known.One = Known.One.ashr(Shift);
360      break;
361    case TargetOpcode::G_LSHR:
362      Known.Zero = Known.Zero.lshr(Shift);
363      Known.One = Known.One.lshr(Shift);
364      Known.Zero.setBitsFrom(Known.Zero.getBitWidth() - Shift);
365      break;
366    case TargetOpcode::G_SHL:
367      Known.Zero = Known.Zero.shl(Shift);
368      Known.One = Known.One.shl(Shift);
369      Known.Zero.setBits(0, Shift);
370      break;
371    }
372    break;
373  }
374  case TargetOpcode::G_INTTOPTR:
375  case TargetOpcode::G_PTRTOINT:
376    // Fall through and handle them the same as zext/trunc.
377    LLVM_FALLTHROUGH;
378  case TargetOpcode::G_ZEXT:
379  case TargetOpcode::G_TRUNC: {
380    Register SrcReg = MI.getOperand(1).getReg();
381    LLT SrcTy = MRI.getType(SrcReg);
382    unsigned SrcBitWidth = SrcTy.isPointer()
383                               ? DL.getIndexSizeInBits(SrcTy.getAddressSpace())
384                               : SrcTy.getSizeInBits();
385    assert(SrcBitWidth && "SrcBitWidth can't be zero");
386    Known = Known.zextOrTrunc(SrcBitWidth);
387    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
388    Known = Known.zextOrTrunc(BitWidth);
389    if (BitWidth > SrcBitWidth)
390      Known.Zero.setBitsFrom(SrcBitWidth);
391    break;
392  }
393  }
394
395  assert(!Known.hasConflict() && "Bits known to be one AND zero?");
396  LLVM_DEBUG(dumpResult(MI, Known, Depth));
397
398  // Update the cache.
399  ComputeKnownBitsCache[R] = Known;
400}
401
402unsigned GISelKnownBits::computeNumSignBits(Register R,
403                                            const APInt &DemandedElts,
404                                            unsigned Depth) {
405  MachineInstr &MI = *MRI.getVRegDef(R);
406  unsigned Opcode = MI.getOpcode();
407
408  if (Opcode == TargetOpcode::G_CONSTANT)
409    return MI.getOperand(1).getCImm()->getValue().getNumSignBits();
410
411  if (Depth == getMaxDepth())
412    return 1;
413
414  if (!DemandedElts)
415    return 1; // No demanded elts, better to assume we don't know anything.
416
417  LLT DstTy = MRI.getType(R);
418  const unsigned TyBits = DstTy.getScalarSizeInBits();
419
420  // Handle the case where this is called on a register that does not have a
421  // type constraint. This is unlikely to occur except by looking through copies
422  // but it is possible for the initial register being queried to be in this
423  // state.
424  if (!DstTy.isValid())
425    return 1;
426
427  unsigned FirstAnswer = 1;
428  switch (Opcode) {
429  case TargetOpcode::COPY: {
430    MachineOperand &Src = MI.getOperand(1);
431    if (Src.getReg().isVirtual() && Src.getSubReg() == 0 &&
432        MRI.getType(Src.getReg()).isValid()) {
433      // Don't increment Depth for this one since we didn't do any work.
434      return computeNumSignBits(Src.getReg(), DemandedElts, Depth);
435    }
436
437    return 1;
438  }
439  case TargetOpcode::G_SEXT: {
440    Register Src = MI.getOperand(1).getReg();
441    LLT SrcTy = MRI.getType(Src);
442    unsigned Tmp = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits();
443    return computeNumSignBits(Src, DemandedElts, Depth + 1) + Tmp;
444  }
445  case TargetOpcode::G_SEXTLOAD: {
446    Register Dst = MI.getOperand(0).getReg();
447    LLT Ty = MRI.getType(Dst);
448    // TODO: add vector support
449    if (Ty.isVector())
450      break;
451    if (MI.hasOneMemOperand())
452      return Ty.getSizeInBits() - (*MI.memoperands_begin())->getSizeInBits();
453    break;
454  }
455  case TargetOpcode::G_TRUNC: {
456    Register Src = MI.getOperand(1).getReg();
457    LLT SrcTy = MRI.getType(Src);
458
459    // Check if the sign bits of source go down as far as the truncated value.
460    unsigned DstTyBits = DstTy.getScalarSizeInBits();
461    unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
462    unsigned NumSrcSignBits = computeNumSignBits(Src, DemandedElts, Depth + 1);
463    if (NumSrcSignBits > (NumSrcBits - DstTyBits))
464      return NumSrcSignBits - (NumSrcBits - DstTyBits);
465    break;
466  }
467  case TargetOpcode::G_INTRINSIC:
468  case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
469  default: {
470    unsigned NumBits =
471      TL.computeNumSignBitsForTargetInstr(*this, R, DemandedElts, MRI, Depth);
472    if (NumBits > 1)
473      FirstAnswer = std::max(FirstAnswer, NumBits);
474    break;
475  }
476  }
477
478  // Finally, if we can prove that the top bits of the result are 0's or 1's,
479  // use this information.
480  KnownBits Known = getKnownBits(R, DemandedElts, Depth);
481  APInt Mask;
482  if (Known.isNonNegative()) {        // sign bit is 0
483    Mask = Known.Zero;
484  } else if (Known.isNegative()) {  // sign bit is 1;
485    Mask = Known.One;
486  } else {
487    // Nothing known.
488    return FirstAnswer;
489  }
490
491  // Okay, we know that the sign bit in Mask is set.  Use CLO to determine
492  // the number of identical bits in the top of the input value.
493  Mask <<= Mask.getBitWidth() - TyBits;
494  return std::max(FirstAnswer, Mask.countLeadingOnes());
495}
496
497unsigned GISelKnownBits::computeNumSignBits(Register R, unsigned Depth) {
498  LLT Ty = MRI.getType(R);
499  APInt DemandedElts = Ty.isVector()
500                           ? APInt::getAllOnesValue(Ty.getNumElements())
501                           : APInt(1, 1);
502  return computeNumSignBits(R, DemandedElts, Depth);
503}
504
505void GISelKnownBitsAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
506  AU.setPreservesAll();
507  MachineFunctionPass::getAnalysisUsage(AU);
508}
509
510bool GISelKnownBitsAnalysis::runOnMachineFunction(MachineFunction &MF) {
511  return false;
512}
513