1//===- InstructionCost.h ----------------------------------------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8/// \file 9/// This file defines an InstructionCost class that is used when calculating 10/// the cost of an instruction, or a group of instructions. In addition to a 11/// numeric value representing the cost the class also contains a state that 12/// can be used to encode particular properties, i.e. a cost being invalid or 13/// unknown. 14/// 15//===----------------------------------------------------------------------===// 16 17#ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H 18#define LLVM_SUPPORT_INSTRUCTIONCOST_H 19 20#include "llvm/ADT/Optional.h" 21 22namespace llvm { 23 24class raw_ostream; 25 26class InstructionCost { 27public: 28 using CostType = int; 29 30 /// These states can currently be used to indicate whether a cost is valid or 31 /// invalid. Examples of an invalid cost might be where the cost is 32 /// prohibitively expensive and the user wants to prevent certain 33 /// optimizations being performed. Or perhaps the cost is simply unknown 34 /// because the operation makes no sense in certain circumstances. These 35 /// states can be expanded in future to support other cases if necessary. 36 enum CostState { Valid, Invalid }; 37 38private: 39 CostType Value = 0; 40 CostState State = Valid; 41 42 void propagateState(const InstructionCost &RHS) { 43 if (RHS.State == Invalid) 44 State = Invalid; 45 } 46 47public: 48 // A default constructed InstructionCost is a valid zero cost 49 InstructionCost() = default; 50 51 InstructionCost(CostState) = delete; 52 InstructionCost(CostType Val) : Value(Val), State(Valid) {} 53 54 static InstructionCost getInvalid(CostType Val = 0) { 55 InstructionCost Tmp(Val); 56 Tmp.setInvalid(); 57 return Tmp; 58 } 59 60 bool isValid() const { return State == Valid; } 61 void setValid() { State = Valid; } 62 void setInvalid() { State = Invalid; } 63 CostState getState() const { return State; } 64 65 /// This function is intended to be used as sparingly as possible, since the 66 /// class provides the full range of operator support required for arithmetic 67 /// and comparisons. 68 Optional<CostType> getValue() const { 69 if (isValid()) 70 return Value; 71 return None; 72 } 73 74 /// For all of the arithmetic operators provided here any invalid state is 75 /// perpetuated and cannot be removed. Once a cost becomes invalid it stays 76 /// invalid, and it also inherits any invalid state from the RHS. Regardless 77 /// of the state, arithmetic and comparisons work on the actual values in the 78 /// same way as they would on a basic type, such as integer. 79 80 InstructionCost &operator+=(const InstructionCost &RHS) { 81 propagateState(RHS); 82 Value += RHS.Value; 83 return *this; 84 } 85 86 InstructionCost &operator+=(const CostType RHS) { 87 InstructionCost RHS2(RHS); 88 *this += RHS2; 89 return *this; 90 } 91 92 InstructionCost &operator-=(const InstructionCost &RHS) { 93 propagateState(RHS); 94 Value -= RHS.Value; 95 return *this; 96 } 97 98 InstructionCost &operator-=(const CostType RHS) { 99 InstructionCost RHS2(RHS); 100 *this -= RHS2; 101 return *this; 102 } 103 104 InstructionCost &operator*=(const InstructionCost &RHS) { 105 propagateState(RHS); 106 Value *= RHS.Value; 107 return *this; 108 } 109 110 InstructionCost &operator*=(const CostType RHS) { 111 InstructionCost RHS2(RHS); 112 *this *= RHS2; 113 return *this; 114 } 115 116 InstructionCost &operator/=(const InstructionCost &RHS) { 117 propagateState(RHS); 118 Value /= RHS.Value; 119 return *this; 120 } 121 122 InstructionCost &operator/=(const CostType RHS) { 123 InstructionCost RHS2(RHS); 124 *this /= RHS2; 125 return *this; 126 } 127 128 InstructionCost &operator++() { 129 *this += 1; 130 return *this; 131 } 132 133 InstructionCost operator++(int) { 134 InstructionCost Copy = *this; 135 ++*this; 136 return Copy; 137 } 138 139 InstructionCost &operator--() { 140 *this -= 1; 141 return *this; 142 } 143 144 InstructionCost operator--(int) { 145 InstructionCost Copy = *this; 146 --*this; 147 return Copy; 148 } 149 150 /// For the comparison operators we have chosen to use lexicographical 151 /// ordering where valid costs are always considered to be less than invalid 152 /// costs. This avoids having to add asserts to the comparison operators that 153 /// the states are valid and users can test for validity of the cost 154 /// explicitly. 155 bool operator<(const InstructionCost &RHS) const { 156 if (State != RHS.State) 157 return State < RHS.State; 158 return Value < RHS.Value; 159 } 160 161 // Implement in terms of operator< to ensure that the two comparisons stay in 162 // sync 163 bool operator==(const InstructionCost &RHS) const { 164 return !(*this < RHS) && !(RHS < *this); 165 } 166 167 bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } 168 169 bool operator==(const CostType RHS) const { 170 InstructionCost RHS2(RHS); 171 return *this == RHS2; 172 } 173 174 bool operator!=(const CostType RHS) const { return !(*this == RHS); } 175 176 bool operator>(const InstructionCost &RHS) const { return RHS < *this; } 177 178 bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } 179 180 bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); } 181 182 bool operator<(const CostType RHS) const { 183 InstructionCost RHS2(RHS); 184 return *this < RHS2; 185 } 186 187 bool operator>(const CostType RHS) const { 188 InstructionCost RHS2(RHS); 189 return *this > RHS2; 190 } 191 192 bool operator<=(const CostType RHS) const { 193 InstructionCost RHS2(RHS); 194 return *this <= RHS2; 195 } 196 197 bool operator>=(const CostType RHS) const { 198 InstructionCost RHS2(RHS); 199 return *this >= RHS2; 200 } 201 202 void print(raw_ostream &OS) const; 203 204 template <class Function> 205 auto map(const Function &F) const -> InstructionCost { 206 if (isValid()) 207 return F(*getValue()); 208 return getInvalid(); 209 } 210}; 211 212inline InstructionCost operator+(const InstructionCost &LHS, 213 const InstructionCost &RHS) { 214 InstructionCost LHS2(LHS); 215 LHS2 += RHS; 216 return LHS2; 217} 218 219inline InstructionCost operator-(const InstructionCost &LHS, 220 const InstructionCost &RHS) { 221 InstructionCost LHS2(LHS); 222 LHS2 -= RHS; 223 return LHS2; 224} 225 226inline InstructionCost operator*(const InstructionCost &LHS, 227 const InstructionCost &RHS) { 228 InstructionCost LHS2(LHS); 229 LHS2 *= RHS; 230 return LHS2; 231} 232 233inline InstructionCost operator/(const InstructionCost &LHS, 234 const InstructionCost &RHS) { 235 InstructionCost LHS2(LHS); 236 LHS2 /= RHS; 237 return LHS2; 238} 239 240inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) { 241 V.print(OS); 242 return OS; 243} 244 245} // namespace llvm 246 247#endif 248