1/*===--------------- amxintrin.h - AMX intrinsics -*- C/C++ -*---------------===
2 *
3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 * See https://llvm.org/LICENSE.txt for license information.
5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 *
7 *===------------------------------------------------------------------------===
8 */
9
10#ifndef __IMMINTRIN_H
11#error "Never use <amxintrin.h> directly; include <immintrin.h> instead."
12#endif /* __IMMINTRIN_H */
13
14#ifndef __AMXINTRIN_H
15#define __AMXINTRIN_H
16#ifdef __x86_64__
17
18/* Define the default attributes for the functions in this file. */
19#define __DEFAULT_FN_ATTRS_TILE                                                \
20  __attribute__((__always_inline__, __nodebug__, __target__("amx-tile")))
21#define __DEFAULT_FN_ATTRS_INT8                                                \
22  __attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
23#define __DEFAULT_FN_ATTRS_BF16                                                \
24  __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16")))
25
26/// Load tile configuration from a 64-byte memory location specified by
27/// "mem_addr". The tile configuration includes the tile type palette, the
28/// number of bytes per row, and the number of rows. If the specified
29/// palette_id is zero, that signifies the init state for both the tile
30/// config and the tile data, and the tiles are zeroed. Any invalid
31/// configurations will result in #GP fault.
32///
33/// \headerfile <immintrin.h>
34///
35/// This intrinsic corresponds to the <c> LDTILECFG </c> instruction.
36///
37/// \param __config
38///    A pointer to 512-bits configuration
39static __inline__ void __DEFAULT_FN_ATTRS_TILE
40_tile_loadconfig(const void *__config) {
41  __builtin_ia32_tile_loadconfig(__config);
42}
43
44/// Stores the current tile configuration to a 64-byte memory location
45/// specified by "mem_addr". The tile configuration includes the tile type
46/// palette, the number of bytes per row, and the number of rows. If tiles
47/// are not configured, all zeroes will be stored to memory.
48///
49/// \headerfile <immintrin.h>
50///
51/// This intrinsic corresponds to the <c> STTILECFG </c> instruction.
52///
53/// \param __config
54///    A pointer to 512-bits configuration
55static __inline__ void __DEFAULT_FN_ATTRS_TILE
56_tile_storeconfig(void *__config) {
57  __builtin_ia32_tile_storeconfig(__config);
58}
59
60/// Release the tile configuration to return to the init state, which
61/// releases all storage it currently holds.
62///
63/// \headerfile <immintrin.h>
64///
65/// This intrinsic corresponds to the <c> TILERELEASE </c> instruction.
66static __inline__ void __DEFAULT_FN_ATTRS_TILE _tile_release(void) {
67  __builtin_ia32_tilerelease();
68}
69
70/// Load tile rows from memory specifieid by "base" address and "stride" into
71/// destination tile "dst" using the tile configuration previously configured
72/// via "_tile_loadconfig".
73///
74/// \headerfile <immintrin.h>
75///
76/// This intrinsic corresponds to the <c> TILELOADD </c> instruction.
77///
78/// \param dst
79///    A destination tile. Max size is 1024 Bytes.
80/// \param base
81///    A pointer to base address.
82/// \param stride
83///    The stride between the rows' data to be loaded in memory.
84#define _tile_loadd(dst, base, stride)                                         \
85  __builtin_ia32_tileloadd64((dst), ((const void *)(base)),                    \
86                             (__SIZE_TYPE__)(stride))
87
88/// Load tile rows from memory specifieid by "base" address and "stride" into
89/// destination tile "dst" using the tile configuration previously configured
90/// via "_tile_loadconfig". This intrinsic provides a hint to the implementation
91/// that the data will likely not be reused in the near future and the data
92/// caching can be optimized accordingly.
93///
94/// \headerfile <immintrin.h>
95///
96/// This intrinsic corresponds to the <c> TILELOADDT1 </c> instruction.
97///
98/// \param dst
99///    A destination tile. Max size is 1024 Bytes.
100/// \param base
101///    A pointer to base address.
102/// \param stride
103///    The stride between the rows' data to be loaded in memory.
104#define _tile_stream_loadd(dst, base, stride)                                  \
105  __builtin_ia32_tileloaddt164((dst), ((const void *)(base)),                  \
106                               (__SIZE_TYPE__)(stride))
107
108/// Store the tile specified by "src" to memory specifieid by "base" address and
109/// "stride" using the tile configuration previously configured via
110/// "_tile_loadconfig".
111///
112/// \headerfile <immintrin.h>
113///
114/// This intrinsic corresponds to the <c> TILESTORED </c> instruction.
115///
116/// \param dst
117///    A destination tile. Max size is 1024 Bytes.
118/// \param base
119///    A pointer to base address.
120/// \param stride
121///    The stride between the rows' data to be stored in memory.
122#define _tile_stored(dst, base, stride)                                        \
123  __builtin_ia32_tilestored64((dst), ((void *)(base)), (__SIZE_TYPE__)(stride))
124
125/// Zero the tile specified by "tdest".
126///
127/// \headerfile <immintrin.h>
128///
129/// This intrinsic corresponds to the <c> TILEZERO </c> instruction.
130///
131/// \param tile
132///    The destination tile to be zero. Max size is 1024 Bytes.
133#define _tile_zero(tile) __builtin_ia32_tilezero((tile))
134
135/// Compute dot-product of bytes in tiles with a source/destination accumulator.
136/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in src0 with
137/// corresponding signed 8-bit integers in src1, producing 4 intermediate 32-bit
138/// results. Sum these 4 results with the corresponding 32-bit integer in "dst",
139/// and store the 32-bit result back to tile "dst".
140///
141/// \headerfile <immintrin.h>
142///
143/// This intrinsic corresponds to the <c> TDPBSSD </c> instruction.
144///
145/// \param dst
146///    The destination tile. Max size is 1024 Bytes.
147/// \param src0
148///    The 1st source tile. Max size is 1024 Bytes.
149/// \param src1
150///    The 2nd source tile. Max size is 1024 Bytes.
151#define _tile_dpbssd(dst, src0, src1)                                          \
152  __builtin_ia32_tdpbssd((dst), (src0), (src1))
153
154/// Compute dot-product of bytes in tiles with a source/destination accumulator.
155/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in src0 with
156/// corresponding unsigned 8-bit integers in src1, producing 4 intermediate
157/// 32-bit results. Sum these 4 results with the corresponding 32-bit integer
158/// in "dst", and store the 32-bit result back to tile "dst".
159///
160/// \headerfile <immintrin.h>
161///
162/// This intrinsic corresponds to the <c> TDPBSUD </c> instruction.
163///
164/// \param dst
165///    The destination tile. Max size is 1024 Bytes.
166/// \param src0
167///    The 1st source tile. Max size is 1024 Bytes.
168/// \param src1
169///    The 2nd source tile. Max size is 1024 Bytes.
170#define _tile_dpbsud(dst, src0, src1)                                          \
171  __builtin_ia32_tdpbsud((dst), (src0), (src1))
172
173/// Compute dot-product of bytes in tiles with a source/destination accumulator.
174/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in src0 with
175/// corresponding signed 8-bit integers in src1, producing 4 intermediate 32-bit
176/// results. Sum these 4 results with the corresponding 32-bit integer in "dst",
177/// and store the 32-bit result back to tile "dst".
178///
179/// \headerfile <immintrin.h>
180///
181/// This intrinsic corresponds to the <c> TDPBUSD </c> instruction.
182///
183/// \param dst
184///    The destination tile. Max size is 1024 Bytes.
185/// \param src0
186///    The 1st source tile. Max size is 1024 Bytes.
187/// \param src1
188///    The 2nd source tile. Max size is 1024 Bytes.
189#define _tile_dpbusd(dst, src0, src1)                                          \
190  __builtin_ia32_tdpbusd((dst), (src0), (src1))
191
192/// Compute dot-product of bytes in tiles with a source/destination accumulator.
193/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in src0 with
194/// corresponding unsigned 8-bit integers in src1, producing 4 intermediate
195/// 32-bit results. Sum these 4 results with the corresponding 32-bit integer in
196/// "dst", and store the 32-bit result back to tile "dst".
197///
198/// \headerfile <immintrin.h>
199///
200/// This intrinsic corresponds to the <c> TDPBUUD </c> instruction.
201///
202/// \param dst
203///    The destination tile. Max size is 1024 Bytes.
204/// \param src0
205///    The 1st source tile. Max size is 1024 Bytes.
206/// \param src1
207///    The 2nd source tile. Max size is 1024 Bytes.
208#define _tile_dpbuud(dst, src0, src1)                                          \
209  __builtin_ia32_tdpbuud((dst), (src0), (src1))
210
211/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles src0 and
212/// src1, accumulating the intermediate single-precision (32-bit) floating-point
213/// elements with elements in "dst", and store the 32-bit result back to tile
214/// "dst".
215///
216/// \headerfile <immintrin.h>
217///
218/// This intrinsic corresponds to the <c> TDPBF16PS </c> instruction.
219///
220/// \param dst
221///    The destination tile. Max size is 1024 Bytes.
222/// \param src0
223///    The 1st source tile. Max size is 1024 Bytes.
224/// \param src1
225///    The 2nd source tile. Max size is 1024 Bytes.
226#define _tile_dpbf16ps(dst, src0, src1)                                        \
227  __builtin_ia32_tdpbf16ps((dst), (src0), (src1))
228
229/// AMX tile register size can be configured, the maximum size is 16x64=1024
230/// bytes. Since there is no 2D type in llvm IR, we use vector type to
231/// represent 2D tile and the fixed size is maximum amx tile register size.
232typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
233
234/// This is internal intrinsic. C/C++ user should avoid calling it directly.
235static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
236_tile_loadd_internal(unsigned short m, unsigned short n, const void *base,
237                     __SIZE_TYPE__ stride) {
238  return __builtin_ia32_tileloadd64_internal(m, n, base,
239                                             (__SIZE_TYPE__)(stride));
240}
241
242/// This is internal intrinsic. C/C++ user should avoid calling it directly.
243static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
244_tile_dpbssd_internal(unsigned short m, unsigned short n, unsigned short k,
245                      _tile1024i dst, _tile1024i src1, _tile1024i src2) {
246  return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
247}
248
249/// This is internal intrinsic. C/C++ user should avoid calling it directly.
250static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
251_tile_dpbsud_internal(unsigned short m, unsigned short n, unsigned short k,
252                      _tile1024i dst, _tile1024i src1, _tile1024i src2) {
253  return __builtin_ia32_tdpbsud_internal(m, n, k, dst, src1, src2);
254}
255
256/// This is internal intrinsic. C/C++ user should avoid calling it directly.
257static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
258_tile_dpbusd_internal(unsigned short m, unsigned short n, unsigned short k,
259                      _tile1024i dst, _tile1024i src1, _tile1024i src2) {
260  return __builtin_ia32_tdpbusd_internal(m, n, k, dst, src1, src2);
261}
262
263/// This is internal intrinsic. C/C++ user should avoid calling it directly.
264static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
265_tile_dpbuud_internal(unsigned short m, unsigned short n, unsigned short k,
266                      _tile1024i dst, _tile1024i src1, _tile1024i src2) {
267  return __builtin_ia32_tdpbuud_internal(m, n, k, dst, src1, src2);
268}
269
270/// This is internal intrinsic. C/C++ user should avoid calling it directly.
271static __inline__ void __DEFAULT_FN_ATTRS_INT8
272_tile_stored_internal(unsigned short m, unsigned short n, void *base,
273                      __SIZE_TYPE__ stride, _tile1024i tile) {
274  return __builtin_ia32_tilestored64_internal(m, n, base,
275                                              (__SIZE_TYPE__)(stride), tile);
276}
277
278/// This is internal intrinsic. C/C++ user should avoid calling it directly.
279static __inline__ _tile1024i __DEFAULT_FN_ATTRS_BF16
280_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
281                        _tile1024i dst, _tile1024i src1, _tile1024i src2) {
282  return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
283}
284
285/// This struct pack the shape and tile data together for user. We suggest
286/// initializing the struct as early as possible, because compiler depends
287/// on the shape information to do configure. The constant value is preferred
288/// for optimization by compiler.
289typedef struct __tile1024i_str {
290  const unsigned short row;
291  const unsigned short col;
292  _tile1024i tile;
293} __tile1024i;
294
295/// Load tile rows from memory specifieid by "base" address and "stride" into
296/// destination tile "dst".
297///
298/// \headerfile <immintrin.h>
299///
300/// This intrinsic corresponds to the <c> TILELOADD </c> instruction.
301///
302/// \param dst
303///    A destination tile. Max size is 1024 Bytes.
304/// \param base
305///    A pointer to base address.
306/// \param stride
307///    The stride between the rows' data to be loaded in memory.
308__DEFAULT_FN_ATTRS_TILE
309static void __tile_loadd(__tile1024i *dst, const void *base,
310                         __SIZE_TYPE__ stride) {
311  dst->tile = _tile_loadd_internal(dst->row, dst->col, base, stride);
312}
313
314/// Compute dot-product of bytes in tiles with a source/destination accumulator.
315/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in src0 with
316/// corresponding signed 8-bit integers in src1, producing 4 intermediate 32-bit
317/// results. Sum these 4 results with the corresponding 32-bit integer in "dst",
318/// and store the 32-bit result back to tile "dst".
319///
320/// \headerfile <immintrin.h>
321///
322/// This intrinsic corresponds to the <c> TDPBSSD </c> instruction.
323///
324/// \param dst
325///    The destination tile. Max size is 1024 Bytes.
326/// \param src0
327///    The 1st source tile. Max size is 1024 Bytes.
328/// \param src1
329///    The 2nd source tile. Max size is 1024 Bytes.
330__DEFAULT_FN_ATTRS_INT8
331static void __tile_dpbssd(__tile1024i *dst, __tile1024i src0,
332                          __tile1024i src1) {
333  dst->tile = _tile_dpbssd_internal(src0.row, src1.col, src0.col, dst->tile,
334                                    src0.tile, src1.tile);
335}
336
337/// Compute dot-product of bytes in tiles with a source/destination accumulator.
338/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in src0 with
339/// corresponding unsigned 8-bit integers in src1, producing 4 intermediate
340/// 32-bit results. Sum these 4 results with the corresponding 32-bit integer
341/// in "dst", and store the 32-bit result back to tile "dst".
342///
343/// \headerfile <immintrin.h>
344///
345/// This intrinsic corresponds to the <c> TDPBSUD </c> instruction.
346///
347/// \param dst
348///    The destination tile. Max size is 1024 Bytes.
349/// \param src0
350///    The 1st source tile. Max size is 1024 Bytes.
351/// \param src1
352///    The 2nd source tile. Max size is 1024 Bytes.
353__DEFAULT_FN_ATTRS_INT8
354static void __tile_dpbsud(__tile1024i *dst, __tile1024i src0,
355                          __tile1024i src1) {
356  dst->tile = _tile_dpbsud_internal(src0.row, src1.col, src0.col, dst->tile,
357                                    src0.tile, src1.tile);
358}
359
360/// Compute dot-product of bytes in tiles with a source/destination accumulator.
361/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in src0 with
362/// corresponding signed 8-bit integers in src1, producing 4 intermediate 32-bit
363/// results. Sum these 4 results with the corresponding 32-bit integer in "dst",
364/// and store the 32-bit result back to tile "dst".
365///
366/// \headerfile <immintrin.h>
367///
368/// This intrinsic corresponds to the <c> TDPBUSD </c> instruction.
369///
370/// \param dst
371///    The destination tile. Max size is 1024 Bytes.
372/// \param src0
373///    The 1st source tile. Max size is 1024 Bytes.
374/// \param src1
375///    The 2nd source tile. Max size is 1024 Bytes.
376__DEFAULT_FN_ATTRS_INT8
377static void __tile_dpbusd(__tile1024i *dst, __tile1024i src0,
378                          __tile1024i src1) {
379  dst->tile = _tile_dpbusd_internal(src0.row, src1.col, src0.col, dst->tile,
380                                    src0.tile, src1.tile);
381}
382
383/// Compute dot-product of bytes in tiles with a source/destination accumulator.
384/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in src0 with
385/// corresponding unsigned 8-bit integers in src1, producing 4 intermediate
386/// 32-bit results. Sum these 4 results with the corresponding 32-bit integer in
387/// "dst", and store the 32-bit result back to tile "dst".
388///
389/// \headerfile <immintrin.h>
390///
391/// This intrinsic corresponds to the <c> TDPBUUD </c> instruction.
392///
393/// \param dst
394///    The destination tile. Max size is 1024 Bytes.
395/// \param src0
396///    The 1st source tile. Max size is 1024 Bytes.
397/// \param src1
398///    The 2nd source tile. Max size is 1024 Bytes.
399__DEFAULT_FN_ATTRS_INT8
400static void __tile_dpbuud(__tile1024i *dst, __tile1024i src0,
401                          __tile1024i src1) {
402  dst->tile = _tile_dpbuud_internal(src0.row, src1.col, src0.col, dst->tile,
403                                    src0.tile, src1.tile);
404}
405
406/// Store the tile specified by "src" to memory specifieid by "base" address and
407/// "stride".
408///
409/// \headerfile <immintrin.h>
410///
411/// This intrinsic corresponds to the <c> TILESTORED </c> instruction.
412///
413/// \param dst
414///    A destination tile. Max size is 1024 Bytes.
415/// \param base
416///    A pointer to base address.
417/// \param stride
418///    The stride between the rows' data to be stored in memory.
419__DEFAULT_FN_ATTRS_TILE
420static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) {
421  _tile_stored_internal(src.row, src.col, base, stride, src.tile);
422}
423
424/// Zero the tile specified by "dst".
425///
426/// \headerfile <immintrin.h>
427///
428/// This intrinsic corresponds to the <c> TILEZERO </c> instruction.
429///
430/// \param dst
431///    The destination tile to be zero. Max size is 1024 Bytes.
432__DEFAULT_FN_ATTRS_TILE
433static void __tile_zero(__tile1024i *dst) {
434  dst->tile = __builtin_ia32_tilezero_internal(dst->row, dst->col);
435}
436
437/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles src0 and
438/// src1, accumulating the intermediate single-precision (32-bit) floating-point
439/// elements with elements in "dst", and store the 32-bit result back to tile
440/// "dst".
441///
442/// \headerfile <immintrin.h>
443///
444/// This intrinsic corresponds to the <c> TDPBF16PS </c> instruction.
445///
446/// \param dst
447///    The destination tile. Max size is 1024 Bytes.
448/// \param src0
449///    The 1st source tile. Max size is 1024 Bytes.
450/// \param src1
451///    The 2nd source tile. Max size is 1024 Bytes.
452__DEFAULT_FN_ATTRS_BF16
453static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src0,
454                            __tile1024i src1) {
455  dst->tile = _tile_dpbf16ps_internal(src0.row, src1.col, src0.col, dst->tile,
456                                      src0.tile, src1.tile);
457}
458
459#undef __DEFAULT_FN_ATTRS_TILE
460#undef __DEFAULT_FN_ATTRS_INT8
461#undef __DEFAULT_FN_ATTRS_BF16
462
463#endif /* __x86_64__ */
464#endif /* __AMXINTRIN_H */
465