SITypeRewriter.cpp revision 284734
129415Sjmg//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 2119853Scg// 339899Sluigi// The LLVM Compiler Infrastructure 429415Sjmg// 529415Sjmg// This file is distributed under the University of Illinois Open Source 629415Sjmg// License. See LICENSE.TXT for details. 729415Sjmg// 850723Scg//===----------------------------------------------------------------------===// 950723Scg// 1029415Sjmg/// \file 1129415Sjmg/// This pass removes performs the following type substitution on all 1230869Sjmg/// non-compute shaders: 1330869Sjmg/// 1430869Sjmg/// v16i8 => i128 1530869Sjmg/// - v16i8 is used for constant memory resource descriptors. This type is 1650723Scg/// legal for some compute APIs, and we don't want to declare it as legal 1750723Scg/// in the backend, because we want the legalizer to expand all v16i8 1830869Sjmg/// operations. 1950723Scg/// v1* => * 2050723Scg/// - Having v1* types complicates the legalizer and we can easily replace 2150723Scg/// - them with the element type. 2250723Scg//===----------------------------------------------------------------------===// 2350723Scg 2450723Scg#include "AMDGPU.h" 2550723Scg#include "llvm/IR/IRBuilder.h" 2650723Scg#include "llvm/IR/InstVisitor.h" 2750723Scg 2850723Scgusing namespace llvm; 2950723Scg 3029415Sjmgnamespace { 3129415Sjmg 3253465Scgclass SITypeRewriter : public FunctionPass, 3329415Sjmg public InstVisitor<SITypeRewriter> { 3453465Scg 3553553Stanimura static char ID; 3629415Sjmg Module *Mod; 37110499Snyan Type *v16i8; 38110499Snyan Type *v4i32; 3970134Scg 4070134Scgpublic: 4182180Scg SITypeRewriter() : FunctionPass(ID) { } 4282180Scg bool doInitialization(Module &M) override; 4367803Scg bool runOnFunction(Function &F) override; 4455706Scg const char *getPassName() const override { 4555254Scg return "SI Type Rewriter"; 4667803Scg } 4750723Scg void visitLoadInst(LoadInst &I); 4850723Scg void visitCallInst(CallInst &I); 4964881Scg void visitBitCast(BitCastInst &I); 5050723Scg}; 5174763Scg 5229415Sjmg} // End anonymous namespace 5367803Scg 5464881Scgchar SITypeRewriter::ID = 0; 5550723Scg 5664881Scgbool SITypeRewriter::doInitialization(Module &M) { 5750723Scg Mod = &M; 5874763Scg v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 5929415Sjmg v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4); 6064881Scg return false; 6164881Scg} 6264881Scg 6364881Scgbool SITypeRewriter::runOnFunction(Function &F) { 6464881Scg Attribute A = F.getFnAttribute("ShaderType"); 6564881Scg 6654462Scg unsigned ShaderType = ShaderType::COMPUTE; 6774763Scg if (A.isStringAttribute()) { 6854462Scg StringRef Str = A.getValueAsString(); 6950723Scg Str.getAsInteger(0, ShaderType); 7029415Sjmg } 7150723Scg if (ShaderType == ShaderType::COMPUTE) 7250723Scg return false; 7374763Scg 7474763Scg visit(F); 7567803Scg visit(F); 7667803Scg 7750723Scg return false; 7829415Sjmg} 7950723Scg 8050723Scgvoid SITypeRewriter::visitLoadInst(LoadInst &I) { 8150723Scg Value *Ptr = I.getPointerOperand(); 8254462Scg Type *PtrTy = Ptr->getType(); 8354462Scg Type *ElemTy = PtrTy->getPointerElementType(); 8465644Scg IRBuilder<> Builder(&I); 8555706Scg if (ElemTy == v16i8) { 8629415Sjmg Value *BitCast = Builder.CreateBitCast(Ptr, 8784111Scg PointerType::get(v4i32,PtrTy->getPointerAddressSpace())); 8850723Scg LoadInst *Load = Builder.CreateLoad(BitCast); 8950723Scg SmallVector<std::pair<unsigned, MDNode *>, 8> MD; 9070291Scg I.getAllMetadataOtherThanDebugLoc(MD); 9150723Scg for (unsigned i = 0, e = MD.size(); i != e; ++i) { 9274763Scg Load->setMetadata(MD[i].first, MD[i].second); 9350723Scg } 9450723Scg Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 9584111Scg I.replaceAllUsesWith(BitCastLoad); 9674763Scg I.eraseFromParent(); 9774763Scg } 9850723Scg} 9950723Scg 10050723Scgvoid SITypeRewriter::visitCallInst(CallInst &I) { 10167803Scg IRBuilder<> Builder(&I); 10250723Scg 10350723Scg SmallVector <Value*, 8> Args; 10450723Scg SmallVector <Type*, 8> Types; 10550723Scg bool NeedToReplace = false; 10654462Scg Function *F = I.getCalledFunction(); 10729415Sjmg std::string Name = F->getName(); 10850723Scg for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 10984111Scg Value *Arg = I.getArgOperand(i); 11050723Scg if (Arg->getType() == v16i8) { 11129415Sjmg Args.push_back(Builder.CreateBitCast(Arg, v4i32)); 11250723Scg Types.push_back(v4i32); 11329415Sjmg NeedToReplace = true; 11450723Scg Name = Name + ".v4i32"; 11550723Scg } else if (Arg->getType()->isVectorTy() && 11650723Scg Arg->getType()->getVectorNumElements() == 1 && 11750723Scg Arg->getType()->getVectorElementType() == 11829415Sjmg Type::getInt32Ty(I.getContext())){ 11929415Sjmg Type *ElementTy = Arg->getType()->getVectorElementType(); 12084111Scg std::string TypeName = "i32"; 12184111Scg InsertElementInst *Def = cast<InsertElementInst>(Arg); 12284111Scg Args.push_back(Def->getOperand(1)); 12384111Scg Types.push_back(ElementTy); 12484111Scg std::string VecTypeName = "v1" + TypeName; 12584111Scg Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 12684111Scg NeedToReplace = true; 127129180Struckman } else { 128129180Struckman Args.push_back(Arg); 129129180Struckman Types.push_back(Arg->getType()); 130129180Struckman } 131129180Struckman } 132129180Struckman 13384111Scg if (!NeedToReplace) { 13484111Scg return; 13584111Scg } 13684111Scg Function *NewF = Mod->getFunction(Name); 13784111Scg if (!NewF) { 13829415Sjmg NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 13950723Scg NewF->setAttributes(F->getAttributes()); 14029415Sjmg } 14167803Scg I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 14250723Scg I.eraseFromParent(); 14329415Sjmg} 14450723Scg 14550723Scgvoid SITypeRewriter::visitBitCast(BitCastInst &I) { 14650723Scg IRBuilder<> Builder(&I); 147108064Ssemenu if (I.getDestTy() != v4i32) { 14829415Sjmg return; 14929415Sjmg } 15029415Sjmg 15150723Scg if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 15229415Sjmg if (Op->getSrcTy() == v4i32) { 15350723Scg I.replaceAllUsesWith(Op->getOperand(0)); 15450723Scg I.eraseFromParent(); 15529415Sjmg } 15650723Scg } 15750723Scg} 15850723Scg 15950723ScgFunctionPass *llvm::createSITypeRewriter() { 16029415Sjmg return new SITypeRewriter(); 16129415Sjmg} 16250723Scg