1//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10// the pass was taken from SPIRV-LLVM translator.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRV.h"
15#include "SPIRVTargetMachine.h"
16#include "llvm/Demangle/Demangle.h"
17#include "llvm/IR/InstIterator.h"
18#include "llvm/IR/InstVisitor.h"
19#include "llvm/IR/PassManager.h"
20#include "llvm/Transforms/Utils/Cloning.h"
21
22#include <list>
23
24#define DEBUG_TYPE "spirv-regularizer"
25
26using namespace llvm;
27
28namespace llvm {
29void initializeSPIRVRegularizerPass(PassRegistry &);
30}
31
32namespace {
33struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
34  DenseMap<Function *, Function *> Old2NewFuncs;
35
36public:
37  static char ID;
38  SPIRVRegularizer() : FunctionPass(ID) {
39    initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
40  }
41  bool runOnFunction(Function &F) override;
42  StringRef getPassName() const override { return "SPIR-V Regularizer"; }
43
44  void getAnalysisUsage(AnalysisUsage &AU) const override {
45    FunctionPass::getAnalysisUsage(AU);
46  }
47  void visitCallInst(CallInst &CI);
48
49private:
50  void visitCallScalToVec(CallInst *CI, StringRef MangledName,
51                          StringRef DemangledName);
52  void runLowerConstExpr(Function &F);
53};
54} // namespace
55
56char SPIRVRegularizer::ID = 0;
57
58INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
59                false)
60
61// Since SPIR-V cannot represent constant expression, constant expressions
62// in LLVM IR need to be lowered to instructions. For each function,
63// the constant expressions used by instructions of the function are replaced
64// by instructions placed in the entry block since it dominates all other BBs.
65// Each constant expression only needs to be lowered once in each function
66// and all uses of it by instructions in that function are replaced by
67// one instruction.
68// TODO: remove redundant instructions for common subexpression.
69void SPIRVRegularizer::runLowerConstExpr(Function &F) {
70  LLVMContext &Ctx = F.getContext();
71  std::list<Instruction *> WorkList;
72  for (auto &II : instructions(F))
73    WorkList.push_back(&II);
74
75  auto FBegin = F.begin();
76  while (!WorkList.empty()) {
77    Instruction *II = WorkList.front();
78
79    auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
80      if (isa<Function>(V))
81        return V;
82      auto *CE = cast<ConstantExpr>(V);
83      LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
84      auto ReplInst = CE->getAsInstruction();
85      auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86      ReplInst->insertBefore(InsPoint);
87      LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
88      std::vector<Instruction *> Users;
89      // Do not replace use during iteration of use. Do it in another loop.
90      for (auto U : CE->users()) {
91        LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
92        auto InstUser = dyn_cast<Instruction>(U);
93        // Only replace users in scope of current function.
94        if (InstUser && InstUser->getParent()->getParent() == &F)
95          Users.push_back(InstUser);
96      }
97      for (auto &User : Users) {
98        if (ReplInst->getParent() == User->getParent() &&
99            User->comesBefore(ReplInst))
100          ReplInst->moveBefore(User);
101        User->replaceUsesOfWith(CE, ReplInst);
102      }
103      return ReplInst;
104    };
105
106    WorkList.pop_front();
107    auto LowerConstantVec = [&II, &LowerOp, &WorkList,
108                             &Ctx](ConstantVector *Vec,
109                                   unsigned NumOfOp) -> Value * {
110      if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
111            return isa<ConstantExpr>(V) || isa<Function>(V);
112          })) {
113        // Expand a vector of constexprs and construct it back with
114        // series of insertelement instructions.
115        std::list<Value *> OpList;
116        std::transform(Vec->op_begin(), Vec->op_end(),
117                       std::back_inserter(OpList),
118                       [LowerOp](Value *V) { return LowerOp(V); });
119        Value *Repl = nullptr;
120        unsigned Idx = 0;
121        auto *PhiII = dyn_cast<PHINode>(II);
122        Instruction *InsPoint =
123            PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
124        std::list<Instruction *> ReplList;
125        for (auto V : OpList) {
126          if (auto *Inst = dyn_cast<Instruction>(V))
127            ReplList.push_back(Inst);
128          Repl = InsertElementInst::Create(
129              (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130              ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint);
131        }
132        WorkList.splice(WorkList.begin(), ReplList);
133        return Repl;
134      }
135      return nullptr;
136    };
137    for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
138      auto *Op = II->getOperand(OI);
139      if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
140        Value *ReplInst = LowerConstantVec(Vec, OI);
141        if (ReplInst)
142          II->replaceUsesOfWith(Op, ReplInst);
143      } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
144        WorkList.push_front(cast<Instruction>(LowerOp(CE)));
145      } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
146        auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
147        if (!ConstMD)
148          continue;
149        Constant *C = ConstMD->getValue();
150        Value *ReplInst = nullptr;
151        if (auto *Vec = dyn_cast<ConstantVector>(C))
152          ReplInst = LowerConstantVec(Vec, OI);
153        if (auto *CE = dyn_cast<ConstantExpr>(C))
154          ReplInst = LowerOp(CE);
155        if (!ReplInst)
156          continue;
157        Metadata *RepMD = ValueAsMetadata::get(ReplInst);
158        Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
159        II->setOperand(OI, RepMDVal);
160        WorkList.push_front(cast<Instruction>(ReplInst));
161      }
162    }
163  }
164}
165
166// It fixes calls to OCL builtins that accept vector arguments and one of them
167// is actually a scalar splat.
168void SPIRVRegularizer::visitCallInst(CallInst &CI) {
169  auto F = CI.getCalledFunction();
170  if (!F)
171    return;
172
173  auto MangledName = F->getName();
174  char *NameStr = itaniumDemangle(F->getName().data());
175  if (!NameStr)
176    return;
177  StringRef DemangledName(NameStr);
178
179  // TODO: add support for other builtins.
180  if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
181      DemangledName.starts_with("min") || DemangledName.starts_with("max"))
182    visitCallScalToVec(&CI, MangledName, DemangledName);
183  free(NameStr);
184}
185
186void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
187                                          StringRef DemangledName) {
188  // Check if all arguments have the same type - it's simple case.
189  auto Uniform = true;
190  Type *Arg0Ty = CI->getOperand(0)->getType();
191  auto IsArg0Vector = isa<VectorType>(Arg0Ty);
192  for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
193    Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
194  if (Uniform)
195    return;
196
197  auto *OldF = CI->getCalledFunction();
198  Function *NewF = nullptr;
199  if (!Old2NewFuncs.count(OldF)) {
200    AttributeList Attrs = CI->getCalledFunction()->getAttributes();
201    SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
202    auto *NewFTy =
203        FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
204    NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
205                            *OldF->getParent());
206    ValueToValueMapTy VMap;
207    auto NewFArgIt = NewF->arg_begin();
208    for (auto &Arg : OldF->args()) {
209      auto ArgName = Arg.getName();
210      NewFArgIt->setName(ArgName);
211      VMap[&Arg] = &(*NewFArgIt++);
212    }
213    SmallVector<ReturnInst *, 8> Returns;
214    CloneFunctionInto(NewF, OldF, VMap,
215                      CloneFunctionChangeType::LocalChangesOnly, Returns);
216    NewF->setAttributes(Attrs);
217    Old2NewFuncs[OldF] = NewF;
218  } else {
219    NewF = Old2NewFuncs[OldF];
220  }
221  assert(NewF);
222
223  // This produces an instruction sequence that implements a splat of
224  // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
225  // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
226  // For instance (transcoding/OpMin.ll), this call
227  //   call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
228  // is translated to
229  //    %8 = OpUndef %v2uint
230  //   %14 = OpConstantComposite %v2uint %uint_1 %uint_10
231  //   ...
232  //   %10 = OpCompositeInsert %v2uint %uint_5 %8 0
233  //   %11 = OpVectorShuffle %v2uint %10 %8 0 0
234  // %call = OpExtInst %v2uint %1 s_min %14 %11
235  auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
236  PoisonValue *PVal = PoisonValue::get(Arg0Ty);
237  Instruction *Inst =
238      InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);
239  ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
240  Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
241  Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI);
242  CI->setOperand(1, NewVec);
243  CI->replaceUsesOfWith(OldF, NewF);
244  CI->mutateFunctionType(NewF->getFunctionType());
245}
246
247bool SPIRVRegularizer::runOnFunction(Function &F) {
248  runLowerConstExpr(F);
249  visit(F);
250  for (auto &OldNew : Old2NewFuncs) {
251    Function *OldF = OldNew.first;
252    Function *NewF = OldNew.second;
253    NewF->takeName(OldF);
254    OldF->eraseFromParent();
255  }
256  return true;
257}
258
259FunctionPass *llvm::createSPIRVRegularizerPass() {
260  return new SPIRVRegularizer();
261}
262