AMDGPUMetadataVerifier.cpp revision 344779
1//===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10/// \file
11/// Implements a verifier for AMDGPU HSA metadata.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
16#include "llvm/Support/AMDGPUMetadata.h"
17
18namespace llvm {
19namespace AMDGPU {
20namespace HSAMD {
21namespace V3 {
22
23bool MetadataVerifier::verifyScalar(
24    msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind,
25    function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
26  auto ScalarPtr = dyn_cast<msgpack::ScalarNode>(&Node);
27  if (!ScalarPtr)
28    return false;
29  auto &Scalar = *ScalarPtr;
30  // Do not output extraneous tags for types we know from the spec.
31  Scalar.IgnoreTag = true;
32  if (Scalar.getScalarKind() != SKind) {
33    if (Strict)
34      return false;
35    // If we are not strict, we interpret string values as "implicitly typed"
36    // and attempt to coerce them to the expected type here.
37    if (Scalar.getScalarKind() != msgpack::ScalarNode::SK_String)
38      return false;
39    std::string StringValue = Scalar.getString();
40    Scalar.setScalarKind(SKind);
41    if (Scalar.inputYAML(StringValue) != StringRef())
42      return false;
43  }
44  if (verifyValue)
45    return verifyValue(Scalar);
46  return true;
47}
48
49bool MetadataVerifier::verifyInteger(msgpack::Node &Node) {
50  if (!verifyScalar(Node, msgpack::ScalarNode::SK_UInt))
51    if (!verifyScalar(Node, msgpack::ScalarNode::SK_Int))
52      return false;
53  return true;
54}
55
56bool MetadataVerifier::verifyArray(
57    msgpack::Node &Node, function_ref<bool(msgpack::Node &)> verifyNode,
58    Optional<size_t> Size) {
59  auto ArrayPtr = dyn_cast<msgpack::ArrayNode>(&Node);
60  if (!ArrayPtr)
61    return false;
62  auto &Array = *ArrayPtr;
63  if (Size && Array.size() != *Size)
64    return false;
65  for (auto &Item : Array)
66    if (!verifyNode(*Item.get()))
67      return false;
68
69  return true;
70}
71
72bool MetadataVerifier::verifyEntry(
73    msgpack::MapNode &MapNode, StringRef Key, bool Required,
74    function_ref<bool(msgpack::Node &)> verifyNode) {
75  auto Entry = MapNode.find(Key);
76  if (Entry == MapNode.end())
77    return !Required;
78  return verifyNode(*Entry->second.get());
79}
80
81bool MetadataVerifier::verifyScalarEntry(
82    msgpack::MapNode &MapNode, StringRef Key, bool Required,
83    msgpack::ScalarNode::ScalarKind SKind,
84    function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
85  return verifyEntry(MapNode, Key, Required, [=](msgpack::Node &Node) {
86    return verifyScalar(Node, SKind, verifyValue);
87  });
88}
89
90bool MetadataVerifier::verifyIntegerEntry(msgpack::MapNode &MapNode,
91                                          StringRef Key, bool Required) {
92  return verifyEntry(MapNode, Key, Required, [this](msgpack::Node &Node) {
93    return verifyInteger(Node);
94  });
95}
96
97bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
98  auto ArgsMapPtr = dyn_cast<msgpack::MapNode>(&Node);
99  if (!ArgsMapPtr)
100    return false;
101  auto &ArgsMap = *ArgsMapPtr;
102
103  if (!verifyScalarEntry(ArgsMap, ".name", false,
104                         msgpack::ScalarNode::SK_String))
105    return false;
106  if (!verifyScalarEntry(ArgsMap, ".type_name", false,
107                         msgpack::ScalarNode::SK_String))
108    return false;
109  if (!verifyIntegerEntry(ArgsMap, ".size", true))
110    return false;
111  if (!verifyIntegerEntry(ArgsMap, ".offset", true))
112    return false;
113  if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
114                         msgpack::ScalarNode::SK_String,
115                         [](msgpack::ScalarNode &SNode) {
116                           return StringSwitch<bool>(SNode.getString())
117                               .Case("by_value", true)
118                               .Case("global_buffer", true)
119                               .Case("dynamic_shared_pointer", true)
120                               .Case("sampler", true)
121                               .Case("image", true)
122                               .Case("pipe", true)
123                               .Case("queue", true)
124                               .Case("hidden_global_offset_x", true)
125                               .Case("hidden_global_offset_y", true)
126                               .Case("hidden_global_offset_z", true)
127                               .Case("hidden_none", true)
128                               .Case("hidden_printf_buffer", true)
129                               .Case("hidden_default_queue", true)
130                               .Case("hidden_completion_action", true)
131                               .Default(false);
132                         }))
133    return false;
134  if (!verifyScalarEntry(ArgsMap, ".value_type", true,
135                         msgpack::ScalarNode::SK_String,
136                         [](msgpack::ScalarNode &SNode) {
137                           return StringSwitch<bool>(SNode.getString())
138                               .Case("struct", true)
139                               .Case("i8", true)
140                               .Case("u8", true)
141                               .Case("i16", true)
142                               .Case("u16", true)
143                               .Case("f16", true)
144                               .Case("i32", true)
145                               .Case("u32", true)
146                               .Case("f32", true)
147                               .Case("i64", true)
148                               .Case("u64", true)
149                               .Case("f64", true)
150                               .Default(false);
151                         }))
152    return false;
153  if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
154    return false;
155  if (!verifyScalarEntry(ArgsMap, ".address_space", false,
156                         msgpack::ScalarNode::SK_String,
157                         [](msgpack::ScalarNode &SNode) {
158                           return StringSwitch<bool>(SNode.getString())
159                               .Case("private", true)
160                               .Case("global", true)
161                               .Case("constant", true)
162                               .Case("local", true)
163                               .Case("generic", true)
164                               .Case("region", true)
165                               .Default(false);
166                         }))
167    return false;
168  if (!verifyScalarEntry(ArgsMap, ".access", false,
169                         msgpack::ScalarNode::SK_String,
170                         [](msgpack::ScalarNode &SNode) {
171                           return StringSwitch<bool>(SNode.getString())
172                               .Case("read_only", true)
173                               .Case("write_only", true)
174                               .Case("read_write", true)
175                               .Default(false);
176                         }))
177    return false;
178  if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
179                         msgpack::ScalarNode::SK_String,
180                         [](msgpack::ScalarNode &SNode) {
181                           return StringSwitch<bool>(SNode.getString())
182                               .Case("read_only", true)
183                               .Case("write_only", true)
184                               .Case("read_write", true)
185                               .Default(false);
186                         }))
187    return false;
188  if (!verifyScalarEntry(ArgsMap, ".is_const", false,
189                         msgpack::ScalarNode::SK_Boolean))
190    return false;
191  if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
192                         msgpack::ScalarNode::SK_Boolean))
193    return false;
194  if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
195                         msgpack::ScalarNode::SK_Boolean))
196    return false;
197  if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
198                         msgpack::ScalarNode::SK_Boolean))
199    return false;
200
201  return true;
202}
203
204bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
205  auto KernelMapPtr = dyn_cast<msgpack::MapNode>(&Node);
206  if (!KernelMapPtr)
207    return false;
208  auto &KernelMap = *KernelMapPtr;
209
210  if (!verifyScalarEntry(KernelMap, ".name", true,
211                         msgpack::ScalarNode::SK_String))
212    return false;
213  if (!verifyScalarEntry(KernelMap, ".symbol", true,
214                         msgpack::ScalarNode::SK_String))
215    return false;
216  if (!verifyScalarEntry(KernelMap, ".language", false,
217                         msgpack::ScalarNode::SK_String,
218                         [](msgpack::ScalarNode &SNode) {
219                           return StringSwitch<bool>(SNode.getString())
220                               .Case("OpenCL C", true)
221                               .Case("OpenCL C++", true)
222                               .Case("HCC", true)
223                               .Case("HIP", true)
224                               .Case("OpenMP", true)
225                               .Case("Assembler", true)
226                               .Default(false);
227                         }))
228    return false;
229  if (!verifyEntry(
230          KernelMap, ".language_version", false, [this](msgpack::Node &Node) {
231            return verifyArray(
232                Node,
233                [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
234          }))
235    return false;
236  if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::Node &Node) {
237        return verifyArray(Node, [this](msgpack::Node &Node) {
238          return verifyKernelArgs(Node);
239        });
240      }))
241    return false;
242  if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
243                   [this](msgpack::Node &Node) {
244                     return verifyArray(Node,
245                                        [this](msgpack::Node &Node) {
246                                          return verifyInteger(Node);
247                                        },
248                                        3);
249                   }))
250    return false;
251  if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
252                   [this](msgpack::Node &Node) {
253                     return verifyArray(Node,
254                                        [this](msgpack::Node &Node) {
255                                          return verifyInteger(Node);
256                                        },
257                                        3);
258                   }))
259    return false;
260  if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
261                         msgpack::ScalarNode::SK_String))
262    return false;
263  if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
264                         msgpack::ScalarNode::SK_String))
265    return false;
266  if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
267    return false;
268  if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
269    return false;
270  if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
271    return false;
272  if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
273    return false;
274  if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
275    return false;
276  if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
277    return false;
278  if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
279    return false;
280  if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
281    return false;
282  if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
283    return false;
284  if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
285    return false;
286
287  return true;
288}
289
290bool MetadataVerifier::verify(msgpack::Node &HSAMetadataRoot) {
291  auto RootMapPtr = dyn_cast<msgpack::MapNode>(&HSAMetadataRoot);
292  if (!RootMapPtr)
293    return false;
294  auto &RootMap = *RootMapPtr;
295
296  if (!verifyEntry(
297          RootMap, "amdhsa.version", true, [this](msgpack::Node &Node) {
298            return verifyArray(
299                Node,
300                [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
301          }))
302    return false;
303  if (!verifyEntry(
304          RootMap, "amdhsa.printf", false, [this](msgpack::Node &Node) {
305            return verifyArray(Node, [this](msgpack::Node &Node) {
306              return verifyScalar(Node, msgpack::ScalarNode::SK_String);
307            });
308          }))
309    return false;
310  if (!verifyEntry(RootMap, "amdhsa.kernels", true,
311                   [this](msgpack::Node &Node) {
312                     return verifyArray(Node, [this](msgpack::Node &Node) {
313                       return verifyKernel(Node);
314                     });
315                   }))
316    return false;
317
318  return true;
319}
320
321} // end namespace V3
322} // end namespace HSAMD
323} // end namespace AMDGPU
324} // end namespace llvm
325