1//===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the class that knows how to divide SCEV's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Analysis/ScalarEvolutionDivision.h"
14#include "llvm/ADT/APInt.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/Analysis/ScalarEvolution.h"
18#include "llvm/Support/Casting.h"
19#include <cassert>
20#include <cstdint>
21
22namespace llvm {
23class Type;
24}
25
26using namespace llvm;
27
28namespace {
29
30static inline int sizeOfSCEV(const SCEV *S) {
31  struct FindSCEVSize {
32    int Size = 0;
33
34    FindSCEVSize() = default;
35
36    bool follow(const SCEV *S) {
37      ++Size;
38      // Keep looking at all operands of S.
39      return true;
40    }
41
42    bool isDone() const { return false; }
43  };
44
45  FindSCEVSize F;
46  SCEVTraversal<FindSCEVSize> ST(F);
47  ST.visitAll(S);
48  return F.Size;
49}
50
51} // namespace
52
53// Computes the Quotient and Remainder of the division of Numerator by
54// Denominator.
55void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
56                          const SCEV *Denominator, const SCEV **Quotient,
57                          const SCEV **Remainder) {
58  assert(Numerator && Denominator && "Uninitialized SCEV");
59
60  SCEVDivision D(SE, Numerator, Denominator);
61
62  // Check for the trivial case here to avoid having to check for it in the
63  // rest of the code.
64  if (Numerator == Denominator) {
65    *Quotient = D.One;
66    *Remainder = D.Zero;
67    return;
68  }
69
70  if (Numerator->isZero()) {
71    *Quotient = D.Zero;
72    *Remainder = D.Zero;
73    return;
74  }
75
76  // A simple case when N/1. The quotient is N.
77  if (Denominator->isOne()) {
78    *Quotient = Numerator;
79    *Remainder = D.Zero;
80    return;
81  }
82
83  // Split the Denominator when it is a product.
84  if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
85    const SCEV *Q, *R;
86    *Quotient = Numerator;
87    for (const SCEV *Op : T->operands()) {
88      divide(SE, *Quotient, Op, &Q, &R);
89      *Quotient = Q;
90
91      // Bail out when the Numerator is not divisible by one of the terms of
92      // the Denominator.
93      if (!R->isZero()) {
94        *Quotient = D.Zero;
95        *Remainder = Numerator;
96        return;
97      }
98    }
99    *Remainder = D.Zero;
100    return;
101  }
102
103  D.visit(Numerator);
104  *Quotient = D.Quotient;
105  *Remainder = D.Remainder;
106}
107
108void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
109  if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
110    APInt NumeratorVal = Numerator->getAPInt();
111    APInt DenominatorVal = D->getAPInt();
112    uint32_t NumeratorBW = NumeratorVal.getBitWidth();
113    uint32_t DenominatorBW = DenominatorVal.getBitWidth();
114
115    if (NumeratorBW > DenominatorBW)
116      DenominatorVal = DenominatorVal.sext(NumeratorBW);
117    else if (NumeratorBW < DenominatorBW)
118      NumeratorVal = NumeratorVal.sext(DenominatorBW);
119
120    APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
121    APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
122    APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
123    Quotient = SE.getConstant(QuotientVal);
124    Remainder = SE.getConstant(RemainderVal);
125    return;
126  }
127}
128
129void SCEVDivision::visitVScale(const SCEVVScale *Numerator) {
130  return cannotDivide(Numerator);
131}
132
133void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
134  const SCEV *StartQ, *StartR, *StepQ, *StepR;
135  if (!Numerator->isAffine())
136    return cannotDivide(Numerator);
137  divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
138  divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
139  // Bail out if the types do not match.
140  Type *Ty = Denominator->getType();
141  if (Ty != StartQ->getType() || Ty != StartR->getType() ||
142      Ty != StepQ->getType() || Ty != StepR->getType())
143    return cannotDivide(Numerator);
144  Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
145                              Numerator->getNoWrapFlags());
146  Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
147                               Numerator->getNoWrapFlags());
148}
149
150void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
151  SmallVector<const SCEV *, 2> Qs, Rs;
152  Type *Ty = Denominator->getType();
153
154  for (const SCEV *Op : Numerator->operands()) {
155    const SCEV *Q, *R;
156    divide(SE, Op, Denominator, &Q, &R);
157
158    // Bail out if types do not match.
159    if (Ty != Q->getType() || Ty != R->getType())
160      return cannotDivide(Numerator);
161
162    Qs.push_back(Q);
163    Rs.push_back(R);
164  }
165
166  if (Qs.size() == 1) {
167    Quotient = Qs[0];
168    Remainder = Rs[0];
169    return;
170  }
171
172  Quotient = SE.getAddExpr(Qs);
173  Remainder = SE.getAddExpr(Rs);
174}
175
176void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
177  SmallVector<const SCEV *, 2> Qs;
178  Type *Ty = Denominator->getType();
179
180  bool FoundDenominatorTerm = false;
181  for (const SCEV *Op : Numerator->operands()) {
182    // Bail out if types do not match.
183    if (Ty != Op->getType())
184      return cannotDivide(Numerator);
185
186    if (FoundDenominatorTerm) {
187      Qs.push_back(Op);
188      continue;
189    }
190
191    // Check whether Denominator divides one of the product operands.
192    const SCEV *Q, *R;
193    divide(SE, Op, Denominator, &Q, &R);
194    if (!R->isZero()) {
195      Qs.push_back(Op);
196      continue;
197    }
198
199    // Bail out if types do not match.
200    if (Ty != Q->getType())
201      return cannotDivide(Numerator);
202
203    FoundDenominatorTerm = true;
204    Qs.push_back(Q);
205  }
206
207  if (FoundDenominatorTerm) {
208    Remainder = Zero;
209    if (Qs.size() == 1)
210      Quotient = Qs[0];
211    else
212      Quotient = SE.getMulExpr(Qs);
213    return;
214  }
215
216  if (!isa<SCEVUnknown>(Denominator))
217    return cannotDivide(Numerator);
218
219  // The Remainder is obtained by replacing Denominator by 0 in Numerator.
220  ValueToSCEVMapTy RewriteMap;
221  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
222  Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
223
224  if (Remainder->isZero()) {
225    // The Quotient is obtained by replacing Denominator by 1 in Numerator.
226    RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
227    Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
228    return;
229  }
230
231  // Quotient is (Numerator - Remainder) divided by Denominator.
232  const SCEV *Q, *R;
233  const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
234  // This SCEV does not seem to simplify: fail the division here.
235  if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
236    return cannotDivide(Numerator);
237  divide(SE, Diff, Denominator, &Q, &R);
238  if (R != Zero)
239    return cannotDivide(Numerator);
240  Quotient = Q;
241}
242
243SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
244                           const SCEV *Denominator)
245    : SE(S), Denominator(Denominator) {
246  Zero = SE.getZero(Denominator->getType());
247  One = SE.getOne(Denominator->getType());
248
249  // We generally do not know how to divide Expr by Denominator. We initialize
250  // the division to a "cannot divide" state to simplify the rest of the code.
251  cannotDivide(Numerator);
252}
253
254// Convenience function for giving up on the division. We set the quotient to
255// be equal to zero and the remainder to be equal to the numerator.
256void SCEVDivision::cannotDivide(const SCEV *Numerator) {
257  Quotient = Zero;
258  Remainder = Numerator;
259}
260