1//===-- X86InstrFMA3Info.cpp - X86 FMA3 Instruction Information -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the classes providing information
10// about existing X86 FMA3 opcodes, classifying and grouping them.
11//
12//===----------------------------------------------------------------------===//
13
14#include "X86InstrFMA3Info.h"
15#include "X86InstrInfo.h"
16#include "llvm/Support/ManagedStatic.h"
17#include "llvm/Support/Threading.h"
18#include <cassert>
19#include <cstdint>
20
21using namespace llvm;
22
23#define FMA3GROUP(Name, Suf, Attrs) \
24  { { X86::Name##132##Suf, X86::Name##213##Suf, X86::Name##231##Suf }, Attrs },
25
26#define FMA3GROUP_MASKED(Name, Suf, Attrs) \
27  FMA3GROUP(Name, Suf, Attrs) \
28  FMA3GROUP(Name, Suf##k, Attrs | X86InstrFMA3Group::KMergeMasked) \
29  FMA3GROUP(Name, Suf##kz, Attrs | X86InstrFMA3Group::KZeroMasked)
30
31#define FMA3GROUP_PACKED_WIDTHS(Name, Suf, Attrs) \
32  FMA3GROUP(Name, Suf##Ym, Attrs) \
33  FMA3GROUP(Name, Suf##Yr, Attrs) \
34  FMA3GROUP_MASKED(Name, Suf##Z128m, Attrs) \
35  FMA3GROUP_MASKED(Name, Suf##Z128r, Attrs) \
36  FMA3GROUP_MASKED(Name, Suf##Z256m, Attrs) \
37  FMA3GROUP_MASKED(Name, Suf##Z256r, Attrs) \
38  FMA3GROUP_MASKED(Name, Suf##Zm, Attrs) \
39  FMA3GROUP_MASKED(Name, Suf##Zr, Attrs) \
40  FMA3GROUP(Name, Suf##m, Attrs) \
41  FMA3GROUP(Name, Suf##r, Attrs)
42
43#define FMA3GROUP_PACKED(Name, Attrs) \
44  FMA3GROUP_PACKED_WIDTHS(Name, PD, Attrs) \
45  FMA3GROUP_PACKED_WIDTHS(Name, PS, Attrs)
46
47#define FMA3GROUP_SCALAR_WIDTHS(Name, Suf, Attrs) \
48  FMA3GROUP(Name, Suf##Zm, Attrs) \
49  FMA3GROUP_MASKED(Name, Suf##Zm_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
50  FMA3GROUP(Name, Suf##Zr, Attrs) \
51  FMA3GROUP_MASKED(Name, Suf##Zr_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
52  FMA3GROUP(Name, Suf##m, Attrs) \
53  FMA3GROUP(Name, Suf##m_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
54  FMA3GROUP(Name, Suf##r, Attrs) \
55  FMA3GROUP(Name, Suf##r_Int, Attrs | X86InstrFMA3Group::Intrinsic)
56
57#define FMA3GROUP_SCALAR(Name, Attrs) \
58  FMA3GROUP_SCALAR_WIDTHS(Name, SD, Attrs) \
59  FMA3GROUP_SCALAR_WIDTHS(Name, SS, Attrs)
60
61#define FMA3GROUP_FULL(Name, Attrs) \
62  FMA3GROUP_PACKED(Name, Attrs) \
63  FMA3GROUP_SCALAR(Name, Attrs)
64
65static const X86InstrFMA3Group Groups[] = {
66  FMA3GROUP_FULL(VFMADD, 0)
67  FMA3GROUP_PACKED(VFMADDSUB, 0)
68  FMA3GROUP_FULL(VFMSUB, 0)
69  FMA3GROUP_PACKED(VFMSUBADD, 0)
70  FMA3GROUP_FULL(VFNMADD, 0)
71  FMA3GROUP_FULL(VFNMSUB, 0)
72};
73
74#define FMA3GROUP_PACKED_AVX512_WIDTHS(Name, Type, Suf, Attrs) \
75  FMA3GROUP_MASKED(Name, Type##Z128##Suf, Attrs) \
76  FMA3GROUP_MASKED(Name, Type##Z256##Suf, Attrs) \
77  FMA3GROUP_MASKED(Name, Type##Z##Suf, Attrs)
78
79#define FMA3GROUP_PACKED_AVX512(Name, Suf, Attrs) \
80  FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PD, Suf, Attrs) \
81  FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PS, Suf, Attrs)
82
83#define FMA3GROUP_PACKED_AVX512_ROUND(Name, Suf, Attrs) \
84  FMA3GROUP_MASKED(Name, PDZ##Suf, Attrs) \
85  FMA3GROUP_MASKED(Name, PSZ##Suf, Attrs)
86
87#define FMA3GROUP_SCALAR_AVX512_ROUND(Name, Suf, Attrs) \
88  FMA3GROUP(Name, SDZ##Suf, Attrs) \
89  FMA3GROUP_MASKED(Name, SDZ##Suf##_Int, Attrs) \
90  FMA3GROUP(Name, SSZ##Suf, Attrs) \
91  FMA3GROUP_MASKED(Name, SSZ##Suf##_Int, Attrs)
92
93static const X86InstrFMA3Group BroadcastGroups[] = {
94  FMA3GROUP_PACKED_AVX512(VFMADD, mb, 0)
95  FMA3GROUP_PACKED_AVX512(VFMADDSUB, mb, 0)
96  FMA3GROUP_PACKED_AVX512(VFMSUB, mb, 0)
97  FMA3GROUP_PACKED_AVX512(VFMSUBADD, mb, 0)
98  FMA3GROUP_PACKED_AVX512(VFNMADD, mb, 0)
99  FMA3GROUP_PACKED_AVX512(VFNMSUB, mb, 0)
100};
101
102static const X86InstrFMA3Group RoundGroups[] = {
103  FMA3GROUP_PACKED_AVX512_ROUND(VFMADD, rb, 0)
104  FMA3GROUP_SCALAR_AVX512_ROUND(VFMADD, rb, X86InstrFMA3Group::Intrinsic)
105  FMA3GROUP_PACKED_AVX512_ROUND(VFMADDSUB, rb, 0)
106  FMA3GROUP_PACKED_AVX512_ROUND(VFMSUB, rb, 0)
107  FMA3GROUP_SCALAR_AVX512_ROUND(VFMSUB, rb, X86InstrFMA3Group::Intrinsic)
108  FMA3GROUP_PACKED_AVX512_ROUND(VFMSUBADD, rb, 0)
109  FMA3GROUP_PACKED_AVX512_ROUND(VFNMADD, rb, 0)
110  FMA3GROUP_SCALAR_AVX512_ROUND(VFNMADD, rb, X86InstrFMA3Group::Intrinsic)
111  FMA3GROUP_PACKED_AVX512_ROUND(VFNMSUB, rb, 0)
112  FMA3GROUP_SCALAR_AVX512_ROUND(VFNMSUB, rb, X86InstrFMA3Group::Intrinsic)
113};
114
115static void verifyTables() {
116#ifndef NDEBUG
117  static std::atomic<bool> TableChecked(false);
118  if (!TableChecked.load(std::memory_order_relaxed)) {
119    assert(llvm::is_sorted(Groups) && llvm::is_sorted(RoundGroups) &&
120           llvm::is_sorted(BroadcastGroups) && "FMA3 tables not sorted!");
121    TableChecked.store(true, std::memory_order_relaxed);
122  }
123#endif
124}
125
126/// Returns a reference to a group of FMA3 opcodes to where the given
127/// \p Opcode is included. If the given \p Opcode is not recognized as FMA3
128/// and not included into any FMA3 group, then nullptr is returned.
129const X86InstrFMA3Group *llvm::getFMA3Group(unsigned Opcode, uint64_t TSFlags) {
130
131  // FMA3 instructions have a well defined encoding pattern we can exploit.
132  uint8_t BaseOpcode = X86II::getBaseOpcodeFor(TSFlags);
133  bool IsFMA3 = ((TSFlags & X86II::EncodingMask) == X86II::VEX ||
134                 (TSFlags & X86II::EncodingMask) == X86II::EVEX) &&
135                (TSFlags & X86II::OpMapMask) == X86II::T8 &&
136                (TSFlags & X86II::OpPrefixMask) == X86II::PD &&
137                ((BaseOpcode >= 0x96 && BaseOpcode <= 0x9F) ||
138                 (BaseOpcode >= 0xA6 && BaseOpcode <= 0xAF) ||
139                 (BaseOpcode >= 0xB6 && BaseOpcode <= 0xBF));
140  if (!IsFMA3)
141    return nullptr;
142
143  verifyTables();
144
145  ArrayRef<X86InstrFMA3Group> Table;
146  if (TSFlags & X86II::EVEX_RC)
147    Table = makeArrayRef(RoundGroups);
148  else if (TSFlags & X86II::EVEX_B)
149    Table = makeArrayRef(BroadcastGroups);
150  else
151    Table = makeArrayRef(Groups);
152
153  // FMA 132 instructions have an opcode of 0x96-0x9F
154  // FMA 213 instructions have an opcode of 0xA6-0xAF
155  // FMA 231 instructions have an opcode of 0xB6-0xBF
156  unsigned FormIndex = ((BaseOpcode - 0x90) >> 4) & 0x3;
157
158  auto I = partition_point(Table, [=](const X86InstrFMA3Group &Group) {
159    return Group.Opcodes[FormIndex] < Opcode;
160  });
161  assert(I != Table.end() && I->Opcodes[FormIndex] == Opcode &&
162         "Couldn't find FMA3 opcode!");
163  return I;
164}
165