1//===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// This pass resolves calls to OpenCL image attribute, image resource ID and
11/// sampler resource ID getter functions.
12///
13/// Image attributes (size and format) are expected to be passed to the kernel
14/// as kernel arguments immediately following the image argument itself,
15/// therefore this pass adds image size and format arguments to the kernel
16/// functions in the module. The kernel functions with image arguments are
17/// re-created using the new signature. The new arguments are added to the
18/// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19/// Note: this pass may invalidate pointers to functions.
20///
21/// Resource IDs of read-only images, write-only images and samplers are
22/// defined to be their index among the kernel arguments of the same
23/// type and access qualifier.
24//
25//===----------------------------------------------------------------------===//
26
27#include "AMDGPU.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/ADT/Twine.h"
31#include "llvm/IR/Argument.h"
32#include "llvm/IR/DerivedTypes.h"
33#include "llvm/IR/Constants.h"
34#include "llvm/IR/Function.h"
35#include "llvm/IR/Instruction.h"
36#include "llvm/IR/Instructions.h"
37#include "llvm/IR/Metadata.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/Type.h"
40#include "llvm/IR/Use.h"
41#include "llvm/IR/User.h"
42#include "llvm/Pass.h"
43#include "llvm/Support/Casting.h"
44#include "llvm/Support/ErrorHandling.h"
45#include "llvm/Transforms/Utils/Cloning.h"
46#include "llvm/Transforms/Utils/ValueMapper.h"
47#include <cassert>
48#include <cstddef>
49#include <cstdint>
50#include <tuple>
51
52using namespace llvm;
53
54static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
55static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
56static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
57static StringRef GetSamplerResourceIDFunc =
58    "llvm.OpenCL.sampler.get.resource.id";
59
60static StringRef ImageSizeArgMDType =   "__llvm_image_size";
61static StringRef ImageFormatArgMDType = "__llvm_image_format";
62
63static StringRef KernelsMDNodeName = "opencl.kernels";
64static StringRef KernelArgMDNodeNames[] = {
65  "kernel_arg_addr_space",
66  "kernel_arg_access_qual",
67  "kernel_arg_type",
68  "kernel_arg_base_type",
69  "kernel_arg_type_qual"};
70static const unsigned NumKernelArgMDNodes = 5;
71
72namespace {
73
74using MDVector = SmallVector<Metadata *, 8>;
75struct KernelArgMD {
76  MDVector ArgVector[NumKernelArgMDNodes];
77};
78
79} // end anonymous namespace
80
81static inline bool
82IsImageType(StringRef TypeString) {
83  return TypeString == "image2d_t" || TypeString == "image3d_t";
84}
85
86static inline bool
87IsSamplerType(StringRef TypeString) {
88  return TypeString == "sampler_t";
89}
90
91static Function *
92GetFunctionFromMDNode(MDNode *Node) {
93  if (!Node)
94    return nullptr;
95
96  size_t NumOps = Node->getNumOperands();
97  if (NumOps != NumKernelArgMDNodes + 1)
98    return nullptr;
99
100  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
101  if (!F)
102    return nullptr;
103
104  // Sanity checks.
105  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
106  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
107    MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
108    if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
109      return nullptr;
110    if (!ArgNode->getOperand(0))
111      return nullptr;
112
113    // FIXME: It should be possible to do image lowering when some metadata
114    // args missing or not in the expected order.
115    MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
116    if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
117      return nullptr;
118  }
119
120  return F;
121}
122
123static StringRef
124AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
125  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
126  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
127}
128
129static StringRef
130ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
131  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
132  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
133}
134
135static MDVector
136GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
137  MDVector Res;
138  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
139    MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
140    Res.push_back(Node->getOperand(OpIdx));
141  }
142  return Res;
143}
144
145static void
146PushArgMD(KernelArgMD &MD, const MDVector &V) {
147  assert(V.size() == NumKernelArgMDNodes);
148  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
149    MD.ArgVector[i].push_back(V[i]);
150  }
151}
152
153namespace {
154
155class R600OpenCLImageTypeLoweringPass : public ModulePass {
156  static char ID;
157
158  LLVMContext *Context;
159  Type *Int32Type;
160  Type *ImageSizeType;
161  Type *ImageFormatType;
162  SmallVector<Instruction *, 4> InstsToErase;
163
164  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
165                        Argument &ImageSizeArg,
166                        Argument &ImageFormatArg) {
167    bool Modified = false;
168
169    for (auto &Use : ImageArg.uses()) {
170      auto Inst = dyn_cast<CallInst>(Use.getUser());
171      if (!Inst) {
172        continue;
173      }
174
175      Function *F = Inst->getCalledFunction();
176      if (!F)
177        continue;
178
179      Value *Replacement = nullptr;
180      StringRef Name = F->getName();
181      if (Name.startswith(GetImageResourceIDFunc)) {
182        Replacement = ConstantInt::get(Int32Type, ResourceID);
183      } else if (Name.startswith(GetImageSizeFunc)) {
184        Replacement = &ImageSizeArg;
185      } else if (Name.startswith(GetImageFormatFunc)) {
186        Replacement = &ImageFormatArg;
187      } else {
188        continue;
189      }
190
191      Inst->replaceAllUsesWith(Replacement);
192      InstsToErase.push_back(Inst);
193      Modified = true;
194    }
195
196    return Modified;
197  }
198
199  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
200    bool Modified = false;
201
202    for (const auto &Use : SamplerArg.uses()) {
203      auto Inst = dyn_cast<CallInst>(Use.getUser());
204      if (!Inst) {
205        continue;
206      }
207
208      Function *F = Inst->getCalledFunction();
209      if (!F)
210        continue;
211
212      Value *Replacement = nullptr;
213      StringRef Name = F->getName();
214      if (Name == GetSamplerResourceIDFunc) {
215        Replacement = ConstantInt::get(Int32Type, ResourceID);
216      } else {
217        continue;
218      }
219
220      Inst->replaceAllUsesWith(Replacement);
221      InstsToErase.push_back(Inst);
222      Modified = true;
223    }
224
225    return Modified;
226  }
227
228  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
229    uint32_t NumReadOnlyImageArgs = 0;
230    uint32_t NumWriteOnlyImageArgs = 0;
231    uint32_t NumSamplerArgs = 0;
232
233    bool Modified = false;
234    InstsToErase.clear();
235    for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
236      Argument &Arg = *ArgI;
237      StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
238
239      // Handle image types.
240      if (IsImageType(Type)) {
241        StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
242        uint32_t ResourceID;
243        if (AccessQual == "read_only") {
244          ResourceID = NumReadOnlyImageArgs++;
245        } else if (AccessQual == "write_only") {
246          ResourceID = NumWriteOnlyImageArgs++;
247        } else {
248          llvm_unreachable("Wrong image access qualifier.");
249        }
250
251        Argument &SizeArg = *(++ArgI);
252        Argument &FormatArg = *(++ArgI);
253        Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
254
255      // Handle sampler type.
256      } else if (IsSamplerType(Type)) {
257        uint32_t ResourceID = NumSamplerArgs++;
258        Modified |= replaceSamplerUses(Arg, ResourceID);
259      }
260    }
261    for (unsigned i = 0; i < InstsToErase.size(); ++i) {
262      InstsToErase[i]->eraseFromParent();
263    }
264
265    return Modified;
266  }
267
268  std::tuple<Function *, MDNode *>
269  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
270    bool Modified = false;
271
272    FunctionType *FT = F->getFunctionType();
273    SmallVector<Type *, 8> ArgTypes;
274
275    // Metadata operands for new MDNode.
276    KernelArgMD NewArgMDs;
277    PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
278
279    // Add implicit arguments to the signature.
280    for (unsigned i = 0; i < FT->getNumParams(); ++i) {
281      ArgTypes.push_back(FT->getParamType(i));
282      MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
283      PushArgMD(NewArgMDs, ArgMD);
284
285      if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
286        continue;
287
288      // Add size implicit argument.
289      ArgTypes.push_back(ImageSizeType);
290      ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
291      PushArgMD(NewArgMDs, ArgMD);
292
293      // Add format implicit argument.
294      ArgTypes.push_back(ImageFormatType);
295      ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
296      PushArgMD(NewArgMDs, ArgMD);
297
298      Modified = true;
299    }
300    if (!Modified) {
301      return std::make_tuple(nullptr, nullptr);
302    }
303
304    // Create function with new signature and clone the old body into it.
305    auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
306    auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
307    ValueToValueMapTy VMap;
308    auto NewFArgIt = NewF->arg_begin();
309    for (auto &Arg: F->args()) {
310      auto ArgName = Arg.getName();
311      NewFArgIt->setName(ArgName);
312      VMap[&Arg] = &(*NewFArgIt++);
313      if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
314        (NewFArgIt++)->setName(Twine("__size_") + ArgName);
315        (NewFArgIt++)->setName(Twine("__format_") + ArgName);
316      }
317    }
318    SmallVector<ReturnInst*, 8> Returns;
319    CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
320
321    // Build new MDNode.
322    SmallVector<Metadata *, 6> KernelMDArgs;
323    KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
324    for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
325      KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
326    MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
327
328    return std::make_tuple(NewF, NewMDNode);
329  }
330
331  bool transformKernels(Module &M) {
332    NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
333    if (!KernelsMDNode)
334      return false;
335
336    bool Modified = false;
337    for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
338      MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
339      Function *F = GetFunctionFromMDNode(KernelMDNode);
340      if (!F)
341        continue;
342
343      Function *NewF;
344      MDNode *NewMDNode;
345      std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
346      if (NewF) {
347        // Replace old function and metadata with new ones.
348        F->eraseFromParent();
349        M.getFunctionList().push_back(NewF);
350        M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
351                              NewF->getAttributes());
352        KernelsMDNode->setOperand(i, NewMDNode);
353
354        F = NewF;
355        KernelMDNode = NewMDNode;
356        Modified = true;
357      }
358
359      Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
360    }
361
362    return Modified;
363  }
364
365public:
366  R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
367
368  bool runOnModule(Module &M) override {
369    Context = &M.getContext();
370    Int32Type = Type::getInt32Ty(M.getContext());
371    ImageSizeType = ArrayType::get(Int32Type, 3);
372    ImageFormatType = ArrayType::get(Int32Type, 2);
373
374    return transformKernels(M);
375  }
376
377  StringRef getPassName() const override {
378    return "R600 OpenCL Image Type Pass";
379  }
380};
381
382} // end anonymous namespace
383
384char R600OpenCLImageTypeLoweringPass::ID = 0;
385
386ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
387  return new R600OpenCLImageTypeLoweringPass();
388}
389