NVPTXAsmPrinter.h revision 341825
1//===-- NVPTXAsmPrinter.h - NVPTX LLVM assembly writer ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains a printer that converts from our internal representation
11// of machine-dependent LLVM code to NVPTX assembly language.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
16#define LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
17
18#include "NVPTX.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/CodeGen/AsmPrinter.h"
25#include "llvm/CodeGen/MachineFunction.h"
26#include "llvm/CodeGen/MachineLoopInfo.h"
27#include "llvm/IR/Constants.h"
28#include "llvm/IR/DebugLoc.h"
29#include "llvm/IR/DerivedTypes.h"
30#include "llvm/IR/Function.h"
31#include "llvm/IR/GlobalValue.h"
32#include "llvm/IR/Value.h"
33#include "llvm/MC/MCExpr.h"
34#include "llvm/MC/MCStreamer.h"
35#include "llvm/MC/MCSymbol.h"
36#include "llvm/PassAnalysisSupport.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Compiler.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/raw_ostream.h"
41#include "llvm/Target/TargetMachine.h"
42#include <algorithm>
43#include <cassert>
44#include <map>
45#include <memory>
46#include <string>
47#include <vector>
48
49// The ptx syntax and format is very different from that usually seem in a .s
50// file,
51// therefore we are not able to use the MCAsmStreamer interface here.
52//
53// We are handcrafting the output method here.
54//
55// A better approach is to clone the MCAsmStreamer to a MCPTXAsmStreamer
56// (subclass of MCStreamer).
57
58namespace llvm {
59
60class MCOperand;
61
62class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
63
64  class AggBuffer {
65    // Used to buffer the emitted string for initializing global
66    // aggregates.
67    //
68    // Normally an aggregate (array, vector or structure) is emitted
69    // as a u8[]. However, if one element/field of the aggregate
70    // is a non-NULL address, then the aggregate is emitted as u32[]
71    // or u64[].
72    //
73    // We first layout the aggregate in 'buffer' in bytes, except for
74    // those symbol addresses. For the i-th symbol address in the
75    //aggregate, its corresponding 4-byte or 8-byte elements in 'buffer'
76    // are filled with 0s. symbolPosInBuffer[i-1] records its position
77    // in 'buffer', and Symbols[i-1] records the Value*.
78    //
79    // Once we have this AggBuffer setup, we can choose how to print
80    // it out.
81  public:
82    unsigned numSymbols;   // number of symbol addresses
83
84  private:
85    const unsigned size;   // size of the buffer in bytes
86    std::vector<unsigned char> buffer; // the buffer
87    SmallVector<unsigned, 4> symbolPosInBuffer;
88    SmallVector<const Value *, 4> Symbols;
89    // SymbolsBeforeStripping[i] is the original form of Symbols[i] before
90    // stripping pointer casts, i.e.,
91    // Symbols[i] == SymbolsBeforeStripping[i]->stripPointerCasts().
92    //
93    // We need to keep these values because AggBuffer::print decides whether to
94    // emit a "generic()" cast for Symbols[i] depending on the address space of
95    // SymbolsBeforeStripping[i].
96    SmallVector<const Value *, 4> SymbolsBeforeStripping;
97    unsigned curpos;
98    raw_ostream &O;
99    NVPTXAsmPrinter &AP;
100    bool EmitGeneric;
101
102  public:
103    AggBuffer(unsigned size, raw_ostream &O, NVPTXAsmPrinter &AP)
104        : size(size), buffer(size), O(O), AP(AP) {
105      curpos = 0;
106      numSymbols = 0;
107      EmitGeneric = AP.EmitGeneric;
108    }
109
110    unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) {
111      assert((curpos + Num) <= size);
112      assert((curpos + Bytes) <= size);
113      for (int i = 0; i < Num; ++i) {
114        buffer[curpos] = Ptr[i];
115        curpos++;
116      }
117      for (int i = Num; i < Bytes; ++i) {
118        buffer[curpos] = 0;
119        curpos++;
120      }
121      return curpos;
122    }
123
124    unsigned addZeros(int Num) {
125      assert((curpos + Num) <= size);
126      for (int i = 0; i < Num; ++i) {
127        buffer[curpos] = 0;
128        curpos++;
129      }
130      return curpos;
131    }
132
133    void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) {
134      symbolPosInBuffer.push_back(curpos);
135      Symbols.push_back(GVar);
136      SymbolsBeforeStripping.push_back(GVarBeforeStripping);
137      numSymbols++;
138    }
139
140    void print() {
141      if (numSymbols == 0) {
142        // print out in bytes
143        for (unsigned i = 0; i < size; i++) {
144          if (i)
145            O << ", ";
146          O << (unsigned int) buffer[i];
147        }
148      } else {
149        // print out in 4-bytes or 8-bytes
150        unsigned int pos = 0;
151        unsigned int nSym = 0;
152        unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
153        unsigned int nBytes = 4;
154        if (static_cast<const NVPTXTargetMachine &>(AP.TM).is64Bit())
155          nBytes = 8;
156        for (pos = 0; pos < size; pos += nBytes) {
157          if (pos)
158            O << ", ";
159          if (pos == nextSymbolPos) {
160            const Value *v = Symbols[nSym];
161            const Value *v0 = SymbolsBeforeStripping[nSym];
162            if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
163              MCSymbol *Name = AP.getSymbol(GVar);
164              PointerType *PTy = dyn_cast<PointerType>(v0->getType());
165              bool IsNonGenericPointer = false; // Is v0 a non-generic pointer?
166              if (PTy && PTy->getAddressSpace() != 0) {
167                IsNonGenericPointer = true;
168              }
169              if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
170                O << "generic(";
171                Name->print(O, AP.MAI);
172                O << ")";
173              } else {
174                Name->print(O, AP.MAI);
175              }
176            } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
177              const MCExpr *Expr =
178                AP.lowerConstantForGV(cast<Constant>(CExpr), false);
179              AP.printMCExpr(*Expr, O);
180            } else
181              llvm_unreachable("symbol type unknown");
182            nSym++;
183            if (nSym >= numSymbols)
184              nextSymbolPos = size + 1;
185            else
186              nextSymbolPos = symbolPosInBuffer[nSym];
187          } else if (nBytes == 4)
188            O << *(unsigned int *)(&buffer[pos]);
189          else
190            O << *(unsigned long long *)(&buffer[pos]);
191        }
192      }
193    }
194  };
195
196  friend class AggBuffer;
197
198private:
199  StringRef getPassName() const override { return "NVPTX Assembly Printer"; }
200
201  const Function *F;
202  std::string CurrentFnName;
203
204  void EmitBasicBlockStart(const MachineBasicBlock &MBB) const override;
205  void EmitFunctionEntryLabel() override;
206  void EmitFunctionBodyStart() override;
207  void EmitFunctionBodyEnd() override;
208  void emitImplicitDef(const MachineInstr *MI) const override;
209
210  void EmitInstruction(const MachineInstr *) override;
211  void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
212  bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
213  MCOperand GetSymbolRef(const MCSymbol *Symbol);
214  unsigned encodeVirtualRegister(unsigned Reg);
215
216  void printVecModifiedImmediate(const MachineOperand &MO, const char *Modifier,
217                                 raw_ostream &O);
218  void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &O,
219                       const char *Modifier = nullptr);
220  void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O,
221                          bool = false);
222  void printParamName(Function::const_arg_iterator I, int paramIndex,
223                      raw_ostream &O);
224  void emitGlobals(const Module &M);
225  void emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI);
226  void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const;
227  void emitVirtualRegister(unsigned int vr, raw_ostream &);
228  void emitFunctionParamList(const Function *, raw_ostream &O);
229  void emitFunctionParamList(const MachineFunction &MF, raw_ostream &O);
230  void setAndEmitFunctionVirtualRegisters(const MachineFunction &MF);
231  void printReturnValStr(const Function *, raw_ostream &O);
232  void printReturnValStr(const MachineFunction &MF, raw_ostream &O);
233  bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
234                       unsigned AsmVariant, const char *ExtraCode,
235                       raw_ostream &) override;
236  void printOperand(const MachineInstr *MI, int opNum, raw_ostream &O,
237                    const char *Modifier = nullptr);
238  bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo,
239                             unsigned AsmVariant, const char *ExtraCode,
240                             raw_ostream &) override;
241
242  const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric);
243  void printMCExpr(const MCExpr &Expr, raw_ostream &OS);
244
245protected:
246  bool doInitialization(Module &M) override;
247  bool doFinalization(Module &M) override;
248
249private:
250  bool GlobalsEmitted;
251
252  // This is specific per MachineFunction.
253  const MachineRegisterInfo *MRI;
254  // The contents are specific for each
255  // MachineFunction. But the size of the
256  // array is not.
257  typedef DenseMap<unsigned, unsigned> VRegMap;
258  typedef DenseMap<const TargetRegisterClass *, VRegMap> VRegRCMap;
259  VRegRCMap VRegMapping;
260
261  // Cache the subtarget here.
262  const NVPTXSubtarget *nvptxSubtarget;
263
264  // List of variables demoted to a function scope.
265  std::map<const Function *, std::vector<const GlobalVariable *>> localDecls;
266
267  void emitPTXGlobalVariable(const GlobalVariable *GVar, raw_ostream &O);
268  void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const;
269  std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const;
270  void printScalarConstant(const Constant *CPV, raw_ostream &O);
271  void printFPConstant(const ConstantFP *Fp, raw_ostream &O);
272  void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer);
273  void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer);
274
275  void emitLinkageDirective(const GlobalValue *V, raw_ostream &O);
276  void emitDeclarations(const Module &, raw_ostream &O);
277  void emitDeclaration(const Function *, raw_ostream &O);
278  void emitDemotedVars(const Function *, raw_ostream &);
279
280  bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
281                               MCOperand &MCOp);
282  void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
283
284  bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
285
286  // Used to control the need to emit .generic() in the initializer of
287  // module scope variables.
288  // Although ptx supports the hybrid mode like the following,
289  //    .global .u32 a;
290  //    .global .u32 b;
291  //    .global .u32 addr[] = {a, generic(b)}
292  // we have difficulty representing the difference in the NVVM IR.
293  //
294  // Since the address value should always be generic in CUDA C and always
295  // be specific in OpenCL, we use this simple control here.
296  //
297  bool EmitGeneric;
298
299public:
300  NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer)
301      : AsmPrinter(TM, std::move(Streamer)),
302        EmitGeneric(static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
303                    NVPTX::CUDA) {}
304
305  bool runOnMachineFunction(MachineFunction &F) override;
306
307  void getAnalysisUsage(AnalysisUsage &AU) const override {
308    AU.addRequired<MachineLoopInfo>();
309    AsmPrinter::getAnalysisUsage(AU);
310  }
311
312  std::string getVirtualRegisterName(unsigned) const;
313
314  const MCSymbol *getFunctionFrameSymbol() const override;
315};
316
317} // end namespace llvm
318
319#endif // LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
320