1//===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// Insert tilecfg for each area of key AMX intrinsic.
10/// All the key AMX intrinsic's tile operand must come from tileload. And the
11/// def tile of key AMX intrinsic must be tilestored.
12/// take tdpbssd for example:
13/// --------------------------------------------------------------------------
14/// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...)                key
15/// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...)                 |
16/// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...)                amx
17/// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3)         |
18/// call void @llvm.x86.tilestored64.internal(... td)                     area
19/// --------------------------------------------------------------------------
20/// This pass will insert tilecfg before every key-amx-area, some like:
21/// --------------------------------------------------------------------------
22/// %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
23/// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
24/// ...
25/// ... pre-config shape of %t1                                 *
26/// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
27/// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
28/// ...                                                         *
29/// ... pre-config shape of %t2                                 * shapes
30/// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     *
31/// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
32/// ...
33/// call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * tile config
34//
35//===----------------------------------------------------------------------===//
36//
37#include "X86.h"
38#include "llvm/ADT/SmallSet.h"
39#include "llvm/Analysis/TargetTransformInfo.h"
40#include "llvm/CodeGen/Passes.h"
41#include "llvm/CodeGen/TargetPassConfig.h"
42#include "llvm/CodeGen/ValueTypes.h"
43#include "llvm/IR/DataLayout.h"
44#include "llvm/IR/Function.h"
45#include "llvm/IR/IRBuilder.h"
46#include "llvm/IR/Instructions.h"
47#include "llvm/IR/IntrinsicInst.h"
48#include "llvm/IR/IntrinsicsX86.h"
49#include "llvm/IR/PatternMatch.h"
50#include "llvm/InitializePasses.h"
51#include "llvm/Pass.h"
52#include "llvm/Support/raw_ostream.h"
53#include "llvm/Target/TargetMachine.h"
54
55using namespace llvm;
56using namespace PatternMatch;
57
58#define DEBUG_TYPE "pre-amx-config"
59
60static bool isAMXIntrinsic(IntrinsicInst *II) {
61  for (Value *Operand : II->operands())
62    if (Operand->getType()->isX86_AMXTy())
63      return true;
64  return II->getType()->isX86_AMXTy();
65}
66
67static bool isTileLoad(IntrinsicInst *II) {
68  return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
69         II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
70}
71
72static bool isTileStore(IntrinsicInst *II) {
73  return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
74}
75
76#ifndef NDEBUG
77static bool onlyTileDef(IntrinsicInst *II) {
78  for (Value *Operand : II->operands())
79    if (Operand->getType()->isX86_AMXTy())
80      return false;
81  return II->getType()->isX86_AMXTy();
82}
83
84static bool brokenVolatile(Instruction *I) {
85  // Todo: it is weak to identify a normal call here.
86  if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
87    return true;
88  return false;
89}
90#endif
91
92namespace {
93class X86PreAMXConfig {
94  using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>;
95
96  Function &F;
97
98public:
99  X86PreAMXConfig(Function &Func) : F(Func) {}
100  bool preTileConfig();
101  void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
102  bool findConfigShapes(PosAndShapesMap &PosAndShapes);
103  bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
104  void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
105                       SmallVector<Value *, 8> &Shapes);
106  BasicBlock::iterator
107  getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
108                           SmallVector<Value *, 8> &Shapes);
109  bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
110                          IntrinsicInst *KeyAMX);
111};
112
113// Orderly write the shapes in tilecfg's mem. This maybe not right.
114// Because the first shape may not corresponding to the first tmm register,
115// so we need to handle at at X86FastTileConfig::materializeTileCfg()
116// after register allocation.
117// For example:
118// --------------------------------------------------------------------------
119// zeroinitialize tilecfg's mem (of ldtilecfg)
120// --------------------------------------------------------------------------
121// ... pre-config shape of %t1                                 *
122// %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48   *
123// %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
124// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
125// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
126// ...                                                         *
127// ... pre-config shape of %t2                                 *
128// %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49   *
129// %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
130// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
131// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
132// ...                                                         *
133// ... pre-config shape of %t3                                 * of
134// %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50   *
135// %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
136// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
137// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
138// ...                                                         * tiles
139// ... pre-config shape of %td                                 *
140// %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51   *
141// %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
142// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
143// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
144// --------------------------------------------------------------------------
145// call void @llvm.x86.ldtilecfg(i8* %mem)                     * tile config
146// --------------------------------------------------------------------------
147// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
148// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
149// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
150// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
151// call void @llvm.x86.tilestored64.internal(... td)                     area
152// --------------------------------------------------------------------------
153void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
154                                      SmallVector<Value *, 8> &Shapes) {
155  LLVMContext &Ctx = Builder.getContext();
156  Type *I8Ty = Type::getInt8Ty(Ctx);
157  Type *I16Ty = Type::getInt16Ty(Ctx);
158
159  // TODO: Currently we defaultly set Palette = 1, it may be assigned to
160  // other value in the future.
161  Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
162  Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
163  Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset);
164  Builder.CreateStore(PaletteValue, PalettePos);
165
166  for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
167    Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
168    Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
169    const std::string ShapeName = "amx.tmm." + itostr(I);
170    Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset,
171                                      ShapeName + ".shape.row");
172    Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset);
173    ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0),
174                                   ShapeName + ".shape.col");
175    Value *Row = Shapes[I * 2];
176    Value *Col = Shapes[I * 2 + 1];
177    Row = Builder.CreateTrunc(Row, I8Ty);
178    Builder.CreateStore(Row, RowPos);
179    Builder.CreateStore(Col, ColPos);
180  }
181}
182
183void X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
184                                    SmallVector<Value *, 8> &Shapes) {
185  Module *M = F.getParent();
186  IRBuilder<> Builder(ModelStart);
187  const DataLayout &DL = M->getDataLayout();
188  unsigned AddrSpace = DL.getAllocaAddrSpace();
189  LLVMContext &Ctx = Builder.getContext();
190  Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
191  Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
192
193  AllocaInst *Addr =
194      new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
195  Addr->setAlignment(Alignment);
196  Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
197
198  Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment);
199
200  preWriteTileCfg(I8Ptr, Builder, Shapes);
201
202  Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, std::nullopt,
203                          {I8Ptr});
204}
205
206// Todo: We may need to handle "more than one store" case in the future.
207bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
208                                         IntrinsicInst *Store,
209                                         IntrinsicInst *KeyAMX) {
210  Value *ST = Store->getOperand(4);
211
212  // Only has tileload and tilestore.
213  if (!KeyAMX)
214    return (Loads.size() == 1) && Loads.contains(ST);
215
216  // All Loads should be operands of KeyAMX.
217  // All tile operands of KeyAMX should come from Loads.
218  for (Value *Op : KeyAMX->operands()) {
219    if (Op->getType()->isX86_AMXTy())
220      if (!Loads.erase(Op))
221        return false;
222  }
223
224  // The def of KeyAMX should be stored into mem.
225  // Todo: is it key amx can be no def?
226  return Loads.empty() && (ST == cast<Value>(KeyAMX));
227}
228
229bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
230                                      SmallVector<Value *, 8> &Shapes) {
231  for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
232    Value *Op = KeyAMX->getOperand(I);
233    if (!Op->getType()->isX86_AMXTy())
234      continue;
235    IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
236    assert((TileDef && isTileLoad(TileDef)) &&
237           "All KeyAMX's tile definiation should comes from TileLoad!");
238    Shapes.push_back(TileDef->getOperand(0));
239    Shapes.push_back(TileDef->getOperand(1));
240  }
241  if (!isTileStore(KeyAMX)) {
242    Shapes.push_back(KeyAMX->getOperand(0));
243    Shapes.push_back(KeyAMX->getOperand(1));
244  }
245  return Shapes.size() != 0;
246}
247
248// Collect the shapes and skip the area of current key amx intrinsic.
249//
250// For example:
251// ...
252// --------------------------------------------------------------------------
253// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)  record (m,k)
254// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)  record (m,k)
255// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)  record (m,k)
256// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
257// call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
258// --------------------------------------------------------------------------
259BasicBlock::iterator
260X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
261                                          SmallVector<Value *, 8> &Shapes) {
262  IntrinsicInst *KeyAMX = nullptr;
263  BasicBlock *BB = Iter->getParent();
264  BasicBlock::iterator PosEnd = BB->end();
265  SmallSet<Value *, 4> Loads;
266
267  // See TileStore as "Config Position End" and check volatile model.
268  for (auto I = Iter, E = BB->end(); I != E; ++I) {
269    assert(!brokenVolatile(&*I) && "Not reach tile store!");
270    IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
271    if (!II || !isAMXIntrinsic(II))
272      continue;
273
274    if (isTileLoad(II)) {
275      Loads.insert(II);
276    } else if (isTileStore(II)) {
277      if (!checkVolatileModel(Loads, II, KeyAMX))
278        report_fatal_error("Not Volatile AMX Model!");
279      PosEnd = I;
280      break;
281    } else {
282      assert(!KeyAMX && "Too many key amx intrinsic!");
283      KeyAMX = II;
284    }
285  }
286  assert(PosEnd != BB->end() && "Not find TileStore!");
287
288  // See KeyAMX as TileStore if only TileLoad and TileStore.
289  if (!KeyAMX)
290    KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
291
292  // Get Shapes in order.
293  assert(Shapes.empty() && "Shapes should be clean.");
294  getKeyAMXShapes(KeyAMX, Shapes);
295
296  return PosEnd;
297}
298
299// Record a key amx area's shapes with its position.
300// Use the first tileload as its position.
301// For example:
302// ...
303// --------------------------------------------------------------------------
304// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)   <--  pos
305// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)        /
306// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)     shapes:
307// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)    (m,k)(k,n)
308// call void @llvm.x86.tilestored64.internal(m, n,... td)          (m,n)(m,n)
309// --------------------------------------------------------------------------
310bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) {
311  bool Find = false;
312  for (BasicBlock &BB : F) {
313    for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
314      IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
315      if (!II)
316        continue;
317      if (!isAMXIntrinsic(II))
318        continue;
319      assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
320
321      I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
322      Find = true;
323    }
324  }
325  return Find;
326}
327
328// Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
329// e.g. (key amx = tdpbssd)
330// --------------------------------------------------------------------------
331// %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
332// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
333// ...
334// ... pre-config shape of %t1                                 *
335// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
336// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
337// ...                                                         *
338// ... pre-config shape of %t2                                 *
339// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
340// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
341// ...                                                         *
342// ... pre-config shape of %t3                                 * of
343// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
344// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
345// ...                                                         * tiles
346// ... pre-config shape of %td                                 *
347// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
348// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
349//
350// call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * pre-config
351// --------------------------------------------------------------------------
352// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
353// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
354// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
355// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
356// call void @llvm.x86.tilestored64.internal(... td)                     area
357// --------------------------------------------------------------------------
358bool X86PreAMXConfig::preTileConfig() {
359  PosAndShapesMap PosAndShapes;
360  bool NeedCfg = findConfigShapes(PosAndShapes);
361  if (!NeedCfg)
362    return false;
363  for (auto &IPAndShapes : PosAndShapes)
364    addTileConfig(IPAndShapes.first, IPAndShapes.second);
365
366  return true;
367}
368} // anonymous namespace
369
370namespace {
371
372class X86PreAMXConfigPass : public FunctionPass {
373public:
374  static char ID;
375
376  X86PreAMXConfigPass() : FunctionPass(ID) {
377    initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
378  }
379
380  bool runOnFunction(Function &F) override {
381    TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
382    bool C = false;
383
384    // Prepare for fast register allocation at O0.
385    if (TM->getOptLevel() == CodeGenOpt::None) {
386
387      // We pre-config each key AMX intrinsic at O0.
388      // In theory, one tile config can cover several AMX intrinsics, but
389      // it is very diffcult to classify the tile shapes at O0. So here we
390      // let thing be easy, pre-config every key AMX intrinsic.
391      X86PreAMXConfig PCFG(F);
392      C = PCFG.preTileConfig();
393    }
394
395    return C;
396  }
397
398  void getAnalysisUsage(AnalysisUsage &AU) const override {
399    AU.setPreservesCFG();
400    AU.addRequired<TargetPassConfig>();
401  }
402};
403
404} // anonymous namespace
405
406static const char PassName[] = "Pre AMX Tile Config";
407char X86PreAMXConfigPass::ID = 0;
408INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
409INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
410INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
411
412FunctionPass *llvm::createX86PreAMXConfigPass() {
413  return new X86PreAMXConfigPass();
414}
415