1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ...                                          |
27/// | ...                                                    |
28/// | "use %td"                                              |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ...                                          |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34/// | ...                                                    |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2"                                             |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
45#include "llvm/Analysis/OptimizationRemarkEmitter.h"
46#include "llvm/Analysis/TargetLibraryInfo.h"
47#include "llvm/Analysis/TargetTransformInfo.h"
48#include "llvm/CodeGen/Passes.h"
49#include "llvm/CodeGen/TargetPassConfig.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/Instructions.h"
55#include "llvm/IR/IntrinsicInst.h"
56#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/IR/PatternMatch.h"
58#include "llvm/InitializePasses.h"
59#include "llvm/Pass.h"
60#include "llvm/Target/TargetMachine.h"
61#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62#include "llvm/Transforms/Utils/Local.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "lower-amx-type"
70
71static bool isAMXCast(Instruction *II) {
72  return match(II,
73               m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74         match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75}
76
77static bool isAMXIntrinsic(Value *I) {
78  auto *II = dyn_cast<IntrinsicInst>(I);
79  if (!II)
80    return false;
81  if (isAMXCast(II))
82    return false;
83  // Check if return type or parameter is x86_amx. If it is x86_amx
84  // the intrinsic must be x86 amx intrinsics.
85  if (II->getType()->isX86_AMXTy())
86    return true;
87  for (Value *V : II->args()) {
88    if (V->getType()->isX86_AMXTy())
89      return true;
90  }
91
92  return false;
93}
94
95static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
96                                           Type *Ty) {
97  Function &F = *BB->getParent();
98  Module *M = BB->getModule();
99  const DataLayout &DL = M->getDataLayout();
100
101  LLVMContext &Ctx = Builder.getContext();
102  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103  unsigned AllocaAS = DL.getAllocaAddrSpace();
104  AllocaInst *AllocaRes =
105      new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106  AllocaRes->setAlignment(AllocaAlignment);
107  return AllocaRes;
108}
109
110static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
111  for (Instruction &I : F.getEntryBlock())
112    if (!isa<AllocaInst>(&I))
113      return &I;
114  llvm_unreachable("No terminator in the entry block!");
115}
116
117static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118  IRBuilder<> Builder(II);
119  Value *Row = nullptr, *Col = nullptr;
120  switch (II->getIntrinsicID()) {
121  default:
122    llvm_unreachable("Expect amx intrinsics");
123  case Intrinsic::x86_tileloadd64_internal:
124  case Intrinsic::x86_tileloaddt164_internal:
125  case Intrinsic::x86_tilestored64_internal: {
126    Row = II->getArgOperand(0);
127    Col = II->getArgOperand(1);
128    break;
129  }
130  // a * b + c
131  // The shape depends on which operand.
132  case Intrinsic::x86_tcmmimfp16ps_internal:
133  case Intrinsic::x86_tcmmrlfp16ps_internal:
134  case Intrinsic::x86_tdpbssd_internal:
135  case Intrinsic::x86_tdpbsud_internal:
136  case Intrinsic::x86_tdpbusd_internal:
137  case Intrinsic::x86_tdpbuud_internal:
138  case Intrinsic::x86_tdpbf16ps_internal:
139  case Intrinsic::x86_tdpfp16ps_internal: {
140    switch (OpNo) {
141    case 3:
142      Row = II->getArgOperand(0);
143      Col = II->getArgOperand(1);
144      break;
145    case 4:
146      Row = II->getArgOperand(0);
147      Col = II->getArgOperand(2);
148      break;
149    case 5:
150      if (isa<ConstantInt>(II->getArgOperand(2)))
151        Row = Builder.getInt16(
152            (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
153      else if (isa<Instruction>(II->getArgOperand(2))) {
154        // When it is not a const value and it is not a function argument, we
155        // create Row after the definition of II->getOperand(2) instead of
156        // before II. For example, II is %118, we try to getshape for %117:
157        //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
158        //   i32> %115).
159        //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
160        //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
161        //   %117).
162        // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
163        // definition is after its user(new tileload for %117).
164        // So, the best choice is to create %row right after the definition of
165        // %106.
166        Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
167        Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
168        cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
169      } else {
170        // When it is not a const value and it is a function argument, we create
171        // Row at the entry bb.
172        IRBuilder<> NewBuilder(
173            getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
174        Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
175      }
176      Col = II->getArgOperand(1);
177      break;
178    }
179    break;
180  }
181  }
182
183  return std::make_pair(Row, Col);
184}
185
186static std::pair<Value *, Value *> getShape(PHINode *Phi) {
187  Use &U = *(Phi->use_begin());
188  unsigned OpNo = U.getOperandNo();
189  User *V = U.getUser();
190  // TODO We don't traverse all users. To make the algorithm simple, here we
191  // just traverse the first user. If we can find shape, then return the shape,
192  // otherwise just return nullptr and the optimization for undef/zero will be
193  // abandoned.
194  while (V) {
195    if (isAMXCast(dyn_cast<Instruction>(V))) {
196      if (V->use_empty())
197        break;
198      Use &U = *(V->use_begin());
199      OpNo = U.getOperandNo();
200      V = U.getUser();
201    } else if (isAMXIntrinsic(V)) {
202      return getShape(cast<IntrinsicInst>(V), OpNo);
203    } else if (isa<PHINode>(V)) {
204      if (V->use_empty())
205        break;
206      Use &U = *(V->use_begin());
207      V = U.getUser();
208    } else {
209      break;
210    }
211  }
212
213  return std::make_pair(nullptr, nullptr);
214}
215
216namespace {
217class X86LowerAMXType {
218  Function &Func;
219
220  // In AMX intrinsics we let Shape = {Row, Col}, but the
221  // RealCol = Col / ElementSize. We may use the RealCol
222  // as a new Row for other new created AMX intrinsics.
223  std::map<Value *, Value *> Col2Row;
224
225public:
226  X86LowerAMXType(Function &F) : Func(F) {}
227  bool visit();
228  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
229  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
230  bool transformBitcast(BitCastInst *Bitcast);
231};
232
233// %src = load <256 x i32>, <256 x i32>* %addr, align 64
234// %2 = bitcast <256 x i32> %src to x86_amx
235// -->
236// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
237// i8* %addr, i64 %stride64)
238void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
239  Value *Row = nullptr, *Col = nullptr;
240  Use &U = *(Bitcast->use_begin());
241  unsigned OpNo = U.getOperandNo();
242  auto *II = cast<IntrinsicInst>(U.getUser());
243  std::tie(Row, Col) = getShape(II, OpNo);
244  IRBuilder<> Builder(Bitcast);
245  // Use the maximun column as stride.
246  Value *Stride = Builder.getInt64(64);
247  Value *I8Ptr = LD->getOperand(0);
248  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
249
250  Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
251                                           std::nullopt, Args);
252  Bitcast->replaceAllUsesWith(NewInst);
253}
254
255// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
256//                                                    %stride);
257// %13 = bitcast x86_amx %src to <256 x i32>
258// store <256 x i32> %13, <256 x i32>* %addr, align 64
259// -->
260// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
261//                                           %stride64, %13)
262void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
263
264  Value *Tile = Bitcast->getOperand(0);
265  auto *II = cast<IntrinsicInst>(Tile);
266  // Tile is output from AMX intrinsic. The first operand of the
267  // intrinsic is row, the second operand of the intrinsic is column.
268  Value *Row = II->getOperand(0);
269  Value *Col = II->getOperand(1);
270  IRBuilder<> Builder(ST);
271  // Use the maximum column as stride. It must be the same with load
272  // stride.
273  Value *Stride = Builder.getInt64(64);
274  Value *I8Ptr = ST->getOperand(1);
275  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
276  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
277                          Args);
278  if (Bitcast->hasOneUse())
279    return;
280  // %13 = bitcast x86_amx %src to <256 x i32>
281  // store <256 x i32> %13, <256 x i32>* %addr, align 64
282  // %add = <256 x i32> %13, <256 x i32> %src2
283  // -->
284  // %13 = bitcast x86_amx %src to <256 x i32>
285  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
286  //                                           %stride64, %13)
287  // %14 = load <256 x i32>, %addr
288  // %add = <256 x i32> %14, <256 x i32> %src2
289  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
290  Bitcast->replaceAllUsesWith(Vec);
291}
292
293// transform bitcast to <store, load> instructions.
294bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
295  IRBuilder<> Builder(Bitcast);
296  AllocaInst *AllocaAddr;
297  Value *I8Ptr, *Stride;
298  auto *Src = Bitcast->getOperand(0);
299
300  auto Prepare = [&](Type *MemTy) {
301    AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
302    I8Ptr = AllocaAddr;
303    Stride = Builder.getInt64(64);
304  };
305
306  if (Bitcast->getType()->isX86_AMXTy()) {
307    // %2 = bitcast <256 x i32> %src to x86_amx
308    // -->
309    // %addr = alloca <256 x i32>, align 64
310    // store <256 x i32> %src, <256 x i32>* %addr, align 64
311    // %addr2 = bitcast <256 x i32>* to i8*
312    // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
313    //                                                  i8* %addr2,
314    //                                                  i64 64)
315    Use &U = *(Bitcast->use_begin());
316    unsigned OpNo = U.getOperandNo();
317    auto *II = dyn_cast<IntrinsicInst>(U.getUser());
318    if (!II)
319      return false; // May be bitcast from x86amx to <256 x i32>.
320    Prepare(Bitcast->getOperand(0)->getType());
321    Builder.CreateStore(Src, AllocaAddr);
322    // TODO we can pick an constant operand for the shape.
323    Value *Row = nullptr, *Col = nullptr;
324    std::tie(Row, Col) = getShape(II, OpNo);
325    std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
326    Value *NewInst = Builder.CreateIntrinsic(
327        Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328    Bitcast->replaceAllUsesWith(NewInst);
329  } else {
330    // %2 = bitcast x86_amx %src to <256 x i32>
331    // -->
332    // %addr = alloca <256 x i32>, align 64
333    // %addr2 = bitcast <256 x i32>* to i8*
334    // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
335    //                                           i8* %addr2, i64 %stride)
336    // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
337    auto *II = dyn_cast<IntrinsicInst>(Src);
338    if (!II)
339      return false; // May be bitcast from <256 x i32> to x86amx.
340    Prepare(Bitcast->getType());
341    Value *Row = II->getOperand(0);
342    Value *Col = II->getOperand(1);
343    std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
344    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
345                            Args);
346    Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
347    Bitcast->replaceAllUsesWith(NewInst);
348  }
349
350  return true;
351}
352
353bool X86LowerAMXType::visit() {
354  SmallVector<Instruction *, 8> DeadInsts;
355  Col2Row.clear();
356
357  for (BasicBlock *BB : post_order(&Func)) {
358    for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
359      auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
360      if (!Bitcast)
361        continue;
362
363      Value *Src = Bitcast->getOperand(0);
364      if (Bitcast->getType()->isX86_AMXTy()) {
365        if (Bitcast->user_empty()) {
366          DeadInsts.push_back(Bitcast);
367          continue;
368        }
369        LoadInst *LD = dyn_cast<LoadInst>(Src);
370        if (!LD) {
371          if (transformBitcast(Bitcast))
372            DeadInsts.push_back(Bitcast);
373          continue;
374        }
375        // If load has mutli-user, duplicate a vector load.
376        // %src = load <256 x i32>, <256 x i32>* %addr, align 64
377        // %2 = bitcast <256 x i32> %src to x86_amx
378        // %add = add <256 x i32> %src, <256 x i32> %src2
379        // -->
380        // %src = load <256 x i32>, <256 x i32>* %addr, align 64
381        // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
382        //                                            i8* %addr, i64 %stride64)
383        // %add = add <256 x i32> %src, <256 x i32> %src2
384
385        // If load has one user, the load will be eliminated in DAG ISel.
386        // %src = load <256 x i32>, <256 x i32>* %addr, align 64
387        // %2 = bitcast <256 x i32> %src to x86_amx
388        // -->
389        // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
390        //                                            i8* %addr, i64 %stride64)
391        combineLoadBitcast(LD, Bitcast);
392        DeadInsts.push_back(Bitcast);
393        if (LD->hasOneUse())
394          DeadInsts.push_back(LD);
395      } else if (Src->getType()->isX86_AMXTy()) {
396        if (Bitcast->user_empty()) {
397          DeadInsts.push_back(Bitcast);
398          continue;
399        }
400        StoreInst *ST = nullptr;
401        for (Use &U : Bitcast->uses()) {
402          ST = dyn_cast<StoreInst>(U.getUser());
403          if (ST)
404            break;
405        }
406        if (!ST) {
407          if (transformBitcast(Bitcast))
408            DeadInsts.push_back(Bitcast);
409          continue;
410        }
411        // If bitcast (%13) has one use, combine bitcast and store to amx store.
412        // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
413        //                                                    %stride);
414        // %13 = bitcast x86_amx %src to <256 x i32>
415        // store <256 x i32> %13, <256 x i32>* %addr, align 64
416        // -->
417        // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
418        //                                           %stride64, %13)
419        //
420        // If bitcast (%13) has multi-use, transform as below.
421        // %13 = bitcast x86_amx %src to <256 x i32>
422        // store <256 x i32> %13, <256 x i32>* %addr, align 64
423        // %add = <256 x i32> %13, <256 x i32> %src2
424        // -->
425        // %13 = bitcast x86_amx %src to <256 x i32>
426        // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
427        //                                           %stride64, %13)
428        // %14 = load <256 x i32>, %addr
429        // %add = <256 x i32> %14, <256 x i32> %src2
430        //
431        combineBitcastStore(Bitcast, ST);
432        // Delete user first.
433        DeadInsts.push_back(ST);
434        DeadInsts.push_back(Bitcast);
435      }
436    }
437  }
438
439  bool C = !DeadInsts.empty();
440
441  for (auto *Inst : DeadInsts)
442    Inst->eraseFromParent();
443
444  return C;
445}
446} // anonymous namespace
447
448static Value *getAllocaPos(BasicBlock *BB) {
449  Module *M = BB->getModule();
450  Function *F = BB->getParent();
451  IRBuilder<> Builder(&F->getEntryBlock().front());
452  const DataLayout &DL = M->getDataLayout();
453  unsigned AllocaAS = DL.getAllocaAddrSpace();
454  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
455  AllocaInst *AllocaRes =
456      new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
457  BasicBlock::iterator Iter = AllocaRes->getIterator();
458  ++Iter;
459  Builder.SetInsertPoint(&*Iter);
460  Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
461  return I8Ptr;
462}
463
464static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
465  assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
466  auto *II = cast<IntrinsicInst>(TileDef);
467  assert(II && "Not tile intrinsic!");
468  Value *Row = II->getOperand(0);
469  Value *Col = II->getOperand(1);
470
471  BasicBlock *BB = TileDef->getParent();
472  BasicBlock::iterator Iter = TileDef->getIterator();
473  IRBuilder<> Builder(BB, ++Iter);
474  Value *Stride = Builder.getInt64(64);
475  std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
476
477  Instruction *TileStore = Builder.CreateIntrinsic(
478      Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
479  return TileStore;
480}
481
482static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
483  Value *V = U.get();
484  assert(V->getType()->isX86_AMXTy() && "Not define tile!");
485
486  // Get tile shape.
487  IntrinsicInst *II = nullptr;
488  if (IsPHI) {
489    Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
490    II = cast<IntrinsicInst>(PhiOp);
491  } else {
492    II = cast<IntrinsicInst>(V);
493  }
494  Value *Row = II->getOperand(0);
495  Value *Col = II->getOperand(1);
496
497  Instruction *UserI = cast<Instruction>(U.getUser());
498  IRBuilder<> Builder(UserI);
499  Value *Stride = Builder.getInt64(64);
500  std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
501
502  Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
503                                            std::nullopt, Args);
504  UserI->replaceUsesOfWith(V, TileLoad);
505}
506
507static bool isIncomingOfPHI(Instruction *I) {
508  for (Use &U : I->uses()) {
509    User *V = U.getUser();
510    if (isa<PHINode>(V))
511      return true;
512  }
513  return false;
514}
515
516// Let all AMX tile data become volatile data, shorten the life range
517// of each tile register before fast register allocation.
518namespace {
519class X86VolatileTileData {
520  Function &F;
521
522public:
523  X86VolatileTileData(Function &Func) : F(Func) {}
524  Value *updatePhiIncomings(BasicBlock *BB,
525                            SmallVector<Instruction *, 2> &Incomings);
526  void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
527  bool volatileTileData();
528  void volatileTilePHI(PHINode *PHI);
529  void volatileTileNonPHI(Instruction *I);
530};
531
532Value *X86VolatileTileData::updatePhiIncomings(
533    BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
534  Value *I8Ptr = getAllocaPos(BB);
535
536  for (auto *I : Incomings) {
537    User *Store = createTileStore(I, I8Ptr);
538
539    // All its uses (except phi) should load from stored mem.
540    for (Use &U : I->uses()) {
541      User *V = U.getUser();
542      if (isa<PHINode>(V) || V == Store)
543        continue;
544      replaceWithTileLoad(U, I8Ptr);
545    }
546  }
547  return I8Ptr;
548}
549
550void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
551                                                Value *StorePtr) {
552  for (Use &U : PHI->uses())
553    replaceWithTileLoad(U, StorePtr, true);
554  PHI->eraseFromParent();
555}
556
557// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
558// and their related AMX intrinsics.
559// 1) PHI Def should change to tileload.
560// 2) PHI Incoming Values should tilestored in just after their def.
561// 3) The mem of these tileload and tilestores should be same.
562// e.g.
563// ------------------------------------------------------
564// bb_dom:
565//   ...
566//   br i1 %bool.cond, label %if.else, label %if.then
567//
568// if.then:
569//   def %t0 = ...
570//   ...
571//   use %t0
572//   ...
573//   br label %if.end
574//
575// if.else:
576//   def %t1 = ...
577//   br label %if.end
578//
579// if.end:
580//   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
581//   ...
582//   use %td
583// ------------------------------------------------------
584// -->
585// ------------------------------------------------------
586// bb_entry:
587//   %mem = alloca <256 x i32>, align 1024                  *
588//   ...
589// bb_dom:
590//   ...
591//   br i1 %bool.cond, label %if.else, label %if.then
592//
593// if.then:
594//   def %t0 = ...
595//   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
596//   ...
597//   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
598//   use %t0`                                               *
599//   ...
600//   br label %if.end
601//
602// if.else:
603//   def %t1 = ...
604//   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
605//   br label %if.end
606//
607// if.end:
608//   ...
609//   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
610//   use %td
611// ------------------------------------------------------
612void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
613  BasicBlock *BB = PHI->getParent();
614  SmallVector<Instruction *, 2> Incomings;
615
616  for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
617    Value *Op = PHI->getIncomingValue(I);
618    Instruction *Inst = dyn_cast<Instruction>(Op);
619    assert(Inst && "We shouldn't fold AMX instrution!");
620    Incomings.push_back(Inst);
621  }
622
623  Value *StorePtr = updatePhiIncomings(BB, Incomings);
624  replacePhiDefWithLoad(PHI, StorePtr);
625}
626
627// Store the defined tile and load it before use.
628// All its users are not PHI.
629// e.g.
630// ------------------------------------------------------
631// def %td = ...
632// ...
633// "use %td"
634// ------------------------------------------------------
635// -->
636// ------------------------------------------------------
637// def %td = ...
638// call void @llvm.x86.tilestored64.internal(mem, %td)
639// ...
640// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
641// "use %td2"
642// ------------------------------------------------------
643void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
644  BasicBlock *BB = I->getParent();
645  Value *I8Ptr = getAllocaPos(BB);
646  User *Store = createTileStore(I, I8Ptr);
647
648  // All its uses should load from stored mem.
649  for (Use &U : I->uses()) {
650    User *V = U.getUser();
651    assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
652    if (V != Store)
653      replaceWithTileLoad(U, I8Ptr);
654  }
655}
656
657// Volatile Tile Model:
658// 1) All the uses of tile data comes from tileload in time.
659// 2) All the defs of tile data tilestore into mem immediately.
660// For example:
661// --------------------------------------------------------------------------
662// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
663// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
664// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
665// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
666// call void @llvm.x86.tilestored64.internal(... td)                     area
667// --------------------------------------------------------------------------
668// 3) No terminator, call or other amx instructions in the key amx area.
669bool X86VolatileTileData::volatileTileData() {
670  bool Changed = false;
671  for (BasicBlock &BB : F) {
672    SmallVector<Instruction *, 2> PHIInsts;
673    SmallVector<Instruction *, 8> AMXDefInsts;
674
675    for (Instruction &I : BB) {
676      if (!I.getType()->isX86_AMXTy())
677        continue;
678      if (isa<PHINode>(&I))
679        PHIInsts.push_back(&I);
680      else
681        AMXDefInsts.push_back(&I);
682    }
683
684    // First we "volatile" the non-phi related amx intrinsics.
685    for (Instruction *I : AMXDefInsts) {
686      if (isIncomingOfPHI(I))
687        continue;
688      volatileTileNonPHI(I);
689      Changed = true;
690    }
691
692    for (Instruction *I : PHIInsts) {
693      volatileTilePHI(dyn_cast<PHINode>(I));
694      Changed = true;
695    }
696  }
697  return Changed;
698}
699
700} // anonymous namespace
701
702namespace {
703
704class X86LowerAMXCast {
705  Function &Func;
706  std::unique_ptr<DominatorTree> DT;
707
708public:
709  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
710  bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
711  bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
712  bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
713  bool combineAMXcast(TargetLibraryInfo *TLI);
714  bool transformAMXCast(IntrinsicInst *AMXCast);
715  bool transformAllAMXCast();
716  bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
717                              SmallSetVector<Instruction *, 16> &DeadInst);
718};
719
720static bool DCEInstruction(Instruction *I,
721                           SmallSetVector<Instruction *, 16> &WorkList,
722                           const TargetLibraryInfo *TLI) {
723  if (isInstructionTriviallyDead(I, TLI)) {
724    salvageDebugInfo(*I);
725    salvageKnowledge(I);
726
727    // Null out all of the instruction's operands to see if any operand becomes
728    // dead as we go.
729    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
730      Value *OpV = I->getOperand(i);
731      I->setOperand(i, nullptr);
732
733      if (!OpV->use_empty() || I == OpV)
734        continue;
735
736      // If the operand is an instruction that became dead as we nulled out the
737      // operand, and if it is 'trivially' dead, delete it in a future loop
738      // iteration.
739      if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
740        if (isInstructionTriviallyDead(OpI, TLI)) {
741          WorkList.insert(OpI);
742        }
743      }
744    }
745    I->eraseFromParent();
746    return true;
747  }
748  return false;
749}
750
751/// This function handles following case
752///
753///     A  ->  B    amxcast
754///     PHI
755///     B  ->  A    amxcast
756///
757/// All the related PHI nodes can be replaced by new PHI nodes with type A.
758/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
759bool X86LowerAMXCast::optimizeAMXCastFromPhi(
760    IntrinsicInst *CI, PHINode *PN,
761    SmallSetVector<Instruction *, 16> &DeadInst) {
762  IRBuilder<> Builder(CI);
763  Value *Src = CI->getOperand(0);
764  Type *SrcTy = Src->getType(); // Type B
765  Type *DestTy = CI->getType(); // Type A
766
767  SmallVector<PHINode *, 4> PhiWorklist;
768  SmallSetVector<PHINode *, 4> OldPhiNodes;
769
770  // Find all of the A->B casts and PHI nodes.
771  // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
772  // OldPhiNodes is used to track all known PHI nodes, before adding a new
773  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
774  PhiWorklist.push_back(PN);
775  OldPhiNodes.insert(PN);
776  while (!PhiWorklist.empty()) {
777    auto *OldPN = PhiWorklist.pop_back_val();
778    for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
779      Value *IncValue = OldPN->getIncomingValue(I);
780      // TODO: currently, We ignore cases where it is a const. In the future, we
781      // might support const.
782      if (isa<Constant>(IncValue)) {
783        auto *IncConst = dyn_cast<Constant>(IncValue);
784        if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
785          return false;
786        Value *Row = nullptr, *Col = nullptr;
787        std::tie(Row, Col) = getShape(OldPN);
788        // TODO: If it is not constant the Row and Col must domoniate tilezero
789        // that we are going to create.
790        if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
791          return false;
792        // Create tilezero at the end of incoming block.
793        auto *Block = OldPN->getIncomingBlock(I);
794        BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
795        Instruction *NewInst = Builder.CreateIntrinsic(
796            Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
797        NewInst->moveBefore(&*Iter);
798        NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799                                          {IncValue->getType()}, {NewInst});
800        NewInst->moveBefore(&*Iter);
801        // Replace InValue with new Value.
802        OldPN->setIncomingValue(I, NewInst);
803        IncValue = NewInst;
804      }
805
806      if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
807        if (OldPhiNodes.insert(PNode))
808          PhiWorklist.push_back(PNode);
809        continue;
810      }
811      Instruction *ACI = dyn_cast<Instruction>(IncValue);
812      if (ACI && isAMXCast(ACI)) {
813        // Verify it's a A->B cast.
814        Type *TyA = ACI->getOperand(0)->getType();
815        Type *TyB = ACI->getType();
816        if (TyA != DestTy || TyB != SrcTy)
817          return false;
818        continue;
819      }
820      return false;
821    }
822  }
823
824  // Check that each user of each old PHI node is something that we can
825  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
826  for (auto *OldPN : OldPhiNodes) {
827    for (User *V : OldPN->users()) {
828      Instruction *ACI = dyn_cast<Instruction>(V);
829      if (ACI && isAMXCast(ACI)) {
830        // Verify it's a B->A cast.
831        Type *TyB = ACI->getOperand(0)->getType();
832        Type *TyA = ACI->getType();
833        if (TyA != DestTy || TyB != SrcTy)
834          return false;
835      } else if (auto *PHI = dyn_cast<PHINode>(V)) {
836        // As long as the user is another old PHI node, then even if we don't
837        // rewrite it, the PHI web we're considering won't have any users
838        // outside itself, so it'll be dead.
839        // example:
840        //   bb.0:
841        //      %0 = amxcast ...
842        //   bb.1:
843        //      %1 = amxcast ...
844        //   bb.2:
845        //      %goodphi = phi %0, %1
846        //      %3 = amxcast %goodphi
847        //   bb.3:
848        //      %goodphi2 = phi %0, %goodphi
849        //      %4 = amxcast %goodphi2
850        // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
851        // outside the phi-web, so the combination stop When
852        // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
853        // will be done.
854        if (OldPhiNodes.count(PHI) == 0)
855          return false;
856      } else
857        return false;
858    }
859  }
860
861  // For each old PHI node, create a corresponding new PHI node with a type A.
862  SmallDenseMap<PHINode *, PHINode *> NewPNodes;
863  for (auto *OldPN : OldPhiNodes) {
864    Builder.SetInsertPoint(OldPN);
865    PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866    NewPNodes[OldPN] = NewPN;
867  }
868
869  // Fill in the operands of new PHI nodes.
870  for (auto *OldPN : OldPhiNodes) {
871    PHINode *NewPN = NewPNodes[OldPN];
872    for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873      Value *V = OldPN->getOperand(j);
874      Value *NewV = nullptr;
875      Instruction *ACI = dyn_cast<Instruction>(V);
876      // There should not be a AMXcast from a const.
877      if (ACI && isAMXCast(ACI))
878        NewV = ACI->getOperand(0);
879      else if (auto *PrevPN = dyn_cast<PHINode>(V))
880        NewV = NewPNodes[PrevPN];
881      assert(NewV);
882      NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
883    }
884  }
885
886  // Traverse all accumulated PHI nodes and process its users,
887  // which are Stores and BitcCasts. Without this processing
888  // NewPHI nodes could be replicated and could lead to extra
889  // moves generated after DeSSA.
890  // If there is a store with type B, change it to type A.
891
892  // Replace users of BitCast B->A with NewPHI. These will help
893  // later to get rid of a closure formed by OldPHI nodes.
894  for (auto *OldPN : OldPhiNodes) {
895    PHINode *NewPN = NewPNodes[OldPN];
896    for (User *V : make_early_inc_range(OldPN->users())) {
897      Instruction *ACI = dyn_cast<Instruction>(V);
898      if (ACI && isAMXCast(ACI)) {
899        Type *TyB = ACI->getOperand(0)->getType();
900        Type *TyA = ACI->getType();
901        assert(TyA == DestTy && TyB == SrcTy);
902        (void)TyA;
903        (void)TyB;
904        ACI->replaceAllUsesWith(NewPN);
905        DeadInst.insert(ACI);
906      } else if (auto *PHI = dyn_cast<PHINode>(V)) {
907        // We don't need to push PHINode into DeadInst since they are operands
908        // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
909        assert(OldPhiNodes.contains(PHI));
910        (void)PHI;
911      } else
912        llvm_unreachable("all uses should be handled");
913    }
914  }
915  return true;
916}
917
918// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
919// store <256 x i32> %43, <256 x i32>* %p, align 64
920// -->
921// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
922//                                           i64 64, x86_amx %42)
923bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
924  Value *Tile = Cast->getOperand(0);
925  // TODO: If it is cast intrinsic or phi node, we can propagate the
926  // shape information through def-use chain.
927  if (!isAMXIntrinsic(Tile))
928    return false;
929  auto *II = cast<IntrinsicInst>(Tile);
930  // Tile is output from AMX intrinsic. The first operand of the
931  // intrinsic is row, the second operand of the intrinsic is column.
932  Value *Row = II->getOperand(0);
933  Value *Col = II->getOperand(1);
934  IRBuilder<> Builder(ST);
935  // Stride should be equal to col(measured by bytes)
936  Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
937  Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
938  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
939  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
940                          Args);
941  return true;
942}
943
944// %65 = load <256 x i32>, <256 x i32>* %p, align 64
945// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
946// -->
947// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
948//                                                   i8* %p, i64 64)
949bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
950  bool EraseLoad = true;
951  Value *Row = nullptr, *Col = nullptr;
952  Use &U = *(Cast->use_begin());
953  unsigned OpNo = U.getOperandNo();
954  auto *II = cast<IntrinsicInst>(U.getUser());
955  // TODO: If it is cast intrinsic or phi node, we can propagate the
956  // shape information through def-use chain.
957  if (!isAMXIntrinsic(II))
958    return false;
959  std::tie(Row, Col) = getShape(II, OpNo);
960  IRBuilder<> Builder(LD);
961  // Stride should be equal to col(measured by bytes)
962  Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
963  Value *I8Ptr;
964
965  // To save compiling time, we create doninator tree when it is really
966  // needed.
967  if (!DT)
968    DT.reset(new DominatorTree(Func));
969  if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
970    // store the value to stack and reload it from stack before cast.
971    auto *AllocaAddr =
972        createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
973    Builder.SetInsertPoint(&*std::next(LD->getIterator()));
974    Builder.CreateStore(LD, AllocaAddr);
975
976    Builder.SetInsertPoint(Cast);
977    I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
978    EraseLoad = false;
979  } else {
980    I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
981  }
982  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
983
984  Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
985                                           std::nullopt, Args);
986  Cast->replaceAllUsesWith(NewInst);
987
988  return EraseLoad;
989}
990
991bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
992  bool Change = false;
993  for (auto *Cast : Casts) {
994    auto *II = cast<IntrinsicInst>(Cast);
995    // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
996    // store <256 x i32> %43, <256 x i32>* %p, align 64
997    // -->
998    // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
999    //                                           i64 64, x86_amx %42)
1000    if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1001      SmallVector<Instruction *, 2> DeadStores;
1002      for (User *U : Cast->users()) {
1003        StoreInst *Store = dyn_cast<StoreInst>(U);
1004        if (!Store)
1005          continue;
1006        if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1007          DeadStores.push_back(Store);
1008          Change = true;
1009        }
1010      }
1011      for (auto *Store : DeadStores)
1012        Store->eraseFromParent();
1013    } else { // x86_cast_vector_to_tile
1014      SmallVector<Instruction *, 2> DeadLoads;
1015      auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1016      if (!Load || !Load->hasOneUse())
1017        continue;
1018      // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1019      // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1020      // -->
1021      // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1022      //                                                   i8* %p, i64 64)
1023      if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1024        // Set the operand is null so that load instruction can be erased.
1025        Cast->setOperand(0, nullptr);
1026        Load->eraseFromParent();
1027      }
1028    }
1029  }
1030  return Change;
1031}
1032
1033bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1034  bool Change = false;
1035  // Collect tile cast instruction.
1036  SmallVector<Instruction *, 8> Vec2TileInsts;
1037  SmallVector<Instruction *, 8> Tile2VecInsts;
1038  SmallVector<Instruction *, 8> PhiCastWorkList;
1039  SmallSetVector<Instruction *, 16> DeadInst;
1040  for (BasicBlock &BB : Func) {
1041    for (Instruction &I : BB) {
1042      Value *Vec;
1043      if (match(&I,
1044                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1045        Vec2TileInsts.push_back(&I);
1046      else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1047                             m_Value(Vec))))
1048        Tile2VecInsts.push_back(&I);
1049    }
1050  }
1051
1052  auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1053    for (auto *Inst : Insts) {
1054      for (User *U : Inst->users()) {
1055        IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1056        if (!II || II->getIntrinsicID() != IID)
1057          continue;
1058        // T1 = vec2tile V0
1059        // V2 = tile2vec T1
1060        // V3 = OP V2
1061        // -->
1062        // T1 = vec2tile V0
1063        // V2 = tile2vec T1
1064        // V3 = OP V0
1065        II->replaceAllUsesWith(Inst->getOperand(0));
1066        Change = true;
1067      }
1068    }
1069  };
1070
1071  Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1072  Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1073
1074  SmallVector<Instruction *, 8> LiveCasts;
1075  auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1076    for (auto *Inst : Insts) {
1077      if (Inst->use_empty()) {
1078        Inst->eraseFromParent();
1079        Change = true;
1080      } else {
1081        LiveCasts.push_back(Inst);
1082      }
1083    }
1084  };
1085
1086  EraseInst(Vec2TileInsts);
1087  EraseInst(Tile2VecInsts);
1088  LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1089                       "Vec2Tile and Tile2Vec:\n";
1090             Func.dump());
1091  Change |= combineLdSt(LiveCasts);
1092  EraseInst(LiveCasts);
1093  LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1094                       "AMXCast and load/store:\n";
1095             Func.dump());
1096
1097  // Handle the A->B->A cast, and there is an intervening PHI node.
1098  for (BasicBlock &BB : Func) {
1099    for (Instruction &I : BB) {
1100      if (isAMXCast(&I)) {
1101        if (isa<PHINode>(I.getOperand(0)))
1102          PhiCastWorkList.push_back(&I);
1103      }
1104    }
1105  }
1106  for (auto *I : PhiCastWorkList) {
1107    // We skip the dead Amxcast.
1108    if (DeadInst.contains(I))
1109      continue;
1110    PHINode *PN = cast<PHINode>(I->getOperand(0));
1111    if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1112      DeadInst.insert(PN);
1113      Change = true;
1114    }
1115  }
1116
1117  // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1118  // have no uses. We do some DeadCodeElimination for them.
1119  while (!DeadInst.empty()) {
1120    Instruction *I = DeadInst.pop_back_val();
1121    Change |= DCEInstruction(I, DeadInst, TLI);
1122  }
1123  LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1124                       "optimizeAMXCastFromPhi:\n";
1125             Func.dump());
1126  return Change;
1127}
1128
1129// There might be remaining AMXcast after combineAMXcast and they should be
1130// handled elegantly.
1131bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1132  IRBuilder<> Builder(AMXCast);
1133  AllocaInst *AllocaAddr;
1134  Value *I8Ptr, *Stride;
1135  auto *Src = AMXCast->getOperand(0);
1136
1137  auto Prepare = [&](Type *MemTy) {
1138    AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1139    I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1140    Stride = Builder.getInt64(64);
1141  };
1142
1143  if (AMXCast->getType()->isX86_AMXTy()) {
1144    // %2 = amxcast <225 x i32> %src to x86_amx
1145    // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1146    //                                           i8* %addr3, i64 60, x86_amx %2)
1147    // -->
1148    // %addr = alloca <225 x i32>, align 64
1149    // store <225 x i32> %src, <225 x i32>* %addr, align 64
1150    // %addr2 = bitcast <225 x i32>* %addr to i8*
1151    // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1152    //                                                  i8* %addr2,
1153    //                                                  i64 60)
1154    // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1155    //                                           i8* %addr3, i64 60, x86_amx %2)
1156    if (AMXCast->use_empty()) {
1157      AMXCast->eraseFromParent();
1158      return true;
1159    }
1160    Use &U = *(AMXCast->use_begin());
1161    unsigned OpNo = U.getOperandNo();
1162    auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1163    if (!II)
1164      return false; // May be bitcast from x86amx to <256 x i32>.
1165    Prepare(AMXCast->getOperand(0)->getType());
1166    Builder.CreateStore(Src, AllocaAddr);
1167    // TODO we can pick an constant operand for the shape.
1168    Value *Row = nullptr, *Col = nullptr;
1169    std::tie(Row, Col) = getShape(II, OpNo);
1170    std::array<Value *, 4> Args = {
1171        Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1172    Value *NewInst = Builder.CreateIntrinsic(
1173        Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1174    AMXCast->replaceAllUsesWith(NewInst);
1175    AMXCast->eraseFromParent();
1176  } else {
1177    // %2 = amxcast x86_amx %src to <225 x i32>
1178    // -->
1179    // %addr = alloca <225 x i32>, align 64
1180    // %addr2 = bitcast <225 x i32>* to i8*
1181    // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1182    //                                           i8* %addr2, i64 %stride)
1183    // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1184    auto *II = dyn_cast<IntrinsicInst>(Src);
1185    if (!II)
1186      return false; // May be bitcast from <256 x i32> to x86amx.
1187    Prepare(AMXCast->getType());
1188    Value *Row = II->getOperand(0);
1189    Value *Col = II->getOperand(1);
1190    std::array<Value *, 5> Args = {
1191        Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1192    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1193                            Args);
1194    Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1195    AMXCast->replaceAllUsesWith(NewInst);
1196    AMXCast->eraseFromParent();
1197  }
1198
1199  return true;
1200}
1201
1202bool X86LowerAMXCast::transformAllAMXCast() {
1203  bool Change = false;
1204  // Collect tile cast instruction.
1205  SmallVector<Instruction *, 8> WorkLists;
1206  for (BasicBlock &BB : Func) {
1207    for (Instruction &I : BB) {
1208      if (isAMXCast(&I))
1209        WorkLists.push_back(&I);
1210    }
1211  }
1212
1213  for (auto *Inst : WorkLists) {
1214    Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1215  }
1216
1217  return Change;
1218}
1219
1220} // anonymous namespace
1221
1222namespace {
1223
1224class X86LowerAMXTypeLegacyPass : public FunctionPass {
1225public:
1226  static char ID;
1227
1228  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1229    initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1230  }
1231
1232  bool runOnFunction(Function &F) override {
1233    bool C = false;
1234    TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1235    TargetLibraryInfo *TLI =
1236        &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1237
1238    X86LowerAMXCast LAC(F);
1239    C |= LAC.combineAMXcast(TLI);
1240    // There might be remaining AMXcast after combineAMXcast and they should be
1241    // handled elegantly.
1242    C |= LAC.transformAllAMXCast();
1243
1244    X86LowerAMXType LAT(F);
1245    C |= LAT.visit();
1246
1247    // Prepare for fast register allocation at O0.
1248    // Todo: May better check the volatile model of AMX code, not just
1249    // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1250    if (TM->getOptLevel() == CodeGenOptLevel::None) {
1251      // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1252      // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1253      // sure the amx data is volatile, that is nessary for AMX fast
1254      // register allocation.
1255      if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1256        X86VolatileTileData VTD(F);
1257        C = VTD.volatileTileData() || C;
1258      }
1259    }
1260
1261    return C;
1262  }
1263
1264  void getAnalysisUsage(AnalysisUsage &AU) const override {
1265    AU.setPreservesCFG();
1266    AU.addRequired<TargetPassConfig>();
1267    AU.addRequired<TargetLibraryInfoWrapperPass>();
1268  }
1269};
1270
1271} // anonymous namespace
1272
1273static const char PassName[] = "Lower AMX type for load/store";
1274char X86LowerAMXTypeLegacyPass::ID = 0;
1275INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1276                      false)
1277INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1278INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1279INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1280                    false)
1281
1282FunctionPass *llvm::createX86LowerAMXTypePass() {
1283  return new X86LowerAMXTypeLegacyPass();
1284}
1285