1353940Sdim//===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// 2353940Sdim// 3353940Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4353940Sdim// See https://llvm.org/LICENSE.txt for license information. 5353940Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6353940Sdim// 7353940Sdim//===----------------------------------------------------------------------===// 8353940Sdim// 9353940Sdim/// \file 10353940Sdim/// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 11353940Sdim/// branches to help accelerate DSP applications. These two extensions can be 12353940Sdim/// combined to provide implicit vector predication within a low-overhead loop. 13353940Sdim/// The HardwareLoops pass inserts intrinsics identifying loops that the 14353940Sdim/// backend will attempt to convert into a low-overhead loop. The vectorizer is 15353940Sdim/// responsible for generating a vectorized loop in which the lanes are 16353940Sdim/// predicated upon the iteration counter. This pass looks at these predicated 17353940Sdim/// vector loops, that are targets for low-overhead loops, and prepares it for 18353940Sdim/// code generation. Once the vectorizer has produced a masked loop, there's a 19353940Sdim/// couple of final forms: 20353940Sdim/// - A tail-predicated loop, with implicit predication. 21353940Sdim/// - A loop containing multiple VCPT instructions, predicating multiple VPT 22353940Sdim/// blocks of instructions operating on different vector types. 23357095Sdim/// 24357095Sdim/// This pass inserts the inserts the VCTP intrinsic to represent the effect of 25357095Sdim/// tail predication. This will be picked up by the ARM Low-overhead loop pass, 26357095Sdim/// which performs the final transformation to a DLSTP or WLSTP tail-predicated 27357095Sdim/// loop. 28353940Sdim 29357095Sdim#include "ARM.h" 30357095Sdim#include "ARMSubtarget.h" 31353940Sdim#include "llvm/Analysis/LoopInfo.h" 32353940Sdim#include "llvm/Analysis/LoopPass.h" 33353940Sdim#include "llvm/Analysis/ScalarEvolution.h" 34353940Sdim#include "llvm/Analysis/ScalarEvolutionExpander.h" 35353940Sdim#include "llvm/Analysis/ScalarEvolutionExpressions.h" 36353940Sdim#include "llvm/Analysis/TargetTransformInfo.h" 37353940Sdim#include "llvm/CodeGen/TargetPassConfig.h" 38357095Sdim#include "llvm/IR/IRBuilder.h" 39353940Sdim#include "llvm/IR/Instructions.h" 40357095Sdim#include "llvm/IR/IntrinsicsARM.h" 41353940Sdim#include "llvm/IR/PatternMatch.h" 42353940Sdim#include "llvm/Support/Debug.h" 43353940Sdim#include "llvm/Transforms/Utils/BasicBlockUtils.h" 44353940Sdim 45353940Sdimusing namespace llvm; 46353940Sdim 47353940Sdim#define DEBUG_TYPE "mve-tail-predication" 48353940Sdim#define DESC "Transform predicated vector loops to use MVE tail predication" 49353940Sdim 50357095Sdimcl::opt<bool> 51353940SdimDisableTailPredication("disable-mve-tail-predication", cl::Hidden, 52353940Sdim cl::init(true), 53353940Sdim cl::desc("Disable MVE Tail Predication")); 54353940Sdimnamespace { 55353940Sdim 56353940Sdimclass MVETailPredication : public LoopPass { 57353940Sdim SmallVector<IntrinsicInst*, 4> MaskedInsts; 58353940Sdim Loop *L = nullptr; 59353940Sdim ScalarEvolution *SE = nullptr; 60353940Sdim TargetTransformInfo *TTI = nullptr; 61353940Sdim 62353940Sdimpublic: 63353940Sdim static char ID; 64353940Sdim 65353940Sdim MVETailPredication() : LoopPass(ID) { } 66353940Sdim 67353940Sdim void getAnalysisUsage(AnalysisUsage &AU) const override { 68353940Sdim AU.addRequired<ScalarEvolutionWrapperPass>(); 69353940Sdim AU.addRequired<LoopInfoWrapperPass>(); 70353940Sdim AU.addRequired<TargetPassConfig>(); 71353940Sdim AU.addRequired<TargetTransformInfoWrapperPass>(); 72353940Sdim AU.addPreserved<LoopInfoWrapperPass>(); 73353940Sdim AU.setPreservesCFG(); 74353940Sdim } 75353940Sdim 76353940Sdim bool runOnLoop(Loop *L, LPPassManager&) override; 77353940Sdim 78353940Sdimprivate: 79353940Sdim 80353940Sdim /// Perform the relevant checks on the loop and convert if possible. 81353940Sdim bool TryConvert(Value *TripCount); 82353940Sdim 83353940Sdim /// Return whether this is a vectorized loop, that contains masked 84353940Sdim /// load/stores. 85353940Sdim bool IsPredicatedVectorLoop(); 86353940Sdim 87353940Sdim /// Compute a value for the total number of elements that the predicated 88353940Sdim /// loop will process. 89353940Sdim Value *ComputeElements(Value *TripCount, VectorType *VecTy); 90353940Sdim 91353940Sdim /// Is the icmp that generates an i1 vector, based upon a loop counter 92353940Sdim /// and a limit that is defined outside the loop. 93353940Sdim bool isTailPredicate(Instruction *Predicate, Value *NumElements); 94357095Sdim 95357095Sdim /// Insert the intrinsic to represent the effect of tail predication. 96357095Sdim void InsertVCTPIntrinsic(Instruction *Predicate, 97357095Sdim DenseMap<Instruction*, Instruction*> &NewPredicates, 98357095Sdim VectorType *VecTy, 99357095Sdim Value *NumElements); 100353940Sdim}; 101353940Sdim 102353940Sdim} // end namespace 103353940Sdim 104353940Sdimstatic bool IsDecrement(Instruction &I) { 105353940Sdim auto *Call = dyn_cast<IntrinsicInst>(&I); 106353940Sdim if (!Call) 107353940Sdim return false; 108353940Sdim 109353940Sdim Intrinsic::ID ID = Call->getIntrinsicID(); 110353940Sdim return ID == Intrinsic::loop_decrement_reg; 111353940Sdim} 112353940Sdim 113353940Sdimstatic bool IsMasked(Instruction *I) { 114353940Sdim auto *Call = dyn_cast<IntrinsicInst>(I); 115353940Sdim if (!Call) 116353940Sdim return false; 117353940Sdim 118353940Sdim Intrinsic::ID ID = Call->getIntrinsicID(); 119353940Sdim // TODO: Support gather/scatter expand/compress operations. 120353940Sdim return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 121353940Sdim} 122353940Sdim 123353940Sdimbool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 124353940Sdim if (skipLoop(L) || DisableTailPredication) 125353940Sdim return false; 126353940Sdim 127353940Sdim Function &F = *L->getHeader()->getParent(); 128353940Sdim auto &TPC = getAnalysis<TargetPassConfig>(); 129353940Sdim auto &TM = TPC.getTM<TargetMachine>(); 130353940Sdim auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 131353940Sdim TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 132353940Sdim SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 133353940Sdim this->L = L; 134353940Sdim 135353940Sdim // The MVE and LOB extensions are combined to enable tail-predication, but 136353940Sdim // there's nothing preventing us from generating VCTP instructions for v8.1m. 137353940Sdim if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 138357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); 139353940Sdim return false; 140353940Sdim } 141353940Sdim 142353940Sdim BasicBlock *Preheader = L->getLoopPreheader(); 143353940Sdim if (!Preheader) 144353940Sdim return false; 145353940Sdim 146353940Sdim auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 147353940Sdim for (auto &I : *BB) { 148353940Sdim auto *Call = dyn_cast<IntrinsicInst>(&I); 149353940Sdim if (!Call) 150353940Sdim continue; 151353940Sdim 152353940Sdim Intrinsic::ID ID = Call->getIntrinsicID(); 153353940Sdim if (ID == Intrinsic::set_loop_iterations || 154353940Sdim ID == Intrinsic::test_set_loop_iterations) 155353940Sdim return cast<IntrinsicInst>(&I); 156353940Sdim } 157353940Sdim return nullptr; 158353940Sdim }; 159353940Sdim 160353940Sdim // Look for the hardware loop intrinsic that sets the iteration count. 161353940Sdim IntrinsicInst *Setup = FindLoopIterations(Preheader); 162353940Sdim 163357095Sdim // The test.set iteration could live in the pre-preheader. 164353940Sdim if (!Setup) { 165353940Sdim if (!Preheader->getSinglePredecessor()) 166353940Sdim return false; 167353940Sdim Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 168353940Sdim if (!Setup) 169353940Sdim return false; 170353940Sdim } 171353940Sdim 172353940Sdim // Search for the hardware loop intrinic that decrements the loop counter. 173353940Sdim IntrinsicInst *Decrement = nullptr; 174353940Sdim for (auto *BB : L->getBlocks()) { 175353940Sdim for (auto &I : *BB) { 176353940Sdim if (IsDecrement(I)) { 177353940Sdim Decrement = cast<IntrinsicInst>(&I); 178353940Sdim break; 179353940Sdim } 180353940Sdim } 181353940Sdim } 182353940Sdim 183353940Sdim if (!Decrement) 184353940Sdim return false; 185353940Sdim 186357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" 187353940Sdim << *Decrement << "\n"); 188357095Sdim return TryConvert(Setup->getArgOperand(0)); 189353940Sdim} 190353940Sdim 191353940Sdimbool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) { 192353940Sdim // Look for the following: 193353940Sdim 194353940Sdim // %trip.count.minus.1 = add i32 %N, -1 195353940Sdim // %broadcast.splatinsert10 = insertelement <4 x i32> undef, 196353940Sdim // i32 %trip.count.minus.1, i32 0 197353940Sdim // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, 198353940Sdim // <4 x i32> undef, 199353940Sdim // <4 x i32> zeroinitializer 200353940Sdim // ... 201353940Sdim // ... 202353940Sdim // %index = phi i32 203353940Sdim // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 204353940Sdim // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, 205353940Sdim // <4 x i32> undef, 206353940Sdim // <4 x i32> zeroinitializer 207353940Sdim // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> 208353940Sdim // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 209353940Sdim 210353940Sdim // And return whether V == %pred. 211353940Sdim 212353940Sdim using namespace PatternMatch; 213353940Sdim 214353940Sdim CmpInst::Predicate Pred; 215353940Sdim Instruction *Shuffle = nullptr; 216353940Sdim Instruction *Induction = nullptr; 217353940Sdim 218353940Sdim // The vector icmp 219353940Sdim if (!match(I, m_ICmp(Pred, m_Instruction(Induction), 220353940Sdim m_Instruction(Shuffle))) || 221357095Sdim Pred != ICmpInst::ICMP_ULE) 222353940Sdim return false; 223353940Sdim 224353940Sdim // First find the stuff outside the loop which is setting up the limit 225353940Sdim // vector.... 226353940Sdim // The invariant shuffle that broadcast the limit into a vector. 227353940Sdim Instruction *Insert = nullptr; 228353940Sdim if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 229353940Sdim m_Zero()))) 230353940Sdim return false; 231353940Sdim 232353940Sdim // Insert the limit into a vector. 233353940Sdim Instruction *BECount = nullptr; 234353940Sdim if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount), 235353940Sdim m_Zero()))) 236353940Sdim return false; 237353940Sdim 238353940Sdim // The limit calculation, backedge count. 239353940Sdim Value *TripCount = nullptr; 240353940Sdim if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) 241353940Sdim return false; 242353940Sdim 243357095Sdim if (TripCount != NumElements || !L->isLoopInvariant(BECount)) 244353940Sdim return false; 245353940Sdim 246353940Sdim // Now back to searching inside the loop body... 247357095Sdim // Find the add with takes the index iv and adds a constant vector to it. 248353940Sdim Instruction *BroadcastSplat = nullptr; 249353940Sdim Constant *Const = nullptr; 250353940Sdim if (!match(Induction, m_Add(m_Instruction(BroadcastSplat), 251353940Sdim m_Constant(Const)))) 252353940Sdim return false; 253353940Sdim 254353940Sdim // Check that we're adding <0, 1, 2, 3... 255353940Sdim if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { 256353940Sdim for (unsigned i = 0; i < CDS->getNumElements(); ++i) { 257353940Sdim if (CDS->getElementAsInteger(i) != i) 258353940Sdim return false; 259353940Sdim } 260353940Sdim } else 261353940Sdim return false; 262353940Sdim 263353940Sdim // The shuffle which broadcasts the index iv into a vector. 264353940Sdim if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 265353940Sdim m_Zero()))) 266353940Sdim return false; 267353940Sdim 268353940Sdim // The insert element which initialises a vector with the index iv. 269353940Sdim Instruction *IV = nullptr; 270353940Sdim if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) 271353940Sdim return false; 272353940Sdim 273353940Sdim // The index iv. 274353940Sdim auto *Phi = dyn_cast<PHINode>(IV); 275353940Sdim if (!Phi) 276353940Sdim return false; 277353940Sdim 278353940Sdim // TODO: Don't think we need to check the entry value. 279353940Sdim Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); 280353940Sdim if (!match(OnEntry, m_Zero())) 281353940Sdim return false; 282357095Sdim 283353940Sdim Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); 284353940Sdim unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); 285353940Sdim 286353940Sdim Instruction *LHS = nullptr; 287353940Sdim if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) 288353940Sdim return false; 289357095Sdim 290353940Sdim return LHS == Phi; 291353940Sdim} 292353940Sdim 293353940Sdimstatic VectorType* getVectorType(IntrinsicInst *I) { 294353940Sdim unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 295353940Sdim auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 296353940Sdim return cast<VectorType>(PtrTy->getElementType()); 297353940Sdim} 298353940Sdim 299353940Sdimbool MVETailPredication::IsPredicatedVectorLoop() { 300353940Sdim // Check that the loop contains at least one masked load/store intrinsic. 301353940Sdim // We only support 'normal' vector instructions - other than masked 302353940Sdim // load/stores. 303353940Sdim for (auto *BB : L->getBlocks()) { 304353940Sdim for (auto &I : *BB) { 305353940Sdim if (IsMasked(&I)) { 306353940Sdim VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); 307353940Sdim unsigned Lanes = VecTy->getNumElements(); 308353940Sdim unsigned ElementWidth = VecTy->getScalarSizeInBits(); 309353940Sdim // MVE vectors are 128-bit, but don't support 128 x i1. 310353940Sdim // TODO: Can we support vectors larger than 128-bits? 311357095Sdim unsigned MaxWidth = TTI->getRegisterBitWidth(true); 312357095Sdim if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth) 313353940Sdim return false; 314353940Sdim MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 315353940Sdim } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { 316353940Sdim for (auto &U : Int->args()) { 317353940Sdim if (isa<VectorType>(U->getType())) 318353940Sdim return false; 319353940Sdim } 320353940Sdim } 321353940Sdim } 322353940Sdim } 323353940Sdim 324353940Sdim return !MaskedInsts.empty(); 325353940Sdim} 326353940Sdim 327353940SdimValue* MVETailPredication::ComputeElements(Value *TripCount, 328353940Sdim VectorType *VecTy) { 329353940Sdim const SCEV *TripCountSE = SE->getSCEV(TripCount); 330353940Sdim ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()), 331353940Sdim VecTy->getNumElements()); 332353940Sdim 333353940Sdim if (VF->equalsInt(1)) 334353940Sdim return nullptr; 335353940Sdim 336353940Sdim // TODO: Support constant trip counts. 337353940Sdim auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* { 338353940Sdim if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 339353940Sdim if (Const->getAPInt() != -VF->getValue()) 340353940Sdim return nullptr; 341353940Sdim } else 342353940Sdim return nullptr; 343353940Sdim return dyn_cast<SCEVMulExpr>(S->getOperand(1)); 344353940Sdim }; 345353940Sdim 346353940Sdim auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* { 347353940Sdim if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 348353940Sdim if (Const->getValue() != VF) 349353940Sdim return nullptr; 350353940Sdim } else 351353940Sdim return nullptr; 352353940Sdim return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); 353353940Sdim }; 354353940Sdim 355353940Sdim auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* { 356353940Sdim if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { 357353940Sdim if (Const->getValue() != VF) 358353940Sdim return nullptr; 359353940Sdim } else 360353940Sdim return nullptr; 361353940Sdim 362353940Sdim if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { 363353940Sdim if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { 364353940Sdim if (Const->getAPInt() != (VF->getValue() - 1)) 365353940Sdim return nullptr; 366353940Sdim } else 367353940Sdim return nullptr; 368353940Sdim 369353940Sdim return RoundUp->getOperand(1); 370353940Sdim } 371353940Sdim return nullptr; 372353940Sdim }; 373353940Sdim 374353940Sdim // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to 375353940Sdim // determine the numbers of elements instead? Looks like this is what is used 376353940Sdim // for delinearization, but I'm not sure if it can be applied to the 377353940Sdim // vectorized form - at least not without a bit more work than I feel 378353940Sdim // comfortable with. 379353940Sdim 380353940Sdim // Search for Elems in the following SCEV: 381353940Sdim // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> 382353940Sdim const SCEV *Elems = nullptr; 383353940Sdim if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) 384353940Sdim if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) 385353940Sdim if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) 386353940Sdim if (auto *Mul = VisitAdd(Add)) 387353940Sdim if (auto *Div = VisitMul(Mul)) 388353940Sdim if (auto *Res = VisitDiv(Div)) 389353940Sdim Elems = Res; 390353940Sdim 391353940Sdim if (!Elems) 392353940Sdim return nullptr; 393353940Sdim 394353940Sdim Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); 395353940Sdim if (!isSafeToExpandAt(Elems, InsertPt, *SE)) 396353940Sdim return nullptr; 397353940Sdim 398353940Sdim auto DL = L->getHeader()->getModule()->getDataLayout(); 399353940Sdim SCEVExpander Expander(*SE, DL, "elements"); 400353940Sdim return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); 401353940Sdim} 402353940Sdim 403353940Sdim// Look through the exit block to see whether there's a duplicate predicate 404353940Sdim// instruction. This can happen when we need to perform a select on values 405353940Sdim// from the last and previous iteration. Instead of doing a straight 406353940Sdim// replacement of that predicate with the vctp, clone the vctp and place it 407353940Sdim// in the block. This means that the VPR doesn't have to be live into the 408353940Sdim// exit block which should make it easier to convert this loop into a proper 409353940Sdim// tail predicated loop. 410353940Sdimstatic void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, 411353940Sdim SetVector<Instruction*> &MaybeDead, Loop *L) { 412357095Sdim BasicBlock *Exit = L->getUniqueExitBlock(); 413357095Sdim if (!Exit) { 414357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n"); 415357095Sdim return; 416357095Sdim } 417353940Sdim 418357095Sdim for (auto &Pair : NewPredicates) { 419357095Sdim Instruction *OldPred = Pair.first; 420357095Sdim Instruction *NewPred = Pair.second; 421357095Sdim 422357095Sdim for (auto &I : *Exit) { 423357095Sdim if (I.isSameOperationAs(OldPred)) { 424357095Sdim Instruction *PredClone = NewPred->clone(); 425357095Sdim PredClone->insertBefore(&I); 426357095Sdim I.replaceAllUsesWith(PredClone); 427357095Sdim MaybeDead.insert(&I); 428357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump(); 429357095Sdim dbgs() << "ARM TP: with: "; PredClone->dump()); 430357095Sdim break; 431353940Sdim } 432353940Sdim } 433353940Sdim } 434353940Sdim 435353940Sdim // Drop references and add operands to check for dead. 436353940Sdim SmallPtrSet<Instruction*, 4> Dead; 437353940Sdim while (!MaybeDead.empty()) { 438353940Sdim auto *I = MaybeDead.front(); 439353940Sdim MaybeDead.remove(I); 440353940Sdim if (I->hasNUsesOrMore(1)) 441353940Sdim continue; 442353940Sdim 443353940Sdim for (auto &U : I->operands()) { 444353940Sdim if (auto *OpI = dyn_cast<Instruction>(U)) 445353940Sdim MaybeDead.insert(OpI); 446353940Sdim } 447353940Sdim I->dropAllReferences(); 448353940Sdim Dead.insert(I); 449353940Sdim } 450353940Sdim 451357095Sdim for (auto *I : Dead) { 452357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump()); 453353940Sdim I->eraseFromParent(); 454357095Sdim } 455353940Sdim 456353940Sdim for (auto I : L->blocks()) 457353940Sdim DeleteDeadPHIs(I); 458353940Sdim} 459353940Sdim 460357095Sdimvoid MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate, 461357095Sdim DenseMap<Instruction*, Instruction*> &NewPredicates, 462357095Sdim VectorType *VecTy, Value *NumElements) { 463357095Sdim IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); 464357095Sdim Module *M = L->getHeader()->getModule(); 465357095Sdim Type *Ty = IntegerType::get(M->getContext(), 32); 466357095Sdim 467357095Sdim // Insert a phi to count the number of elements processed by the loop. 468357095Sdim PHINode *Processed = Builder.CreatePHI(Ty, 2); 469357095Sdim Processed->addIncoming(NumElements, L->getLoopPreheader()); 470357095Sdim 471357095Sdim // Insert the intrinsic to represent the effect of tail predication. 472357095Sdim Builder.SetInsertPoint(cast<Instruction>(Predicate)); 473357095Sdim ConstantInt *Factor = 474357095Sdim ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements()); 475357095Sdim 476357095Sdim Intrinsic::ID VCTPID; 477357095Sdim switch (VecTy->getNumElements()) { 478357095Sdim default: 479357095Sdim llvm_unreachable("unexpected number of lanes"); 480357095Sdim case 4: VCTPID = Intrinsic::arm_mve_vctp32; break; 481357095Sdim case 8: VCTPID = Intrinsic::arm_mve_vctp16; break; 482357095Sdim case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; 483357095Sdim 484357095Sdim // FIXME: vctp64 currently not supported because the predicate 485357095Sdim // vector wants to be <2 x i1>, but v2i1 is not a legal MVE 486357095Sdim // type, so problems happen at isel time. 487357095Sdim // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics 488357095Sdim // purposes, but takes a v4i1 instead of a v2i1. 489357095Sdim } 490357095Sdim Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 491357095Sdim Value *TailPredicate = Builder.CreateCall(VCTP, Processed); 492357095Sdim Predicate->replaceAllUsesWith(TailPredicate); 493357095Sdim NewPredicates[Predicate] = cast<Instruction>(TailPredicate); 494357095Sdim 495357095Sdim // Add the incoming value to the new phi. 496357095Sdim // TODO: This add likely already exists in the loop. 497357095Sdim Value *Remaining = Builder.CreateSub(Processed, Factor); 498357095Sdim Processed->addIncoming(Remaining, L->getLoopLatch()); 499357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " 500357095Sdim << *Processed << "\n" 501357095Sdim << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n"); 502357095Sdim} 503357095Sdim 504353940Sdimbool MVETailPredication::TryConvert(Value *TripCount) { 505357095Sdim if (!IsPredicatedVectorLoop()) { 506357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop"); 507353940Sdim return false; 508357095Sdim } 509353940Sdim 510357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); 511353940Sdim 512353940Sdim // Walk through the masked intrinsics and try to find whether the predicate 513353940Sdim // operand is generated from an induction variable. 514353940Sdim SetVector<Instruction*> Predicates; 515353940Sdim DenseMap<Instruction*, Instruction*> NewPredicates; 516353940Sdim 517353940Sdim for (auto *I : MaskedInsts) { 518353940Sdim Intrinsic::ID ID = I->getIntrinsicID(); 519353940Sdim unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; 520353940Sdim auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 521353940Sdim if (!Predicate || Predicates.count(Predicate)) 522353940Sdim continue; 523353940Sdim 524353940Sdim VectorType *VecTy = getVectorType(I); 525353940Sdim Value *NumElements = ComputeElements(TripCount, VecTy); 526353940Sdim if (!NumElements) 527353940Sdim continue; 528353940Sdim 529353940Sdim if (!isTailPredicate(Predicate, NumElements)) { 530357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n"); 531353940Sdim continue; 532353940Sdim } 533353940Sdim 534357095Sdim LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n"); 535353940Sdim Predicates.insert(Predicate); 536353940Sdim 537357095Sdim InsertVCTPIntrinsic(Predicate, NewPredicates, VecTy, NumElements); 538353940Sdim } 539353940Sdim 540353940Sdim // Now clean up. 541353940Sdim Cleanup(NewPredicates, Predicates, L); 542353940Sdim return true; 543353940Sdim} 544353940Sdim 545353940SdimPass *llvm::createMVETailPredicationPass() { 546353940Sdim return new MVETailPredication(); 547353940Sdim} 548353940Sdim 549353940Sdimchar MVETailPredication::ID = 0; 550353940Sdim 551353940SdimINITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 552353940SdimINITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 553