SITypeRewriter.cpp revision 284734
129415Sjmg//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2119853Scg//
339899Sluigi//                     The LLVM Compiler Infrastructure
429415Sjmg//
529415Sjmg// This file is distributed under the University of Illinois Open Source
629415Sjmg// License. See LICENSE.TXT for details.
729415Sjmg//
850723Scg//===----------------------------------------------------------------------===//
950723Scg//
1029415Sjmg/// \file
1129415Sjmg/// This pass removes performs the following type substitution on all
1230869Sjmg/// non-compute shaders:
1330869Sjmg///
1430869Sjmg/// v16i8 => i128
1530869Sjmg///   - v16i8 is used for constant memory resource descriptors.  This type is
1650723Scg///      legal for some compute APIs, and we don't want to declare it as legal
1750723Scg///      in the backend, because we want the legalizer to expand all v16i8
1830869Sjmg///      operations.
1950723Scg/// v1* => *
2050723Scg///   - Having v1* types complicates the legalizer and we can easily replace
2150723Scg///   - them with the element type.
2250723Scg//===----------------------------------------------------------------------===//
2350723Scg
2450723Scg#include "AMDGPU.h"
2550723Scg#include "llvm/IR/IRBuilder.h"
2650723Scg#include "llvm/IR/InstVisitor.h"
2750723Scg
2850723Scgusing namespace llvm;
2950723Scg
3029415Sjmgnamespace {
3129415Sjmg
3253465Scgclass SITypeRewriter : public FunctionPass,
3329415Sjmg                       public InstVisitor<SITypeRewriter> {
3453465Scg
3553553Stanimura  static char ID;
3629415Sjmg  Module *Mod;
37110499Snyan  Type *v16i8;
38110499Snyan  Type *v4i32;
3970134Scg
4070134Scgpublic:
4182180Scg  SITypeRewriter() : FunctionPass(ID) { }
4282180Scg  bool doInitialization(Module &M) override;
4367803Scg  bool runOnFunction(Function &F) override;
4455706Scg  const char *getPassName() const override {
4555254Scg    return "SI Type Rewriter";
4667803Scg  }
4750723Scg  void visitLoadInst(LoadInst &I);
4850723Scg  void visitCallInst(CallInst &I);
4964881Scg  void visitBitCast(BitCastInst &I);
5050723Scg};
5174763Scg
5229415Sjmg} // End anonymous namespace
5367803Scg
5464881Scgchar SITypeRewriter::ID = 0;
5550723Scg
5664881Scgbool SITypeRewriter::doInitialization(Module &M) {
5750723Scg  Mod = &M;
5874763Scg  v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
5929415Sjmg  v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
6064881Scg  return false;
6164881Scg}
6264881Scg
6364881Scgbool SITypeRewriter::runOnFunction(Function &F) {
6464881Scg  Attribute A = F.getFnAttribute("ShaderType");
6564881Scg
6654462Scg  unsigned ShaderType = ShaderType::COMPUTE;
6774763Scg  if (A.isStringAttribute()) {
6854462Scg    StringRef Str = A.getValueAsString();
6950723Scg    Str.getAsInteger(0, ShaderType);
7029415Sjmg  }
7150723Scg  if (ShaderType == ShaderType::COMPUTE)
7250723Scg    return false;
7374763Scg
7474763Scg  visit(F);
7567803Scg  visit(F);
7667803Scg
7750723Scg  return false;
7829415Sjmg}
7950723Scg
8050723Scgvoid SITypeRewriter::visitLoadInst(LoadInst &I) {
8150723Scg  Value *Ptr = I.getPointerOperand();
8254462Scg  Type *PtrTy = Ptr->getType();
8354462Scg  Type *ElemTy = PtrTy->getPointerElementType();
8465644Scg  IRBuilder<> Builder(&I);
8555706Scg  if (ElemTy == v16i8)  {
8629415Sjmg    Value *BitCast = Builder.CreateBitCast(Ptr,
8784111Scg        PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
8850723Scg    LoadInst *Load = Builder.CreateLoad(BitCast);
8950723Scg    SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
9070291Scg    I.getAllMetadataOtherThanDebugLoc(MD);
9150723Scg    for (unsigned i = 0, e = MD.size(); i != e; ++i) {
9274763Scg      Load->setMetadata(MD[i].first, MD[i].second);
9350723Scg    }
9450723Scg    Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
9584111Scg    I.replaceAllUsesWith(BitCastLoad);
9674763Scg    I.eraseFromParent();
9774763Scg  }
9850723Scg}
9950723Scg
10050723Scgvoid SITypeRewriter::visitCallInst(CallInst &I) {
10167803Scg  IRBuilder<> Builder(&I);
10250723Scg
10350723Scg  SmallVector <Value*, 8> Args;
10450723Scg  SmallVector <Type*, 8> Types;
10550723Scg  bool NeedToReplace = false;
10654462Scg  Function *F = I.getCalledFunction();
10729415Sjmg  std::string Name = F->getName();
10850723Scg  for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
10984111Scg    Value *Arg = I.getArgOperand(i);
11050723Scg    if (Arg->getType() == v16i8) {
11129415Sjmg      Args.push_back(Builder.CreateBitCast(Arg, v4i32));
11250723Scg      Types.push_back(v4i32);
11329415Sjmg      NeedToReplace = true;
11450723Scg      Name = Name + ".v4i32";
11550723Scg    } else if (Arg->getType()->isVectorTy() &&
11650723Scg               Arg->getType()->getVectorNumElements() == 1 &&
11750723Scg               Arg->getType()->getVectorElementType() ==
11829415Sjmg                                              Type::getInt32Ty(I.getContext())){
11929415Sjmg      Type *ElementTy = Arg->getType()->getVectorElementType();
12084111Scg      std::string TypeName = "i32";
12184111Scg      InsertElementInst *Def = cast<InsertElementInst>(Arg);
12284111Scg      Args.push_back(Def->getOperand(1));
12384111Scg      Types.push_back(ElementTy);
12484111Scg      std::string VecTypeName = "v1" + TypeName;
12584111Scg      Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
12684111Scg      NeedToReplace = true;
127129180Struckman    } else {
128129180Struckman      Args.push_back(Arg);
129129180Struckman      Types.push_back(Arg->getType());
130129180Struckman    }
131129180Struckman  }
132129180Struckman
13384111Scg  if (!NeedToReplace) {
13484111Scg    return;
13584111Scg  }
13684111Scg  Function *NewF = Mod->getFunction(Name);
13784111Scg  if (!NewF) {
13829415Sjmg    NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
13950723Scg    NewF->setAttributes(F->getAttributes());
14029415Sjmg  }
14167803Scg  I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
14250723Scg  I.eraseFromParent();
14329415Sjmg}
14450723Scg
14550723Scgvoid SITypeRewriter::visitBitCast(BitCastInst &I) {
14650723Scg  IRBuilder<> Builder(&I);
147108064Ssemenu  if (I.getDestTy() != v4i32) {
14829415Sjmg    return;
14929415Sjmg  }
15029415Sjmg
15150723Scg  if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
15229415Sjmg    if (Op->getSrcTy() == v4i32) {
15350723Scg      I.replaceAllUsesWith(Op->getOperand(0));
15450723Scg      I.eraseFromParent();
15529415Sjmg    }
15650723Scg  }
15750723Scg}
15850723Scg
15950723ScgFunctionPass *llvm::createSITypeRewriter() {
16029415Sjmg  return new SITypeRewriter();
16129415Sjmg}
16250723Scg