SITypeRewriter.cpp revision 263508
1//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10/// \file 11/// This pass removes performs the following type substitution on all 12/// non-compute shaders: 13/// 14/// v16i8 => i128 15/// - v16i8 is used for constant memory resource descriptors. This type is 16/// legal for some compute APIs, and we don't want to declare it as legal 17/// in the backend, because we want the legalizer to expand all v16i8 18/// operations. 19/// v1* => * 20/// - Having v1* types complicates the legalizer and we can easily replace 21/// - them with the element type. 22//===----------------------------------------------------------------------===// 23 24#include "AMDGPU.h" 25 26#include "llvm/IR/IRBuilder.h" 27#include "llvm/InstVisitor.h" 28 29using namespace llvm; 30 31namespace { 32 33class SITypeRewriter : public FunctionPass, 34 public InstVisitor<SITypeRewriter> { 35 36 static char ID; 37 Module *Mod; 38 Type *v16i8; 39 Type *i128; 40 41public: 42 SITypeRewriter() : FunctionPass(ID) { } 43 virtual bool doInitialization(Module &M); 44 virtual bool runOnFunction(Function &F); 45 virtual const char *getPassName() const { 46 return "SI Type Rewriter"; 47 } 48 void visitLoadInst(LoadInst &I); 49 void visitCallInst(CallInst &I); 50 void visitBitCast(BitCastInst &I); 51}; 52 53} // End anonymous namespace 54 55char SITypeRewriter::ID = 0; 56 57bool SITypeRewriter::doInitialization(Module &M) { 58 Mod = &M; 59 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 60 i128 = Type::getIntNTy(M.getContext(), 128); 61 return false; 62} 63 64bool SITypeRewriter::runOnFunction(Function &F) { 65 AttributeSet Set = F.getAttributes(); 66 Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType"); 67 68 unsigned ShaderType = ShaderType::COMPUTE; 69 if (A.isStringAttribute()) { 70 StringRef Str = A.getValueAsString(); 71 Str.getAsInteger(0, ShaderType); 72 } 73 if (ShaderType != ShaderType::COMPUTE) { 74 visit(F); 75 } 76 77 visit(F); 78 79 return false; 80} 81 82void SITypeRewriter::visitLoadInst(LoadInst &I) { 83 Value *Ptr = I.getPointerOperand(); 84 Type *PtrTy = Ptr->getType(); 85 Type *ElemTy = PtrTy->getPointerElementType(); 86 IRBuilder<> Builder(&I); 87 if (ElemTy == v16i8) { 88 Value *BitCast = Builder.CreateBitCast(Ptr, Type::getIntNPtrTy(I.getContext(), 128, 2)); 89 LoadInst *Load = Builder.CreateLoad(BitCast); 90 SmallVector <std::pair<unsigned, MDNode*>, 8> MD; 91 I.getAllMetadataOtherThanDebugLoc(MD); 92 for (unsigned i = 0, e = MD.size(); i != e; ++i) { 93 Load->setMetadata(MD[i].first, MD[i].second); 94 } 95 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 96 I.replaceAllUsesWith(BitCastLoad); 97 I.eraseFromParent(); 98 } 99} 100 101void SITypeRewriter::visitCallInst(CallInst &I) { 102 IRBuilder<> Builder(&I); 103 SmallVector <Value*, 8> Args; 104 SmallVector <Type*, 8> Types; 105 bool NeedToReplace = false; 106 Function *F = I.getCalledFunction(); 107 std::string Name = F->getName().str(); 108 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 109 Value *Arg = I.getArgOperand(i); 110 if (Arg->getType() == v16i8) { 111 Args.push_back(Builder.CreateBitCast(Arg, i128)); 112 Types.push_back(i128); 113 NeedToReplace = true; 114 Name = Name + ".i128"; 115 } else if (Arg->getType()->isVectorTy() && 116 Arg->getType()->getVectorNumElements() == 1 && 117 Arg->getType()->getVectorElementType() == 118 Type::getInt32Ty(I.getContext())){ 119 Type *ElementTy = Arg->getType()->getVectorElementType(); 120 std::string TypeName = "i32"; 121 InsertElementInst *Def = dyn_cast<InsertElementInst>(Arg); 122 assert(Def); 123 Args.push_back(Def->getOperand(1)); 124 Types.push_back(ElementTy); 125 std::string VecTypeName = "v1" + TypeName; 126 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 127 NeedToReplace = true; 128 } else { 129 Args.push_back(Arg); 130 Types.push_back(Arg->getType()); 131 } 132 } 133 134 if (!NeedToReplace) { 135 return; 136 } 137 Function *NewF = Mod->getFunction(Name); 138 if (!NewF) { 139 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 140 NewF->setAttributes(F->getAttributes()); 141 } 142 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 143 I.eraseFromParent(); 144} 145 146void SITypeRewriter::visitBitCast(BitCastInst &I) { 147 IRBuilder<> Builder(&I); 148 if (I.getDestTy() != i128) { 149 return; 150 } 151 152 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 153 if (Op->getSrcTy() == i128) { 154 I.replaceAllUsesWith(Op->getOperand(0)); 155 I.eraseFromParent(); 156 } 157 } 158} 159 160FunctionPass *llvm::createSITypeRewriter() { 161 return new SITypeRewriter(); 162} 163