1//===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass replaces accesses to kernel arguments with loads from
10/// offsets from the kernarg base pointer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AMDGPU.h"
15#include "AMDGPUSubtarget.h"
16#include "AMDGPUTargetMachine.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Analysis/Loads.h"
19#include "llvm/CodeGen/Passes.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/BasicBlock.h"
23#include "llvm/IR/Constants.h"
24#include "llvm/IR/DerivedTypes.h"
25#include "llvm/IR/Function.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/InstrTypes.h"
28#include "llvm/IR/Instruction.h"
29#include "llvm/IR/Instructions.h"
30#include "llvm/IR/LLVMContext.h"
31#include "llvm/IR/MDBuilder.h"
32#include "llvm/IR/Metadata.h"
33#include "llvm/IR/Operator.h"
34#include "llvm/IR/Type.h"
35#include "llvm/IR/Value.h"
36#include "llvm/Pass.h"
37#include "llvm/Support/Casting.h"
38
39#define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
40
41using namespace llvm;
42
43namespace {
44
45class AMDGPULowerKernelArguments : public FunctionPass{
46public:
47  static char ID;
48
49  AMDGPULowerKernelArguments() : FunctionPass(ID) {}
50
51  bool runOnFunction(Function &F) override;
52
53  void getAnalysisUsage(AnalysisUsage &AU) const override {
54    AU.addRequired<TargetPassConfig>();
55    AU.setPreservesAll();
56 }
57};
58
59} // end anonymous namespace
60
61// skip allocas
62static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
63  BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
64  for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
65    AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
66
67    // If this is a dynamic alloca, the value may depend on the loaded kernargs,
68    // so loads will need to be inserted before it.
69    if (!AI || !AI->isStaticAlloca())
70      break;
71  }
72
73  return InsPt;
74}
75
76bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
77  CallingConv::ID CC = F.getCallingConv();
78  if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
79    return false;
80
81  auto &TPC = getAnalysis<TargetPassConfig>();
82
83  const TargetMachine &TM = TPC.getTM<TargetMachine>();
84  const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
85  LLVMContext &Ctx = F.getParent()->getContext();
86  const DataLayout &DL = F.getParent()->getDataLayout();
87  BasicBlock &EntryBlock = *F.begin();
88  IRBuilder<> Builder(&*getInsertPt(EntryBlock));
89
90  const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
91  const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
92
93  Align MaxAlign;
94  // FIXME: Alignment is broken broken with explicit arg offset.;
95  const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
96  if (TotalKernArgSize == 0)
97    return false;
98
99  CallInst *KernArgSegment =
100      Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
101                              nullptr, F.getName() + ".kernarg.segment");
102
103  KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
104  KernArgSegment->addAttribute(AttributeList::ReturnIndex,
105    Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
106
107  unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
108  uint64_t ExplicitArgOffset = 0;
109
110  for (Argument &Arg : F.args()) {
111    Type *ArgTy = Arg.getType();
112    Align ABITypeAlign = DL.getABITypeAlign(ArgTy);
113    unsigned Size = DL.getTypeSizeInBits(ArgTy);
114    unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
115
116    uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
117    ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
118
119    if (Arg.use_empty())
120      continue;
121
122    if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
123      // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
124      // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
125      // can't represent this with range metadata because it's only allowed for
126      // integer types.
127      if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
128           PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
129          !ST.hasUsableDSOffset())
130        continue;
131
132      // FIXME: We can replace this with equivalent alias.scope/noalias
133      // metadata, but this appears to be a lot of work.
134      if (Arg.hasNoAliasAttr())
135        continue;
136    }
137
138    auto *VT = dyn_cast<FixedVectorType>(ArgTy);
139    bool IsV3 = VT && VT->getNumElements() == 3;
140    bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
141
142    VectorType *V4Ty = nullptr;
143
144    int64_t AlignDownOffset = alignDown(EltOffset, 4);
145    int64_t OffsetDiff = EltOffset - AlignDownOffset;
146    Align AdjustedAlign = commonAlignment(
147        KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
148
149    Value *ArgPtr;
150    Type *AdjustedArgTy;
151    if (DoShiftOpt) { // FIXME: Handle aggregate types
152      // Since we don't have sub-dword scalar loads, avoid doing an extload by
153      // loading earlier than the argument address, and extracting the relevant
154      // bits.
155      //
156      // Additionally widen any sub-dword load to i32 even if suitably aligned,
157      // so that CSE between different argument loads works easily.
158      ArgPtr = Builder.CreateConstInBoundsGEP1_64(
159          Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
160          Arg.getName() + ".kernarg.offset.align.down");
161      AdjustedArgTy = Builder.getInt32Ty();
162    } else {
163      ArgPtr = Builder.CreateConstInBoundsGEP1_64(
164          Builder.getInt8Ty(), KernArgSegment, EltOffset,
165          Arg.getName() + ".kernarg.offset");
166      AdjustedArgTy = ArgTy;
167    }
168
169    if (IsV3 && Size >= 32) {
170      V4Ty = FixedVectorType::get(VT->getElementType(), 4);
171      // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
172      AdjustedArgTy = V4Ty;
173    }
174
175    ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
176                                   ArgPtr->getName() + ".cast");
177    LoadInst *Load =
178        Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
179    Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
180
181    MDBuilder MDB(Ctx);
182
183    if (isa<PointerType>(ArgTy)) {
184      if (Arg.hasNonNullAttr())
185        Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
186
187      uint64_t DerefBytes = Arg.getDereferenceableBytes();
188      if (DerefBytes != 0) {
189        Load->setMetadata(
190          LLVMContext::MD_dereferenceable,
191          MDNode::get(Ctx,
192                      MDB.createConstant(
193                        ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
194      }
195
196      uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
197      if (DerefOrNullBytes != 0) {
198        Load->setMetadata(
199          LLVMContext::MD_dereferenceable_or_null,
200          MDNode::get(Ctx,
201                      MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
202                                                          DerefOrNullBytes))));
203      }
204
205      unsigned ParamAlign = Arg.getParamAlignment();
206      if (ParamAlign != 0) {
207        Load->setMetadata(
208          LLVMContext::MD_align,
209          MDNode::get(Ctx,
210                      MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
211                                                          ParamAlign))));
212      }
213    }
214
215    // TODO: Convert noalias arg to !noalias
216
217    if (DoShiftOpt) {
218      Value *ExtractBits = OffsetDiff == 0 ?
219        Load : Builder.CreateLShr(Load, OffsetDiff * 8);
220
221      IntegerType *ArgIntTy = Builder.getIntNTy(Size);
222      Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
223      Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
224                                            Arg.getName() + ".load");
225      Arg.replaceAllUsesWith(NewVal);
226    } else if (IsV3) {
227      Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
228                                                ArrayRef<int>{0, 1, 2},
229                                                Arg.getName() + ".load");
230      Arg.replaceAllUsesWith(Shuf);
231    } else {
232      Load->setName(Arg.getName() + ".load");
233      Arg.replaceAllUsesWith(Load);
234    }
235  }
236
237  KernArgSegment->addAttribute(
238      AttributeList::ReturnIndex,
239      Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
240
241  return true;
242}
243
244INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
245                      "AMDGPU Lower Kernel Arguments", false, false)
246INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
247                    false, false)
248
249char AMDGPULowerKernelArguments::ID = 0;
250
251FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
252  return new AMDGPULowerKernelArguments();
253}
254