SITypeRewriter.cpp revision 303975
1//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10/// \file
11/// This pass removes performs the following type substitution on all
12/// non-compute shaders:
13///
14/// v16i8 => i128
15///   - v16i8 is used for constant memory resource descriptors.  This type is
16///      legal for some compute APIs, and we don't want to declare it as legal
17///      in the backend, because we want the legalizer to expand all v16i8
18///      operations.
19/// v1* => *
20///   - Having v1* types complicates the legalizer and we can easily replace
21///   - them with the element type.
22//===----------------------------------------------------------------------===//
23
24#include "AMDGPU.h"
25#include "Utils/AMDGPUBaseInfo.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/InstVisitor.h"
28
29using namespace llvm;
30
31namespace {
32
33class SITypeRewriter : public FunctionPass,
34                       public InstVisitor<SITypeRewriter> {
35
36  static char ID;
37  Module *Mod;
38  Type *v16i8;
39  Type *v4i32;
40
41public:
42  SITypeRewriter() : FunctionPass(ID) { }
43  bool doInitialization(Module &M) override;
44  bool runOnFunction(Function &F) override;
45  const char *getPassName() const override {
46    return "SI Type Rewriter";
47  }
48  void visitLoadInst(LoadInst &I);
49  void visitCallInst(CallInst &I);
50  void visitBitCast(BitCastInst &I);
51};
52
53} // End anonymous namespace
54
55char SITypeRewriter::ID = 0;
56
57bool SITypeRewriter::doInitialization(Module &M) {
58  Mod = &M;
59  v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
60  v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
61  return false;
62}
63
64bool SITypeRewriter::runOnFunction(Function &F) {
65  if (AMDGPU::getShaderType(F) == ShaderType::COMPUTE)
66    return false;
67
68  visit(F);
69  visit(F);
70
71  return false;
72}
73
74void SITypeRewriter::visitLoadInst(LoadInst &I) {
75  Value *Ptr = I.getPointerOperand();
76  Type *PtrTy = Ptr->getType();
77  Type *ElemTy = PtrTy->getPointerElementType();
78  IRBuilder<> Builder(&I);
79  if (ElemTy == v16i8)  {
80    Value *BitCast = Builder.CreateBitCast(Ptr,
81        PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
82    LoadInst *Load = Builder.CreateLoad(BitCast);
83    SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
84    I.getAllMetadataOtherThanDebugLoc(MD);
85    for (unsigned i = 0, e = MD.size(); i != e; ++i) {
86      Load->setMetadata(MD[i].first, MD[i].second);
87    }
88    Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
89    I.replaceAllUsesWith(BitCastLoad);
90    I.eraseFromParent();
91  }
92}
93
94void SITypeRewriter::visitCallInst(CallInst &I) {
95  IRBuilder<> Builder(&I);
96
97  SmallVector <Value*, 8> Args;
98  SmallVector <Type*, 8> Types;
99  bool NeedToReplace = false;
100  Function *F = I.getCalledFunction();
101  if (!F)
102    return;
103
104  std::string Name = F->getName();
105  for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
106    Value *Arg = I.getArgOperand(i);
107    if (Arg->getType() == v16i8) {
108      Args.push_back(Builder.CreateBitCast(Arg, v4i32));
109      Types.push_back(v4i32);
110      NeedToReplace = true;
111      Name = Name + ".v4i32";
112    } else if (Arg->getType()->isVectorTy() &&
113               Arg->getType()->getVectorNumElements() == 1 &&
114               Arg->getType()->getVectorElementType() ==
115                                              Type::getInt32Ty(I.getContext())){
116      Type *ElementTy = Arg->getType()->getVectorElementType();
117      std::string TypeName = "i32";
118      InsertElementInst *Def = cast<InsertElementInst>(Arg);
119      Args.push_back(Def->getOperand(1));
120      Types.push_back(ElementTy);
121      std::string VecTypeName = "v1" + TypeName;
122      Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
123      NeedToReplace = true;
124    } else {
125      Args.push_back(Arg);
126      Types.push_back(Arg->getType());
127    }
128  }
129
130  if (!NeedToReplace) {
131    return;
132  }
133  Function *NewF = Mod->getFunction(Name);
134  if (!NewF) {
135    NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
136    NewF->setAttributes(F->getAttributes());
137  }
138  I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
139  I.eraseFromParent();
140}
141
142void SITypeRewriter::visitBitCast(BitCastInst &I) {
143  IRBuilder<> Builder(&I);
144  if (I.getDestTy() != v4i32) {
145    return;
146  }
147
148  if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
149    if (Op->getSrcTy() == v4i32) {
150      I.replaceAllUsesWith(Op->getOperand(0));
151      I.eraseFromParent();
152    }
153  }
154}
155
156FunctionPass *llvm::createSITypeRewriter() {
157  return new SITypeRewriter();
158}
159