1//===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass replaces accesses to kernel arguments with loads from
10/// offsets from the kernarg base pointer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AMDGPU.h"
15#include "AMDGPUSubtarget.h"
16#include "AMDGPUTargetMachine.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Analysis/Loads.h"
19#include "llvm/CodeGen/Passes.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/BasicBlock.h"
23#include "llvm/IR/Constants.h"
24#include "llvm/IR/DerivedTypes.h"
25#include "llvm/IR/Function.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/InstrTypes.h"
28#include "llvm/IR/Instruction.h"
29#include "llvm/IR/Instructions.h"
30#include "llvm/IR/LLVMContext.h"
31#include "llvm/IR/MDBuilder.h"
32#include "llvm/IR/Metadata.h"
33#include "llvm/IR/Operator.h"
34#include "llvm/IR/Type.h"
35#include "llvm/IR/Value.h"
36#include "llvm/Pass.h"
37#include "llvm/Support/Casting.h"
38
39#define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
40
41using namespace llvm;
42
43namespace {
44
45class AMDGPULowerKernelArguments : public FunctionPass{
46public:
47  static char ID;
48
49  AMDGPULowerKernelArguments() : FunctionPass(ID) {}
50
51  bool runOnFunction(Function &F) override;
52
53  void getAnalysisUsage(AnalysisUsage &AU) const override {
54    AU.addRequired<TargetPassConfig>();
55    AU.setPreservesAll();
56 }
57};
58
59} // end anonymous namespace
60
61bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
62  CallingConv::ID CC = F.getCallingConv();
63  if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
64    return false;
65
66  auto &TPC = getAnalysis<TargetPassConfig>();
67
68  const TargetMachine &TM = TPC.getTM<TargetMachine>();
69  const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
70  LLVMContext &Ctx = F.getParent()->getContext();
71  const DataLayout &DL = F.getParent()->getDataLayout();
72  BasicBlock &EntryBlock = *F.begin();
73  IRBuilder<> Builder(&*EntryBlock.begin());
74
75  const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
76  const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
77
78  Align MaxAlign;
79  // FIXME: Alignment is broken broken with explicit arg offset.;
80  const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
81  if (TotalKernArgSize == 0)
82    return false;
83
84  CallInst *KernArgSegment =
85      Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
86                              nullptr, F.getName() + ".kernarg.segment");
87
88  KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
89  KernArgSegment->addAttribute(AttributeList::ReturnIndex,
90    Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
91
92  unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
93  uint64_t ExplicitArgOffset = 0;
94
95  for (Argument &Arg : F.args()) {
96    Type *ArgTy = Arg.getType();
97    unsigned ABITypeAlign = DL.getABITypeAlignment(ArgTy);
98    unsigned Size = DL.getTypeSizeInBits(ArgTy);
99    unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
100
101    uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
102    ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
103
104    if (Arg.use_empty())
105      continue;
106
107    if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
108      // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
109      // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
110      // can't represent this with range metadata because it's only allowed for
111      // integer types.
112      if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
113           PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
114          !ST.hasUsableDSOffset())
115        continue;
116
117      // FIXME: We can replace this with equivalent alias.scope/noalias
118      // metadata, but this appears to be a lot of work.
119      if (Arg.hasNoAliasAttr())
120        continue;
121    }
122
123    VectorType *VT = dyn_cast<VectorType>(ArgTy);
124    bool IsV3 = VT && VT->getNumElements() == 3;
125    bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
126
127    VectorType *V4Ty = nullptr;
128
129    int64_t AlignDownOffset = alignDown(EltOffset, 4);
130    int64_t OffsetDiff = EltOffset - AlignDownOffset;
131    Align AdjustedAlign = commonAlignment(
132        KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
133
134    Value *ArgPtr;
135    Type *AdjustedArgTy;
136    if (DoShiftOpt) { // FIXME: Handle aggregate types
137      // Since we don't have sub-dword scalar loads, avoid doing an extload by
138      // loading earlier than the argument address, and extracting the relevant
139      // bits.
140      //
141      // Additionally widen any sub-dword load to i32 even if suitably aligned,
142      // so that CSE between different argument loads works easily.
143      ArgPtr = Builder.CreateConstInBoundsGEP1_64(
144          Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
145          Arg.getName() + ".kernarg.offset.align.down");
146      AdjustedArgTy = Builder.getInt32Ty();
147    } else {
148      ArgPtr = Builder.CreateConstInBoundsGEP1_64(
149          Builder.getInt8Ty(), KernArgSegment, EltOffset,
150          Arg.getName() + ".kernarg.offset");
151      AdjustedArgTy = ArgTy;
152    }
153
154    if (IsV3 && Size >= 32) {
155      V4Ty = VectorType::get(VT->getVectorElementType(), 4);
156      // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
157      AdjustedArgTy = V4Ty;
158    }
159
160    ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
161                                   ArgPtr->getName() + ".cast");
162    LoadInst *Load =
163        Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign.value());
164    Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
165
166    MDBuilder MDB(Ctx);
167
168    if (isa<PointerType>(ArgTy)) {
169      if (Arg.hasNonNullAttr())
170        Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
171
172      uint64_t DerefBytes = Arg.getDereferenceableBytes();
173      if (DerefBytes != 0) {
174        Load->setMetadata(
175          LLVMContext::MD_dereferenceable,
176          MDNode::get(Ctx,
177                      MDB.createConstant(
178                        ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
179      }
180
181      uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
182      if (DerefOrNullBytes != 0) {
183        Load->setMetadata(
184          LLVMContext::MD_dereferenceable_or_null,
185          MDNode::get(Ctx,
186                      MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
187                                                          DerefOrNullBytes))));
188      }
189
190      unsigned ParamAlign = Arg.getParamAlignment();
191      if (ParamAlign != 0) {
192        Load->setMetadata(
193          LLVMContext::MD_align,
194          MDNode::get(Ctx,
195                      MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
196                                                          ParamAlign))));
197      }
198    }
199
200    // TODO: Convert noalias arg to !noalias
201
202    if (DoShiftOpt) {
203      Value *ExtractBits = OffsetDiff == 0 ?
204        Load : Builder.CreateLShr(Load, OffsetDiff * 8);
205
206      IntegerType *ArgIntTy = Builder.getIntNTy(Size);
207      Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
208      Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
209                                            Arg.getName() + ".load");
210      Arg.replaceAllUsesWith(NewVal);
211    } else if (IsV3) {
212      Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
213                                                {0, 1, 2},
214                                                Arg.getName() + ".load");
215      Arg.replaceAllUsesWith(Shuf);
216    } else {
217      Load->setName(Arg.getName() + ".load");
218      Arg.replaceAllUsesWith(Load);
219    }
220  }
221
222  KernArgSegment->addAttribute(
223      AttributeList::ReturnIndex,
224      Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
225
226  return true;
227}
228
229INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
230                      "AMDGPU Lower Kernel Arguments", false, false)
231INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
232                    false, false)
233
234char AMDGPULowerKernelArguments::ID = 0;
235
236FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
237  return new AMDGPULowerKernelArguments();
238}
239