1//===- InstructionCost.h ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// This file defines an InstructionCost class that is used when calculating
10/// the cost of an instruction, or a group of instructions. In addition to a
11/// numeric value representing the cost the class also contains a state that
12/// can be used to encode particular properties, such as a cost being invalid.
13/// Operations on InstructionCost implement saturation arithmetic, so that
14/// accumulating costs on large cost-values don't overflow.
15///
16//===----------------------------------------------------------------------===//
17
18#ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H
19#define LLVM_SUPPORT_INSTRUCTIONCOST_H
20
21#include "llvm/Support/MathExtras.h"
22#include <limits>
23#include <optional>
24
25namespace llvm {
26
27class raw_ostream;
28
29class InstructionCost {
30public:
31  using CostType = int64_t;
32
33  /// CostState describes the state of a cost.
34  enum CostState {
35    Valid,  /// < The cost value represents a valid cost, even when the
36            /// cost-value is large.
37    Invalid /// < Invalid indicates there is no way to represent the cost as a
38            /// numeric value. This state exists to represent a possible issue,
39            /// e.g. if the cost-model knows the operation cannot be expanded
40            /// into a valid code-sequence by the code-generator.  While some
41            /// passes may assert that the calculated cost must be valid, it is
42            /// up to individual passes how to interpret an Invalid cost. For
43            /// example, a transformation pass could choose not to perform a
44            /// transformation if the resulting cost would end up Invalid.
45            /// Because some passes may assert a cost is Valid, it is not
46            /// recommended to use Invalid costs to model 'Unknown'.
47            /// Note that Invalid is semantically different from a (very) high,
48            /// but valid cost, which intentionally indicates no issue, but
49            /// rather a strong preference not to select a certain operation.
50  };
51
52private:
53  CostType Value = 0;
54  CostState State = Valid;
55
56  void propagateState(const InstructionCost &RHS) {
57    if (RHS.State == Invalid)
58      State = Invalid;
59  }
60
61  static CostType getMaxValue() { return std::numeric_limits<CostType>::max(); }
62  static CostType getMinValue() { return std::numeric_limits<CostType>::min(); }
63
64public:
65  // A default constructed InstructionCost is a valid zero cost
66  InstructionCost() = default;
67
68  InstructionCost(CostState) = delete;
69  InstructionCost(CostType Val) : Value(Val), State(Valid) {}
70
71  static InstructionCost getMax() { return getMaxValue(); }
72  static InstructionCost getMin() { return getMinValue(); }
73  static InstructionCost getInvalid(CostType Val = 0) {
74    InstructionCost Tmp(Val);
75    Tmp.setInvalid();
76    return Tmp;
77  }
78
79  bool isValid() const { return State == Valid; }
80  void setValid() { State = Valid; }
81  void setInvalid() { State = Invalid; }
82  CostState getState() const { return State; }
83
84  /// This function is intended to be used as sparingly as possible, since the
85  /// class provides the full range of operator support required for arithmetic
86  /// and comparisons.
87  std::optional<CostType> getValue() const {
88    if (isValid())
89      return Value;
90    return std::nullopt;
91  }
92
93  /// For all of the arithmetic operators provided here any invalid state is
94  /// perpetuated and cannot be removed. Once a cost becomes invalid it stays
95  /// invalid, and it also inherits any invalid state from the RHS.
96  /// Arithmetic work on the actual values is implemented with saturation,
97  /// to avoid overflow when using more extreme cost values.
98
99  InstructionCost &operator+=(const InstructionCost &RHS) {
100    propagateState(RHS);
101
102    // Saturating addition.
103    InstructionCost::CostType Result;
104    if (AddOverflow(Value, RHS.Value, Result))
105      Result = RHS.Value > 0 ? getMaxValue() : getMinValue();
106
107    Value = Result;
108    return *this;
109  }
110
111  InstructionCost &operator+=(const CostType RHS) {
112    InstructionCost RHS2(RHS);
113    *this += RHS2;
114    return *this;
115  }
116
117  InstructionCost &operator-=(const InstructionCost &RHS) {
118    propagateState(RHS);
119
120    // Saturating subtract.
121    InstructionCost::CostType Result;
122    if (SubOverflow(Value, RHS.Value, Result))
123      Result = RHS.Value > 0 ? getMinValue() : getMaxValue();
124    Value = Result;
125    return *this;
126  }
127
128  InstructionCost &operator-=(const CostType RHS) {
129    InstructionCost RHS2(RHS);
130    *this -= RHS2;
131    return *this;
132  }
133
134  InstructionCost &operator*=(const InstructionCost &RHS) {
135    propagateState(RHS);
136
137    // Saturating multiply.
138    InstructionCost::CostType Result;
139    if (MulOverflow(Value, RHS.Value, Result)) {
140      if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0))
141        Result = getMaxValue();
142      else
143        Result = getMinValue();
144    }
145
146    Value = Result;
147    return *this;
148  }
149
150  InstructionCost &operator*=(const CostType RHS) {
151    InstructionCost RHS2(RHS);
152    *this *= RHS2;
153    return *this;
154  }
155
156  InstructionCost &operator/=(const InstructionCost &RHS) {
157    propagateState(RHS);
158    Value /= RHS.Value;
159    return *this;
160  }
161
162  InstructionCost &operator/=(const CostType RHS) {
163    InstructionCost RHS2(RHS);
164    *this /= RHS2;
165    return *this;
166  }
167
168  InstructionCost &operator++() {
169    *this += 1;
170    return *this;
171  }
172
173  InstructionCost operator++(int) {
174    InstructionCost Copy = *this;
175    ++*this;
176    return Copy;
177  }
178
179  InstructionCost &operator--() {
180    *this -= 1;
181    return *this;
182  }
183
184  InstructionCost operator--(int) {
185    InstructionCost Copy = *this;
186    --*this;
187    return Copy;
188  }
189
190  /// For the comparison operators we have chosen to use lexicographical
191  /// ordering where valid costs are always considered to be less than invalid
192  /// costs. This avoids having to add asserts to the comparison operators that
193  /// the states are valid and users can test for validity of the cost
194  /// explicitly.
195  bool operator<(const InstructionCost &RHS) const {
196    if (State != RHS.State)
197      return State < RHS.State;
198    return Value < RHS.Value;
199  }
200
201  // Implement in terms of operator< to ensure that the two comparisons stay in
202  // sync
203  bool operator==(const InstructionCost &RHS) const {
204    return !(*this < RHS) && !(RHS < *this);
205  }
206
207  bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); }
208
209  bool operator==(const CostType RHS) const {
210    InstructionCost RHS2(RHS);
211    return *this == RHS2;
212  }
213
214  bool operator!=(const CostType RHS) const { return !(*this == RHS); }
215
216  bool operator>(const InstructionCost &RHS) const { return RHS < *this; }
217
218  bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); }
219
220  bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); }
221
222  bool operator<(const CostType RHS) const {
223    InstructionCost RHS2(RHS);
224    return *this < RHS2;
225  }
226
227  bool operator>(const CostType RHS) const {
228    InstructionCost RHS2(RHS);
229    return *this > RHS2;
230  }
231
232  bool operator<=(const CostType RHS) const {
233    InstructionCost RHS2(RHS);
234    return *this <= RHS2;
235  }
236
237  bool operator>=(const CostType RHS) const {
238    InstructionCost RHS2(RHS);
239    return *this >= RHS2;
240  }
241
242  void print(raw_ostream &OS) const;
243
244  template <class Function>
245  auto map(const Function &F) const -> InstructionCost {
246    if (isValid())
247      return F(Value);
248    return getInvalid();
249  }
250};
251
252inline InstructionCost operator+(const InstructionCost &LHS,
253                                 const InstructionCost &RHS) {
254  InstructionCost LHS2(LHS);
255  LHS2 += RHS;
256  return LHS2;
257}
258
259inline InstructionCost operator-(const InstructionCost &LHS,
260                                 const InstructionCost &RHS) {
261  InstructionCost LHS2(LHS);
262  LHS2 -= RHS;
263  return LHS2;
264}
265
266inline InstructionCost operator*(const InstructionCost &LHS,
267                                 const InstructionCost &RHS) {
268  InstructionCost LHS2(LHS);
269  LHS2 *= RHS;
270  return LHS2;
271}
272
273inline InstructionCost operator/(const InstructionCost &LHS,
274                                 const InstructionCost &RHS) {
275  InstructionCost LHS2(LHS);
276  LHS2 /= RHS;
277  return LHS2;
278}
279
280inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) {
281  V.print(OS);
282  return OS;
283}
284
285} // namespace llvm
286
287#endif
288