1303231Sdim//===- NVVMIntrRange.cpp - Set !range metadata for NVVM intrinsics --------===//
2303231Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6303231Sdim//
7303231Sdim//===----------------------------------------------------------------------===//
8303231Sdim//
9303231Sdim// This pass adds appropriate !range metadata for calls to NVVM
10303231Sdim// intrinsics that return a limited range of values.
11303231Sdim//
12303231Sdim//===----------------------------------------------------------------------===//
13303231Sdim
14303231Sdim#include "NVPTX.h"
15303231Sdim#include "llvm/IR/Constants.h"
16303231Sdim#include "llvm/IR/InstIterator.h"
17321369Sdim#include "llvm/IR/Instructions.h"
18303231Sdim#include "llvm/IR/Intrinsics.h"
19360784Sdim#include "llvm/IR/IntrinsicsNVPTX.h"
20360784Sdim#include "llvm/Support/CommandLine.h"
21303231Sdim
22303231Sdimusing namespace llvm;
23303231Sdim
24303231Sdim#define DEBUG_TYPE "nvvm-intr-range"
25303231Sdim
26303231Sdimnamespace llvm { void initializeNVVMIntrRangePass(PassRegistry &); }
27303231Sdim
28303231Sdim// Add !range metadata based on limits of given SM variant.
29303231Sdimstatic cl::opt<unsigned> NVVMIntrRangeSM("nvvm-intr-range-sm", cl::init(20),
30303231Sdim                                         cl::Hidden, cl::desc("SM variant"));
31303231Sdim
32303231Sdimnamespace {
33303231Sdimclass NVVMIntrRange : public FunctionPass {
34303231Sdim private:
35303231Sdim   struct {
36303231Sdim     unsigned x, y, z;
37303231Sdim   } MaxBlockSize, MaxGridSize;
38303231Sdim
39303231Sdim public:
40303231Sdim   static char ID;
41303231Sdim   NVVMIntrRange() : NVVMIntrRange(NVVMIntrRangeSM) {}
42303231Sdim   NVVMIntrRange(unsigned int SmVersion) : FunctionPass(ID) {
43303231Sdim     MaxBlockSize.x = 1024;
44303231Sdim     MaxBlockSize.y = 1024;
45303231Sdim     MaxBlockSize.z = 64;
46303231Sdim
47303231Sdim     MaxGridSize.x = SmVersion >= 30 ? 0x7fffffff : 0xffff;
48303231Sdim     MaxGridSize.y = 0xffff;
49303231Sdim     MaxGridSize.z = 0xffff;
50303231Sdim
51303231Sdim     initializeNVVMIntrRangePass(*PassRegistry::getPassRegistry());
52303231Sdim   }
53303231Sdim
54303231Sdim   bool runOnFunction(Function &) override;
55303231Sdim};
56303231Sdim}
57303231Sdim
58303231SdimFunctionPass *llvm::createNVVMIntrRangePass(unsigned int SmVersion) {
59303231Sdim  return new NVVMIntrRange(SmVersion);
60303231Sdim}
61303231Sdim
62303231Sdimchar NVVMIntrRange::ID = 0;
63303231SdimINITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
64303231Sdim                "Add !range metadata to NVVM intrinsics.", false, false)
65303231Sdim
66303231Sdim// Adds the passed-in [Low,High) range information as metadata to the
67303231Sdim// passed-in call instruction.
68303231Sdimstatic bool addRangeMetadata(uint64_t Low, uint64_t High, CallInst *C) {
69314564Sdim  // This call already has range metadata, nothing to do.
70314564Sdim  if (C->getMetadata(LLVMContext::MD_range))
71314564Sdim    return false;
72314564Sdim
73303231Sdim  LLVMContext &Context = C->getParent()->getContext();
74303231Sdim  IntegerType *Int32Ty = Type::getInt32Ty(Context);
75303231Sdim  Metadata *LowAndHigh[] = {
76303231Sdim      ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Low)),
77303231Sdim      ConstantAsMetadata::get(ConstantInt::get(Int32Ty, High))};
78303231Sdim  C->setMetadata(LLVMContext::MD_range, MDNode::get(Context, LowAndHigh));
79303231Sdim  return true;
80303231Sdim}
81303231Sdim
82303231Sdimbool NVVMIntrRange::runOnFunction(Function &F) {
83303231Sdim  // Go through the calls in this function.
84303231Sdim  bool Changed = false;
85303231Sdim  for (Instruction &I : instructions(F)) {
86303231Sdim    CallInst *Call = dyn_cast<CallInst>(&I);
87303231Sdim    if (!Call)
88303231Sdim      continue;
89303231Sdim
90303231Sdim    if (Function *Callee = Call->getCalledFunction()) {
91303231Sdim      switch (Callee->getIntrinsicID()) {
92303231Sdim      // Index within block
93303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_tid_x:
94303231Sdim        Changed |= addRangeMetadata(0, MaxBlockSize.x, Call);
95303231Sdim        break;
96303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_tid_y:
97303231Sdim        Changed |= addRangeMetadata(0, MaxBlockSize.y, Call);
98303231Sdim        break;
99303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_tid_z:
100303231Sdim        Changed |= addRangeMetadata(0, MaxBlockSize.z, Call);
101303231Sdim        break;
102303231Sdim
103303231Sdim      // Block size
104303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
105303231Sdim        Changed |= addRangeMetadata(1, MaxBlockSize.x+1, Call);
106303231Sdim        break;
107303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
108303231Sdim        Changed |= addRangeMetadata(1, MaxBlockSize.y+1, Call);
109303231Sdim        break;
110303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
111303231Sdim        Changed |= addRangeMetadata(1, MaxBlockSize.z+1, Call);
112303231Sdim        break;
113303231Sdim
114303231Sdim      // Index within grid
115303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
116303231Sdim        Changed |= addRangeMetadata(0, MaxGridSize.x, Call);
117303231Sdim        break;
118303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
119303231Sdim        Changed |= addRangeMetadata(0, MaxGridSize.y, Call);
120303231Sdim        break;
121303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
122303231Sdim        Changed |= addRangeMetadata(0, MaxGridSize.z, Call);
123303231Sdim        break;
124303231Sdim
125303231Sdim      // Grid size
126303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
127303231Sdim        Changed |= addRangeMetadata(1, MaxGridSize.x+1, Call);
128303231Sdim        break;
129303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
130303231Sdim        Changed |= addRangeMetadata(1, MaxGridSize.y+1, Call);
131303231Sdim        break;
132303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
133303231Sdim        Changed |= addRangeMetadata(1, MaxGridSize.z+1, Call);
134303231Sdim        break;
135303231Sdim
136303231Sdim      // warp size is constant 32.
137303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_warpsize:
138303231Sdim        Changed |= addRangeMetadata(32, 32+1, Call);
139303231Sdim        break;
140303231Sdim
141303231Sdim      // Lane ID is [0..warpsize)
142303231Sdim      case Intrinsic::nvvm_read_ptx_sreg_laneid:
143303231Sdim        Changed |= addRangeMetadata(0, 32, Call);
144303231Sdim        break;
145303231Sdim
146303231Sdim      default:
147303231Sdim        break;
148303231Sdim      }
149303231Sdim    }
150303231Sdim  }
151303231Sdim
152303231Sdim  return Changed;
153303231Sdim}
154