1//===- InstructionCost.h ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// This file defines an InstructionCost class that is used when calculating
10/// the cost of an instruction, or a group of instructions. In addition to a
11/// numeric value representing the cost the class also contains a state that
12/// can be used to encode particular properties, i.e. a cost being invalid or
13/// unknown.
14///
15//===----------------------------------------------------------------------===//
16
17#ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H
18#define LLVM_SUPPORT_INSTRUCTIONCOST_H
19
20#include "llvm/ADT/Optional.h"
21
22namespace llvm {
23
24class raw_ostream;
25
26class InstructionCost {
27public:
28  using CostType = int;
29
30  /// These states can currently be used to indicate whether a cost is valid or
31  /// invalid. Examples of an invalid cost might be where the cost is
32  /// prohibitively expensive and the user wants to prevent certain
33  /// optimizations being performed. Or perhaps the cost is simply unknown
34  /// because the operation makes no sense in certain circumstances. These
35  /// states can be expanded in future to support other cases if necessary.
36  enum CostState { Valid, Invalid };
37
38private:
39  CostType Value = 0;
40  CostState State = Valid;
41
42  void propagateState(const InstructionCost &RHS) {
43    if (RHS.State == Invalid)
44      State = Invalid;
45  }
46
47public:
48  // A default constructed InstructionCost is a valid zero cost
49  InstructionCost() = default;
50
51  InstructionCost(CostState) = delete;
52  InstructionCost(CostType Val) : Value(Val), State(Valid) {}
53
54  static InstructionCost getInvalid(CostType Val = 0) {
55    InstructionCost Tmp(Val);
56    Tmp.setInvalid();
57    return Tmp;
58  }
59
60  bool isValid() const { return State == Valid; }
61  void setValid() { State = Valid; }
62  void setInvalid() { State = Invalid; }
63  CostState getState() const { return State; }
64
65  /// This function is intended to be used as sparingly as possible, since the
66  /// class provides the full range of operator support required for arithmetic
67  /// and comparisons.
68  Optional<CostType> getValue() const {
69    if (isValid())
70      return Value;
71    return None;
72  }
73
74  /// For all of the arithmetic operators provided here any invalid state is
75  /// perpetuated and cannot be removed. Once a cost becomes invalid it stays
76  /// invalid, and it also inherits any invalid state from the RHS. Regardless
77  /// of the state, arithmetic and comparisons work on the actual values in the
78  /// same way as they would on a basic type, such as integer.
79
80  InstructionCost &operator+=(const InstructionCost &RHS) {
81    propagateState(RHS);
82    Value += RHS.Value;
83    return *this;
84  }
85
86  InstructionCost &operator+=(const CostType RHS) {
87    InstructionCost RHS2(RHS);
88    *this += RHS2;
89    return *this;
90  }
91
92  InstructionCost &operator-=(const InstructionCost &RHS) {
93    propagateState(RHS);
94    Value -= RHS.Value;
95    return *this;
96  }
97
98  InstructionCost &operator-=(const CostType RHS) {
99    InstructionCost RHS2(RHS);
100    *this -= RHS2;
101    return *this;
102  }
103
104  InstructionCost &operator*=(const InstructionCost &RHS) {
105    propagateState(RHS);
106    Value *= RHS.Value;
107    return *this;
108  }
109
110  InstructionCost &operator*=(const CostType RHS) {
111    InstructionCost RHS2(RHS);
112    *this *= RHS2;
113    return *this;
114  }
115
116  InstructionCost &operator/=(const InstructionCost &RHS) {
117    propagateState(RHS);
118    Value /= RHS.Value;
119    return *this;
120  }
121
122  InstructionCost &operator/=(const CostType RHS) {
123    InstructionCost RHS2(RHS);
124    *this /= RHS2;
125    return *this;
126  }
127
128  InstructionCost &operator++() {
129    *this += 1;
130    return *this;
131  }
132
133  InstructionCost operator++(int) {
134    InstructionCost Copy = *this;
135    ++*this;
136    return Copy;
137  }
138
139  InstructionCost &operator--() {
140    *this -= 1;
141    return *this;
142  }
143
144  InstructionCost operator--(int) {
145    InstructionCost Copy = *this;
146    --*this;
147    return Copy;
148  }
149
150  /// For the comparison operators we have chosen to use lexicographical
151  /// ordering where valid costs are always considered to be less than invalid
152  /// costs. This avoids having to add asserts to the comparison operators that
153  /// the states are valid and users can test for validity of the cost
154  /// explicitly.
155  bool operator<(const InstructionCost &RHS) const {
156    if (State != RHS.State)
157      return State < RHS.State;
158    return Value < RHS.Value;
159  }
160
161  // Implement in terms of operator< to ensure that the two comparisons stay in
162  // sync
163  bool operator==(const InstructionCost &RHS) const {
164    return !(*this < RHS) && !(RHS < *this);
165  }
166
167  bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); }
168
169  bool operator==(const CostType RHS) const {
170    InstructionCost RHS2(RHS);
171    return *this == RHS2;
172  }
173
174  bool operator!=(const CostType RHS) const { return !(*this == RHS); }
175
176  bool operator>(const InstructionCost &RHS) const { return RHS < *this; }
177
178  bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); }
179
180  bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); }
181
182  bool operator<(const CostType RHS) const {
183    InstructionCost RHS2(RHS);
184    return *this < RHS2;
185  }
186
187  bool operator>(const CostType RHS) const {
188    InstructionCost RHS2(RHS);
189    return *this > RHS2;
190  }
191
192  bool operator<=(const CostType RHS) const {
193    InstructionCost RHS2(RHS);
194    return *this <= RHS2;
195  }
196
197  bool operator>=(const CostType RHS) const {
198    InstructionCost RHS2(RHS);
199    return *this >= RHS2;
200  }
201
202  void print(raw_ostream &OS) const;
203
204  template <class Function>
205  auto map(const Function &F) const -> InstructionCost {
206    if (isValid())
207      return F(*getValue());
208    return getInvalid();
209  }
210};
211
212inline InstructionCost operator+(const InstructionCost &LHS,
213                                 const InstructionCost &RHS) {
214  InstructionCost LHS2(LHS);
215  LHS2 += RHS;
216  return LHS2;
217}
218
219inline InstructionCost operator-(const InstructionCost &LHS,
220                                 const InstructionCost &RHS) {
221  InstructionCost LHS2(LHS);
222  LHS2 -= RHS;
223  return LHS2;
224}
225
226inline InstructionCost operator*(const InstructionCost &LHS,
227                                 const InstructionCost &RHS) {
228  InstructionCost LHS2(LHS);
229  LHS2 *= RHS;
230  return LHS2;
231}
232
233inline InstructionCost operator/(const InstructionCost &LHS,
234                                 const InstructionCost &RHS) {
235  InstructionCost LHS2(LHS);
236  LHS2 /= RHS;
237  return LHS2;
238}
239
240inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) {
241  V.print(OS);
242  return OS;
243}
244
245} // namespace llvm
246
247#endif
248