1//===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// This pass resolves calls to OpenCL image attribute, image resource ID and
11/// sampler resource ID getter functions.
12///
13/// Image attributes (size and format) are expected to be passed to the kernel
14/// as kernel arguments immediately following the image argument itself,
15/// therefore this pass adds image size and format arguments to the kernel
16/// functions in the module. The kernel functions with image arguments are
17/// re-created using the new signature. The new arguments are added to the
18/// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19/// Note: this pass may invalidate pointers to functions.
20///
21/// Resource IDs of read-only images, write-only images and samplers are
22/// defined to be their index among the kernel arguments of the same
23/// type and access qualifier.
24//
25//===----------------------------------------------------------------------===//
26
27#include "AMDGPU.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/IR/Constants.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/Instructions.h"
33#include "llvm/IR/Metadata.h"
34#include "llvm/Pass.h"
35#include "llvm/Transforms/Utils/Cloning.h"
36
37using namespace llvm;
38
39static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
40static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
41static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
42static StringRef GetSamplerResourceIDFunc =
43    "llvm.OpenCL.sampler.get.resource.id";
44
45static StringRef ImageSizeArgMDType =   "__llvm_image_size";
46static StringRef ImageFormatArgMDType = "__llvm_image_format";
47
48static StringRef KernelsMDNodeName = "opencl.kernels";
49static StringRef KernelArgMDNodeNames[] = {
50  "kernel_arg_addr_space",
51  "kernel_arg_access_qual",
52  "kernel_arg_type",
53  "kernel_arg_base_type",
54  "kernel_arg_type_qual"};
55static const unsigned NumKernelArgMDNodes = 5;
56
57namespace {
58
59using MDVector = SmallVector<Metadata *, 8>;
60struct KernelArgMD {
61  MDVector ArgVector[NumKernelArgMDNodes];
62};
63
64} // end anonymous namespace
65
66static inline bool
67IsImageType(StringRef TypeString) {
68  return TypeString == "image2d_t" || TypeString == "image3d_t";
69}
70
71static inline bool
72IsSamplerType(StringRef TypeString) {
73  return TypeString == "sampler_t";
74}
75
76static Function *
77GetFunctionFromMDNode(MDNode *Node) {
78  if (!Node)
79    return nullptr;
80
81  size_t NumOps = Node->getNumOperands();
82  if (NumOps != NumKernelArgMDNodes + 1)
83    return nullptr;
84
85  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
86  if (!F)
87    return nullptr;
88
89  // Sanity checks.
90  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92    MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
93    if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
94      return nullptr;
95    if (!ArgNode->getOperand(0))
96      return nullptr;
97
98    // FIXME: It should be possible to do image lowering when some metadata
99    // args missing or not in the expected order.
100    MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
101    if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
102      return nullptr;
103  }
104
105  return F;
106}
107
108static StringRef
109AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
110  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
111  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
112}
113
114static StringRef
115ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
116  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
117  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
118}
119
120static MDVector
121GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
122  MDVector Res;
123  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
124    MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
125    Res.push_back(Node->getOperand(OpIdx));
126  }
127  return Res;
128}
129
130static void
131PushArgMD(KernelArgMD &MD, const MDVector &V) {
132  assert(V.size() == NumKernelArgMDNodes);
133  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
134    MD.ArgVector[i].push_back(V[i]);
135  }
136}
137
138namespace {
139
140class R600OpenCLImageTypeLoweringPass : public ModulePass {
141  static char ID;
142
143  LLVMContext *Context;
144  Type *Int32Type;
145  Type *ImageSizeType;
146  Type *ImageFormatType;
147  SmallVector<Instruction *, 4> InstsToErase;
148
149  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
150                        Argument &ImageSizeArg,
151                        Argument &ImageFormatArg) {
152    bool Modified = false;
153
154    for (auto &Use : ImageArg.uses()) {
155      auto Inst = dyn_cast<CallInst>(Use.getUser());
156      if (!Inst) {
157        continue;
158      }
159
160      Function *F = Inst->getCalledFunction();
161      if (!F)
162        continue;
163
164      Value *Replacement = nullptr;
165      StringRef Name = F->getName();
166      if (Name.startswith(GetImageResourceIDFunc)) {
167        Replacement = ConstantInt::get(Int32Type, ResourceID);
168      } else if (Name.startswith(GetImageSizeFunc)) {
169        Replacement = &ImageSizeArg;
170      } else if (Name.startswith(GetImageFormatFunc)) {
171        Replacement = &ImageFormatArg;
172      } else {
173        continue;
174      }
175
176      Inst->replaceAllUsesWith(Replacement);
177      InstsToErase.push_back(Inst);
178      Modified = true;
179    }
180
181    return Modified;
182  }
183
184  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
185    bool Modified = false;
186
187    for (const auto &Use : SamplerArg.uses()) {
188      auto Inst = dyn_cast<CallInst>(Use.getUser());
189      if (!Inst) {
190        continue;
191      }
192
193      Function *F = Inst->getCalledFunction();
194      if (!F)
195        continue;
196
197      Value *Replacement = nullptr;
198      StringRef Name = F->getName();
199      if (Name == GetSamplerResourceIDFunc) {
200        Replacement = ConstantInt::get(Int32Type, ResourceID);
201      } else {
202        continue;
203      }
204
205      Inst->replaceAllUsesWith(Replacement);
206      InstsToErase.push_back(Inst);
207      Modified = true;
208    }
209
210    return Modified;
211  }
212
213  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
214    uint32_t NumReadOnlyImageArgs = 0;
215    uint32_t NumWriteOnlyImageArgs = 0;
216    uint32_t NumSamplerArgs = 0;
217
218    bool Modified = false;
219    InstsToErase.clear();
220    for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
221      Argument &Arg = *ArgI;
222      StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
223
224      // Handle image types.
225      if (IsImageType(Type)) {
226        StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
227        uint32_t ResourceID;
228        if (AccessQual == "read_only") {
229          ResourceID = NumReadOnlyImageArgs++;
230        } else if (AccessQual == "write_only") {
231          ResourceID = NumWriteOnlyImageArgs++;
232        } else {
233          llvm_unreachable("Wrong image access qualifier.");
234        }
235
236        Argument &SizeArg = *(++ArgI);
237        Argument &FormatArg = *(++ArgI);
238        Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
239
240      // Handle sampler type.
241      } else if (IsSamplerType(Type)) {
242        uint32_t ResourceID = NumSamplerArgs++;
243        Modified |= replaceSamplerUses(Arg, ResourceID);
244      }
245    }
246    for (unsigned i = 0; i < InstsToErase.size(); ++i) {
247      InstsToErase[i]->eraseFromParent();
248    }
249
250    return Modified;
251  }
252
253  std::tuple<Function *, MDNode *>
254  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255    bool Modified = false;
256
257    FunctionType *FT = F->getFunctionType();
258    SmallVector<Type *, 8> ArgTypes;
259
260    // Metadata operands for new MDNode.
261    KernelArgMD NewArgMDs;
262    PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
263
264    // Add implicit arguments to the signature.
265    for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266      ArgTypes.push_back(FT->getParamType(i));
267      MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
268      PushArgMD(NewArgMDs, ArgMD);
269
270      if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
271        continue;
272
273      // Add size implicit argument.
274      ArgTypes.push_back(ImageSizeType);
275      ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
276      PushArgMD(NewArgMDs, ArgMD);
277
278      // Add format implicit argument.
279      ArgTypes.push_back(ImageFormatType);
280      ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
281      PushArgMD(NewArgMDs, ArgMD);
282
283      Modified = true;
284    }
285    if (!Modified) {
286      return std::make_tuple(nullptr, nullptr);
287    }
288
289    // Create function with new signature and clone the old body into it.
290    auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
291    auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
292    ValueToValueMapTy VMap;
293    auto NewFArgIt = NewF->arg_begin();
294    for (auto &Arg: F->args()) {
295      auto ArgName = Arg.getName();
296      NewFArgIt->setName(ArgName);
297      VMap[&Arg] = &(*NewFArgIt++);
298      if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
299        (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300        (NewFArgIt++)->setName(Twine("__format_") + ArgName);
301      }
302    }
303    SmallVector<ReturnInst*, 8> Returns;
304    CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
305                      Returns);
306
307    // Build new MDNode.
308    SmallVector<Metadata *, 6> KernelMDArgs;
309    KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
310    for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
311      KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
312    MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
313
314    return std::make_tuple(NewF, NewMDNode);
315  }
316
317  bool transformKernels(Module &M) {
318    NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
319    if (!KernelsMDNode)
320      return false;
321
322    bool Modified = false;
323    for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
324      MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
325      Function *F = GetFunctionFromMDNode(KernelMDNode);
326      if (!F)
327        continue;
328
329      Function *NewF;
330      MDNode *NewMDNode;
331      std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
332      if (NewF) {
333        // Replace old function and metadata with new ones.
334        F->eraseFromParent();
335        M.getFunctionList().push_back(NewF);
336        M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
337                              NewF->getAttributes());
338        KernelsMDNode->setOperand(i, NewMDNode);
339
340        F = NewF;
341        KernelMDNode = NewMDNode;
342        Modified = true;
343      }
344
345      Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
346    }
347
348    return Modified;
349  }
350
351public:
352  R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
353
354  bool runOnModule(Module &M) override {
355    Context = &M.getContext();
356    Int32Type = Type::getInt32Ty(M.getContext());
357    ImageSizeType = ArrayType::get(Int32Type, 3);
358    ImageFormatType = ArrayType::get(Int32Type, 2);
359
360    return transformKernels(M);
361  }
362
363  StringRef getPassName() const override {
364    return "R600 OpenCL Image Type Pass";
365  }
366};
367
368} // end anonymous namespace
369
370char R600OpenCLImageTypeLoweringPass::ID = 0;
371
372ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
373  return new R600OpenCLImageTypeLoweringPass();
374}
375