1311116Sdim//===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 2311116Sdim// 3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4353358Sdim// See https://llvm.org/LICENSE.txt for license information. 5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6311116Sdim// 7311116Sdim//===----------------------------------------------------------------------===// 8311116Sdim// 9311116Sdim// 10311116Sdim// Arguments to kernel and device functions are passed via param space, 11311116Sdim// which imposes certain restrictions: 12311116Sdim// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 13311116Sdim// 14311116Sdim// Kernel parameters are read-only and accessible only via ld.param 15311116Sdim// instruction, directly or via a pointer. Pointers to kernel 16311116Sdim// arguments can't be converted to generic address space. 17311116Sdim// 18311116Sdim// Device function parameters are directly accessible via 19311116Sdim// ld.param/st.param, but taking the address of one returns a pointer 20311116Sdim// to a copy created in local space which *can't* be used with 21311116Sdim// ld.param/st.param. 22311116Sdim// 23311116Sdim// Copying a byval struct into local memory in IR allows us to enforce 24311116Sdim// the param space restrictions, gives the rest of IR a pointer w/o 25311116Sdim// param space restrictions, and gives us an opportunity to eliminate 26311116Sdim// the copy. 27311116Sdim// 28311116Sdim// Pointer arguments to kernel functions need more work to be lowered: 29311116Sdim// 30311116Sdim// 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 31311116Sdim// global address space. This allows later optimizations to emit 32311116Sdim// ld.global.*/st.global.* for accessing these pointer arguments. For 33311116Sdim// example, 34311116Sdim// 35311116Sdim// define void @foo(float* %input) { 36311116Sdim// %v = load float, float* %input, align 4 37311116Sdim// ... 38311116Sdim// } 39311116Sdim// 40311116Sdim// becomes 41311116Sdim// 42311116Sdim// define void @foo(float* %input) { 43311116Sdim// %input2 = addrspacecast float* %input to float addrspace(1)* 44311116Sdim// %input3 = addrspacecast float addrspace(1)* %input2 to float* 45311116Sdim// %v = load float, float* %input3, align 4 46311116Sdim// ... 47311116Sdim// } 48311116Sdim// 49311116Sdim// Later, NVPTXInferAddressSpaces will optimize it to 50311116Sdim// 51311116Sdim// define void @foo(float* %input) { 52311116Sdim// %input2 = addrspacecast float* %input to float addrspace(1)* 53311116Sdim// %v = load float, float addrspace(1)* %input2, align 4 54311116Sdim// ... 55311116Sdim// } 56311116Sdim// 57311116Sdim// 2. Convert pointers in a byval kernel parameter to pointers in the global 58311116Sdim// address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., 59311116Sdim// 60311116Sdim// struct S { 61311116Sdim// int *x; 62311116Sdim// int *y; 63311116Sdim// }; 64311116Sdim// __global__ void foo(S s) { 65311116Sdim// int *b = s.y; 66311116Sdim// // use b 67311116Sdim// } 68311116Sdim// 69311116Sdim// "b" points to the global address space. In the IR level, 70311116Sdim// 71311116Sdim// define void @foo({i32*, i32*}* byval %input) { 72311116Sdim// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 73311116Sdim// %b = load i32*, i32** %b_ptr 74311116Sdim// ; use %b 75311116Sdim// } 76311116Sdim// 77311116Sdim// becomes 78311116Sdim// 79311116Sdim// define void @foo({i32*, i32*}* byval %input) { 80311116Sdim// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 81311116Sdim// %b = load i32*, i32** %b_ptr 82311116Sdim// %b_global = addrspacecast i32* %b to i32 addrspace(1)* 83311116Sdim// %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* 84311116Sdim// ; use %b_generic 85311116Sdim// } 86311116Sdim// 87311116Sdim// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 88311116Sdim// cancel the addrspacecast pair this pass emits. 89311116Sdim//===----------------------------------------------------------------------===// 90311116Sdim 91311116Sdim#include "NVPTX.h" 92321369Sdim#include "NVPTXTargetMachine.h" 93311116Sdim#include "NVPTXUtilities.h" 94353358Sdim#include "MCTargetDesc/NVPTXBaseInfo.h" 95311116Sdim#include "llvm/Analysis/ValueTracking.h" 96311116Sdim#include "llvm/IR/Function.h" 97311116Sdim#include "llvm/IR/Instructions.h" 98311116Sdim#include "llvm/IR/Module.h" 99311116Sdim#include "llvm/IR/Type.h" 100311116Sdim#include "llvm/Pass.h" 101311116Sdim 102311116Sdimusing namespace llvm; 103311116Sdim 104311116Sdimnamespace llvm { 105311116Sdimvoid initializeNVPTXLowerArgsPass(PassRegistry &); 106311116Sdim} 107311116Sdim 108311116Sdimnamespace { 109311116Sdimclass NVPTXLowerArgs : public FunctionPass { 110311116Sdim bool runOnFunction(Function &F) override; 111311116Sdim 112311116Sdim bool runOnKernelFunction(Function &F); 113311116Sdim bool runOnDeviceFunction(Function &F); 114311116Sdim 115311116Sdim // handle byval parameters 116311116Sdim void handleByValParam(Argument *Arg); 117311116Sdim // Knowing Ptr must point to the global address space, this function 118311116Sdim // addrspacecasts Ptr to global and then back to generic. This allows 119311116Sdim // NVPTXInferAddressSpaces to fold the global-to-generic cast into 120311116Sdim // loads/stores that appear later. 121311116Sdim void markPointerAsGlobal(Value *Ptr); 122311116Sdim 123311116Sdimpublic: 124311116Sdim static char ID; // Pass identification, replacement for typeid 125311116Sdim NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr) 126311116Sdim : FunctionPass(ID), TM(TM) {} 127311116Sdim StringRef getPassName() const override { 128311116Sdim return "Lower pointer arguments of CUDA kernels"; 129311116Sdim } 130311116Sdim 131311116Sdimprivate: 132311116Sdim const NVPTXTargetMachine *TM; 133311116Sdim}; 134311116Sdim} // namespace 135311116Sdim 136311116Sdimchar NVPTXLowerArgs::ID = 1; 137311116Sdim 138311116SdimINITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args", 139311116Sdim "Lower arguments (NVPTX)", false, false) 140311116Sdim 141311116Sdim// ============================================================================= 142311116Sdim// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 143311116Sdim// then add the following instructions to the first basic block: 144311116Sdim// 145311116Sdim// %temp = alloca %struct.x, align 8 146311116Sdim// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 147311116Sdim// %tv = load %struct.x addrspace(101)* %tempd 148311116Sdim// store %struct.x %tv, %struct.x* %temp, align 8 149311116Sdim// 150311116Sdim// The above code allocates some space in the stack and copies the incoming 151311116Sdim// struct from param space to local space. 152311116Sdim// Then replace all occurrences of %d by %temp. 153311116Sdim// ============================================================================= 154311116Sdimvoid NVPTXLowerArgs::handleByValParam(Argument *Arg) { 155311116Sdim Function *Func = Arg->getParent(); 156311116Sdim Instruction *FirstInst = &(Func->getEntryBlock().front()); 157311116Sdim PointerType *PType = dyn_cast<PointerType>(Arg->getType()); 158311116Sdim 159311116Sdim assert(PType && "Expecting pointer type in handleByValParam"); 160311116Sdim 161311116Sdim Type *StructType = PType->getElementType(); 162321369Sdim unsigned AS = Func->getParent()->getDataLayout().getAllocaAddrSpace(); 163321369Sdim AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); 164311116Sdim // Set the alignment to alignment of the byval parameter. This is because, 165311116Sdim // later load/stores assume that alignment, and we are going to replace 166311116Sdim // the use of the byval parameter with this alloca instruction. 167360784Sdim AllocA->setAlignment(MaybeAlign(Func->getParamAlignment(Arg->getArgNo()))); 168311116Sdim Arg->replaceAllUsesWith(AllocA); 169311116Sdim 170311116Sdim Value *ArgInParam = new AddrSpaceCastInst( 171311116Sdim Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 172311116Sdim FirstInst); 173353358Sdim LoadInst *LI = 174353358Sdim new LoadInst(StructType, ArgInParam, Arg->getName(), FirstInst); 175311116Sdim new StoreInst(LI, AllocA, FirstInst); 176311116Sdim} 177311116Sdim 178311116Sdimvoid NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 179311116Sdim if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) 180311116Sdim return; 181311116Sdim 182311116Sdim // Deciding where to emit the addrspacecast pair. 183311116Sdim BasicBlock::iterator InsertPt; 184311116Sdim if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 185311116Sdim // Insert at the functon entry if Ptr is an argument. 186311116Sdim InsertPt = Arg->getParent()->getEntryBlock().begin(); 187311116Sdim } else { 188311116Sdim // Insert right after Ptr if Ptr is an instruction. 189311116Sdim InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 190311116Sdim assert(InsertPt != InsertPt->getParent()->end() && 191311116Sdim "We don't call this function with Ptr being a terminator."); 192311116Sdim } 193311116Sdim 194311116Sdim Instruction *PtrInGlobal = new AddrSpaceCastInst( 195311116Sdim Ptr, PointerType::get(Ptr->getType()->getPointerElementType(), 196311116Sdim ADDRESS_SPACE_GLOBAL), 197311116Sdim Ptr->getName(), &*InsertPt); 198311116Sdim Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 199311116Sdim Ptr->getName(), &*InsertPt); 200311116Sdim // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 201311116Sdim Ptr->replaceAllUsesWith(PtrInGeneric); 202311116Sdim PtrInGlobal->setOperand(0, Ptr); 203311116Sdim} 204311116Sdim 205311116Sdim// ============================================================================= 206311116Sdim// Main function for this pass. 207311116Sdim// ============================================================================= 208311116Sdimbool NVPTXLowerArgs::runOnKernelFunction(Function &F) { 209311116Sdim if (TM && TM->getDrvInterface() == NVPTX::CUDA) { 210311116Sdim // Mark pointers in byval structs as global. 211311116Sdim for (auto &B : F) { 212311116Sdim for (auto &I : B) { 213311116Sdim if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 214311116Sdim if (LI->getType()->isPointerTy()) { 215311116Sdim Value *UO = GetUnderlyingObject(LI->getPointerOperand(), 216311116Sdim F.getParent()->getDataLayout()); 217311116Sdim if (Argument *Arg = dyn_cast<Argument>(UO)) { 218311116Sdim if (Arg->hasByValAttr()) { 219311116Sdim // LI is a load from a pointer within a byval kernel parameter. 220311116Sdim markPointerAsGlobal(LI); 221311116Sdim } 222311116Sdim } 223311116Sdim } 224311116Sdim } 225311116Sdim } 226311116Sdim } 227311116Sdim } 228311116Sdim 229311116Sdim for (Argument &Arg : F.args()) { 230311116Sdim if (Arg.getType()->isPointerTy()) { 231311116Sdim if (Arg.hasByValAttr()) 232311116Sdim handleByValParam(&Arg); 233311116Sdim else if (TM && TM->getDrvInterface() == NVPTX::CUDA) 234311116Sdim markPointerAsGlobal(&Arg); 235311116Sdim } 236311116Sdim } 237311116Sdim return true; 238311116Sdim} 239311116Sdim 240311116Sdim// Device functions only need to copy byval args into local memory. 241311116Sdimbool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { 242311116Sdim for (Argument &Arg : F.args()) 243311116Sdim if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 244311116Sdim handleByValParam(&Arg); 245311116Sdim return true; 246311116Sdim} 247311116Sdim 248311116Sdimbool NVPTXLowerArgs::runOnFunction(Function &F) { 249311116Sdim return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F); 250311116Sdim} 251311116Sdim 252311116SdimFunctionPass * 253311116Sdimllvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) { 254311116Sdim return new NVPTXLowerArgs(TM); 255311116Sdim} 256