1/*===--------- avx512vlbf16intrin.h - AVX512_BF16 intrinsics ---------------===
2 *
3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 * See https://llvm.org/LICENSE.txt for license information.
5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 *
7 *===-----------------------------------------------------------------------===
8 */
9#ifndef __IMMINTRIN_H
10#error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
11#endif
12
13#ifdef __SSE2__
14
15#ifndef __AVX512VLBF16INTRIN_H
16#define __AVX512VLBF16INTRIN_H
17
18#define __DEFAULT_FN_ATTRS128 \
19  __attribute__((__always_inline__, __nodebug__, \
20                 __target__("avx512vl, avx512bf16"), __min_vector_width__(128)))
21#define __DEFAULT_FN_ATTRS256 \
22  __attribute__((__always_inline__, __nodebug__, \
23                 __target__("avx512vl, avx512bf16"), __min_vector_width__(256)))
24
25/// Convert Two Packed Single Data to One Packed BF16 Data.
26///
27/// \headerfile <x86intrin.h>
28///
29/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
30///
31/// \param __A
32///    A 128-bit vector of [4 x float].
33/// \param __B
34///    A 128-bit vector of [4 x float].
35/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
36///    conversion of __B, and higher 64 bits come from conversion of __A.
37static __inline__ __m128bh __DEFAULT_FN_ATTRS128
38_mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
39  return (__m128bh)__builtin_ia32_cvtne2ps2bf16_128((__v4sf) __A,
40                                                    (__v4sf) __B);
41}
42
43/// Convert Two Packed Single Data to One Packed BF16 Data.
44///
45/// \headerfile <x86intrin.h>
46///
47/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
48///
49/// \param __A
50///    A 128-bit vector of [4 x float].
51/// \param __B
52///    A 128-bit vector of [4 x float].
53/// \param __W
54///    A 128-bit vector of [8 x bfloat].
55/// \param __U
56///    A 8-bit mask value specifying what is chosen for each element.
57///    A 1 means conversion of __A or __B. A 0 means element from __W.
58/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
59///    conversion of __B, and higher 64 bits come from conversion of __A.
60static __inline__ __m128bh __DEFAULT_FN_ATTRS128
61_mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
62  return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
63                                             (__v8bf)_mm_cvtne2ps_pbh(__A, __B),
64                                             (__v8bf)__W);
65}
66
67/// Convert Two Packed Single Data to One Packed BF16 Data.
68///
69/// \headerfile <x86intrin.h>
70///
71/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
72///
73/// \param __A
74///    A 128-bit vector of [4 x float].
75/// \param __B
76///    A 128-bit vector of [4 x float].
77/// \param __U
78///    A 8-bit mask value specifying what is chosen for each element.
79///    A 1 means conversion of __A or __B. A 0 means element is zero.
80/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
81///    conversion of __B, and higher 64 bits come from conversion of __A.
82static __inline__ __m128bh __DEFAULT_FN_ATTRS128
83_mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
84  return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
85                                             (__v8bf)_mm_cvtne2ps_pbh(__A, __B),
86                                             (__v8bf)_mm_setzero_si128());
87}
88
89/// Convert Two Packed Single Data to One Packed BF16 Data.
90///
91/// \headerfile <x86intrin.h>
92///
93/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
94///
95/// \param __A
96///    A 256-bit vector of [8 x float].
97/// \param __B
98///    A 256-bit vector of [8 x float].
99/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
100///    conversion of __B, and higher 128 bits come from conversion of __A.
101static __inline__ __m256bh __DEFAULT_FN_ATTRS256
102_mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
103  return (__m256bh)__builtin_ia32_cvtne2ps2bf16_256((__v8sf) __A,
104                                                    (__v8sf) __B);
105}
106
107/// Convert Two Packed Single Data to One Packed BF16 Data.
108///
109/// \headerfile <x86intrin.h>
110///
111/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
112///
113/// \param __A
114///    A 256-bit vector of [8 x float].
115/// \param __B
116///    A 256-bit vector of [8 x float].
117/// \param __W
118///    A 256-bit vector of [16 x bfloat].
119/// \param __U
120///    A 16-bit mask value specifying what is chosen for each element.
121///    A 1 means conversion of __A or __B. A 0 means element from __W.
122/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
123///    conversion of __B, and higher 128 bits come from conversion of __A.
124static __inline__ __m256bh __DEFAULT_FN_ATTRS256
125_mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
126  return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
127                                         (__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
128                                         (__v16bf)__W);
129}
130
131/// Convert Two Packed Single Data to One Packed BF16 Data.
132///
133/// \headerfile <x86intrin.h>
134///
135/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
136///
137/// \param __A
138///    A 256-bit vector of [8 x float].
139/// \param __B
140///    A 256-bit vector of [8 x float].
141/// \param __U
142///    A 16-bit mask value specifying what is chosen for each element.
143///    A 1 means conversion of __A or __B. A 0 means element is zero.
144/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
145///    conversion of __B, and higher 128 bits come from conversion of __A.
146static __inline__ __m256bh __DEFAULT_FN_ATTRS256
147_mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
148  return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
149                                         (__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
150                                         (__v16bf)_mm256_setzero_si256());
151}
152
153/// Convert Packed Single Data to Packed BF16 Data.
154///
155/// \headerfile <x86intrin.h>
156///
157/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
158///
159/// \param __A
160///    A 128-bit vector of [4 x float].
161/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
162///    conversion of __A, and higher 64 bits are 0.
163#define _mm_cvtneps_pbh(A)                                                     \
164  ((__m128bh)__builtin_ia32_vcvtneps2bf16128((__v4sf)(A)))
165
166/// Convert Packed Single Data to Packed BF16 Data.
167///
168/// \headerfile <x86intrin.h>
169///
170/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
171///
172/// \param __A
173///    A 128-bit vector of [4 x float].
174/// \param __W
175///    A 128-bit vector of [8 x bfloat].
176/// \param __U
177///    A 4-bit mask value specifying what is chosen for each element.
178///    A 1 means conversion of __A. A 0 means element from __W.
179/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
180///    conversion of __A, and higher 64 bits are 0.
181static __inline__ __m128bh __DEFAULT_FN_ATTRS128
182_mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
183  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
184                                                        (__v8bf)__W,
185                                                        (__mmask8)__U);
186}
187
188/// Convert Packed Single Data to Packed BF16 Data.
189///
190/// \headerfile <x86intrin.h>
191///
192/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
193///
194/// \param __A
195///    A 128-bit vector of [4 x float].
196/// \param __U
197///    A 4-bit mask value specifying what is chosen for each element.
198///    A 1 means conversion of __A. A 0 means element is zero.
199/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
200///    conversion of __A, and higher 64 bits are 0.
201static __inline__ __m128bh __DEFAULT_FN_ATTRS128
202_mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
203  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
204                                                    (__v8bf)_mm_setzero_si128(),
205                                                    (__mmask8)__U);
206}
207
208/// Convert Packed Single Data to Packed BF16 Data.
209///
210/// \headerfile <x86intrin.h>
211///
212/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
213///
214/// \param __A
215///    A 256-bit vector of [8 x float].
216/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
217#define _mm256_cvtneps_pbh(A)                                                  \
218  ((__m128bh)__builtin_ia32_vcvtneps2bf16256((__v8sf)(A)))
219
220/// Convert Packed Single Data to Packed BF16 Data.
221///
222/// \headerfile <x86intrin.h>
223///
224/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
225///
226/// \param __A
227///    A 256-bit vector of [8 x float].
228/// \param __W
229///    A 256-bit vector of [8 x bfloat].
230/// \param __U
231///    A 8-bit mask value specifying what is chosen for each element.
232///    A 1 means conversion of __A. A 0 means element from __W.
233/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
234static __inline__ __m128bh __DEFAULT_FN_ATTRS256
235_mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
236  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
237                                                        (__v8bf)__W,
238                                                        (__mmask8)__U);
239}
240
241/// Convert Packed Single Data to Packed BF16 Data.
242///
243/// \headerfile <x86intrin.h>
244///
245/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
246///
247/// \param __A
248///    A 256-bit vector of [8 x float].
249/// \param __U
250///    A 8-bit mask value specifying what is chosen for each element.
251///    A 1 means conversion of __A. A 0 means element is zero.
252/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
253static __inline__ __m128bh __DEFAULT_FN_ATTRS256
254_mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
255  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
256                                                    (__v8bf)_mm_setzero_si128(),
257                                                    (__mmask8)__U);
258}
259
260/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
261///
262/// \headerfile <x86intrin.h>
263///
264/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
265///
266/// \param __A
267///    A 128-bit vector of [8 x bfloat].
268/// \param __B
269///    A 128-bit vector of [8 x bfloat].
270/// \param __D
271///    A 128-bit vector of [4 x float].
272/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
273///  __A, __B and __D
274static __inline__ __m128 __DEFAULT_FN_ATTRS128
275_mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
276  return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
277                                             (__v8bf)__A,
278                                             (__v8bf)__B);
279}
280
281/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
282///
283/// \headerfile <x86intrin.h>
284///
285/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
286///
287/// \param __A
288///    A 128-bit vector of [8 x bfloat].
289/// \param __B
290///    A 128-bit vector of [8 x bfloat].
291/// \param __D
292///    A 128-bit vector of [4 x float].
293/// \param __U
294///    A 8-bit mask value specifying what is chosen for each element.
295///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
296/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
297///  __A, __B and __D
298static __inline__ __m128 __DEFAULT_FN_ATTRS128
299_mm_mask_dpbf16_ps(__m128 __D, __mmask8 __U, __m128bh __A, __m128bh __B) {
300  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
301                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
302                                           (__v4sf)__D);
303}
304
305/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
306///
307/// \headerfile <x86intrin.h>
308///
309/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
310///
311/// \param __A
312///    A 128-bit vector of [8 x bfloat].
313/// \param __B
314///    A 128-bit vector of [8 x bfloat].
315/// \param __D
316///    A 128-bit vector of [4 x float].
317/// \param __U
318///    A 8-bit mask value specifying what is chosen for each element.
319///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
320/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
321///  __A, __B and __D
322static __inline__ __m128 __DEFAULT_FN_ATTRS128
323_mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
324  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
325                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
326                                           (__v4sf)_mm_setzero_si128());
327}
328
329/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
330///
331/// \headerfile <x86intrin.h>
332///
333/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
334///
335/// \param __A
336///    A 256-bit vector of [16 x bfloat].
337/// \param __B
338///    A 256-bit vector of [16 x bfloat].
339/// \param __D
340///    A 256-bit vector of [8 x float].
341/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
342///  __A, __B and __D
343static __inline__ __m256 __DEFAULT_FN_ATTRS256
344_mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
345  return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
346                                             (__v16bf)__A,
347                                             (__v16bf)__B);
348}
349
350/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
351///
352/// \headerfile <x86intrin.h>
353///
354/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
355///
356/// \param __A
357///    A 256-bit vector of [16 x bfloat].
358/// \param __B
359///    A 256-bit vector of [16 x bfloat].
360/// \param __D
361///    A 256-bit vector of [8 x float].
362/// \param __U
363///    A 16-bit mask value specifying what is chosen for each element.
364///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
365/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
366///  __A, __B and __D
367static __inline__ __m256 __DEFAULT_FN_ATTRS256
368_mm256_mask_dpbf16_ps(__m256 __D, __mmask8 __U, __m256bh __A, __m256bh __B) {
369  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
370                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
371                                        (__v8sf)__D);
372}
373
374/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
375///
376/// \headerfile <x86intrin.h>
377///
378/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
379///
380/// \param __A
381///    A 256-bit vector of [16 x bfloat].
382/// \param __B
383///    A 256-bit vector of [16 x bfloat].
384/// \param __D
385///    A 256-bit vector of [8 x float].
386/// \param __U
387///    A 8-bit mask value specifying what is chosen for each element.
388///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
389/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
390///  __A, __B and __D
391static __inline__ __m256 __DEFAULT_FN_ATTRS256
392_mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
393  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
394                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
395                                        (__v8sf)_mm256_setzero_si256());
396}
397
398/// Convert One Single float Data to One BF16 Data.
399///
400/// \headerfile <x86intrin.h>
401///
402/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
403///
404/// \param __A
405///    A float data.
406/// \returns A bf16 data whose sign field and exponent field keep unchanged,
407///    and fraction field is truncated to 7 bits.
408static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
409  __v4sf __V = {__A, 0, 0, 0};
410  __v8bf __R = __builtin_ia32_cvtneps2bf16_128_mask(
411      (__v4sf)__V, (__v8bf)_mm_undefined_si128(), (__mmask8)-1);
412  return (__bf16)__R[0];
413}
414
415/// Convert Packed BF16 Data to Packed float Data.
416///
417/// \headerfile <x86intrin.h>
418///
419/// \param __A
420///    A 128-bit vector of [4 x bfloat].
421/// \returns A 128-bit vector of [4 x float] come from conversion of __A
422static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
423  return _mm_castsi128_ps(
424      (__m128i)_mm_slli_epi32((__m128i)_mm_cvtepi16_epi32((__m128i)__A), 16));
425}
426
427/// Convert Packed BF16 Data to Packed float Data.
428///
429/// \headerfile <x86intrin.h>
430///
431/// \param __A
432///    A 128-bit vector of [8 x bfloat].
433/// \returns A 256-bit vector of [8 x float] come from conversion of __A
434static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
435  return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
436      (__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
437}
438
439/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
440///
441/// \headerfile <x86intrin.h>
442///
443/// \param __U
444///    A 4-bit mask. Elements are zeroed out when the corresponding mask
445///    bit is not set.
446/// \param __A
447///    A 128-bit vector of [4 x bfloat].
448/// \returns A 128-bit vector of [4 x float] come from conversion of __A
449static __inline__ __m128 __DEFAULT_FN_ATTRS128
450_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
451  return _mm_castsi128_ps((__m128i)_mm_slli_epi32(
452      (__m128i)_mm_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
453}
454
455/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
456///
457/// \headerfile <x86intrin.h>
458///
459/// \param __U
460///    A 8-bit mask. Elements are zeroed out when the corresponding mask
461///    bit is not set.
462/// \param __A
463///    A 128-bit vector of [8 x bfloat].
464/// \returns A 256-bit vector of [8 x float] come from conversion of __A
465static __inline__ __m256 __DEFAULT_FN_ATTRS256
466_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
467  return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
468      (__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
469}
470
471/// Convert Packed BF16 Data to Packed float Data using merging mask.
472///
473/// \headerfile <x86intrin.h>
474///
475/// \param __S
476///    A 128-bit vector of [4 x float]. Elements are copied from __S when
477///     the corresponding mask bit is not set.
478/// \param __U
479///    A 4-bit mask. Elements are zeroed out when the corresponding mask
480///    bit is not set.
481/// \param __A
482///    A 128-bit vector of [4 x bfloat].
483/// \returns A 128-bit vector of [4 x float] come from conversion of __A
484static __inline__ __m128 __DEFAULT_FN_ATTRS128
485_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
486  return _mm_castsi128_ps((__m128i)_mm_mask_slli_epi32(
487      (__m128i)__S, (__mmask8)__U, (__m128i)_mm_cvtepi16_epi32((__m128i)__A),
488      16));
489}
490
491/// Convert Packed BF16 Data to Packed float Data using merging mask.
492///
493/// \headerfile <x86intrin.h>
494///
495/// \param __S
496///    A 256-bit vector of [8 x float]. Elements are copied from __S when
497///     the corresponding mask bit is not set.
498/// \param __U
499///    A 8-bit mask. Elements are zeroed out when the corresponding mask
500///    bit is not set.
501/// \param __A
502///    A 128-bit vector of [8 x bfloat].
503/// \returns A 256-bit vector of [8 x float] come from conversion of __A
504static __inline__ __m256 __DEFAULT_FN_ATTRS256
505_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
506  return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
507      (__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
508      16));
509}
510
511#undef __DEFAULT_FN_ATTRS128
512#undef __DEFAULT_FN_ATTRS256
513
514#endif
515#endif
516