1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2//                                    instrinsics
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/ADT/Twine.h"
17#include "llvm/Analysis/TargetTransformInfo.h"
18#include "llvm/CodeGen/TargetSubtargetInfo.h"
19#include "llvm/IR/BasicBlock.h"
20#include "llvm/IR/Constant.h"
21#include "llvm/IR/Constants.h"
22#include "llvm/IR/DerivedTypes.h"
23#include "llvm/IR/Function.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/InstrTypes.h"
26#include "llvm/IR/Instruction.h"
27#include "llvm/IR/Instructions.h"
28#include "llvm/IR/IntrinsicInst.h"
29#include "llvm/IR/Intrinsics.h"
30#include "llvm/IR/Type.h"
31#include "llvm/IR/Value.h"
32#include "llvm/InitializePasses.h"
33#include "llvm/Pass.h"
34#include "llvm/Support/Casting.h"
35#include <algorithm>
36#include <cassert>
37
38using namespace llvm;
39
40#define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42namespace {
43
44class ScalarizeMaskedMemIntrin : public FunctionPass {
45  const TargetTransformInfo *TTI = nullptr;
46
47public:
48  static char ID; // Pass identification, replacement for typeid
49
50  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51    initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52  }
53
54  bool runOnFunction(Function &F) override;
55
56  StringRef getPassName() const override {
57    return "Scalarize Masked Memory Intrinsics";
58  }
59
60  void getAnalysisUsage(AnalysisUsage &AU) const override {
61    AU.addRequired<TargetTransformInfoWrapperPass>();
62  }
63
64private:
65  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67};
68
69} // end anonymous namespace
70
71char ScalarizeMaskedMemIntrin::ID = 0;
72
73INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74                "Scalarize unsupported masked memory intrinsics", false, false)
75
76FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77  return new ScalarizeMaskedMemIntrin();
78}
79
80static bool isConstantIntVector(Value *Mask) {
81  Constant *C = dyn_cast<Constant>(Mask);
82  if (!C)
83    return false;
84
85  unsigned NumElts = Mask->getType()->getVectorNumElements();
86  for (unsigned i = 0; i != NumElts; ++i) {
87    Constant *CElt = C->getAggregateElement(i);
88    if (!CElt || !isa<ConstantInt>(CElt))
89      return false;
90  }
91
92  return true;
93}
94
95// Translate a masked load intrinsic like
96// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
97//                               <16 x i1> %mask, <16 x i32> %passthru)
98// to a chain of basic blocks, with loading element one-by-one if
99// the appropriate mask bit is set
100//
101//  %1 = bitcast i8* %addr to i32*
102//  %2 = extractelement <16 x i1> %mask, i32 0
103//  br i1 %2, label %cond.load, label %else
104//
105// cond.load:                                        ; preds = %0
106//  %3 = getelementptr i32* %1, i32 0
107//  %4 = load i32* %3
108//  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
109//  br label %else
110//
111// else:                                             ; preds = %0, %cond.load
112//  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113//  %6 = extractelement <16 x i1> %mask, i32 1
114//  br i1 %6, label %cond.load1, label %else2
115//
116// cond.load1:                                       ; preds = %else
117//  %7 = getelementptr i32* %1, i32 1
118//  %8 = load i32* %7
119//  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
120//  br label %else2
121//
122// else2:                                          ; preds = %else, %cond.load1
123//  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124//  %10 = extractelement <16 x i1> %mask, i32 2
125//  br i1 %10, label %cond.load4, label %else5
126//
127static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
128  Value *Ptr = CI->getArgOperand(0);
129  Value *Alignment = CI->getArgOperand(1);
130  Value *Mask = CI->getArgOperand(2);
131  Value *Src0 = CI->getArgOperand(3);
132
133  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
134  VectorType *VecType = cast<VectorType>(CI->getType());
135
136  Type *EltTy = VecType->getElementType();
137
138  IRBuilder<> Builder(CI->getContext());
139  Instruction *InsertPt = CI;
140  BasicBlock *IfBlock = CI->getParent();
141
142  Builder.SetInsertPoint(InsertPt);
143  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
144
145  // Short-cut if the mask is all-true.
146  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
147    Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
148    CI->replaceAllUsesWith(NewI);
149    CI->eraseFromParent();
150    return;
151  }
152
153  // Adjust alignment for the scalar instruction.
154  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
155  // Bitcast %addr from i8* to EltTy*
156  Type *NewPtrType =
157      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
158  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
159  unsigned VectorWidth = VecType->getNumElements();
160
161  // The result vector
162  Value *VResult = Src0;
163
164  if (isConstantIntVector(Mask)) {
165    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
166      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
167        continue;
168      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
169      LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
170      VResult = Builder.CreateInsertElement(VResult, Load, Idx);
171    }
172    CI->replaceAllUsesWith(VResult);
173    CI->eraseFromParent();
174    return;
175  }
176
177  // If the mask is not v1i1, use scalar bit test operations. This generates
178  // better results on X86 at least.
179  Value *SclrMask;
180  if (VectorWidth != 1) {
181    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
182    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
183  }
184
185  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
186    // Fill the "else" block, created in the previous iteration
187    //
188    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
189    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
190    //  %cond = icmp ne i16 %mask_1, 0
191    //  br i1 %mask_1, label %cond.load, label %else
192    //
193    Value *Predicate;
194    if (VectorWidth != 1) {
195      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
196      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
197                                       Builder.getIntN(VectorWidth, 0));
198    } else {
199      Predicate = Builder.CreateExtractElement(Mask, Idx);
200    }
201
202    // Create "cond" block
203    //
204    //  %EltAddr = getelementptr i32* %1, i32 0
205    //  %Elt = load i32* %EltAddr
206    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
207    //
208    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
209                                                     "cond.load");
210    Builder.SetInsertPoint(InsertPt);
211
212    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
213    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
214    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
215
216    // Create "else" block, fill it in the next iteration
217    BasicBlock *NewIfBlock =
218        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
219    Builder.SetInsertPoint(InsertPt);
220    Instruction *OldBr = IfBlock->getTerminator();
221    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
222    OldBr->eraseFromParent();
223    BasicBlock *PrevIfBlock = IfBlock;
224    IfBlock = NewIfBlock;
225
226    // Create the phi to join the new and previous value.
227    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
228    Phi->addIncoming(NewVResult, CondBlock);
229    Phi->addIncoming(VResult, PrevIfBlock);
230    VResult = Phi;
231  }
232
233  CI->replaceAllUsesWith(VResult);
234  CI->eraseFromParent();
235
236  ModifiedDT = true;
237}
238
239// Translate a masked store intrinsic, like
240// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
241//                               <16 x i1> %mask)
242// to a chain of basic blocks, that stores element one-by-one if
243// the appropriate mask bit is set
244//
245//   %1 = bitcast i8* %addr to i32*
246//   %2 = extractelement <16 x i1> %mask, i32 0
247//   br i1 %2, label %cond.store, label %else
248//
249// cond.store:                                       ; preds = %0
250//   %3 = extractelement <16 x i32> %val, i32 0
251//   %4 = getelementptr i32* %1, i32 0
252//   store i32 %3, i32* %4
253//   br label %else
254//
255// else:                                             ; preds = %0, %cond.store
256//   %5 = extractelement <16 x i1> %mask, i32 1
257//   br i1 %5, label %cond.store1, label %else2
258//
259// cond.store1:                                      ; preds = %else
260//   %6 = extractelement <16 x i32> %val, i32 1
261//   %7 = getelementptr i32* %1, i32 1
262//   store i32 %6, i32* %7
263//   br label %else2
264//   . . .
265static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
266  Value *Src = CI->getArgOperand(0);
267  Value *Ptr = CI->getArgOperand(1);
268  Value *Alignment = CI->getArgOperand(2);
269  Value *Mask = CI->getArgOperand(3);
270
271  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
272  VectorType *VecType = cast<VectorType>(Src->getType());
273
274  Type *EltTy = VecType->getElementType();
275
276  IRBuilder<> Builder(CI->getContext());
277  Instruction *InsertPt = CI;
278  BasicBlock *IfBlock = CI->getParent();
279  Builder.SetInsertPoint(InsertPt);
280  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
281
282  // Short-cut if the mask is all-true.
283  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
284    Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285    CI->eraseFromParent();
286    return;
287  }
288
289  // Adjust alignment for the scalar instruction.
290  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
291  // Bitcast %addr from i8* to EltTy*
292  Type *NewPtrType =
293      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
294  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295  unsigned VectorWidth = VecType->getNumElements();
296
297  if (isConstantIntVector(Mask)) {
298    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
300        continue;
301      Value *OneElt = Builder.CreateExtractElement(Src, Idx);
302      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
303      Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
304    }
305    CI->eraseFromParent();
306    return;
307  }
308
309  // If the mask is not v1i1, use scalar bit test operations. This generates
310  // better results on X86 at least.
311  Value *SclrMask;
312  if (VectorWidth != 1) {
313    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
314    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
315  }
316
317  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
318    // Fill the "else" block, created in the previous iteration
319    //
320    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
321    //  %cond = icmp ne i16 %mask_1, 0
322    //  br i1 %mask_1, label %cond.store, label %else
323    //
324    Value *Predicate;
325    if (VectorWidth != 1) {
326      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
327      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
328                                       Builder.getIntN(VectorWidth, 0));
329    } else {
330      Predicate = Builder.CreateExtractElement(Mask, Idx);
331    }
332
333    // Create "cond" block
334    //
335    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
336    //  %EltAddr = getelementptr i32* %1, i32 0
337    //  %store i32 %OneElt, i32* %EltAddr
338    //
339    BasicBlock *CondBlock =
340        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
341    Builder.SetInsertPoint(InsertPt);
342
343    Value *OneElt = Builder.CreateExtractElement(Src, Idx);
344    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
345    Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
346
347    // Create "else" block, fill it in the next iteration
348    BasicBlock *NewIfBlock =
349        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
350    Builder.SetInsertPoint(InsertPt);
351    Instruction *OldBr = IfBlock->getTerminator();
352    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
353    OldBr->eraseFromParent();
354    IfBlock = NewIfBlock;
355  }
356  CI->eraseFromParent();
357
358  ModifiedDT = true;
359}
360
361// Translate a masked gather intrinsic like
362// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
363//                               <16 x i1> %Mask, <16 x i32> %Src)
364// to a chain of basic blocks, with loading element one-by-one if
365// the appropriate mask bit is set
366//
367// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
368// %Mask0 = extractelement <16 x i1> %Mask, i32 0
369// br i1 %Mask0, label %cond.load, label %else
370//
371// cond.load:
372// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
373// %Load0 = load i32, i32* %Ptr0, align 4
374// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
375// br label %else
376//
377// else:
378// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
379// %Mask1 = extractelement <16 x i1> %Mask, i32 1
380// br i1 %Mask1, label %cond.load1, label %else2
381//
382// cond.load1:
383// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
384// %Load1 = load i32, i32* %Ptr1, align 4
385// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
386// br label %else2
387// . . .
388// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
389// ret <16 x i32> %Result
390static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
391  Value *Ptrs = CI->getArgOperand(0);
392  Value *Alignment = CI->getArgOperand(1);
393  Value *Mask = CI->getArgOperand(2);
394  Value *Src0 = CI->getArgOperand(3);
395
396  VectorType *VecType = cast<VectorType>(CI->getType());
397  Type *EltTy = VecType->getElementType();
398
399  IRBuilder<> Builder(CI->getContext());
400  Instruction *InsertPt = CI;
401  BasicBlock *IfBlock = CI->getParent();
402  Builder.SetInsertPoint(InsertPt);
403  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
404
405  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
406
407  // The result vector
408  Value *VResult = Src0;
409  unsigned VectorWidth = VecType->getNumElements();
410
411  // Shorten the way if the mask is a vector of constants.
412  if (isConstantIntVector(Mask)) {
413    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
414      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
415        continue;
416      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
417      LoadInst *Load =
418          Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
419      VResult =
420          Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
421    }
422    CI->replaceAllUsesWith(VResult);
423    CI->eraseFromParent();
424    return;
425  }
426
427  // If the mask is not v1i1, use scalar bit test operations. This generates
428  // better results on X86 at least.
429  Value *SclrMask;
430  if (VectorWidth != 1) {
431    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
432    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
433  }
434
435  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
436    // Fill the "else" block, created in the previous iteration
437    //
438    //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
439    //  %cond = icmp ne i16 %mask_1, 0
440    //  br i1 %Mask1, label %cond.load, label %else
441    //
442
443    Value *Predicate;
444    if (VectorWidth != 1) {
445      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
446      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
447                                       Builder.getIntN(VectorWidth, 0));
448    } else {
449      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
450    }
451
452    // Create "cond" block
453    //
454    //  %EltAddr = getelementptr i32* %1, i32 0
455    //  %Elt = load i32* %EltAddr
456    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
457    //
458    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
459    Builder.SetInsertPoint(InsertPt);
460
461    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
462    LoadInst *Load =
463        Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
464    Value *NewVResult =
465        Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
466
467    // Create "else" block, fill it in the next iteration
468    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
469    Builder.SetInsertPoint(InsertPt);
470    Instruction *OldBr = IfBlock->getTerminator();
471    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
472    OldBr->eraseFromParent();
473    BasicBlock *PrevIfBlock = IfBlock;
474    IfBlock = NewIfBlock;
475
476    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
477    Phi->addIncoming(NewVResult, CondBlock);
478    Phi->addIncoming(VResult, PrevIfBlock);
479    VResult = Phi;
480  }
481
482  CI->replaceAllUsesWith(VResult);
483  CI->eraseFromParent();
484
485  ModifiedDT = true;
486}
487
488// Translate a masked scatter intrinsic, like
489// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
490//                                  <16 x i1> %Mask)
491// to a chain of basic blocks, that stores element one-by-one if
492// the appropriate mask bit is set.
493//
494// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
495// %Mask0 = extractelement <16 x i1> %Mask, i32 0
496// br i1 %Mask0, label %cond.store, label %else
497//
498// cond.store:
499// %Elt0 = extractelement <16 x i32> %Src, i32 0
500// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
501// store i32 %Elt0, i32* %Ptr0, align 4
502// br label %else
503//
504// else:
505// %Mask1 = extractelement <16 x i1> %Mask, i32 1
506// br i1 %Mask1, label %cond.store1, label %else2
507//
508// cond.store1:
509// %Elt1 = extractelement <16 x i32> %Src, i32 1
510// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
511// store i32 %Elt1, i32* %Ptr1, align 4
512// br label %else2
513//   . . .
514static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
515  Value *Src = CI->getArgOperand(0);
516  Value *Ptrs = CI->getArgOperand(1);
517  Value *Alignment = CI->getArgOperand(2);
518  Value *Mask = CI->getArgOperand(3);
519
520  assert(isa<VectorType>(Src->getType()) &&
521         "Unexpected data type in masked scatter intrinsic");
522  assert(isa<VectorType>(Ptrs->getType()) &&
523         isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
524         "Vector of pointers is expected in masked scatter intrinsic");
525
526  IRBuilder<> Builder(CI->getContext());
527  Instruction *InsertPt = CI;
528  BasicBlock *IfBlock = CI->getParent();
529  Builder.SetInsertPoint(InsertPt);
530  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
531
532  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
533  unsigned VectorWidth = Src->getType()->getVectorNumElements();
534
535  // Shorten the way if the mask is a vector of constants.
536  if (isConstantIntVector(Mask)) {
537    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
538      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
539        continue;
540      Value *OneElt =
541          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
542      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
543      Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
544    }
545    CI->eraseFromParent();
546    return;
547  }
548
549  // If the mask is not v1i1, use scalar bit test operations. This generates
550  // better results on X86 at least.
551  Value *SclrMask;
552  if (VectorWidth != 1) {
553    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
554    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
555  }
556
557  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
558    // Fill the "else" block, created in the previous iteration
559    //
560    //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
561    //  %cond = icmp ne i16 %mask_1, 0
562    //  br i1 %Mask1, label %cond.store, label %else
563    //
564    Value *Predicate;
565    if (VectorWidth != 1) {
566      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
567      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
568                                       Builder.getIntN(VectorWidth, 0));
569    } else {
570      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
571    }
572
573    // Create "cond" block
574    //
575    //  %Elt1 = extractelement <16 x i32> %Src, i32 1
576    //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
577    //  %store i32 %Elt1, i32* %Ptr1
578    //
579    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
580    Builder.SetInsertPoint(InsertPt);
581
582    Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
583    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
584    Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
585
586    // Create "else" block, fill it in the next iteration
587    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
588    Builder.SetInsertPoint(InsertPt);
589    Instruction *OldBr = IfBlock->getTerminator();
590    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
591    OldBr->eraseFromParent();
592    IfBlock = NewIfBlock;
593  }
594  CI->eraseFromParent();
595
596  ModifiedDT = true;
597}
598
599static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
600  Value *Ptr = CI->getArgOperand(0);
601  Value *Mask = CI->getArgOperand(1);
602  Value *PassThru = CI->getArgOperand(2);
603
604  VectorType *VecType = cast<VectorType>(CI->getType());
605
606  Type *EltTy = VecType->getElementType();
607
608  IRBuilder<> Builder(CI->getContext());
609  Instruction *InsertPt = CI;
610  BasicBlock *IfBlock = CI->getParent();
611
612  Builder.SetInsertPoint(InsertPt);
613  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
614
615  unsigned VectorWidth = VecType->getNumElements();
616
617  // The result vector
618  Value *VResult = PassThru;
619
620  // Shorten the way if the mask is a vector of constants.
621  if (isConstantIntVector(Mask)) {
622    unsigned MemIndex = 0;
623    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
624      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
625        continue;
626      Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
627      LoadInst *Load =
628          Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx));
629      VResult =
630          Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
631      ++MemIndex;
632    }
633    CI->replaceAllUsesWith(VResult);
634    CI->eraseFromParent();
635    return;
636  }
637
638  // If the mask is not v1i1, use scalar bit test operations. This generates
639  // better results on X86 at least.
640  Value *SclrMask;
641  if (VectorWidth != 1) {
642    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
643    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
644  }
645
646  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
647    // Fill the "else" block, created in the previous iteration
648    //
649    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
650    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
651    //  br i1 %mask_1, label %cond.load, label %else
652    //
653
654    Value *Predicate;
655    if (VectorWidth != 1) {
656      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
657      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
658                                       Builder.getIntN(VectorWidth, 0));
659    } else {
660      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
661    }
662
663    // Create "cond" block
664    //
665    //  %EltAddr = getelementptr i32* %1, i32 0
666    //  %Elt = load i32* %EltAddr
667    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
668    //
669    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
670                                                     "cond.load");
671    Builder.SetInsertPoint(InsertPt);
672
673    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
674    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
675
676    // Move the pointer if there are more blocks to come.
677    Value *NewPtr;
678    if ((Idx + 1) != VectorWidth)
679      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
680
681    // Create "else" block, fill it in the next iteration
682    BasicBlock *NewIfBlock =
683        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
684    Builder.SetInsertPoint(InsertPt);
685    Instruction *OldBr = IfBlock->getTerminator();
686    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
687    OldBr->eraseFromParent();
688    BasicBlock *PrevIfBlock = IfBlock;
689    IfBlock = NewIfBlock;
690
691    // Create the phi to join the new and previous value.
692    PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
693    ResultPhi->addIncoming(NewVResult, CondBlock);
694    ResultPhi->addIncoming(VResult, PrevIfBlock);
695    VResult = ResultPhi;
696
697    // Add a PHI for the pointer if this isn't the last iteration.
698    if ((Idx + 1) != VectorWidth) {
699      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
700      PtrPhi->addIncoming(NewPtr, CondBlock);
701      PtrPhi->addIncoming(Ptr, PrevIfBlock);
702      Ptr = PtrPhi;
703    }
704  }
705
706  CI->replaceAllUsesWith(VResult);
707  CI->eraseFromParent();
708
709  ModifiedDT = true;
710}
711
712static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
713  Value *Src = CI->getArgOperand(0);
714  Value *Ptr = CI->getArgOperand(1);
715  Value *Mask = CI->getArgOperand(2);
716
717  VectorType *VecType = cast<VectorType>(Src->getType());
718
719  IRBuilder<> Builder(CI->getContext());
720  Instruction *InsertPt = CI;
721  BasicBlock *IfBlock = CI->getParent();
722
723  Builder.SetInsertPoint(InsertPt);
724  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
725
726  Type *EltTy = VecType->getVectorElementType();
727
728  unsigned VectorWidth = VecType->getNumElements();
729
730  // Shorten the way if the mask is a vector of constants.
731  if (isConstantIntVector(Mask)) {
732    unsigned MemIndex = 0;
733    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
734      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
735        continue;
736      Value *OneElt =
737          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
738      Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
739      Builder.CreateAlignedStore(OneElt, NewPtr, 1);
740      ++MemIndex;
741    }
742    CI->eraseFromParent();
743    return;
744  }
745
746  // If the mask is not v1i1, use scalar bit test operations. This generates
747  // better results on X86 at least.
748  Value *SclrMask;
749  if (VectorWidth != 1) {
750    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
751    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
752  }
753
754  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
755    // Fill the "else" block, created in the previous iteration
756    //
757    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
758    //  br i1 %mask_1, label %cond.store, label %else
759    //
760    Value *Predicate;
761    if (VectorWidth != 1) {
762      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
763      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
764                                       Builder.getIntN(VectorWidth, 0));
765    } else {
766      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
767    }
768
769    // Create "cond" block
770    //
771    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
772    //  %EltAddr = getelementptr i32* %1, i32 0
773    //  %store i32 %OneElt, i32* %EltAddr
774    //
775    BasicBlock *CondBlock =
776        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
777    Builder.SetInsertPoint(InsertPt);
778
779    Value *OneElt = Builder.CreateExtractElement(Src, Idx);
780    Builder.CreateAlignedStore(OneElt, Ptr, 1);
781
782    // Move the pointer if there are more blocks to come.
783    Value *NewPtr;
784    if ((Idx + 1) != VectorWidth)
785      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
786
787    // Create "else" block, fill it in the next iteration
788    BasicBlock *NewIfBlock =
789        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
790    Builder.SetInsertPoint(InsertPt);
791    Instruction *OldBr = IfBlock->getTerminator();
792    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
793    OldBr->eraseFromParent();
794    BasicBlock *PrevIfBlock = IfBlock;
795    IfBlock = NewIfBlock;
796
797    // Add a PHI for the pointer if this isn't the last iteration.
798    if ((Idx + 1) != VectorWidth) {
799      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
800      PtrPhi->addIncoming(NewPtr, CondBlock);
801      PtrPhi->addIncoming(Ptr, PrevIfBlock);
802      Ptr = PtrPhi;
803    }
804  }
805  CI->eraseFromParent();
806
807  ModifiedDT = true;
808}
809
810bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
811  bool EverMadeChange = false;
812
813  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
814
815  bool MadeChange = true;
816  while (MadeChange) {
817    MadeChange = false;
818    for (Function::iterator I = F.begin(); I != F.end();) {
819      BasicBlock *BB = &*I++;
820      bool ModifiedDTOnIteration = false;
821      MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
822
823      // Restart BB iteration if the dominator tree of the Function was changed
824      if (ModifiedDTOnIteration)
825        break;
826    }
827
828    EverMadeChange |= MadeChange;
829  }
830
831  return EverMadeChange;
832}
833
834bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
835  bool MadeChange = false;
836
837  BasicBlock::iterator CurInstIterator = BB.begin();
838  while (CurInstIterator != BB.end()) {
839    if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
840      MadeChange |= optimizeCallInst(CI, ModifiedDT);
841    if (ModifiedDT)
842      return true;
843  }
844
845  return MadeChange;
846}
847
848bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
849                                                bool &ModifiedDT) {
850  IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
851  if (II) {
852    unsigned Alignment;
853    switch (II->getIntrinsicID()) {
854    default:
855      break;
856    case Intrinsic::masked_load: {
857      // Scalarize unsupported vector masked load
858      Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
859      if (TTI->isLegalMaskedLoad(CI->getType(), MaybeAlign(Alignment)))
860        return false;
861      scalarizeMaskedLoad(CI, ModifiedDT);
862      return true;
863    }
864    case Intrinsic::masked_store: {
865      Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
866      if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(),
867                                  MaybeAlign(Alignment)))
868        return false;
869      scalarizeMaskedStore(CI, ModifiedDT);
870      return true;
871    }
872    case Intrinsic::masked_gather:
873      Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
874      if (TTI->isLegalMaskedGather(CI->getType(), MaybeAlign(Alignment)))
875        return false;
876      scalarizeMaskedGather(CI, ModifiedDT);
877      return true;
878    case Intrinsic::masked_scatter:
879      Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
880      if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType(),
881                                    MaybeAlign(Alignment)))
882        return false;
883      scalarizeMaskedScatter(CI, ModifiedDT);
884      return true;
885    case Intrinsic::masked_expandload:
886      if (TTI->isLegalMaskedExpandLoad(CI->getType()))
887        return false;
888      scalarizeMaskedExpandLoad(CI, ModifiedDT);
889      return true;
890    case Intrinsic::masked_compressstore:
891      if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
892        return false;
893      scalarizeMaskedCompressStore(CI, ModifiedDT);
894      return true;
895    }
896  }
897
898  return false;
899}
900