1//===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass interleaves around sext/zext/trunc instructions. MVE does not have
10// a single sext/zext or trunc instruction that takes the bottom half of a
11// vector and extends to a full width, like NEON has with MOVL. Instead it is
12// expected that this happens through top/bottom instructions. So the MVE
13// equivalent VMOVLT/B instructions take either the even or odd elements of the
14// input and extend them to the larger type, producing a vector with half the
15// number of elements each of double the bitwidth. As there is no simple
16// instruction, we often have to turn sext/zext/trunc into a series of lane
17// moves (or stack loads/stores, which we do not do yet).
18//
19// This pass takes vector code that starts at truncs, looks for interconnected
20// blobs of operations that end with sext/zext (or constants/splats) of the
21// form:
22//   %sa = sext v8i16 %a to v8i32
23//   %sb = sext v8i16 %b to v8i32
24//   %add = add v8i32 %sa, %sb
25//   %r = trunc %add to v8i16
26// And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27//   %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28//   %sa = sext v8i16 %sha to v8i32
29//   %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30//   %sb = sext v8i16 %shb to v8i32
31//   %add = add v8i32 %sa, %sb
32//   %r = trunc %add to v8i16
33//   %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34// Which can then be split and lowered to MVE instructions efficiently:
35//   %sa_b = VMOVLB.s16 %a
36//   %sa_t = VMOVLT.s16 %a
37//   %sb_b = VMOVLB.s16 %b
38//   %sb_t = VMOVLT.s16 %b
39//   %add_b = VADD.i32 %sa_b, %sb_b
40//   %add_t = VADD.i32 %sa_t, %sb_t
41//   %r = VMOVNT.i16 %add_b, %add_t
42//
43//===----------------------------------------------------------------------===//
44
45#include "ARM.h"
46#include "ARMBaseInstrInfo.h"
47#include "ARMSubtarget.h"
48#include "llvm/ADT/SetVector.h"
49#include "llvm/Analysis/TargetTransformInfo.h"
50#include "llvm/CodeGen/TargetLowering.h"
51#include "llvm/CodeGen/TargetPassConfig.h"
52#include "llvm/CodeGen/TargetSubtargetInfo.h"
53#include "llvm/IR/BasicBlock.h"
54#include "llvm/IR/Constant.h"
55#include "llvm/IR/Constants.h"
56#include "llvm/IR/DerivedTypes.h"
57#include "llvm/IR/Function.h"
58#include "llvm/IR/IRBuilder.h"
59#include "llvm/IR/InstIterator.h"
60#include "llvm/IR/InstrTypes.h"
61#include "llvm/IR/Instruction.h"
62#include "llvm/IR/Instructions.h"
63#include "llvm/IR/IntrinsicInst.h"
64#include "llvm/IR/Intrinsics.h"
65#include "llvm/IR/IntrinsicsARM.h"
66#include "llvm/IR/PatternMatch.h"
67#include "llvm/IR/Type.h"
68#include "llvm/IR/Value.h"
69#include "llvm/InitializePasses.h"
70#include "llvm/Pass.h"
71#include "llvm/Support/Casting.h"
72#include <algorithm>
73#include <cassert>
74
75using namespace llvm;
76
77#define DEBUG_TYPE "mve-laneinterleave"
78
79cl::opt<bool> EnableInterleave(
80    "enable-mve-interleave", cl::Hidden, cl::init(true),
81    cl::desc("Enable interleave MVE vector operation lowering"));
82
83namespace {
84
85class MVELaneInterleaving : public FunctionPass {
86public:
87  static char ID; // Pass identification, replacement for typeid
88
89  explicit MVELaneInterleaving() : FunctionPass(ID) {
90    initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
91  }
92
93  bool runOnFunction(Function &F) override;
94
95  StringRef getPassName() const override { return "MVE lane interleaving"; }
96
97  void getAnalysisUsage(AnalysisUsage &AU) const override {
98    AU.setPreservesCFG();
99    AU.addRequired<TargetPassConfig>();
100    FunctionPass::getAnalysisUsage(AU);
101  }
102};
103
104} // end anonymous namespace
105
106char MVELaneInterleaving::ID = 0;
107
108INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
109                false)
110
111Pass *llvm::createMVELaneInterleavingPass() {
112  return new MVELaneInterleaving();
113}
114
115static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
116                                     SmallSetVector<Instruction *, 4> &Truncs) {
117  // This is not always beneficial to transform. Exts can be incorporated into
118  // loads, Truncs can be folded into stores.
119  // Truncs are usually the same number of instructions,
120  //  VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
121  // Exts are unfortunately more instructions in the general case:
122  //  A=VLDRH.32; B=VLDRH.32;
123  // vs with interleaving:
124  //  T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
125  // But those VMOVL may be folded into a VMULL.
126
127  // But expensive extends/truncs are always good to remove. FPExts always
128  // involve extra VCVT's so are always considered to be beneficial to convert.
129  for (auto *E : Exts) {
130    if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
131      LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
132      return true;
133    }
134  }
135  for (auto *T : Truncs) {
136    if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
137      LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
138      return true;
139    }
140  }
141
142  // Otherwise, we know we have a load(ext), see if any of the Extends are a
143  // vmull. This is a simple heuristic and certainly not perfect.
144  for (auto *E : Exts) {
145    if (!E->hasOneUse() ||
146        cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
147      LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
148      return false;
149    }
150  }
151  return true;
152}
153
154static bool tryInterleave(Instruction *Start,
155                          SmallPtrSetImpl<Instruction *> &Visited) {
156  LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
157
158  if (!isa<Instruction>(Start->getOperand(0)))
159    return false;
160
161  // Look for connected operations starting from Ext's, terminating at Truncs.
162  std::vector<Instruction *> Worklist;
163  Worklist.push_back(Start);
164  Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
165
166  SmallSetVector<Instruction *, 4> Truncs;
167  SmallSetVector<Instruction *, 4> Reducts;
168  SmallSetVector<Instruction *, 4> Exts;
169  SmallSetVector<Use *, 4> OtherLeafs;
170  SmallSetVector<Instruction *, 4> Ops;
171
172  while (!Worklist.empty()) {
173    Instruction *I = Worklist.back();
174    Worklist.pop_back();
175
176    switch (I->getOpcode()) {
177    // Truncs
178    case Instruction::Trunc:
179    case Instruction::FPTrunc:
180      if (!Truncs.insert(I))
181        continue;
182      Visited.insert(I);
183      break;
184
185    // Extend leafs
186    case Instruction::SExt:
187    case Instruction::ZExt:
188    case Instruction::FPExt:
189      if (Exts.count(I))
190        continue;
191      for (auto *Use : I->users())
192        Worklist.push_back(cast<Instruction>(Use));
193      Exts.insert(I);
194      break;
195
196    case Instruction::Call: {
197      IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
198      if (!II)
199        return false;
200
201      if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) {
202        if (!Reducts.insert(I))
203          continue;
204        Visited.insert(I);
205        break;
206      }
207
208      switch (II->getIntrinsicID()) {
209      case Intrinsic::abs:
210      case Intrinsic::smin:
211      case Intrinsic::smax:
212      case Intrinsic::umin:
213      case Intrinsic::umax:
214      case Intrinsic::sadd_sat:
215      case Intrinsic::ssub_sat:
216      case Intrinsic::uadd_sat:
217      case Intrinsic::usub_sat:
218      case Intrinsic::minnum:
219      case Intrinsic::maxnum:
220      case Intrinsic::fabs:
221      case Intrinsic::fma:
222      case Intrinsic::ceil:
223      case Intrinsic::floor:
224      case Intrinsic::rint:
225      case Intrinsic::round:
226      case Intrinsic::trunc:
227        break;
228      default:
229        return false;
230      }
231      [[fallthrough]]; // Fall through to treating these like an operator below.
232    }
233    // Binary/tertiary ops
234    case Instruction::Add:
235    case Instruction::Sub:
236    case Instruction::Mul:
237    case Instruction::AShr:
238    case Instruction::LShr:
239    case Instruction::Shl:
240    case Instruction::ICmp:
241    case Instruction::FCmp:
242    case Instruction::FAdd:
243    case Instruction::FMul:
244    case Instruction::Select:
245      if (!Ops.insert(I))
246        continue;
247
248      for (Use &Op : I->operands()) {
249        if (!isa<FixedVectorType>(Op->getType()))
250          continue;
251        if (isa<Instruction>(Op))
252          Worklist.push_back(cast<Instruction>(&Op));
253        else
254          OtherLeafs.insert(&Op);
255      }
256
257      for (auto *Use : I->users())
258        Worklist.push_back(cast<Instruction>(Use));
259      break;
260
261    case Instruction::ShuffleVector:
262      // A shuffle of a splat is a splat.
263      if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
264        continue;
265      [[fallthrough]];
266
267    default:
268      LLVM_DEBUG(dbgs() << "  Unhandled instruction: " << *I << "\n");
269      return false;
270    }
271  }
272
273  if (Exts.empty() && OtherLeafs.empty())
274    return false;
275
276  LLVM_DEBUG({
277    dbgs() << "Found group:\n  Exts:\n";
278    for (auto *I : Exts)
279      dbgs() << "  " << *I << "\n";
280    dbgs() << "  Ops:\n";
281    for (auto *I : Ops)
282      dbgs() << "  " << *I << "\n";
283    dbgs() << "  OtherLeafs:\n";
284    for (auto *I : OtherLeafs)
285      dbgs() << "  " << *I->get() << " of " << *I->getUser() << "\n";
286    dbgs() << "  Truncs:\n";
287    for (auto *I : Truncs)
288      dbgs() << "  " << *I << "\n";
289    dbgs() << "  Reducts:\n";
290    for (auto *I : Reducts)
291      dbgs() << "  " << *I << "\n";
292  });
293
294  assert((!Truncs.empty() || !Reducts.empty()) &&
295         "Expected some truncs or reductions");
296  if (Truncs.empty() && Exts.empty())
297    return false;
298
299  auto *VT = !Truncs.empty()
300                 ? cast<FixedVectorType>(Truncs[0]->getType())
301                 : cast<FixedVectorType>(Exts[0]->getOperand(0)->getType());
302  LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n");
303
304  // Check types
305  unsigned NumElts = VT->getNumElements();
306  unsigned BaseElts = VT->getScalarSizeInBits() == 16
307                          ? 8
308                          : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
309  if (BaseElts == 0 || NumElts % BaseElts != 0) {
310    LLVM_DEBUG(dbgs() << "  Type is unsupported\n");
311    return false;
312  }
313  if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
314      VT->getScalarSizeInBits() * 2) {
315    LLVM_DEBUG(dbgs() << "  Type not double sized\n");
316    return false;
317  }
318  for (Instruction *I : Exts)
319    if (I->getOperand(0)->getType() != VT) {
320      LLVM_DEBUG(dbgs() << "  Wrong type on " << *I << "\n");
321      return false;
322    }
323  for (Instruction *I : Truncs)
324    if (I->getType() != VT) {
325      LLVM_DEBUG(dbgs() << "  Wrong type on " << *I << "\n");
326      return false;
327    }
328
329  // Check that it looks beneficial
330  if (!isProfitableToInterleave(Exts, Truncs))
331    return false;
332  if (!Reducts.empty() && (Ops.empty() || all_of(Ops, [](Instruction *I) {
333                             return I->getOpcode() == Instruction::Mul ||
334                                    I->getOpcode() == Instruction::Select ||
335                                    I->getOpcode() == Instruction::ICmp;
336                           }))) {
337    LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n");
338    return false;
339  }
340
341  // Create new shuffles around the extends / truncs / other leaves.
342  IRBuilder<> Builder(Start);
343
344  SmallVector<int, 16> LeafMask;
345  SmallVector<int, 16> TruncMask;
346  // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7   8, 10, 12, 14,  9, 11, 13, 15
347  // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7   8, 12,  9, 13, 10, 14, 11, 15
348  for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
349    for (unsigned i = 0; i < BaseElts / 2; i++)
350      LeafMask.push_back(Base + i * 2);
351    for (unsigned i = 0; i < BaseElts / 2; i++)
352      LeafMask.push_back(Base + i * 2 + 1);
353  }
354  for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
355    for (unsigned i = 0; i < BaseElts / 2; i++) {
356      TruncMask.push_back(Base + i);
357      TruncMask.push_back(Base + i + BaseElts / 2);
358    }
359  }
360
361  for (Instruction *I : Exts) {
362    LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
363    Builder.SetInsertPoint(I);
364    Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
365    bool FPext = isa<FPExtInst>(I);
366    bool Sext = isa<SExtInst>(I);
367    Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
368                       : Sext ? Builder.CreateSExt(Shuffle, I->getType())
369                              : Builder.CreateZExt(Shuffle, I->getType());
370    I->replaceAllUsesWith(Ext);
371    LLVM_DEBUG(dbgs() << "  with " << *Shuffle << "\n");
372  }
373
374  for (Use *I : OtherLeafs) {
375    LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
376    Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
377    Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
378    I->getUser()->setOperand(I->getOperandNo(), Shuffle);
379    LLVM_DEBUG(dbgs() << "  with " << *Shuffle << "\n");
380  }
381
382  for (Instruction *I : Truncs) {
383    LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
384
385    Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
386    Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
387    I->replaceAllUsesWith(Shuf);
388    cast<Instruction>(Shuf)->setOperand(0, I);
389
390    LLVM_DEBUG(dbgs() << "  with " << *Shuf << "\n");
391  }
392
393  return true;
394}
395
396// Add reductions are fairly common and associative, meaning we can start the
397// interleaving from them and don't need to emit a shuffle.
398static bool isAddReduction(Instruction &I) {
399  if (auto *II = dyn_cast<IntrinsicInst>(&I))
400    return II->getIntrinsicID() == Intrinsic::vector_reduce_add;
401  return false;
402}
403
404bool MVELaneInterleaving::runOnFunction(Function &F) {
405  if (!EnableInterleave)
406    return false;
407  auto &TPC = getAnalysis<TargetPassConfig>();
408  auto &TM = TPC.getTM<TargetMachine>();
409  auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
410  if (!ST->hasMVEIntegerOps())
411    return false;
412
413  bool Changed = false;
414
415  SmallPtrSet<Instruction *, 16> Visited;
416  for (Instruction &I : reverse(instructions(F))) {
417    if (((I.getType()->isVectorTy() &&
418          (isa<TruncInst>(I) || isa<FPTruncInst>(I))) ||
419         isAddReduction(I)) &&
420        !Visited.count(&I))
421      Changed |= tryInterleave(&I, Visited);
422  }
423
424  return Changed;
425}
426