1336809Sdim//===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2336809Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6336809Sdim//
7336809Sdim//===----------------------------------------------------------------------===//
8336809Sdim//
9336809Sdim/// \file
10336809Sdim/// This pass resolves calls to OpenCL image attribute, image resource ID and
11336809Sdim/// sampler resource ID getter functions.
12336809Sdim///
13336809Sdim/// Image attributes (size and format) are expected to be passed to the kernel
14336809Sdim/// as kernel arguments immediately following the image argument itself,
15336809Sdim/// therefore this pass adds image size and format arguments to the kernel
16336809Sdim/// functions in the module. The kernel functions with image arguments are
17336809Sdim/// re-created using the new signature. The new arguments are added to the
18336809Sdim/// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19336809Sdim/// Note: this pass may invalidate pointers to functions.
20336809Sdim///
21336809Sdim/// Resource IDs of read-only images, write-only images and samplers are
22336809Sdim/// defined to be their index among the kernel arguments of the same
23336809Sdim/// type and access qualifier.
24336809Sdim//
25336809Sdim//===----------------------------------------------------------------------===//
26336809Sdim
27336809Sdim#include "AMDGPU.h"
28336809Sdim#include "llvm/ADT/SmallVector.h"
29336809Sdim#include "llvm/ADT/StringRef.h"
30336809Sdim#include "llvm/ADT/Twine.h"
31336809Sdim#include "llvm/IR/Argument.h"
32336809Sdim#include "llvm/IR/DerivedTypes.h"
33336809Sdim#include "llvm/IR/Constants.h"
34336809Sdim#include "llvm/IR/Function.h"
35336809Sdim#include "llvm/IR/Instruction.h"
36336809Sdim#include "llvm/IR/Instructions.h"
37336809Sdim#include "llvm/IR/Metadata.h"
38336809Sdim#include "llvm/IR/Module.h"
39336809Sdim#include "llvm/IR/Type.h"
40336809Sdim#include "llvm/IR/Use.h"
41336809Sdim#include "llvm/IR/User.h"
42336809Sdim#include "llvm/Pass.h"
43336809Sdim#include "llvm/Support/Casting.h"
44336809Sdim#include "llvm/Support/ErrorHandling.h"
45336809Sdim#include "llvm/Transforms/Utils/Cloning.h"
46336809Sdim#include "llvm/Transforms/Utils/ValueMapper.h"
47336809Sdim#include <cassert>
48336809Sdim#include <cstddef>
49336809Sdim#include <cstdint>
50336809Sdim#include <tuple>
51336809Sdim
52336809Sdimusing namespace llvm;
53336809Sdim
54336809Sdimstatic StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
55336809Sdimstatic StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
56336809Sdimstatic StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
57336809Sdimstatic StringRef GetSamplerResourceIDFunc =
58336809Sdim    "llvm.OpenCL.sampler.get.resource.id";
59336809Sdim
60336809Sdimstatic StringRef ImageSizeArgMDType =   "__llvm_image_size";
61336809Sdimstatic StringRef ImageFormatArgMDType = "__llvm_image_format";
62336809Sdim
63336809Sdimstatic StringRef KernelsMDNodeName = "opencl.kernels";
64336809Sdimstatic StringRef KernelArgMDNodeNames[] = {
65336809Sdim  "kernel_arg_addr_space",
66336809Sdim  "kernel_arg_access_qual",
67336809Sdim  "kernel_arg_type",
68336809Sdim  "kernel_arg_base_type",
69336809Sdim  "kernel_arg_type_qual"};
70336809Sdimstatic const unsigned NumKernelArgMDNodes = 5;
71336809Sdim
72336809Sdimnamespace {
73336809Sdim
74336809Sdimusing MDVector = SmallVector<Metadata *, 8>;
75336809Sdimstruct KernelArgMD {
76336809Sdim  MDVector ArgVector[NumKernelArgMDNodes];
77336809Sdim};
78336809Sdim
79336809Sdim} // end anonymous namespace
80336809Sdim
81336809Sdimstatic inline bool
82336809SdimIsImageType(StringRef TypeString) {
83336809Sdim  return TypeString == "image2d_t" || TypeString == "image3d_t";
84336809Sdim}
85336809Sdim
86336809Sdimstatic inline bool
87336809SdimIsSamplerType(StringRef TypeString) {
88336809Sdim  return TypeString == "sampler_t";
89336809Sdim}
90336809Sdim
91336809Sdimstatic Function *
92336809SdimGetFunctionFromMDNode(MDNode *Node) {
93336809Sdim  if (!Node)
94336809Sdim    return nullptr;
95336809Sdim
96336809Sdim  size_t NumOps = Node->getNumOperands();
97336809Sdim  if (NumOps != NumKernelArgMDNodes + 1)
98336809Sdim    return nullptr;
99336809Sdim
100336809Sdim  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
101336809Sdim  if (!F)
102336809Sdim    return nullptr;
103336809Sdim
104336809Sdim  // Sanity checks.
105336809Sdim  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
106336809Sdim  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
107336809Sdim    MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
108336809Sdim    if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
109336809Sdim      return nullptr;
110336809Sdim    if (!ArgNode->getOperand(0))
111336809Sdim      return nullptr;
112336809Sdim
113336809Sdim    // FIXME: It should be possible to do image lowering when some metadata
114336809Sdim    // args missing or not in the expected order.
115336809Sdim    MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
116336809Sdim    if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
117336809Sdim      return nullptr;
118336809Sdim  }
119336809Sdim
120336809Sdim  return F;
121336809Sdim}
122336809Sdim
123336809Sdimstatic StringRef
124336809SdimAccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
125336809Sdim  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
126336809Sdim  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
127336809Sdim}
128336809Sdim
129336809Sdimstatic StringRef
130336809SdimArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
131336809Sdim  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
132336809Sdim  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
133336809Sdim}
134336809Sdim
135336809Sdimstatic MDVector
136336809SdimGetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
137336809Sdim  MDVector Res;
138336809Sdim  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
139336809Sdim    MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
140336809Sdim    Res.push_back(Node->getOperand(OpIdx));
141336809Sdim  }
142336809Sdim  return Res;
143336809Sdim}
144336809Sdim
145336809Sdimstatic void
146336809SdimPushArgMD(KernelArgMD &MD, const MDVector &V) {
147336809Sdim  assert(V.size() == NumKernelArgMDNodes);
148336809Sdim  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
149336809Sdim    MD.ArgVector[i].push_back(V[i]);
150336809Sdim  }
151336809Sdim}
152336809Sdim
153336809Sdimnamespace {
154336809Sdim
155336809Sdimclass R600OpenCLImageTypeLoweringPass : public ModulePass {
156336809Sdim  static char ID;
157336809Sdim
158336809Sdim  LLVMContext *Context;
159336809Sdim  Type *Int32Type;
160336809Sdim  Type *ImageSizeType;
161336809Sdim  Type *ImageFormatType;
162336809Sdim  SmallVector<Instruction *, 4> InstsToErase;
163336809Sdim
164336809Sdim  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
165336809Sdim                        Argument &ImageSizeArg,
166336809Sdim                        Argument &ImageFormatArg) {
167336809Sdim    bool Modified = false;
168336809Sdim
169336809Sdim    for (auto &Use : ImageArg.uses()) {
170336809Sdim      auto Inst = dyn_cast<CallInst>(Use.getUser());
171336809Sdim      if (!Inst) {
172336809Sdim        continue;
173336809Sdim      }
174336809Sdim
175336809Sdim      Function *F = Inst->getCalledFunction();
176336809Sdim      if (!F)
177336809Sdim        continue;
178336809Sdim
179336809Sdim      Value *Replacement = nullptr;
180336809Sdim      StringRef Name = F->getName();
181336809Sdim      if (Name.startswith(GetImageResourceIDFunc)) {
182336809Sdim        Replacement = ConstantInt::get(Int32Type, ResourceID);
183336809Sdim      } else if (Name.startswith(GetImageSizeFunc)) {
184336809Sdim        Replacement = &ImageSizeArg;
185336809Sdim      } else if (Name.startswith(GetImageFormatFunc)) {
186336809Sdim        Replacement = &ImageFormatArg;
187336809Sdim      } else {
188336809Sdim        continue;
189336809Sdim      }
190336809Sdim
191336809Sdim      Inst->replaceAllUsesWith(Replacement);
192336809Sdim      InstsToErase.push_back(Inst);
193336809Sdim      Modified = true;
194336809Sdim    }
195336809Sdim
196336809Sdim    return Modified;
197336809Sdim  }
198336809Sdim
199336809Sdim  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
200336809Sdim    bool Modified = false;
201336809Sdim
202336809Sdim    for (const auto &Use : SamplerArg.uses()) {
203336809Sdim      auto Inst = dyn_cast<CallInst>(Use.getUser());
204336809Sdim      if (!Inst) {
205336809Sdim        continue;
206336809Sdim      }
207336809Sdim
208336809Sdim      Function *F = Inst->getCalledFunction();
209336809Sdim      if (!F)
210336809Sdim        continue;
211336809Sdim
212336809Sdim      Value *Replacement = nullptr;
213336809Sdim      StringRef Name = F->getName();
214336809Sdim      if (Name == GetSamplerResourceIDFunc) {
215336809Sdim        Replacement = ConstantInt::get(Int32Type, ResourceID);
216336809Sdim      } else {
217336809Sdim        continue;
218336809Sdim      }
219336809Sdim
220336809Sdim      Inst->replaceAllUsesWith(Replacement);
221336809Sdim      InstsToErase.push_back(Inst);
222336809Sdim      Modified = true;
223336809Sdim    }
224336809Sdim
225336809Sdim    return Modified;
226336809Sdim  }
227336809Sdim
228336809Sdim  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
229336809Sdim    uint32_t NumReadOnlyImageArgs = 0;
230336809Sdim    uint32_t NumWriteOnlyImageArgs = 0;
231336809Sdim    uint32_t NumSamplerArgs = 0;
232336809Sdim
233336809Sdim    bool Modified = false;
234336809Sdim    InstsToErase.clear();
235336809Sdim    for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
236336809Sdim      Argument &Arg = *ArgI;
237336809Sdim      StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
238336809Sdim
239336809Sdim      // Handle image types.
240336809Sdim      if (IsImageType(Type)) {
241336809Sdim        StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
242336809Sdim        uint32_t ResourceID;
243336809Sdim        if (AccessQual == "read_only") {
244336809Sdim          ResourceID = NumReadOnlyImageArgs++;
245336809Sdim        } else if (AccessQual == "write_only") {
246336809Sdim          ResourceID = NumWriteOnlyImageArgs++;
247336809Sdim        } else {
248336809Sdim          llvm_unreachable("Wrong image access qualifier.");
249336809Sdim        }
250336809Sdim
251336809Sdim        Argument &SizeArg = *(++ArgI);
252336809Sdim        Argument &FormatArg = *(++ArgI);
253336809Sdim        Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
254336809Sdim
255336809Sdim      // Handle sampler type.
256336809Sdim      } else if (IsSamplerType(Type)) {
257336809Sdim        uint32_t ResourceID = NumSamplerArgs++;
258336809Sdim        Modified |= replaceSamplerUses(Arg, ResourceID);
259336809Sdim      }
260336809Sdim    }
261336809Sdim    for (unsigned i = 0; i < InstsToErase.size(); ++i) {
262336809Sdim      InstsToErase[i]->eraseFromParent();
263336809Sdim    }
264336809Sdim
265336809Sdim    return Modified;
266336809Sdim  }
267336809Sdim
268336809Sdim  std::tuple<Function *, MDNode *>
269336809Sdim  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
270336809Sdim    bool Modified = false;
271336809Sdim
272336809Sdim    FunctionType *FT = F->getFunctionType();
273336809Sdim    SmallVector<Type *, 8> ArgTypes;
274336809Sdim
275336809Sdim    // Metadata operands for new MDNode.
276336809Sdim    KernelArgMD NewArgMDs;
277336809Sdim    PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
278336809Sdim
279336809Sdim    // Add implicit arguments to the signature.
280336809Sdim    for (unsigned i = 0; i < FT->getNumParams(); ++i) {
281336809Sdim      ArgTypes.push_back(FT->getParamType(i));
282336809Sdim      MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
283336809Sdim      PushArgMD(NewArgMDs, ArgMD);
284336809Sdim
285336809Sdim      if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
286336809Sdim        continue;
287336809Sdim
288336809Sdim      // Add size implicit argument.
289336809Sdim      ArgTypes.push_back(ImageSizeType);
290336809Sdim      ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
291336809Sdim      PushArgMD(NewArgMDs, ArgMD);
292336809Sdim
293336809Sdim      // Add format implicit argument.
294336809Sdim      ArgTypes.push_back(ImageFormatType);
295336809Sdim      ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
296336809Sdim      PushArgMD(NewArgMDs, ArgMD);
297336809Sdim
298336809Sdim      Modified = true;
299336809Sdim    }
300336809Sdim    if (!Modified) {
301336809Sdim      return std::make_tuple(nullptr, nullptr);
302336809Sdim    }
303336809Sdim
304336809Sdim    // Create function with new signature and clone the old body into it.
305336809Sdim    auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
306336809Sdim    auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
307336809Sdim    ValueToValueMapTy VMap;
308336809Sdim    auto NewFArgIt = NewF->arg_begin();
309336809Sdim    for (auto &Arg: F->args()) {
310336809Sdim      auto ArgName = Arg.getName();
311336809Sdim      NewFArgIt->setName(ArgName);
312336809Sdim      VMap[&Arg] = &(*NewFArgIt++);
313336809Sdim      if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
314336809Sdim        (NewFArgIt++)->setName(Twine("__size_") + ArgName);
315336809Sdim        (NewFArgIt++)->setName(Twine("__format_") + ArgName);
316336809Sdim      }
317336809Sdim    }
318336809Sdim    SmallVector<ReturnInst*, 8> Returns;
319336809Sdim    CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
320336809Sdim
321336809Sdim    // Build new MDNode.
322336809Sdim    SmallVector<Metadata *, 6> KernelMDArgs;
323336809Sdim    KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
324336809Sdim    for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
325336809Sdim      KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
326336809Sdim    MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
327336809Sdim
328336809Sdim    return std::make_tuple(NewF, NewMDNode);
329336809Sdim  }
330336809Sdim
331336809Sdim  bool transformKernels(Module &M) {
332336809Sdim    NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
333336809Sdim    if (!KernelsMDNode)
334336809Sdim      return false;
335336809Sdim
336336809Sdim    bool Modified = false;
337336809Sdim    for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
338336809Sdim      MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
339336809Sdim      Function *F = GetFunctionFromMDNode(KernelMDNode);
340336809Sdim      if (!F)
341336809Sdim        continue;
342336809Sdim
343336809Sdim      Function *NewF;
344336809Sdim      MDNode *NewMDNode;
345336809Sdim      std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
346336809Sdim      if (NewF) {
347336809Sdim        // Replace old function and metadata with new ones.
348336809Sdim        F->eraseFromParent();
349336809Sdim        M.getFunctionList().push_back(NewF);
350336809Sdim        M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
351336809Sdim                              NewF->getAttributes());
352336809Sdim        KernelsMDNode->setOperand(i, NewMDNode);
353336809Sdim
354336809Sdim        F = NewF;
355336809Sdim        KernelMDNode = NewMDNode;
356336809Sdim        Modified = true;
357336809Sdim      }
358336809Sdim
359336809Sdim      Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
360336809Sdim    }
361336809Sdim
362336809Sdim    return Modified;
363336809Sdim  }
364336809Sdim
365336809Sdimpublic:
366336809Sdim  R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
367336809Sdim
368336809Sdim  bool runOnModule(Module &M) override {
369336809Sdim    Context = &M.getContext();
370336809Sdim    Int32Type = Type::getInt32Ty(M.getContext());
371336809Sdim    ImageSizeType = ArrayType::get(Int32Type, 3);
372336809Sdim    ImageFormatType = ArrayType::get(Int32Type, 2);
373336809Sdim
374336809Sdim    return transformKernels(M);
375336809Sdim  }
376336809Sdim
377336809Sdim  StringRef getPassName() const override {
378336809Sdim    return "R600 OpenCL Image Type Pass";
379336809Sdim  }
380336809Sdim};
381336809Sdim
382336809Sdim} // end anonymous namespace
383336809Sdim
384336809Sdimchar R600OpenCLImageTypeLoweringPass::ID = 0;
385336809Sdim
386336809SdimModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
387336809Sdim  return new R600OpenCLImageTypeLoweringPass();
388336809Sdim}
389