1//===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// General infrastructure for keeping track of the values that according to
10// the SPIR-V binary layout should be global to the whole module.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15#define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16
17#include "MCTargetDesc/SPIRVBaseInfo.h"
18#include "MCTargetDesc/SPIRVMCTargetDesc.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/MapVector.h"
21#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22#include "llvm/CodeGen/MachineModuleInfo.h"
23
24#include <type_traits>
25
26namespace llvm {
27namespace SPIRV {
28// NOTE: using MapVector instead of DenseMap because it helps getting
29// everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
30// memory and expensive removals which do not happen anyway.
31class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
32  SmallVector<DTSortableEntry *, 2> Deps;
33
34  struct FlagsTy {
35    unsigned IsFunc : 1;
36    unsigned IsGV : 1;
37    // NOTE: bit-field default init is a C++20 feature.
38    FlagsTy() : IsFunc(0), IsGV(0) {}
39  };
40  FlagsTy Flags;
41
42public:
43  // Common hoisting utility doesn't support function, because their hoisting
44  // require hoisting of params as well.
45  bool getIsFunc() const { return Flags.IsFunc; }
46  bool getIsGV() const { return Flags.IsGV; }
47  void setIsFunc(bool V) { Flags.IsFunc = V; }
48  void setIsGV(bool V) { Flags.IsGV = V; }
49
50  const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
51  void addDep(DTSortableEntry *E) { Deps.push_back(E); }
52};
53
54struct SpecialTypeDescriptor {
55  enum SpecialTypeKind {
56    STK_Empty = 0,
57    STK_Image,
58    STK_SampledImage,
59    STK_Sampler,
60    STK_Pipe,
61    STK_DeviceEvent,
62    STK_Last = -1
63  };
64  SpecialTypeKind Kind;
65
66  unsigned Hash;
67
68  SpecialTypeDescriptor() = delete;
69  SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
70
71  unsigned getHash() const { return Hash; }
72
73  virtual ~SpecialTypeDescriptor() {}
74};
75
76struct ImageTypeDescriptor : public SpecialTypeDescriptor {
77  union ImageAttrs {
78    struct BitFlags {
79      unsigned Dim : 3;
80      unsigned Depth : 2;
81      unsigned Arrayed : 1;
82      unsigned MS : 1;
83      unsigned Sampled : 2;
84      unsigned ImageFormat : 6;
85      unsigned AQ : 2;
86    } Flags;
87    unsigned Val;
88  };
89
90  ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
91                      unsigned Arrayed, unsigned MS, unsigned Sampled,
92                      unsigned ImageFormat, unsigned AQ = 0)
93      : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
94    ImageAttrs Attrs;
95    Attrs.Val = 0;
96    Attrs.Flags.Dim = Dim;
97    Attrs.Flags.Depth = Depth;
98    Attrs.Flags.Arrayed = Arrayed;
99    Attrs.Flags.MS = MS;
100    Attrs.Flags.Sampled = Sampled;
101    Attrs.Flags.ImageFormat = ImageFormat;
102    Attrs.Flags.AQ = AQ;
103    Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
104           ((Attrs.Val << 8) | Kind);
105  }
106
107  static bool classof(const SpecialTypeDescriptor *TD) {
108    return TD->Kind == SpecialTypeKind::STK_Image;
109  }
110};
111
112struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
113  SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
114      : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
115    assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
116    ImageTypeDescriptor TD(
117        SampledTy, ImageTy->getOperand(2).getImm(),
118        ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
119        ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
120        ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
121    Hash = TD.getHash() ^ Kind;
122  }
123
124  static bool classof(const SpecialTypeDescriptor *TD) {
125    return TD->Kind == SpecialTypeKind::STK_SampledImage;
126  }
127};
128
129struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
130  SamplerTypeDescriptor()
131      : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
132    Hash = Kind;
133  }
134
135  static bool classof(const SpecialTypeDescriptor *TD) {
136    return TD->Kind == SpecialTypeKind::STK_Sampler;
137  }
138};
139
140struct PipeTypeDescriptor : public SpecialTypeDescriptor {
141
142  PipeTypeDescriptor(uint8_t AQ)
143      : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
144    Hash = (AQ << 8) | Kind;
145  }
146
147  static bool classof(const SpecialTypeDescriptor *TD) {
148    return TD->Kind == SpecialTypeKind::STK_Pipe;
149  }
150};
151
152struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
153
154  DeviceEventTypeDescriptor()
155      : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
156    Hash = Kind;
157  }
158
159  static bool classof(const SpecialTypeDescriptor *TD) {
160    return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
161  }
162};
163} // namespace SPIRV
164
165template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
166  static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
167    return SPIRV::SpecialTypeDescriptor(
168        SPIRV::SpecialTypeDescriptor::STK_Empty);
169  }
170  static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
171    return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
172  }
173  static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
174    return Val.getHash();
175  }
176  static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
177                      SPIRV::SpecialTypeDescriptor RHS) {
178    return getHashValue(LHS) == getHashValue(RHS);
179  }
180};
181
182template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
183public:
184  // NOTE: using MapVector instead of DenseMap helps getting everything ordered
185  // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
186  // expensive removals which don't happen anyway.
187  using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
188
189private:
190  StorageTy Storage;
191
192public:
193  void add(KeyTy V, const MachineFunction *MF, Register R) {
194    if (find(V, MF).isValid())
195      return;
196
197    Storage[V][MF] = R;
198    if (std::is_same<Function,
199                     typename std::remove_const<
200                         typename std::remove_pointer<KeyTy>::type>::type>() ||
201        std::is_same<Argument,
202                     typename std::remove_const<
203                         typename std::remove_pointer<KeyTy>::type>::type>())
204      Storage[V].setIsFunc(true);
205    if (std::is_same<GlobalVariable,
206                     typename std::remove_const<
207                         typename std::remove_pointer<KeyTy>::type>::type>())
208      Storage[V].setIsGV(true);
209  }
210
211  Register find(KeyTy V, const MachineFunction *MF) const {
212    auto iter = Storage.find(V);
213    if (iter != Storage.end()) {
214      auto Map = iter->second;
215      auto iter2 = Map.find(MF);
216      if (iter2 != Map.end())
217        return iter2->second;
218    }
219    return Register();
220  }
221
222  const StorageTy &getAllUses() const { return Storage; }
223
224private:
225  StorageTy &getAllUses() { return Storage; }
226
227  // The friend class needs to have access to the internal storage
228  // to be able to build dependency graph, can't declare only one
229  // function a 'friend' due to the incomplete declaration at this point
230  // and mutual dependency problems.
231  friend class SPIRVGeneralDuplicatesTracker;
232};
233
234template <typename T>
235class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
236
237template <>
238class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
239    : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
240
241class SPIRVGeneralDuplicatesTracker {
242  SPIRVDuplicatesTracker<Type> TT;
243  SPIRVDuplicatesTracker<Constant> CT;
244  SPIRVDuplicatesTracker<GlobalVariable> GT;
245  SPIRVDuplicatesTracker<Function> FT;
246  SPIRVDuplicatesTracker<Argument> AT;
247  SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
248
249  // NOTE: using MOs instead of regs to get rid of MF dependency to be able
250  // to use flat data structure.
251  // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
252  // but makes LITs more stable, should prefer DenseMap still due to
253  // significant perf difference.
254  using SPIRVReg2EntryTy =
255      MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
256
257  template <typename T>
258  void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
259                         SPIRVReg2EntryTy &Reg2Entry);
260
261public:
262  void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
263                      MachineModuleInfo *MMI);
264
265  void add(const Type *T, const MachineFunction *MF, Register R) {
266    TT.add(T, MF, R);
267  }
268
269  void add(const Constant *C, const MachineFunction *MF, Register R) {
270    CT.add(C, MF, R);
271  }
272
273  void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
274    GT.add(GV, MF, R);
275  }
276
277  void add(const Function *F, const MachineFunction *MF, Register R) {
278    FT.add(F, MF, R);
279  }
280
281  void add(const Argument *Arg, const MachineFunction *MF, Register R) {
282    AT.add(Arg, MF, R);
283  }
284
285  void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
286           Register R) {
287    ST.add(TD, MF, R);
288  }
289
290  Register find(const Type *T, const MachineFunction *MF) {
291    return TT.find(const_cast<Type *>(T), MF);
292  }
293
294  Register find(const Constant *C, const MachineFunction *MF) {
295    return CT.find(const_cast<Constant *>(C), MF);
296  }
297
298  Register find(const GlobalVariable *GV, const MachineFunction *MF) {
299    return GT.find(const_cast<GlobalVariable *>(GV), MF);
300  }
301
302  Register find(const Function *F, const MachineFunction *MF) {
303    return FT.find(const_cast<Function *>(F), MF);
304  }
305
306  Register find(const Argument *Arg, const MachineFunction *MF) {
307    return AT.find(const_cast<Argument *>(Arg), MF);
308  }
309
310  Register find(const SPIRV::SpecialTypeDescriptor &TD,
311                const MachineFunction *MF) {
312    return ST.find(TD, MF);
313  }
314
315  const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
316};
317} // namespace llvm
318#endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
319