1336809Sdim//===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===// 2336809Sdim// 3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4353358Sdim// See https://llvm.org/LICENSE.txt for license information. 5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6336809Sdim// 7336809Sdim//===----------------------------------------------------------------------===// 8336809Sdim// 9336809Sdim/// \file 10336809Sdim/// This pass resolves calls to OpenCL image attribute, image resource ID and 11336809Sdim/// sampler resource ID getter functions. 12336809Sdim/// 13336809Sdim/// Image attributes (size and format) are expected to be passed to the kernel 14336809Sdim/// as kernel arguments immediately following the image argument itself, 15336809Sdim/// therefore this pass adds image size and format arguments to the kernel 16336809Sdim/// functions in the module. The kernel functions with image arguments are 17336809Sdim/// re-created using the new signature. The new arguments are added to the 18336809Sdim/// kernel metadata with kernel_arg_type set to "image_size" or "image_format". 19336809Sdim/// Note: this pass may invalidate pointers to functions. 20336809Sdim/// 21336809Sdim/// Resource IDs of read-only images, write-only images and samplers are 22336809Sdim/// defined to be their index among the kernel arguments of the same 23336809Sdim/// type and access qualifier. 24336809Sdim// 25336809Sdim//===----------------------------------------------------------------------===// 26336809Sdim 27336809Sdim#include "AMDGPU.h" 28336809Sdim#include "llvm/ADT/SmallVector.h" 29336809Sdim#include "llvm/ADT/StringRef.h" 30336809Sdim#include "llvm/ADT/Twine.h" 31336809Sdim#include "llvm/IR/Argument.h" 32336809Sdim#include "llvm/IR/DerivedTypes.h" 33336809Sdim#include "llvm/IR/Constants.h" 34336809Sdim#include "llvm/IR/Function.h" 35336809Sdim#include "llvm/IR/Instruction.h" 36336809Sdim#include "llvm/IR/Instructions.h" 37336809Sdim#include "llvm/IR/Metadata.h" 38336809Sdim#include "llvm/IR/Module.h" 39336809Sdim#include "llvm/IR/Type.h" 40336809Sdim#include "llvm/IR/Use.h" 41336809Sdim#include "llvm/IR/User.h" 42336809Sdim#include "llvm/Pass.h" 43336809Sdim#include "llvm/Support/Casting.h" 44336809Sdim#include "llvm/Support/ErrorHandling.h" 45336809Sdim#include "llvm/Transforms/Utils/Cloning.h" 46336809Sdim#include "llvm/Transforms/Utils/ValueMapper.h" 47336809Sdim#include <cassert> 48336809Sdim#include <cstddef> 49336809Sdim#include <cstdint> 50336809Sdim#include <tuple> 51336809Sdim 52336809Sdimusing namespace llvm; 53336809Sdim 54336809Sdimstatic StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size"; 55336809Sdimstatic StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format"; 56336809Sdimstatic StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id"; 57336809Sdimstatic StringRef GetSamplerResourceIDFunc = 58336809Sdim "llvm.OpenCL.sampler.get.resource.id"; 59336809Sdim 60336809Sdimstatic StringRef ImageSizeArgMDType = "__llvm_image_size"; 61336809Sdimstatic StringRef ImageFormatArgMDType = "__llvm_image_format"; 62336809Sdim 63336809Sdimstatic StringRef KernelsMDNodeName = "opencl.kernels"; 64336809Sdimstatic StringRef KernelArgMDNodeNames[] = { 65336809Sdim "kernel_arg_addr_space", 66336809Sdim "kernel_arg_access_qual", 67336809Sdim "kernel_arg_type", 68336809Sdim "kernel_arg_base_type", 69336809Sdim "kernel_arg_type_qual"}; 70336809Sdimstatic const unsigned NumKernelArgMDNodes = 5; 71336809Sdim 72336809Sdimnamespace { 73336809Sdim 74336809Sdimusing MDVector = SmallVector<Metadata *, 8>; 75336809Sdimstruct KernelArgMD { 76336809Sdim MDVector ArgVector[NumKernelArgMDNodes]; 77336809Sdim}; 78336809Sdim 79336809Sdim} // end anonymous namespace 80336809Sdim 81336809Sdimstatic inline bool 82336809SdimIsImageType(StringRef TypeString) { 83336809Sdim return TypeString == "image2d_t" || TypeString == "image3d_t"; 84336809Sdim} 85336809Sdim 86336809Sdimstatic inline bool 87336809SdimIsSamplerType(StringRef TypeString) { 88336809Sdim return TypeString == "sampler_t"; 89336809Sdim} 90336809Sdim 91336809Sdimstatic Function * 92336809SdimGetFunctionFromMDNode(MDNode *Node) { 93336809Sdim if (!Node) 94336809Sdim return nullptr; 95336809Sdim 96336809Sdim size_t NumOps = Node->getNumOperands(); 97336809Sdim if (NumOps != NumKernelArgMDNodes + 1) 98336809Sdim return nullptr; 99336809Sdim 100336809Sdim auto F = mdconst::dyn_extract<Function>(Node->getOperand(0)); 101336809Sdim if (!F) 102336809Sdim return nullptr; 103336809Sdim 104336809Sdim // Sanity checks. 105336809Sdim size_t ExpectNumArgNodeOps = F->arg_size() + 1; 106336809Sdim for (size_t i = 0; i < NumKernelArgMDNodes; ++i) { 107336809Sdim MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1)); 108336809Sdim if (ArgNode->getNumOperands() != ExpectNumArgNodeOps) 109336809Sdim return nullptr; 110336809Sdim if (!ArgNode->getOperand(0)) 111336809Sdim return nullptr; 112336809Sdim 113336809Sdim // FIXME: It should be possible to do image lowering when some metadata 114336809Sdim // args missing or not in the expected order. 115336809Sdim MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0)); 116336809Sdim if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i]) 117336809Sdim return nullptr; 118336809Sdim } 119336809Sdim 120336809Sdim return F; 121336809Sdim} 122336809Sdim 123336809Sdimstatic StringRef 124336809SdimAccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 125336809Sdim MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2)); 126336809Sdim return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString(); 127336809Sdim} 128336809Sdim 129336809Sdimstatic StringRef 130336809SdimArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 131336809Sdim MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3)); 132336809Sdim return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString(); 133336809Sdim} 134336809Sdim 135336809Sdimstatic MDVector 136336809SdimGetArgMD(MDNode *KernelMDNode, unsigned OpIdx) { 137336809Sdim MDVector Res; 138336809Sdim for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 139336809Sdim MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1)); 140336809Sdim Res.push_back(Node->getOperand(OpIdx)); 141336809Sdim } 142336809Sdim return Res; 143336809Sdim} 144336809Sdim 145336809Sdimstatic void 146336809SdimPushArgMD(KernelArgMD &MD, const MDVector &V) { 147336809Sdim assert(V.size() == NumKernelArgMDNodes); 148336809Sdim for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 149336809Sdim MD.ArgVector[i].push_back(V[i]); 150336809Sdim } 151336809Sdim} 152336809Sdim 153336809Sdimnamespace { 154336809Sdim 155336809Sdimclass R600OpenCLImageTypeLoweringPass : public ModulePass { 156336809Sdim static char ID; 157336809Sdim 158336809Sdim LLVMContext *Context; 159336809Sdim Type *Int32Type; 160336809Sdim Type *ImageSizeType; 161336809Sdim Type *ImageFormatType; 162336809Sdim SmallVector<Instruction *, 4> InstsToErase; 163336809Sdim 164336809Sdim bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID, 165336809Sdim Argument &ImageSizeArg, 166336809Sdim Argument &ImageFormatArg) { 167336809Sdim bool Modified = false; 168336809Sdim 169336809Sdim for (auto &Use : ImageArg.uses()) { 170336809Sdim auto Inst = dyn_cast<CallInst>(Use.getUser()); 171336809Sdim if (!Inst) { 172336809Sdim continue; 173336809Sdim } 174336809Sdim 175336809Sdim Function *F = Inst->getCalledFunction(); 176336809Sdim if (!F) 177336809Sdim continue; 178336809Sdim 179336809Sdim Value *Replacement = nullptr; 180336809Sdim StringRef Name = F->getName(); 181336809Sdim if (Name.startswith(GetImageResourceIDFunc)) { 182336809Sdim Replacement = ConstantInt::get(Int32Type, ResourceID); 183336809Sdim } else if (Name.startswith(GetImageSizeFunc)) { 184336809Sdim Replacement = &ImageSizeArg; 185336809Sdim } else if (Name.startswith(GetImageFormatFunc)) { 186336809Sdim Replacement = &ImageFormatArg; 187336809Sdim } else { 188336809Sdim continue; 189336809Sdim } 190336809Sdim 191336809Sdim Inst->replaceAllUsesWith(Replacement); 192336809Sdim InstsToErase.push_back(Inst); 193336809Sdim Modified = true; 194336809Sdim } 195336809Sdim 196336809Sdim return Modified; 197336809Sdim } 198336809Sdim 199336809Sdim bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) { 200336809Sdim bool Modified = false; 201336809Sdim 202336809Sdim for (const auto &Use : SamplerArg.uses()) { 203336809Sdim auto Inst = dyn_cast<CallInst>(Use.getUser()); 204336809Sdim if (!Inst) { 205336809Sdim continue; 206336809Sdim } 207336809Sdim 208336809Sdim Function *F = Inst->getCalledFunction(); 209336809Sdim if (!F) 210336809Sdim continue; 211336809Sdim 212336809Sdim Value *Replacement = nullptr; 213336809Sdim StringRef Name = F->getName(); 214336809Sdim if (Name == GetSamplerResourceIDFunc) { 215336809Sdim Replacement = ConstantInt::get(Int32Type, ResourceID); 216336809Sdim } else { 217336809Sdim continue; 218336809Sdim } 219336809Sdim 220336809Sdim Inst->replaceAllUsesWith(Replacement); 221336809Sdim InstsToErase.push_back(Inst); 222336809Sdim Modified = true; 223336809Sdim } 224336809Sdim 225336809Sdim return Modified; 226336809Sdim } 227336809Sdim 228336809Sdim bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) { 229336809Sdim uint32_t NumReadOnlyImageArgs = 0; 230336809Sdim uint32_t NumWriteOnlyImageArgs = 0; 231336809Sdim uint32_t NumSamplerArgs = 0; 232336809Sdim 233336809Sdim bool Modified = false; 234336809Sdim InstsToErase.clear(); 235336809Sdim for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) { 236336809Sdim Argument &Arg = *ArgI; 237336809Sdim StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo()); 238336809Sdim 239336809Sdim // Handle image types. 240336809Sdim if (IsImageType(Type)) { 241336809Sdim StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo()); 242336809Sdim uint32_t ResourceID; 243336809Sdim if (AccessQual == "read_only") { 244336809Sdim ResourceID = NumReadOnlyImageArgs++; 245336809Sdim } else if (AccessQual == "write_only") { 246336809Sdim ResourceID = NumWriteOnlyImageArgs++; 247336809Sdim } else { 248336809Sdim llvm_unreachable("Wrong image access qualifier."); 249336809Sdim } 250336809Sdim 251336809Sdim Argument &SizeArg = *(++ArgI); 252336809Sdim Argument &FormatArg = *(++ArgI); 253336809Sdim Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg); 254336809Sdim 255336809Sdim // Handle sampler type. 256336809Sdim } else if (IsSamplerType(Type)) { 257336809Sdim uint32_t ResourceID = NumSamplerArgs++; 258336809Sdim Modified |= replaceSamplerUses(Arg, ResourceID); 259336809Sdim } 260336809Sdim } 261336809Sdim for (unsigned i = 0; i < InstsToErase.size(); ++i) { 262336809Sdim InstsToErase[i]->eraseFromParent(); 263336809Sdim } 264336809Sdim 265336809Sdim return Modified; 266336809Sdim } 267336809Sdim 268336809Sdim std::tuple<Function *, MDNode *> 269336809Sdim addImplicitArgs(Function *F, MDNode *KernelMDNode) { 270336809Sdim bool Modified = false; 271336809Sdim 272336809Sdim FunctionType *FT = F->getFunctionType(); 273336809Sdim SmallVector<Type *, 8> ArgTypes; 274336809Sdim 275336809Sdim // Metadata operands for new MDNode. 276336809Sdim KernelArgMD NewArgMDs; 277336809Sdim PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0)); 278336809Sdim 279336809Sdim // Add implicit arguments to the signature. 280336809Sdim for (unsigned i = 0; i < FT->getNumParams(); ++i) { 281336809Sdim ArgTypes.push_back(FT->getParamType(i)); 282336809Sdim MDVector ArgMD = GetArgMD(KernelMDNode, i + 1); 283336809Sdim PushArgMD(NewArgMDs, ArgMD); 284336809Sdim 285336809Sdim if (!IsImageType(ArgTypeFromMD(KernelMDNode, i))) 286336809Sdim continue; 287336809Sdim 288336809Sdim // Add size implicit argument. 289336809Sdim ArgTypes.push_back(ImageSizeType); 290336809Sdim ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType); 291336809Sdim PushArgMD(NewArgMDs, ArgMD); 292336809Sdim 293336809Sdim // Add format implicit argument. 294336809Sdim ArgTypes.push_back(ImageFormatType); 295336809Sdim ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType); 296336809Sdim PushArgMD(NewArgMDs, ArgMD); 297336809Sdim 298336809Sdim Modified = true; 299336809Sdim } 300336809Sdim if (!Modified) { 301336809Sdim return std::make_tuple(nullptr, nullptr); 302336809Sdim } 303336809Sdim 304336809Sdim // Create function with new signature and clone the old body into it. 305336809Sdim auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false); 306336809Sdim auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName()); 307336809Sdim ValueToValueMapTy VMap; 308336809Sdim auto NewFArgIt = NewF->arg_begin(); 309336809Sdim for (auto &Arg: F->args()) { 310336809Sdim auto ArgName = Arg.getName(); 311336809Sdim NewFArgIt->setName(ArgName); 312336809Sdim VMap[&Arg] = &(*NewFArgIt++); 313336809Sdim if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) { 314336809Sdim (NewFArgIt++)->setName(Twine("__size_") + ArgName); 315336809Sdim (NewFArgIt++)->setName(Twine("__format_") + ArgName); 316336809Sdim } 317336809Sdim } 318336809Sdim SmallVector<ReturnInst*, 8> Returns; 319336809Sdim CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns); 320336809Sdim 321336809Sdim // Build new MDNode. 322336809Sdim SmallVector<Metadata *, 6> KernelMDArgs; 323336809Sdim KernelMDArgs.push_back(ConstantAsMetadata::get(NewF)); 324336809Sdim for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) 325336809Sdim KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i])); 326336809Sdim MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs); 327336809Sdim 328336809Sdim return std::make_tuple(NewF, NewMDNode); 329336809Sdim } 330336809Sdim 331336809Sdim bool transformKernels(Module &M) { 332336809Sdim NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName); 333336809Sdim if (!KernelsMDNode) 334336809Sdim return false; 335336809Sdim 336336809Sdim bool Modified = false; 337336809Sdim for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) { 338336809Sdim MDNode *KernelMDNode = KernelsMDNode->getOperand(i); 339336809Sdim Function *F = GetFunctionFromMDNode(KernelMDNode); 340336809Sdim if (!F) 341336809Sdim continue; 342336809Sdim 343336809Sdim Function *NewF; 344336809Sdim MDNode *NewMDNode; 345336809Sdim std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode); 346336809Sdim if (NewF) { 347336809Sdim // Replace old function and metadata with new ones. 348336809Sdim F->eraseFromParent(); 349336809Sdim M.getFunctionList().push_back(NewF); 350336809Sdim M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(), 351336809Sdim NewF->getAttributes()); 352336809Sdim KernelsMDNode->setOperand(i, NewMDNode); 353336809Sdim 354336809Sdim F = NewF; 355336809Sdim KernelMDNode = NewMDNode; 356336809Sdim Modified = true; 357336809Sdim } 358336809Sdim 359336809Sdim Modified |= replaceImageAndSamplerUses(F, KernelMDNode); 360336809Sdim } 361336809Sdim 362336809Sdim return Modified; 363336809Sdim } 364336809Sdim 365336809Sdimpublic: 366336809Sdim R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {} 367336809Sdim 368336809Sdim bool runOnModule(Module &M) override { 369336809Sdim Context = &M.getContext(); 370336809Sdim Int32Type = Type::getInt32Ty(M.getContext()); 371336809Sdim ImageSizeType = ArrayType::get(Int32Type, 3); 372336809Sdim ImageFormatType = ArrayType::get(Int32Type, 2); 373336809Sdim 374336809Sdim return transformKernels(M); 375336809Sdim } 376336809Sdim 377336809Sdim StringRef getPassName() const override { 378336809Sdim return "R600 OpenCL Image Type Pass"; 379336809Sdim } 380336809Sdim}; 381336809Sdim 382336809Sdim} // end anonymous namespace 383336809Sdim 384336809Sdimchar R600OpenCLImageTypeLoweringPass::ID = 0; 385336809Sdim 386336809SdimModulePass *llvm::createR600OpenCLImageTypeLoweringPass() { 387336809Sdim return new R600OpenCLImageTypeLoweringPass(); 388336809Sdim} 389