1//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains class to help build DXIL op functions.
10//===----------------------------------------------------------------------===//
11
12#include "DXILOpBuilder.h"
13#include "DXILConstants.h"
14#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/Module.h"
16#include "llvm/Support/DXILOperationCommon.h"
17#include "llvm/Support/ErrorHandling.h"
18
19using namespace llvm;
20using namespace llvm::dxil;
21
22constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23
24namespace {
25
26enum OverloadKind : uint16_t {
27  VOID = 1,
28  HALF = 1 << 1,
29  FLOAT = 1 << 2,
30  DOUBLE = 1 << 3,
31  I1 = 1 << 4,
32  I8 = 1 << 5,
33  I16 = 1 << 6,
34  I32 = 1 << 7,
35  I64 = 1 << 8,
36  UserDefineType = 1 << 9,
37  ObjectType = 1 << 10,
38};
39
40} // namespace
41
42static const char *getOverloadTypeName(OverloadKind Kind) {
43  switch (Kind) {
44  case OverloadKind::HALF:
45    return "f16";
46  case OverloadKind::FLOAT:
47    return "f32";
48  case OverloadKind::DOUBLE:
49    return "f64";
50  case OverloadKind::I1:
51    return "i1";
52  case OverloadKind::I8:
53    return "i8";
54  case OverloadKind::I16:
55    return "i16";
56  case OverloadKind::I32:
57    return "i32";
58  case OverloadKind::I64:
59    return "i64";
60  case OverloadKind::VOID:
61  case OverloadKind::ObjectType:
62  case OverloadKind::UserDefineType:
63    break;
64  }
65  llvm_unreachable("invalid overload type for name");
66  return "void";
67}
68
69static OverloadKind getOverloadKind(Type *Ty) {
70  Type::TypeID T = Ty->getTypeID();
71  switch (T) {
72  case Type::VoidTyID:
73    return OverloadKind::VOID;
74  case Type::HalfTyID:
75    return OverloadKind::HALF;
76  case Type::FloatTyID:
77    return OverloadKind::FLOAT;
78  case Type::DoubleTyID:
79    return OverloadKind::DOUBLE;
80  case Type::IntegerTyID: {
81    IntegerType *ITy = cast<IntegerType>(Ty);
82    unsigned Bits = ITy->getBitWidth();
83    switch (Bits) {
84    case 1:
85      return OverloadKind::I1;
86    case 8:
87      return OverloadKind::I8;
88    case 16:
89      return OverloadKind::I16;
90    case 32:
91      return OverloadKind::I32;
92    case 64:
93      return OverloadKind::I64;
94    default:
95      llvm_unreachable("invalid overload type");
96      return OverloadKind::VOID;
97    }
98  }
99  case Type::PointerTyID:
100    return OverloadKind::UserDefineType;
101  case Type::StructTyID:
102    return OverloadKind::ObjectType;
103  default:
104    llvm_unreachable("invalid overload type");
105    return OverloadKind::VOID;
106  }
107}
108
109static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110  if (Kind < OverloadKind::UserDefineType) {
111    return getOverloadTypeName(Kind);
112  } else if (Kind == OverloadKind::UserDefineType) {
113    StructType *ST = cast<StructType>(Ty);
114    return ST->getStructName().str();
115  } else if (Kind == OverloadKind::ObjectType) {
116    StructType *ST = cast<StructType>(Ty);
117    return ST->getStructName().str();
118  } else {
119    std::string Str;
120    raw_string_ostream OS(Str);
121    Ty->print(OS);
122    return OS.str();
123  }
124}
125
126// Static properties.
127struct OpCodeProperty {
128  dxil::OpCode OpCode;
129  // Offset in DXILOpCodeNameTable.
130  unsigned OpCodeNameOffset;
131  dxil::OpCodeClass OpCodeClass;
132  // Offset in DXILOpCodeClassNameTable.
133  unsigned OpCodeClassNameOffset;
134  uint16_t OverloadTys;
135  llvm::Attribute::AttrKind FuncAttr;
136  int OverloadParamIndex;        // parameter index which control the overload.
137                                 // When < 0, should be only 1 overload type.
138  unsigned NumOfParameters;      // Number of parameters include return value.
139  unsigned ParameterTableOffset; // Offset in ParameterTable.
140};
141
142// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143// getOpCodeParameterKind which generated by tableGen.
144#define DXIL_OP_OPERATION_TABLE
145#include "DXILOperation.inc"
146#undef DXIL_OP_OPERATION_TABLE
147
148static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149                                         const OpCodeProperty &Prop) {
150  if (Kind == OverloadKind::VOID) {
151    return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152  }
153  return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154          getTypeName(Kind, Ty))
155      .str();
156}
157
158static std::string constructOverloadTypeName(OverloadKind Kind,
159                                             StringRef TypeName) {
160  if (Kind == OverloadKind::VOID)
161    return TypeName.str();
162
163  assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164  return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165}
166
167static StructType *getOrCreateStructType(StringRef Name,
168                                         ArrayRef<Type *> EltTys,
169                                         LLVMContext &Ctx) {
170  StructType *ST = StructType::getTypeByName(Ctx, Name);
171  if (ST)
172    return ST;
173
174  return StructType::create(Ctx, EltTys, Name);
175}
176
177static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178  OverloadKind Kind = getOverloadKind(OverloadTy);
179  std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
180  Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181                         Type::getInt32Ty(Ctx)};
182  return getOrCreateStructType(TypeName, FieldTypes, Ctx);
183}
184
185static StructType *getHandleType(LLVMContext &Ctx) {
186  return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
187                               Ctx);
188}
189
190static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
191  auto &Ctx = OverloadTy->getContext();
192  switch (Kind) {
193  case ParameterKind::VOID:
194    return Type::getVoidTy(Ctx);
195  case ParameterKind::HALF:
196    return Type::getHalfTy(Ctx);
197  case ParameterKind::FLOAT:
198    return Type::getFloatTy(Ctx);
199  case ParameterKind::DOUBLE:
200    return Type::getDoubleTy(Ctx);
201  case ParameterKind::I1:
202    return Type::getInt1Ty(Ctx);
203  case ParameterKind::I8:
204    return Type::getInt8Ty(Ctx);
205  case ParameterKind::I16:
206    return Type::getInt16Ty(Ctx);
207  case ParameterKind::I32:
208    return Type::getInt32Ty(Ctx);
209  case ParameterKind::I64:
210    return Type::getInt64Ty(Ctx);
211  case ParameterKind::OVERLOAD:
212    return OverloadTy;
213  case ParameterKind::RESOURCE_RET:
214    return getResRetType(OverloadTy, Ctx);
215  case ParameterKind::DXIL_HANDLE:
216    return getHandleType(Ctx);
217  default:
218    break;
219  }
220  llvm_unreachable("Invalid parameter kind");
221  return nullptr;
222}
223
224static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
225                                           Type *OverloadTy) {
226  SmallVector<Type *> ArgTys;
227
228  auto ParamKinds = getOpCodeParameterKind(*Prop);
229
230  for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
231    ParameterKind Kind = ParamKinds[I];
232    ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
233  }
234  return FunctionType::get(
235      ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
236}
237
238static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
239                                                Type *OverloadTy, Module &M) {
240  const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
241
242  OverloadKind Kind = getOverloadKind(OverloadTy);
243  // FIXME: find the issue and report error in clang instead of check it in
244  // backend.
245  if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
246    llvm_unreachable("invalid overload");
247  }
248
249  std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
250  // Dependent on name to dedup.
251  if (auto *Fn = M.getFunction(FnName))
252    return FunctionCallee(Fn);
253
254  FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
255  return M.getOrInsertFunction(FnName, DXILOpFT);
256}
257
258namespace llvm {
259namespace dxil {
260
261CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
262                                          llvm::iterator_range<Use *> Args) {
263  auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
264  SmallVector<Value *> FullArgs;
265  FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
266  FullArgs.append(Args.begin(), Args.end());
267  return B.CreateCall(Fn, FullArgs);
268}
269
270Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
271                                   bool NoOpCodeParam) {
272
273  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
274  if (Prop->OverloadParamIndex < 0) {
275    auto &Ctx = FT->getContext();
276    // When only has 1 overload type, just return it.
277    switch (Prop->OverloadTys) {
278    case OverloadKind::VOID:
279      return Type::getVoidTy(Ctx);
280    case OverloadKind::HALF:
281      return Type::getHalfTy(Ctx);
282    case OverloadKind::FLOAT:
283      return Type::getFloatTy(Ctx);
284    case OverloadKind::DOUBLE:
285      return Type::getDoubleTy(Ctx);
286    case OverloadKind::I1:
287      return Type::getInt1Ty(Ctx);
288    case OverloadKind::I8:
289      return Type::getInt8Ty(Ctx);
290    case OverloadKind::I16:
291      return Type::getInt16Ty(Ctx);
292    case OverloadKind::I32:
293      return Type::getInt32Ty(Ctx);
294    case OverloadKind::I64:
295      return Type::getInt64Ty(Ctx);
296    default:
297      llvm_unreachable("invalid overload type");
298      return nullptr;
299    }
300  }
301
302  // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
303  Type *OverloadType = FT->getReturnType();
304  if (Prop->OverloadParamIndex != 0) {
305    // Skip Return Type and Type for DXIL opcode.
306    const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
307    OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
308  }
309
310  auto ParamKinds = getOpCodeParameterKind(*Prop);
311  auto Kind = ParamKinds[Prop->OverloadParamIndex];
312  // For ResRet and CBufferRet, OverloadTy is in field of StructType.
313  if (Kind == ParameterKind::CBUFFER_RET ||
314      Kind == ParameterKind::RESOURCE_RET) {
315    auto *ST = cast<StructType>(OverloadType);
316    OverloadType = ST->getElementType(0);
317  }
318  return OverloadType;
319}
320
321const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
322  return ::getOpCodeName(DXILOp);
323}
324} // namespace dxil
325} // namespace llvm
326