Math.h revision 360784
1//===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_CODEGEN_PBQP_MATH_H
10#define LLVM_CODEGEN_PBQP_MATH_H
11
12#include "llvm/ADT/Hashing.h"
13#include "llvm/ADT/STLExtras.h"
14#include <algorithm>
15#include <cassert>
16#include <functional>
17#include <memory>
18
19namespace llvm {
20namespace PBQP {
21
22using PBQPNum = float;
23
24/// PBQP Vector class.
25class Vector {
26  friend hash_code hash_value(const Vector &);
27
28public:
29  /// Construct a PBQP vector of the given size.
30  explicit Vector(unsigned Length)
31    : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {}
32
33  /// Construct a PBQP vector with initializer.
34  Vector(unsigned Length, PBQPNum InitVal)
35    : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {
36    std::fill(Data.get(), Data.get() + Length, InitVal);
37  }
38
39  /// Copy construct a PBQP vector.
40  Vector(const Vector &V)
41    : Length(V.Length), Data(std::make_unique<PBQPNum []>(Length)) {
42    std::copy(V.Data.get(), V.Data.get() + Length, Data.get());
43  }
44
45  /// Move construct a PBQP vector.
46  Vector(Vector &&V)
47    : Length(V.Length), Data(std::move(V.Data)) {
48    V.Length = 0;
49  }
50
51  /// Comparison operator.
52  bool operator==(const Vector &V) const {
53    assert(Length != 0 && Data && "Invalid vector");
54    if (Length != V.Length)
55      return false;
56    return std::equal(Data.get(), Data.get() + Length, V.Data.get());
57  }
58
59  /// Return the length of the vector
60  unsigned getLength() const {
61    assert(Length != 0 && Data && "Invalid vector");
62    return Length;
63  }
64
65  /// Element access.
66  PBQPNum& operator[](unsigned Index) {
67    assert(Length != 0 && Data && "Invalid vector");
68    assert(Index < Length && "Vector element access out of bounds.");
69    return Data[Index];
70  }
71
72  /// Const element access.
73  const PBQPNum& operator[](unsigned Index) const {
74    assert(Length != 0 && Data && "Invalid vector");
75    assert(Index < Length && "Vector element access out of bounds.");
76    return Data[Index];
77  }
78
79  /// Add another vector to this one.
80  Vector& operator+=(const Vector &V) {
81    assert(Length != 0 && Data && "Invalid vector");
82    assert(Length == V.Length && "Vector length mismatch.");
83    std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(),
84                   std::plus<PBQPNum>());
85    return *this;
86  }
87
88  /// Returns the index of the minimum value in this vector
89  unsigned minIndex() const {
90    assert(Length != 0 && Data && "Invalid vector");
91    return std::min_element(Data.get(), Data.get() + Length) - Data.get();
92  }
93
94private:
95  unsigned Length;
96  std::unique_ptr<PBQPNum []> Data;
97};
98
99/// Return a hash_value for the given vector.
100inline hash_code hash_value(const Vector &V) {
101  unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get());
102  unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length);
103  return hash_combine(V.Length, hash_combine_range(VBegin, VEnd));
104}
105
106/// Output a textual representation of the given vector on the given
107///        output stream.
108template <typename OStream>
109OStream& operator<<(OStream &OS, const Vector &V) {
110  assert((V.getLength() != 0) && "Zero-length vector badness.");
111
112  OS << "[ " << V[0];
113  for (unsigned i = 1; i < V.getLength(); ++i)
114    OS << ", " << V[i];
115  OS << " ]";
116
117  return OS;
118}
119
120/// PBQP Matrix class
121class Matrix {
122private:
123  friend hash_code hash_value(const Matrix &);
124
125public:
126  /// Construct a PBQP Matrix with the given dimensions.
127  Matrix(unsigned Rows, unsigned Cols) :
128    Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
129  }
130
131  /// Construct a PBQP Matrix with the given dimensions and initial
132  /// value.
133  Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
134    : Rows(Rows), Cols(Cols),
135      Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
136    std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal);
137  }
138
139  /// Copy construct a PBQP matrix.
140  Matrix(const Matrix &M)
141    : Rows(M.Rows), Cols(M.Cols),
142      Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
143    std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get());
144  }
145
146  /// Move construct a PBQP matrix.
147  Matrix(Matrix &&M)
148    : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) {
149    M.Rows = M.Cols = 0;
150  }
151
152  /// Comparison operator.
153  bool operator==(const Matrix &M) const {
154    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
155    if (Rows != M.Rows || Cols != M.Cols)
156      return false;
157    return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get());
158  }
159
160  /// Return the number of rows in this matrix.
161  unsigned getRows() const {
162    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
163    return Rows;
164  }
165
166  /// Return the number of cols in this matrix.
167  unsigned getCols() const {
168    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
169    return Cols;
170  }
171
172  /// Matrix element access.
173  PBQPNum* operator[](unsigned R) {
174    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
175    assert(R < Rows && "Row out of bounds.");
176    return Data.get() + (R * Cols);
177  }
178
179  /// Matrix element access.
180  const PBQPNum* operator[](unsigned R) const {
181    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
182    assert(R < Rows && "Row out of bounds.");
183    return Data.get() + (R * Cols);
184  }
185
186  /// Returns the given row as a vector.
187  Vector getRowAsVector(unsigned R) const {
188    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
189    Vector V(Cols);
190    for (unsigned C = 0; C < Cols; ++C)
191      V[C] = (*this)[R][C];
192    return V;
193  }
194
195  /// Returns the given column as a vector.
196  Vector getColAsVector(unsigned C) const {
197    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
198    Vector V(Rows);
199    for (unsigned R = 0; R < Rows; ++R)
200      V[R] = (*this)[R][C];
201    return V;
202  }
203
204  /// Matrix transpose.
205  Matrix transpose() const {
206    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
207    Matrix M(Cols, Rows);
208    for (unsigned r = 0; r < Rows; ++r)
209      for (unsigned c = 0; c < Cols; ++c)
210        M[c][r] = (*this)[r][c];
211    return M;
212  }
213
214  /// Add the given matrix to this one.
215  Matrix& operator+=(const Matrix &M) {
216    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
217    assert(Rows == M.Rows && Cols == M.Cols &&
218           "Matrix dimensions mismatch.");
219    std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(),
220                   Data.get(), std::plus<PBQPNum>());
221    return *this;
222  }
223
224  Matrix operator+(const Matrix &M) {
225    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
226    Matrix Tmp(*this);
227    Tmp += M;
228    return Tmp;
229  }
230
231private:
232  unsigned Rows, Cols;
233  std::unique_ptr<PBQPNum []> Data;
234};
235
236/// Return a hash_code for the given matrix.
237inline hash_code hash_value(const Matrix &M) {
238  unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get());
239  unsigned *MEnd =
240    reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols));
241  return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
242}
243
244/// Output a textual representation of the given matrix on the given
245///        output stream.
246template <typename OStream>
247OStream& operator<<(OStream &OS, const Matrix &M) {
248  assert((M.getRows() != 0) && "Zero-row matrix badness.");
249  for (unsigned i = 0; i < M.getRows(); ++i)
250    OS << M.getRowAsVector(i) << "\n";
251  return OS;
252}
253
254template <typename Metadata>
255class MDVector : public Vector {
256public:
257  MDVector(const Vector &v) : Vector(v), md(*this) {}
258  MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }
259
260  const Metadata& getMetadata() const { return md; }
261
262private:
263  Metadata md;
264};
265
266template <typename Metadata>
267inline hash_code hash_value(const MDVector<Metadata> &V) {
268  return hash_value(static_cast<const Vector&>(V));
269}
270
271template <typename Metadata>
272class MDMatrix : public Matrix {
273public:
274  MDMatrix(const Matrix &m) : Matrix(m), md(*this) {}
275  MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }
276
277  const Metadata& getMetadata() const { return md; }
278
279private:
280  Metadata md;
281};
282
283template <typename Metadata>
284inline hash_code hash_value(const MDMatrix<Metadata> &M) {
285  return hash_value(static_cast<const Matrix&>(M));
286}
287
288} // end namespace PBQP
289} // end namespace llvm
290
291#endif // LLVM_CODEGEN_PBQP_MATH_H
292