1//===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines some vectorizer utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_ANALYSIS_VECTORUTILS_H
14#define LLVM_ANALYSIS_VECTORUTILS_H
15
16#include "llvm/ADT/MapVector.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Analysis/LoopAccessAnalysis.h"
19#include "llvm/IR/VFABIDemangler.h"
20#include "llvm/Support/CheckedArithmetic.h"
21
22namespace llvm {
23class TargetLibraryInfo;
24
25/// The Vector Function Database.
26///
27/// Helper class used to find the vector functions associated to a
28/// scalar CallInst.
29class VFDatabase {
30  /// The Module of the CallInst CI.
31  const Module *M;
32  /// The CallInst instance being queried for scalar to vector mappings.
33  const CallInst &CI;
34  /// List of vector functions descriptors associated to the call
35  /// instruction.
36  const SmallVector<VFInfo, 8> ScalarToVectorMappings;
37
38  /// Retrieve the scalar-to-vector mappings associated to the rule of
39  /// a vector Function ABI.
40  static void getVFABIMappings(const CallInst &CI,
41                               SmallVectorImpl<VFInfo> &Mappings) {
42    if (!CI.getCalledFunction())
43      return;
44
45    const StringRef ScalarName = CI.getCalledFunction()->getName();
46
47    SmallVector<std::string, 8> ListOfStrings;
48    // The check for the vector-function-abi-variant attribute is done when
49    // retrieving the vector variant names here.
50    VFABI::getVectorVariantNames(CI, ListOfStrings);
51    if (ListOfStrings.empty())
52      return;
53    for (const auto &MangledName : ListOfStrings) {
54      const std::optional<VFInfo> Shape =
55          VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
56      // A match is found via scalar and vector names, and also by
57      // ensuring that the variant described in the attribute has a
58      // corresponding definition or declaration of the vector
59      // function in the Module M.
60      if (Shape && (Shape->ScalarName == ScalarName)) {
61        assert(CI.getModule()->getFunction(Shape->VectorName) &&
62               "Vector function is missing.");
63        Mappings.push_back(*Shape);
64      }
65    }
66  }
67
68public:
69  /// Retrieve all the VFInfo instances associated to the CallInst CI.
70  static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
71    SmallVector<VFInfo, 8> Ret;
72
73    // Get mappings from the Vector Function ABI variants.
74    getVFABIMappings(CI, Ret);
75
76    // Other non-VFABI variants should be retrieved here.
77
78    return Ret;
79  }
80
81  static bool hasMaskedVariant(const CallInst &CI,
82                               std::optional<ElementCount> VF = std::nullopt) {
83    // Check whether we have at least one masked vector version of a scalar
84    // function. If no VF is specified then we check for any masked variant,
85    // otherwise we look for one that matches the supplied VF.
86    auto Mappings = VFDatabase::getMappings(CI);
87    for (VFInfo Info : Mappings)
88      if (!VF || Info.Shape.VF == *VF)
89        if (Info.isMasked())
90          return true;
91
92    return false;
93  }
94
95  /// Constructor, requires a CallInst instance.
96  VFDatabase(CallInst &CI)
97      : M(CI.getModule()), CI(CI),
98        ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
99  /// \defgroup VFDatabase query interface.
100  ///
101  /// @{
102  /// Retrieve the Function with VFShape \p Shape.
103  Function *getVectorizedFunction(const VFShape &Shape) const {
104    if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
105      return CI.getCalledFunction();
106
107    for (const auto &Info : ScalarToVectorMappings)
108      if (Info.Shape == Shape)
109        return M->getFunction(Info.VectorName);
110
111    return nullptr;
112  }
113  /// @}
114};
115
116template <typename T> class ArrayRef;
117class DemandedBits;
118template <typename InstTy> class InterleaveGroup;
119class IRBuilderBase;
120class Loop;
121class ScalarEvolution;
122class TargetTransformInfo;
123class Type;
124class Value;
125
126namespace Intrinsic {
127typedef unsigned ID;
128}
129
130/// A helper function for converting Scalar types to vector types. If
131/// the incoming type is void, we return void. If the EC represents a
132/// scalar, we return the scalar type.
133inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
134  if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
135    return Scalar;
136  return VectorType::get(Scalar, EC);
137}
138
139inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
140  return ToVectorTy(Scalar, ElementCount::getFixed(VF));
141}
142
143/// Identify if the intrinsic is trivially vectorizable.
144/// This method returns true if the intrinsic's argument types are all scalars
145/// for the scalar form of the intrinsic and all vectors (or scalars handled by
146/// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
147bool isTriviallyVectorizable(Intrinsic::ID ID);
148
149/// Identifies if the vector form of the intrinsic has a scalar operand.
150bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
151                                        unsigned ScalarOpdIdx);
152
153/// Identifies if the vector form of the intrinsic is overloaded on the type of
154/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
155bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
156
157/// Returns intrinsic ID for call.
158/// For the input call instruction it finds mapping intrinsic and returns
159/// its intrinsic ID, in case it does not found it return not_intrinsic.
160Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
161                                          const TargetLibraryInfo *TLI);
162
163/// Given a vector and an element number, see if the scalar value is
164/// already around as a register, for example if it were inserted then extracted
165/// from the vector.
166Value *findScalarElement(Value *V, unsigned EltNo);
167
168/// If all non-negative \p Mask elements are the same value, return that value.
169/// If all elements are negative (undefined) or \p Mask contains different
170/// non-negative values, return -1.
171int getSplatIndex(ArrayRef<int> Mask);
172
173/// Get splat value if the input is a splat vector or return nullptr.
174/// The value may be extracted from a splat constants vector or from
175/// a sequence of instructions that broadcast a single value into a vector.
176Value *getSplatValue(const Value *V);
177
178/// Return true if each element of the vector value \p V is poisoned or equal to
179/// every other non-poisoned element. If an index element is specified, either
180/// every element of the vector is poisoned or the element at that index is not
181/// poisoned and equal to every other non-poisoned element.
182/// This may be more powerful than the related getSplatValue() because it is
183/// not limited by finding a scalar source value to a splatted vector.
184bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
185
186/// Transform a shuffle mask's output demanded element mask into demanded
187/// element masks for the 2 operands, returns false if the mask isn't valid.
188/// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
189/// \p AllowUndefElts permits "-1" indices to be treated as undef.
190bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
191                            const APInt &DemandedElts, APInt &DemandedLHS,
192                            APInt &DemandedRHS, bool AllowUndefElts = false);
193
194/// Replace each shuffle mask index with the scaled sequential indices for an
195/// equivalent mask of narrowed elements. Mask elements that are less than 0
196/// (sentinel values) are repeated in the output mask.
197///
198/// Example with Scale = 4:
199///   <4 x i32> <3, 2, 0, -1> -->
200///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
201///
202/// This is the reverse process of widening shuffle mask elements, but it always
203/// succeeds because the indexes can always be multiplied (scaled up) to map to
204/// narrower vector elements.
205void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
206                           SmallVectorImpl<int> &ScaledMask);
207
208/// Try to transform a shuffle mask by replacing elements with the scaled index
209/// for an equivalent mask of widened elements. If all mask elements that would
210/// map to a wider element of the new mask are the same negative number
211/// (sentinel value), that element of the new mask is the same value. If any
212/// element in a given slice is negative and some other element in that slice is
213/// not the same value, return false (partial matches with sentinel values are
214/// not allowed).
215///
216/// Example with Scale = 4:
217///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
218///   <4 x i32> <3, 2, 0, -1>
219///
220/// This is the reverse process of narrowing shuffle mask elements if it
221/// succeeds. This transform is not always possible because indexes may not
222/// divide evenly (scale down) to map to wider vector elements.
223bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
224                          SmallVectorImpl<int> &ScaledMask);
225
226/// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
227/// to get the shuffle mask with widest possible elements.
228void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
229                                  SmallVectorImpl<int> &ScaledMask);
230
231/// Splits and processes shuffle mask depending on the number of input and
232/// output registers. The function does 2 main things: 1) splits the
233/// source/destination vectors into real registers; 2) do the mask analysis to
234/// identify which real registers are permuted. Then the function processes
235/// resulting registers mask using provided action items. If no input register
236/// is defined, \p NoInputAction action is used. If only 1 input register is
237/// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
238/// process > 2 input registers and masks.
239/// \param Mask Original shuffle mask.
240/// \param NumOfSrcRegs Number of source registers.
241/// \param NumOfDestRegs Number of destination registers.
242/// \param NumOfUsedRegs Number of actually used destination registers.
243void processShuffleMasks(
244    ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
245    unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
246    function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
247    function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
248
249/// Compute a map of integer instructions to their minimum legal type
250/// size.
251///
252/// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
253/// type (e.g. i32) whenever arithmetic is performed on them.
254///
255/// For targets with native i8 or i16 operations, usually InstCombine can shrink
256/// the arithmetic type down again. However InstCombine refuses to create
257/// illegal types, so for targets without i8 or i16 registers, the lengthening
258/// and shrinking remains.
259///
260/// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
261/// their scalar equivalents do not, so during vectorization it is important to
262/// remove these lengthens and truncates when deciding the profitability of
263/// vectorization.
264///
265/// This function analyzes the given range of instructions and determines the
266/// minimum type size each can be converted to. It attempts to remove or
267/// minimize type size changes across each def-use chain, so for example in the
268/// following code:
269///
270///   %1 = load i8, i8*
271///   %2 = add i8 %1, 2
272///   %3 = load i16, i16*
273///   %4 = zext i8 %2 to i32
274///   %5 = zext i16 %3 to i32
275///   %6 = add i32 %4, %5
276///   %7 = trunc i32 %6 to i16
277///
278/// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
279/// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
280///
281/// If the optional TargetTransformInfo is provided, this function tries harder
282/// to do less work by only looking at illegal types.
283MapVector<Instruction*, uint64_t>
284computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
285                         DemandedBits &DB,
286                         const TargetTransformInfo *TTI=nullptr);
287
288/// Compute the union of two access-group lists.
289///
290/// If the list contains just one access group, it is returned directly. If the
291/// list is empty, returns nullptr.
292MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
293
294/// Compute the access-group list of access groups that @p Inst1 and @p Inst2
295/// are both in. If either instruction does not access memory at all, it is
296/// considered to be in every list.
297///
298/// If the list contains just one access group, it is returned directly. If the
299/// list is empty, returns nullptr.
300MDNode *intersectAccessGroups(const Instruction *Inst1,
301                              const Instruction *Inst2);
302
303/// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
304/// MD_nontemporal, MD_access_group].
305/// For K in Kinds, we get the MDNode for K from each of the
306/// elements of VL, compute their "intersection" (i.e., the most generic
307/// metadata value that covers all of the individual values), and set I's
308/// metadata for M equal to the intersection value.
309///
310/// This function always sets a (possibly null) value for each K in Kinds.
311Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
312
313/// Create a mask that filters the members of an interleave group where there
314/// are gaps.
315///
316/// For example, the mask for \p Group with interleave-factor 3
317/// and \p VF 4, that has only its first member present is:
318///
319///   <1,0,0,1,0,0,1,0,0,1,0,0>
320///
321/// Note: The result is a mask of 0's and 1's, as opposed to the other
322/// create[*]Mask() utilities which create a shuffle mask (mask that
323/// consists of indices).
324Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
325                               const InterleaveGroup<Instruction> &Group);
326
327/// Create a mask with replicated elements.
328///
329/// This function creates a shuffle mask for replicating each of the \p VF
330/// elements in a vector \p ReplicationFactor times. It can be used to
331/// transform a mask of \p VF elements into a mask of
332/// \p VF * \p ReplicationFactor elements used by a predicated
333/// interleaved-group of loads/stores whose Interleaved-factor ==
334/// \p ReplicationFactor.
335///
336/// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
337///
338///   <0,0,0,1,1,1,2,2,2,3,3,3>
339llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
340                                                unsigned VF);
341
342/// Create an interleave shuffle mask.
343///
344/// This function creates a shuffle mask for interleaving \p NumVecs vectors of
345/// vectorization factor \p VF into a single wide vector. The mask is of the
346/// form:
347///
348///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
349///
350/// For example, the mask for VF = 4 and NumVecs = 2 is:
351///
352///   <0, 4, 1, 5, 2, 6, 3, 7>.
353llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
354
355/// Create a stride shuffle mask.
356///
357/// This function creates a shuffle mask whose elements begin at \p Start and
358/// are incremented by \p Stride. The mask can be used to deinterleave an
359/// interleaved vector into separate vectors of vectorization factor \p VF. The
360/// mask is of the form:
361///
362///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
363///
364/// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
365///
366///   <0, 2, 4, 6>
367llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
368                                            unsigned VF);
369
370/// Create a sequential shuffle mask.
371///
372/// This function creates shuffle mask whose elements are sequential and begin
373/// at \p Start.  The mask contains \p NumInts integers and is padded with \p
374/// NumUndefs undef values. The mask is of the form:
375///
376///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
377///
378/// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
379///
380///   <0, 1, 2, 3, undef, undef, undef, undef>
381llvm::SmallVector<int, 16>
382createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
383
384/// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
385/// mask assuming both operands are identical. This assumes that the unary
386/// shuffle will use elements from operand 0 (operand 1 will be unused).
387llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
388                                           unsigned NumElts);
389
390/// Concatenate a list of vectors.
391///
392/// This function generates code that concatenate the vectors in \p Vecs into a
393/// single large vector. The number of vectors should be greater than one, and
394/// their element types should be the same. The number of elements in the
395/// vectors should also be the same; however, if the last vector has fewer
396/// elements, it will be padded with undefs.
397Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
398
399/// Given a mask vector of i1, Return true if all of the elements of this
400/// predicate mask are known to be false or undef.  That is, return true if all
401/// lanes can be assumed inactive.
402bool maskIsAllZeroOrUndef(Value *Mask);
403
404/// Given a mask vector of i1, Return true if all of the elements of this
405/// predicate mask are known to be true or undef.  That is, return true if all
406/// lanes can be assumed active.
407bool maskIsAllOneOrUndef(Value *Mask);
408
409/// Given a mask vector of i1, Return true if any of the elements of this
410/// predicate mask are known to be true or undef.  That is, return true if at
411/// least one lane can be assumed active.
412bool maskContainsAllOneOrUndef(Value *Mask);
413
414/// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
415/// for each lane which may be active.
416APInt possiblyDemandedEltsInMask(Value *Mask);
417
418/// The group of interleaved loads/stores sharing the same stride and
419/// close to each other.
420///
421/// Each member in this group has an index starting from 0, and the largest
422/// index should be less than interleaved factor, which is equal to the absolute
423/// value of the access's stride.
424///
425/// E.g. An interleaved load group of factor 4:
426///        for (unsigned i = 0; i < 1024; i+=4) {
427///          a = A[i];                           // Member of index 0
428///          b = A[i+1];                         // Member of index 1
429///          d = A[i+3];                         // Member of index 3
430///          ...
431///        }
432///
433///      An interleaved store group of factor 4:
434///        for (unsigned i = 0; i < 1024; i+=4) {
435///          ...
436///          A[i]   = a;                         // Member of index 0
437///          A[i+1] = b;                         // Member of index 1
438///          A[i+2] = c;                         // Member of index 2
439///          A[i+3] = d;                         // Member of index 3
440///        }
441///
442/// Note: the interleaved load group could have gaps (missing members), but
443/// the interleaved store group doesn't allow gaps.
444template <typename InstTy> class InterleaveGroup {
445public:
446  InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
447      : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
448        InsertPos(nullptr) {}
449
450  InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
451      : Alignment(Alignment), InsertPos(Instr) {
452    Factor = std::abs(Stride);
453    assert(Factor > 1 && "Invalid interleave factor");
454
455    Reverse = Stride < 0;
456    Members[0] = Instr;
457  }
458
459  bool isReverse() const { return Reverse; }
460  uint32_t getFactor() const { return Factor; }
461  Align getAlign() const { return Alignment; }
462  uint32_t getNumMembers() const { return Members.size(); }
463
464  /// Try to insert a new member \p Instr with index \p Index and
465  /// alignment \p NewAlign. The index is related to the leader and it could be
466  /// negative if it is the new leader.
467  ///
468  /// \returns false if the instruction doesn't belong to the group.
469  bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
470    // Make sure the key fits in an int32_t.
471    std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
472    if (!MaybeKey)
473      return false;
474    int32_t Key = *MaybeKey;
475
476    // Skip if the key is used for either the tombstone or empty special values.
477    if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
478        DenseMapInfo<int32_t>::getEmptyKey() == Key)
479      return false;
480
481    // Skip if there is already a member with the same index.
482    if (Members.contains(Key))
483      return false;
484
485    if (Key > LargestKey) {
486      // The largest index is always less than the interleave factor.
487      if (Index >= static_cast<int32_t>(Factor))
488        return false;
489
490      LargestKey = Key;
491    } else if (Key < SmallestKey) {
492
493      // Make sure the largest index fits in an int32_t.
494      std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
495      if (!MaybeLargestIndex)
496        return false;
497
498      // The largest index is always less than the interleave factor.
499      if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
500        return false;
501
502      SmallestKey = Key;
503    }
504
505    // It's always safe to select the minimum alignment.
506    Alignment = std::min(Alignment, NewAlign);
507    Members[Key] = Instr;
508    return true;
509  }
510
511  /// Get the member with the given index \p Index
512  ///
513  /// \returns nullptr if contains no such member.
514  InstTy *getMember(uint32_t Index) const {
515    int32_t Key = SmallestKey + Index;
516    return Members.lookup(Key);
517  }
518
519  /// Get the index for the given member. Unlike the key in the member
520  /// map, the index starts from 0.
521  uint32_t getIndex(const InstTy *Instr) const {
522    for (auto I : Members) {
523      if (I.second == Instr)
524        return I.first - SmallestKey;
525    }
526
527    llvm_unreachable("InterleaveGroup contains no such member");
528  }
529
530  InstTy *getInsertPos() const { return InsertPos; }
531  void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
532
533  /// Add metadata (e.g. alias info) from the instructions in this group to \p
534  /// NewInst.
535  ///
536  /// FIXME: this function currently does not add noalias metadata a'la
537  /// addNewMedata.  To do that we need to compute the intersection of the
538  /// noalias info from all members.
539  void addMetadata(InstTy *NewInst) const;
540
541  /// Returns true if this Group requires a scalar iteration to handle gaps.
542  bool requiresScalarEpilogue() const {
543    // If the last member of the Group exists, then a scalar epilog is not
544    // needed for this group.
545    if (getMember(getFactor() - 1))
546      return false;
547
548    // We have a group with gaps. It therefore can't be a reversed access,
549    // because such groups get invalidated (TODO).
550    assert(!isReverse() && "Group should have been invalidated");
551
552    // This is a group of loads, with gaps, and without a last-member
553    return true;
554  }
555
556private:
557  uint32_t Factor; // Interleave Factor.
558  bool Reverse;
559  Align Alignment;
560  DenseMap<int32_t, InstTy *> Members;
561  int32_t SmallestKey = 0;
562  int32_t LargestKey = 0;
563
564  // To avoid breaking dependences, vectorized instructions of an interleave
565  // group should be inserted at either the first load or the last store in
566  // program order.
567  //
568  // E.g. %even = load i32             // Insert Position
569  //      %add = add i32 %even         // Use of %even
570  //      %odd = load i32
571  //
572  //      store i32 %even
573  //      %odd = add i32               // Def of %odd
574  //      store i32 %odd               // Insert Position
575  InstTy *InsertPos;
576};
577
578/// Drive the analysis of interleaved memory accesses in the loop.
579///
580/// Use this class to analyze interleaved accesses only when we can vectorize
581/// a loop. Otherwise it's meaningless to do analysis as the vectorization
582/// on interleaved accesses is unsafe.
583///
584/// The analysis collects interleave groups and records the relationships
585/// between the member and the group in a map.
586class InterleavedAccessInfo {
587public:
588  InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
589                        DominatorTree *DT, LoopInfo *LI,
590                        const LoopAccessInfo *LAI)
591      : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
592
593  ~InterleavedAccessInfo() { invalidateGroups(); }
594
595  /// Analyze the interleaved accesses and collect them in interleave
596  /// groups. Substitute symbolic strides using \p Strides.
597  /// Consider also predicated loads/stores in the analysis if
598  /// \p EnableMaskedInterleavedGroup is true.
599  void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
600
601  /// Invalidate groups, e.g., in case all blocks in loop will be predicated
602  /// contrary to original assumption. Although we currently prevent group
603  /// formation for predicated accesses, we may be able to relax this limitation
604  /// in the future once we handle more complicated blocks. Returns true if any
605  /// groups were invalidated.
606  bool invalidateGroups() {
607    if (InterleaveGroups.empty()) {
608      assert(
609          !RequiresScalarEpilogue &&
610          "RequiresScalarEpilog should not be set without interleave groups");
611      return false;
612    }
613
614    InterleaveGroupMap.clear();
615    for (auto *Ptr : InterleaveGroups)
616      delete Ptr;
617    InterleaveGroups.clear();
618    RequiresScalarEpilogue = false;
619    return true;
620  }
621
622  /// Check if \p Instr belongs to any interleave group.
623  bool isInterleaved(Instruction *Instr) const {
624    return InterleaveGroupMap.contains(Instr);
625  }
626
627  /// Get the interleave group that \p Instr belongs to.
628  ///
629  /// \returns nullptr if doesn't have such group.
630  InterleaveGroup<Instruction> *
631  getInterleaveGroup(const Instruction *Instr) const {
632    return InterleaveGroupMap.lookup(Instr);
633  }
634
635  iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
636  getInterleaveGroups() {
637    return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
638  }
639
640  /// Returns true if an interleaved group that may access memory
641  /// out-of-bounds requires a scalar epilogue iteration for correctness.
642  bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
643
644  /// Invalidate groups that require a scalar epilogue (due to gaps). This can
645  /// happen when optimizing for size forbids a scalar epilogue, and the gap
646  /// cannot be filtered by masking the load/store.
647  void invalidateGroupsRequiringScalarEpilogue();
648
649  /// Returns true if we have any interleave groups.
650  bool hasGroups() const { return !InterleaveGroups.empty(); }
651
652private:
653  /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
654  /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
655  /// The interleaved access analysis can also add new predicates (for example
656  /// by versioning strides of pointers).
657  PredicatedScalarEvolution &PSE;
658
659  Loop *TheLoop;
660  DominatorTree *DT;
661  LoopInfo *LI;
662  const LoopAccessInfo *LAI;
663
664  /// True if the loop may contain non-reversed interleaved groups with
665  /// out-of-bounds accesses. We ensure we don't speculatively access memory
666  /// out-of-bounds by executing at least one scalar epilogue iteration.
667  bool RequiresScalarEpilogue = false;
668
669  /// Holds the relationships between the members and the interleave group.
670  DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
671
672  SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
673
674  /// Holds dependences among the memory accesses in the loop. It maps a source
675  /// access to a set of dependent sink accesses.
676  DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
677
678  /// The descriptor for a strided memory access.
679  struct StrideDescriptor {
680    StrideDescriptor() = default;
681    StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
682                     Align Alignment)
683        : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
684
685    // The access's stride. It is negative for a reverse access.
686    int64_t Stride = 0;
687
688    // The scalar expression of this access.
689    const SCEV *Scev = nullptr;
690
691    // The size of the memory object.
692    uint64_t Size = 0;
693
694    // The alignment of this access.
695    Align Alignment;
696  };
697
698  /// A type for holding instructions and their stride descriptors.
699  using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
700
701  /// Create a new interleave group with the given instruction \p Instr,
702  /// stride \p Stride and alignment \p Align.
703  ///
704  /// \returns the newly created interleave group.
705  InterleaveGroup<Instruction> *
706  createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
707    assert(!InterleaveGroupMap.count(Instr) &&
708           "Already in an interleaved access group");
709    InterleaveGroupMap[Instr] =
710        new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
711    InterleaveGroups.insert(InterleaveGroupMap[Instr]);
712    return InterleaveGroupMap[Instr];
713  }
714
715  /// Release the group and remove all the relationships.
716  void releaseGroup(InterleaveGroup<Instruction> *Group) {
717    for (unsigned i = 0; i < Group->getFactor(); i++)
718      if (Instruction *Member = Group->getMember(i))
719        InterleaveGroupMap.erase(Member);
720
721    InterleaveGroups.erase(Group);
722    delete Group;
723  }
724
725  /// Collect all the accesses with a constant stride in program order.
726  void collectConstStrideAccesses(
727      MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
728      const DenseMap<Value *, const SCEV *> &Strides);
729
730  /// Returns true if \p Stride is allowed in an interleaved group.
731  static bool isStrided(int Stride);
732
733  /// Returns true if \p BB is a predicated block.
734  bool isPredicated(BasicBlock *BB) const {
735    return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
736  }
737
738  /// Returns true if LoopAccessInfo can be used for dependence queries.
739  bool areDependencesValid() const {
740    return LAI && LAI->getDepChecker().getDependences();
741  }
742
743  /// Returns true if memory accesses \p A and \p B can be reordered, if
744  /// necessary, when constructing interleaved groups.
745  ///
746  /// \p A must precede \p B in program order. We return false if reordering is
747  /// not necessary or is prevented because \p A and \p B may be dependent.
748  bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
749                                                 StrideEntry *B) const {
750    // Code motion for interleaved accesses can potentially hoist strided loads
751    // and sink strided stores. The code below checks the legality of the
752    // following two conditions:
753    //
754    // 1. Potentially moving a strided load (B) before any store (A) that
755    //    precedes B, or
756    //
757    // 2. Potentially moving a strided store (A) after any load or store (B)
758    //    that A precedes.
759    //
760    // It's legal to reorder A and B if we know there isn't a dependence from A
761    // to B. Note that this determination is conservative since some
762    // dependences could potentially be reordered safely.
763
764    // A is potentially the source of a dependence.
765    auto *Src = A->first;
766    auto SrcDes = A->second;
767
768    // B is potentially the sink of a dependence.
769    auto *Sink = B->first;
770    auto SinkDes = B->second;
771
772    // Code motion for interleaved accesses can't violate WAR dependences.
773    // Thus, reordering is legal if the source isn't a write.
774    if (!Src->mayWriteToMemory())
775      return true;
776
777    // At least one of the accesses must be strided.
778    if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
779      return true;
780
781    // If dependence information is not available from LoopAccessInfo,
782    // conservatively assume the instructions can't be reordered.
783    if (!areDependencesValid())
784      return false;
785
786    // If we know there is a dependence from source to sink, assume the
787    // instructions can't be reordered. Otherwise, reordering is legal.
788    return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
789  }
790
791  /// Collect the dependences from LoopAccessInfo.
792  ///
793  /// We process the dependences once during the interleaved access analysis to
794  /// enable constant-time dependence queries.
795  void collectDependences() {
796    if (!areDependencesValid())
797      return;
798    auto *Deps = LAI->getDepChecker().getDependences();
799    for (auto Dep : *Deps)
800      Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
801  }
802};
803
804} // llvm namespace
805
806#endif
807