1//=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Replaces LLVM IR instructions with vector operands (i.e., the frem
10// instruction or calls to LLVM intrinsics) with matching calls to functions
11// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/CodeGen/ReplaceWithVeclib.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/ADT/StringRef.h"
19#include "llvm/Analysis/DemandedBits.h"
20#include "llvm/Analysis/GlobalsModRef.h"
21#include "llvm/Analysis/OptimizationRemarkEmitter.h"
22#include "llvm/Analysis/TargetLibraryInfo.h"
23#include "llvm/Analysis/VectorUtils.h"
24#include "llvm/CodeGen/Passes.h"
25#include "llvm/IR/DerivedTypes.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/InstIterator.h"
28#include "llvm/IR/VFABIDemangler.h"
29#include "llvm/Support/TypeSize.h"
30#include "llvm/Transforms/Utils/ModuleUtils.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "replace-with-veclib"
35
36STATISTIC(NumCallsReplaced,
37          "Number of calls to intrinsics that have been replaced.");
38
39STATISTIC(NumTLIFuncDeclAdded,
40          "Number of vector library function declarations added.");
41
42STATISTIC(NumFuncUsedAdded,
43          "Number of functions added to `llvm.compiler.used`");
44
45/// Returns a vector Function that it adds to the Module \p M. When an \p
46/// ScalarFunc is not null, it copies its attributes to the newly created
47/// Function.
48Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
49                         const StringRef TLIName,
50                         Function *ScalarFunc = nullptr) {
51  Function *TLIFunc = M->getFunction(TLIName);
52  if (!TLIFunc) {
53    TLIFunc =
54        Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
55    if (ScalarFunc)
56      TLIFunc->copyAttributesFrom(ScalarFunc);
57
58    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
59                      << TLIName << "` of type `" << *(TLIFunc->getType())
60                      << "` to module.\n");
61
62    ++NumTLIFuncDeclAdded;
63    // Add the freshly created function to llvm.compiler.used, similar to as it
64    // is done in InjectTLIMappings.
65    appendToCompilerUsed(*M, {TLIFunc});
66    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
67                      << "` to `@llvm.compiler.used`.\n");
68    ++NumFuncUsedAdded;
69  }
70  return TLIFunc;
71}
72
73/// Replace the instruction \p I with a call to the corresponding function from
74/// the vector library (\p TLIVecFunc).
75static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
76                                   Function *TLIVecFunc) {
77  IRBuilder<> IRBuilder(&I);
78  auto *CI = dyn_cast<CallInst>(&I);
79  SmallVector<Value *> Args(CI ? CI->args() : I.operands());
80  if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
81    auto *MaskTy =
82        VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
83    Args.insert(Args.begin() + OptMaskpos.value(),
84                Constant::getAllOnesValue(MaskTy));
85  }
86
87  // If it is a call instruction, preserve the operand bundles.
88  SmallVector<OperandBundleDef, 1> OpBundles;
89  if (CI)
90    CI->getOperandBundlesAsDefs(OpBundles);
91
92  auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
93  I.replaceAllUsesWith(Replacement);
94  // Preserve fast math flags for FP math.
95  if (isa<FPMathOperator>(Replacement))
96    Replacement->copyFastMathFlags(&I);
97}
98
99/// Returns true when successfully replaced \p I with a suitable function taking
100/// vector arguments, based on available mappings in the \p TLI. Currently only
101/// works when \p I is a call to vectorized intrinsic or the frem instruction.
102static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
103                                    Instruction &I) {
104  // At the moment VFABI assumes the return type is always widened unless it is
105  // a void type.
106  auto *VTy = dyn_cast<VectorType>(I.getType());
107  ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
108
109  // Compute the argument types of the corresponding scalar call and the scalar
110  // function name. For calls, it additionally finds the function to replace
111  // and checks that all vector operands match the previously found EC.
112  SmallVector<Type *, 8> ScalarArgTypes;
113  std::string ScalarName;
114  Function *FuncToReplace = nullptr;
115  auto *CI = dyn_cast<CallInst>(&I);
116  if (CI) {
117    FuncToReplace = CI->getCalledFunction();
118    Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
119    assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
120    for (auto Arg : enumerate(CI->args())) {
121      auto *ArgTy = Arg.value()->getType();
122      if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
123        ScalarArgTypes.push_back(ArgTy);
124      } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
125        ScalarArgTypes.push_back(VectorArgTy->getElementType());
126        // When return type is void, set EC to the first vector argument, and
127        // disallow vector arguments with different ECs.
128        if (EC.isZero())
129          EC = VectorArgTy->getElementCount();
130        else if (EC != VectorArgTy->getElementCount())
131          return false;
132      } else
133        // Exit when it is supposed to be a vector argument but it isn't.
134        return false;
135    }
136    // Try to reconstruct the name for the scalar version of the instruction,
137    // using scalar argument types.
138    ScalarName = Intrinsic::isOverloaded(IID)
139                     ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
140                     : Intrinsic::getName(IID).str();
141  } else {
142    assert(VTy && "Return type must be a vector");
143    auto *ScalarTy = VTy->getScalarType();
144    LibFunc Func;
145    if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
146      return false;
147    ScalarName = TLI.getName(Func);
148    ScalarArgTypes = {ScalarTy, ScalarTy};
149  }
150
151  // Try to find the mapping for the scalar version of this intrinsic and the
152  // exact vector width of the call operands in the TargetLibraryInfo. First,
153  // check with a non-masked variant, and if that fails try with a masked one.
154  const VecDesc *VD =
155      TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
156  if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
157    return false;
158
159  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
160                    << "` and vector width " << EC << " to: `"
161                    << VD->getVectorFnName() << "`.\n");
162
163  // Replace the call to the intrinsic with a call to the vector library
164  // function.
165  Type *ScalarRetTy = I.getType()->getScalarType();
166  FunctionType *ScalarFTy =
167      FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
168  const std::string MangledName = VD->getVectorFunctionABIVariantString();
169  auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
170  if (!OptInfo)
171    return false;
172
173  // There is no guarantee that the vectorized instructions followed the VFABI
174  // specification when being created, this is why we need to add extra check to
175  // make sure that the operands of the vector function obtained via VFABI match
176  // the operands of the original vector instruction.
177  if (CI) {
178    for (auto VFParam : OptInfo->Shape.Parameters) {
179      if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
180        continue;
181
182      // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
183      // a bug in the VFABI parser.
184      assert(VFParam.ParamPos < CI->arg_size() &&
185             "ParamPos has invalid range.");
186      Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
187      if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
188        LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
189                          << ". Wrong type at index " << VFParam.ParamPos
190                          << ": " << *OrigTy << "\n");
191        return false;
192      }
193    }
194  }
195
196  FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
197  if (!VectorFTy)
198    return false;
199
200  Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
201                                     VD->getVectorFnName(), FuncToReplace);
202
203  replaceWithTLIFunction(I, *OptInfo, TLIFunc);
204  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
205                    << "` with call to `" << TLIFunc->getName() << "`.\n");
206  ++NumCallsReplaced;
207  return true;
208}
209
210/// Supported instruction \p I must be a vectorized frem or a call to an
211/// intrinsic that returns either void or a vector.
212static bool isSupportedInstruction(Instruction *I) {
213  Type *Ty = I->getType();
214  if (auto *CI = dyn_cast<CallInst>(I))
215    return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
216           CI->getCalledFunction()->getIntrinsicID() !=
217               Intrinsic::not_intrinsic;
218  if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
219    return true;
220  return false;
221}
222
223static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
224  bool Changed = false;
225  SmallVector<Instruction *> ReplacedCalls;
226  for (auto &I : instructions(F)) {
227    if (!isSupportedInstruction(&I))
228      continue;
229    if (replaceWithCallToVeclib(TLI, I)) {
230      ReplacedCalls.push_back(&I);
231      Changed = true;
232    }
233  }
234  // Erase the calls to the intrinsics that have been replaced
235  // with calls to the vector library.
236  for (auto *CI : ReplacedCalls)
237    CI->eraseFromParent();
238  return Changed;
239}
240
241////////////////////////////////////////////////////////////////////////////////
242// New pass manager implementation.
243////////////////////////////////////////////////////////////////////////////////
244PreservedAnalyses ReplaceWithVeclib::run(Function &F,
245                                         FunctionAnalysisManager &AM) {
246  const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
247  auto Changed = runImpl(TLI, F);
248  if (Changed) {
249    LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
250                      << NumCallsReplaced << "\n");
251
252    PreservedAnalyses PA;
253    PA.preserveSet<CFGAnalyses>();
254    PA.preserve<TargetLibraryAnalysis>();
255    PA.preserve<ScalarEvolutionAnalysis>();
256    PA.preserve<LoopAccessAnalysis>();
257    PA.preserve<DemandedBitsAnalysis>();
258    PA.preserve<OptimizationRemarkEmitterAnalysis>();
259    return PA;
260  }
261
262  // The pass did not replace any calls, hence it preserves all analyses.
263  return PreservedAnalyses::all();
264}
265
266////////////////////////////////////////////////////////////////////////////////
267// Legacy PM Implementation.
268////////////////////////////////////////////////////////////////////////////////
269bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
270  const TargetLibraryInfo &TLI =
271      getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
272  return runImpl(TLI, F);
273}
274
275void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
276  AU.setPreservesCFG();
277  AU.addRequired<TargetLibraryInfoWrapperPass>();
278  AU.addPreserved<TargetLibraryInfoWrapperPass>();
279  AU.addPreserved<ScalarEvolutionWrapperPass>();
280  AU.addPreserved<AAResultsWrapperPass>();
281  AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
282  AU.addPreserved<GlobalsAAWrapperPass>();
283}
284
285////////////////////////////////////////////////////////////////////////////////
286// Legacy Pass manager initialization
287////////////////////////////////////////////////////////////////////////////////
288char ReplaceWithVeclibLegacy::ID = 0;
289
290INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
291                      "Replace intrinsics with calls to vector library", false,
292                      false)
293INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
294INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
295                    "Replace intrinsics with calls to vector library", false,
296                    false)
297
298FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
299  return new ReplaceWithVeclibLegacy();
300}
301