Math.h revision 360784
1//===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef LLVM_CODEGEN_PBQP_MATH_H 10#define LLVM_CODEGEN_PBQP_MATH_H 11 12#include "llvm/ADT/Hashing.h" 13#include "llvm/ADT/STLExtras.h" 14#include <algorithm> 15#include <cassert> 16#include <functional> 17#include <memory> 18 19namespace llvm { 20namespace PBQP { 21 22using PBQPNum = float; 23 24/// PBQP Vector class. 25class Vector { 26 friend hash_code hash_value(const Vector &); 27 28public: 29 /// Construct a PBQP vector of the given size. 30 explicit Vector(unsigned Length) 31 : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {} 32 33 /// Construct a PBQP vector with initializer. 34 Vector(unsigned Length, PBQPNum InitVal) 35 : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) { 36 std::fill(Data.get(), Data.get() + Length, InitVal); 37 } 38 39 /// Copy construct a PBQP vector. 40 Vector(const Vector &V) 41 : Length(V.Length), Data(std::make_unique<PBQPNum []>(Length)) { 42 std::copy(V.Data.get(), V.Data.get() + Length, Data.get()); 43 } 44 45 /// Move construct a PBQP vector. 46 Vector(Vector &&V) 47 : Length(V.Length), Data(std::move(V.Data)) { 48 V.Length = 0; 49 } 50 51 /// Comparison operator. 52 bool operator==(const Vector &V) const { 53 assert(Length != 0 && Data && "Invalid vector"); 54 if (Length != V.Length) 55 return false; 56 return std::equal(Data.get(), Data.get() + Length, V.Data.get()); 57 } 58 59 /// Return the length of the vector 60 unsigned getLength() const { 61 assert(Length != 0 && Data && "Invalid vector"); 62 return Length; 63 } 64 65 /// Element access. 66 PBQPNum& operator[](unsigned Index) { 67 assert(Length != 0 && Data && "Invalid vector"); 68 assert(Index < Length && "Vector element access out of bounds."); 69 return Data[Index]; 70 } 71 72 /// Const element access. 73 const PBQPNum& operator[](unsigned Index) const { 74 assert(Length != 0 && Data && "Invalid vector"); 75 assert(Index < Length && "Vector element access out of bounds."); 76 return Data[Index]; 77 } 78 79 /// Add another vector to this one. 80 Vector& operator+=(const Vector &V) { 81 assert(Length != 0 && Data && "Invalid vector"); 82 assert(Length == V.Length && "Vector length mismatch."); 83 std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(), 84 std::plus<PBQPNum>()); 85 return *this; 86 } 87 88 /// Returns the index of the minimum value in this vector 89 unsigned minIndex() const { 90 assert(Length != 0 && Data && "Invalid vector"); 91 return std::min_element(Data.get(), Data.get() + Length) - Data.get(); 92 } 93 94private: 95 unsigned Length; 96 std::unique_ptr<PBQPNum []> Data; 97}; 98 99/// Return a hash_value for the given vector. 100inline hash_code hash_value(const Vector &V) { 101 unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get()); 102 unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length); 103 return hash_combine(V.Length, hash_combine_range(VBegin, VEnd)); 104} 105 106/// Output a textual representation of the given vector on the given 107/// output stream. 108template <typename OStream> 109OStream& operator<<(OStream &OS, const Vector &V) { 110 assert((V.getLength() != 0) && "Zero-length vector badness."); 111 112 OS << "[ " << V[0]; 113 for (unsigned i = 1; i < V.getLength(); ++i) 114 OS << ", " << V[i]; 115 OS << " ]"; 116 117 return OS; 118} 119 120/// PBQP Matrix class 121class Matrix { 122private: 123 friend hash_code hash_value(const Matrix &); 124 125public: 126 /// Construct a PBQP Matrix with the given dimensions. 127 Matrix(unsigned Rows, unsigned Cols) : 128 Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(Rows * Cols)) { 129 } 130 131 /// Construct a PBQP Matrix with the given dimensions and initial 132 /// value. 133 Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal) 134 : Rows(Rows), Cols(Cols), 135 Data(std::make_unique<PBQPNum []>(Rows * Cols)) { 136 std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal); 137 } 138 139 /// Copy construct a PBQP matrix. 140 Matrix(const Matrix &M) 141 : Rows(M.Rows), Cols(M.Cols), 142 Data(std::make_unique<PBQPNum []>(Rows * Cols)) { 143 std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get()); 144 } 145 146 /// Move construct a PBQP matrix. 147 Matrix(Matrix &&M) 148 : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) { 149 M.Rows = M.Cols = 0; 150 } 151 152 /// Comparison operator. 153 bool operator==(const Matrix &M) const { 154 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 155 if (Rows != M.Rows || Cols != M.Cols) 156 return false; 157 return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get()); 158 } 159 160 /// Return the number of rows in this matrix. 161 unsigned getRows() const { 162 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 163 return Rows; 164 } 165 166 /// Return the number of cols in this matrix. 167 unsigned getCols() const { 168 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 169 return Cols; 170 } 171 172 /// Matrix element access. 173 PBQPNum* operator[](unsigned R) { 174 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 175 assert(R < Rows && "Row out of bounds."); 176 return Data.get() + (R * Cols); 177 } 178 179 /// Matrix element access. 180 const PBQPNum* operator[](unsigned R) const { 181 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 182 assert(R < Rows && "Row out of bounds."); 183 return Data.get() + (R * Cols); 184 } 185 186 /// Returns the given row as a vector. 187 Vector getRowAsVector(unsigned R) const { 188 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 189 Vector V(Cols); 190 for (unsigned C = 0; C < Cols; ++C) 191 V[C] = (*this)[R][C]; 192 return V; 193 } 194 195 /// Returns the given column as a vector. 196 Vector getColAsVector(unsigned C) const { 197 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 198 Vector V(Rows); 199 for (unsigned R = 0; R < Rows; ++R) 200 V[R] = (*this)[R][C]; 201 return V; 202 } 203 204 /// Matrix transpose. 205 Matrix transpose() const { 206 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 207 Matrix M(Cols, Rows); 208 for (unsigned r = 0; r < Rows; ++r) 209 for (unsigned c = 0; c < Cols; ++c) 210 M[c][r] = (*this)[r][c]; 211 return M; 212 } 213 214 /// Add the given matrix to this one. 215 Matrix& operator+=(const Matrix &M) { 216 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 217 assert(Rows == M.Rows && Cols == M.Cols && 218 "Matrix dimensions mismatch."); 219 std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(), 220 Data.get(), std::plus<PBQPNum>()); 221 return *this; 222 } 223 224 Matrix operator+(const Matrix &M) { 225 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 226 Matrix Tmp(*this); 227 Tmp += M; 228 return Tmp; 229 } 230 231private: 232 unsigned Rows, Cols; 233 std::unique_ptr<PBQPNum []> Data; 234}; 235 236/// Return a hash_code for the given matrix. 237inline hash_code hash_value(const Matrix &M) { 238 unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get()); 239 unsigned *MEnd = 240 reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols)); 241 return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd)); 242} 243 244/// Output a textual representation of the given matrix on the given 245/// output stream. 246template <typename OStream> 247OStream& operator<<(OStream &OS, const Matrix &M) { 248 assert((M.getRows() != 0) && "Zero-row matrix badness."); 249 for (unsigned i = 0; i < M.getRows(); ++i) 250 OS << M.getRowAsVector(i) << "\n"; 251 return OS; 252} 253 254template <typename Metadata> 255class MDVector : public Vector { 256public: 257 MDVector(const Vector &v) : Vector(v), md(*this) {} 258 MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { } 259 260 const Metadata& getMetadata() const { return md; } 261 262private: 263 Metadata md; 264}; 265 266template <typename Metadata> 267inline hash_code hash_value(const MDVector<Metadata> &V) { 268 return hash_value(static_cast<const Vector&>(V)); 269} 270 271template <typename Metadata> 272class MDMatrix : public Matrix { 273public: 274 MDMatrix(const Matrix &m) : Matrix(m), md(*this) {} 275 MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { } 276 277 const Metadata& getMetadata() const { return md; } 278 279private: 280 Metadata md; 281}; 282 283template <typename Metadata> 284inline hash_code hash_value(const MDMatrix<Metadata> &M) { 285 return hash_value(static_cast<const Matrix&>(M)); 286} 287 288} // end namespace PBQP 289} // end namespace llvm 290 291#endif // LLVM_CODEGEN_PBQP_MATH_H 292