1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19#include "llvm/CodeGen/MachineInstr.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/CodeGen/TargetOpcodes.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28    TargetOpcode::G_ADD,
29    TargetOpcode::G_FADD,
30    TargetOpcode::G_SUB,
31    TargetOpcode::G_FSUB,
32    TargetOpcode::G_MUL,
33    TargetOpcode::G_FMUL,
34    TargetOpcode::G_SDIV,
35    TargetOpcode::G_UDIV,
36    TargetOpcode::G_FDIV,
37    TargetOpcode::G_SREM,
38    TargetOpcode::G_UREM,
39    TargetOpcode::G_FREM,
40    TargetOpcode::G_FNEG,
41    TargetOpcode::G_CONSTANT,
42    TargetOpcode::G_FCONSTANT,
43    TargetOpcode::G_AND,
44    TargetOpcode::G_OR,
45    TargetOpcode::G_XOR,
46    TargetOpcode::G_SHL,
47    TargetOpcode::G_ASHR,
48    TargetOpcode::G_LSHR,
49    TargetOpcode::G_SELECT,
50    TargetOpcode::G_EXTRACT_VECTOR_ELT,
51};
52
53bool isTypeFoldingSupported(unsigned Opcode) {
54  return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55}
56
57SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58  using namespace TargetOpcode;
59
60  this->ST = &ST;
61  GR = ST.getSPIRVGlobalRegistry();
62
63  const LLT s1 = LLT::scalar(1);
64  const LLT s8 = LLT::scalar(8);
65  const LLT s16 = LLT::scalar(16);
66  const LLT s32 = LLT::scalar(32);
67  const LLT s64 = LLT::scalar(64);
68
69  const LLT v16s64 = LLT::fixed_vector(16, 64);
70  const LLT v16s32 = LLT::fixed_vector(16, 32);
71  const LLT v16s16 = LLT::fixed_vector(16, 16);
72  const LLT v16s8 = LLT::fixed_vector(16, 8);
73  const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75  const LLT v8s64 = LLT::fixed_vector(8, 64);
76  const LLT v8s32 = LLT::fixed_vector(8, 32);
77  const LLT v8s16 = LLT::fixed_vector(8, 16);
78  const LLT v8s8 = LLT::fixed_vector(8, 8);
79  const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81  const LLT v4s64 = LLT::fixed_vector(4, 64);
82  const LLT v4s32 = LLT::fixed_vector(4, 32);
83  const LLT v4s16 = LLT::fixed_vector(4, 16);
84  const LLT v4s8 = LLT::fixed_vector(4, 8);
85  const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87  const LLT v3s64 = LLT::fixed_vector(3, 64);
88  const LLT v3s32 = LLT::fixed_vector(3, 32);
89  const LLT v3s16 = LLT::fixed_vector(3, 16);
90  const LLT v3s8 = LLT::fixed_vector(3, 8);
91  const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93  const LLT v2s64 = LLT::fixed_vector(2, 64);
94  const LLT v2s32 = LLT::fixed_vector(2, 32);
95  const LLT v2s16 = LLT::fixed_vector(2, 16);
96  const LLT v2s8 = LLT::fixed_vector(2, 8);
97  const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99  const unsigned PSize = ST.getPointerSize();
100  const LLT p0 = LLT::pointer(0, PSize); // Function
101  const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102  const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103  const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104  const LLT p4 = LLT::pointer(4, PSize); // Generic
105  const LLT p5 = LLT::pointer(5, PSize); // Input
106
107  // TODO: remove copy-pasting here by using concatenation in some way.
108  auto allPtrsScalarsAndVectors = {
109      p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
110      s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
111      v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
112      v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113
114  auto allScalarsAndVectors = {
115      s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
116      v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
117      v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118
119  auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
120                                  v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
121                                  v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
122                                  v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123
124  auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125
126  auto allIntScalars = {s8, s16, s32, s64};
127
128  auto allFloatScalarsAndVectors = {
129      s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
130      v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131
132  auto allFloatAndIntScalars = allIntScalars;
133
134  auto allPtrs = {p0, p1, p2, p3, p4, p5};
135  auto allWritablePtrs = {p0, p1, p3, p4};
136
137  for (auto Opc : TypeFoldingSupportingOpcs)
138    getActionDefinitionsBuilder(Opc).custom();
139
140  getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141
142  // TODO: add proper rules for vectors legalization.
143  getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144
145  getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146      .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147
148  getActionDefinitionsBuilder(G_MEMSET).legalIf(
149      all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150
151  getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152      .legalForCartesianProduct(allPtrs, allPtrs);
153
154  getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155
156  getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157
158  getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159
160  getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161      .legalForCartesianProduct(allIntScalarsAndVectors,
162                                allFloatScalarsAndVectors);
163
164  getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165      .legalForCartesianProduct(allFloatScalarsAndVectors,
166                                allScalarsAndVectors);
167
168  getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169      .legalFor(allIntScalarsAndVectors);
170
171  getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
172      allIntScalarsAndVectors, allIntScalarsAndVectors);
173
174  getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175
176  getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
177      typeInSet(0, allPtrsScalarsAndVectors),
178      typeInSet(1, allPtrsScalarsAndVectors),
179      LegalityPredicate(([=](const LegalityQuery &Query) {
180        return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181      }))));
182
183  getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
184
185  getActionDefinitionsBuilder(G_INTTOPTR)
186      .legalForCartesianProduct(allPtrs, allIntScalars);
187  getActionDefinitionsBuilder(G_PTRTOINT)
188      .legalForCartesianProduct(allIntScalars, allPtrs);
189  getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
190      allPtrs, allIntScalars);
191
192  // ST.canDirectlyComparePointers() for pointer args is supported in
193  // legalizeCustom().
194  getActionDefinitionsBuilder(G_ICMP).customIf(
195      all(typeInSet(0, allBoolScalarsAndVectors),
196          typeInSet(1, allPtrsScalarsAndVectors)));
197
198  getActionDefinitionsBuilder(G_FCMP).legalIf(
199      all(typeInSet(0, allBoolScalarsAndVectors),
200          typeInSet(1, allFloatScalarsAndVectors)));
201
202  getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203                               G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204                               G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205                               G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206      .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207
208  getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209      .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210
211  getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212  // TODO: add proper legalization rules.
213  getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214
215  getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216      .alwaysLegal();
217
218  // Extensions.
219  getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220      .legalForCartesianProduct(allScalarsAndVectors);
221
222  // FP conversions.
223  getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224      .legalForCartesianProduct(allFloatScalarsAndVectors);
225
226  // Pointer-handling.
227  getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228
229  // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230  getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231
232  // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
233  // tighten these requirements. Many of these math functions are only legal on
234  // specific bitwidths, so they are not selectable for
235  // allFloatScalarsAndVectors.
236  getActionDefinitionsBuilder({G_FPOW,
237                               G_FEXP,
238                               G_FEXP2,
239                               G_FLOG,
240                               G_FLOG2,
241                               G_FLOG10,
242                               G_FABS,
243                               G_FMINNUM,
244                               G_FMAXNUM,
245                               G_FCEIL,
246                               G_FCOS,
247                               G_FSIN,
248                               G_FSQRT,
249                               G_FFLOOR,
250                               G_FRINT,
251                               G_FNEARBYINT,
252                               G_INTRINSIC_ROUND,
253                               G_INTRINSIC_TRUNC,
254                               G_FMINIMUM,
255                               G_FMAXIMUM,
256                               G_INTRINSIC_ROUNDEVEN})
257      .legalFor(allFloatScalarsAndVectors);
258
259  getActionDefinitionsBuilder(G_FCOPYSIGN)
260      .legalForCartesianProduct(allFloatScalarsAndVectors,
261                                allFloatScalarsAndVectors);
262
263  getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
264      allFloatScalarsAndVectors, allIntScalarsAndVectors);
265
266  if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
267    getActionDefinitionsBuilder(
268        {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
269        .legalForCartesianProduct(allIntScalarsAndVectors,
270                                  allIntScalarsAndVectors);
271
272    // Struct return types become a single scalar, so cannot easily legalize.
273    getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
274  }
275
276  getLegacyLegalizerInfo().computeTables();
277  verify(*ST.getInstrInfo());
278}
279
280static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
281                                LegalizerHelper &Helper,
282                                MachineRegisterInfo &MRI,
283                                SPIRVGlobalRegistry *GR) {
284  Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
285  GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
286  Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
287      .addDef(ConvReg)
288      .addUse(Reg);
289  return ConvReg;
290}
291
292bool SPIRVLegalizerInfo::legalizeCustom(
293    LegalizerHelper &Helper, MachineInstr &MI,
294    LostDebugLocObserver &LocObserver) const {
295  auto Opc = MI.getOpcode();
296  MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
297  if (!isTypeFoldingSupported(Opc)) {
298    assert(Opc == TargetOpcode::G_ICMP);
299    assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
300    auto &Op0 = MI.getOperand(2);
301    auto &Op1 = MI.getOperand(3);
302    Register Reg0 = Op0.getReg();
303    Register Reg1 = Op1.getReg();
304    CmpInst::Predicate Cond =
305        static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
306    if ((!ST->canDirectlyComparePointers() ||
307         (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
308        MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
309      LLT ConvT = LLT::scalar(ST->getPointerSize());
310      Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
311                                      ST->getPointerSize());
312      SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
313      Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
314      Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
315    }
316    return true;
317  }
318  // TODO: implement legalization for other opcodes.
319  return true;
320}
321