1/*===--------- avx512vlbf16intrin.h - AVX512_BF16 intrinsics ---------------===
2 *
3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 * See https://llvm.org/LICENSE.txt for license information.
5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 *
7 *===-----------------------------------------------------------------------===
8 */
9#ifndef __IMMINTRIN_H
10#error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
11#endif
12
13#ifndef __AVX512VLBF16INTRIN_H
14#define __AVX512VLBF16INTRIN_H
15
16typedef short __m128bh __attribute__((__vector_size__(16), __aligned__(16)));
17
18#define __DEFAULT_FN_ATTRS128 \
19  __attribute__((__always_inline__, __nodebug__, \
20                 __target__("avx512vl, avx512bf16"), __min_vector_width__(128)))
21#define __DEFAULT_FN_ATTRS256 \
22  __attribute__((__always_inline__, __nodebug__, \
23                 __target__("avx512vl, avx512bf16"), __min_vector_width__(256)))
24
25/// Convert Two Packed Single Data to One Packed BF16 Data.
26///
27/// \headerfile <x86intrin.h>
28///
29/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
30///
31/// \param __A
32///    A 128-bit vector of [4 x float].
33/// \param __B
34///    A 128-bit vector of [4 x float].
35/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
36///    conversion of __B, and higher 64 bits come from conversion of __A.
37static __inline__ __m128bh __DEFAULT_FN_ATTRS128
38_mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
39  return (__m128bh)__builtin_ia32_cvtne2ps2bf16_128((__v4sf) __A,
40                                                    (__v4sf) __B);
41}
42
43/// Convert Two Packed Single Data to One Packed BF16 Data.
44///
45/// \headerfile <x86intrin.h>
46///
47/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
48///
49/// \param __A
50///    A 128-bit vector of [4 x float].
51/// \param __B
52///    A 128-bit vector of [4 x float].
53/// \param __W
54///    A 128-bit vector of [8 x bfloat].
55/// \param __U
56///    A 8-bit mask value specifying what is chosen for each element.
57///    A 1 means conversion of __A or __B. A 0 means element from __W.
58/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
59///    conversion of __B, and higher 64 bits come from conversion of __A.
60static __inline__ __m128bh __DEFAULT_FN_ATTRS128
61_mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
62  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
63                                             (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
64                                             (__v8hi)__W);
65}
66
67/// Convert Two Packed Single Data to One Packed BF16 Data.
68///
69/// \headerfile <x86intrin.h>
70///
71/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
72///
73/// \param __A
74///    A 128-bit vector of [4 x float].
75/// \param __B
76///    A 128-bit vector of [4 x float].
77/// \param __U
78///    A 8-bit mask value specifying what is chosen for each element.
79///    A 1 means conversion of __A or __B. A 0 means element is zero.
80/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
81///    conversion of __B, and higher 64 bits come from conversion of __A.
82static __inline__ __m128bh __DEFAULT_FN_ATTRS128
83_mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
84  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
85                                             (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
86                                             (__v8hi)_mm_setzero_si128());
87}
88
89/// Convert Two Packed Single Data to One Packed BF16 Data.
90///
91/// \headerfile <x86intrin.h>
92///
93/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
94///
95/// \param __A
96///    A 256-bit vector of [8 x float].
97/// \param __B
98///    A 256-bit vector of [8 x float].
99/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
100///    conversion of __B, and higher 128 bits come from conversion of __A.
101static __inline__ __m256bh __DEFAULT_FN_ATTRS256
102_mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
103  return (__m256bh)__builtin_ia32_cvtne2ps2bf16_256((__v8sf) __A,
104                                                    (__v8sf) __B);
105}
106
107/// Convert Two Packed Single Data to One Packed BF16 Data.
108///
109/// \headerfile <x86intrin.h>
110///
111/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
112///
113/// \param __A
114///    A 256-bit vector of [8 x float].
115/// \param __B
116///    A 256-bit vector of [8 x float].
117/// \param __W
118///    A 256-bit vector of [16 x bfloat].
119/// \param __U
120///    A 16-bit mask value specifying what is chosen for each element.
121///    A 1 means conversion of __A or __B. A 0 means element from __W.
122/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
123///    conversion of __B, and higher 128 bits come from conversion of __A.
124static __inline__ __m256bh __DEFAULT_FN_ATTRS256
125_mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
126  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
127                                         (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
128                                         (__v16hi)__W);
129}
130
131/// Convert Two Packed Single Data to One Packed BF16 Data.
132///
133/// \headerfile <x86intrin.h>
134///
135/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
136///
137/// \param __A
138///    A 256-bit vector of [8 x float].
139/// \param __B
140///    A 256-bit vector of [8 x float].
141/// \param __U
142///    A 16-bit mask value specifying what is chosen for each element.
143///    A 1 means conversion of __A or __B. A 0 means element is zero.
144/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
145///    conversion of __B, and higher 128 bits come from conversion of __A.
146static __inline__ __m256bh __DEFAULT_FN_ATTRS256
147_mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
148  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
149                                         (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
150                                         (__v16hi)_mm256_setzero_si256());
151}
152
153/// Convert Packed Single Data to Packed BF16 Data.
154///
155/// \headerfile <x86intrin.h>
156///
157/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
158///
159/// \param __A
160///    A 128-bit vector of [4 x float].
161/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
162///    conversion of __A, and higher 64 bits are 0.
163static __inline__ __m128bh __DEFAULT_FN_ATTRS128
164_mm_cvtneps_pbh(__m128 __A) {
165  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
166                                                  (__v8hi)_mm_undefined_si128(),
167                                                  (__mmask8)-1);
168}
169
170/// Convert Packed Single Data to Packed BF16 Data.
171///
172/// \headerfile <x86intrin.h>
173///
174/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
175///
176/// \param __A
177///    A 128-bit vector of [4 x float].
178/// \param __W
179///    A 128-bit vector of [8 x bfloat].
180/// \param __U
181///    A 4-bit mask value specifying what is chosen for each element.
182///    A 1 means conversion of __A. A 0 means element from __W.
183/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
184///    conversion of __A, and higher 64 bits are 0.
185static __inline__ __m128bh __DEFAULT_FN_ATTRS128
186_mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
187  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
188                                                        (__v8hi)__W,
189                                                        (__mmask8)__U);
190}
191
192/// Convert Packed Single Data to Packed BF16 Data.
193///
194/// \headerfile <x86intrin.h>
195///
196/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
197///
198/// \param __A
199///    A 128-bit vector of [4 x float].
200/// \param __U
201///    A 4-bit mask value specifying what is chosen for each element.
202///    A 1 means conversion of __A. A 0 means element is zero.
203/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
204///    conversion of __A, and higher 64 bits are 0.
205static __inline__ __m128bh __DEFAULT_FN_ATTRS128
206_mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
207  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
208                                                    (__v8hi)_mm_setzero_si128(),
209                                                    (__mmask8)__U);
210}
211
212/// Convert Packed Single Data to Packed BF16 Data.
213///
214/// \headerfile <x86intrin.h>
215///
216/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
217///
218/// \param __A
219///    A 256-bit vector of [8 x float].
220/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
221static __inline__ __m128bh __DEFAULT_FN_ATTRS256
222_mm256_cvtneps_pbh(__m256 __A) {
223  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
224                                                  (__v8hi)_mm_undefined_si128(),
225                                                  (__mmask8)-1);
226}
227
228/// Convert Packed Single Data to Packed BF16 Data.
229///
230/// \headerfile <x86intrin.h>
231///
232/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
233///
234/// \param __A
235///    A 256-bit vector of [8 x float].
236/// \param __W
237///    A 256-bit vector of [8 x bfloat].
238/// \param __U
239///    A 8-bit mask value specifying what is chosen for each element.
240///    A 1 means conversion of __A. A 0 means element from __W.
241/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
242static __inline__ __m128bh __DEFAULT_FN_ATTRS256
243_mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
244  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
245                                                        (__v8hi)__W,
246                                                        (__mmask8)__U);
247}
248
249/// Convert Packed Single Data to Packed BF16 Data.
250///
251/// \headerfile <x86intrin.h>
252///
253/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
254///
255/// \param __A
256///    A 256-bit vector of [8 x float].
257/// \param __U
258///    A 8-bit mask value specifying what is chosen for each element.
259///    A 1 means conversion of __A. A 0 means element is zero.
260/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
261static __inline__ __m128bh __DEFAULT_FN_ATTRS256
262_mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
263  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
264                                                    (__v8hi)_mm_setzero_si128(),
265                                                    (__mmask8)__U);
266}
267
268/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
269///
270/// \headerfile <x86intrin.h>
271///
272/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
273///
274/// \param __A
275///    A 128-bit vector of [8 x bfloat].
276/// \param __B
277///    A 128-bit vector of [8 x bfloat].
278/// \param __D
279///    A 128-bit vector of [4 x float].
280/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
281///  __A, __B and __D
282static __inline__ __m128 __DEFAULT_FN_ATTRS128
283_mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
284  return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
285                                             (__v4si)__A,
286                                             (__v4si)__B);
287}
288
289/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
290///
291/// \headerfile <x86intrin.h>
292///
293/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
294///
295/// \param __A
296///    A 128-bit vector of [8 x bfloat].
297/// \param __B
298///    A 128-bit vector of [8 x bfloat].
299/// \param __D
300///    A 128-bit vector of [4 x float].
301/// \param __U
302///    A 8-bit mask value specifying what is chosen for each element.
303///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
304/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
305///  __A, __B and __D
306static __inline__ __m128 __DEFAULT_FN_ATTRS128
307_mm_mask_dpbf16_ps(__m128 __D, __mmask8 __U, __m128bh __A, __m128bh __B) {
308  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
309                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
310                                           (__v4sf)__D);
311}
312
313/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
314///
315/// \headerfile <x86intrin.h>
316///
317/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
318///
319/// \param __A
320///    A 128-bit vector of [8 x bfloat].
321/// \param __B
322///    A 128-bit vector of [8 x bfloat].
323/// \param __D
324///    A 128-bit vector of [4 x float].
325/// \param __U
326///    A 8-bit mask value specifying what is chosen for each element.
327///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
328/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
329///  __A, __B and __D
330static __inline__ __m128 __DEFAULT_FN_ATTRS128
331_mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
332  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
333                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
334                                           (__v4sf)_mm_setzero_si128());
335}
336
337/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
338///
339/// \headerfile <x86intrin.h>
340///
341/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
342///
343/// \param __A
344///    A 256-bit vector of [16 x bfloat].
345/// \param __B
346///    A 256-bit vector of [16 x bfloat].
347/// \param __D
348///    A 256-bit vector of [8 x float].
349/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
350///  __A, __B and __D
351static __inline__ __m256 __DEFAULT_FN_ATTRS256
352_mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
353  return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
354                                             (__v8si)__A,
355                                             (__v8si)__B);
356}
357
358/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
359///
360/// \headerfile <x86intrin.h>
361///
362/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
363///
364/// \param __A
365///    A 256-bit vector of [16 x bfloat].
366/// \param __B
367///    A 256-bit vector of [16 x bfloat].
368/// \param __D
369///    A 256-bit vector of [8 x float].
370/// \param __U
371///    A 16-bit mask value specifying what is chosen for each element.
372///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
373/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
374///  __A, __B and __D
375static __inline__ __m256 __DEFAULT_FN_ATTRS256
376_mm256_mask_dpbf16_ps(__m256 __D, __mmask8 __U, __m256bh __A, __m256bh __B) {
377  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
378                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
379                                        (__v8sf)__D);
380}
381
382/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
383///
384/// \headerfile <x86intrin.h>
385///
386/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
387///
388/// \param __A
389///    A 256-bit vector of [16 x bfloat].
390/// \param __B
391///    A 256-bit vector of [16 x bfloat].
392/// \param __D
393///    A 256-bit vector of [8 x float].
394/// \param __U
395///    A 8-bit mask value specifying what is chosen for each element.
396///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
397/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
398///  __A, __B and __D
399static __inline__ __m256 __DEFAULT_FN_ATTRS256
400_mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
401  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
402                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
403                                        (__v8sf)_mm256_setzero_si256());
404}
405
406/// Convert One Single float Data to One BF16 Data.
407///
408/// \headerfile <x86intrin.h>
409///
410/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
411///
412/// \param __A
413///    A float data.
414/// \returns A bf16 data whose sign field and exponent field keep unchanged,
415///    and fraction field is truncated to 7 bits.
416static __inline__ __bfloat16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
417  __v4sf __V = {__A, 0, 0, 0};
418  __v8hi __R = __builtin_ia32_cvtneps2bf16_128_mask(
419      (__v4sf)__V, (__v8hi)_mm_undefined_si128(), (__mmask8)-1);
420  return __R[0];
421}
422
423/// Convert Packed BF16 Data to Packed float Data.
424///
425/// \headerfile <x86intrin.h>
426///
427/// \param __A
428///    A 128-bit vector of [8 x bfloat].
429/// \returns A 256-bit vector of [8 x float] come from convertion of __A
430static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
431  return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
432      (__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
433}
434
435/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
436///
437/// \headerfile <x86intrin.h>
438///
439/// \param __U
440///    A 8-bit mask. Elements are zeroed out when the corresponding mask
441///    bit is not set.
442/// \param __A
443///    A 128-bit vector of [8 x bfloat].
444/// \returns A 256-bit vector of [8 x float] come from convertion of __A
445static __inline__ __m256 __DEFAULT_FN_ATTRS256
446_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
447  return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
448      (__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
449}
450
451/// Convert Packed BF16 Data to Packed float Data using merging mask.
452///
453/// \headerfile <x86intrin.h>
454///
455/// \param __S
456///    A 256-bit vector of [8 x float]. Elements are copied from __S when
457///     the corresponding mask bit is not set.
458/// \param __U
459///    A 8-bit mask. Elements are zeroed out when the corresponding mask
460///    bit is not set.
461/// \param __A
462///    A 128-bit vector of [8 x bfloat].
463/// \returns A 256-bit vector of [8 x float] come from convertion of __A
464static __inline__ __m256 __DEFAULT_FN_ATTRS256
465_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
466  return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
467      (__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
468      16));
469}
470
471#undef __DEFAULT_FN_ATTRS128
472#undef __DEFAULT_FN_ATTRS256
473
474#endif
475