1//===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9/// Insert tilecfg for each area of key AMX intrinsic. 10/// All the key AMX intrinsic's tile operand must come from tileload. And the 11/// def tile of key AMX intrinsic must be tilestored. 12/// take tdpbssd for example: 13/// -------------------------------------------------------------------------- 14/// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key 15/// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) | 16/// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx 17/// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) | 18/// call void @llvm.x86.tilestored64.internal(... td) area 19/// -------------------------------------------------------------------------- 20/// This pass will insert tilecfg before every key-amx-area, some like: 21/// -------------------------------------------------------------------------- 22/// %cfgmem = alloca <16 x i32>, align 4 * allocate mem 23/// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init 24/// ... 25/// ... pre-config shape of %t1 * 26/// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 27/// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 28/// ... * 29/// ... pre-config shape of %t2 * shapes 30/// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * 31/// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 32/// ... 33/// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config 34// 35//===----------------------------------------------------------------------===// 36// 37#include "X86.h" 38#include "llvm/ADT/SmallSet.h" 39#include "llvm/Analysis/TargetTransformInfo.h" 40#include "llvm/CodeGen/Passes.h" 41#include "llvm/CodeGen/TargetPassConfig.h" 42#include "llvm/CodeGen/ValueTypes.h" 43#include "llvm/IR/DataLayout.h" 44#include "llvm/IR/Function.h" 45#include "llvm/IR/IRBuilder.h" 46#include "llvm/IR/Instructions.h" 47#include "llvm/IR/IntrinsicInst.h" 48#include "llvm/IR/IntrinsicsX86.h" 49#include "llvm/IR/PatternMatch.h" 50#include "llvm/InitializePasses.h" 51#include "llvm/Pass.h" 52#include "llvm/Support/raw_ostream.h" 53#include "llvm/Target/TargetMachine.h" 54 55using namespace llvm; 56using namespace PatternMatch; 57 58#define DEBUG_TYPE "pre-amx-config" 59 60static bool isAMXIntrinsic(IntrinsicInst *II) { 61 for (Value *Operand : II->operands()) 62 if (Operand->getType()->isX86_AMXTy()) 63 return true; 64 return II->getType()->isX86_AMXTy(); 65} 66 67static bool isTileLoad(IntrinsicInst *II) { 68 return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal || 69 II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal; 70} 71 72static bool isTileStore(IntrinsicInst *II) { 73 return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal; 74} 75 76#ifndef NDEBUG 77static bool onlyTileDef(IntrinsicInst *II) { 78 for (Value *Operand : II->operands()) 79 if (Operand->getType()->isX86_AMXTy()) 80 return false; 81 return II->getType()->isX86_AMXTy(); 82} 83 84static bool brokenVolatile(Instruction *I) { 85 // Todo: it is weak to identify a normal call here. 86 if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator()) 87 return true; 88 return false; 89} 90#endif 91 92namespace { 93class X86PreAMXConfig { 94 using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>; 95 96 Function &F; 97 98public: 99 X86PreAMXConfig(Function &Func) : F(Func) {} 100 bool preTileConfig(); 101 void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes); 102 bool findConfigShapes(PosAndShapesMap &PosAndShapes); 103 bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes); 104 void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder, 105 SmallVector<Value *, 8> &Shapes); 106 BasicBlock::iterator 107 getShapesAndConfigPosEnd(BasicBlock::iterator Iter, 108 SmallVector<Value *, 8> &Shapes); 109 bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store, 110 IntrinsicInst *KeyAMX); 111}; 112 113// Orderly write the shapes in tilecfg's mem. This maybe not right. 114// Because the first shape may not corresponding to the first tmm register, 115// so we need to handle at at X86FastTileConfig::materializeTileCfg() 116// after register allocation. 117// For example: 118// -------------------------------------------------------------------------- 119// zeroinitialize tilecfg's mem (of ldtilecfg) 120// -------------------------------------------------------------------------- 121// ... pre-config shape of %t1 * 122// %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 * 123// %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 * 124// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 125// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 126// ... * 127// ... pre-config shape of %t2 * 128// %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 * 129// %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 * 130// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes 131// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 132// ... * 133// ... pre-config shape of %t3 * of 134// %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 * 135// %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 * 136// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * 137// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * 138// ... * tiles 139// ... pre-config shape of %td * 140// %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 * 141// %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 * 142// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * 143// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * 144// -------------------------------------------------------------------------- 145// call void @llvm.x86.ldtilecfg(i8* %mem) * tile config 146// -------------------------------------------------------------------------- 147// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 148// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 149// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 150// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 151// call void @llvm.x86.tilestored64.internal(... td) area 152// -------------------------------------------------------------------------- 153void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder, 154 SmallVector<Value *, 8> &Shapes) { 155 LLVMContext &Ctx = Builder.getContext(); 156 Type *I8Ty = Type::getInt8Ty(Ctx); 157 Type *I16Ty = Type::getInt16Ty(Ctx); 158 159 // TODO: Currently we defaultly set Palette = 1, it may be assigned to 160 // other value in the future. 161 Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0); 162 Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); 163 Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset); 164 Builder.CreateStore(PaletteValue, PalettePos); 165 166 for (int I = 0, E = Shapes.size() / 2; I < E; I++) { 167 Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I); 168 Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2); 169 const std::string ShapeName = "amx.tmm." + itostr(I); 170 Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset, 171 ShapeName + ".shape.row"); 172 Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset); 173 ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0), 174 ShapeName + ".shape.col"); 175 Value *Row = Shapes[I * 2]; 176 Value *Col = Shapes[I * 2 + 1]; 177 Row = Builder.CreateTrunc(Row, I8Ty); 178 Builder.CreateStore(Row, RowPos); 179 Builder.CreateStore(Col, ColPos); 180 } 181} 182 183void X86PreAMXConfig::addTileConfig(Instruction *ModelStart, 184 SmallVector<Value *, 8> &Shapes) { 185 Module *M = F.getParent(); 186 IRBuilder<> Builder(ModelStart); 187 const DataLayout &DL = M->getDataLayout(); 188 unsigned AddrSpace = DL.getAllocaAddrSpace(); 189 LLVMContext &Ctx = Builder.getContext(); 190 Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false); 191 Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx)); 192 193 AllocaInst *Addr = 194 new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front()); 195 Addr->setAlignment(Alignment); 196 Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy()); 197 198 Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment); 199 200 preWriteTileCfg(I8Ptr, Builder, Shapes); 201 202 Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, std::nullopt, 203 {I8Ptr}); 204} 205 206// Todo: We may need to handle "more than one store" case in the future. 207bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads, 208 IntrinsicInst *Store, 209 IntrinsicInst *KeyAMX) { 210 Value *ST = Store->getOperand(4); 211 212 // Only has tileload and tilestore. 213 if (!KeyAMX) 214 return (Loads.size() == 1) && Loads.contains(ST); 215 216 // All Loads should be operands of KeyAMX. 217 // All tile operands of KeyAMX should come from Loads. 218 for (Value *Op : KeyAMX->operands()) { 219 if (Op->getType()->isX86_AMXTy()) 220 if (!Loads.erase(Op)) 221 return false; 222 } 223 224 // The def of KeyAMX should be stored into mem. 225 // Todo: is it key amx can be no def? 226 return Loads.empty() && (ST == cast<Value>(KeyAMX)); 227} 228 229bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX, 230 SmallVector<Value *, 8> &Shapes) { 231 for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) { 232 Value *Op = KeyAMX->getOperand(I); 233 if (!Op->getType()->isX86_AMXTy()) 234 continue; 235 IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op); 236 assert((TileDef && isTileLoad(TileDef)) && 237 "All KeyAMX's tile definiation should comes from TileLoad!"); 238 Shapes.push_back(TileDef->getOperand(0)); 239 Shapes.push_back(TileDef->getOperand(1)); 240 } 241 if (!isTileStore(KeyAMX)) { 242 Shapes.push_back(KeyAMX->getOperand(0)); 243 Shapes.push_back(KeyAMX->getOperand(1)); 244 } 245 return Shapes.size() != 0; 246} 247 248// Collect the shapes and skip the area of current key amx intrinsic. 249// 250// For example: 251// ... 252// -------------------------------------------------------------------------- 253// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k) 254// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k) 255// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k) 256// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) 257// call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k) 258// -------------------------------------------------------------------------- 259BasicBlock::iterator 260X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter, 261 SmallVector<Value *, 8> &Shapes) { 262 IntrinsicInst *KeyAMX = nullptr; 263 BasicBlock *BB = Iter->getParent(); 264 BasicBlock::iterator PosEnd = BB->end(); 265 SmallSet<Value *, 4> Loads; 266 267 // See TileStore as "Config Position End" and check volatile model. 268 for (auto I = Iter, E = BB->end(); I != E; ++I) { 269 assert(!brokenVolatile(&*I) && "Not reach tile store!"); 270 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); 271 if (!II || !isAMXIntrinsic(II)) 272 continue; 273 274 if (isTileLoad(II)) { 275 Loads.insert(II); 276 } else if (isTileStore(II)) { 277 if (!checkVolatileModel(Loads, II, KeyAMX)) 278 report_fatal_error("Not Volatile AMX Model!"); 279 PosEnd = I; 280 break; 281 } else { 282 assert(!KeyAMX && "Too many key amx intrinsic!"); 283 KeyAMX = II; 284 } 285 } 286 assert(PosEnd != BB->end() && "Not find TileStore!"); 287 288 // See KeyAMX as TileStore if only TileLoad and TileStore. 289 if (!KeyAMX) 290 KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd); 291 292 // Get Shapes in order. 293 assert(Shapes.empty() && "Shapes should be clean."); 294 getKeyAMXShapes(KeyAMX, Shapes); 295 296 return PosEnd; 297} 298 299// Record a key amx area's shapes with its position. 300// Use the first tileload as its position. 301// For example: 302// ... 303// -------------------------------------------------------------------------- 304// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos 305// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) / 306// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes: 307// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n) 308// call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n) 309// -------------------------------------------------------------------------- 310bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) { 311 bool Find = false; 312 for (BasicBlock &BB : F) { 313 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) { 314 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); 315 if (!II) 316 continue; 317 if (!isAMXIntrinsic(II)) 318 continue; 319 assert(onlyTileDef(II) && "Not volatile model for AMX at O0!"); 320 321 I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]); 322 Find = true; 323 } 324 } 325 return Find; 326} 327 328// Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic. 329// e.g. (key amx = tdpbssd) 330// -------------------------------------------------------------------------- 331// %cfgmem = alloca <16 x i32>, align 4 * allocate mem 332// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init 333// ... 334// ... pre-config shape of %t1 * 335// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 336// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 337// ... * 338// ... pre-config shape of %t2 * 339// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes 340// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 341// ... * 342// ... pre-config shape of %t3 * of 343// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * 344// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * 345// ... * tiles 346// ... pre-config shape of %td * 347// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * 348// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * 349// 350// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config 351// -------------------------------------------------------------------------- 352// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 353// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 354// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 355// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 356// call void @llvm.x86.tilestored64.internal(... td) area 357// -------------------------------------------------------------------------- 358bool X86PreAMXConfig::preTileConfig() { 359 PosAndShapesMap PosAndShapes; 360 bool NeedCfg = findConfigShapes(PosAndShapes); 361 if (!NeedCfg) 362 return false; 363 for (auto &IPAndShapes : PosAndShapes) 364 addTileConfig(IPAndShapes.first, IPAndShapes.second); 365 366 return true; 367} 368} // anonymous namespace 369 370namespace { 371 372class X86PreAMXConfigPass : public FunctionPass { 373public: 374 static char ID; 375 376 X86PreAMXConfigPass() : FunctionPass(ID) { 377 initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry()); 378 } 379 380 bool runOnFunction(Function &F) override { 381 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 382 bool C = false; 383 384 // Prepare for fast register allocation at O0. 385 if (TM->getOptLevel() == CodeGenOpt::None) { 386 387 // We pre-config each key AMX intrinsic at O0. 388 // In theory, one tile config can cover several AMX intrinsics, but 389 // it is very diffcult to classify the tile shapes at O0. So here we 390 // let thing be easy, pre-config every key AMX intrinsic. 391 X86PreAMXConfig PCFG(F); 392 C = PCFG.preTileConfig(); 393 } 394 395 return C; 396 } 397 398 void getAnalysisUsage(AnalysisUsage &AU) const override { 399 AU.setPreservesCFG(); 400 AU.addRequired<TargetPassConfig>(); 401 } 402}; 403 404} // anonymous namespace 405 406static const char PassName[] = "Pre AMX Tile Config"; 407char X86PreAMXConfigPass::ID = 0; 408INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) 409INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 410INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) 411 412FunctionPass *llvm::createX86PreAMXConfigPass() { 413 return new X86PreAMXConfigPass(); 414} 415