1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXMCExpr.h"
21#include "NVPTXMachineFunctionInfo.h"
22#include "NVPTXRegisterInfo.h"
23#include "NVPTXSubtarget.h"
24#include "NVPTXTargetMachine.h"
25#include "NVPTXUtilities.h"
26#include "TargetInfo/NVPTXTargetInfo.h"
27#include "cl_common_defines.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/DenseSet.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/ADT/StringExtras.h"
35#include "llvm/ADT/StringRef.h"
36#include "llvm/ADT/Triple.h"
37#include "llvm/ADT/Twine.h"
38#include "llvm/Analysis/ConstantFolding.h"
39#include "llvm/CodeGen/Analysis.h"
40#include "llvm/CodeGen/MachineBasicBlock.h"
41#include "llvm/CodeGen/MachineFrameInfo.h"
42#include "llvm/CodeGen/MachineFunction.h"
43#include "llvm/CodeGen/MachineInstr.h"
44#include "llvm/CodeGen/MachineLoopInfo.h"
45#include "llvm/CodeGen/MachineModuleInfo.h"
46#include "llvm/CodeGen/MachineOperand.h"
47#include "llvm/CodeGen/MachineRegisterInfo.h"
48#include "llvm/CodeGen/TargetRegisterInfo.h"
49#include "llvm/CodeGen/ValueTypes.h"
50#include "llvm/IR/Attributes.h"
51#include "llvm/IR/BasicBlock.h"
52#include "llvm/IR/Constant.h"
53#include "llvm/IR/Constants.h"
54#include "llvm/IR/DataLayout.h"
55#include "llvm/IR/DebugInfo.h"
56#include "llvm/IR/DebugInfoMetadata.h"
57#include "llvm/IR/DebugLoc.h"
58#include "llvm/IR/DerivedTypes.h"
59#include "llvm/IR/Function.h"
60#include "llvm/IR/GlobalValue.h"
61#include "llvm/IR/GlobalVariable.h"
62#include "llvm/IR/Instruction.h"
63#include "llvm/IR/LLVMContext.h"
64#include "llvm/IR/Module.h"
65#include "llvm/IR/Operator.h"
66#include "llvm/IR/Type.h"
67#include "llvm/IR/User.h"
68#include "llvm/MC/MCExpr.h"
69#include "llvm/MC/MCInst.h"
70#include "llvm/MC/MCInstrDesc.h"
71#include "llvm/MC/MCStreamer.h"
72#include "llvm/MC/MCSymbol.h"
73#include "llvm/MC/TargetRegistry.h"
74#include "llvm/Support/Casting.h"
75#include "llvm/Support/CommandLine.h"
76#include "llvm/Support/Endian.h"
77#include "llvm/Support/ErrorHandling.h"
78#include "llvm/Support/MachineValueType.h"
79#include "llvm/Support/NativeFormatting.h"
80#include "llvm/Support/Path.h"
81#include "llvm/Support/raw_ostream.h"
82#include "llvm/Target/TargetLoweringObjectFile.h"
83#include "llvm/Target/TargetMachine.h"
84#include "llvm/Transforms/Utils/UnrollLoop.h"
85#include <cassert>
86#include <cstdint>
87#include <cstring>
88#include <new>
89#include <string>
90#include <utility>
91#include <vector>
92
93using namespace llvm;
94
95#define DEPOTNAME "__local_depot"
96
97/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
98/// depends.
99static void
100DiscoverDependentGlobals(const Value *V,
101                         DenseSet<const GlobalVariable *> &Globals) {
102  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
103    Globals.insert(GV);
104  else {
105    if (const User *U = dyn_cast<User>(V)) {
106      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
107        DiscoverDependentGlobals(U->getOperand(i), Globals);
108      }
109    }
110  }
111}
112
113/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
114/// instances to be emitted, but only after any dependents have been added
115/// first.s
116static void
117VisitGlobalVariableForEmission(const GlobalVariable *GV,
118                               SmallVectorImpl<const GlobalVariable *> &Order,
119                               DenseSet<const GlobalVariable *> &Visited,
120                               DenseSet<const GlobalVariable *> &Visiting) {
121  // Have we already visited this one?
122  if (Visited.count(GV))
123    return;
124
125  // Do we have a circular dependency?
126  if (!Visiting.insert(GV).second)
127    report_fatal_error("Circular dependency found in global variable set");
128
129  // Make sure we visit all dependents first
130  DenseSet<const GlobalVariable *> Others;
131  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
132    DiscoverDependentGlobals(GV->getOperand(i), Others);
133
134  for (const GlobalVariable *GV : Others)
135    VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
136
137  // Now we can visit ourself
138  Order.push_back(GV);
139  Visited.insert(GV);
140  Visiting.erase(GV);
141}
142
143void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
144  NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
145                                        getSubtargetInfo().getFeatureBits());
146
147  MCInst Inst;
148  lowerToMCInst(MI, Inst);
149  EmitToStreamer(*OutStreamer, Inst);
150}
151
152// Handle symbol backtracking for targets that do not support image handles
153bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154                                           unsigned OpNo, MCOperand &MCOp) {
155  const MachineOperand &MO = MI->getOperand(OpNo);
156  const MCInstrDesc &MCID = MI->getDesc();
157
158  if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159    // This is a texture fetch, so operand 4 is a texref and operand 5 is
160    // a samplerref
161    if (OpNo == 4 && MO.isImm()) {
162      lowerImageHandleSymbol(MO.getImm(), MCOp);
163      return true;
164    }
165    if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166      lowerImageHandleSymbol(MO.getImm(), MCOp);
167      return true;
168    }
169
170    return false;
171  } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172    unsigned VecSize =
173      1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174
175    // For a surface load of vector size N, the Nth operand will be the surfref
176    if (OpNo == VecSize && MO.isImm()) {
177      lowerImageHandleSymbol(MO.getImm(), MCOp);
178      return true;
179    }
180
181    return false;
182  } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183    // This is a surface store, so operand 0 is a surfref
184    if (OpNo == 0 && MO.isImm()) {
185      lowerImageHandleSymbol(MO.getImm(), MCOp);
186      return true;
187    }
188
189    return false;
190  } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191    // This is a query, so operand 1 is a surfref/texref
192    if (OpNo == 1 && MO.isImm()) {
193      lowerImageHandleSymbol(MO.getImm(), MCOp);
194      return true;
195    }
196
197    return false;
198  }
199
200  return false;
201}
202
203void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204  // Ewwww
205  LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
206  NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
207  const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208  const char *Sym = MFI->getImageHandleSymbol(Index);
209  StringRef SymName = nvTM.getStrPool().save(Sym);
210  MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
211}
212
213void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
214  OutMI.setOpcode(MI->getOpcode());
215  // Special: Do not mangle symbol operand of CALL_PROTOTYPE
216  if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
217    const MachineOperand &MO = MI->getOperand(0);
218    OutMI.addOperand(GetSymbolRef(
219      OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
220    return;
221  }
222
223  const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
224  for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
225    const MachineOperand &MO = MI->getOperand(i);
226
227    MCOperand MCOp;
228    if (!STI.hasImageHandles()) {
229      if (lowerImageHandleOperand(MI, i, MCOp)) {
230        OutMI.addOperand(MCOp);
231        continue;
232      }
233    }
234
235    if (lowerOperand(MO, MCOp))
236      OutMI.addOperand(MCOp);
237  }
238}
239
240bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
241                                   MCOperand &MCOp) {
242  switch (MO.getType()) {
243  default: llvm_unreachable("unknown operand type");
244  case MachineOperand::MO_Register:
245    MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
246    break;
247  case MachineOperand::MO_Immediate:
248    MCOp = MCOperand::createImm(MO.getImm());
249    break;
250  case MachineOperand::MO_MachineBasicBlock:
251    MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
252        MO.getMBB()->getSymbol(), OutContext));
253    break;
254  case MachineOperand::MO_ExternalSymbol:
255    MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
256    break;
257  case MachineOperand::MO_GlobalAddress:
258    MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
259    break;
260  case MachineOperand::MO_FPImmediate: {
261    const ConstantFP *Cnt = MO.getFPImm();
262    const APFloat &Val = Cnt->getValueAPF();
263
264    switch (Cnt->getType()->getTypeID()) {
265    default: report_fatal_error("Unsupported FP type"); break;
266    case Type::HalfTyID:
267      MCOp = MCOperand::createExpr(
268        NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
269      break;
270    case Type::FloatTyID:
271      MCOp = MCOperand::createExpr(
272        NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
273      break;
274    case Type::DoubleTyID:
275      MCOp = MCOperand::createExpr(
276        NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
277      break;
278    }
279    break;
280  }
281  }
282  return true;
283}
284
285unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
286  if (Register::isVirtualRegister(Reg)) {
287    const TargetRegisterClass *RC = MRI->getRegClass(Reg);
288
289    DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
290    unsigned RegNum = RegMap[Reg];
291
292    // Encode the register class in the upper 4 bits
293    // Must be kept in sync with NVPTXInstPrinter::printRegName
294    unsigned Ret = 0;
295    if (RC == &NVPTX::Int1RegsRegClass) {
296      Ret = (1 << 28);
297    } else if (RC == &NVPTX::Int16RegsRegClass) {
298      Ret = (2 << 28);
299    } else if (RC == &NVPTX::Int32RegsRegClass) {
300      Ret = (3 << 28);
301    } else if (RC == &NVPTX::Int64RegsRegClass) {
302      Ret = (4 << 28);
303    } else if (RC == &NVPTX::Float32RegsRegClass) {
304      Ret = (5 << 28);
305    } else if (RC == &NVPTX::Float64RegsRegClass) {
306      Ret = (6 << 28);
307    } else if (RC == &NVPTX::Float16RegsRegClass) {
308      Ret = (7 << 28);
309    } else if (RC == &NVPTX::Float16x2RegsRegClass) {
310      Ret = (8 << 28);
311    } else {
312      report_fatal_error("Bad register class");
313    }
314
315    // Insert the vreg number
316    Ret |= (RegNum & 0x0FFFFFFF);
317    return Ret;
318  } else {
319    // Some special-use registers are actually physical registers.
320    // Encode this as the register class ID of 0 and the real register ID.
321    return Reg & 0x0FFFFFFF;
322  }
323}
324
325MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
326  const MCExpr *Expr;
327  Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
328                                 OutContext);
329  return MCOperand::createExpr(Expr);
330}
331
332void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
333  const DataLayout &DL = getDataLayout();
334  const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
335  const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
336
337  Type *Ty = F->getReturnType();
338
339  bool isABI = (STI.getSmVersion() >= 20);
340
341  if (Ty->getTypeID() == Type::VoidTyID)
342    return;
343
344  O << " (";
345
346  if (isABI) {
347    if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
348      unsigned size = 0;
349      if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
350        size = ITy->getBitWidth();
351      } else {
352        assert(Ty->isFloatingPointTy() && "Floating point type expected here");
353        size = Ty->getPrimitiveSizeInBits();
354      }
355      // PTX ABI requires all scalar return values to be at least 32
356      // bits in size.  fp16 normally uses .b16 as its storage type in
357      // PTX, so its size must be adjusted here, too.
358      size = promoteScalarArgumentSize(size);
359
360      O << ".param .b" << size << " func_retval0";
361    } else if (isa<PointerType>(Ty)) {
362      O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
363        << " func_retval0";
364    } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
365      unsigned totalsz = DL.getTypeAllocSize(Ty);
366      unsigned retAlignment = 0;
367      if (!getAlign(*F, 0, retAlignment))
368        retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
369      O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
370        << "]";
371    } else
372      llvm_unreachable("Unknown return type");
373  } else {
374    SmallVector<EVT, 16> vtparts;
375    ComputeValueVTs(*TLI, DL, Ty, vtparts);
376    unsigned idx = 0;
377    for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378      unsigned elems = 1;
379      EVT elemtype = vtparts[i];
380      if (vtparts[i].isVector()) {
381        elems = vtparts[i].getVectorNumElements();
382        elemtype = vtparts[i].getVectorElementType();
383      }
384
385      for (unsigned j = 0, je = elems; j != je; ++j) {
386        unsigned sz = elemtype.getSizeInBits();
387        if (elemtype.isInteger())
388          sz = promoteScalarArgumentSize(sz);
389        O << ".reg .b" << sz << " func_retval" << idx;
390        if (j < je - 1)
391          O << ", ";
392        ++idx;
393      }
394      if (i < e - 1)
395        O << ", ";
396    }
397  }
398  O << ") ";
399}
400
401void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
402                                        raw_ostream &O) {
403  const Function &F = MF.getFunction();
404  printReturnValStr(&F, O);
405}
406
407// Return true if MBB is the header of a loop marked with
408// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
409bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
410    const MachineBasicBlock &MBB) const {
411  MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
412  // We insert .pragma "nounroll" only to the loop header.
413  if (!LI.isLoopHeader(&MBB))
414    return false;
415
416  // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
417  // we iterate through each back edge of the loop with header MBB, and check
418  // whether its metadata contains llvm.loop.unroll.disable.
419  for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
420    if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
421      // Edges from other loops to MBB are not back edges.
422      continue;
423    }
424    if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
425      if (MDNode *LoopID =
426              PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
427        if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
428          return true;
429        if (MDNode *UnrollCountMD =
430                GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
431          if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
432                  ->getZExtValue() == 1)
433            return true;
434        }
435      }
436    }
437  }
438  return false;
439}
440
441void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
442  AsmPrinter::emitBasicBlockStart(MBB);
443  if (isLoopHeaderOfNoUnroll(MBB))
444    OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
445}
446
447void NVPTXAsmPrinter::emitFunctionEntryLabel() {
448  SmallString<128> Str;
449  raw_svector_ostream O(Str);
450
451  if (!GlobalsEmitted) {
452    emitGlobals(*MF->getFunction().getParent());
453    GlobalsEmitted = true;
454  }
455
456  // Set up
457  MRI = &MF->getRegInfo();
458  F = &MF->getFunction();
459  emitLinkageDirective(F, O);
460  if (isKernelFunction(*F))
461    O << ".entry ";
462  else {
463    O << ".func ";
464    printReturnValStr(*MF, O);
465  }
466
467  CurrentFnSym->print(O, MAI);
468
469  emitFunctionParamList(*MF, O);
470
471  if (isKernelFunction(*F))
472    emitKernelFunctionDirectives(*F, O);
473
474  if (shouldEmitPTXNoReturn(F, TM))
475    O << ".noreturn";
476
477  OutStreamer->emitRawText(O.str());
478
479  VRegMapping.clear();
480  // Emit open brace for function body.
481  OutStreamer->emitRawText(StringRef("{\n"));
482  setAndEmitFunctionVirtualRegisters(*MF);
483  // Emit initial .loc debug directive for correct relocation symbol data.
484  if (MMI && MMI->hasDebugInfo())
485    emitInitialRawDwarfLocDirective(*MF);
486}
487
488bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
489  bool Result = AsmPrinter::runOnMachineFunction(F);
490  // Emit closing brace for the body of function F.
491  // The closing brace must be emitted here because we need to emit additional
492  // debug labels/data after the last basic block.
493  // We need to emit the closing brace here because we don't have function that
494  // finished emission of the function body.
495  OutStreamer->emitRawText(StringRef("}\n"));
496  return Result;
497}
498
499void NVPTXAsmPrinter::emitFunctionBodyStart() {
500  SmallString<128> Str;
501  raw_svector_ostream O(Str);
502  emitDemotedVars(&MF->getFunction(), O);
503  OutStreamer->emitRawText(O.str());
504}
505
506void NVPTXAsmPrinter::emitFunctionBodyEnd() {
507  VRegMapping.clear();
508}
509
510const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
511    SmallString<128> Str;
512    raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
513    return OutContext.getOrCreateSymbol(Str);
514}
515
516void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
517  Register RegNo = MI->getOperand(0).getReg();
518  if (RegNo.isVirtual()) {
519    OutStreamer->AddComment(Twine("implicit-def: ") +
520                            getVirtualRegisterName(RegNo));
521  } else {
522    const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
523    OutStreamer->AddComment(Twine("implicit-def: ") +
524                            STI.getRegisterInfo()->getName(RegNo));
525  }
526  OutStreamer->addBlankLine();
527}
528
529void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
530                                                   raw_ostream &O) const {
531  // If the NVVM IR has some of reqntid* specified, then output
532  // the reqntid directive, and set the unspecified ones to 1.
533  // If none of reqntid* is specified, don't output reqntid directive.
534  unsigned reqntidx, reqntidy, reqntidz;
535  bool specified = false;
536  if (!getReqNTIDx(F, reqntidx))
537    reqntidx = 1;
538  else
539    specified = true;
540  if (!getReqNTIDy(F, reqntidy))
541    reqntidy = 1;
542  else
543    specified = true;
544  if (!getReqNTIDz(F, reqntidz))
545    reqntidz = 1;
546  else
547    specified = true;
548
549  if (specified)
550    O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
551      << "\n";
552
553  // If the NVVM IR has some of maxntid* specified, then output
554  // the maxntid directive, and set the unspecified ones to 1.
555  // If none of maxntid* is specified, don't output maxntid directive.
556  unsigned maxntidx, maxntidy, maxntidz;
557  specified = false;
558  if (!getMaxNTIDx(F, maxntidx))
559    maxntidx = 1;
560  else
561    specified = true;
562  if (!getMaxNTIDy(F, maxntidy))
563    maxntidy = 1;
564  else
565    specified = true;
566  if (!getMaxNTIDz(F, maxntidz))
567    maxntidz = 1;
568  else
569    specified = true;
570
571  if (specified)
572    O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
573      << "\n";
574
575  unsigned mincta;
576  if (getMinCTASm(F, mincta))
577    O << ".minnctapersm " << mincta << "\n";
578
579  unsigned maxnreg;
580  if (getMaxNReg(F, maxnreg))
581    O << ".maxnreg " << maxnreg << "\n";
582}
583
584std::string
585NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
586  const TargetRegisterClass *RC = MRI->getRegClass(Reg);
587
588  std::string Name;
589  raw_string_ostream NameStr(Name);
590
591  VRegRCMap::const_iterator I = VRegMapping.find(RC);
592  assert(I != VRegMapping.end() && "Bad register class");
593  const DenseMap<unsigned, unsigned> &RegMap = I->second;
594
595  VRegMap::const_iterator VI = RegMap.find(Reg);
596  assert(VI != RegMap.end() && "Bad virtual register");
597  unsigned MappedVR = VI->second;
598
599  NameStr << getNVPTXRegClassStr(RC) << MappedVR;
600
601  NameStr.flush();
602  return Name;
603}
604
605void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
606                                          raw_ostream &O) {
607  O << getVirtualRegisterName(vr);
608}
609
610void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
611  emitLinkageDirective(F, O);
612  if (isKernelFunction(*F))
613    O << ".entry ";
614  else
615    O << ".func ";
616  printReturnValStr(F, O);
617  getSymbol(F)->print(O, MAI);
618  O << "\n";
619  emitFunctionParamList(F, O);
620  if (shouldEmitPTXNoReturn(F, TM))
621    O << ".noreturn";
622  O << ";\n";
623}
624
625static bool usedInGlobalVarDef(const Constant *C) {
626  if (!C)
627    return false;
628
629  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
630    return GV->getName() != "llvm.used";
631  }
632
633  for (const User *U : C->users())
634    if (const Constant *C = dyn_cast<Constant>(U))
635      if (usedInGlobalVarDef(C))
636        return true;
637
638  return false;
639}
640
641static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
642  if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
643    if (othergv->getName() == "llvm.used")
644      return true;
645  }
646
647  if (const Instruction *instr = dyn_cast<Instruction>(U)) {
648    if (instr->getParent() && instr->getParent()->getParent()) {
649      const Function *curFunc = instr->getParent()->getParent();
650      if (oneFunc && (curFunc != oneFunc))
651        return false;
652      oneFunc = curFunc;
653      return true;
654    } else
655      return false;
656  }
657
658  for (const User *UU : U->users())
659    if (!usedInOneFunc(UU, oneFunc))
660      return false;
661
662  return true;
663}
664
665/* Find out if a global variable can be demoted to local scope.
666 * Currently, this is valid for CUDA shared variables, which have local
667 * scope and global lifetime. So the conditions to check are :
668 * 1. Is the global variable in shared address space?
669 * 2. Does it have internal linkage?
670 * 3. Is the global variable referenced only in one function?
671 */
672static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
673  if (!gv->hasInternalLinkage())
674    return false;
675  PointerType *Pty = gv->getType();
676  if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
677    return false;
678
679  const Function *oneFunc = nullptr;
680
681  bool flag = usedInOneFunc(gv, oneFunc);
682  if (!flag)
683    return false;
684  if (!oneFunc)
685    return false;
686  f = oneFunc;
687  return true;
688}
689
690static bool useFuncSeen(const Constant *C,
691                        DenseMap<const Function *, bool> &seenMap) {
692  for (const User *U : C->users()) {
693    if (const Constant *cu = dyn_cast<Constant>(U)) {
694      if (useFuncSeen(cu, seenMap))
695        return true;
696    } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
697      const BasicBlock *bb = I->getParent();
698      if (!bb)
699        continue;
700      const Function *caller = bb->getParent();
701      if (!caller)
702        continue;
703      if (seenMap.find(caller) != seenMap.end())
704        return true;
705    }
706  }
707  return false;
708}
709
710void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
711  DenseMap<const Function *, bool> seenMap;
712  for (const Function &F : M) {
713    if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
714      emitDeclaration(&F, O);
715      continue;
716    }
717
718    if (F.isDeclaration()) {
719      if (F.use_empty())
720        continue;
721      if (F.getIntrinsicID())
722        continue;
723      emitDeclaration(&F, O);
724      continue;
725    }
726    for (const User *U : F.users()) {
727      if (const Constant *C = dyn_cast<Constant>(U)) {
728        if (usedInGlobalVarDef(C)) {
729          // The use is in the initialization of a global variable
730          // that is a function pointer, so print a declaration
731          // for the original function
732          emitDeclaration(&F, O);
733          break;
734        }
735        // Emit a declaration of this function if the function that
736        // uses this constant expr has already been seen.
737        if (useFuncSeen(C, seenMap)) {
738          emitDeclaration(&F, O);
739          break;
740        }
741      }
742
743      if (!isa<Instruction>(U))
744        continue;
745      const Instruction *instr = cast<Instruction>(U);
746      const BasicBlock *bb = instr->getParent();
747      if (!bb)
748        continue;
749      const Function *caller = bb->getParent();
750      if (!caller)
751        continue;
752
753      // If a caller has already been seen, then the caller is
754      // appearing in the module before the callee. so print out
755      // a declaration for the callee.
756      if (seenMap.find(caller) != seenMap.end()) {
757        emitDeclaration(&F, O);
758        break;
759      }
760    }
761    seenMap[&F] = true;
762  }
763}
764
765static bool isEmptyXXStructor(GlobalVariable *GV) {
766  if (!GV) return true;
767  const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
768  if (!InitList) return true;  // Not an array; we don't know how to parse.
769  return InitList->getNumOperands() == 0;
770}
771
772void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
773  // Construct a default subtarget off of the TargetMachine defaults. The
774  // rest of NVPTX isn't friendly to change subtargets per function and
775  // so the default TargetMachine will have all of the options.
776  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
777  const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
778  SmallString<128> Str1;
779  raw_svector_ostream OS1(Str1);
780
781  // Emit header before any dwarf directives are emitted below.
782  emitHeader(M, OS1, *STI);
783  OutStreamer->emitRawText(OS1.str());
784}
785
786bool NVPTXAsmPrinter::doInitialization(Module &M) {
787  if (M.alias_size()) {
788    report_fatal_error("Module has aliases, which NVPTX does not support.");
789    return true; // error
790  }
791  if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
792    report_fatal_error(
793        "Module has a nontrivial global ctor, which NVPTX does not support.");
794    return true;  // error
795  }
796  if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
797    report_fatal_error(
798        "Module has a nontrivial global dtor, which NVPTX does not support.");
799    return true;  // error
800  }
801
802  // We need to call the parent's one explicitly.
803  bool Result = AsmPrinter::doInitialization(M);
804
805  GlobalsEmitted = false;
806
807  return Result;
808}
809
810void NVPTXAsmPrinter::emitGlobals(const Module &M) {
811  SmallString<128> Str2;
812  raw_svector_ostream OS2(Str2);
813
814  emitDeclarations(M, OS2);
815
816  // As ptxas does not support forward references of globals, we need to first
817  // sort the list of module-level globals in def-use order. We visit each
818  // global variable in order, and ensure that we emit it *after* its dependent
819  // globals. We use a little extra memory maintaining both a set and a list to
820  // have fast searches while maintaining a strict ordering.
821  SmallVector<const GlobalVariable *, 8> Globals;
822  DenseSet<const GlobalVariable *> GVVisited;
823  DenseSet<const GlobalVariable *> GVVisiting;
824
825  // Visit each global variable, in order
826  for (const GlobalVariable &I : M.globals())
827    VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
828
829  assert(GVVisited.size() == M.getGlobalList().size() &&
830         "Missed a global variable");
831  assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
832
833  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
834  const NVPTXSubtarget &STI =
835      *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
836
837  // Print out module-level global variables in proper order
838  for (unsigned i = 0, e = Globals.size(); i != e; ++i)
839    printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
840
841  OS2 << '\n';
842
843  OutStreamer->emitRawText(OS2.str());
844}
845
846void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
847                                 const NVPTXSubtarget &STI) {
848  O << "//\n";
849  O << "// Generated by LLVM NVPTX Back-End\n";
850  O << "//\n";
851  O << "\n";
852
853  unsigned PTXVersion = STI.getPTXVersion();
854  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
855
856  O << ".target ";
857  O << STI.getTargetName();
858
859  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
860  if (NTM.getDrvInterface() == NVPTX::NVCL)
861    O << ", texmode_independent";
862
863  bool HasFullDebugInfo = false;
864  for (DICompileUnit *CU : M.debug_compile_units()) {
865    switch(CU->getEmissionKind()) {
866    case DICompileUnit::NoDebug:
867    case DICompileUnit::DebugDirectivesOnly:
868      break;
869    case DICompileUnit::LineTablesOnly:
870    case DICompileUnit::FullDebug:
871      HasFullDebugInfo = true;
872      break;
873    }
874    if (HasFullDebugInfo)
875      break;
876  }
877  if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
878    O << ", debug";
879
880  O << "\n";
881
882  O << ".address_size ";
883  if (NTM.is64Bit())
884    O << "64";
885  else
886    O << "32";
887  O << "\n";
888
889  O << "\n";
890}
891
892bool NVPTXAsmPrinter::doFinalization(Module &M) {
893  bool HasDebugInfo = MMI && MMI->hasDebugInfo();
894
895  // If we did not emit any functions, then the global declarations have not
896  // yet been emitted.
897  if (!GlobalsEmitted) {
898    emitGlobals(M);
899    GlobalsEmitted = true;
900  }
901
902  // call doFinalization
903  bool ret = AsmPrinter::doFinalization(M);
904
905  clearAnnotationCache(&M);
906
907  auto *TS =
908      static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
909  // Close the last emitted section
910  if (HasDebugInfo) {
911    TS->closeLastSection();
912    // Emit empty .debug_loc section for better support of the empty files.
913    OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
914  }
915
916  // Output last DWARF .file directives, if any.
917  TS->outputDwarfFileDirectives();
918
919  return ret;
920}
921
922// This function emits appropriate linkage directives for
923// functions and global variables.
924//
925// extern function declaration            -> .extern
926// extern function definition             -> .visible
927// external global variable with init     -> .visible
928// external without init                  -> .extern
929// appending                              -> not allowed, assert.
930// for any linkage other than
931// internal, private, linker_private,
932// linker_private_weak, linker_private_weak_def_auto,
933// we emit                                -> .weak.
934
935void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
936                                           raw_ostream &O) {
937  if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
938    if (V->hasExternalLinkage()) {
939      if (isa<GlobalVariable>(V)) {
940        const GlobalVariable *GVar = cast<GlobalVariable>(V);
941        if (GVar) {
942          if (GVar->hasInitializer())
943            O << ".visible ";
944          else
945            O << ".extern ";
946        }
947      } else if (V->isDeclaration())
948        O << ".extern ";
949      else
950        O << ".visible ";
951    } else if (V->hasAppendingLinkage()) {
952      std::string msg;
953      msg.append("Error: ");
954      msg.append("Symbol ");
955      if (V->hasName())
956        msg.append(std::string(V->getName()));
957      msg.append("has unsupported appending linkage type");
958      llvm_unreachable(msg.c_str());
959    } else if (!V->hasInternalLinkage() &&
960               !V->hasPrivateLinkage()) {
961      O << ".weak ";
962    }
963  }
964}
965
966void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
967                                         raw_ostream &O, bool processDemoted,
968                                         const NVPTXSubtarget &STI) {
969  // Skip meta data
970  if (GVar->hasSection()) {
971    if (GVar->getSection() == "llvm.metadata")
972      return;
973  }
974
975  // Skip LLVM intrinsic global variables
976  if (GVar->getName().startswith("llvm.") ||
977      GVar->getName().startswith("nvvm."))
978    return;
979
980  const DataLayout &DL = getDataLayout();
981
982  // GlobalVariables are always constant pointers themselves.
983  PointerType *PTy = GVar->getType();
984  Type *ETy = GVar->getValueType();
985
986  if (GVar->hasExternalLinkage()) {
987    if (GVar->hasInitializer())
988      O << ".visible ";
989    else
990      O << ".extern ";
991  } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
992             GVar->hasAvailableExternallyLinkage() ||
993             GVar->hasCommonLinkage()) {
994    O << ".weak ";
995  }
996
997  if (isTexture(*GVar)) {
998    O << ".global .texref " << getTextureName(*GVar) << ";\n";
999    return;
1000  }
1001
1002  if (isSurface(*GVar)) {
1003    O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1004    return;
1005  }
1006
1007  if (GVar->isDeclaration()) {
1008    // (extern) declarations, no definition or initializer
1009    // Currently the only known declaration is for an automatic __local
1010    // (.shared) promoted to global.
1011    emitPTXGlobalVariable(GVar, O, STI);
1012    O << ";\n";
1013    return;
1014  }
1015
1016  if (isSampler(*GVar)) {
1017    O << ".global .samplerref " << getSamplerName(*GVar);
1018
1019    const Constant *Initializer = nullptr;
1020    if (GVar->hasInitializer())
1021      Initializer = GVar->getInitializer();
1022    const ConstantInt *CI = nullptr;
1023    if (Initializer)
1024      CI = dyn_cast<ConstantInt>(Initializer);
1025    if (CI) {
1026      unsigned sample = CI->getZExtValue();
1027
1028      O << " = { ";
1029
1030      for (int i = 0,
1031               addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1032           i < 3; i++) {
1033        O << "addr_mode_" << i << " = ";
1034        switch (addr) {
1035        case 0:
1036          O << "wrap";
1037          break;
1038        case 1:
1039          O << "clamp_to_border";
1040          break;
1041        case 2:
1042          O << "clamp_to_edge";
1043          break;
1044        case 3:
1045          O << "wrap";
1046          break;
1047        case 4:
1048          O << "mirror";
1049          break;
1050        }
1051        O << ", ";
1052      }
1053      O << "filter_mode = ";
1054      switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1055      case 0:
1056        O << "nearest";
1057        break;
1058      case 1:
1059        O << "linear";
1060        break;
1061      case 2:
1062        llvm_unreachable("Anisotropic filtering is not supported");
1063      default:
1064        O << "nearest";
1065        break;
1066      }
1067      if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1068        O << ", force_unnormalized_coords = 1";
1069      }
1070      O << " }";
1071    }
1072
1073    O << ";\n";
1074    return;
1075  }
1076
1077  if (GVar->hasPrivateLinkage()) {
1078    if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1079      return;
1080
1081    // FIXME - need better way (e.g. Metadata) to avoid generating this global
1082    if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1083      return;
1084    if (GVar->use_empty())
1085      return;
1086  }
1087
1088  const Function *demotedFunc = nullptr;
1089  if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1090    O << "// " << GVar->getName() << " has been demoted\n";
1091    if (localDecls.find(demotedFunc) != localDecls.end())
1092      localDecls[demotedFunc].push_back(GVar);
1093    else {
1094      std::vector<const GlobalVariable *> temp;
1095      temp.push_back(GVar);
1096      localDecls[demotedFunc] = temp;
1097    }
1098    return;
1099  }
1100
1101  O << ".";
1102  emitPTXAddressSpace(PTy->getAddressSpace(), O);
1103
1104  if (isManaged(*GVar)) {
1105    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1106      report_fatal_error(
1107          ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1108    }
1109    O << " .attribute(.managed)";
1110  }
1111
1112  if (MaybeAlign A = GVar->getAlign())
1113    O << " .align " << A->value();
1114  else
1115    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1116
1117  if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1118      (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1119    O << " .";
1120    // Special case: ABI requires that we use .u8 for predicates
1121    if (ETy->isIntegerTy(1))
1122      O << "u8";
1123    else
1124      O << getPTXFundamentalTypeStr(ETy, false);
1125    O << " ";
1126    getSymbol(GVar)->print(O, MAI);
1127
1128    // Ptx allows variable initilization only for constant and global state
1129    // spaces.
1130    if (GVar->hasInitializer()) {
1131      if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1132          (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1133        const Constant *Initializer = GVar->getInitializer();
1134        // 'undef' is treated as there is no value specified.
1135        if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1136          O << " = ";
1137          printScalarConstant(Initializer, O);
1138        }
1139      } else {
1140        // The frontend adds zero-initializer to device and constant variables
1141        // that don't have an initial value, and UndefValue to shared
1142        // variables, so skip warning for this case.
1143        if (!GVar->getInitializer()->isNullValue() &&
1144            !isa<UndefValue>(GVar->getInitializer())) {
1145          report_fatal_error("initial value of '" + GVar->getName() +
1146                             "' is not allowed in addrspace(" +
1147                             Twine(PTy->getAddressSpace()) + ")");
1148        }
1149      }
1150    }
1151  } else {
1152    unsigned int ElementSize = 0;
1153
1154    // Although PTX has direct support for struct type and array type and
1155    // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1156    // targets that support these high level field accesses. Structs, arrays
1157    // and vectors are lowered into arrays of bytes.
1158    switch (ETy->getTypeID()) {
1159    case Type::IntegerTyID: // Integers larger than 64 bits
1160    case Type::StructTyID:
1161    case Type::ArrayTyID:
1162    case Type::FixedVectorTyID:
1163      ElementSize = DL.getTypeStoreSize(ETy);
1164      // Ptx allows variable initilization only for constant and
1165      // global state spaces.
1166      if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1167           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1168          GVar->hasInitializer()) {
1169        const Constant *Initializer = GVar->getInitializer();
1170        if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1171          AggBuffer aggBuffer(ElementSize, *this);
1172          bufferAggregateConstant(Initializer, &aggBuffer);
1173          if (aggBuffer.numSymbols()) {
1174            unsigned int ptrSize = MAI->getCodePointerSize();
1175            if (ElementSize % ptrSize ||
1176                !aggBuffer.allSymbolsAligned(ptrSize)) {
1177              // Print in bytes and use the mask() operator for pointers.
1178              if (!STI.hasMaskOperator())
1179                report_fatal_error(
1180                    "initialized packed aggregate with pointers '" +
1181                    GVar->getName() +
1182                    "' requires at least PTX ISA version 7.1");
1183              O << " .u8 ";
1184              getSymbol(GVar)->print(O, MAI);
1185              O << "[" << ElementSize << "] = {";
1186              aggBuffer.printBytes(O);
1187              O << "}";
1188            } else {
1189              O << " .u" << ptrSize * 8 << " ";
1190              getSymbol(GVar)->print(O, MAI);
1191              O << "[" << ElementSize / ptrSize << "] = {";
1192              aggBuffer.printWords(O);
1193              O << "}";
1194            }
1195          } else {
1196            O << " .b8 ";
1197            getSymbol(GVar)->print(O, MAI);
1198            O << "[" << ElementSize << "] = {";
1199            aggBuffer.printBytes(O);
1200            O << "}";
1201          }
1202        } else {
1203          O << " .b8 ";
1204          getSymbol(GVar)->print(O, MAI);
1205          if (ElementSize) {
1206            O << "[";
1207            O << ElementSize;
1208            O << "]";
1209          }
1210        }
1211      } else {
1212        O << " .b8 ";
1213        getSymbol(GVar)->print(O, MAI);
1214        if (ElementSize) {
1215          O << "[";
1216          O << ElementSize;
1217          O << "]";
1218        }
1219      }
1220      break;
1221    default:
1222      llvm_unreachable("type not supported yet");
1223    }
1224  }
1225  O << ";\n";
1226}
1227
1228void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1229  const Value *v = Symbols[nSym];
1230  const Value *v0 = SymbolsBeforeStripping[nSym];
1231  if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1232    MCSymbol *Name = AP.getSymbol(GVar);
1233    PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1234    // Is v0 a generic pointer?
1235    bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1236    if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1237      os << "generic(";
1238      Name->print(os, AP.MAI);
1239      os << ")";
1240    } else {
1241      Name->print(os, AP.MAI);
1242    }
1243  } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1244    const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1245    AP.printMCExpr(*Expr, os);
1246  } else
1247    llvm_unreachable("symbol type unknown");
1248}
1249
1250void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1251  unsigned int ptrSize = AP.MAI->getCodePointerSize();
1252  symbolPosInBuffer.push_back(size);
1253  unsigned int nSym = 0;
1254  unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1255  for (unsigned int pos = 0; pos < size;) {
1256    if (pos)
1257      os << ", ";
1258    if (pos != nextSymbolPos) {
1259      os << (unsigned int)buffer[pos];
1260      ++pos;
1261      continue;
1262    }
1263    // Generate a per-byte mask() operator for the symbol, which looks like:
1264    //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1265    // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1266    std::string symText;
1267    llvm::raw_string_ostream oss(symText);
1268    printSymbol(nSym, oss);
1269    for (unsigned i = 0; i < ptrSize; ++i) {
1270      if (i)
1271        os << ", ";
1272      llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1273      os << "(" << symText << ")";
1274    }
1275    pos += ptrSize;
1276    nextSymbolPos = symbolPosInBuffer[++nSym];
1277    assert(nextSymbolPos >= pos);
1278  }
1279}
1280
1281void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1282  unsigned int ptrSize = AP.MAI->getCodePointerSize();
1283  symbolPosInBuffer.push_back(size);
1284  unsigned int nSym = 0;
1285  unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1286  assert(nextSymbolPos % ptrSize == 0);
1287  for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1288    if (pos)
1289      os << ", ";
1290    if (pos == nextSymbolPos) {
1291      printSymbol(nSym, os);
1292      nextSymbolPos = symbolPosInBuffer[++nSym];
1293      assert(nextSymbolPos % ptrSize == 0);
1294      assert(nextSymbolPos >= pos + ptrSize);
1295    } else if (ptrSize == 4)
1296      os << support::endian::read32le(&buffer[pos]);
1297    else
1298      os << support::endian::read64le(&buffer[pos]);
1299  }
1300}
1301
1302void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1303  if (localDecls.find(f) == localDecls.end())
1304    return;
1305
1306  std::vector<const GlobalVariable *> &gvars = localDecls[f];
1307
1308  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1309  const NVPTXSubtarget &STI =
1310      *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1311
1312  for (const GlobalVariable *GV : gvars) {
1313    O << "\t// demoted variable\n\t";
1314    printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1315  }
1316}
1317
1318void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1319                                          raw_ostream &O) const {
1320  switch (AddressSpace) {
1321  case ADDRESS_SPACE_LOCAL:
1322    O << "local";
1323    break;
1324  case ADDRESS_SPACE_GLOBAL:
1325    O << "global";
1326    break;
1327  case ADDRESS_SPACE_CONST:
1328    O << "const";
1329    break;
1330  case ADDRESS_SPACE_SHARED:
1331    O << "shared";
1332    break;
1333  default:
1334    report_fatal_error("Bad address space found while emitting PTX: " +
1335                       llvm::Twine(AddressSpace));
1336    break;
1337  }
1338}
1339
1340std::string
1341NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1342  switch (Ty->getTypeID()) {
1343  case Type::IntegerTyID: {
1344    unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1345    if (NumBits == 1)
1346      return "pred";
1347    else if (NumBits <= 64) {
1348      std::string name = "u";
1349      return name + utostr(NumBits);
1350    } else {
1351      llvm_unreachable("Integer too large");
1352      break;
1353    }
1354    break;
1355  }
1356  case Type::HalfTyID:
1357    // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1358    return "b16";
1359  case Type::FloatTyID:
1360    return "f32";
1361  case Type::DoubleTyID:
1362    return "f64";
1363  case Type::PointerTyID: {
1364    unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1365    assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1366
1367    if (PtrSize == 64)
1368      if (useB4PTR)
1369        return "b64";
1370      else
1371        return "u64";
1372    else if (useB4PTR)
1373      return "b32";
1374    else
1375      return "u32";
1376  }
1377  default:
1378    break;
1379  }
1380  llvm_unreachable("unexpected type");
1381}
1382
1383void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1384                                            raw_ostream &O,
1385                                            const NVPTXSubtarget &STI) {
1386  const DataLayout &DL = getDataLayout();
1387
1388  // GlobalVariables are always constant pointers themselves.
1389  Type *ETy = GVar->getValueType();
1390
1391  O << ".";
1392  emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1393  if (isManaged(*GVar)) {
1394    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1395      report_fatal_error(
1396          ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1397    }
1398    O << " .attribute(.managed)";
1399  }
1400  if (MaybeAlign A = GVar->getAlign())
1401    O << " .align " << A->value();
1402  else
1403    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1404
1405  // Special case for i128
1406  if (ETy->isIntegerTy(128)) {
1407    O << " .b8 ";
1408    getSymbol(GVar)->print(O, MAI);
1409    O << "[16]";
1410    return;
1411  }
1412
1413  if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1414    O << " .";
1415    O << getPTXFundamentalTypeStr(ETy);
1416    O << " ";
1417    getSymbol(GVar)->print(O, MAI);
1418    return;
1419  }
1420
1421  int64_t ElementSize = 0;
1422
1423  // Although PTX has direct support for struct type and array type and LLVM IR
1424  // is very similar to PTX, the LLVM CodeGen does not support for targets that
1425  // support these high level field accesses. Structs and arrays are lowered
1426  // into arrays of bytes.
1427  switch (ETy->getTypeID()) {
1428  case Type::StructTyID:
1429  case Type::ArrayTyID:
1430  case Type::FixedVectorTyID:
1431    ElementSize = DL.getTypeStoreSize(ETy);
1432    O << " .b8 ";
1433    getSymbol(GVar)->print(O, MAI);
1434    O << "[";
1435    if (ElementSize) {
1436      O << ElementSize;
1437    }
1438    O << "]";
1439    break;
1440  default:
1441    llvm_unreachable("type not supported yet");
1442  }
1443}
1444
1445void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1446                                     int paramIndex, raw_ostream &O) {
1447  getSymbol(I->getParent())->print(O, MAI);
1448  O << "_param_" << paramIndex;
1449}
1450
1451void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1452  const DataLayout &DL = getDataLayout();
1453  const AttributeList &PAL = F->getAttributes();
1454  const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1455  const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1456
1457  Function::const_arg_iterator I, E;
1458  unsigned paramIndex = 0;
1459  bool first = true;
1460  bool isKernelFunc = isKernelFunction(*F);
1461  bool isABI = (STI.getSmVersion() >= 20);
1462  bool hasImageHandles = STI.hasImageHandles();
1463
1464  if (F->arg_empty() && !F->isVarArg()) {
1465    O << "()\n";
1466    return;
1467  }
1468
1469  O << "(\n";
1470
1471  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1472    Type *Ty = I->getType();
1473
1474    if (!first)
1475      O << ",\n";
1476
1477    first = false;
1478
1479    // Handle image/sampler parameters
1480    if (isKernelFunction(*F)) {
1481      if (isSampler(*I) || isImage(*I)) {
1482        if (isImage(*I)) {
1483          std::string sname = std::string(I->getName());
1484          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1485            if (hasImageHandles)
1486              O << "\t.param .u64 .ptr .surfref ";
1487            else
1488              O << "\t.param .surfref ";
1489            CurrentFnSym->print(O, MAI);
1490            O << "_param_" << paramIndex;
1491          }
1492          else { // Default image is read_only
1493            if (hasImageHandles)
1494              O << "\t.param .u64 .ptr .texref ";
1495            else
1496              O << "\t.param .texref ";
1497            CurrentFnSym->print(O, MAI);
1498            O << "_param_" << paramIndex;
1499          }
1500        } else {
1501          if (hasImageHandles)
1502            O << "\t.param .u64 .ptr .samplerref ";
1503          else
1504            O << "\t.param .samplerref ";
1505          CurrentFnSym->print(O, MAI);
1506          O << "_param_" << paramIndex;
1507        }
1508        continue;
1509      }
1510    }
1511
1512    auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1513                                    paramIndex](Type *Ty) -> Align {
1514      Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1515      MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1516      return std::max(TypeAlign, ParamAlign.valueOrOne());
1517    };
1518
1519    if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1520      if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1521        // Just print .param .align <a> .b8 .param[size];
1522        // <a>  = optimal alignment for the element type; always multiple of
1523        //        PAL.getParamAlignment
1524        // size = typeallocsize of element type
1525        Align OptimalAlign = getOptimalAlignForParam(Ty);
1526
1527        O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1528        printParamName(I, paramIndex, O);
1529        O << "[" << DL.getTypeAllocSize(Ty) << "]";
1530
1531        continue;
1532      }
1533      // Just a scalar
1534      auto *PTy = dyn_cast<PointerType>(Ty);
1535      unsigned PTySizeInBits = 0;
1536      if (PTy) {
1537        PTySizeInBits =
1538            TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1539        assert(PTySizeInBits && "Invalid pointer size");
1540      }
1541
1542      if (isKernelFunc) {
1543        if (PTy) {
1544          // Special handling for pointer arguments to kernel
1545          O << "\t.param .u" << PTySizeInBits << " ";
1546
1547          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1548              NVPTX::CUDA) {
1549            int addrSpace = PTy->getAddressSpace();
1550            switch (addrSpace) {
1551            default:
1552              O << ".ptr ";
1553              break;
1554            case ADDRESS_SPACE_CONST:
1555              O << ".ptr .const ";
1556              break;
1557            case ADDRESS_SPACE_SHARED:
1558              O << ".ptr .shared ";
1559              break;
1560            case ADDRESS_SPACE_GLOBAL:
1561              O << ".ptr .global ";
1562              break;
1563            }
1564            Align ParamAlign = I->getParamAlign().valueOrOne();
1565            O << ".align " << ParamAlign.value() << " ";
1566          }
1567          printParamName(I, paramIndex, O);
1568          continue;
1569        }
1570
1571        // non-pointer scalar to kernel func
1572        O << "\t.param .";
1573        // Special case: predicate operands become .u8 types
1574        if (Ty->isIntegerTy(1))
1575          O << "u8";
1576        else
1577          O << getPTXFundamentalTypeStr(Ty);
1578        O << " ";
1579        printParamName(I, paramIndex, O);
1580        continue;
1581      }
1582      // Non-kernel function, just print .param .b<size> for ABI
1583      // and .reg .b<size> for non-ABI
1584      unsigned sz = 0;
1585      if (isa<IntegerType>(Ty)) {
1586        sz = cast<IntegerType>(Ty)->getBitWidth();
1587        sz = promoteScalarArgumentSize(sz);
1588      } else if (PTy) {
1589        assert(PTySizeInBits && "Invalid pointer size");
1590        sz = PTySizeInBits;
1591      } else if (Ty->isHalfTy())
1592        // PTX ABI requires all scalar parameters to be at least 32
1593        // bits in size.  fp16 normally uses .b16 as its storage type
1594        // in PTX, so its size must be adjusted here, too.
1595        sz = 32;
1596      else
1597        sz = Ty->getPrimitiveSizeInBits();
1598      if (isABI)
1599        O << "\t.param .b" << sz << " ";
1600      else
1601        O << "\t.reg .b" << sz << " ";
1602      printParamName(I, paramIndex, O);
1603      continue;
1604    }
1605
1606    // param has byVal attribute.
1607    Type *ETy = PAL.getParamByValType(paramIndex);
1608    assert(ETy && "Param should have byval type");
1609
1610    if (isABI || isKernelFunc) {
1611      // Just print .param .align <a> .b8 .param[size];
1612      // <a>  = optimal alignment for the element type; always multiple of
1613      //        PAL.getParamAlignment
1614      // size = typeallocsize of element type
1615      Align OptimalAlign =
1616          isKernelFunc
1617              ? getOptimalAlignForParam(ETy)
1618              : TLI->getFunctionByValParamAlign(
1619                    F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1620
1621      unsigned sz = DL.getTypeAllocSize(ETy);
1622      O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1623      printParamName(I, paramIndex, O);
1624      O << "[" << sz << "]";
1625      continue;
1626    } else {
1627      // Split the ETy into constituent parts and
1628      // print .param .b<size> <name> for each part.
1629      // Further, if a part is vector, print the above for
1630      // each vector element.
1631      SmallVector<EVT, 16> vtparts;
1632      ComputeValueVTs(*TLI, DL, ETy, vtparts);
1633      for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1634        unsigned elems = 1;
1635        EVT elemtype = vtparts[i];
1636        if (vtparts[i].isVector()) {
1637          elems = vtparts[i].getVectorNumElements();
1638          elemtype = vtparts[i].getVectorElementType();
1639        }
1640
1641        for (unsigned j = 0, je = elems; j != je; ++j) {
1642          unsigned sz = elemtype.getSizeInBits();
1643          if (elemtype.isInteger())
1644            sz = promoteScalarArgumentSize(sz);
1645          O << "\t.reg .b" << sz << " ";
1646          printParamName(I, paramIndex, O);
1647          if (j < je - 1)
1648            O << ",\n";
1649          ++paramIndex;
1650        }
1651        if (i < e - 1)
1652          O << ",\n";
1653      }
1654      --paramIndex;
1655      continue;
1656    }
1657  }
1658
1659  if (F->isVarArg()) {
1660    if (!first)
1661      O << ",\n";
1662    O << "\t.param .align " << STI.getMaxRequiredAlignment();
1663    O << " .b8 ";
1664    getSymbol(F)->print(O, MAI);
1665    O << "_vararg[]";
1666  }
1667
1668  O << "\n)\n";
1669}
1670
1671void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1672                                            raw_ostream &O) {
1673  const Function &F = MF.getFunction();
1674  emitFunctionParamList(&F, O);
1675}
1676
1677void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1678    const MachineFunction &MF) {
1679  SmallString<128> Str;
1680  raw_svector_ostream O(Str);
1681
1682  // Map the global virtual register number to a register class specific
1683  // virtual register number starting from 1 with that class.
1684  const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1685  //unsigned numRegClasses = TRI->getNumRegClasses();
1686
1687  // Emit the Fake Stack Object
1688  const MachineFrameInfo &MFI = MF.getFrameInfo();
1689  int NumBytes = (int) MFI.getStackSize();
1690  if (NumBytes) {
1691    O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1692      << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1693    if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1694      O << "\t.reg .b64 \t%SP;\n";
1695      O << "\t.reg .b64 \t%SPL;\n";
1696    } else {
1697      O << "\t.reg .b32 \t%SP;\n";
1698      O << "\t.reg .b32 \t%SPL;\n";
1699    }
1700  }
1701
1702  // Go through all virtual registers to establish the mapping between the
1703  // global virtual
1704  // register number and the per class virtual register number.
1705  // We use the per class virtual register number in the ptx output.
1706  unsigned int numVRs = MRI->getNumVirtRegs();
1707  for (unsigned i = 0; i < numVRs; i++) {
1708    Register vr = Register::index2VirtReg(i);
1709    const TargetRegisterClass *RC = MRI->getRegClass(vr);
1710    DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1711    int n = regmap.size();
1712    regmap.insert(std::make_pair(vr, n + 1));
1713  }
1714
1715  // Emit register declarations
1716  // @TODO: Extract out the real register usage
1717  // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1718  // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1719  // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1720  // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1721  // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1722  // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1723  // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1724
1725  // Emit declaration of the virtual registers or 'physical' registers for
1726  // each register class
1727  for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1728    const TargetRegisterClass *RC = TRI->getRegClass(i);
1729    DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1730    std::string rcname = getNVPTXRegClassName(RC);
1731    std::string rcStr = getNVPTXRegClassStr(RC);
1732    int n = regmap.size();
1733
1734    // Only declare those registers that may be used.
1735    if (n) {
1736       O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1737         << ">;\n";
1738    }
1739  }
1740
1741  OutStreamer->emitRawText(O.str());
1742}
1743
1744void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1745  APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1746  bool ignored;
1747  unsigned int numHex;
1748  const char *lead;
1749
1750  if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1751    numHex = 8;
1752    lead = "0f";
1753    APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1754  } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1755    numHex = 16;
1756    lead = "0d";
1757    APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1758  } else
1759    llvm_unreachable("unsupported fp type");
1760
1761  APInt API = APF.bitcastToAPInt();
1762  O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1763}
1764
1765void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1766  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1767    O << CI->getValue();
1768    return;
1769  }
1770  if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1771    printFPConstant(CFP, O);
1772    return;
1773  }
1774  if (isa<ConstantPointerNull>(CPV)) {
1775    O << "0";
1776    return;
1777  }
1778  if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1779    bool IsNonGenericPointer = false;
1780    if (GVar->getType()->getAddressSpace() != 0) {
1781      IsNonGenericPointer = true;
1782    }
1783    if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1784      O << "generic(";
1785      getSymbol(GVar)->print(O, MAI);
1786      O << ")";
1787    } else {
1788      getSymbol(GVar)->print(O, MAI);
1789    }
1790    return;
1791  }
1792  if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1793    const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1794    printMCExpr(*E, O);
1795    return;
1796  }
1797  llvm_unreachable("Not scalar type found in printScalarConstant()");
1798}
1799
1800void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1801                                   AggBuffer *AggBuffer) {
1802  const DataLayout &DL = getDataLayout();
1803  int AllocSize = DL.getTypeAllocSize(CPV->getType());
1804  if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1805    // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1806    // only the space allocated by CPV.
1807    AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1808    return;
1809  }
1810
1811  // Helper for filling AggBuffer with APInts.
1812  auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1813    size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1814    SmallVector<unsigned char, 16> Buf(NumBytes);
1815    for (unsigned I = 0; I < NumBytes; ++I) {
1816      Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1817    }
1818    AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1819  };
1820
1821  switch (CPV->getType()->getTypeID()) {
1822  case Type::IntegerTyID:
1823    if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1824      AddIntToBuffer(CI->getValue());
1825      break;
1826    }
1827    if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1828      if (const auto *CI =
1829              dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1830        AddIntToBuffer(CI->getValue());
1831        break;
1832      }
1833      if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1834        Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1835        AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1836        AggBuffer->addZeros(AllocSize);
1837        break;
1838      }
1839    }
1840    llvm_unreachable("unsupported integer const type");
1841    break;
1842
1843  case Type::HalfTyID:
1844  case Type::BFloatTyID:
1845  case Type::FloatTyID:
1846  case Type::DoubleTyID:
1847    AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1848    break;
1849
1850  case Type::PointerTyID: {
1851    if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1852      AggBuffer->addSymbol(GVar, GVar);
1853    } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1854      const Value *v = Cexpr->stripPointerCasts();
1855      AggBuffer->addSymbol(v, Cexpr);
1856    }
1857    AggBuffer->addZeros(AllocSize);
1858    break;
1859  }
1860
1861  case Type::ArrayTyID:
1862  case Type::FixedVectorTyID:
1863  case Type::StructTyID: {
1864    if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1865      bufferAggregateConstant(CPV, AggBuffer);
1866      if (Bytes > AllocSize)
1867        AggBuffer->addZeros(Bytes - AllocSize);
1868    } else if (isa<ConstantAggregateZero>(CPV))
1869      AggBuffer->addZeros(Bytes);
1870    else
1871      llvm_unreachable("Unexpected Constant type");
1872    break;
1873  }
1874
1875  default:
1876    llvm_unreachable("unsupported type");
1877  }
1878}
1879
1880void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1881                                              AggBuffer *aggBuffer) {
1882  const DataLayout &DL = getDataLayout();
1883  int Bytes;
1884
1885  // Integers of arbitrary width
1886  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1887    APInt Val = CI->getValue();
1888    for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1889      uint8_t Byte = Val.getLoBits(8).getZExtValue();
1890      aggBuffer->addBytes(&Byte, 1, 1);
1891      Val.lshrInPlace(8);
1892    }
1893    return;
1894  }
1895
1896  // Old constants
1897  if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1898    if (CPV->getNumOperands())
1899      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1900        bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1901    return;
1902  }
1903
1904  if (const ConstantDataSequential *CDS =
1905          dyn_cast<ConstantDataSequential>(CPV)) {
1906    if (CDS->getNumElements())
1907      for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1908        bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1909                     aggBuffer);
1910    return;
1911  }
1912
1913  if (isa<ConstantStruct>(CPV)) {
1914    if (CPV->getNumOperands()) {
1915      StructType *ST = cast<StructType>(CPV->getType());
1916      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1917        if (i == (e - 1))
1918          Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1919                  DL.getTypeAllocSize(ST) -
1920                  DL.getStructLayout(ST)->getElementOffset(i);
1921        else
1922          Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1923                  DL.getStructLayout(ST)->getElementOffset(i);
1924        bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1925      }
1926    }
1927    return;
1928  }
1929  llvm_unreachable("unsupported constant type in printAggregateConstant()");
1930}
1931
1932/// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1933/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1934/// expressions that are representable in PTX and create
1935/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1936const MCExpr *
1937NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1938  MCContext &Ctx = OutContext;
1939
1940  if (CV->isNullValue() || isa<UndefValue>(CV))
1941    return MCConstantExpr::create(0, Ctx);
1942
1943  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1944    return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1945
1946  if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1947    const MCSymbolRefExpr *Expr =
1948      MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1949    if (ProcessingGeneric) {
1950      return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1951    } else {
1952      return Expr;
1953    }
1954  }
1955
1956  const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1957  if (!CE) {
1958    llvm_unreachable("Unknown constant value to lower!");
1959  }
1960
1961  switch (CE->getOpcode()) {
1962  default: {
1963    // If the code isn't optimized, there may be outstanding folding
1964    // opportunities. Attempt to fold the expression using DataLayout as a
1965    // last resort before giving up.
1966    Constant *C = ConstantFoldConstant(CE, getDataLayout());
1967    if (C != CE)
1968      return lowerConstantForGV(C, ProcessingGeneric);
1969
1970    // Otherwise report the problem to the user.
1971    std::string S;
1972    raw_string_ostream OS(S);
1973    OS << "Unsupported expression in static initializer: ";
1974    CE->printAsOperand(OS, /*PrintType=*/false,
1975                   !MF ? nullptr : MF->getFunction().getParent());
1976    report_fatal_error(Twine(OS.str()));
1977  }
1978
1979  case Instruction::AddrSpaceCast: {
1980    // Strip the addrspacecast and pass along the operand
1981    PointerType *DstTy = cast<PointerType>(CE->getType());
1982    if (DstTy->getAddressSpace() == 0) {
1983      return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1984    }
1985    std::string S;
1986    raw_string_ostream OS(S);
1987    OS << "Unsupported expression in static initializer: ";
1988    CE->printAsOperand(OS, /*PrintType=*/ false,
1989                       !MF ? nullptr : MF->getFunction().getParent());
1990    report_fatal_error(Twine(OS.str()));
1991  }
1992
1993  case Instruction::GetElementPtr: {
1994    const DataLayout &DL = getDataLayout();
1995
1996    // Generate a symbolic expression for the byte address
1997    APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1998    cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
1999
2000    const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2001                                            ProcessingGeneric);
2002    if (!OffsetAI)
2003      return Base;
2004
2005    int64_t Offset = OffsetAI.getSExtValue();
2006    return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2007                                   Ctx);
2008  }
2009
2010  case Instruction::Trunc:
2011    // We emit the value and depend on the assembler to truncate the generated
2012    // expression properly.  This is important for differences between
2013    // blockaddress labels.  Since the two labels are in the same function, it
2014    // is reasonable to treat their delta as a 32-bit value.
2015    [[fallthrough]];
2016  case Instruction::BitCast:
2017    return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2018
2019  case Instruction::IntToPtr: {
2020    const DataLayout &DL = getDataLayout();
2021
2022    // Handle casts to pointers by changing them into casts to the appropriate
2023    // integer type.  This promotes constant folding and simplifies this code.
2024    Constant *Op = CE->getOperand(0);
2025    Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2026                                      false/*ZExt*/);
2027    return lowerConstantForGV(Op, ProcessingGeneric);
2028  }
2029
2030  case Instruction::PtrToInt: {
2031    const DataLayout &DL = getDataLayout();
2032
2033    // Support only foldable casts to/from pointers that can be eliminated by
2034    // changing the pointer to the appropriately sized integer type.
2035    Constant *Op = CE->getOperand(0);
2036    Type *Ty = CE->getType();
2037
2038    const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2039
2040    // We can emit the pointer value into this slot if the slot is an
2041    // integer slot equal to the size of the pointer.
2042    if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2043      return OpExpr;
2044
2045    // Otherwise the pointer is smaller than the resultant integer, mask off
2046    // the high bits so we are sure to get a proper truncation if the input is
2047    // a constant expr.
2048    unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2049    const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2050    return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2051  }
2052
2053  // The MC library also has a right-shift operator, but it isn't consistently
2054  // signed or unsigned between different targets.
2055  case Instruction::Add: {
2056    const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2057    const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2058    switch (CE->getOpcode()) {
2059    default: llvm_unreachable("Unknown binary operator constant cast expr");
2060    case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2061    }
2062  }
2063  }
2064}
2065
2066// Copy of MCExpr::print customized for NVPTX
2067void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2068  switch (Expr.getKind()) {
2069  case MCExpr::Target:
2070    return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2071  case MCExpr::Constant:
2072    OS << cast<MCConstantExpr>(Expr).getValue();
2073    return;
2074
2075  case MCExpr::SymbolRef: {
2076    const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2077    const MCSymbol &Sym = SRE.getSymbol();
2078    Sym.print(OS, MAI);
2079    return;
2080  }
2081
2082  case MCExpr::Unary: {
2083    const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2084    switch (UE.getOpcode()) {
2085    case MCUnaryExpr::LNot:  OS << '!'; break;
2086    case MCUnaryExpr::Minus: OS << '-'; break;
2087    case MCUnaryExpr::Not:   OS << '~'; break;
2088    case MCUnaryExpr::Plus:  OS << '+'; break;
2089    }
2090    printMCExpr(*UE.getSubExpr(), OS);
2091    return;
2092  }
2093
2094  case MCExpr::Binary: {
2095    const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2096
2097    // Only print parens around the LHS if it is non-trivial.
2098    if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2099        isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2100      printMCExpr(*BE.getLHS(), OS);
2101    } else {
2102      OS << '(';
2103      printMCExpr(*BE.getLHS(), OS);
2104      OS<< ')';
2105    }
2106
2107    switch (BE.getOpcode()) {
2108    case MCBinaryExpr::Add:
2109      // Print "X-42" instead of "X+-42".
2110      if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2111        if (RHSC->getValue() < 0) {
2112          OS << RHSC->getValue();
2113          return;
2114        }
2115      }
2116
2117      OS <<  '+';
2118      break;
2119    default: llvm_unreachable("Unhandled binary operator");
2120    }
2121
2122    // Only print parens around the LHS if it is non-trivial.
2123    if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2124      printMCExpr(*BE.getRHS(), OS);
2125    } else {
2126      OS << '(';
2127      printMCExpr(*BE.getRHS(), OS);
2128      OS << ')';
2129    }
2130    return;
2131  }
2132  }
2133
2134  llvm_unreachable("Invalid expression kind!");
2135}
2136
2137/// PrintAsmOperand - Print out an operand for an inline asm expression.
2138///
2139bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2140                                      const char *ExtraCode, raw_ostream &O) {
2141  if (ExtraCode && ExtraCode[0]) {
2142    if (ExtraCode[1] != 0)
2143      return true; // Unknown modifier.
2144
2145    switch (ExtraCode[0]) {
2146    default:
2147      // See if this is a generic print operand
2148      return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2149    case 'r':
2150      break;
2151    }
2152  }
2153
2154  printOperand(MI, OpNo, O);
2155
2156  return false;
2157}
2158
2159bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2160                                            unsigned OpNo,
2161                                            const char *ExtraCode,
2162                                            raw_ostream &O) {
2163  if (ExtraCode && ExtraCode[0])
2164    return true; // Unknown modifier
2165
2166  O << '[';
2167  printMemOperand(MI, OpNo, O);
2168  O << ']';
2169
2170  return false;
2171}
2172
2173void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2174                                   raw_ostream &O) {
2175  const MachineOperand &MO = MI->getOperand(opNum);
2176  switch (MO.getType()) {
2177  case MachineOperand::MO_Register:
2178    if (MO.getReg().isPhysical()) {
2179      if (MO.getReg() == NVPTX::VRDepot)
2180        O << DEPOTNAME << getFunctionNumber();
2181      else
2182        O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2183    } else {
2184      emitVirtualRegister(MO.getReg(), O);
2185    }
2186    break;
2187
2188  case MachineOperand::MO_Immediate:
2189    O << MO.getImm();
2190    break;
2191
2192  case MachineOperand::MO_FPImmediate:
2193    printFPConstant(MO.getFPImm(), O);
2194    break;
2195
2196  case MachineOperand::MO_GlobalAddress:
2197    PrintSymbolOperand(MO, O);
2198    break;
2199
2200  case MachineOperand::MO_MachineBasicBlock:
2201    MO.getMBB()->getSymbol()->print(O, MAI);
2202    break;
2203
2204  default:
2205    llvm_unreachable("Operand type not supported.");
2206  }
2207}
2208
2209void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2210                                      raw_ostream &O, const char *Modifier) {
2211  printOperand(MI, opNum, O);
2212
2213  if (Modifier && strcmp(Modifier, "add") == 0) {
2214    O << ", ";
2215    printOperand(MI, opNum + 1, O);
2216  } else {
2217    if (MI->getOperand(opNum + 1).isImm() &&
2218        MI->getOperand(opNum + 1).getImm() == 0)
2219      return; // don't print ',0' or '+0'
2220    O << "+";
2221    printOperand(MI, opNum + 1, O);
2222  }
2223}
2224
2225// Force static initialization.
2226extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2227  RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2228  RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2229}
2230