1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2//                                    instrinsics
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17#include "llvm/ADT/Twine.h"
18#include "llvm/Analysis/DomTreeUpdater.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Constant.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Dominators.h"
25#include "llvm/IR/Function.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/InstrTypes.h"
28#include "llvm/IR/Instruction.h"
29#include "llvm/IR/Instructions.h"
30#include "llvm/IR/IntrinsicInst.h"
31#include "llvm/IR/Intrinsics.h"
32#include "llvm/IR/Type.h"
33#include "llvm/IR/Value.h"
34#include "llvm/InitializePasses.h"
35#include "llvm/Pass.h"
36#include "llvm/Support/Casting.h"
37#include "llvm/Transforms/Scalar.h"
38#include "llvm/Transforms/Utils/BasicBlockUtils.h"
39#include <algorithm>
40#include <cassert>
41
42using namespace llvm;
43
44#define DEBUG_TYPE "scalarize-masked-mem-intrin"
45
46namespace {
47
48class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
49public:
50  static char ID; // Pass identification, replacement for typeid
51
52  explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
53    initializeScalarizeMaskedMemIntrinLegacyPassPass(
54        *PassRegistry::getPassRegistry());
55  }
56
57  bool runOnFunction(Function &F) override;
58
59  StringRef getPassName() const override {
60    return "Scalarize Masked Memory Intrinsics";
61  }
62
63  void getAnalysisUsage(AnalysisUsage &AU) const override {
64    AU.addRequired<TargetTransformInfoWrapperPass>();
65    AU.addPreserved<DominatorTreeWrapperPass>();
66  }
67};
68
69} // end anonymous namespace
70
71static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
72                          const TargetTransformInfo &TTI, const DataLayout &DL,
73                          DomTreeUpdater *DTU);
74static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
75                             const TargetTransformInfo &TTI,
76                             const DataLayout &DL, DomTreeUpdater *DTU);
77
78char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79
80INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81                      "Scalarize unsupported masked memory intrinsics", false,
82                      false)
83INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
84INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
85INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86                    "Scalarize unsupported masked memory intrinsics", false,
87                    false)
88
89FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90  return new ScalarizeMaskedMemIntrinLegacyPass();
91}
92
93static bool isConstantIntVector(Value *Mask) {
94  Constant *C = dyn_cast<Constant>(Mask);
95  if (!C)
96    return false;
97
98  unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99  for (unsigned i = 0; i != NumElts; ++i) {
100    Constant *CElt = C->getAggregateElement(i);
101    if (!CElt || !isa<ConstantInt>(CElt))
102      return false;
103  }
104
105  return true;
106}
107
108static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109                                unsigned Idx) {
110  return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111}
112
113// Translate a masked load intrinsic like
114// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115//                               <16 x i1> %mask, <16 x i32> %passthru)
116// to a chain of basic blocks, with loading element one-by-one if
117// the appropriate mask bit is set
118//
119//  %1 = bitcast i8* %addr to i32*
120//  %2 = extractelement <16 x i1> %mask, i32 0
121//  br i1 %2, label %cond.load, label %else
122//
123// cond.load:                                        ; preds = %0
124//  %3 = getelementptr i32* %1, i32 0
125//  %4 = load i32* %3
126//  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127//  br label %else
128//
129// else:                                             ; preds = %0, %cond.load
130//  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
131//  %6 = extractelement <16 x i1> %mask, i32 1
132//  br i1 %6, label %cond.load1, label %else2
133//
134// cond.load1:                                       ; preds = %else
135//  %7 = getelementptr i32* %1, i32 1
136//  %8 = load i32* %7
137//  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138//  br label %else2
139//
140// else2:                                          ; preds = %else, %cond.load1
141//  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142//  %10 = extractelement <16 x i1> %mask, i32 2
143//  br i1 %10, label %cond.load4, label %else5
144//
145static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
146                                DomTreeUpdater *DTU, bool &ModifiedDT) {
147  Value *Ptr = CI->getArgOperand(0);
148  Value *Alignment = CI->getArgOperand(1);
149  Value *Mask = CI->getArgOperand(2);
150  Value *Src0 = CI->getArgOperand(3);
151
152  const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
153  VectorType *VecType = cast<FixedVectorType>(CI->getType());
154
155  Type *EltTy = VecType->getElementType();
156
157  IRBuilder<> Builder(CI->getContext());
158  Instruction *InsertPt = CI;
159  BasicBlock *IfBlock = CI->getParent();
160
161  Builder.SetInsertPoint(InsertPt);
162  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
163
164  // Short-cut if the mask is all-true.
165  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
166    Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
167    CI->replaceAllUsesWith(NewI);
168    CI->eraseFromParent();
169    return;
170  }
171
172  // Adjust alignment for the scalar instruction.
173  const Align AdjustedAlignVal =
174      commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
175  // Bitcast %addr from i8* to EltTy*
176  Type *NewPtrType =
177      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
178  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
179  unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
180
181  // The result vector
182  Value *VResult = Src0;
183
184  if (isConstantIntVector(Mask)) {
185    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
186      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
187        continue;
188      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
189      LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
190      VResult = Builder.CreateInsertElement(VResult, Load, Idx);
191    }
192    CI->replaceAllUsesWith(VResult);
193    CI->eraseFromParent();
194    return;
195  }
196
197  // If the mask is not v1i1, use scalar bit test operations. This generates
198  // better results on X86 at least.
199  Value *SclrMask;
200  if (VectorWidth != 1) {
201    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
202    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
203  }
204
205  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
206    // Fill the "else" block, created in the previous iteration
207    //
208    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
209    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
210    //  %cond = icmp ne i16 %mask_1, 0
211    //  br i1 %mask_1, label %cond.load, label %else
212    //
213    Value *Predicate;
214    if (VectorWidth != 1) {
215      Value *Mask = Builder.getInt(APInt::getOneBitSet(
216          VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
217      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
218                                       Builder.getIntN(VectorWidth, 0));
219    } else {
220      Predicate = Builder.CreateExtractElement(Mask, Idx);
221    }
222
223    // Create "cond" block
224    //
225    //  %EltAddr = getelementptr i32* %1, i32 0
226    //  %Elt = load i32* %EltAddr
227    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
228    //
229    Instruction *ThenTerm =
230        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
231                                  /*BranchWeights=*/nullptr, DTU);
232
233    BasicBlock *CondBlock = ThenTerm->getParent();
234    CondBlock->setName("cond.load");
235
236    Builder.SetInsertPoint(CondBlock->getTerminator());
237    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
238    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
239    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
240
241    // Create "else" block, fill it in the next iteration
242    BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
243    NewIfBlock->setName("else");
244    BasicBlock *PrevIfBlock = IfBlock;
245    IfBlock = NewIfBlock;
246
247    // Create the phi to join the new and previous value.
248    Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
249    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
250    Phi->addIncoming(NewVResult, CondBlock);
251    Phi->addIncoming(VResult, PrevIfBlock);
252    VResult = Phi;
253  }
254
255  CI->replaceAllUsesWith(VResult);
256  CI->eraseFromParent();
257
258  ModifiedDT = true;
259}
260
261// Translate a masked store intrinsic, like
262// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
263//                               <16 x i1> %mask)
264// to a chain of basic blocks, that stores element one-by-one if
265// the appropriate mask bit is set
266//
267//   %1 = bitcast i8* %addr to i32*
268//   %2 = extractelement <16 x i1> %mask, i32 0
269//   br i1 %2, label %cond.store, label %else
270//
271// cond.store:                                       ; preds = %0
272//   %3 = extractelement <16 x i32> %val, i32 0
273//   %4 = getelementptr i32* %1, i32 0
274//   store i32 %3, i32* %4
275//   br label %else
276//
277// else:                                             ; preds = %0, %cond.store
278//   %5 = extractelement <16 x i1> %mask, i32 1
279//   br i1 %5, label %cond.store1, label %else2
280//
281// cond.store1:                                      ; preds = %else
282//   %6 = extractelement <16 x i32> %val, i32 1
283//   %7 = getelementptr i32* %1, i32 1
284//   store i32 %6, i32* %7
285//   br label %else2
286//   . . .
287static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
288                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
289  Value *Src = CI->getArgOperand(0);
290  Value *Ptr = CI->getArgOperand(1);
291  Value *Alignment = CI->getArgOperand(2);
292  Value *Mask = CI->getArgOperand(3);
293
294  const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
295  auto *VecType = cast<VectorType>(Src->getType());
296
297  Type *EltTy = VecType->getElementType();
298
299  IRBuilder<> Builder(CI->getContext());
300  Instruction *InsertPt = CI;
301  Builder.SetInsertPoint(InsertPt);
302  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
303
304  // Short-cut if the mask is all-true.
305  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
306    Builder.CreateAlignedStore(Src, Ptr, AlignVal);
307    CI->eraseFromParent();
308    return;
309  }
310
311  // Adjust alignment for the scalar instruction.
312  const Align AdjustedAlignVal =
313      commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
314  // Bitcast %addr from i8* to EltTy*
315  Type *NewPtrType =
316      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
317  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
318  unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
319
320  if (isConstantIntVector(Mask)) {
321    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
322      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
323        continue;
324      Value *OneElt = Builder.CreateExtractElement(Src, Idx);
325      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
326      Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
327    }
328    CI->eraseFromParent();
329    return;
330  }
331
332  // If the mask is not v1i1, use scalar bit test operations. This generates
333  // better results on X86 at least.
334  Value *SclrMask;
335  if (VectorWidth != 1) {
336    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
337    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
338  }
339
340  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
341    // Fill the "else" block, created in the previous iteration
342    //
343    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
344    //  %cond = icmp ne i16 %mask_1, 0
345    //  br i1 %mask_1, label %cond.store, label %else
346    //
347    Value *Predicate;
348    if (VectorWidth != 1) {
349      Value *Mask = Builder.getInt(APInt::getOneBitSet(
350          VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
351      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
352                                       Builder.getIntN(VectorWidth, 0));
353    } else {
354      Predicate = Builder.CreateExtractElement(Mask, Idx);
355    }
356
357    // Create "cond" block
358    //
359    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
360    //  %EltAddr = getelementptr i32* %1, i32 0
361    //  %store i32 %OneElt, i32* %EltAddr
362    //
363    Instruction *ThenTerm =
364        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
365                                  /*BranchWeights=*/nullptr, DTU);
366
367    BasicBlock *CondBlock = ThenTerm->getParent();
368    CondBlock->setName("cond.store");
369
370    Builder.SetInsertPoint(CondBlock->getTerminator());
371    Value *OneElt = Builder.CreateExtractElement(Src, Idx);
372    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
373    Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
374
375    // Create "else" block, fill it in the next iteration
376    BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
377    NewIfBlock->setName("else");
378
379    Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
380  }
381  CI->eraseFromParent();
382
383  ModifiedDT = true;
384}
385
386// Translate a masked gather intrinsic like
387// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
388//                               <16 x i1> %Mask, <16 x i32> %Src)
389// to a chain of basic blocks, with loading element one-by-one if
390// the appropriate mask bit is set
391//
392// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
393// %Mask0 = extractelement <16 x i1> %Mask, i32 0
394// br i1 %Mask0, label %cond.load, label %else
395//
396// cond.load:
397// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
398// %Load0 = load i32, i32* %Ptr0, align 4
399// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
400// br label %else
401//
402// else:
403// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
404// %Mask1 = extractelement <16 x i1> %Mask, i32 1
405// br i1 %Mask1, label %cond.load1, label %else2
406//
407// cond.load1:
408// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
409// %Load1 = load i32, i32* %Ptr1, align 4
410// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
411// br label %else2
412// . . .
413// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
414// ret <16 x i32> %Result
415static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
416                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
417  Value *Ptrs = CI->getArgOperand(0);
418  Value *Alignment = CI->getArgOperand(1);
419  Value *Mask = CI->getArgOperand(2);
420  Value *Src0 = CI->getArgOperand(3);
421
422  auto *VecType = cast<FixedVectorType>(CI->getType());
423  Type *EltTy = VecType->getElementType();
424
425  IRBuilder<> Builder(CI->getContext());
426  Instruction *InsertPt = CI;
427  BasicBlock *IfBlock = CI->getParent();
428  Builder.SetInsertPoint(InsertPt);
429  MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
430
431  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
432
433  // The result vector
434  Value *VResult = Src0;
435  unsigned VectorWidth = VecType->getNumElements();
436
437  // Shorten the way if the mask is a vector of constants.
438  if (isConstantIntVector(Mask)) {
439    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
440      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
441        continue;
442      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
443      LoadInst *Load =
444          Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
445      VResult =
446          Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
447    }
448    CI->replaceAllUsesWith(VResult);
449    CI->eraseFromParent();
450    return;
451  }
452
453  // If the mask is not v1i1, use scalar bit test operations. This generates
454  // better results on X86 at least.
455  Value *SclrMask;
456  if (VectorWidth != 1) {
457    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
458    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
459  }
460
461  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
462    // Fill the "else" block, created in the previous iteration
463    //
464    //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
465    //  %cond = icmp ne i16 %mask_1, 0
466    //  br i1 %Mask1, label %cond.load, label %else
467    //
468
469    Value *Predicate;
470    if (VectorWidth != 1) {
471      Value *Mask = Builder.getInt(APInt::getOneBitSet(
472          VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
473      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
474                                       Builder.getIntN(VectorWidth, 0));
475    } else {
476      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
477    }
478
479    // Create "cond" block
480    //
481    //  %EltAddr = getelementptr i32* %1, i32 0
482    //  %Elt = load i32* %EltAddr
483    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
484    //
485    Instruction *ThenTerm =
486        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
487                                  /*BranchWeights=*/nullptr, DTU);
488
489    BasicBlock *CondBlock = ThenTerm->getParent();
490    CondBlock->setName("cond.load");
491
492    Builder.SetInsertPoint(CondBlock->getTerminator());
493    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
494    LoadInst *Load =
495        Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
496    Value *NewVResult =
497        Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
498
499    // Create "else" block, fill it in the next iteration
500    BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
501    NewIfBlock->setName("else");
502    BasicBlock *PrevIfBlock = IfBlock;
503    IfBlock = NewIfBlock;
504
505    // Create the phi to join the new and previous value.
506    Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
507    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
508    Phi->addIncoming(NewVResult, CondBlock);
509    Phi->addIncoming(VResult, PrevIfBlock);
510    VResult = Phi;
511  }
512
513  CI->replaceAllUsesWith(VResult);
514  CI->eraseFromParent();
515
516  ModifiedDT = true;
517}
518
519// Translate a masked scatter intrinsic, like
520// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
521//                                  <16 x i1> %Mask)
522// to a chain of basic blocks, that stores element one-by-one if
523// the appropriate mask bit is set.
524//
525// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
526// %Mask0 = extractelement <16 x i1> %Mask, i32 0
527// br i1 %Mask0, label %cond.store, label %else
528//
529// cond.store:
530// %Elt0 = extractelement <16 x i32> %Src, i32 0
531// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
532// store i32 %Elt0, i32* %Ptr0, align 4
533// br label %else
534//
535// else:
536// %Mask1 = extractelement <16 x i1> %Mask, i32 1
537// br i1 %Mask1, label %cond.store1, label %else2
538//
539// cond.store1:
540// %Elt1 = extractelement <16 x i32> %Src, i32 1
541// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
542// store i32 %Elt1, i32* %Ptr1, align 4
543// br label %else2
544//   . . .
545static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
546                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
547  Value *Src = CI->getArgOperand(0);
548  Value *Ptrs = CI->getArgOperand(1);
549  Value *Alignment = CI->getArgOperand(2);
550  Value *Mask = CI->getArgOperand(3);
551
552  auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
553
554  assert(
555      isa<VectorType>(Ptrs->getType()) &&
556      isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
557      "Vector of pointers is expected in masked scatter intrinsic");
558
559  IRBuilder<> Builder(CI->getContext());
560  Instruction *InsertPt = CI;
561  Builder.SetInsertPoint(InsertPt);
562  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
563
564  MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
565  unsigned VectorWidth = SrcFVTy->getNumElements();
566
567  // Shorten the way if the mask is a vector of constants.
568  if (isConstantIntVector(Mask)) {
569    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
570      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
571        continue;
572      Value *OneElt =
573          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
574      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
575      Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
576    }
577    CI->eraseFromParent();
578    return;
579  }
580
581  // If the mask is not v1i1, use scalar bit test operations. This generates
582  // better results on X86 at least.
583  Value *SclrMask;
584  if (VectorWidth != 1) {
585    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
586    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
587  }
588
589  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
590    // Fill the "else" block, created in the previous iteration
591    //
592    //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
593    //  %cond = icmp ne i16 %mask_1, 0
594    //  br i1 %Mask1, label %cond.store, label %else
595    //
596    Value *Predicate;
597    if (VectorWidth != 1) {
598      Value *Mask = Builder.getInt(APInt::getOneBitSet(
599          VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
600      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
601                                       Builder.getIntN(VectorWidth, 0));
602    } else {
603      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
604    }
605
606    // Create "cond" block
607    //
608    //  %Elt1 = extractelement <16 x i32> %Src, i32 1
609    //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
610    //  %store i32 %Elt1, i32* %Ptr1
611    //
612    Instruction *ThenTerm =
613        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
614                                  /*BranchWeights=*/nullptr, DTU);
615
616    BasicBlock *CondBlock = ThenTerm->getParent();
617    CondBlock->setName("cond.store");
618
619    Builder.SetInsertPoint(CondBlock->getTerminator());
620    Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
621    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
622    Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
623
624    // Create "else" block, fill it in the next iteration
625    BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
626    NewIfBlock->setName("else");
627
628    Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
629  }
630  CI->eraseFromParent();
631
632  ModifiedDT = true;
633}
634
635static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
636                                      DomTreeUpdater *DTU, bool &ModifiedDT) {
637  Value *Ptr = CI->getArgOperand(0);
638  Value *Mask = CI->getArgOperand(1);
639  Value *PassThru = CI->getArgOperand(2);
640
641  auto *VecType = cast<FixedVectorType>(CI->getType());
642
643  Type *EltTy = VecType->getElementType();
644
645  IRBuilder<> Builder(CI->getContext());
646  Instruction *InsertPt = CI;
647  BasicBlock *IfBlock = CI->getParent();
648
649  Builder.SetInsertPoint(InsertPt);
650  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
651
652  unsigned VectorWidth = VecType->getNumElements();
653
654  // The result vector
655  Value *VResult = PassThru;
656
657  // Shorten the way if the mask is a vector of constants.
658  // Create a build_vector pattern, with loads/undefs as necessary and then
659  // shuffle blend with the pass through value.
660  if (isConstantIntVector(Mask)) {
661    unsigned MemIndex = 0;
662    VResult = UndefValue::get(VecType);
663    SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
664    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
665      Value *InsertElt;
666      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
667        InsertElt = UndefValue::get(EltTy);
668        ShuffleMask[Idx] = Idx + VectorWidth;
669      } else {
670        Value *NewPtr =
671            Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
672        InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
673                                              "Load" + Twine(Idx));
674        ShuffleMask[Idx] = Idx;
675        ++MemIndex;
676      }
677      VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
678                                            "Res" + Twine(Idx));
679    }
680    VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
681    CI->replaceAllUsesWith(VResult);
682    CI->eraseFromParent();
683    return;
684  }
685
686  // If the mask is not v1i1, use scalar bit test operations. This generates
687  // better results on X86 at least.
688  Value *SclrMask;
689  if (VectorWidth != 1) {
690    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
691    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
692  }
693
694  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
695    // Fill the "else" block, created in the previous iteration
696    //
697    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
698    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
699    //  br i1 %mask_1, label %cond.load, label %else
700    //
701
702    Value *Predicate;
703    if (VectorWidth != 1) {
704      Value *Mask = Builder.getInt(APInt::getOneBitSet(
705          VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
706      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
707                                       Builder.getIntN(VectorWidth, 0));
708    } else {
709      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
710    }
711
712    // Create "cond" block
713    //
714    //  %EltAddr = getelementptr i32* %1, i32 0
715    //  %Elt = load i32* %EltAddr
716    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
717    //
718    Instruction *ThenTerm =
719        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
720                                  /*BranchWeights=*/nullptr, DTU);
721
722    BasicBlock *CondBlock = ThenTerm->getParent();
723    CondBlock->setName("cond.load");
724
725    Builder.SetInsertPoint(CondBlock->getTerminator());
726    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
727    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
728
729    // Move the pointer if there are more blocks to come.
730    Value *NewPtr;
731    if ((Idx + 1) != VectorWidth)
732      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
733
734    // Create "else" block, fill it in the next iteration
735    BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
736    NewIfBlock->setName("else");
737    BasicBlock *PrevIfBlock = IfBlock;
738    IfBlock = NewIfBlock;
739
740    // Create the phi to join the new and previous value.
741    Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
742    PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
743    ResultPhi->addIncoming(NewVResult, CondBlock);
744    ResultPhi->addIncoming(VResult, PrevIfBlock);
745    VResult = ResultPhi;
746
747    // Add a PHI for the pointer if this isn't the last iteration.
748    if ((Idx + 1) != VectorWidth) {
749      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
750      PtrPhi->addIncoming(NewPtr, CondBlock);
751      PtrPhi->addIncoming(Ptr, PrevIfBlock);
752      Ptr = PtrPhi;
753    }
754  }
755
756  CI->replaceAllUsesWith(VResult);
757  CI->eraseFromParent();
758
759  ModifiedDT = true;
760}
761
762static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
763                                         DomTreeUpdater *DTU,
764                                         bool &ModifiedDT) {
765  Value *Src = CI->getArgOperand(0);
766  Value *Ptr = CI->getArgOperand(1);
767  Value *Mask = CI->getArgOperand(2);
768
769  auto *VecType = cast<FixedVectorType>(Src->getType());
770
771  IRBuilder<> Builder(CI->getContext());
772  Instruction *InsertPt = CI;
773  BasicBlock *IfBlock = CI->getParent();
774
775  Builder.SetInsertPoint(InsertPt);
776  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
777
778  Type *EltTy = VecType->getElementType();
779
780  unsigned VectorWidth = VecType->getNumElements();
781
782  // Shorten the way if the mask is a vector of constants.
783  if (isConstantIntVector(Mask)) {
784    unsigned MemIndex = 0;
785    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
786      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
787        continue;
788      Value *OneElt =
789          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
790      Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
791      Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
792      ++MemIndex;
793    }
794    CI->eraseFromParent();
795    return;
796  }
797
798  // If the mask is not v1i1, use scalar bit test operations. This generates
799  // better results on X86 at least.
800  Value *SclrMask;
801  if (VectorWidth != 1) {
802    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
803    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
804  }
805
806  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
807    // Fill the "else" block, created in the previous iteration
808    //
809    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
810    //  br i1 %mask_1, label %cond.store, label %else
811    //
812    Value *Predicate;
813    if (VectorWidth != 1) {
814      Value *Mask = Builder.getInt(APInt::getOneBitSet(
815          VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
816      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
817                                       Builder.getIntN(VectorWidth, 0));
818    } else {
819      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
820    }
821
822    // Create "cond" block
823    //
824    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
825    //  %EltAddr = getelementptr i32* %1, i32 0
826    //  %store i32 %OneElt, i32* %EltAddr
827    //
828    Instruction *ThenTerm =
829        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
830                                  /*BranchWeights=*/nullptr, DTU);
831
832    BasicBlock *CondBlock = ThenTerm->getParent();
833    CondBlock->setName("cond.store");
834
835    Builder.SetInsertPoint(CondBlock->getTerminator());
836    Value *OneElt = Builder.CreateExtractElement(Src, Idx);
837    Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
838
839    // Move the pointer if there are more blocks to come.
840    Value *NewPtr;
841    if ((Idx + 1) != VectorWidth)
842      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
843
844    // Create "else" block, fill it in the next iteration
845    BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
846    NewIfBlock->setName("else");
847    BasicBlock *PrevIfBlock = IfBlock;
848    IfBlock = NewIfBlock;
849
850    Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
851
852    // Add a PHI for the pointer if this isn't the last iteration.
853    if ((Idx + 1) != VectorWidth) {
854      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
855      PtrPhi->addIncoming(NewPtr, CondBlock);
856      PtrPhi->addIncoming(Ptr, PrevIfBlock);
857      Ptr = PtrPhi;
858    }
859  }
860  CI->eraseFromParent();
861
862  ModifiedDT = true;
863}
864
865static bool runImpl(Function &F, const TargetTransformInfo &TTI,
866                    DominatorTree *DT) {
867  Optional<DomTreeUpdater> DTU;
868  if (DT)
869    DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
870
871  bool EverMadeChange = false;
872  bool MadeChange = true;
873  auto &DL = F.getParent()->getDataLayout();
874  while (MadeChange) {
875    MadeChange = false;
876    for (Function::iterator I = F.begin(); I != F.end();) {
877      BasicBlock *BB = &*I++;
878      bool ModifiedDTOnIteration = false;
879      MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL,
880                                  DTU.hasValue() ? DTU.getPointer() : nullptr);
881
882
883      // Restart BB iteration if the dominator tree of the Function was changed
884      if (ModifiedDTOnIteration)
885        break;
886    }
887
888    EverMadeChange |= MadeChange;
889  }
890  return EverMadeChange;
891}
892
893bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
894  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
895  DominatorTree *DT = nullptr;
896  if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
897    DT = &DTWP->getDomTree();
898  return runImpl(F, TTI, DT);
899}
900
901PreservedAnalyses
902ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
903  auto &TTI = AM.getResult<TargetIRAnalysis>(F);
904  auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
905  if (!runImpl(F, TTI, DT))
906    return PreservedAnalyses::all();
907  PreservedAnalyses PA;
908  PA.preserve<TargetIRAnalysis>();
909  PA.preserve<DominatorTreeAnalysis>();
910  return PA;
911}
912
913static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
914                          const TargetTransformInfo &TTI, const DataLayout &DL,
915                          DomTreeUpdater *DTU) {
916  bool MadeChange = false;
917
918  BasicBlock::iterator CurInstIterator = BB.begin();
919  while (CurInstIterator != BB.end()) {
920    if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
921      MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
922    if (ModifiedDT)
923      return true;
924  }
925
926  return MadeChange;
927}
928
929static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
930                             const TargetTransformInfo &TTI,
931                             const DataLayout &DL, DomTreeUpdater *DTU) {
932  IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
933  if (II) {
934    // The scalarization code below does not work for scalable vectors.
935    if (isa<ScalableVectorType>(II->getType()) ||
936        any_of(II->arg_operands(),
937               [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
938      return false;
939
940    switch (II->getIntrinsicID()) {
941    default:
942      break;
943    case Intrinsic::masked_load:
944      // Scalarize unsupported vector masked load
945      if (TTI.isLegalMaskedLoad(
946              CI->getType(),
947              cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
948        return false;
949      scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
950      return true;
951    case Intrinsic::masked_store:
952      if (TTI.isLegalMaskedStore(
953              CI->getArgOperand(0)->getType(),
954              cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
955        return false;
956      scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
957      return true;
958    case Intrinsic::masked_gather: {
959      unsigned AlignmentInt =
960          cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
961      Type *LoadTy = CI->getType();
962      Align Alignment =
963          DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
964      if (TTI.isLegalMaskedGather(LoadTy, Alignment))
965        return false;
966      scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
967      return true;
968    }
969    case Intrinsic::masked_scatter: {
970      unsigned AlignmentInt =
971          cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
972      Type *StoreTy = CI->getArgOperand(0)->getType();
973      Align Alignment =
974          DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
975      if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
976        return false;
977      scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
978      return true;
979    }
980    case Intrinsic::masked_expandload:
981      if (TTI.isLegalMaskedExpandLoad(CI->getType()))
982        return false;
983      scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
984      return true;
985    case Intrinsic::masked_compressstore:
986      if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
987        return false;
988      scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
989      return true;
990    }
991  }
992
993  return false;
994}
995