1//===- VPlanSLP.cpp - SLP Analysis based on VPlan -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// This file implements SLP analysis based on VPlan. The analysis is based on
9/// the ideas described in
10///
11///   Look-ahead SLP: auto-vectorization in the presence of commutative
12///   operations, CGO 2018 by Vasileios Porpodas, Rodrigo C. O. Rocha,
13///   Lu��s F. W. G��es
14///
15//===----------------------------------------------------------------------===//
16
17#include "VPlan.h"
18#include "VPlanValue.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Analysis/VectorUtils.h"
22#include "llvm/IR/Instruction.h"
23#include "llvm/IR/Instructions.h"
24#include "llvm/IR/Type.h"
25#include "llvm/IR/Value.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/ErrorHandling.h"
29#include "llvm/Support/raw_ostream.h"
30#include <algorithm>
31#include <cassert>
32#include <optional>
33#include <utility>
34
35using namespace llvm;
36
37#define DEBUG_TYPE "vplan-slp"
38
39// Number of levels to look ahead when re-ordering multi node operands.
40static unsigned LookaheadMaxDepth = 5;
41
42VPInstruction *VPlanSlp::markFailed() {
43  // FIXME: Currently this is used to signal we hit instructions we cannot
44  //        trivially SLP'ize.
45  CompletelySLP = false;
46  return nullptr;
47}
48
49void VPlanSlp::addCombined(ArrayRef<VPValue *> Operands, VPInstruction *New) {
50  if (all_of(Operands, [](VPValue *V) {
51        return cast<VPInstruction>(V)->getUnderlyingInstr();
52      })) {
53    unsigned BundleSize = 0;
54    for (VPValue *V : Operands) {
55      Type *T = cast<VPInstruction>(V)->getUnderlyingInstr()->getType();
56      assert(!T->isVectorTy() && "Only scalar types supported for now");
57      BundleSize += T->getScalarSizeInBits();
58    }
59    WidestBundleBits = std::max(WidestBundleBits, BundleSize);
60  }
61
62  auto Res = BundleToCombined.try_emplace(to_vector<4>(Operands), New);
63  assert(Res.second &&
64         "Already created a combined instruction for the operand bundle");
65  (void)Res;
66}
67
68bool VPlanSlp::areVectorizable(ArrayRef<VPValue *> Operands) const {
69  // Currently we only support VPInstructions.
70  if (!all_of(Operands, [](VPValue *Op) {
71        return Op && isa<VPInstruction>(Op) &&
72               cast<VPInstruction>(Op)->getUnderlyingInstr();
73      })) {
74    LLVM_DEBUG(dbgs() << "VPSLP: not all operands are VPInstructions\n");
75    return false;
76  }
77
78  // Check if opcodes and type width agree for all instructions in the bundle.
79  // FIXME: Differing widths/opcodes can be handled by inserting additional
80  //        instructions.
81  // FIXME: Deal with non-primitive types.
82  const Instruction *OriginalInstr =
83      cast<VPInstruction>(Operands[0])->getUnderlyingInstr();
84  unsigned Opcode = OriginalInstr->getOpcode();
85  unsigned Width = OriginalInstr->getType()->getPrimitiveSizeInBits();
86  if (!all_of(Operands, [Opcode, Width](VPValue *Op) {
87        const Instruction *I = cast<VPInstruction>(Op)->getUnderlyingInstr();
88        return I->getOpcode() == Opcode &&
89               I->getType()->getPrimitiveSizeInBits() == Width;
90      })) {
91    LLVM_DEBUG(dbgs() << "VPSLP: Opcodes do not agree \n");
92    return false;
93  }
94
95  // For now, all operands must be defined in the same BB.
96  if (any_of(Operands, [this](VPValue *Op) {
97        return cast<VPInstruction>(Op)->getParent() != &this->BB;
98      })) {
99    LLVM_DEBUG(dbgs() << "VPSLP: operands in different BBs\n");
100    return false;
101  }
102
103  if (any_of(Operands,
104             [](VPValue *Op) { return Op->hasMoreThanOneUniqueUser(); })) {
105    LLVM_DEBUG(dbgs() << "VPSLP: Some operands have multiple users.\n");
106    return false;
107  }
108
109  // For loads, check that there are no instructions writing to memory in
110  // between them.
111  // TODO: we only have to forbid instructions writing to memory that could
112  //       interfere with any of the loads in the bundle
113  if (Opcode == Instruction::Load) {
114    unsigned LoadsSeen = 0;
115    VPBasicBlock *Parent = cast<VPInstruction>(Operands[0])->getParent();
116    for (auto &I : *Parent) {
117      auto *VPI = dyn_cast<VPInstruction>(&I);
118      if (!VPI)
119        break;
120      if (VPI->getOpcode() == Instruction::Load &&
121          llvm::is_contained(Operands, VPI))
122        LoadsSeen++;
123
124      if (LoadsSeen == Operands.size())
125        break;
126      if (LoadsSeen > 0 && VPI->mayWriteToMemory()) {
127        LLVM_DEBUG(
128            dbgs() << "VPSLP: instruction modifying memory between loads\n");
129        return false;
130      }
131    }
132
133    if (!all_of(Operands, [](VPValue *Op) {
134          return cast<LoadInst>(cast<VPInstruction>(Op)->getUnderlyingInstr())
135              ->isSimple();
136        })) {
137      LLVM_DEBUG(dbgs() << "VPSLP: only simple loads are supported.\n");
138      return false;
139    }
140  }
141
142  if (Opcode == Instruction::Store)
143    if (!all_of(Operands, [](VPValue *Op) {
144          return cast<StoreInst>(cast<VPInstruction>(Op)->getUnderlyingInstr())
145              ->isSimple();
146        })) {
147      LLVM_DEBUG(dbgs() << "VPSLP: only simple stores are supported.\n");
148      return false;
149    }
150
151  return true;
152}
153
154static SmallVector<VPValue *, 4> getOperands(ArrayRef<VPValue *> Values,
155                                             unsigned OperandIndex) {
156  SmallVector<VPValue *, 4> Operands;
157  for (VPValue *V : Values) {
158    // Currently we only support VPInstructions.
159    auto *U = cast<VPInstruction>(V);
160    Operands.push_back(U->getOperand(OperandIndex));
161  }
162  return Operands;
163}
164
165static bool areCommutative(ArrayRef<VPValue *> Values) {
166  return Instruction::isCommutative(
167      cast<VPInstruction>(Values[0])->getOpcode());
168}
169
170static SmallVector<SmallVector<VPValue *, 4>, 4>
171getOperands(ArrayRef<VPValue *> Values) {
172  SmallVector<SmallVector<VPValue *, 4>, 4> Result;
173  auto *VPI = cast<VPInstruction>(Values[0]);
174
175  switch (VPI->getOpcode()) {
176  case Instruction::Load:
177    llvm_unreachable("Loads terminate a tree, no need to get operands");
178  case Instruction::Store:
179    Result.push_back(getOperands(Values, 0));
180    break;
181  default:
182    for (unsigned I = 0, NumOps = VPI->getNumOperands(); I < NumOps; ++I)
183      Result.push_back(getOperands(Values, I));
184    break;
185  }
186
187  return Result;
188}
189
190/// Returns the opcode of Values or ~0 if they do not all agree.
191static std::optional<unsigned> getOpcode(ArrayRef<VPValue *> Values) {
192  unsigned Opcode = cast<VPInstruction>(Values[0])->getOpcode();
193  if (any_of(Values, [Opcode](VPValue *V) {
194        return cast<VPInstruction>(V)->getOpcode() != Opcode;
195      }))
196    return std::nullopt;
197  return {Opcode};
198}
199
200/// Returns true if A and B access sequential memory if they are loads or
201/// stores or if they have identical opcodes otherwise.
202static bool areConsecutiveOrMatch(VPInstruction *A, VPInstruction *B,
203                                  VPInterleavedAccessInfo &IAI) {
204  if (A->getOpcode() != B->getOpcode())
205    return false;
206
207  if (A->getOpcode() != Instruction::Load &&
208      A->getOpcode() != Instruction::Store)
209    return true;
210  auto *GA = IAI.getInterleaveGroup(A);
211  auto *GB = IAI.getInterleaveGroup(B);
212
213  return GA && GB && GA == GB && GA->getIndex(A) + 1 == GB->getIndex(B);
214}
215
216/// Implements getLAScore from Listing 7 in the paper.
217/// Traverses and compares operands of V1 and V2 to MaxLevel.
218static unsigned getLAScore(VPValue *V1, VPValue *V2, unsigned MaxLevel,
219                           VPInterleavedAccessInfo &IAI) {
220  auto *I1 = dyn_cast<VPInstruction>(V1);
221  auto *I2 = dyn_cast<VPInstruction>(V2);
222  // Currently we only support VPInstructions.
223  if (!I1 || !I2)
224    return 0;
225
226  if (MaxLevel == 0)
227    return (unsigned)areConsecutiveOrMatch(I1, I2, IAI);
228
229  unsigned Score = 0;
230  for (unsigned I = 0, EV1 = I1->getNumOperands(); I < EV1; ++I)
231    for (unsigned J = 0, EV2 = I2->getNumOperands(); J < EV2; ++J)
232      Score +=
233          getLAScore(I1->getOperand(I), I2->getOperand(J), MaxLevel - 1, IAI);
234  return Score;
235}
236
237std::pair<VPlanSlp::OpMode, VPValue *>
238VPlanSlp::getBest(OpMode Mode, VPValue *Last,
239                  SmallPtrSetImpl<VPValue *> &Candidates,
240                  VPInterleavedAccessInfo &IAI) {
241  assert((Mode == OpMode::Load || Mode == OpMode::Opcode) &&
242         "Currently we only handle load and commutative opcodes");
243  LLVM_DEBUG(dbgs() << "      getBest\n");
244
245  SmallVector<VPValue *, 4> BestCandidates;
246  LLVM_DEBUG(dbgs() << "        Candidates  for "
247                    << *cast<VPInstruction>(Last)->getUnderlyingInstr() << " ");
248  for (auto *Candidate : Candidates) {
249    auto *LastI = cast<VPInstruction>(Last);
250    auto *CandidateI = cast<VPInstruction>(Candidate);
251    if (areConsecutiveOrMatch(LastI, CandidateI, IAI)) {
252      LLVM_DEBUG(dbgs() << *cast<VPInstruction>(Candidate)->getUnderlyingInstr()
253                        << " ");
254      BestCandidates.push_back(Candidate);
255    }
256  }
257  LLVM_DEBUG(dbgs() << "\n");
258
259  if (BestCandidates.empty())
260    return {OpMode::Failed, nullptr};
261
262  if (BestCandidates.size() == 1)
263    return {Mode, BestCandidates[0]};
264
265  VPValue *Best = nullptr;
266  unsigned BestScore = 0;
267  for (unsigned Depth = 1; Depth < LookaheadMaxDepth; Depth++) {
268    unsigned PrevScore = ~0u;
269    bool AllSame = true;
270
271    // FIXME: Avoid visiting the same operands multiple times.
272    for (auto *Candidate : BestCandidates) {
273      unsigned Score = getLAScore(Last, Candidate, Depth, IAI);
274      if (PrevScore == ~0u)
275        PrevScore = Score;
276      if (PrevScore != Score)
277        AllSame = false;
278      PrevScore = Score;
279
280      if (Score > BestScore) {
281        BestScore = Score;
282        Best = Candidate;
283      }
284    }
285    if (!AllSame)
286      break;
287  }
288  LLVM_DEBUG(dbgs() << "Found best "
289                    << *cast<VPInstruction>(Best)->getUnderlyingInstr()
290                    << "\n");
291  Candidates.erase(Best);
292
293  return {Mode, Best};
294}
295
296SmallVector<VPlanSlp::MultiNodeOpTy, 4> VPlanSlp::reorderMultiNodeOps() {
297  SmallVector<MultiNodeOpTy, 4> FinalOrder;
298  SmallVector<OpMode, 4> Mode;
299  FinalOrder.reserve(MultiNodeOps.size());
300  Mode.reserve(MultiNodeOps.size());
301
302  LLVM_DEBUG(dbgs() << "Reordering multinode\n");
303
304  for (auto &Operands : MultiNodeOps) {
305    FinalOrder.push_back({Operands.first, {Operands.second[0]}});
306    if (cast<VPInstruction>(Operands.second[0])->getOpcode() ==
307        Instruction::Load)
308      Mode.push_back(OpMode::Load);
309    else
310      Mode.push_back(OpMode::Opcode);
311  }
312
313  for (unsigned Lane = 1, E = MultiNodeOps[0].second.size(); Lane < E; ++Lane) {
314    LLVM_DEBUG(dbgs() << "  Finding best value for lane " << Lane << "\n");
315    SmallPtrSet<VPValue *, 4> Candidates;
316    LLVM_DEBUG(dbgs() << "  Candidates  ");
317    for (auto Ops : MultiNodeOps) {
318      LLVM_DEBUG(
319          dbgs() << *cast<VPInstruction>(Ops.second[Lane])->getUnderlyingInstr()
320                 << " ");
321      Candidates.insert(Ops.second[Lane]);
322    }
323    LLVM_DEBUG(dbgs() << "\n");
324
325    for (unsigned Op = 0, E = MultiNodeOps.size(); Op < E; ++Op) {
326      LLVM_DEBUG(dbgs() << "  Checking " << Op << "\n");
327      if (Mode[Op] == OpMode::Failed)
328        continue;
329
330      VPValue *Last = FinalOrder[Op].second[Lane - 1];
331      std::pair<OpMode, VPValue *> Res =
332          getBest(Mode[Op], Last, Candidates, IAI);
333      if (Res.second)
334        FinalOrder[Op].second.push_back(Res.second);
335      else
336        // TODO: handle this case
337        FinalOrder[Op].second.push_back(markFailed());
338    }
339  }
340
341  return FinalOrder;
342}
343
344#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345void VPlanSlp::dumpBundle(ArrayRef<VPValue *> Values) {
346  dbgs() << " Ops: ";
347  for (auto *Op : Values) {
348    if (auto *VPInstr = cast_or_null<VPInstruction>(Op))
349      if (auto *Instr = VPInstr->getUnderlyingInstr()) {
350        dbgs() << *Instr << " | ";
351        continue;
352      }
353    dbgs() << " nullptr | ";
354  }
355  dbgs() << "\n";
356}
357#endif
358
359VPInstruction *VPlanSlp::buildGraph(ArrayRef<VPValue *> Values) {
360  assert(!Values.empty() && "Need some operands!");
361
362  // If we already visited this instruction bundle, re-use the existing node
363  auto I = BundleToCombined.find(to_vector<4>(Values));
364  if (I != BundleToCombined.end()) {
365#ifndef NDEBUG
366    // Check that the resulting graph is a tree. If we re-use a node, this means
367    // its values have multiple users. We only allow this, if all users of each
368    // value are the same instruction.
369    for (auto *V : Values) {
370      auto UI = V->user_begin();
371      auto *FirstUser = *UI++;
372      while (UI != V->user_end()) {
373        assert(*UI == FirstUser && "Currently we only support SLP trees.");
374        UI++;
375      }
376    }
377#endif
378    return I->second;
379  }
380
381  // Dump inputs
382  LLVM_DEBUG({
383    dbgs() << "buildGraph: ";
384    dumpBundle(Values);
385  });
386
387  if (!areVectorizable(Values))
388    return markFailed();
389
390  assert(getOpcode(Values) && "Opcodes for all values must match");
391  unsigned ValuesOpcode = *getOpcode(Values);
392
393  SmallVector<VPValue *, 4> CombinedOperands;
394  if (areCommutative(Values)) {
395    bool MultiNodeRoot = !MultiNodeActive;
396    MultiNodeActive = true;
397    for (auto &Operands : getOperands(Values)) {
398      LLVM_DEBUG({
399        dbgs() << "  Visiting Commutative";
400        dumpBundle(Operands);
401      });
402
403      auto OperandsOpcode = getOpcode(Operands);
404      if (OperandsOpcode && OperandsOpcode == getOpcode(Values)) {
405        LLVM_DEBUG(dbgs() << "    Same opcode, continue building\n");
406        CombinedOperands.push_back(buildGraph(Operands));
407      } else {
408        LLVM_DEBUG(dbgs() << "    Adding multinode Ops\n");
409        // Create dummy VPInstruction, which will we replace later by the
410        // re-ordered operand.
411        VPInstruction *Op = new VPInstruction(0, {});
412        CombinedOperands.push_back(Op);
413        MultiNodeOps.emplace_back(Op, Operands);
414      }
415    }
416
417    if (MultiNodeRoot) {
418      LLVM_DEBUG(dbgs() << "Reorder \n");
419      MultiNodeActive = false;
420
421      auto FinalOrder = reorderMultiNodeOps();
422
423      MultiNodeOps.clear();
424      for (auto &Ops : FinalOrder) {
425        VPInstruction *NewOp = buildGraph(Ops.second);
426        Ops.first->replaceAllUsesWith(NewOp);
427        for (unsigned i = 0; i < CombinedOperands.size(); i++)
428          if (CombinedOperands[i] == Ops.first)
429            CombinedOperands[i] = NewOp;
430        delete Ops.first;
431        Ops.first = NewOp;
432      }
433      LLVM_DEBUG(dbgs() << "Found final order\n");
434    }
435  } else {
436    LLVM_DEBUG(dbgs() << "  NonCommuntative\n");
437    if (ValuesOpcode == Instruction::Load)
438      for (VPValue *V : Values)
439        CombinedOperands.push_back(cast<VPInstruction>(V)->getOperand(0));
440    else
441      for (auto &Operands : getOperands(Values))
442        CombinedOperands.push_back(buildGraph(Operands));
443  }
444
445  unsigned Opcode;
446  switch (ValuesOpcode) {
447  case Instruction::Load:
448    Opcode = VPInstruction::SLPLoad;
449    break;
450  case Instruction::Store:
451    Opcode = VPInstruction::SLPStore;
452    break;
453  default:
454    Opcode = ValuesOpcode;
455    break;
456  }
457
458  if (!CompletelySLP)
459    return markFailed();
460
461  assert(CombinedOperands.size() > 0 && "Need more some operands");
462  auto *Inst = cast<VPInstruction>(Values[0])->getUnderlyingInstr();
463  auto *VPI = new VPInstruction(Opcode, CombinedOperands, Inst->getDebugLoc());
464  VPI->setUnderlyingInstr(Inst);
465
466  LLVM_DEBUG(dbgs() << "Create VPInstruction " << *VPI << " "
467                    << *cast<VPInstruction>(Values[0]) << "\n");
468  addCombined(Values, VPI);
469  return VPI;
470}
471