1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import os
8import re
9import shutil
10import tempfile
11import yaml
12
13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16def c_upper(name):
17    return name.upper().replace('-', '_')
18
19
20def c_lower(name):
21    return name.lower().replace('-', '_')
22
23
24def limit_to_number(name):
25    """
26    Turn a string limit like u32-max or s64-min into its numerical value
27    """
28    if name[0] == 'u' and name.endswith('-min'):
29        return 0
30    width = int(name[1:-4])
31    if name[0] == 's':
32        width -= 1
33    value = (1 << width) - 1
34    if name[0] == 's' and name.endswith('-min'):
35        value = -value - 1
36    return value
37
38
39class BaseNlLib:
40    def get_family_id(self):
41        return 'ys->family_id'
42
43
44class Type(SpecAttr):
45    def __init__(self, family, attr_set, attr, value):
46        super().__init__(family, attr_set, attr, value)
47
48        self.attr = attr
49        self.attr_set = attr_set
50        self.type = attr['type']
51        self.checks = attr.get('checks', {})
52
53        self.request = False
54        self.reply = False
55
56        if 'len' in attr:
57            self.len = attr['len']
58
59        if 'nested-attributes' in attr:
60            self.nested_attrs = attr['nested-attributes']
61            if self.nested_attrs == family.name:
62                self.nested_render_name = c_lower(f"{family.name}")
63            else:
64                self.nested_render_name = c_lower(f"{family.name}_{self.nested_attrs}")
65
66            if self.nested_attrs in self.family.consts:
67                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
68            else:
69                self.nested_struct_type = 'struct ' + self.nested_render_name
70
71        self.c_name = c_lower(self.name)
72        if self.c_name in _C_KW:
73            self.c_name += '_'
74
75        # Added by resolve():
76        self.enum_name = None
77        delattr(self, "enum_name")
78
79    def get_limit(self, limit, default=None):
80        value = self.checks.get(limit, default)
81        if value is None:
82            return value
83        elif value in self.family.consts:
84            return c_upper(f"{self.family['name']}-{value}")
85        if not isinstance(value, int):
86            value = limit_to_number(value)
87        return value
88
89    def resolve(self):
90        if 'name-prefix' in self.attr:
91            enum_name = f"{self.attr['name-prefix']}{self.name}"
92        else:
93            enum_name = f"{self.attr_set.name_prefix}{self.name}"
94        self.enum_name = c_upper(enum_name)
95
96    def is_multi_val(self):
97        return None
98
99    def is_scalar(self):
100        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
101
102    def is_recursive(self):
103        return False
104
105    def is_recursive_for_op(self, ri):
106        return self.is_recursive() and not ri.op
107
108    def presence_type(self):
109        return 'bit'
110
111    def presence_member(self, space, type_filter):
112        if self.presence_type() != type_filter:
113            return
114
115        if self.presence_type() == 'bit':
116            pfx = '__' if space == 'user' else ''
117            return f"{pfx}u32 {self.c_name}:1;"
118
119        if self.presence_type() == 'len':
120            pfx = '__' if space == 'user' else ''
121            return f"{pfx}u32 {self.c_name}_len;"
122
123    def _complex_member_type(self, ri):
124        return None
125
126    def free_needs_iter(self):
127        return False
128
129    def free(self, ri, var, ref):
130        if self.is_multi_val() or self.presence_type() == 'len':
131            ri.cw.p(f'free({var}->{ref}{self.c_name});')
132
133    def arg_member(self, ri):
134        member = self._complex_member_type(ri)
135        if member:
136            arg = [member + ' *' + self.c_name]
137            if self.presence_type() == 'count':
138                arg += ['unsigned int n_' + self.c_name]
139            return arg
140        raise Exception(f"Struct member not implemented for class type {self.type}")
141
142    def struct_member(self, ri):
143        if self.is_multi_val():
144            ri.cw.p(f"unsigned int n_{self.c_name};")
145        member = self._complex_member_type(ri)
146        if member:
147            ptr = '*' if self.is_multi_val() else ''
148            if self.is_recursive_for_op(ri):
149                ptr = '*'
150            ri.cw.p(f"{member} {ptr}{self.c_name};")
151            return
152        members = self.arg_member(ri)
153        for one in members:
154            ri.cw.p(one + ';')
155
156    def _attr_policy(self, policy):
157        return '{ .type = ' + policy + ', }'
158
159    def attr_policy(self, cw):
160        policy = c_upper('nla-' + self.attr['type'])
161
162        spec = self._attr_policy(policy)
163        cw.p(f"\t[{self.enum_name}] = {spec},")
164
165    def _attr_typol(self):
166        raise Exception(f"Type policy not implemented for class type {self.type}")
167
168    def attr_typol(self, cw):
169        typol = self._attr_typol()
170        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
171
172    def _attr_put_line(self, ri, var, line):
173        if self.presence_type() == 'bit':
174            ri.cw.p(f"if ({var}->_present.{self.c_name})")
175        elif self.presence_type() == 'len':
176            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
177        ri.cw.p(f"{line};")
178
179    def _attr_put_simple(self, ri, var, put_type):
180        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
181        self._attr_put_line(ri, var, line)
182
183    def attr_put(self, ri, var):
184        raise Exception(f"Put not implemented for class type {self.type}")
185
186    def _attr_get(self, ri, var):
187        raise Exception(f"Attr get not implemented for class type {self.type}")
188
189    def attr_get(self, ri, var, first):
190        lines, init_lines, local_vars = self._attr_get(ri, var)
191        if type(lines) is str:
192            lines = [lines]
193        if type(init_lines) is str:
194            init_lines = [init_lines]
195
196        kw = 'if' if first else 'else if'
197        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
198        if local_vars:
199            for local in local_vars:
200                ri.cw.p(local)
201            ri.cw.nl()
202
203        if not self.is_multi_val():
204            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
205            ri.cw.p("return YNL_PARSE_CB_ERROR;")
206            if self.presence_type() == 'bit':
207                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
208
209        if init_lines:
210            ri.cw.nl()
211            for line in init_lines:
212                ri.cw.p(line)
213
214        for line in lines:
215            ri.cw.p(line)
216        ri.cw.block_end()
217        return True
218
219    def _setter_lines(self, ri, member, presence):
220        raise Exception(f"Setter not implemented for class type {self.type}")
221
222    def setter(self, ri, space, direction, deref=False, ref=None):
223        ref = (ref if ref else []) + [self.c_name]
224        var = "req"
225        member = f"{var}->{'.'.join(ref)}"
226
227        code = []
228        presence = ''
229        for i in range(0, len(ref)):
230            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
231            # Every layer below last is a nest, so we know it uses bit presence
232            # last layer is "self" and may be a complex type
233            if i == len(ref) - 1 and self.presence_type() != 'bit':
234                continue
235            code.append(presence + ' = 1;')
236        code += self._setter_lines(ri, member, presence)
237
238        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
239        free = bool([x for x in code if 'free(' in x])
240        alloc = bool([x for x in code if 'alloc(' in x])
241        if free and not alloc:
242            func_name = '__' + func_name
243        ri.cw.write_func('static inline void', func_name, body=code,
244                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
245
246
247class TypeUnused(Type):
248    def presence_type(self):
249        return ''
250
251    def arg_member(self, ri):
252        return []
253
254    def _attr_get(self, ri, var):
255        return ['return YNL_PARSE_CB_ERROR;'], None, None
256
257    def _attr_typol(self):
258        return '.type = YNL_PT_REJECT, '
259
260    def attr_policy(self, cw):
261        pass
262
263    def attr_put(self, ri, var):
264        pass
265
266    def attr_get(self, ri, var, first):
267        pass
268
269    def setter(self, ri, space, direction, deref=False, ref=None):
270        pass
271
272
273class TypePad(Type):
274    def presence_type(self):
275        return ''
276
277    def arg_member(self, ri):
278        return []
279
280    def _attr_typol(self):
281        return '.type = YNL_PT_IGNORE, '
282
283    def attr_put(self, ri, var):
284        pass
285
286    def attr_get(self, ri, var, first):
287        pass
288
289    def attr_policy(self, cw):
290        pass
291
292    def setter(self, ri, space, direction, deref=False, ref=None):
293        pass
294
295
296class TypeScalar(Type):
297    def __init__(self, family, attr_set, attr, value):
298        super().__init__(family, attr_set, attr, value)
299
300        self.byte_order_comment = ''
301        if 'byte-order' in attr:
302            self.byte_order_comment = f" /* {attr['byte-order']} */"
303
304        if 'enum' in self.attr:
305            enum = self.family.consts[self.attr['enum']]
306            low, high = enum.value_range()
307            if 'min' not in self.checks:
308                if low != 0 or self.type[0] == 's':
309                    self.checks['min'] = low
310            if 'max' not in self.checks:
311                self.checks['max'] = high
312
313        if 'min' in self.checks and 'max' in self.checks:
314            if self.get_limit('min') > self.get_limit('max'):
315                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
316            self.checks['range'] = True
317
318        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
319        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
320        if low < 0 and self.type[0] == 'u':
321            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
322        if low < -32768 or high > 32767:
323            self.checks['full-range'] = True
324
325        # Added by resolve():
326        self.is_bitfield = None
327        delattr(self, "is_bitfield")
328        self.type_name = None
329        delattr(self, "type_name")
330
331    def resolve(self):
332        self.resolve_up(super())
333
334        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
335            self.is_bitfield = True
336        elif 'enum' in self.attr:
337            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
338        else:
339            self.is_bitfield = False
340
341        if not self.is_bitfield and 'enum' in self.attr:
342            self.type_name = self.family.consts[self.attr['enum']].user_type
343        elif self.is_auto_scalar:
344            self.type_name = '__' + self.type[0] + '64'
345        else:
346            self.type_name = '__' + self.type
347
348    def _attr_policy(self, policy):
349        if 'flags-mask' in self.checks or self.is_bitfield:
350            if self.is_bitfield:
351                enum = self.family.consts[self.attr['enum']]
352                mask = enum.get_mask(as_flags=True)
353            else:
354                flags = self.family.consts[self.checks['flags-mask']]
355                flag_cnt = len(flags['entries'])
356                mask = (1 << flag_cnt) - 1
357            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
358        elif 'full-range' in self.checks:
359            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
360        elif 'range' in self.checks:
361            return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
362        elif 'min' in self.checks:
363            return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
364        elif 'max' in self.checks:
365            return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
366        return super()._attr_policy(policy)
367
368    def _attr_typol(self):
369        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
370
371    def arg_member(self, ri):
372        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
373
374    def attr_put(self, ri, var):
375        self._attr_put_simple(ri, var, self.type)
376
377    def _attr_get(self, ri, var):
378        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
379
380    def _setter_lines(self, ri, member, presence):
381        return [f"{member} = {self.c_name};"]
382
383
384class TypeFlag(Type):
385    def arg_member(self, ri):
386        return []
387
388    def _attr_typol(self):
389        return '.type = YNL_PT_FLAG, '
390
391    def attr_put(self, ri, var):
392        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
393
394    def _attr_get(self, ri, var):
395        return [], None, None
396
397    def _setter_lines(self, ri, member, presence):
398        return []
399
400
401class TypeString(Type):
402    def arg_member(self, ri):
403        return [f"const char *{self.c_name}"]
404
405    def presence_type(self):
406        return 'len'
407
408    def struct_member(self, ri):
409        ri.cw.p(f"char *{self.c_name};")
410
411    def _attr_typol(self):
412        return f'.type = YNL_PT_NUL_STR, '
413
414    def _attr_policy(self, policy):
415        if 'exact-len' in self.checks:
416            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
417        else:
418            mem = '{ .type = ' + policy
419            if 'max-len' in self.checks:
420                mem += ', .len = ' + str(self.get_limit('max-len'))
421            mem += ', }'
422        return mem
423
424    def attr_policy(self, cw):
425        if self.checks.get('unterminated-ok', False):
426            policy = 'NLA_STRING'
427        else:
428            policy = 'NLA_NUL_STRING'
429
430        spec = self._attr_policy(policy)
431        cw.p(f"\t[{self.enum_name}] = {spec},")
432
433    def attr_put(self, ri, var):
434        self._attr_put_simple(ri, var, 'str')
435
436    def _attr_get(self, ri, var):
437        len_mem = var + '->_present.' + self.c_name + '_len'
438        return [f"{len_mem} = len;",
439                f"{var}->{self.c_name} = malloc(len + 1);",
440                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
441                f"{var}->{self.c_name}[len] = 0;"], \
442               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
443               ['unsigned int len;']
444
445    def _setter_lines(self, ri, member, presence):
446        return [f"free({member});",
447                f"{presence}_len = strlen({self.c_name});",
448                f"{member} = malloc({presence}_len + 1);",
449                f'memcpy({member}, {self.c_name}, {presence}_len);',
450                f'{member}[{presence}_len] = 0;']
451
452
453class TypeBinary(Type):
454    def arg_member(self, ri):
455        return [f"const void *{self.c_name}", 'size_t len']
456
457    def presence_type(self):
458        return 'len'
459
460    def struct_member(self, ri):
461        ri.cw.p(f"void *{self.c_name};")
462
463    def _attr_typol(self):
464        return f'.type = YNL_PT_BINARY,'
465
466    def _attr_policy(self, policy):
467        if 'exact-len' in self.checks:
468            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
469        else:
470            mem = '{ '
471            if len(self.checks) == 1 and 'min-len' in self.checks:
472                mem += '.len = ' + str(self.get_limit('min-len'))
473            elif len(self.checks) == 0:
474                mem += '.type = NLA_BINARY'
475            else:
476                raise Exception('One or more of binary type checks not implemented, yet')
477            mem += ', }'
478        return mem
479
480    def attr_put(self, ri, var):
481        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
482                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
483
484    def _attr_get(self, ri, var):
485        len_mem = var + '->_present.' + self.c_name + '_len'
486        return [f"{len_mem} = len;",
487                f"{var}->{self.c_name} = malloc(len);",
488                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
489               ['len = ynl_attr_data_len(attr);'], \
490               ['unsigned int len;']
491
492    def _setter_lines(self, ri, member, presence):
493        return [f"free({member});",
494                f"{presence}_len = len;",
495                f"{member} = malloc({presence}_len);",
496                f'memcpy({member}, {self.c_name}, {presence}_len);']
497
498
499class TypeBitfield32(Type):
500    def _complex_member_type(self, ri):
501        return "struct nla_bitfield32"
502
503    def _attr_typol(self):
504        return f'.type = YNL_PT_BITFIELD32, '
505
506    def _attr_policy(self, policy):
507        if not 'enum' in self.attr:
508            raise Exception('Enum required for bitfield32 attr')
509        enum = self.family.consts[self.attr['enum']]
510        mask = enum.get_mask(as_flags=True)
511        return f"NLA_POLICY_BITFIELD32({mask})"
512
513    def attr_put(self, ri, var):
514        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
515        self._attr_put_line(ri, var, line)
516
517    def _attr_get(self, ri, var):
518        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
519
520    def _setter_lines(self, ri, member, presence):
521        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
522
523
524class TypeNest(Type):
525    def is_recursive(self):
526        return self.family.pure_nested_structs[self.nested_attrs].recursive
527
528    def _complex_member_type(self, ri):
529        return self.nested_struct_type
530
531    def free(self, ri, var, ref):
532        at = '&'
533        if self.is_recursive_for_op(ri):
534            at = ''
535            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
536        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
537
538    def _attr_typol(self):
539        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
540
541    def _attr_policy(self, policy):
542        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
543
544    def attr_put(self, ri, var):
545        at = '' if self.is_recursive_for_op(ri) else '&'
546        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
547                            f"{self.enum_name}, {at}{var}->{self.c_name})")
548
549    def _attr_get(self, ri, var):
550        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
551                     "return YNL_PARSE_CB_ERROR;"]
552        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
553                      f"parg.data = &{var}->{self.c_name};"]
554        return get_lines, init_lines, None
555
556    def setter(self, ri, space, direction, deref=False, ref=None):
557        ref = (ref if ref else []) + [self.c_name]
558
559        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
560            if attr.is_recursive():
561                continue
562            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
563
564
565class TypeMultiAttr(Type):
566    def __init__(self, family, attr_set, attr, value, base_type):
567        super().__init__(family, attr_set, attr, value)
568
569        self.base_type = base_type
570
571    def is_multi_val(self):
572        return True
573
574    def presence_type(self):
575        return 'count'
576
577    def _complex_member_type(self, ri):
578        if 'type' not in self.attr or self.attr['type'] == 'nest':
579            return self.nested_struct_type
580        elif self.attr['type'] in scalars:
581            scalar_pfx = '__' if ri.ku_space == 'user' else ''
582            return scalar_pfx + self.attr['type']
583        else:
584            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
585
586    def free_needs_iter(self):
587        return 'type' not in self.attr or self.attr['type'] == 'nest'
588
589    def free(self, ri, var, ref):
590        if self.attr['type'] in scalars:
591            ri.cw.p(f"free({var}->{ref}{self.c_name});")
592        elif 'type' not in self.attr or self.attr['type'] == 'nest':
593            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
594            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
595            ri.cw.p(f"free({var}->{ref}{self.c_name});")
596        else:
597            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
598
599    def _attr_policy(self, policy):
600        return self.base_type._attr_policy(policy)
601
602    def _attr_typol(self):
603        return self.base_type._attr_typol()
604
605    def _attr_get(self, ri, var):
606        return f'n_{self.c_name}++;', None, None
607
608    def attr_put(self, ri, var):
609        if self.attr['type'] in scalars:
610            put_type = self.type
611            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
612            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
613        elif 'type' not in self.attr or self.attr['type'] == 'nest':
614            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
615            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
616                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
617        else:
618            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
619
620    def _setter_lines(self, ri, member, presence):
621        # For multi-attr we have a count, not presence, hack up the presence
622        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
623        return [f"free({member});",
624                f"{member} = {self.c_name};",
625                f"{presence} = n_{self.c_name};"]
626
627
628class TypeArrayNest(Type):
629    def is_multi_val(self):
630        return True
631
632    def presence_type(self):
633        return 'count'
634
635    def _complex_member_type(self, ri):
636        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
637            return self.nested_struct_type
638        elif self.attr['sub-type'] in scalars:
639            scalar_pfx = '__' if ri.ku_space == 'user' else ''
640            return scalar_pfx + self.attr['sub-type']
641        else:
642            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
643
644    def _attr_typol(self):
645        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
646
647    def _attr_get(self, ri, var):
648        local_vars = ['const struct nlattr *attr2;']
649        get_lines = [f'attr_{self.c_name} = attr;',
650                     'ynl_attr_for_each_nested(attr2, attr)',
651                     f'\t{var}->n_{self.c_name}++;']
652        return get_lines, None, local_vars
653
654
655class TypeNestTypeValue(Type):
656    def _complex_member_type(self, ri):
657        return self.nested_struct_type
658
659    def _attr_typol(self):
660        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
661
662    def _attr_get(self, ri, var):
663        prev = 'attr'
664        tv_args = ''
665        get_lines = []
666        local_vars = []
667        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
668                      f"parg.data = &{var}->{self.c_name};"]
669        if 'type-value' in self.attr:
670            tv_names = [c_lower(x) for x in self.attr["type-value"]]
671            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
672            local_vars += [f'__u32 {", ".join(tv_names)};']
673            for level in self.attr["type-value"]:
674                level = c_lower(level)
675                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
676                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
677                prev = 'attr_' + level
678
679            tv_args = f", {', '.join(tv_names)}"
680
681        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
682        return get_lines, init_lines, local_vars
683
684
685class Struct:
686    def __init__(self, family, space_name, type_list=None, inherited=None):
687        self.family = family
688        self.space_name = space_name
689        self.attr_set = family.attr_sets[space_name]
690        # Use list to catch comparisons with empty sets
691        self._inherited = inherited if inherited is not None else []
692        self.inherited = []
693
694        self.nested = type_list is None
695        if family.name == c_lower(space_name):
696            self.render_name = c_lower(family.name)
697        else:
698            self.render_name = c_lower(family.name + '-' + space_name)
699        self.struct_name = 'struct ' + self.render_name
700        if self.nested and space_name in family.consts:
701            self.struct_name += '_'
702        self.ptr_name = self.struct_name + ' *'
703        # All attr sets this one contains, directly or multiple levels down
704        self.child_nests = set()
705
706        self.request = False
707        self.reply = False
708        self.recursive = False
709
710        self.attr_list = []
711        self.attrs = dict()
712        if type_list is not None:
713            for t in type_list:
714                self.attr_list.append((t, self.attr_set[t]),)
715        else:
716            for t in self.attr_set:
717                self.attr_list.append((t, self.attr_set[t]),)
718
719        max_val = 0
720        self.attr_max_val = None
721        for name, attr in self.attr_list:
722            if attr.value >= max_val:
723                max_val = attr.value
724                self.attr_max_val = attr
725            self.attrs[name] = attr
726
727    def __iter__(self):
728        yield from self.attrs
729
730    def __getitem__(self, key):
731        return self.attrs[key]
732
733    def member_list(self):
734        return self.attr_list
735
736    def set_inherited(self, new_inherited):
737        if self._inherited != new_inherited:
738            raise Exception("Inheriting different members not supported")
739        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
740
741
742class EnumEntry(SpecEnumEntry):
743    def __init__(self, enum_set, yaml, prev, value_start):
744        super().__init__(enum_set, yaml, prev, value_start)
745
746        if prev:
747            self.value_change = (self.value != prev.value + 1)
748        else:
749            self.value_change = (self.value != 0)
750        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
751
752        # Added by resolve:
753        self.c_name = None
754        delattr(self, "c_name")
755
756    def resolve(self):
757        self.resolve_up(super())
758
759        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
760
761
762class EnumSet(SpecEnumSet):
763    def __init__(self, family, yaml):
764        self.render_name = c_lower(family.name + '-' + yaml['name'])
765
766        if 'enum-name' in yaml:
767            if yaml['enum-name']:
768                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
769                self.user_type = self.enum_name
770            else:
771                self.enum_name = None
772        else:
773            self.enum_name = 'enum ' + self.render_name
774
775        if self.enum_name:
776            self.user_type = self.enum_name
777        else:
778            self.user_type = 'int'
779
780        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
781
782        super().__init__(family, yaml)
783
784    def new_entry(self, entry, prev_entry, value_start):
785        return EnumEntry(self, entry, prev_entry, value_start)
786
787    def value_range(self):
788        low = min([x.value for x in self.entries.values()])
789        high = max([x.value for x in self.entries.values()])
790
791        if high - low + 1 != len(self.entries):
792            raise Exception("Can't get value range for a noncontiguous enum")
793
794        return low, high
795
796
797class AttrSet(SpecAttrSet):
798    def __init__(self, family, yaml):
799        super().__init__(family, yaml)
800
801        if self.subset_of is None:
802            if 'name-prefix' in yaml:
803                pfx = yaml['name-prefix']
804            elif self.name == family.name:
805                pfx = family.name + '-a-'
806            else:
807                pfx = f"{family.name}-a-{self.name}-"
808            self.name_prefix = c_upper(pfx)
809            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
810            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
811        else:
812            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
813            self.max_name = family.attr_sets[self.subset_of].max_name
814            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
815
816        # Added by resolve:
817        self.c_name = None
818        delattr(self, "c_name")
819
820    def resolve(self):
821        self.c_name = c_lower(self.name)
822        if self.c_name in _C_KW:
823            self.c_name += '_'
824        if self.c_name == self.family.c_name:
825            self.c_name = ''
826
827    def new_attr(self, elem, value):
828        if elem['type'] in scalars:
829            t = TypeScalar(self.family, self, elem, value)
830        elif elem['type'] == 'unused':
831            t = TypeUnused(self.family, self, elem, value)
832        elif elem['type'] == 'pad':
833            t = TypePad(self.family, self, elem, value)
834        elif elem['type'] == 'flag':
835            t = TypeFlag(self.family, self, elem, value)
836        elif elem['type'] == 'string':
837            t = TypeString(self.family, self, elem, value)
838        elif elem['type'] == 'binary':
839            t = TypeBinary(self.family, self, elem, value)
840        elif elem['type'] == 'bitfield32':
841            t = TypeBitfield32(self.family, self, elem, value)
842        elif elem['type'] == 'nest':
843            t = TypeNest(self.family, self, elem, value)
844        elif elem['type'] == 'array-nest':
845            t = TypeArrayNest(self.family, self, elem, value)
846        elif elem['type'] == 'nest-type-value':
847            t = TypeNestTypeValue(self.family, self, elem, value)
848        else:
849            raise Exception(f"No typed class for type {elem['type']}")
850
851        if 'multi-attr' in elem and elem['multi-attr']:
852            t = TypeMultiAttr(self.family, self, elem, value, t)
853
854        return t
855
856
857class Operation(SpecOperation):
858    def __init__(self, family, yaml, req_value, rsp_value):
859        super().__init__(family, yaml, req_value, rsp_value)
860
861        self.render_name = c_lower(family.name + '_' + self.name)
862
863        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
864                         ('dump' in yaml and 'request' in yaml['dump'])
865
866        self.has_ntf = False
867
868        # Added by resolve:
869        self.enum_name = None
870        delattr(self, "enum_name")
871
872    def resolve(self):
873        self.resolve_up(super())
874
875        if not self.is_async:
876            self.enum_name = self.family.op_prefix + c_upper(self.name)
877        else:
878            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
879
880    def mark_has_ntf(self):
881        self.has_ntf = True
882
883
884class Family(SpecFamily):
885    def __init__(self, file_name, exclude_ops):
886        # Added by resolve:
887        self.c_name = None
888        delattr(self, "c_name")
889        self.op_prefix = None
890        delattr(self, "op_prefix")
891        self.async_op_prefix = None
892        delattr(self, "async_op_prefix")
893        self.mcgrps = None
894        delattr(self, "mcgrps")
895        self.consts = None
896        delattr(self, "consts")
897        self.hooks = None
898        delattr(self, "hooks")
899
900        super().__init__(file_name, exclude_ops=exclude_ops)
901
902        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
903        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
904
905        if 'definitions' not in self.yaml:
906            self.yaml['definitions'] = []
907
908        if 'uapi-header' in self.yaml:
909            self.uapi_header = self.yaml['uapi-header']
910        else:
911            self.uapi_header = f"linux/{self.name}.h"
912        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
913            self.uapi_header_name = self.uapi_header[6:-2]
914        else:
915            self.uapi_header_name = self.name
916
917    def resolve(self):
918        self.resolve_up(super())
919
920        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
921            raise Exception("Codegen only supported for genetlink")
922
923        self.c_name = c_lower(self.name)
924        if 'name-prefix' in self.yaml['operations']:
925            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
926        else:
927            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
928        if 'async-prefix' in self.yaml['operations']:
929            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
930        else:
931            self.async_op_prefix = self.op_prefix
932
933        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
934
935        self.hooks = dict()
936        for when in ['pre', 'post']:
937            self.hooks[when] = dict()
938            for op_mode in ['do', 'dump']:
939                self.hooks[when][op_mode] = dict()
940                self.hooks[when][op_mode]['set'] = set()
941                self.hooks[when][op_mode]['list'] = []
942
943        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
944        self.root_sets = dict()
945        # dict space-name -> set('request', 'reply')
946        self.pure_nested_structs = dict()
947
948        self._mark_notify()
949        self._mock_up_events()
950
951        self._load_root_sets()
952        self._load_nested_sets()
953        self._load_attr_use()
954        self._load_hooks()
955
956        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
957        if self.kernel_policy == 'global':
958            self._load_global_policy()
959
960    def new_enum(self, elem):
961        return EnumSet(self, elem)
962
963    def new_attr_set(self, elem):
964        return AttrSet(self, elem)
965
966    def new_operation(self, elem, req_value, rsp_value):
967        return Operation(self, elem, req_value, rsp_value)
968
969    def _mark_notify(self):
970        for op in self.msgs.values():
971            if 'notify' in op:
972                self.ops[op['notify']].mark_has_ntf()
973
974    # Fake a 'do' equivalent of all events, so that we can render their response parsing
975    def _mock_up_events(self):
976        for op in self.yaml['operations']['list']:
977            if 'event' in op:
978                op['do'] = {
979                    'reply': {
980                        'attributes': op['event']['attributes']
981                    }
982                }
983
984    def _load_root_sets(self):
985        for op_name, op in self.msgs.items():
986            if 'attribute-set' not in op:
987                continue
988
989            req_attrs = set()
990            rsp_attrs = set()
991            for op_mode in ['do', 'dump']:
992                if op_mode in op and 'request' in op[op_mode]:
993                    req_attrs.update(set(op[op_mode]['request']['attributes']))
994                if op_mode in op and 'reply' in op[op_mode]:
995                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
996            if 'event' in op:
997                rsp_attrs.update(set(op['event']['attributes']))
998
999            if op['attribute-set'] not in self.root_sets:
1000                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1001            else:
1002                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1003                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1004
1005    def _sort_pure_types(self):
1006        # Try to reorder according to dependencies
1007        pns_key_list = list(self.pure_nested_structs.keys())
1008        pns_key_seen = set()
1009        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1010        for _ in range(rounds):
1011            if len(pns_key_list) == 0:
1012                break
1013            name = pns_key_list.pop(0)
1014            finished = True
1015            for _, spec in self.attr_sets[name].items():
1016                if 'nested-attributes' in spec:
1017                    nested = spec['nested-attributes']
1018                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1019                    if self.pure_nested_structs[nested].recursive:
1020                        continue
1021                    if nested not in pns_key_seen:
1022                        # Dicts are sorted, this will make struct last
1023                        struct = self.pure_nested_structs.pop(name)
1024                        self.pure_nested_structs[name] = struct
1025                        finished = False
1026                        break
1027            if finished:
1028                pns_key_seen.add(name)
1029            else:
1030                pns_key_list.append(name)
1031
1032    def _load_nested_sets(self):
1033        attr_set_queue = list(self.root_sets.keys())
1034        attr_set_seen = set(self.root_sets.keys())
1035
1036        while len(attr_set_queue):
1037            a_set = attr_set_queue.pop(0)
1038            for attr, spec in self.attr_sets[a_set].items():
1039                if 'nested-attributes' not in spec:
1040                    continue
1041
1042                nested = spec['nested-attributes']
1043                if nested not in attr_set_seen:
1044                    attr_set_queue.append(nested)
1045                    attr_set_seen.add(nested)
1046
1047                inherit = set()
1048                if nested not in self.root_sets:
1049                    if nested not in self.pure_nested_structs:
1050                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1051                else:
1052                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1053
1054                if 'type-value' in spec:
1055                    if nested in self.root_sets:
1056                        raise Exception("Inheriting members to a space used as root not supported")
1057                    inherit.update(set(spec['type-value']))
1058                elif spec['type'] == 'array-nest':
1059                    inherit.add('idx')
1060                self.pure_nested_structs[nested].set_inherited(inherit)
1061
1062        for root_set, rs_members in self.root_sets.items():
1063            for attr, spec in self.attr_sets[root_set].items():
1064                if 'nested-attributes' in spec:
1065                    nested = spec['nested-attributes']
1066                    if attr in rs_members['request']:
1067                        self.pure_nested_structs[nested].request = True
1068                    if attr in rs_members['reply']:
1069                        self.pure_nested_structs[nested].reply = True
1070
1071        self._sort_pure_types()
1072
1073        # Propagate the request / reply / recursive
1074        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1075            for _, spec in self.attr_sets[attr_set].items():
1076                if 'nested-attributes' in spec:
1077                    child_name = spec['nested-attributes']
1078                    struct.child_nests.add(child_name)
1079                    child = self.pure_nested_structs.get(child_name)
1080                    if child:
1081                        if not child.recursive:
1082                            struct.child_nests.update(child.child_nests)
1083                        child.request |= struct.request
1084                        child.reply |= struct.reply
1085                if attr_set in struct.child_nests:
1086                    struct.recursive = True
1087
1088        self._sort_pure_types()
1089
1090    def _load_attr_use(self):
1091        for _, struct in self.pure_nested_structs.items():
1092            if struct.request:
1093                for _, arg in struct.member_list():
1094                    arg.request = True
1095            if struct.reply:
1096                for _, arg in struct.member_list():
1097                    arg.reply = True
1098
1099        for root_set, rs_members in self.root_sets.items():
1100            for attr, spec in self.attr_sets[root_set].items():
1101                if attr in rs_members['request']:
1102                    spec.request = True
1103                if attr in rs_members['reply']:
1104                    spec.reply = True
1105
1106    def _load_global_policy(self):
1107        global_set = set()
1108        attr_set_name = None
1109        for op_name, op in self.ops.items():
1110            if not op:
1111                continue
1112            if 'attribute-set' not in op:
1113                continue
1114
1115            if attr_set_name is None:
1116                attr_set_name = op['attribute-set']
1117            if attr_set_name != op['attribute-set']:
1118                raise Exception('For a global policy all ops must use the same set')
1119
1120            for op_mode in ['do', 'dump']:
1121                if op_mode in op:
1122                    req = op[op_mode].get('request')
1123                    if req:
1124                        global_set.update(req.get('attributes', []))
1125
1126        self.global_policy = []
1127        self.global_policy_set = attr_set_name
1128        for attr in self.attr_sets[attr_set_name]:
1129            if attr in global_set:
1130                self.global_policy.append(attr)
1131
1132    def _load_hooks(self):
1133        for op in self.ops.values():
1134            for op_mode in ['do', 'dump']:
1135                if op_mode not in op:
1136                    continue
1137                for when in ['pre', 'post']:
1138                    if when not in op[op_mode]:
1139                        continue
1140                    name = op[op_mode][when]
1141                    if name in self.hooks[when][op_mode]['set']:
1142                        continue
1143                    self.hooks[when][op_mode]['set'].add(name)
1144                    self.hooks[when][op_mode]['list'].append(name)
1145
1146
1147class RenderInfo:
1148    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1149        self.family = family
1150        self.nl = cw.nlib
1151        self.ku_space = ku_space
1152        self.op_mode = op_mode
1153        self.op = op
1154
1155        self.fixed_hdr = None
1156        if op and op.fixed_header:
1157            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1158
1159        # 'do' and 'dump' response parsing is identical
1160        self.type_consistent = True
1161        if op_mode != 'do' and 'dump' in op:
1162            if 'do' in op:
1163                if ('reply' in op['do']) != ('reply' in op["dump"]):
1164                    self.type_consistent = False
1165                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1166                    self.type_consistent = False
1167            else:
1168                self.type_consistent = False
1169
1170        self.attr_set = attr_set
1171        if not self.attr_set:
1172            self.attr_set = op['attribute-set']
1173
1174        self.type_name_conflict = False
1175        if op:
1176            self.type_name = c_lower(op.name)
1177        else:
1178            self.type_name = c_lower(attr_set)
1179            if attr_set in family.consts:
1180                self.type_name_conflict = True
1181
1182        self.cw = cw
1183
1184        self.struct = dict()
1185        if op_mode == 'notify':
1186            op_mode = 'do'
1187        for op_dir in ['request', 'reply']:
1188            if op:
1189                type_list = []
1190                if op_dir in op[op_mode]:
1191                    type_list = op[op_mode][op_dir]['attributes']
1192                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1193        if op_mode == 'event':
1194            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1195
1196
1197class CodeWriter:
1198    def __init__(self, nlib, out_file=None, overwrite=True):
1199        self.nlib = nlib
1200        self._overwrite = overwrite
1201
1202        self._nl = False
1203        self._block_end = False
1204        self._silent_block = False
1205        self._ind = 0
1206        self._ifdef_block = None
1207        if out_file is None:
1208            self._out = os.sys.stdout
1209        else:
1210            self._out = tempfile.NamedTemporaryFile('w+')
1211            self._out_file = out_file
1212
1213    def __del__(self):
1214        self.close_out_file()
1215
1216    def close_out_file(self):
1217        if self._out == os.sys.stdout:
1218            return
1219        # Avoid modifying the file if contents didn't change
1220        self._out.flush()
1221        if not self._overwrite and os.path.isfile(self._out_file):
1222            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1223                return
1224        with open(self._out_file, 'w+') as out_file:
1225            self._out.seek(0)
1226            shutil.copyfileobj(self._out, out_file)
1227            self._out.close()
1228        self._out = os.sys.stdout
1229
1230    @classmethod
1231    def _is_cond(cls, line):
1232        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1233
1234    def p(self, line, add_ind=0):
1235        if self._block_end:
1236            self._block_end = False
1237            if line.startswith('else'):
1238                line = '} ' + line
1239            else:
1240                self._out.write('\t' * self._ind + '}\n')
1241
1242        if self._nl:
1243            self._out.write('\n')
1244            self._nl = False
1245
1246        ind = self._ind
1247        if line[-1] == ':':
1248            ind -= 1
1249        if self._silent_block:
1250            ind += 1
1251        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1252        if line[0] == '#':
1253            ind = 0
1254        if add_ind:
1255            ind += add_ind
1256        self._out.write('\t' * ind + line + '\n')
1257
1258    def nl(self):
1259        self._nl = True
1260
1261    def block_start(self, line=''):
1262        if line:
1263            line = line + ' '
1264        self.p(line + '{')
1265        self._ind += 1
1266
1267    def block_end(self, line=''):
1268        if line and line[0] not in {';', ','}:
1269            line = ' ' + line
1270        self._ind -= 1
1271        self._nl = False
1272        if not line:
1273            # Delay printing closing bracket in case "else" comes next
1274            if self._block_end:
1275                self._out.write('\t' * (self._ind + 1) + '}\n')
1276            self._block_end = True
1277        else:
1278            self.p('}' + line)
1279
1280    def write_doc_line(self, doc, indent=True):
1281        words = doc.split()
1282        line = ' *'
1283        for word in words:
1284            if len(line) + len(word) >= 79:
1285                self.p(line)
1286                line = ' *'
1287                if indent:
1288                    line += '  '
1289            line += ' ' + word
1290        self.p(line)
1291
1292    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1293        if not args:
1294            args = ['void']
1295
1296        if doc:
1297            self.p('/*')
1298            self.p(' * ' + doc)
1299            self.p(' */')
1300
1301        oneline = qual_ret
1302        if qual_ret[-1] != '*':
1303            oneline += ' '
1304        oneline += f"{name}({', '.join(args)}){suffix}"
1305
1306        if len(oneline) < 80:
1307            self.p(oneline)
1308            return
1309
1310        v = qual_ret
1311        if len(v) > 3:
1312            self.p(v)
1313            v = ''
1314        elif qual_ret[-1] != '*':
1315            v += ' '
1316        v += name + '('
1317        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1318        delta_ind = len(v) - len(ind)
1319        v += args[0]
1320        i = 1
1321        while i < len(args):
1322            next_len = len(v) + len(args[i])
1323            if v[0] == '\t':
1324                next_len += delta_ind
1325            if next_len > 76:
1326                self.p(v + ',')
1327                v = ind
1328            else:
1329                v += ', '
1330            v += args[i]
1331            i += 1
1332        self.p(v + ')' + suffix)
1333
1334    def write_func_lvar(self, local_vars):
1335        if not local_vars:
1336            return
1337
1338        if type(local_vars) is str:
1339            local_vars = [local_vars]
1340
1341        local_vars.sort(key=len, reverse=True)
1342        for var in local_vars:
1343            self.p(var)
1344        self.nl()
1345
1346    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1347        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1348        self.write_func_lvar(local_vars=local_vars)
1349
1350        self.block_start()
1351        for line in body:
1352            self.p(line)
1353        self.block_end()
1354
1355    def writes_defines(self, defines):
1356        longest = 0
1357        for define in defines:
1358            if len(define[0]) > longest:
1359                longest = len(define[0])
1360        longest = ((longest + 8) // 8) * 8
1361        for define in defines:
1362            line = '#define ' + define[0]
1363            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1364            if type(define[1]) is int:
1365                line += str(define[1])
1366            elif type(define[1]) is str:
1367                line += '"' + define[1] + '"'
1368            self.p(line)
1369
1370    def write_struct_init(self, members):
1371        longest = max([len(x[0]) for x in members])
1372        longest += 1  # because we prepend a .
1373        longest = ((longest + 8) // 8) * 8
1374        for one in members:
1375            line = '.' + one[0]
1376            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1377            line += '= ' + str(one[1]) + ','
1378            self.p(line)
1379
1380    def ifdef_block(self, config):
1381        config_option = None
1382        if config:
1383            config_option = 'CONFIG_' + c_upper(config)
1384        if self._ifdef_block == config_option:
1385            return
1386
1387        if self._ifdef_block:
1388            self.p('#endif /* ' + self._ifdef_block + ' */')
1389        if config_option:
1390            self.p('#ifdef ' + config_option)
1391        self._ifdef_block = config_option
1392
1393
1394scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1395
1396direction_to_suffix = {
1397    'reply': '_rsp',
1398    'request': '_req',
1399    '': ''
1400}
1401
1402op_mode_to_wrapper = {
1403    'do': '',
1404    'dump': '_list',
1405    'notify': '_ntf',
1406    'event': '',
1407}
1408
1409_C_KW = {
1410    'auto',
1411    'bool',
1412    'break',
1413    'case',
1414    'char',
1415    'const',
1416    'continue',
1417    'default',
1418    'do',
1419    'double',
1420    'else',
1421    'enum',
1422    'extern',
1423    'float',
1424    'for',
1425    'goto',
1426    'if',
1427    'inline',
1428    'int',
1429    'long',
1430    'register',
1431    'return',
1432    'short',
1433    'signed',
1434    'sizeof',
1435    'static',
1436    'struct',
1437    'switch',
1438    'typedef',
1439    'union',
1440    'unsigned',
1441    'void',
1442    'volatile',
1443    'while'
1444}
1445
1446
1447def rdir(direction):
1448    if direction == 'reply':
1449        return 'request'
1450    if direction == 'request':
1451        return 'reply'
1452    return direction
1453
1454
1455def op_prefix(ri, direction, deref=False):
1456    suffix = f"_{ri.type_name}"
1457
1458    if not ri.op_mode or ri.op_mode == 'do':
1459        suffix += f"{direction_to_suffix[direction]}"
1460    else:
1461        if direction == 'request':
1462            suffix += '_req_dump'
1463        else:
1464            if ri.type_consistent:
1465                if deref:
1466                    suffix += f"{direction_to_suffix[direction]}"
1467                else:
1468                    suffix += op_mode_to_wrapper[ri.op_mode]
1469            else:
1470                suffix += '_rsp'
1471                suffix += '_dump' if deref else '_list'
1472
1473    return f"{ri.family.c_name}{suffix}"
1474
1475
1476def type_name(ri, direction, deref=False):
1477    return f"struct {op_prefix(ri, direction, deref=deref)}"
1478
1479
1480def print_prototype(ri, direction, terminate=True, doc=None):
1481    suffix = ';' if terminate else ''
1482
1483    fname = ri.op.render_name
1484    if ri.op_mode == 'dump':
1485        fname += '_dump'
1486
1487    args = ['struct ynl_sock *ys']
1488    if 'request' in ri.op[ri.op_mode]:
1489        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1490
1491    ret = 'int'
1492    if 'reply' in ri.op[ri.op_mode]:
1493        ret = f"{type_name(ri, rdir(direction))} *"
1494
1495    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1496
1497
1498def print_req_prototype(ri):
1499    print_prototype(ri, "request", doc=ri.op['doc'])
1500
1501
1502def print_dump_prototype(ri):
1503    print_prototype(ri, "request")
1504
1505
1506def put_typol_fwd(cw, struct):
1507    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1508
1509
1510def put_typol(cw, struct):
1511    type_max = struct.attr_set.max_name
1512    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1513
1514    for _, arg in struct.member_list():
1515        arg.attr_typol(cw)
1516
1517    cw.block_end(line=';')
1518    cw.nl()
1519
1520    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1521    cw.p(f'.max_attr = {type_max},')
1522    cw.p(f'.table = {struct.render_name}_policy,')
1523    cw.block_end(line=';')
1524    cw.nl()
1525
1526
1527def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1528    args = [f'int {arg_name}']
1529    if enum:
1530        args = [enum.user_type + ' ' + arg_name]
1531    cw.write_func_prot('const char *', f'{render_name}_str', args)
1532    cw.block_start()
1533    if enum and enum.type == 'flags':
1534        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1535    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1536    cw.p('return NULL;')
1537    cw.p(f'return {map_name}[{arg_name}];')
1538    cw.block_end()
1539    cw.nl()
1540
1541
1542def put_op_name_fwd(family, cw):
1543    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1544
1545
1546def put_op_name(family, cw):
1547    map_name = f'{family.c_name}_op_strmap'
1548    cw.block_start(line=f"static const char * const {map_name}[] =")
1549    for op_name, op in family.msgs.items():
1550        if op.rsp_value:
1551            # Make sure we don't add duplicated entries, if multiple commands
1552            # produce the same response in legacy families.
1553            if family.rsp_by_value[op.rsp_value] != op:
1554                cw.p(f'// skip "{op_name}", duplicate reply value')
1555                continue
1556
1557            if op.req_value == op.rsp_value:
1558                cw.p(f'[{op.enum_name}] = "{op_name}",')
1559            else:
1560                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1561    cw.block_end(line=';')
1562    cw.nl()
1563
1564    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1565
1566
1567def put_enum_to_str_fwd(family, cw, enum):
1568    args = [enum.user_type + ' value']
1569    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1570
1571
1572def put_enum_to_str(family, cw, enum):
1573    map_name = f'{enum.render_name}_strmap'
1574    cw.block_start(line=f"static const char * const {map_name}[] =")
1575    for entry in enum.entries.values():
1576        cw.p(f'[{entry.value}] = "{entry.name}",')
1577    cw.block_end(line=';')
1578    cw.nl()
1579
1580    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1581
1582
1583def put_req_nested_prototype(ri, struct, suffix=';'):
1584    func_args = ['struct nlmsghdr *nlh',
1585                 'unsigned int attr_type',
1586                 f'{struct.ptr_name}obj']
1587
1588    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1589                          suffix=suffix)
1590
1591
1592def put_req_nested(ri, struct):
1593    put_req_nested_prototype(ri, struct, suffix='')
1594    ri.cw.block_start()
1595    ri.cw.write_func_lvar('struct nlattr *nest;')
1596
1597    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1598
1599    for _, arg in struct.member_list():
1600        arg.attr_put(ri, "obj")
1601
1602    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1603
1604    ri.cw.nl()
1605    ri.cw.p('return 0;')
1606    ri.cw.block_end()
1607    ri.cw.nl()
1608
1609
1610def _multi_parse(ri, struct, init_lines, local_vars):
1611    if struct.nested:
1612        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1613    else:
1614        if ri.fixed_hdr:
1615            local_vars += ['void *hdr;']
1616        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1617
1618    array_nests = set()
1619    multi_attrs = set()
1620    needs_parg = False
1621    for arg, aspec in struct.member_list():
1622        if aspec['type'] == 'array-nest':
1623            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1624            array_nests.add(arg)
1625        if 'multi-attr' in aspec:
1626            multi_attrs.add(arg)
1627        needs_parg |= 'nested-attributes' in aspec
1628    if array_nests or multi_attrs:
1629        local_vars.append('int i;')
1630    if needs_parg:
1631        local_vars.append('struct ynl_parse_arg parg;')
1632        init_lines.append('parg.ys = yarg->ys;')
1633
1634    all_multi = array_nests | multi_attrs
1635
1636    for anest in sorted(all_multi):
1637        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1638
1639    ri.cw.block_start()
1640    ri.cw.write_func_lvar(local_vars)
1641
1642    for line in init_lines:
1643        ri.cw.p(line)
1644    ri.cw.nl()
1645
1646    for arg in struct.inherited:
1647        ri.cw.p(f'dst->{arg} = {arg};')
1648
1649    if ri.fixed_hdr:
1650        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1651        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1652    for anest in sorted(all_multi):
1653        aspec = struct[anest]
1654        ri.cw.p(f"if (dst->{aspec.c_name})")
1655        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1656
1657    ri.cw.nl()
1658    ri.cw.block_start(line=iter_line)
1659    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1660    ri.cw.nl()
1661
1662    first = True
1663    for _, arg in struct.member_list():
1664        good = arg.attr_get(ri, 'dst', first=first)
1665        # First may be 'unused' or 'pad', ignore those
1666        first &= not good
1667
1668    ri.cw.block_end()
1669    ri.cw.nl()
1670
1671    for anest in sorted(array_nests):
1672        aspec = struct[anest]
1673
1674        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1675        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1676        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1677        ri.cw.p('i = 0;')
1678        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1679        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1680        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1681        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1682        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1683        ri.cw.p('i++;')
1684        ri.cw.block_end()
1685        ri.cw.block_end()
1686    ri.cw.nl()
1687
1688    for anest in sorted(multi_attrs):
1689        aspec = struct[anest]
1690        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1691        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1692        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1693        ri.cw.p('i = 0;')
1694        if 'nested-attributes' in aspec:
1695            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1696        ri.cw.block_start(line=iter_line)
1697        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1698        if 'nested-attributes' in aspec:
1699            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1700            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1701            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1702        elif aspec.type in scalars:
1703            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1704        else:
1705            raise Exception('Nest parsing type not supported yet')
1706        ri.cw.p('i++;')
1707        ri.cw.block_end()
1708        ri.cw.block_end()
1709        ri.cw.block_end()
1710    ri.cw.nl()
1711
1712    if struct.nested:
1713        ri.cw.p('return 0;')
1714    else:
1715        ri.cw.p('return YNL_PARSE_CB_OK;')
1716    ri.cw.block_end()
1717    ri.cw.nl()
1718
1719
1720def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1721    func_args = ['struct ynl_parse_arg *yarg',
1722                 'const struct nlattr *nested']
1723    for arg in struct.inherited:
1724        func_args.append('__u32 ' + arg)
1725
1726    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1727                          suffix=suffix)
1728
1729
1730def parse_rsp_nested(ri, struct):
1731    parse_rsp_nested_prototype(ri, struct, suffix='')
1732
1733    local_vars = ['const struct nlattr *attr;',
1734                  f'{struct.ptr_name}dst = yarg->data;']
1735    init_lines = []
1736
1737    _multi_parse(ri, struct, init_lines, local_vars)
1738
1739
1740def parse_rsp_msg(ri, deref=False):
1741    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1742        return
1743
1744    func_args = ['const struct nlmsghdr *nlh',
1745                 'struct ynl_parse_arg *yarg']
1746
1747    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1748                  'const struct nlattr *attr;']
1749    init_lines = ['dst = yarg->data;']
1750
1751    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1752
1753    if ri.struct["reply"].member_list():
1754        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1755    else:
1756        # Empty reply
1757        ri.cw.block_start()
1758        ri.cw.p('return YNL_PARSE_CB_OK;')
1759        ri.cw.block_end()
1760        ri.cw.nl()
1761
1762
1763def print_req(ri):
1764    ret_ok = '0'
1765    ret_err = '-1'
1766    direction = "request"
1767    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1768                  'struct nlmsghdr *nlh;',
1769                  'int err;']
1770
1771    if 'reply' in ri.op[ri.op_mode]:
1772        ret_ok = 'rsp'
1773        ret_err = 'NULL'
1774        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1775
1776    if ri.fixed_hdr:
1777        local_vars += ['size_t hdr_len;',
1778                       'void *hdr;']
1779
1780    print_prototype(ri, direction, terminate=False)
1781    ri.cw.block_start()
1782    ri.cw.write_func_lvar(local_vars)
1783
1784    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1785
1786    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1787    if 'reply' in ri.op[ri.op_mode]:
1788        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1789    ri.cw.nl()
1790
1791    if ri.fixed_hdr:
1792        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1793        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1794        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1795        ri.cw.nl()
1796
1797    for _, attr in ri.struct["request"].member_list():
1798        attr.attr_put(ri, "req")
1799    ri.cw.nl()
1800
1801    if 'reply' in ri.op[ri.op_mode]:
1802        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1803        ri.cw.p('yrs.yarg.data = rsp;')
1804        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1805        if ri.op.value is not None:
1806            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1807        else:
1808            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1809        ri.cw.nl()
1810    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1811    ri.cw.p('if (err < 0)')
1812    if 'reply' in ri.op[ri.op_mode]:
1813        ri.cw.p('goto err_free;')
1814    else:
1815        ri.cw.p('return -1;')
1816    ri.cw.nl()
1817
1818    ri.cw.p(f"return {ret_ok};")
1819    ri.cw.nl()
1820
1821    if 'reply' in ri.op[ri.op_mode]:
1822        ri.cw.p('err_free:')
1823        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1824        ri.cw.p(f"return {ret_err};")
1825
1826    ri.cw.block_end()
1827
1828
1829def print_dump(ri):
1830    direction = "request"
1831    print_prototype(ri, direction, terminate=False)
1832    ri.cw.block_start()
1833    local_vars = ['struct ynl_dump_state yds = {};',
1834                  'struct nlmsghdr *nlh;',
1835                  'int err;']
1836
1837    if ri.fixed_hdr:
1838        local_vars += ['size_t hdr_len;',
1839                       'void *hdr;']
1840
1841    ri.cw.write_func_lvar(local_vars)
1842
1843    ri.cw.p('yds.yarg.ys = ys;')
1844    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1845    ri.cw.p("yds.yarg.data = NULL;")
1846    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1847    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1848    if ri.op.value is not None:
1849        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1850    else:
1851        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1852    ri.cw.nl()
1853    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1854
1855    if ri.fixed_hdr:
1856        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1857        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1858        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1859        ri.cw.nl()
1860
1861    if "request" in ri.op[ri.op_mode]:
1862        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1863        ri.cw.nl()
1864        for _, attr in ri.struct["request"].member_list():
1865            attr.attr_put(ri, "req")
1866    ri.cw.nl()
1867
1868    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1869    ri.cw.p('if (err < 0)')
1870    ri.cw.p('goto free_list;')
1871    ri.cw.nl()
1872
1873    ri.cw.p('return yds.first;')
1874    ri.cw.nl()
1875    ri.cw.p('free_list:')
1876    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1877    ri.cw.p('return NULL;')
1878    ri.cw.block_end()
1879
1880
1881def call_free(ri, direction, var):
1882    return f"{op_prefix(ri, direction)}_free({var});"
1883
1884
1885def free_arg_name(direction):
1886    if direction:
1887        return direction_to_suffix[direction][1:]
1888    return 'obj'
1889
1890
1891def print_alloc_wrapper(ri, direction):
1892    name = op_prefix(ri, direction)
1893    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1894    ri.cw.block_start()
1895    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1896    ri.cw.block_end()
1897
1898
1899def print_free_prototype(ri, direction, suffix=';'):
1900    name = op_prefix(ri, direction)
1901    struct_name = name
1902    if ri.type_name_conflict:
1903        struct_name += '_'
1904    arg = free_arg_name(direction)
1905    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1906
1907
1908def _print_type(ri, direction, struct):
1909    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1910    if not direction and ri.type_name_conflict:
1911        suffix += '_'
1912
1913    if ri.op_mode == 'dump':
1914        suffix += '_dump'
1915
1916    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1917
1918    if ri.fixed_hdr:
1919        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1920        ri.cw.nl()
1921
1922    meta_started = False
1923    for _, attr in struct.member_list():
1924        for type_filter in ['len', 'bit']:
1925            line = attr.presence_member(ri.ku_space, type_filter)
1926            if line:
1927                if not meta_started:
1928                    ri.cw.block_start(line=f"struct")
1929                    meta_started = True
1930                ri.cw.p(line)
1931    if meta_started:
1932        ri.cw.block_end(line='_present;')
1933        ri.cw.nl()
1934
1935    for arg in struct.inherited:
1936        ri.cw.p(f"__u32 {arg};")
1937
1938    for _, attr in struct.member_list():
1939        attr.struct_member(ri)
1940
1941    ri.cw.block_end(line=';')
1942    ri.cw.nl()
1943
1944
1945def print_type(ri, direction):
1946    _print_type(ri, direction, ri.struct[direction])
1947
1948
1949def print_type_full(ri, struct):
1950    _print_type(ri, "", struct)
1951
1952
1953def print_type_helpers(ri, direction, deref=False):
1954    print_free_prototype(ri, direction)
1955    ri.cw.nl()
1956
1957    if ri.ku_space == 'user' and direction == 'request':
1958        for _, attr in ri.struct[direction].member_list():
1959            attr.setter(ri, ri.attr_set, direction, deref=deref)
1960    ri.cw.nl()
1961
1962
1963def print_req_type_helpers(ri):
1964    if len(ri.struct["request"].attr_list) == 0:
1965        return
1966    print_alloc_wrapper(ri, "request")
1967    print_type_helpers(ri, "request")
1968
1969
1970def print_rsp_type_helpers(ri):
1971    if 'reply' not in ri.op[ri.op_mode]:
1972        return
1973    print_type_helpers(ri, "reply")
1974
1975
1976def print_parse_prototype(ri, direction, terminate=True):
1977    suffix = "_rsp" if direction == "reply" else "_req"
1978    term = ';' if terminate else ''
1979
1980    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1981                          ['const struct nlattr **tb',
1982                           f"struct {ri.op.render_name}{suffix} *req"],
1983                          suffix=term)
1984
1985
1986def print_req_type(ri):
1987    if len(ri.struct["request"].attr_list) == 0:
1988        return
1989    print_type(ri, "request")
1990
1991
1992def print_req_free(ri):
1993    if 'request' not in ri.op[ri.op_mode]:
1994        return
1995    _free_type(ri, 'request', ri.struct['request'])
1996
1997
1998def print_rsp_type(ri):
1999    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2000        direction = 'reply'
2001    elif ri.op_mode == 'event':
2002        direction = 'reply'
2003    else:
2004        return
2005    print_type(ri, direction)
2006
2007
2008def print_wrapped_type(ri):
2009    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2010    if ri.op_mode == 'dump':
2011        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2012    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2013        ri.cw.p('__u16 family;')
2014        ri.cw.p('__u8 cmd;')
2015        ri.cw.p('struct ynl_ntf_base_type *next;')
2016        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2017    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2018    ri.cw.block_end(line=';')
2019    ri.cw.nl()
2020    print_free_prototype(ri, 'reply')
2021    ri.cw.nl()
2022
2023
2024def _free_type_members_iter(ri, struct):
2025    for _, attr in struct.member_list():
2026        if attr.free_needs_iter():
2027            ri.cw.p('unsigned int i;')
2028            ri.cw.nl()
2029            break
2030
2031
2032def _free_type_members(ri, var, struct, ref=''):
2033    for _, attr in struct.member_list():
2034        attr.free(ri, var, ref)
2035
2036
2037def _free_type(ri, direction, struct):
2038    var = free_arg_name(direction)
2039
2040    print_free_prototype(ri, direction, suffix='')
2041    ri.cw.block_start()
2042    _free_type_members_iter(ri, struct)
2043    _free_type_members(ri, var, struct)
2044    if direction:
2045        ri.cw.p(f'free({var});')
2046    ri.cw.block_end()
2047    ri.cw.nl()
2048
2049
2050def free_rsp_nested_prototype(ri):
2051        print_free_prototype(ri, "")
2052
2053
2054def free_rsp_nested(ri, struct):
2055    _free_type(ri, "", struct)
2056
2057
2058def print_rsp_free(ri):
2059    if 'reply' not in ri.op[ri.op_mode]:
2060        return
2061    _free_type(ri, 'reply', ri.struct['reply'])
2062
2063
2064def print_dump_type_free(ri):
2065    sub_type = type_name(ri, 'reply')
2066
2067    print_free_prototype(ri, 'reply', suffix='')
2068    ri.cw.block_start()
2069    ri.cw.p(f"{sub_type} *next = rsp;")
2070    ri.cw.nl()
2071    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2072    _free_type_members_iter(ri, ri.struct['reply'])
2073    ri.cw.p('rsp = next;')
2074    ri.cw.p('next = rsp->next;')
2075    ri.cw.nl()
2076
2077    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2078    ri.cw.p(f'free(rsp);')
2079    ri.cw.block_end()
2080    ri.cw.block_end()
2081    ri.cw.nl()
2082
2083
2084def print_ntf_type_free(ri):
2085    print_free_prototype(ri, 'reply', suffix='')
2086    ri.cw.block_start()
2087    _free_type_members_iter(ri, ri.struct['reply'])
2088    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2089    ri.cw.p(f'free(rsp);')
2090    ri.cw.block_end()
2091    ri.cw.nl()
2092
2093
2094def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2095    if terminate and ri and policy_should_be_static(struct.family):
2096        return
2097
2098    if terminate:
2099        prefix = 'extern '
2100    else:
2101        if ri and policy_should_be_static(struct.family):
2102            prefix = 'static '
2103        else:
2104            prefix = ''
2105
2106    suffix = ';' if terminate else ' = {'
2107
2108    max_attr = struct.attr_max_val
2109    if ri:
2110        name = ri.op.render_name
2111        if ri.op.dual_policy:
2112            name += '_' + ri.op_mode
2113    else:
2114        name = struct.render_name
2115    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2116
2117
2118def print_req_policy(cw, struct, ri=None):
2119    if ri and ri.op:
2120        cw.ifdef_block(ri.op.get('config-cond', None))
2121    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2122    for _, arg in struct.member_list():
2123        arg.attr_policy(cw)
2124    cw.p("};")
2125    cw.ifdef_block(None)
2126    cw.nl()
2127
2128
2129def kernel_can_gen_family_struct(family):
2130    return family.proto == 'genetlink'
2131
2132
2133def policy_should_be_static(family):
2134    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2135
2136
2137def print_kernel_policy_ranges(family, cw):
2138    first = True
2139    for _, attr_set in family.attr_sets.items():
2140        if attr_set.subset_of:
2141            continue
2142
2143        for _, attr in attr_set.items():
2144            if not attr.request:
2145                continue
2146            if 'full-range' not in attr.checks:
2147                continue
2148
2149            if first:
2150                cw.p('/* Integer value ranges */')
2151                first = False
2152
2153            sign = '' if attr.type[0] == 'u' else '_signed'
2154            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2155            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2156            members = []
2157            if 'min' in attr.checks:
2158                members.append(('min', str(attr.get_limit('min')) + suffix))
2159            if 'max' in attr.checks:
2160                members.append(('max', str(attr.get_limit('max')) + suffix))
2161            cw.write_struct_init(members)
2162            cw.block_end(line=';')
2163            cw.nl()
2164
2165
2166def print_kernel_op_table_fwd(family, cw, terminate):
2167    exported = not kernel_can_gen_family_struct(family)
2168
2169    if not terminate or exported:
2170        cw.p(f"/* Ops table for {family.name} */")
2171
2172        pol_to_struct = {'global': 'genl_small_ops',
2173                         'per-op': 'genl_ops',
2174                         'split': 'genl_split_ops'}
2175        struct_type = pol_to_struct[family.kernel_policy]
2176
2177        if not exported:
2178            cnt = ""
2179        elif family.kernel_policy == 'split':
2180            cnt = 0
2181            for op in family.ops.values():
2182                if 'do' in op:
2183                    cnt += 1
2184                if 'dump' in op:
2185                    cnt += 1
2186        else:
2187            cnt = len(family.ops)
2188
2189        qual = 'static const' if not exported else 'const'
2190        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2191        if terminate:
2192            cw.p(f"extern {line};")
2193        else:
2194            cw.block_start(line=line + ' =')
2195
2196    if not terminate:
2197        return
2198
2199    cw.nl()
2200    for name in family.hooks['pre']['do']['list']:
2201        cw.write_func_prot('int', c_lower(name),
2202                           ['const struct genl_split_ops *ops',
2203                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2204    for name in family.hooks['post']['do']['list']:
2205        cw.write_func_prot('void', c_lower(name),
2206                           ['const struct genl_split_ops *ops',
2207                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2208    for name in family.hooks['pre']['dump']['list']:
2209        cw.write_func_prot('int', c_lower(name),
2210                           ['struct netlink_callback *cb'], suffix=';')
2211    for name in family.hooks['post']['dump']['list']:
2212        cw.write_func_prot('int', c_lower(name),
2213                           ['struct netlink_callback *cb'], suffix=';')
2214
2215    cw.nl()
2216
2217    for op_name, op in family.ops.items():
2218        if op.is_async:
2219            continue
2220
2221        if 'do' in op:
2222            name = c_lower(f"{family.name}-nl-{op_name}-doit")
2223            cw.write_func_prot('int', name,
2224                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2225
2226        if 'dump' in op:
2227            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2228            cw.write_func_prot('int', name,
2229                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2230    cw.nl()
2231
2232
2233def print_kernel_op_table_hdr(family, cw):
2234    print_kernel_op_table_fwd(family, cw, terminate=True)
2235
2236
2237def print_kernel_op_table(family, cw):
2238    print_kernel_op_table_fwd(family, cw, terminate=False)
2239    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2240        for op_name, op in family.ops.items():
2241            if op.is_async:
2242                continue
2243
2244            cw.ifdef_block(op.get('config-cond', None))
2245            cw.block_start()
2246            members = [('cmd', op.enum_name)]
2247            if 'dont-validate' in op:
2248                members.append(('validate',
2249                                ' | '.join([c_upper('genl-dont-validate-' + x)
2250                                            for x in op['dont-validate']])), )
2251            for op_mode in ['do', 'dump']:
2252                if op_mode in op:
2253                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2254                    members.append((op_mode + 'it', name))
2255            if family.kernel_policy == 'per-op':
2256                struct = Struct(family, op['attribute-set'],
2257                                type_list=op['do']['request']['attributes'])
2258
2259                name = c_lower(f"{family.name}-{op_name}-nl-policy")
2260                members.append(('policy', name))
2261                members.append(('maxattr', struct.attr_max_val.enum_name))
2262            if 'flags' in op:
2263                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2264            cw.write_struct_init(members)
2265            cw.block_end(line=',')
2266    elif family.kernel_policy == 'split':
2267        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2268                    'dump': {'pre': 'start', 'post': 'done'}}
2269
2270        for op_name, op in family.ops.items():
2271            for op_mode in ['do', 'dump']:
2272                if op.is_async or op_mode not in op:
2273                    continue
2274
2275                cw.ifdef_block(op.get('config-cond', None))
2276                cw.block_start()
2277                members = [('cmd', op.enum_name)]
2278                if 'dont-validate' in op:
2279                    dont_validate = []
2280                    for x in op['dont-validate']:
2281                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2282                            continue
2283                        if op_mode == "dump" and x == 'strict':
2284                            continue
2285                        dont_validate.append(x)
2286
2287                    if dont_validate:
2288                        members.append(('validate',
2289                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2290                                                    for x in dont_validate])), )
2291                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2292                if 'pre' in op[op_mode]:
2293                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2294                members.append((op_mode + 'it', name))
2295                if 'post' in op[op_mode]:
2296                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2297                if 'request' in op[op_mode]:
2298                    struct = Struct(family, op['attribute-set'],
2299                                    type_list=op[op_mode]['request']['attributes'])
2300
2301                    if op.dual_policy:
2302                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2303                    else:
2304                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2305                    members.append(('policy', name))
2306                    members.append(('maxattr', struct.attr_max_val.enum_name))
2307                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2308                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2309                cw.write_struct_init(members)
2310                cw.block_end(line=',')
2311    cw.ifdef_block(None)
2312
2313    cw.block_end(line=';')
2314    cw.nl()
2315
2316
2317def print_kernel_mcgrp_hdr(family, cw):
2318    if not family.mcgrps['list']:
2319        return
2320
2321    cw.block_start('enum')
2322    for grp in family.mcgrps['list']:
2323        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2324        cw.p(grp_id)
2325    cw.block_end(';')
2326    cw.nl()
2327
2328
2329def print_kernel_mcgrp_src(family, cw):
2330    if not family.mcgrps['list']:
2331        return
2332
2333    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2334    for grp in family.mcgrps['list']:
2335        name = grp['name']
2336        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2337        cw.p('[' + grp_id + '] = { "' + name + '", },')
2338    cw.block_end(';')
2339    cw.nl()
2340
2341
2342def print_kernel_family_struct_hdr(family, cw):
2343    if not kernel_can_gen_family_struct(family):
2344        return
2345
2346    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2347    cw.nl()
2348    if 'sock-priv' in family.kernel_family:
2349        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2350        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2351        cw.nl()
2352
2353
2354def print_kernel_family_struct_src(family, cw):
2355    if not kernel_can_gen_family_struct(family):
2356        return
2357
2358    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2359    cw.p('.name\t\t= ' + family.fam_key + ',')
2360    cw.p('.version\t= ' + family.ver_key + ',')
2361    cw.p('.netnsok\t= true,')
2362    cw.p('.parallel_ops\t= true,')
2363    cw.p('.module\t\t= THIS_MODULE,')
2364    if family.kernel_policy == 'per-op':
2365        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2366        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2367    elif family.kernel_policy == 'split':
2368        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2369        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2370    if family.mcgrps['list']:
2371        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2372        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2373    if 'sock-priv' in family.kernel_family:
2374        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2375        # Force cast here, actual helpers take pointer to the real type.
2376        cw.p(f'.sock_priv_init\t= (void *){family.c_name}_nl_sock_priv_init,')
2377        cw.p(f'.sock_priv_destroy = (void *){family.c_name}_nl_sock_priv_destroy,')
2378    cw.block_end(';')
2379
2380
2381def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2382    start_line = 'enum'
2383    if enum_name in obj:
2384        if obj[enum_name]:
2385            start_line = 'enum ' + c_lower(obj[enum_name])
2386    elif ckey and ckey in obj:
2387        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2388    cw.block_start(line=start_line)
2389
2390
2391def render_uapi(family, cw):
2392    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2393    cw.p('#ifndef ' + hdr_prot)
2394    cw.p('#define ' + hdr_prot)
2395    cw.nl()
2396
2397    defines = [(family.fam_key, family["name"]),
2398               (family.ver_key, family.get('version', 1))]
2399    cw.writes_defines(defines)
2400    cw.nl()
2401
2402    defines = []
2403    for const in family['definitions']:
2404        if const['type'] != 'const':
2405            cw.writes_defines(defines)
2406            defines = []
2407            cw.nl()
2408
2409        # Write kdoc for enum and flags (one day maybe also structs)
2410        if const['type'] == 'enum' or const['type'] == 'flags':
2411            enum = family.consts[const['name']]
2412
2413            if enum.has_doc():
2414                cw.p('/**')
2415                doc = ''
2416                if 'doc' in enum:
2417                    doc = ' - ' + enum['doc']
2418                cw.write_doc_line(enum.enum_name + doc)
2419                for entry in enum.entries.values():
2420                    if entry.has_doc():
2421                        doc = '@' + entry.c_name + ': ' + entry['doc']
2422                        cw.write_doc_line(doc)
2423                cw.p(' */')
2424
2425            uapi_enum_start(family, cw, const, 'name')
2426            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2427            for entry in enum.entries.values():
2428                suffix = ','
2429                if entry.value_change:
2430                    suffix = f" = {entry.user_value()}" + suffix
2431                cw.p(entry.c_name + suffix)
2432
2433            if const.get('render-max', False):
2434                cw.nl()
2435                cw.p('/* private: */')
2436                if const['type'] == 'flags':
2437                    max_name = c_upper(name_pfx + 'mask')
2438                    max_val = f' = {enum.get_mask()},'
2439                    cw.p(max_name + max_val)
2440                else:
2441                    max_name = c_upper(name_pfx + 'max')
2442                    cw.p('__' + max_name + ',')
2443                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2444            cw.block_end(line=';')
2445            cw.nl()
2446        elif const['type'] == 'const':
2447            defines.append([c_upper(family.get('c-define-name',
2448                                               f"{family.name}-{const['name']}")),
2449                            const['value']])
2450
2451    if defines:
2452        cw.writes_defines(defines)
2453        cw.nl()
2454
2455    max_by_define = family.get('max-by-define', False)
2456
2457    for _, attr_set in family.attr_sets.items():
2458        if attr_set.subset_of:
2459            continue
2460
2461        max_value = f"({attr_set.cnt_name} - 1)"
2462
2463        val = 0
2464        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2465        for _, attr in attr_set.items():
2466            suffix = ','
2467            if attr.value != val:
2468                suffix = f" = {attr.value},"
2469                val = attr.value
2470            val += 1
2471            cw.p(attr.enum_name + suffix)
2472        cw.nl()
2473        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2474        if not max_by_define:
2475            cw.p(f"{attr_set.max_name} = {max_value}")
2476        cw.block_end(line=';')
2477        if max_by_define:
2478            cw.p(f"#define {attr_set.max_name} {max_value}")
2479        cw.nl()
2480
2481    # Commands
2482    separate_ntf = 'async-prefix' in family['operations']
2483
2484    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2485    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2486    max_value = f"({cnt_name} - 1)"
2487
2488    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2489    val = 0
2490    for op in family.msgs.values():
2491        if separate_ntf and ('notify' in op or 'event' in op):
2492            continue
2493
2494        suffix = ','
2495        if op.value != val:
2496            suffix = f" = {op.value},"
2497            val = op.value
2498        cw.p(op.enum_name + suffix)
2499        val += 1
2500    cw.nl()
2501    cw.p(cnt_name + ('' if max_by_define else ','))
2502    if not max_by_define:
2503        cw.p(f"{max_name} = {max_value}")
2504    cw.block_end(line=';')
2505    if max_by_define:
2506        cw.p(f"#define {max_name} {max_value}")
2507    cw.nl()
2508
2509    if separate_ntf:
2510        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2511        for op in family.msgs.values():
2512            if separate_ntf and not ('notify' in op or 'event' in op):
2513                continue
2514
2515            suffix = ','
2516            if 'value' in op:
2517                suffix = f" = {op['value']},"
2518            cw.p(op.enum_name + suffix)
2519        cw.block_end(line=';')
2520        cw.nl()
2521
2522    # Multicast
2523    defines = []
2524    for grp in family.mcgrps['list']:
2525        name = grp['name']
2526        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2527                        f'{name}'])
2528    cw.nl()
2529    if defines:
2530        cw.writes_defines(defines)
2531        cw.nl()
2532
2533    cw.p(f'#endif /* {hdr_prot} */')
2534
2535
2536def _render_user_ntf_entry(ri, op):
2537    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2538    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2539    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2540    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2541    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2542    ri.cw.block_end(line=',')
2543
2544
2545def render_user_family(family, cw, prototype):
2546    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2547    if prototype:
2548        cw.p(f'extern {symbol};')
2549        return
2550
2551    if family.ntfs:
2552        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2553        for ntf_op_name, ntf_op in family.ntfs.items():
2554            if 'notify' in ntf_op:
2555                op = family.ops[ntf_op['notify']]
2556                ri = RenderInfo(cw, family, "user", op, "notify")
2557            elif 'event' in ntf_op:
2558                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2559            else:
2560                raise Exception('Invalid notification ' + ntf_op_name)
2561            _render_user_ntf_entry(ri, ntf_op)
2562        for op_name, op in family.ops.items():
2563            if 'event' not in op:
2564                continue
2565            ri = RenderInfo(cw, family, "user", op, "event")
2566            _render_user_ntf_entry(ri, op)
2567        cw.block_end(line=";")
2568        cw.nl()
2569
2570    cw.block_start(f'{symbol} = ')
2571    cw.p(f'.name\t\t= "{family.c_name}",')
2572    if family.fixed_header:
2573        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2574    else:
2575        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2576    if family.ntfs:
2577        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2578        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2579    cw.block_end(line=';')
2580
2581
2582def family_contains_bitfield32(family):
2583    for _, attr_set in family.attr_sets.items():
2584        if attr_set.subset_of:
2585            continue
2586        for _, attr in attr_set.items():
2587            if attr.type == "bitfield32":
2588                return True
2589    return False
2590
2591
2592def find_kernel_root(full_path):
2593    sub_path = ''
2594    while True:
2595        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2596        full_path = os.path.dirname(full_path)
2597        maintainers = os.path.join(full_path, "MAINTAINERS")
2598        if os.path.exists(maintainers):
2599            return full_path, sub_path[:-1]
2600
2601
2602def main():
2603    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2604    parser.add_argument('--mode', dest='mode', type=str, required=True)
2605    parser.add_argument('--spec', dest='spec', type=str, required=True)
2606    parser.add_argument('--header', dest='header', action='store_true', default=None)
2607    parser.add_argument('--source', dest='header', action='store_false')
2608    parser.add_argument('--user-header', nargs='+', default=[])
2609    parser.add_argument('--cmp-out', action='store_true', default=None,
2610                        help='Do not overwrite the output file if the new output is identical to the old')
2611    parser.add_argument('--exclude-op', action='append', default=[])
2612    parser.add_argument('-o', dest='out_file', type=str, default=None)
2613    args = parser.parse_args()
2614
2615    if args.header is None:
2616        parser.error("--header or --source is required")
2617
2618    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2619
2620    try:
2621        parsed = Family(args.spec, exclude_ops)
2622        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2623            print('Spec license:', parsed.license)
2624            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2625            os.sys.exit(1)
2626    except yaml.YAMLError as exc:
2627        print(exc)
2628        os.sys.exit(1)
2629        return
2630
2631    supported_models = ['unified']
2632    if args.mode in ['user', 'kernel']:
2633        supported_models += ['directional']
2634    if parsed.msg_id_model not in supported_models:
2635        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2636        os.sys.exit(1)
2637
2638    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2639
2640    _, spec_kernel = find_kernel_root(args.spec)
2641    if args.mode == 'uapi' or args.header:
2642        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2643    else:
2644        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2645    cw.p("/* Do not edit directly, auto-generated from: */")
2646    cw.p(f"/*\t{spec_kernel} */")
2647    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2648    if args.exclude_op or args.user_header:
2649        line = ''
2650        line += ' --user-header '.join([''] + args.user_header)
2651        line += ' --exclude-op '.join([''] + args.exclude_op)
2652        cw.p(f'/* YNL-ARG{line} */')
2653    cw.nl()
2654
2655    if args.mode == 'uapi':
2656        render_uapi(parsed, cw)
2657        return
2658
2659    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2660    if args.header:
2661        cw.p('#ifndef ' + hdr_prot)
2662        cw.p('#define ' + hdr_prot)
2663        cw.nl()
2664
2665    if args.mode == 'kernel':
2666        cw.p('#include <net/netlink.h>')
2667        cw.p('#include <net/genetlink.h>')
2668        cw.nl()
2669        if not args.header:
2670            if args.out_file:
2671                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2672            cw.nl()
2673        headers = ['uapi/' + parsed.uapi_header]
2674        headers += parsed.kernel_family.get('headers', [])
2675    else:
2676        cw.p('#include <stdlib.h>')
2677        cw.p('#include <string.h>')
2678        if args.header:
2679            cw.p('#include <linux/types.h>')
2680            if family_contains_bitfield32(parsed):
2681                cw.p('#include <linux/netlink.h>')
2682        else:
2683            cw.p(f'#include "{parsed.name}-user.h"')
2684            cw.p('#include "ynl.h"')
2685        headers = [parsed.uapi_header]
2686    for definition in parsed['definitions']:
2687        if 'header' in definition:
2688            headers.append(definition['header'])
2689    for one in headers:
2690        cw.p(f"#include <{one}>")
2691    cw.nl()
2692
2693    if args.mode == "user":
2694        if not args.header:
2695            cw.p("#include <linux/genetlink.h>")
2696            cw.nl()
2697            for one in args.user_header:
2698                cw.p(f'#include "{one}"')
2699        else:
2700            cw.p('struct ynl_sock;')
2701            cw.nl()
2702            render_user_family(parsed, cw, True)
2703        cw.nl()
2704
2705    if args.mode == "kernel":
2706        if args.header:
2707            for _, struct in sorted(parsed.pure_nested_structs.items()):
2708                if struct.request:
2709                    cw.p('/* Common nested types */')
2710                    break
2711            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2712                if struct.request:
2713                    print_req_policy_fwd(cw, struct)
2714            cw.nl()
2715
2716            if parsed.kernel_policy == 'global':
2717                cw.p(f"/* Global operation policy for {parsed.name} */")
2718
2719                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2720                print_req_policy_fwd(cw, struct)
2721                cw.nl()
2722
2723            if parsed.kernel_policy in {'per-op', 'split'}:
2724                for op_name, op in parsed.ops.items():
2725                    if 'do' in op and 'event' not in op:
2726                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2727                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2728                        cw.nl()
2729
2730            print_kernel_op_table_hdr(parsed, cw)
2731            print_kernel_mcgrp_hdr(parsed, cw)
2732            print_kernel_family_struct_hdr(parsed, cw)
2733        else:
2734            print_kernel_policy_ranges(parsed, cw)
2735
2736            for _, struct in sorted(parsed.pure_nested_structs.items()):
2737                if struct.request:
2738                    cw.p('/* Common nested types */')
2739                    break
2740            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2741                if struct.request:
2742                    print_req_policy(cw, struct)
2743            cw.nl()
2744
2745            if parsed.kernel_policy == 'global':
2746                cw.p(f"/* Global operation policy for {parsed.name} */")
2747
2748                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2749                print_req_policy(cw, struct)
2750                cw.nl()
2751
2752            for op_name, op in parsed.ops.items():
2753                if parsed.kernel_policy in {'per-op', 'split'}:
2754                    for op_mode in ['do', 'dump']:
2755                        if op_mode in op and 'request' in op[op_mode]:
2756                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2757                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2758                            print_req_policy(cw, ri.struct['request'], ri=ri)
2759                            cw.nl()
2760
2761            print_kernel_op_table(parsed, cw)
2762            print_kernel_mcgrp_src(parsed, cw)
2763            print_kernel_family_struct_src(parsed, cw)
2764
2765    if args.mode == "user":
2766        if args.header:
2767            cw.p('/* Enums */')
2768            put_op_name_fwd(parsed, cw)
2769
2770            for name, const in parsed.consts.items():
2771                if isinstance(const, EnumSet):
2772                    put_enum_to_str_fwd(parsed, cw, const)
2773            cw.nl()
2774
2775            cw.p('/* Common nested types */')
2776            for attr_set, struct in parsed.pure_nested_structs.items():
2777                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2778                print_type_full(ri, struct)
2779
2780            for op_name, op in parsed.ops.items():
2781                cw.p(f"/* ============== {op.enum_name} ============== */")
2782
2783                if 'do' in op and 'event' not in op:
2784                    cw.p(f"/* {op.enum_name} - do */")
2785                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2786                    print_req_type(ri)
2787                    print_req_type_helpers(ri)
2788                    cw.nl()
2789                    print_rsp_type(ri)
2790                    print_rsp_type_helpers(ri)
2791                    cw.nl()
2792                    print_req_prototype(ri)
2793                    cw.nl()
2794
2795                if 'dump' in op:
2796                    cw.p(f"/* {op.enum_name} - dump */")
2797                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2798                    print_req_type(ri)
2799                    print_req_type_helpers(ri)
2800                    if not ri.type_consistent:
2801                        print_rsp_type(ri)
2802                    print_wrapped_type(ri)
2803                    print_dump_prototype(ri)
2804                    cw.nl()
2805
2806                if op.has_ntf:
2807                    cw.p(f"/* {op.enum_name} - notify */")
2808                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2809                    if not ri.type_consistent:
2810                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2811                    print_wrapped_type(ri)
2812
2813            for op_name, op in parsed.ntfs.items():
2814                if 'event' in op:
2815                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2816                    cw.p(f"/* {op.enum_name} - event */")
2817                    print_rsp_type(ri)
2818                    cw.nl()
2819                    print_wrapped_type(ri)
2820            cw.nl()
2821        else:
2822            cw.p('/* Enums */')
2823            put_op_name(parsed, cw)
2824
2825            for name, const in parsed.consts.items():
2826                if isinstance(const, EnumSet):
2827                    put_enum_to_str(parsed, cw, const)
2828            cw.nl()
2829
2830            has_recursive_nests = False
2831            cw.p('/* Policies */')
2832            for struct in parsed.pure_nested_structs.values():
2833                if struct.recursive:
2834                    put_typol_fwd(cw, struct)
2835                    has_recursive_nests = True
2836            if has_recursive_nests:
2837                cw.nl()
2838            for name in parsed.pure_nested_structs:
2839                struct = Struct(parsed, name)
2840                put_typol(cw, struct)
2841            for name in parsed.root_sets:
2842                struct = Struct(parsed, name)
2843                put_typol(cw, struct)
2844
2845            cw.p('/* Common nested types */')
2846            if has_recursive_nests:
2847                for attr_set, struct in parsed.pure_nested_structs.items():
2848                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2849                    free_rsp_nested_prototype(ri)
2850                    if struct.request:
2851                        put_req_nested_prototype(ri, struct)
2852                    if struct.reply:
2853                        parse_rsp_nested_prototype(ri, struct)
2854                cw.nl()
2855            for attr_set, struct in parsed.pure_nested_structs.items():
2856                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2857
2858                free_rsp_nested(ri, struct)
2859                if struct.request:
2860                    put_req_nested(ri, struct)
2861                if struct.reply:
2862                    parse_rsp_nested(ri, struct)
2863
2864            for op_name, op in parsed.ops.items():
2865                cw.p(f"/* ============== {op.enum_name} ============== */")
2866                if 'do' in op and 'event' not in op:
2867                    cw.p(f"/* {op.enum_name} - do */")
2868                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2869                    print_req_free(ri)
2870                    print_rsp_free(ri)
2871                    parse_rsp_msg(ri)
2872                    print_req(ri)
2873                    cw.nl()
2874
2875                if 'dump' in op:
2876                    cw.p(f"/* {op.enum_name} - dump */")
2877                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2878                    if not ri.type_consistent:
2879                        parse_rsp_msg(ri, deref=True)
2880                    print_req_free(ri)
2881                    print_dump_type_free(ri)
2882                    print_dump(ri)
2883                    cw.nl()
2884
2885                if op.has_ntf:
2886                    cw.p(f"/* {op.enum_name} - notify */")
2887                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2888                    if not ri.type_consistent:
2889                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2890                    print_ntf_type_free(ri)
2891
2892            for op_name, op in parsed.ntfs.items():
2893                if 'event' in op:
2894                    cw.p(f"/* {op.enum_name} - event */")
2895
2896                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2897                    parse_rsp_msg(ri)
2898
2899                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2900                    print_ntf_type_free(ri)
2901            cw.nl()
2902            render_user_family(parsed, cw, False)
2903
2904    if args.header:
2905        cw.p(f'#endif /* {hdr_prot} */')
2906
2907
2908if __name__ == "__main__":
2909    main()
2910