1311116Sdim//===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
2311116Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6311116Sdim//
7311116Sdim//===----------------------------------------------------------------------===//
8311116Sdim//
9311116Sdim//
10311116Sdim// Arguments to kernel and device functions are passed via param space,
11311116Sdim// which imposes certain restrictions:
12311116Sdim// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
13311116Sdim//
14311116Sdim// Kernel parameters are read-only and accessible only via ld.param
15311116Sdim// instruction, directly or via a pointer. Pointers to kernel
16311116Sdim// arguments can't be converted to generic address space.
17311116Sdim//
18311116Sdim// Device function parameters are directly accessible via
19311116Sdim// ld.param/st.param, but taking the address of one returns a pointer
20311116Sdim// to a copy created in local space which *can't* be used with
21311116Sdim// ld.param/st.param.
22311116Sdim//
23311116Sdim// Copying a byval struct into local memory in IR allows us to enforce
24311116Sdim// the param space restrictions, gives the rest of IR a pointer w/o
25311116Sdim// param space restrictions, and gives us an opportunity to eliminate
26311116Sdim// the copy.
27311116Sdim//
28311116Sdim// Pointer arguments to kernel functions need more work to be lowered:
29311116Sdim//
30311116Sdim// 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
31311116Sdim//    global address space. This allows later optimizations to emit
32311116Sdim//    ld.global.*/st.global.* for accessing these pointer arguments. For
33311116Sdim//    example,
34311116Sdim//
35311116Sdim//    define void @foo(float* %input) {
36311116Sdim//      %v = load float, float* %input, align 4
37311116Sdim//      ...
38311116Sdim//    }
39311116Sdim//
40311116Sdim//    becomes
41311116Sdim//
42311116Sdim//    define void @foo(float* %input) {
43311116Sdim//      %input2 = addrspacecast float* %input to float addrspace(1)*
44311116Sdim//      %input3 = addrspacecast float addrspace(1)* %input2 to float*
45311116Sdim//      %v = load float, float* %input3, align 4
46311116Sdim//      ...
47311116Sdim//    }
48311116Sdim//
49311116Sdim//    Later, NVPTXInferAddressSpaces will optimize it to
50311116Sdim//
51311116Sdim//    define void @foo(float* %input) {
52311116Sdim//      %input2 = addrspacecast float* %input to float addrspace(1)*
53311116Sdim//      %v = load float, float addrspace(1)* %input2, align 4
54311116Sdim//      ...
55311116Sdim//    }
56311116Sdim//
57311116Sdim// 2. Convert pointers in a byval kernel parameter to pointers in the global
58311116Sdim//    address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
59311116Sdim//
60311116Sdim//    struct S {
61311116Sdim//      int *x;
62311116Sdim//      int *y;
63311116Sdim//    };
64311116Sdim//    __global__ void foo(S s) {
65311116Sdim//      int *b = s.y;
66311116Sdim//      // use b
67311116Sdim//    }
68311116Sdim//
69311116Sdim//    "b" points to the global address space. In the IR level,
70311116Sdim//
71311116Sdim//    define void @foo({i32*, i32*}* byval %input) {
72311116Sdim//      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
73311116Sdim//      %b = load i32*, i32** %b_ptr
74311116Sdim//      ; use %b
75311116Sdim//    }
76311116Sdim//
77311116Sdim//    becomes
78311116Sdim//
79311116Sdim//    define void @foo({i32*, i32*}* byval %input) {
80311116Sdim//      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81311116Sdim//      %b = load i32*, i32** %b_ptr
82311116Sdim//      %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83311116Sdim//      %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
84311116Sdim//      ; use %b_generic
85311116Sdim//    }
86311116Sdim//
87311116Sdim// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88311116Sdim// cancel the addrspacecast pair this pass emits.
89311116Sdim//===----------------------------------------------------------------------===//
90311116Sdim
91311116Sdim#include "NVPTX.h"
92321369Sdim#include "NVPTXTargetMachine.h"
93311116Sdim#include "NVPTXUtilities.h"
94353358Sdim#include "MCTargetDesc/NVPTXBaseInfo.h"
95311116Sdim#include "llvm/Analysis/ValueTracking.h"
96311116Sdim#include "llvm/IR/Function.h"
97311116Sdim#include "llvm/IR/Instructions.h"
98311116Sdim#include "llvm/IR/Module.h"
99311116Sdim#include "llvm/IR/Type.h"
100311116Sdim#include "llvm/Pass.h"
101311116Sdim
102311116Sdimusing namespace llvm;
103311116Sdim
104311116Sdimnamespace llvm {
105311116Sdimvoid initializeNVPTXLowerArgsPass(PassRegistry &);
106311116Sdim}
107311116Sdim
108311116Sdimnamespace {
109311116Sdimclass NVPTXLowerArgs : public FunctionPass {
110311116Sdim  bool runOnFunction(Function &F) override;
111311116Sdim
112311116Sdim  bool runOnKernelFunction(Function &F);
113311116Sdim  bool runOnDeviceFunction(Function &F);
114311116Sdim
115311116Sdim  // handle byval parameters
116311116Sdim  void handleByValParam(Argument *Arg);
117311116Sdim  // Knowing Ptr must point to the global address space, this function
118311116Sdim  // addrspacecasts Ptr to global and then back to generic. This allows
119311116Sdim  // NVPTXInferAddressSpaces to fold the global-to-generic cast into
120311116Sdim  // loads/stores that appear later.
121311116Sdim  void markPointerAsGlobal(Value *Ptr);
122311116Sdim
123311116Sdimpublic:
124311116Sdim  static char ID; // Pass identification, replacement for typeid
125311116Sdim  NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr)
126311116Sdim      : FunctionPass(ID), TM(TM) {}
127311116Sdim  StringRef getPassName() const override {
128311116Sdim    return "Lower pointer arguments of CUDA kernels";
129311116Sdim  }
130311116Sdim
131311116Sdimprivate:
132311116Sdim  const NVPTXTargetMachine *TM;
133311116Sdim};
134311116Sdim} // namespace
135311116Sdim
136311116Sdimchar NVPTXLowerArgs::ID = 1;
137311116Sdim
138311116SdimINITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args",
139311116Sdim                "Lower arguments (NVPTX)", false, false)
140311116Sdim
141311116Sdim// =============================================================================
142311116Sdim// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
143311116Sdim// then add the following instructions to the first basic block:
144311116Sdim//
145311116Sdim// %temp = alloca %struct.x, align 8
146311116Sdim// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
147311116Sdim// %tv = load %struct.x addrspace(101)* %tempd
148311116Sdim// store %struct.x %tv, %struct.x* %temp, align 8
149311116Sdim//
150311116Sdim// The above code allocates some space in the stack and copies the incoming
151311116Sdim// struct from param space to local space.
152311116Sdim// Then replace all occurrences of %d by %temp.
153311116Sdim// =============================================================================
154311116Sdimvoid NVPTXLowerArgs::handleByValParam(Argument *Arg) {
155311116Sdim  Function *Func = Arg->getParent();
156311116Sdim  Instruction *FirstInst = &(Func->getEntryBlock().front());
157311116Sdim  PointerType *PType = dyn_cast<PointerType>(Arg->getType());
158311116Sdim
159311116Sdim  assert(PType && "Expecting pointer type in handleByValParam");
160311116Sdim
161311116Sdim  Type *StructType = PType->getElementType();
162321369Sdim  unsigned AS = Func->getParent()->getDataLayout().getAllocaAddrSpace();
163321369Sdim  AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
164311116Sdim  // Set the alignment to alignment of the byval parameter. This is because,
165311116Sdim  // later load/stores assume that alignment, and we are going to replace
166311116Sdim  // the use of the byval parameter with this alloca instruction.
167360784Sdim  AllocA->setAlignment(MaybeAlign(Func->getParamAlignment(Arg->getArgNo())));
168311116Sdim  Arg->replaceAllUsesWith(AllocA);
169311116Sdim
170311116Sdim  Value *ArgInParam = new AddrSpaceCastInst(
171311116Sdim      Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
172311116Sdim      FirstInst);
173353358Sdim  LoadInst *LI =
174353358Sdim      new LoadInst(StructType, ArgInParam, Arg->getName(), FirstInst);
175311116Sdim  new StoreInst(LI, AllocA, FirstInst);
176311116Sdim}
177311116Sdim
178311116Sdimvoid NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
179311116Sdim  if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
180311116Sdim    return;
181311116Sdim
182311116Sdim  // Deciding where to emit the addrspacecast pair.
183311116Sdim  BasicBlock::iterator InsertPt;
184311116Sdim  if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
185311116Sdim    // Insert at the functon entry if Ptr is an argument.
186311116Sdim    InsertPt = Arg->getParent()->getEntryBlock().begin();
187311116Sdim  } else {
188311116Sdim    // Insert right after Ptr if Ptr is an instruction.
189311116Sdim    InsertPt = ++cast<Instruction>(Ptr)->getIterator();
190311116Sdim    assert(InsertPt != InsertPt->getParent()->end() &&
191311116Sdim           "We don't call this function with Ptr being a terminator.");
192311116Sdim  }
193311116Sdim
194311116Sdim  Instruction *PtrInGlobal = new AddrSpaceCastInst(
195311116Sdim      Ptr, PointerType::get(Ptr->getType()->getPointerElementType(),
196311116Sdim                            ADDRESS_SPACE_GLOBAL),
197311116Sdim      Ptr->getName(), &*InsertPt);
198311116Sdim  Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
199311116Sdim                                              Ptr->getName(), &*InsertPt);
200311116Sdim  // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
201311116Sdim  Ptr->replaceAllUsesWith(PtrInGeneric);
202311116Sdim  PtrInGlobal->setOperand(0, Ptr);
203311116Sdim}
204311116Sdim
205311116Sdim// =============================================================================
206311116Sdim// Main function for this pass.
207311116Sdim// =============================================================================
208311116Sdimbool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
209311116Sdim  if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
210311116Sdim    // Mark pointers in byval structs as global.
211311116Sdim    for (auto &B : F) {
212311116Sdim      for (auto &I : B) {
213311116Sdim        if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
214311116Sdim          if (LI->getType()->isPointerTy()) {
215311116Sdim            Value *UO = GetUnderlyingObject(LI->getPointerOperand(),
216311116Sdim                                            F.getParent()->getDataLayout());
217311116Sdim            if (Argument *Arg = dyn_cast<Argument>(UO)) {
218311116Sdim              if (Arg->hasByValAttr()) {
219311116Sdim                // LI is a load from a pointer within a byval kernel parameter.
220311116Sdim                markPointerAsGlobal(LI);
221311116Sdim              }
222311116Sdim            }
223311116Sdim          }
224311116Sdim        }
225311116Sdim      }
226311116Sdim    }
227311116Sdim  }
228311116Sdim
229311116Sdim  for (Argument &Arg : F.args()) {
230311116Sdim    if (Arg.getType()->isPointerTy()) {
231311116Sdim      if (Arg.hasByValAttr())
232311116Sdim        handleByValParam(&Arg);
233311116Sdim      else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
234311116Sdim        markPointerAsGlobal(&Arg);
235311116Sdim    }
236311116Sdim  }
237311116Sdim  return true;
238311116Sdim}
239311116Sdim
240311116Sdim// Device functions only need to copy byval args into local memory.
241311116Sdimbool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
242311116Sdim  for (Argument &Arg : F.args())
243311116Sdim    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
244311116Sdim      handleByValParam(&Arg);
245311116Sdim  return true;
246311116Sdim}
247311116Sdim
248311116Sdimbool NVPTXLowerArgs::runOnFunction(Function &F) {
249311116Sdim  return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F);
250311116Sdim}
251311116Sdim
252311116SdimFunctionPass *
253311116Sdimllvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) {
254311116Sdim  return new NVPTXLowerArgs(TM);
255311116Sdim}
256