1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2// instrinsics 3// 4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5// See https://llvm.org/LICENSE.txt for license information. 6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7// 8//===----------------------------------------------------------------------===// 9// 10// This pass replaces masked memory intrinsics - when unsupported by the target 11// - with a chain of basic blocks, that deal with the elements one-by-one if the 12// appropriate mask bit is set. 13// 14//===----------------------------------------------------------------------===// 15 16#include "llvm/ADT/Twine.h" 17#include "llvm/Analysis/TargetTransformInfo.h" 18#include "llvm/CodeGen/TargetSubtargetInfo.h" 19#include "llvm/IR/BasicBlock.h" 20#include "llvm/IR/Constant.h" 21#include "llvm/IR/Constants.h" 22#include "llvm/IR/DerivedTypes.h" 23#include "llvm/IR/Function.h" 24#include "llvm/IR/IRBuilder.h" 25#include "llvm/IR/InstrTypes.h" 26#include "llvm/IR/Instruction.h" 27#include "llvm/IR/Instructions.h" 28#include "llvm/IR/IntrinsicInst.h" 29#include "llvm/IR/Intrinsics.h" 30#include "llvm/IR/Type.h" 31#include "llvm/IR/Value.h" 32#include "llvm/InitializePasses.h" 33#include "llvm/Pass.h" 34#include "llvm/Support/Casting.h" 35#include <algorithm> 36#include <cassert> 37 38using namespace llvm; 39 40#define DEBUG_TYPE "scalarize-masked-mem-intrin" 41 42namespace { 43 44class ScalarizeMaskedMemIntrin : public FunctionPass { 45 const TargetTransformInfo *TTI = nullptr; 46 47public: 48 static char ID; // Pass identification, replacement for typeid 49 50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) { 51 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry()); 52 } 53 54 bool runOnFunction(Function &F) override; 55 56 StringRef getPassName() const override { 57 return "Scalarize Masked Memory Intrinsics"; 58 } 59 60 void getAnalysisUsage(AnalysisUsage &AU) const override { 61 AU.addRequired<TargetTransformInfoWrapperPass>(); 62 } 63 64private: 65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT); 66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); 67}; 68 69} // end anonymous namespace 70 71char ScalarizeMaskedMemIntrin::ID = 0; 72 73INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, 74 "Scalarize unsupported masked memory intrinsics", false, false) 75 76FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() { 77 return new ScalarizeMaskedMemIntrin(); 78} 79 80static bool isConstantIntVector(Value *Mask) { 81 Constant *C = dyn_cast<Constant>(Mask); 82 if (!C) 83 return false; 84 85 unsigned NumElts = Mask->getType()->getVectorNumElements(); 86 for (unsigned i = 0; i != NumElts; ++i) { 87 Constant *CElt = C->getAggregateElement(i); 88 if (!CElt || !isa<ConstantInt>(CElt)) 89 return false; 90 } 91 92 return true; 93} 94 95// Translate a masked load intrinsic like 96// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 97// <16 x i1> %mask, <16 x i32> %passthru) 98// to a chain of basic blocks, with loading element one-by-one if 99// the appropriate mask bit is set 100// 101// %1 = bitcast i8* %addr to i32* 102// %2 = extractelement <16 x i1> %mask, i32 0 103// br i1 %2, label %cond.load, label %else 104// 105// cond.load: ; preds = %0 106// %3 = getelementptr i32* %1, i32 0 107// %4 = load i32* %3 108// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 109// br label %else 110// 111// else: ; preds = %0, %cond.load 112// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] 113// %6 = extractelement <16 x i1> %mask, i32 1 114// br i1 %6, label %cond.load1, label %else2 115// 116// cond.load1: ; preds = %else 117// %7 = getelementptr i32* %1, i32 1 118// %8 = load i32* %7 119// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 120// br label %else2 121// 122// else2: ; preds = %else, %cond.load1 123// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 124// %10 = extractelement <16 x i1> %mask, i32 2 125// br i1 %10, label %cond.load4, label %else5 126// 127static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { 128 Value *Ptr = CI->getArgOperand(0); 129 Value *Alignment = CI->getArgOperand(1); 130 Value *Mask = CI->getArgOperand(2); 131 Value *Src0 = CI->getArgOperand(3); 132 133 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 134 VectorType *VecType = cast<VectorType>(CI->getType()); 135 136 Type *EltTy = VecType->getElementType(); 137 138 IRBuilder<> Builder(CI->getContext()); 139 Instruction *InsertPt = CI; 140 BasicBlock *IfBlock = CI->getParent(); 141 142 Builder.SetInsertPoint(InsertPt); 143 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 144 145 // Short-cut if the mask is all-true. 146 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 147 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 148 CI->replaceAllUsesWith(NewI); 149 CI->eraseFromParent(); 150 return; 151 } 152 153 // Adjust alignment for the scalar instruction. 154 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 155 // Bitcast %addr from i8* to EltTy* 156 Type *NewPtrType = 157 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 158 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 159 unsigned VectorWidth = VecType->getNumElements(); 160 161 // The result vector 162 Value *VResult = Src0; 163 164 if (isConstantIntVector(Mask)) { 165 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 166 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 167 continue; 168 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 169 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); 170 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 171 } 172 CI->replaceAllUsesWith(VResult); 173 CI->eraseFromParent(); 174 return; 175 } 176 177 // If the mask is not v1i1, use scalar bit test operations. This generates 178 // better results on X86 at least. 179 Value *SclrMask; 180 if (VectorWidth != 1) { 181 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 182 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 183 } 184 185 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 186 // Fill the "else" block, created in the previous iteration 187 // 188 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 189 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 190 // %cond = icmp ne i16 %mask_1, 0 191 // br i1 %mask_1, label %cond.load, label %else 192 // 193 Value *Predicate; 194 if (VectorWidth != 1) { 195 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 196 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 197 Builder.getIntN(VectorWidth, 0)); 198 } else { 199 Predicate = Builder.CreateExtractElement(Mask, Idx); 200 } 201 202 // Create "cond" block 203 // 204 // %EltAddr = getelementptr i32* %1, i32 0 205 // %Elt = load i32* %EltAddr 206 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 207 // 208 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), 209 "cond.load"); 210 Builder.SetInsertPoint(InsertPt); 211 212 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 213 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); 214 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 215 216 // Create "else" block, fill it in the next iteration 217 BasicBlock *NewIfBlock = 218 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 219 Builder.SetInsertPoint(InsertPt); 220 Instruction *OldBr = IfBlock->getTerminator(); 221 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 222 OldBr->eraseFromParent(); 223 BasicBlock *PrevIfBlock = IfBlock; 224 IfBlock = NewIfBlock; 225 226 // Create the phi to join the new and previous value. 227 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 228 Phi->addIncoming(NewVResult, CondBlock); 229 Phi->addIncoming(VResult, PrevIfBlock); 230 VResult = Phi; 231 } 232 233 CI->replaceAllUsesWith(VResult); 234 CI->eraseFromParent(); 235 236 ModifiedDT = true; 237} 238 239// Translate a masked store intrinsic, like 240// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 241// <16 x i1> %mask) 242// to a chain of basic blocks, that stores element one-by-one if 243// the appropriate mask bit is set 244// 245// %1 = bitcast i8* %addr to i32* 246// %2 = extractelement <16 x i1> %mask, i32 0 247// br i1 %2, label %cond.store, label %else 248// 249// cond.store: ; preds = %0 250// %3 = extractelement <16 x i32> %val, i32 0 251// %4 = getelementptr i32* %1, i32 0 252// store i32 %3, i32* %4 253// br label %else 254// 255// else: ; preds = %0, %cond.store 256// %5 = extractelement <16 x i1> %mask, i32 1 257// br i1 %5, label %cond.store1, label %else2 258// 259// cond.store1: ; preds = %else 260// %6 = extractelement <16 x i32> %val, i32 1 261// %7 = getelementptr i32* %1, i32 1 262// store i32 %6, i32* %7 263// br label %else2 264// . . . 265static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { 266 Value *Src = CI->getArgOperand(0); 267 Value *Ptr = CI->getArgOperand(1); 268 Value *Alignment = CI->getArgOperand(2); 269 Value *Mask = CI->getArgOperand(3); 270 271 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 272 VectorType *VecType = cast<VectorType>(Src->getType()); 273 274 Type *EltTy = VecType->getElementType(); 275 276 IRBuilder<> Builder(CI->getContext()); 277 Instruction *InsertPt = CI; 278 BasicBlock *IfBlock = CI->getParent(); 279 Builder.SetInsertPoint(InsertPt); 280 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 281 282 // Short-cut if the mask is all-true. 283 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 284 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 285 CI->eraseFromParent(); 286 return; 287 } 288 289 // Adjust alignment for the scalar instruction. 290 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 291 // Bitcast %addr from i8* to EltTy* 292 Type *NewPtrType = 293 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 294 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 295 unsigned VectorWidth = VecType->getNumElements(); 296 297 if (isConstantIntVector(Mask)) { 298 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 299 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 300 continue; 301 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 302 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 303 Builder.CreateAlignedStore(OneElt, Gep, AlignVal); 304 } 305 CI->eraseFromParent(); 306 return; 307 } 308 309 // If the mask is not v1i1, use scalar bit test operations. This generates 310 // better results on X86 at least. 311 Value *SclrMask; 312 if (VectorWidth != 1) { 313 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 314 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 315 } 316 317 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 318 // Fill the "else" block, created in the previous iteration 319 // 320 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 321 // %cond = icmp ne i16 %mask_1, 0 322 // br i1 %mask_1, label %cond.store, label %else 323 // 324 Value *Predicate; 325 if (VectorWidth != 1) { 326 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 327 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 328 Builder.getIntN(VectorWidth, 0)); 329 } else { 330 Predicate = Builder.CreateExtractElement(Mask, Idx); 331 } 332 333 // Create "cond" block 334 // 335 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 336 // %EltAddr = getelementptr i32* %1, i32 0 337 // %store i32 %OneElt, i32* %EltAddr 338 // 339 BasicBlock *CondBlock = 340 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 341 Builder.SetInsertPoint(InsertPt); 342 343 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 344 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 345 Builder.CreateAlignedStore(OneElt, Gep, AlignVal); 346 347 // Create "else" block, fill it in the next iteration 348 BasicBlock *NewIfBlock = 349 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 350 Builder.SetInsertPoint(InsertPt); 351 Instruction *OldBr = IfBlock->getTerminator(); 352 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 353 OldBr->eraseFromParent(); 354 IfBlock = NewIfBlock; 355 } 356 CI->eraseFromParent(); 357 358 ModifiedDT = true; 359} 360 361// Translate a masked gather intrinsic like 362// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 363// <16 x i1> %Mask, <16 x i32> %Src) 364// to a chain of basic blocks, with loading element one-by-one if 365// the appropriate mask bit is set 366// 367// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 368// %Mask0 = extractelement <16 x i1> %Mask, i32 0 369// br i1 %Mask0, label %cond.load, label %else 370// 371// cond.load: 372// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 373// %Load0 = load i32, i32* %Ptr0, align 4 374// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 375// br label %else 376// 377// else: 378// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] 379// %Mask1 = extractelement <16 x i1> %Mask, i32 1 380// br i1 %Mask1, label %cond.load1, label %else2 381// 382// cond.load1: 383// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 384// %Load1 = load i32, i32* %Ptr1, align 4 385// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 386// br label %else2 387// . . . 388// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 389// ret <16 x i32> %Result 390static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { 391 Value *Ptrs = CI->getArgOperand(0); 392 Value *Alignment = CI->getArgOperand(1); 393 Value *Mask = CI->getArgOperand(2); 394 Value *Src0 = CI->getArgOperand(3); 395 396 VectorType *VecType = cast<VectorType>(CI->getType()); 397 Type *EltTy = VecType->getElementType(); 398 399 IRBuilder<> Builder(CI->getContext()); 400 Instruction *InsertPt = CI; 401 BasicBlock *IfBlock = CI->getParent(); 402 Builder.SetInsertPoint(InsertPt); 403 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 404 405 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 406 407 // The result vector 408 Value *VResult = Src0; 409 unsigned VectorWidth = VecType->getNumElements(); 410 411 // Shorten the way if the mask is a vector of constants. 412 if (isConstantIntVector(Mask)) { 413 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 414 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 415 continue; 416 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 417 LoadInst *Load = 418 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 419 VResult = 420 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 421 } 422 CI->replaceAllUsesWith(VResult); 423 CI->eraseFromParent(); 424 return; 425 } 426 427 // If the mask is not v1i1, use scalar bit test operations. This generates 428 // better results on X86 at least. 429 Value *SclrMask; 430 if (VectorWidth != 1) { 431 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 432 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 433 } 434 435 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 436 // Fill the "else" block, created in the previous iteration 437 // 438 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 439 // %cond = icmp ne i16 %mask_1, 0 440 // br i1 %Mask1, label %cond.load, label %else 441 // 442 443 Value *Predicate; 444 if (VectorWidth != 1) { 445 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 446 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 447 Builder.getIntN(VectorWidth, 0)); 448 } else { 449 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 450 } 451 452 // Create "cond" block 453 // 454 // %EltAddr = getelementptr i32* %1, i32 0 455 // %Elt = load i32* %EltAddr 456 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 457 // 458 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); 459 Builder.SetInsertPoint(InsertPt); 460 461 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 462 LoadInst *Load = 463 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 464 Value *NewVResult = 465 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 466 467 // Create "else" block, fill it in the next iteration 468 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 469 Builder.SetInsertPoint(InsertPt); 470 Instruction *OldBr = IfBlock->getTerminator(); 471 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 472 OldBr->eraseFromParent(); 473 BasicBlock *PrevIfBlock = IfBlock; 474 IfBlock = NewIfBlock; 475 476 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 477 Phi->addIncoming(NewVResult, CondBlock); 478 Phi->addIncoming(VResult, PrevIfBlock); 479 VResult = Phi; 480 } 481 482 CI->replaceAllUsesWith(VResult); 483 CI->eraseFromParent(); 484 485 ModifiedDT = true; 486} 487 488// Translate a masked scatter intrinsic, like 489// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 490// <16 x i1> %Mask) 491// to a chain of basic blocks, that stores element one-by-one if 492// the appropriate mask bit is set. 493// 494// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 495// %Mask0 = extractelement <16 x i1> %Mask, i32 0 496// br i1 %Mask0, label %cond.store, label %else 497// 498// cond.store: 499// %Elt0 = extractelement <16 x i32> %Src, i32 0 500// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 501// store i32 %Elt0, i32* %Ptr0, align 4 502// br label %else 503// 504// else: 505// %Mask1 = extractelement <16 x i1> %Mask, i32 1 506// br i1 %Mask1, label %cond.store1, label %else2 507// 508// cond.store1: 509// %Elt1 = extractelement <16 x i32> %Src, i32 1 510// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 511// store i32 %Elt1, i32* %Ptr1, align 4 512// br label %else2 513// . . . 514static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { 515 Value *Src = CI->getArgOperand(0); 516 Value *Ptrs = CI->getArgOperand(1); 517 Value *Alignment = CI->getArgOperand(2); 518 Value *Mask = CI->getArgOperand(3); 519 520 assert(isa<VectorType>(Src->getType()) && 521 "Unexpected data type in masked scatter intrinsic"); 522 assert(isa<VectorType>(Ptrs->getType()) && 523 isa<PointerType>(Ptrs->getType()->getVectorElementType()) && 524 "Vector of pointers is expected in masked scatter intrinsic"); 525 526 IRBuilder<> Builder(CI->getContext()); 527 Instruction *InsertPt = CI; 528 BasicBlock *IfBlock = CI->getParent(); 529 Builder.SetInsertPoint(InsertPt); 530 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 531 532 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 533 unsigned VectorWidth = Src->getType()->getVectorNumElements(); 534 535 // Shorten the way if the mask is a vector of constants. 536 if (isConstantIntVector(Mask)) { 537 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 538 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 539 continue; 540 Value *OneElt = 541 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 542 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 543 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 544 } 545 CI->eraseFromParent(); 546 return; 547 } 548 549 // If the mask is not v1i1, use scalar bit test operations. This generates 550 // better results on X86 at least. 551 Value *SclrMask; 552 if (VectorWidth != 1) { 553 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 554 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 555 } 556 557 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 558 // Fill the "else" block, created in the previous iteration 559 // 560 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 561 // %cond = icmp ne i16 %mask_1, 0 562 // br i1 %Mask1, label %cond.store, label %else 563 // 564 Value *Predicate; 565 if (VectorWidth != 1) { 566 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 567 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 568 Builder.getIntN(VectorWidth, 0)); 569 } else { 570 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 571 } 572 573 // Create "cond" block 574 // 575 // %Elt1 = extractelement <16 x i32> %Src, i32 1 576 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 577 // %store i32 %Elt1, i32* %Ptr1 578 // 579 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); 580 Builder.SetInsertPoint(InsertPt); 581 582 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 583 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 584 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 585 586 // Create "else" block, fill it in the next iteration 587 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 588 Builder.SetInsertPoint(InsertPt); 589 Instruction *OldBr = IfBlock->getTerminator(); 590 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 591 OldBr->eraseFromParent(); 592 IfBlock = NewIfBlock; 593 } 594 CI->eraseFromParent(); 595 596 ModifiedDT = true; 597} 598 599static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { 600 Value *Ptr = CI->getArgOperand(0); 601 Value *Mask = CI->getArgOperand(1); 602 Value *PassThru = CI->getArgOperand(2); 603 604 VectorType *VecType = cast<VectorType>(CI->getType()); 605 606 Type *EltTy = VecType->getElementType(); 607 608 IRBuilder<> Builder(CI->getContext()); 609 Instruction *InsertPt = CI; 610 BasicBlock *IfBlock = CI->getParent(); 611 612 Builder.SetInsertPoint(InsertPt); 613 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 614 615 unsigned VectorWidth = VecType->getNumElements(); 616 617 // The result vector 618 Value *VResult = PassThru; 619 620 // Shorten the way if the mask is a vector of constants. 621 if (isConstantIntVector(Mask)) { 622 unsigned MemIndex = 0; 623 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 624 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 625 continue; 626 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 627 LoadInst *Load = 628 Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx)); 629 VResult = 630 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 631 ++MemIndex; 632 } 633 CI->replaceAllUsesWith(VResult); 634 CI->eraseFromParent(); 635 return; 636 } 637 638 // If the mask is not v1i1, use scalar bit test operations. This generates 639 // better results on X86 at least. 640 Value *SclrMask; 641 if (VectorWidth != 1) { 642 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 643 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 644 } 645 646 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 647 // Fill the "else" block, created in the previous iteration 648 // 649 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 650 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 651 // br i1 %mask_1, label %cond.load, label %else 652 // 653 654 Value *Predicate; 655 if (VectorWidth != 1) { 656 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 657 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 658 Builder.getIntN(VectorWidth, 0)); 659 } else { 660 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 661 } 662 663 // Create "cond" block 664 // 665 // %EltAddr = getelementptr i32* %1, i32 0 666 // %Elt = load i32* %EltAddr 667 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 668 // 669 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), 670 "cond.load"); 671 Builder.SetInsertPoint(InsertPt); 672 673 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1); 674 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 675 676 // Move the pointer if there are more blocks to come. 677 Value *NewPtr; 678 if ((Idx + 1) != VectorWidth) 679 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 680 681 // Create "else" block, fill it in the next iteration 682 BasicBlock *NewIfBlock = 683 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 684 Builder.SetInsertPoint(InsertPt); 685 Instruction *OldBr = IfBlock->getTerminator(); 686 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 687 OldBr->eraseFromParent(); 688 BasicBlock *PrevIfBlock = IfBlock; 689 IfBlock = NewIfBlock; 690 691 // Create the phi to join the new and previous value. 692 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 693 ResultPhi->addIncoming(NewVResult, CondBlock); 694 ResultPhi->addIncoming(VResult, PrevIfBlock); 695 VResult = ResultPhi; 696 697 // Add a PHI for the pointer if this isn't the last iteration. 698 if ((Idx + 1) != VectorWidth) { 699 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 700 PtrPhi->addIncoming(NewPtr, CondBlock); 701 PtrPhi->addIncoming(Ptr, PrevIfBlock); 702 Ptr = PtrPhi; 703 } 704 } 705 706 CI->replaceAllUsesWith(VResult); 707 CI->eraseFromParent(); 708 709 ModifiedDT = true; 710} 711 712static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { 713 Value *Src = CI->getArgOperand(0); 714 Value *Ptr = CI->getArgOperand(1); 715 Value *Mask = CI->getArgOperand(2); 716 717 VectorType *VecType = cast<VectorType>(Src->getType()); 718 719 IRBuilder<> Builder(CI->getContext()); 720 Instruction *InsertPt = CI; 721 BasicBlock *IfBlock = CI->getParent(); 722 723 Builder.SetInsertPoint(InsertPt); 724 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 725 726 Type *EltTy = VecType->getVectorElementType(); 727 728 unsigned VectorWidth = VecType->getNumElements(); 729 730 // Shorten the way if the mask is a vector of constants. 731 if (isConstantIntVector(Mask)) { 732 unsigned MemIndex = 0; 733 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 734 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 735 continue; 736 Value *OneElt = 737 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 738 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 739 Builder.CreateAlignedStore(OneElt, NewPtr, 1); 740 ++MemIndex; 741 } 742 CI->eraseFromParent(); 743 return; 744 } 745 746 // If the mask is not v1i1, use scalar bit test operations. This generates 747 // better results on X86 at least. 748 Value *SclrMask; 749 if (VectorWidth != 1) { 750 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 751 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 752 } 753 754 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 755 // Fill the "else" block, created in the previous iteration 756 // 757 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 758 // br i1 %mask_1, label %cond.store, label %else 759 // 760 Value *Predicate; 761 if (VectorWidth != 1) { 762 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 763 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 764 Builder.getIntN(VectorWidth, 0)); 765 } else { 766 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 767 } 768 769 // Create "cond" block 770 // 771 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 772 // %EltAddr = getelementptr i32* %1, i32 0 773 // %store i32 %OneElt, i32* %EltAddr 774 // 775 BasicBlock *CondBlock = 776 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 777 Builder.SetInsertPoint(InsertPt); 778 779 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 780 Builder.CreateAlignedStore(OneElt, Ptr, 1); 781 782 // Move the pointer if there are more blocks to come. 783 Value *NewPtr; 784 if ((Idx + 1) != VectorWidth) 785 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 786 787 // Create "else" block, fill it in the next iteration 788 BasicBlock *NewIfBlock = 789 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 790 Builder.SetInsertPoint(InsertPt); 791 Instruction *OldBr = IfBlock->getTerminator(); 792 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 793 OldBr->eraseFromParent(); 794 BasicBlock *PrevIfBlock = IfBlock; 795 IfBlock = NewIfBlock; 796 797 // Add a PHI for the pointer if this isn't the last iteration. 798 if ((Idx + 1) != VectorWidth) { 799 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 800 PtrPhi->addIncoming(NewPtr, CondBlock); 801 PtrPhi->addIncoming(Ptr, PrevIfBlock); 802 Ptr = PtrPhi; 803 } 804 } 805 CI->eraseFromParent(); 806 807 ModifiedDT = true; 808} 809 810bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) { 811 bool EverMadeChange = false; 812 813 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 814 815 bool MadeChange = true; 816 while (MadeChange) { 817 MadeChange = false; 818 for (Function::iterator I = F.begin(); I != F.end();) { 819 BasicBlock *BB = &*I++; 820 bool ModifiedDTOnIteration = false; 821 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration); 822 823 // Restart BB iteration if the dominator tree of the Function was changed 824 if (ModifiedDTOnIteration) 825 break; 826 } 827 828 EverMadeChange |= MadeChange; 829 } 830 831 return EverMadeChange; 832} 833 834bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) { 835 bool MadeChange = false; 836 837 BasicBlock::iterator CurInstIterator = BB.begin(); 838 while (CurInstIterator != BB.end()) { 839 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 840 MadeChange |= optimizeCallInst(CI, ModifiedDT); 841 if (ModifiedDT) 842 return true; 843 } 844 845 return MadeChange; 846} 847 848bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI, 849 bool &ModifiedDT) { 850 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 851 if (II) { 852 unsigned Alignment; 853 switch (II->getIntrinsicID()) { 854 default: 855 break; 856 case Intrinsic::masked_load: { 857 // Scalarize unsupported vector masked load 858 Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); 859 if (TTI->isLegalMaskedLoad(CI->getType(), MaybeAlign(Alignment))) 860 return false; 861 scalarizeMaskedLoad(CI, ModifiedDT); 862 return true; 863 } 864 case Intrinsic::masked_store: { 865 Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); 866 if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(), 867 MaybeAlign(Alignment))) 868 return false; 869 scalarizeMaskedStore(CI, ModifiedDT); 870 return true; 871 } 872 case Intrinsic::masked_gather: 873 Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); 874 if (TTI->isLegalMaskedGather(CI->getType(), MaybeAlign(Alignment))) 875 return false; 876 scalarizeMaskedGather(CI, ModifiedDT); 877 return true; 878 case Intrinsic::masked_scatter: 879 Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); 880 if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType(), 881 MaybeAlign(Alignment))) 882 return false; 883 scalarizeMaskedScatter(CI, ModifiedDT); 884 return true; 885 case Intrinsic::masked_expandload: 886 if (TTI->isLegalMaskedExpandLoad(CI->getType())) 887 return false; 888 scalarizeMaskedExpandLoad(CI, ModifiedDT); 889 return true; 890 case Intrinsic::masked_compressstore: 891 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 892 return false; 893 scalarizeMaskedCompressStore(CI, ModifiedDT); 894 return true; 895 } 896 } 897 898 return false; 899} 900