1//===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to config the shape of AMX physical registers
10/// AMX register need to be configured before use. Before FastRegAllocation pass
11/// the ldtilecfg instruction is inserted, however at that time we don't
12/// know the shape of each physical tile registers, because the register
13/// allocation is not done yet. This pass runs after egister allocation
14/// pass. It collects the shape information of each physical tile register
15/// and store the shape in the stack slot that is allocated for load config
16/// to tile config register.
17//
18//===----------------------------------------------------------------------===//
19
20#include "X86.h"
21#include "X86InstrBuilder.h"
22#include "X86MachineFunctionInfo.h"
23#include "X86RegisterInfo.h"
24#include "X86Subtarget.h"
25#include "llvm/CodeGen/MachineFrameInfo.h"
26#include "llvm/CodeGen/MachineFunctionPass.h"
27#include "llvm/CodeGen/MachineInstr.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/CodeGen/Passes.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetRegisterInfo.h"
32#include "llvm/InitializePasses.h"
33
34using namespace llvm;
35
36#define DEBUG_TYPE "fasttileconfig"
37
38namespace {
39
40class X86FastTileConfig : public MachineFunctionPass {
41  // context
42  MachineFunction *MF = nullptr;
43  const X86Subtarget *ST = nullptr;
44  const TargetRegisterInfo *TRI = nullptr;
45  const TargetInstrInfo *TII = nullptr;
46  MachineRegisterInfo *MRI = nullptr;
47
48  MachineInstr *getTileConfigPoint();
49  void tileConfig();
50
51public:
52  X86FastTileConfig() : MachineFunctionPass(ID) {}
53
54  bool fastTileConfig();
55  bool isTileLoad(MachineInstr &MI);
56  bool isTileStore(MachineInstr &MI);
57  bool isAMXInstr(MachineInstr &MI);
58  void getTileStoreShape(MachineInstr &MI,
59                         SmallVector<MachineOperand *> &ShapedTiles);
60
61  MachineInstr *getKeyAMXInstr(MachineInstr *MI);
62  void getTileShapesCfg(MachineInstr *MI,
63                        SmallVector<MachineOperand *> &ShapedTiles);
64  void getShapeCfgInstrs(MachineInstr *MI,
65                         std::map<unsigned, MachineInstr *> &RowCfgs,
66                         std::map<unsigned, MachineInstr *> &ColCfgs);
67
68  /// Return the pass name.
69  StringRef getPassName() const override {
70    return "Fast Tile Register Configure";
71  }
72
73  void materializeTileCfg(MachineInstr *MI);
74
75  void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles,
76                      std::map<unsigned, MachineInstr *> &RowCfgs,
77                      std::map<unsigned, MachineInstr *> &ColCfgs);
78
79  /// Perform register allocation.
80  bool runOnMachineFunction(MachineFunction &MFunc) override;
81
82  MachineFunctionProperties getRequiredProperties() const override {
83    return MachineFunctionProperties().set(
84        MachineFunctionProperties::Property::NoPHIs);
85  }
86
87  static char ID;
88};
89
90} // end anonymous namespace
91
92char X86FastTileConfig::ID = 0;
93
94INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
95                      "Fast Tile Register Configure", false, false)
96INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
97                    "Fast Tile Register Configure", false, false)
98
99static bool isTilePhysReg(MachineOperand &Op) {
100  if (!Op.isReg())
101    return false;
102
103  Register Reg = Op.getReg();
104  if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
105    return true;
106  return false;
107}
108
109static unsigned getTilePhysRegIdx(MachineOperand *Op) {
110  assert(isTilePhysReg(*Op) && "Tile Operand is invalid");
111  return Op->getReg() - X86::TMM0;
112}
113
114static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) {
115  unsigned Offset = 48 + TIdx;
116  MI->getOperand(3).ChangeToImmediate(Offset);
117}
118
119static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) {
120  unsigned Offset = 16 + TIdx * 2;
121  MI->getOperand(3).ChangeToImmediate(Offset);
122}
123
124bool X86FastTileConfig::isTileLoad(MachineInstr &MI) {
125  return MI.getOpcode() == X86::PTILELOADDV;
126}
127bool X86FastTileConfig::isTileStore(MachineInstr &MI) {
128  return MI.getOpcode() == X86::PTILESTOREDV;
129}
130bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) {
131  // TODO: May need to handle some special nontile amx instrucion.
132  if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr())
133    return false;
134
135  for (MachineOperand &MO : MI.operands())
136    if (isTilePhysReg(MO))
137      return true;
138
139  return false;
140}
141
142MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) {
143  auto Cfg = MachineBasicBlock::iterator(MI);
144  MachineBasicBlock *MBB = MI->getParent();
145  MachineInstr *KeyMI = nullptr;
146  int KeyAMXNum = 0;
147
148  for (auto II = Cfg; II != MBB->end(); II++) {
149    if (isTileLoad(*II)) {
150      KeyMI = &*II;
151      continue;
152    }
153
154    if (isTileStore(*II)) {
155      assert(KeyMI && "Key AMX Should be found before!");
156      break;
157    }
158
159    if (isAMXInstr(*II)) {
160      assert((KeyAMXNum == 0) && "Too many Key AMX instruction!");
161      KeyAMXNum++;
162      KeyMI = &*II;
163    }
164  }
165  assert(KeyMI && "There must be an AMX instruction.");
166  return KeyMI;
167}
168
169// Orderly get the tiles in key amx instruction, uses before defs.
170void X86FastTileConfig::getTileShapesCfg(
171    MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) {
172  MachineInstr *KeyMI = getKeyAMXInstr(CfgMI);
173
174  SmallVector<MachineOperand *> DefTiles;
175  for (MachineOperand &MO : KeyMI->operands()) {
176    if (!isTilePhysReg(MO))
177      continue;
178    if (MO.isDef())
179      DefTiles.push_back(&MO);
180    else
181      ShapedTiles.push_back(&MO);
182  }
183  ShapedTiles.append(DefTiles);
184}
185
186// We pre-config the shapes at position named with "amx.tmm.N.shape.row* and
187// amx.shape.N.col*" at pass "Pre AMX Tile Config".
188// The 'N' implies the order of tiles in key amx intrinsic.
189void X86FastTileConfig::getShapeCfgInstrs(
190    MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs,
191    std::map<unsigned, MachineInstr *> &ColCfgs) {
192  auto Cfg = MachineBasicBlock::iterator(MI);
193  MachineBasicBlock *MBB = MI->getParent();
194
195  for (auto II = Cfg; II != MBB->begin(); II--) {
196    if (isAMXInstr(*II) || II->isTerminator() || II->isCall())
197      break;
198    if (!II->mayStore() || !II->hasOneMemOperand())
199      continue;
200    const Value *MemPtr = II->memoperands()[0]->getValue();
201    if (!MemPtr)
202      continue;
203
204    StringRef Name = MemPtr->getName();
205    if (!Name.startswith("amx.tmm."))
206      continue;
207
208    // Get the 'N'th tile shape config in key amx instruction.
209    auto N = Name.find(".shape");
210    StringRef STileIdx = Name.slice(8, N);
211    unsigned Idx;
212    STileIdx.getAsInteger(10, Idx);
213
214    // And related them with their store instructions.
215    if (Name.contains("row"))
216      RowCfgs[Idx] = &*II;
217    else if (Name.contains("col"))
218      ColCfgs[Idx] = &*II;
219    else
220      llvm_unreachable("Invalid tile shape info!");
221  }
222  assert((RowCfgs.size() == ColCfgs.size()) &&
223         "The number of tile row and col must be equal!");
224}
225
226// Here is the data format for the tile config.
227// 0      palette   = 1 now.
228// 1      start_row = 0 now.
229// 2-15   reserved, must be zero
230// 16-17  tile0.colsb Tile 0 bytes per row.
231// 18-19  tile1.colsb Tile 1 bytes per row.
232// 20-21  tile2.colsb Tile 2 bytes per row.
233// ... (sequence continues)
234// 30-31  tile7.colsb Tile 7 bytes per row.
235// 32-47  reserved, must be zero
236// 48     tile0.rows Tile 0 rows.
237// 49     tile1.rows Tile 1 rows.
238// 50     tile2.rows Tile 2 rows.
239// ... (sequence continues)
240// 55     tile7.rows Tile 7 rows.
241// 56-63  reserved, must be zero
242void X86FastTileConfig::rewriteTileCfg(
243    SmallVector<MachineOperand *> &ShapedTiles,
244    std::map<unsigned, MachineInstr *> &RowCfgs,
245    std::map<unsigned, MachineInstr *> &ColCfgs) {
246  assert((RowCfgs.size() == ShapedTiles.size()) &&
247         "The number of tile shapes not equal with the number of tiles!");
248
249  // Orderly get the tiles and adjust the shape config.
250  for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) {
251    MachineOperand *MO = ShapedTiles[I];
252    unsigned TmmIdx = getTilePhysRegIdx(MO);
253    if (I == TmmIdx)
254      continue;
255    adjustRowCfg(TmmIdx, RowCfgs[I]);
256    adjustColCfg(TmmIdx, ColCfgs[I]);
257  }
258}
259
260// We have already preconfig the shapes before fast register allocation at
261// X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register
262// allocation, the shapes pre-written before may not rightly corresponding
263// to the correct tmm registers, so we need adjust them.
264void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) {
265  SmallVector<MachineOperand *> ShapedTiles;
266  std::map<unsigned, MachineInstr *> RowCfgs;
267  std::map<unsigned, MachineInstr *> ColCfgs;
268
269  // Orderly keep the tile uses and def in ShapedTiles;
270  getTileShapesCfg(CfgMI, ShapedTiles);
271  assert(ShapedTiles.size() && "Not find shapes config!");
272
273  getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs);
274
275  rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs);
276}
277
278bool X86FastTileConfig::fastTileConfig() {
279  bool Changed = false;
280
281  for (MachineBasicBlock &MBB : *MF) {
282    SmallVector<MachineInstr *, 2> CFGs;
283    for (MachineInstr &MI : MBB)
284      if (MI.getOpcode() == X86::PLDTILECFGV)
285        CFGs.push_back(&MI);
286    for (auto *MI : CFGs)
287      materializeTileCfg(MI);
288    if (!CFGs.empty())
289      Changed = true;
290  }
291  return Changed;
292}
293
294bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
295  MF = &MFunc;
296  MRI = &MFunc.getRegInfo();
297  ST = &MFunc.getSubtarget<X86Subtarget>();
298  TRI = ST->getRegisterInfo();
299  TII = MFunc.getSubtarget().getInstrInfo();
300
301  return fastTileConfig();
302}
303
304FunctionPass *llvm::createX86FastTileConfigPass() {
305  return new X86FastTileConfig();
306}
307