1//===- NVPTXUtilities.cpp - Utility Functions -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains miscellaneous utility functions
10//
11//===----------------------------------------------------------------------===//
12
13#include "NVPTXUtilities.h"
14#include "NVPTX.h"
15#include "llvm/IR/Constants.h"
16#include "llvm/IR/Function.h"
17#include "llvm/IR/GlobalVariable.h"
18#include "llvm/IR/InstIterator.h"
19#include "llvm/IR/Module.h"
20#include "llvm/IR/Operator.h"
21#include "llvm/Support/ManagedStatic.h"
22#include "llvm/Support/Mutex.h"
23#include <algorithm>
24#include <cstring>
25#include <map>
26#include <mutex>
27#include <string>
28#include <vector>
29
30namespace llvm {
31
32namespace {
33typedef std::map<std::string, std::vector<unsigned> > key_val_pair_t;
34typedef std::map<const GlobalValue *, key_val_pair_t> global_val_annot_t;
35typedef std::map<const Module *, global_val_annot_t> per_module_annot_t;
36} // anonymous namespace
37
38static ManagedStatic<per_module_annot_t> annotationCache;
39static sys::Mutex Lock;
40
41void clearAnnotationCache(const Module *Mod) {
42  std::lock_guard<sys::Mutex> Guard(Lock);
43  annotationCache->erase(Mod);
44}
45
46static void cacheAnnotationFromMD(const MDNode *md, key_val_pair_t &retval) {
47  std::lock_guard<sys::Mutex> Guard(Lock);
48  assert(md && "Invalid mdnode for annotation");
49  assert((md->getNumOperands() % 2) == 1 && "Invalid number of operands");
50  // start index = 1, to skip the global variable key
51  // increment = 2, to skip the value for each property-value pairs
52  for (unsigned i = 1, e = md->getNumOperands(); i != e; i += 2) {
53    // property
54    const MDString *prop = dyn_cast<MDString>(md->getOperand(i));
55    assert(prop && "Annotation property not a string");
56
57    // value
58    ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(md->getOperand(i + 1));
59    assert(Val && "Value operand not a constant int");
60
61    std::string keyname = prop->getString().str();
62    if (retval.find(keyname) != retval.end())
63      retval[keyname].push_back(Val->getZExtValue());
64    else {
65      std::vector<unsigned> tmp;
66      tmp.push_back(Val->getZExtValue());
67      retval[keyname] = tmp;
68    }
69  }
70}
71
72static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
73  std::lock_guard<sys::Mutex> Guard(Lock);
74  NamedMDNode *NMD = m->getNamedMetadata("nvvm.annotations");
75  if (!NMD)
76    return;
77  key_val_pair_t tmp;
78  for (unsigned i = 0, e = NMD->getNumOperands(); i != e; ++i) {
79    const MDNode *elem = NMD->getOperand(i);
80
81    GlobalValue *entity =
82        mdconst::dyn_extract_or_null<GlobalValue>(elem->getOperand(0));
83    // entity may be null due to DCE
84    if (!entity)
85      continue;
86    if (entity != gv)
87      continue;
88
89    // accumulate annotations for entity in tmp
90    cacheAnnotationFromMD(elem, tmp);
91  }
92
93  if (tmp.empty()) // no annotations for this gv
94    return;
95
96  if ((*annotationCache).find(m) != (*annotationCache).end())
97    (*annotationCache)[m][gv] = std::move(tmp);
98  else {
99    global_val_annot_t tmp1;
100    tmp1[gv] = std::move(tmp);
101    (*annotationCache)[m] = std::move(tmp1);
102  }
103}
104
105bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
106                           unsigned &retval) {
107  std::lock_guard<sys::Mutex> Guard(Lock);
108  const Module *m = gv->getParent();
109  if ((*annotationCache).find(m) == (*annotationCache).end())
110    cacheAnnotationFromMD(m, gv);
111  else if ((*annotationCache)[m].find(gv) == (*annotationCache)[m].end())
112    cacheAnnotationFromMD(m, gv);
113  if ((*annotationCache)[m][gv].find(prop) == (*annotationCache)[m][gv].end())
114    return false;
115  retval = (*annotationCache)[m][gv][prop][0];
116  return true;
117}
118
119bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
120                           std::vector<unsigned> &retval) {
121  std::lock_guard<sys::Mutex> Guard(Lock);
122  const Module *m = gv->getParent();
123  if ((*annotationCache).find(m) == (*annotationCache).end())
124    cacheAnnotationFromMD(m, gv);
125  else if ((*annotationCache)[m].find(gv) == (*annotationCache)[m].end())
126    cacheAnnotationFromMD(m, gv);
127  if ((*annotationCache)[m][gv].find(prop) == (*annotationCache)[m][gv].end())
128    return false;
129  retval = (*annotationCache)[m][gv][prop];
130  return true;
131}
132
133bool isTexture(const Value &val) {
134  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
135    unsigned annot;
136    if (findOneNVVMAnnotation(gv, "texture", annot)) {
137      assert((annot == 1) && "Unexpected annotation on a texture symbol");
138      return true;
139    }
140  }
141  return false;
142}
143
144bool isSurface(const Value &val) {
145  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
146    unsigned annot;
147    if (findOneNVVMAnnotation(gv, "surface", annot)) {
148      assert((annot == 1) && "Unexpected annotation on a surface symbol");
149      return true;
150    }
151  }
152  return false;
153}
154
155bool isSampler(const Value &val) {
156  const char *AnnotationName = "sampler";
157
158  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
159    unsigned annot;
160    if (findOneNVVMAnnotation(gv, AnnotationName, annot)) {
161      assert((annot == 1) && "Unexpected annotation on a sampler symbol");
162      return true;
163    }
164  }
165  if (const Argument *arg = dyn_cast<Argument>(&val)) {
166    const Function *func = arg->getParent();
167    std::vector<unsigned> annot;
168    if (findAllNVVMAnnotation(func, AnnotationName, annot)) {
169      if (is_contained(annot, arg->getArgNo()))
170        return true;
171    }
172  }
173  return false;
174}
175
176bool isImageReadOnly(const Value &val) {
177  if (const Argument *arg = dyn_cast<Argument>(&val)) {
178    const Function *func = arg->getParent();
179    std::vector<unsigned> annot;
180    if (findAllNVVMAnnotation(func, "rdoimage", annot)) {
181      if (is_contained(annot, arg->getArgNo()))
182        return true;
183    }
184  }
185  return false;
186}
187
188bool isImageWriteOnly(const Value &val) {
189  if (const Argument *arg = dyn_cast<Argument>(&val)) {
190    const Function *func = arg->getParent();
191    std::vector<unsigned> annot;
192    if (findAllNVVMAnnotation(func, "wroimage", annot)) {
193      if (is_contained(annot, arg->getArgNo()))
194        return true;
195    }
196  }
197  return false;
198}
199
200bool isImageReadWrite(const Value &val) {
201  if (const Argument *arg = dyn_cast<Argument>(&val)) {
202    const Function *func = arg->getParent();
203    std::vector<unsigned> annot;
204    if (findAllNVVMAnnotation(func, "rdwrimage", annot)) {
205      if (is_contained(annot, arg->getArgNo()))
206        return true;
207    }
208  }
209  return false;
210}
211
212bool isImage(const Value &val) {
213  return isImageReadOnly(val) || isImageWriteOnly(val) || isImageReadWrite(val);
214}
215
216bool isManaged(const Value &val) {
217  if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
218    unsigned annot;
219    if (findOneNVVMAnnotation(gv, "managed", annot)) {
220      assert((annot == 1) && "Unexpected annotation on a managed symbol");
221      return true;
222    }
223  }
224  return false;
225}
226
227std::string getTextureName(const Value &val) {
228  assert(val.hasName() && "Found texture variable with no name");
229  return val.getName();
230}
231
232std::string getSurfaceName(const Value &val) {
233  assert(val.hasName() && "Found surface variable with no name");
234  return val.getName();
235}
236
237std::string getSamplerName(const Value &val) {
238  assert(val.hasName() && "Found sampler variable with no name");
239  return val.getName();
240}
241
242bool getMaxNTIDx(const Function &F, unsigned &x) {
243  return findOneNVVMAnnotation(&F, "maxntidx", x);
244}
245
246bool getMaxNTIDy(const Function &F, unsigned &y) {
247  return findOneNVVMAnnotation(&F, "maxntidy", y);
248}
249
250bool getMaxNTIDz(const Function &F, unsigned &z) {
251  return findOneNVVMAnnotation(&F, "maxntidz", z);
252}
253
254bool getReqNTIDx(const Function &F, unsigned &x) {
255  return findOneNVVMAnnotation(&F, "reqntidx", x);
256}
257
258bool getReqNTIDy(const Function &F, unsigned &y) {
259  return findOneNVVMAnnotation(&F, "reqntidy", y);
260}
261
262bool getReqNTIDz(const Function &F, unsigned &z) {
263  return findOneNVVMAnnotation(&F, "reqntidz", z);
264}
265
266bool getMinCTASm(const Function &F, unsigned &x) {
267  return findOneNVVMAnnotation(&F, "minctasm", x);
268}
269
270bool getMaxNReg(const Function &F, unsigned &x) {
271  return findOneNVVMAnnotation(&F, "maxnreg", x);
272}
273
274bool isKernelFunction(const Function &F) {
275  unsigned x = 0;
276  bool retval = findOneNVVMAnnotation(&F, "kernel", x);
277  if (!retval) {
278    // There is no NVVM metadata, check the calling convention
279    return F.getCallingConv() == CallingConv::PTX_Kernel;
280  }
281  return (x == 1);
282}
283
284bool getAlign(const Function &F, unsigned index, unsigned &align) {
285  std::vector<unsigned> Vs;
286  bool retval = findAllNVVMAnnotation(&F, "align", Vs);
287  if (!retval)
288    return false;
289  for (int i = 0, e = Vs.size(); i < e; i++) {
290    unsigned v = Vs[i];
291    if ((v >> 16) == index) {
292      align = v & 0xFFFF;
293      return true;
294    }
295  }
296  return false;
297}
298
299bool getAlign(const CallInst &I, unsigned index, unsigned &align) {
300  if (MDNode *alignNode = I.getMetadata("callalign")) {
301    for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
302      if (const ConstantInt *CI =
303              mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
304        unsigned v = CI->getZExtValue();
305        if ((v >> 16) == index) {
306          align = v & 0xFFFF;
307          return true;
308        }
309        if ((v >> 16) > index) {
310          return false;
311        }
312      }
313    }
314  }
315  return false;
316}
317
318} // namespace llvm
319