1//===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file defines some vectorizer utilities. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef LLVM_ANALYSIS_VECTORUTILS_H 14#define LLVM_ANALYSIS_VECTORUTILS_H 15 16#include "llvm/ADT/MapVector.h" 17#include "llvm/ADT/SmallVector.h" 18#include "llvm/Analysis/LoopAccessAnalysis.h" 19#include "llvm/IR/VFABIDemangler.h" 20#include "llvm/Support/CheckedArithmetic.h" 21 22namespace llvm { 23class TargetLibraryInfo; 24 25/// The Vector Function Database. 26/// 27/// Helper class used to find the vector functions associated to a 28/// scalar CallInst. 29class VFDatabase { 30 /// The Module of the CallInst CI. 31 const Module *M; 32 /// The CallInst instance being queried for scalar to vector mappings. 33 const CallInst &CI; 34 /// List of vector functions descriptors associated to the call 35 /// instruction. 36 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 37 38 /// Retrieve the scalar-to-vector mappings associated to the rule of 39 /// a vector Function ABI. 40 static void getVFABIMappings(const CallInst &CI, 41 SmallVectorImpl<VFInfo> &Mappings) { 42 if (!CI.getCalledFunction()) 43 return; 44 45 const StringRef ScalarName = CI.getCalledFunction()->getName(); 46 47 SmallVector<std::string, 8> ListOfStrings; 48 // The check for the vector-function-abi-variant attribute is done when 49 // retrieving the vector variant names here. 50 VFABI::getVectorVariantNames(CI, ListOfStrings); 51 if (ListOfStrings.empty()) 52 return; 53 for (const auto &MangledName : ListOfStrings) { 54 const std::optional<VFInfo> Shape = 55 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType()); 56 // A match is found via scalar and vector names, and also by 57 // ensuring that the variant described in the attribute has a 58 // corresponding definition or declaration of the vector 59 // function in the Module M. 60 if (Shape && (Shape->ScalarName == ScalarName)) { 61 assert(CI.getModule()->getFunction(Shape->VectorName) && 62 "Vector function is missing."); 63 Mappings.push_back(*Shape); 64 } 65 } 66 } 67 68public: 69 /// Retrieve all the VFInfo instances associated to the CallInst CI. 70 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 71 SmallVector<VFInfo, 8> Ret; 72 73 // Get mappings from the Vector Function ABI variants. 74 getVFABIMappings(CI, Ret); 75 76 // Other non-VFABI variants should be retrieved here. 77 78 return Ret; 79 } 80 81 static bool hasMaskedVariant(const CallInst &CI, 82 std::optional<ElementCount> VF = std::nullopt) { 83 // Check whether we have at least one masked vector version of a scalar 84 // function. If no VF is specified then we check for any masked variant, 85 // otherwise we look for one that matches the supplied VF. 86 auto Mappings = VFDatabase::getMappings(CI); 87 for (VFInfo Info : Mappings) 88 if (!VF || Info.Shape.VF == *VF) 89 if (Info.isMasked()) 90 return true; 91 92 return false; 93 } 94 95 /// Constructor, requires a CallInst instance. 96 VFDatabase(CallInst &CI) 97 : M(CI.getModule()), CI(CI), 98 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 99 /// \defgroup VFDatabase query interface. 100 /// 101 /// @{ 102 /// Retrieve the Function with VFShape \p Shape. 103 Function *getVectorizedFunction(const VFShape &Shape) const { 104 if (Shape == VFShape::getScalarShape(CI.getFunctionType())) 105 return CI.getCalledFunction(); 106 107 for (const auto &Info : ScalarToVectorMappings) 108 if (Info.Shape == Shape) 109 return M->getFunction(Info.VectorName); 110 111 return nullptr; 112 } 113 /// @} 114}; 115 116template <typename T> class ArrayRef; 117class DemandedBits; 118template <typename InstTy> class InterleaveGroup; 119class IRBuilderBase; 120class Loop; 121class ScalarEvolution; 122class TargetTransformInfo; 123class Type; 124class Value; 125 126namespace Intrinsic { 127typedef unsigned ID; 128} 129 130/// A helper function for converting Scalar types to vector types. If 131/// the incoming type is void, we return void. If the EC represents a 132/// scalar, we return the scalar type. 133inline Type *ToVectorTy(Type *Scalar, ElementCount EC) { 134 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar()) 135 return Scalar; 136 return VectorType::get(Scalar, EC); 137} 138 139inline Type *ToVectorTy(Type *Scalar, unsigned VF) { 140 return ToVectorTy(Scalar, ElementCount::getFixed(VF)); 141} 142 143/// Identify if the intrinsic is trivially vectorizable. 144/// This method returns true if the intrinsic's argument types are all scalars 145/// for the scalar form of the intrinsic and all vectors (or scalars handled by 146/// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic. 147bool isTriviallyVectorizable(Intrinsic::ID ID); 148 149/// Identifies if the vector form of the intrinsic has a scalar operand. 150bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, 151 unsigned ScalarOpdIdx); 152 153/// Identifies if the vector form of the intrinsic is overloaded on the type of 154/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. 155bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx); 156 157/// Returns intrinsic ID for call. 158/// For the input call instruction it finds mapping intrinsic and returns 159/// its intrinsic ID, in case it does not found it return not_intrinsic. 160Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 161 const TargetLibraryInfo *TLI); 162 163/// Given a vector and an element number, see if the scalar value is 164/// already around as a register, for example if it were inserted then extracted 165/// from the vector. 166Value *findScalarElement(Value *V, unsigned EltNo); 167 168/// If all non-negative \p Mask elements are the same value, return that value. 169/// If all elements are negative (undefined) or \p Mask contains different 170/// non-negative values, return -1. 171int getSplatIndex(ArrayRef<int> Mask); 172 173/// Get splat value if the input is a splat vector or return nullptr. 174/// The value may be extracted from a splat constants vector or from 175/// a sequence of instructions that broadcast a single value into a vector. 176Value *getSplatValue(const Value *V); 177 178/// Return true if each element of the vector value \p V is poisoned or equal to 179/// every other non-poisoned element. If an index element is specified, either 180/// every element of the vector is poisoned or the element at that index is not 181/// poisoned and equal to every other non-poisoned element. 182/// This may be more powerful than the related getSplatValue() because it is 183/// not limited by finding a scalar source value to a splatted vector. 184bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 185 186/// Transform a shuffle mask's output demanded element mask into demanded 187/// element masks for the 2 operands, returns false if the mask isn't valid. 188/// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. 189/// \p AllowUndefElts permits "-1" indices to be treated as undef. 190bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask, 191 const APInt &DemandedElts, APInt &DemandedLHS, 192 APInt &DemandedRHS, bool AllowUndefElts = false); 193 194/// Replace each shuffle mask index with the scaled sequential indices for an 195/// equivalent mask of narrowed elements. Mask elements that are less than 0 196/// (sentinel values) are repeated in the output mask. 197/// 198/// Example with Scale = 4: 199/// <4 x i32> <3, 2, 0, -1> --> 200/// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 201/// 202/// This is the reverse process of widening shuffle mask elements, but it always 203/// succeeds because the indexes can always be multiplied (scaled up) to map to 204/// narrower vector elements. 205void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 206 SmallVectorImpl<int> &ScaledMask); 207 208/// Try to transform a shuffle mask by replacing elements with the scaled index 209/// for an equivalent mask of widened elements. If all mask elements that would 210/// map to a wider element of the new mask are the same negative number 211/// (sentinel value), that element of the new mask is the same value. If any 212/// element in a given slice is negative and some other element in that slice is 213/// not the same value, return false (partial matches with sentinel values are 214/// not allowed). 215/// 216/// Example with Scale = 4: 217/// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 218/// <4 x i32> <3, 2, 0, -1> 219/// 220/// This is the reverse process of narrowing shuffle mask elements if it 221/// succeeds. This transform is not always possible because indexes may not 222/// divide evenly (scale down) to map to wider vector elements. 223bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 224 SmallVectorImpl<int> &ScaledMask); 225 226/// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds, 227/// to get the shuffle mask with widest possible elements. 228void getShuffleMaskWithWidestElts(ArrayRef<int> Mask, 229 SmallVectorImpl<int> &ScaledMask); 230 231/// Splits and processes shuffle mask depending on the number of input and 232/// output registers. The function does 2 main things: 1) splits the 233/// source/destination vectors into real registers; 2) do the mask analysis to 234/// identify which real registers are permuted. Then the function processes 235/// resulting registers mask using provided action items. If no input register 236/// is defined, \p NoInputAction action is used. If only 1 input register is 237/// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to 238/// process > 2 input registers and masks. 239/// \param Mask Original shuffle mask. 240/// \param NumOfSrcRegs Number of source registers. 241/// \param NumOfDestRegs Number of destination registers. 242/// \param NumOfUsedRegs Number of actually used destination registers. 243void processShuffleMasks( 244 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, 245 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, 246 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, 247 function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction); 248 249/// Compute a map of integer instructions to their minimum legal type 250/// size. 251/// 252/// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 253/// type (e.g. i32) whenever arithmetic is performed on them. 254/// 255/// For targets with native i8 or i16 operations, usually InstCombine can shrink 256/// the arithmetic type down again. However InstCombine refuses to create 257/// illegal types, so for targets without i8 or i16 registers, the lengthening 258/// and shrinking remains. 259/// 260/// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 261/// their scalar equivalents do not, so during vectorization it is important to 262/// remove these lengthens and truncates when deciding the profitability of 263/// vectorization. 264/// 265/// This function analyzes the given range of instructions and determines the 266/// minimum type size each can be converted to. It attempts to remove or 267/// minimize type size changes across each def-use chain, so for example in the 268/// following code: 269/// 270/// %1 = load i8, i8* 271/// %2 = add i8 %1, 2 272/// %3 = load i16, i16* 273/// %4 = zext i8 %2 to i32 274/// %5 = zext i16 %3 to i32 275/// %6 = add i32 %4, %5 276/// %7 = trunc i32 %6 to i16 277/// 278/// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 279/// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 280/// 281/// If the optional TargetTransformInfo is provided, this function tries harder 282/// to do less work by only looking at illegal types. 283MapVector<Instruction*, uint64_t> 284computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 285 DemandedBits &DB, 286 const TargetTransformInfo *TTI=nullptr); 287 288/// Compute the union of two access-group lists. 289/// 290/// If the list contains just one access group, it is returned directly. If the 291/// list is empty, returns nullptr. 292MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 293 294/// Compute the access-group list of access groups that @p Inst1 and @p Inst2 295/// are both in. If either instruction does not access memory at all, it is 296/// considered to be in every list. 297/// 298/// If the list contains just one access group, it is returned directly. If the 299/// list is empty, returns nullptr. 300MDNode *intersectAccessGroups(const Instruction *Inst1, 301 const Instruction *Inst2); 302 303/// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 304/// MD_nontemporal, MD_access_group]. 305/// For K in Kinds, we get the MDNode for K from each of the 306/// elements of VL, compute their "intersection" (i.e., the most generic 307/// metadata value that covers all of the individual values), and set I's 308/// metadata for M equal to the intersection value. 309/// 310/// This function always sets a (possibly null) value for each K in Kinds. 311Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 312 313/// Create a mask that filters the members of an interleave group where there 314/// are gaps. 315/// 316/// For example, the mask for \p Group with interleave-factor 3 317/// and \p VF 4, that has only its first member present is: 318/// 319/// <1,0,0,1,0,0,1,0,0,1,0,0> 320/// 321/// Note: The result is a mask of 0's and 1's, as opposed to the other 322/// create[*]Mask() utilities which create a shuffle mask (mask that 323/// consists of indices). 324Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 325 const InterleaveGroup<Instruction> &Group); 326 327/// Create a mask with replicated elements. 328/// 329/// This function creates a shuffle mask for replicating each of the \p VF 330/// elements in a vector \p ReplicationFactor times. It can be used to 331/// transform a mask of \p VF elements into a mask of 332/// \p VF * \p ReplicationFactor elements used by a predicated 333/// interleaved-group of loads/stores whose Interleaved-factor == 334/// \p ReplicationFactor. 335/// 336/// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 337/// 338/// <0,0,0,1,1,1,2,2,2,3,3,3> 339llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 340 unsigned VF); 341 342/// Create an interleave shuffle mask. 343/// 344/// This function creates a shuffle mask for interleaving \p NumVecs vectors of 345/// vectorization factor \p VF into a single wide vector. The mask is of the 346/// form: 347/// 348/// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 349/// 350/// For example, the mask for VF = 4 and NumVecs = 2 is: 351/// 352/// <0, 4, 1, 5, 2, 6, 3, 7>. 353llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 354 355/// Create a stride shuffle mask. 356/// 357/// This function creates a shuffle mask whose elements begin at \p Start and 358/// are incremented by \p Stride. The mask can be used to deinterleave an 359/// interleaved vector into separate vectors of vectorization factor \p VF. The 360/// mask is of the form: 361/// 362/// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 363/// 364/// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 365/// 366/// <0, 2, 4, 6> 367llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 368 unsigned VF); 369 370/// Create a sequential shuffle mask. 371/// 372/// This function creates shuffle mask whose elements are sequential and begin 373/// at \p Start. The mask contains \p NumInts integers and is padded with \p 374/// NumUndefs undef values. The mask is of the form: 375/// 376/// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 377/// 378/// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 379/// 380/// <0, 1, 2, 3, undef, undef, undef, undef> 381llvm::SmallVector<int, 16> 382createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 383 384/// Given a shuffle mask for a binary shuffle, create the equivalent shuffle 385/// mask assuming both operands are identical. This assumes that the unary 386/// shuffle will use elements from operand 0 (operand 1 will be unused). 387llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask, 388 unsigned NumElts); 389 390/// Concatenate a list of vectors. 391/// 392/// This function generates code that concatenate the vectors in \p Vecs into a 393/// single large vector. The number of vectors should be greater than one, and 394/// their element types should be the same. The number of elements in the 395/// vectors should also be the same; however, if the last vector has fewer 396/// elements, it will be padded with undefs. 397Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 398 399/// Given a mask vector of i1, Return true if all of the elements of this 400/// predicate mask are known to be false or undef. That is, return true if all 401/// lanes can be assumed inactive. 402bool maskIsAllZeroOrUndef(Value *Mask); 403 404/// Given a mask vector of i1, Return true if all of the elements of this 405/// predicate mask are known to be true or undef. That is, return true if all 406/// lanes can be assumed active. 407bool maskIsAllOneOrUndef(Value *Mask); 408 409/// Given a mask vector of i1, Return true if any of the elements of this 410/// predicate mask are known to be true or undef. That is, return true if at 411/// least one lane can be assumed active. 412bool maskContainsAllOneOrUndef(Value *Mask); 413 414/// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 415/// for each lane which may be active. 416APInt possiblyDemandedEltsInMask(Value *Mask); 417 418/// The group of interleaved loads/stores sharing the same stride and 419/// close to each other. 420/// 421/// Each member in this group has an index starting from 0, and the largest 422/// index should be less than interleaved factor, which is equal to the absolute 423/// value of the access's stride. 424/// 425/// E.g. An interleaved load group of factor 4: 426/// for (unsigned i = 0; i < 1024; i+=4) { 427/// a = A[i]; // Member of index 0 428/// b = A[i+1]; // Member of index 1 429/// d = A[i+3]; // Member of index 3 430/// ... 431/// } 432/// 433/// An interleaved store group of factor 4: 434/// for (unsigned i = 0; i < 1024; i+=4) { 435/// ... 436/// A[i] = a; // Member of index 0 437/// A[i+1] = b; // Member of index 1 438/// A[i+2] = c; // Member of index 2 439/// A[i+3] = d; // Member of index 3 440/// } 441/// 442/// Note: the interleaved load group could have gaps (missing members), but 443/// the interleaved store group doesn't allow gaps. 444template <typename InstTy> class InterleaveGroup { 445public: 446 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 447 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 448 InsertPos(nullptr) {} 449 450 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 451 : Alignment(Alignment), InsertPos(Instr) { 452 Factor = std::abs(Stride); 453 assert(Factor > 1 && "Invalid interleave factor"); 454 455 Reverse = Stride < 0; 456 Members[0] = Instr; 457 } 458 459 bool isReverse() const { return Reverse; } 460 uint32_t getFactor() const { return Factor; } 461 Align getAlign() const { return Alignment; } 462 uint32_t getNumMembers() const { return Members.size(); } 463 464 /// Try to insert a new member \p Instr with index \p Index and 465 /// alignment \p NewAlign. The index is related to the leader and it could be 466 /// negative if it is the new leader. 467 /// 468 /// \returns false if the instruction doesn't belong to the group. 469 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 470 // Make sure the key fits in an int32_t. 471 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 472 if (!MaybeKey) 473 return false; 474 int32_t Key = *MaybeKey; 475 476 // Skip if the key is used for either the tombstone or empty special values. 477 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 478 DenseMapInfo<int32_t>::getEmptyKey() == Key) 479 return false; 480 481 // Skip if there is already a member with the same index. 482 if (Members.contains(Key)) 483 return false; 484 485 if (Key > LargestKey) { 486 // The largest index is always less than the interleave factor. 487 if (Index >= static_cast<int32_t>(Factor)) 488 return false; 489 490 LargestKey = Key; 491 } else if (Key < SmallestKey) { 492 493 // Make sure the largest index fits in an int32_t. 494 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 495 if (!MaybeLargestIndex) 496 return false; 497 498 // The largest index is always less than the interleave factor. 499 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 500 return false; 501 502 SmallestKey = Key; 503 } 504 505 // It's always safe to select the minimum alignment. 506 Alignment = std::min(Alignment, NewAlign); 507 Members[Key] = Instr; 508 return true; 509 } 510 511 /// Get the member with the given index \p Index 512 /// 513 /// \returns nullptr if contains no such member. 514 InstTy *getMember(uint32_t Index) const { 515 int32_t Key = SmallestKey + Index; 516 return Members.lookup(Key); 517 } 518 519 /// Get the index for the given member. Unlike the key in the member 520 /// map, the index starts from 0. 521 uint32_t getIndex(const InstTy *Instr) const { 522 for (auto I : Members) { 523 if (I.second == Instr) 524 return I.first - SmallestKey; 525 } 526 527 llvm_unreachable("InterleaveGroup contains no such member"); 528 } 529 530 InstTy *getInsertPos() const { return InsertPos; } 531 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 532 533 /// Add metadata (e.g. alias info) from the instructions in this group to \p 534 /// NewInst. 535 /// 536 /// FIXME: this function currently does not add noalias metadata a'la 537 /// addNewMedata. To do that we need to compute the intersection of the 538 /// noalias info from all members. 539 void addMetadata(InstTy *NewInst) const; 540 541 /// Returns true if this Group requires a scalar iteration to handle gaps. 542 bool requiresScalarEpilogue() const { 543 // If the last member of the Group exists, then a scalar epilog is not 544 // needed for this group. 545 if (getMember(getFactor() - 1)) 546 return false; 547 548 // We have a group with gaps. It therefore can't be a reversed access, 549 // because such groups get invalidated (TODO). 550 assert(!isReverse() && "Group should have been invalidated"); 551 552 // This is a group of loads, with gaps, and without a last-member 553 return true; 554 } 555 556private: 557 uint32_t Factor; // Interleave Factor. 558 bool Reverse; 559 Align Alignment; 560 DenseMap<int32_t, InstTy *> Members; 561 int32_t SmallestKey = 0; 562 int32_t LargestKey = 0; 563 564 // To avoid breaking dependences, vectorized instructions of an interleave 565 // group should be inserted at either the first load or the last store in 566 // program order. 567 // 568 // E.g. %even = load i32 // Insert Position 569 // %add = add i32 %even // Use of %even 570 // %odd = load i32 571 // 572 // store i32 %even 573 // %odd = add i32 // Def of %odd 574 // store i32 %odd // Insert Position 575 InstTy *InsertPos; 576}; 577 578/// Drive the analysis of interleaved memory accesses in the loop. 579/// 580/// Use this class to analyze interleaved accesses only when we can vectorize 581/// a loop. Otherwise it's meaningless to do analysis as the vectorization 582/// on interleaved accesses is unsafe. 583/// 584/// The analysis collects interleave groups and records the relationships 585/// between the member and the group in a map. 586class InterleavedAccessInfo { 587public: 588 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 589 DominatorTree *DT, LoopInfo *LI, 590 const LoopAccessInfo *LAI) 591 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 592 593 ~InterleavedAccessInfo() { invalidateGroups(); } 594 595 /// Analyze the interleaved accesses and collect them in interleave 596 /// groups. Substitute symbolic strides using \p Strides. 597 /// Consider also predicated loads/stores in the analysis if 598 /// \p EnableMaskedInterleavedGroup is true. 599 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 600 601 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 602 /// contrary to original assumption. Although we currently prevent group 603 /// formation for predicated accesses, we may be able to relax this limitation 604 /// in the future once we handle more complicated blocks. Returns true if any 605 /// groups were invalidated. 606 bool invalidateGroups() { 607 if (InterleaveGroups.empty()) { 608 assert( 609 !RequiresScalarEpilogue && 610 "RequiresScalarEpilog should not be set without interleave groups"); 611 return false; 612 } 613 614 InterleaveGroupMap.clear(); 615 for (auto *Ptr : InterleaveGroups) 616 delete Ptr; 617 InterleaveGroups.clear(); 618 RequiresScalarEpilogue = false; 619 return true; 620 } 621 622 /// Check if \p Instr belongs to any interleave group. 623 bool isInterleaved(Instruction *Instr) const { 624 return InterleaveGroupMap.contains(Instr); 625 } 626 627 /// Get the interleave group that \p Instr belongs to. 628 /// 629 /// \returns nullptr if doesn't have such group. 630 InterleaveGroup<Instruction> * 631 getInterleaveGroup(const Instruction *Instr) const { 632 return InterleaveGroupMap.lookup(Instr); 633 } 634 635 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> 636 getInterleaveGroups() { 637 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 638 } 639 640 /// Returns true if an interleaved group that may access memory 641 /// out-of-bounds requires a scalar epilogue iteration for correctness. 642 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 643 644 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 645 /// happen when optimizing for size forbids a scalar epilogue, and the gap 646 /// cannot be filtered by masking the load/store. 647 void invalidateGroupsRequiringScalarEpilogue(); 648 649 /// Returns true if we have any interleave groups. 650 bool hasGroups() const { return !InterleaveGroups.empty(); } 651 652private: 653 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 654 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 655 /// The interleaved access analysis can also add new predicates (for example 656 /// by versioning strides of pointers). 657 PredicatedScalarEvolution &PSE; 658 659 Loop *TheLoop; 660 DominatorTree *DT; 661 LoopInfo *LI; 662 const LoopAccessInfo *LAI; 663 664 /// True if the loop may contain non-reversed interleaved groups with 665 /// out-of-bounds accesses. We ensure we don't speculatively access memory 666 /// out-of-bounds by executing at least one scalar epilogue iteration. 667 bool RequiresScalarEpilogue = false; 668 669 /// Holds the relationships between the members and the interleave group. 670 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 671 672 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 673 674 /// Holds dependences among the memory accesses in the loop. It maps a source 675 /// access to a set of dependent sink accesses. 676 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 677 678 /// The descriptor for a strided memory access. 679 struct StrideDescriptor { 680 StrideDescriptor() = default; 681 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 682 Align Alignment) 683 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 684 685 // The access's stride. It is negative for a reverse access. 686 int64_t Stride = 0; 687 688 // The scalar expression of this access. 689 const SCEV *Scev = nullptr; 690 691 // The size of the memory object. 692 uint64_t Size = 0; 693 694 // The alignment of this access. 695 Align Alignment; 696 }; 697 698 /// A type for holding instructions and their stride descriptors. 699 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 700 701 /// Create a new interleave group with the given instruction \p Instr, 702 /// stride \p Stride and alignment \p Align. 703 /// 704 /// \returns the newly created interleave group. 705 InterleaveGroup<Instruction> * 706 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 707 assert(!InterleaveGroupMap.count(Instr) && 708 "Already in an interleaved access group"); 709 InterleaveGroupMap[Instr] = 710 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 711 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 712 return InterleaveGroupMap[Instr]; 713 } 714 715 /// Release the group and remove all the relationships. 716 void releaseGroup(InterleaveGroup<Instruction> *Group) { 717 for (unsigned i = 0; i < Group->getFactor(); i++) 718 if (Instruction *Member = Group->getMember(i)) 719 InterleaveGroupMap.erase(Member); 720 721 InterleaveGroups.erase(Group); 722 delete Group; 723 } 724 725 /// Collect all the accesses with a constant stride in program order. 726 void collectConstStrideAccesses( 727 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 728 const DenseMap<Value *, const SCEV *> &Strides); 729 730 /// Returns true if \p Stride is allowed in an interleaved group. 731 static bool isStrided(int Stride); 732 733 /// Returns true if \p BB is a predicated block. 734 bool isPredicated(BasicBlock *BB) const { 735 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 736 } 737 738 /// Returns true if LoopAccessInfo can be used for dependence queries. 739 bool areDependencesValid() const { 740 return LAI && LAI->getDepChecker().getDependences(); 741 } 742 743 /// Returns true if memory accesses \p A and \p B can be reordered, if 744 /// necessary, when constructing interleaved groups. 745 /// 746 /// \p A must precede \p B in program order. We return false if reordering is 747 /// not necessary or is prevented because \p A and \p B may be dependent. 748 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 749 StrideEntry *B) const { 750 // Code motion for interleaved accesses can potentially hoist strided loads 751 // and sink strided stores. The code below checks the legality of the 752 // following two conditions: 753 // 754 // 1. Potentially moving a strided load (B) before any store (A) that 755 // precedes B, or 756 // 757 // 2. Potentially moving a strided store (A) after any load or store (B) 758 // that A precedes. 759 // 760 // It's legal to reorder A and B if we know there isn't a dependence from A 761 // to B. Note that this determination is conservative since some 762 // dependences could potentially be reordered safely. 763 764 // A is potentially the source of a dependence. 765 auto *Src = A->first; 766 auto SrcDes = A->second; 767 768 // B is potentially the sink of a dependence. 769 auto *Sink = B->first; 770 auto SinkDes = B->second; 771 772 // Code motion for interleaved accesses can't violate WAR dependences. 773 // Thus, reordering is legal if the source isn't a write. 774 if (!Src->mayWriteToMemory()) 775 return true; 776 777 // At least one of the accesses must be strided. 778 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 779 return true; 780 781 // If dependence information is not available from LoopAccessInfo, 782 // conservatively assume the instructions can't be reordered. 783 if (!areDependencesValid()) 784 return false; 785 786 // If we know there is a dependence from source to sink, assume the 787 // instructions can't be reordered. Otherwise, reordering is legal. 788 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink); 789 } 790 791 /// Collect the dependences from LoopAccessInfo. 792 /// 793 /// We process the dependences once during the interleaved access analysis to 794 /// enable constant-time dependence queries. 795 void collectDependences() { 796 if (!areDependencesValid()) 797 return; 798 auto *Deps = LAI->getDepChecker().getDependences(); 799 for (auto Dep : *Deps) 800 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); 801 } 802}; 803 804} // llvm namespace 805 806#endif 807