1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import os
8import re
9import shutil
10import tempfile
11import yaml
12
13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16def c_upper(name):
17    return name.upper().replace('-', '_')
18
19
20def c_lower(name):
21    return name.lower().replace('-', '_')
22
23
24def limit_to_number(name):
25    """
26    Turn a string limit like u32-max or s64-min into its numerical value
27    """
28    if name[0] == 'u' and name.endswith('-min'):
29        return 0
30    width = int(name[1:-4])
31    if name[0] == 's':
32        width -= 1
33    value = (1 << width) - 1
34    if name[0] == 's' and name.endswith('-min'):
35        value = -value - 1
36    return value
37
38
39class BaseNlLib:
40    def get_family_id(self):
41        return 'ys->family_id'
42
43
44class Type(SpecAttr):
45    def __init__(self, family, attr_set, attr, value):
46        super().__init__(family, attr_set, attr, value)
47
48        self.attr = attr
49        self.attr_set = attr_set
50        self.type = attr['type']
51        self.checks = attr.get('checks', {})
52
53        self.request = False
54        self.reply = False
55
56        if 'len' in attr:
57            self.len = attr['len']
58
59        if 'nested-attributes' in attr:
60            self.nested_attrs = attr['nested-attributes']
61            if self.nested_attrs == family.name:
62                self.nested_render_name = c_lower(f"{family.ident_name}")
63            else:
64                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
65
66            if self.nested_attrs in self.family.consts:
67                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
68            else:
69                self.nested_struct_type = 'struct ' + self.nested_render_name
70
71        self.c_name = c_lower(self.name)
72        if self.c_name in _C_KW:
73            self.c_name += '_'
74
75        # Added by resolve():
76        self.enum_name = None
77        delattr(self, "enum_name")
78
79    def get_limit(self, limit, default=None):
80        value = self.checks.get(limit, default)
81        if value is None:
82            return value
83        elif value in self.family.consts:
84            return c_upper(f"{self.family['name']}-{value}")
85        if not isinstance(value, int):
86            value = limit_to_number(value)
87        return value
88
89    def resolve(self):
90        if 'name-prefix' in self.attr:
91            enum_name = f"{self.attr['name-prefix']}{self.name}"
92        else:
93            enum_name = f"{self.attr_set.name_prefix}{self.name}"
94        self.enum_name = c_upper(enum_name)
95
96    def is_multi_val(self):
97        return None
98
99    def is_scalar(self):
100        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
101
102    def is_recursive(self):
103        return False
104
105    def is_recursive_for_op(self, ri):
106        return self.is_recursive() and not ri.op
107
108    def presence_type(self):
109        return 'bit'
110
111    def presence_member(self, space, type_filter):
112        if self.presence_type() != type_filter:
113            return
114
115        if self.presence_type() == 'bit':
116            pfx = '__' if space == 'user' else ''
117            return f"{pfx}u32 {self.c_name}:1;"
118
119        if self.presence_type() == 'len':
120            pfx = '__' if space == 'user' else ''
121            return f"{pfx}u32 {self.c_name}_len;"
122
123    def _complex_member_type(self, ri):
124        return None
125
126    def free_needs_iter(self):
127        return False
128
129    def free(self, ri, var, ref):
130        if self.is_multi_val() or self.presence_type() == 'len':
131            ri.cw.p(f'free({var}->{ref}{self.c_name});')
132
133    def arg_member(self, ri):
134        member = self._complex_member_type(ri)
135        if member:
136            arg = [member + ' *' + self.c_name]
137            if self.presence_type() == 'count':
138                arg += ['unsigned int n_' + self.c_name]
139            return arg
140        raise Exception(f"Struct member not implemented for class type {self.type}")
141
142    def struct_member(self, ri):
143        if self.is_multi_val():
144            ri.cw.p(f"unsigned int n_{self.c_name};")
145        member = self._complex_member_type(ri)
146        if member:
147            ptr = '*' if self.is_multi_val() else ''
148            if self.is_recursive_for_op(ri):
149                ptr = '*'
150            ri.cw.p(f"{member} {ptr}{self.c_name};")
151            return
152        members = self.arg_member(ri)
153        for one in members:
154            ri.cw.p(one + ';')
155
156    def _attr_policy(self, policy):
157        return '{ .type = ' + policy + ', }'
158
159    def attr_policy(self, cw):
160        policy = c_upper('nla-' + self.attr['type'])
161
162        spec = self._attr_policy(policy)
163        cw.p(f"\t[{self.enum_name}] = {spec},")
164
165    def _attr_typol(self):
166        raise Exception(f"Type policy not implemented for class type {self.type}")
167
168    def attr_typol(self, cw):
169        typol = self._attr_typol()
170        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
171
172    def _attr_put_line(self, ri, var, line):
173        if self.presence_type() == 'bit':
174            ri.cw.p(f"if ({var}->_present.{self.c_name})")
175        elif self.presence_type() == 'len':
176            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
177        ri.cw.p(f"{line};")
178
179    def _attr_put_simple(self, ri, var, put_type):
180        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
181        self._attr_put_line(ri, var, line)
182
183    def attr_put(self, ri, var):
184        raise Exception(f"Put not implemented for class type {self.type}")
185
186    def _attr_get(self, ri, var):
187        raise Exception(f"Attr get not implemented for class type {self.type}")
188
189    def attr_get(self, ri, var, first):
190        lines, init_lines, local_vars = self._attr_get(ri, var)
191        if type(lines) is str:
192            lines = [lines]
193        if type(init_lines) is str:
194            init_lines = [init_lines]
195
196        kw = 'if' if first else 'else if'
197        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
198        if local_vars:
199            for local in local_vars:
200                ri.cw.p(local)
201            ri.cw.nl()
202
203        if not self.is_multi_val():
204            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
205            ri.cw.p("return YNL_PARSE_CB_ERROR;")
206            if self.presence_type() == 'bit':
207                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
208
209        if init_lines:
210            ri.cw.nl()
211            for line in init_lines:
212                ri.cw.p(line)
213
214        for line in lines:
215            ri.cw.p(line)
216        ri.cw.block_end()
217        return True
218
219    def _setter_lines(self, ri, member, presence):
220        raise Exception(f"Setter not implemented for class type {self.type}")
221
222    def setter(self, ri, space, direction, deref=False, ref=None):
223        ref = (ref if ref else []) + [self.c_name]
224        var = "req"
225        member = f"{var}->{'.'.join(ref)}"
226
227        code = []
228        presence = ''
229        for i in range(0, len(ref)):
230            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
231            # Every layer below last is a nest, so we know it uses bit presence
232            # last layer is "self" and may be a complex type
233            if i == len(ref) - 1 and self.presence_type() != 'bit':
234                continue
235            code.append(presence + ' = 1;')
236        code += self._setter_lines(ri, member, presence)
237
238        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
239        free = bool([x for x in code if 'free(' in x])
240        alloc = bool([x for x in code if 'alloc(' in x])
241        if free and not alloc:
242            func_name = '__' + func_name
243        ri.cw.write_func('static inline void', func_name, body=code,
244                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
245
246
247class TypeUnused(Type):
248    def presence_type(self):
249        return ''
250
251    def arg_member(self, ri):
252        return []
253
254    def _attr_get(self, ri, var):
255        return ['return YNL_PARSE_CB_ERROR;'], None, None
256
257    def _attr_typol(self):
258        return '.type = YNL_PT_REJECT, '
259
260    def attr_policy(self, cw):
261        pass
262
263    def attr_put(self, ri, var):
264        pass
265
266    def attr_get(self, ri, var, first):
267        pass
268
269    def setter(self, ri, space, direction, deref=False, ref=None):
270        pass
271
272
273class TypePad(Type):
274    def presence_type(self):
275        return ''
276
277    def arg_member(self, ri):
278        return []
279
280    def _attr_typol(self):
281        return '.type = YNL_PT_IGNORE, '
282
283    def attr_put(self, ri, var):
284        pass
285
286    def attr_get(self, ri, var, first):
287        pass
288
289    def attr_policy(self, cw):
290        pass
291
292    def setter(self, ri, space, direction, deref=False, ref=None):
293        pass
294
295
296class TypeScalar(Type):
297    def __init__(self, family, attr_set, attr, value):
298        super().__init__(family, attr_set, attr, value)
299
300        self.byte_order_comment = ''
301        if 'byte-order' in attr:
302            self.byte_order_comment = f" /* {attr['byte-order']} */"
303
304        if 'enum' in self.attr:
305            enum = self.family.consts[self.attr['enum']]
306            low, high = enum.value_range()
307            if 'min' not in self.checks:
308                if low != 0 or self.type[0] == 's':
309                    self.checks['min'] = low
310            if 'max' not in self.checks:
311                self.checks['max'] = high
312
313        if 'min' in self.checks and 'max' in self.checks:
314            if self.get_limit('min') > self.get_limit('max'):
315                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
316            self.checks['range'] = True
317
318        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
319        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
320        if low < 0 and self.type[0] == 'u':
321            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
322        if low < -32768 or high > 32767:
323            self.checks['full-range'] = True
324
325        # Added by resolve():
326        self.is_bitfield = None
327        delattr(self, "is_bitfield")
328        self.type_name = None
329        delattr(self, "type_name")
330
331    def resolve(self):
332        self.resolve_up(super())
333
334        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
335            self.is_bitfield = True
336        elif 'enum' in self.attr:
337            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
338        else:
339            self.is_bitfield = False
340
341        if not self.is_bitfield and 'enum' in self.attr:
342            self.type_name = self.family.consts[self.attr['enum']].user_type
343        elif self.is_auto_scalar:
344            self.type_name = '__' + self.type[0] + '64'
345        else:
346            self.type_name = '__' + self.type
347
348    def _attr_policy(self, policy):
349        if 'flags-mask' in self.checks or self.is_bitfield:
350            if self.is_bitfield:
351                enum = self.family.consts[self.attr['enum']]
352                mask = enum.get_mask(as_flags=True)
353            else:
354                flags = self.family.consts[self.checks['flags-mask']]
355                flag_cnt = len(flags['entries'])
356                mask = (1 << flag_cnt) - 1
357            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
358        elif 'full-range' in self.checks:
359            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
360        elif 'range' in self.checks:
361            return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
362        elif 'min' in self.checks:
363            return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
364        elif 'max' in self.checks:
365            return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
366        return super()._attr_policy(policy)
367
368    def _attr_typol(self):
369        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
370
371    def arg_member(self, ri):
372        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
373
374    def attr_put(self, ri, var):
375        self._attr_put_simple(ri, var, self.type)
376
377    def _attr_get(self, ri, var):
378        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
379
380    def _setter_lines(self, ri, member, presence):
381        return [f"{member} = {self.c_name};"]
382
383
384class TypeFlag(Type):
385    def arg_member(self, ri):
386        return []
387
388    def _attr_typol(self):
389        return '.type = YNL_PT_FLAG, '
390
391    def attr_put(self, ri, var):
392        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
393
394    def _attr_get(self, ri, var):
395        return [], None, None
396
397    def _setter_lines(self, ri, member, presence):
398        return []
399
400
401class TypeString(Type):
402    def arg_member(self, ri):
403        return [f"const char *{self.c_name}"]
404
405    def presence_type(self):
406        return 'len'
407
408    def struct_member(self, ri):
409        ri.cw.p(f"char *{self.c_name};")
410
411    def _attr_typol(self):
412        return f'.type = YNL_PT_NUL_STR, '
413
414    def _attr_policy(self, policy):
415        if 'exact-len' in self.checks:
416            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.get_limit('exact-len')) + ')'
417        else:
418            mem = '{ .type = ' + policy
419            if 'max-len' in self.checks:
420                mem += ', .len = ' + str(self.get_limit('max-len'))
421            mem += ', }'
422        return mem
423
424    def attr_policy(self, cw):
425        if self.checks.get('unterminated-ok', False):
426            policy = 'NLA_STRING'
427        else:
428            policy = 'NLA_NUL_STRING'
429
430        spec = self._attr_policy(policy)
431        cw.p(f"\t[{self.enum_name}] = {spec},")
432
433    def attr_put(self, ri, var):
434        self._attr_put_simple(ri, var, 'str')
435
436    def _attr_get(self, ri, var):
437        len_mem = var + '->_present.' + self.c_name + '_len'
438        return [f"{len_mem} = len;",
439                f"{var}->{self.c_name} = malloc(len + 1);",
440                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
441                f"{var}->{self.c_name}[len] = 0;"], \
442               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
443               ['unsigned int len;']
444
445    def _setter_lines(self, ri, member, presence):
446        return [f"free({member});",
447                f"{presence}_len = strlen({self.c_name});",
448                f"{member} = malloc({presence}_len + 1);",
449                f'memcpy({member}, {self.c_name}, {presence}_len);',
450                f'{member}[{presence}_len] = 0;']
451
452
453class TypeBinary(Type):
454    def arg_member(self, ri):
455        return [f"const void *{self.c_name}", 'size_t len']
456
457    def presence_type(self):
458        return 'len'
459
460    def struct_member(self, ri):
461        ri.cw.p(f"void *{self.c_name};")
462
463    def _attr_typol(self):
464        return f'.type = YNL_PT_BINARY,'
465
466    def _attr_policy(self, policy):
467        if 'exact-len' in self.checks:
468            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.get_limit('exact-len')) + ')'
469        else:
470            mem = '{ '
471            if len(self.checks) == 1 and 'min-len' in self.checks:
472                mem += '.len = ' + str(self.get_limit('min-len'))
473            elif len(self.checks) == 0:
474                mem += '.type = NLA_BINARY'
475            else:
476                raise Exception('One or more of binary type checks not implemented, yet')
477            mem += ', }'
478        return mem
479
480    def attr_put(self, ri, var):
481        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
482                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
483
484    def _attr_get(self, ri, var):
485        len_mem = var + '->_present.' + self.c_name + '_len'
486        return [f"{len_mem} = len;",
487                f"{var}->{self.c_name} = malloc(len);",
488                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
489               ['len = ynl_attr_data_len(attr);'], \
490               ['unsigned int len;']
491
492    def _setter_lines(self, ri, member, presence):
493        return [f"free({member});",
494                f"{presence}_len = len;",
495                f"{member} = malloc({presence}_len);",
496                f'memcpy({member}, {self.c_name}, {presence}_len);']
497
498
499class TypeBitfield32(Type):
500    def _complex_member_type(self, ri):
501        return "struct nla_bitfield32"
502
503    def _attr_typol(self):
504        return f'.type = YNL_PT_BITFIELD32, '
505
506    def _attr_policy(self, policy):
507        if not 'enum' in self.attr:
508            raise Exception('Enum required for bitfield32 attr')
509        enum = self.family.consts[self.attr['enum']]
510        mask = enum.get_mask(as_flags=True)
511        return f"NLA_POLICY_BITFIELD32({mask})"
512
513    def attr_put(self, ri, var):
514        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
515        self._attr_put_line(ri, var, line)
516
517    def _attr_get(self, ri, var):
518        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
519
520    def _setter_lines(self, ri, member, presence):
521        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
522
523
524class TypeNest(Type):
525    def is_recursive(self):
526        return self.family.pure_nested_structs[self.nested_attrs].recursive
527
528    def _complex_member_type(self, ri):
529        return self.nested_struct_type
530
531    def free(self, ri, var, ref):
532        at = '&'
533        if self.is_recursive_for_op(ri):
534            at = ''
535            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
536        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
537
538    def _attr_typol(self):
539        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
540
541    def _attr_policy(self, policy):
542        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
543
544    def attr_put(self, ri, var):
545        at = '' if self.is_recursive_for_op(ri) else '&'
546        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
547                            f"{self.enum_name}, {at}{var}->{self.c_name})")
548
549    def _attr_get(self, ri, var):
550        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
551                     "return YNL_PARSE_CB_ERROR;"]
552        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
553                      f"parg.data = &{var}->{self.c_name};"]
554        return get_lines, init_lines, None
555
556    def setter(self, ri, space, direction, deref=False, ref=None):
557        ref = (ref if ref else []) + [self.c_name]
558
559        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
560            if attr.is_recursive():
561                continue
562            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
563
564
565class TypeMultiAttr(Type):
566    def __init__(self, family, attr_set, attr, value, base_type):
567        super().__init__(family, attr_set, attr, value)
568
569        self.base_type = base_type
570
571    def is_multi_val(self):
572        return True
573
574    def presence_type(self):
575        return 'count'
576
577    def _complex_member_type(self, ri):
578        if 'type' not in self.attr or self.attr['type'] == 'nest':
579            return self.nested_struct_type
580        elif self.attr['type'] in scalars:
581            scalar_pfx = '__' if ri.ku_space == 'user' else ''
582            return scalar_pfx + self.attr['type']
583        else:
584            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
585
586    def free_needs_iter(self):
587        return 'type' not in self.attr or self.attr['type'] == 'nest'
588
589    def free(self, ri, var, ref):
590        if self.attr['type'] in scalars:
591            ri.cw.p(f"free({var}->{ref}{self.c_name});")
592        elif 'type' not in self.attr or self.attr['type'] == 'nest':
593            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
594            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
595            ri.cw.p(f"free({var}->{ref}{self.c_name});")
596        else:
597            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
598
599    def _attr_policy(self, policy):
600        return self.base_type._attr_policy(policy)
601
602    def _attr_typol(self):
603        return self.base_type._attr_typol()
604
605    def _attr_get(self, ri, var):
606        return f'n_{self.c_name}++;', None, None
607
608    def attr_put(self, ri, var):
609        if self.attr['type'] in scalars:
610            put_type = self.type
611            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
612            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
613        elif 'type' not in self.attr or self.attr['type'] == 'nest':
614            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
615            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
616                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
617        else:
618            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
619
620    def _setter_lines(self, ri, member, presence):
621        # For multi-attr we have a count, not presence, hack up the presence
622        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
623        return [f"free({member});",
624                f"{member} = {self.c_name};",
625                f"{presence} = n_{self.c_name};"]
626
627
628class TypeArrayNest(Type):
629    def is_multi_val(self):
630        return True
631
632    def presence_type(self):
633        return 'count'
634
635    def _complex_member_type(self, ri):
636        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
637            return self.nested_struct_type
638        elif self.attr['sub-type'] in scalars:
639            scalar_pfx = '__' if ri.ku_space == 'user' else ''
640            return scalar_pfx + self.attr['sub-type']
641        else:
642            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
643
644    def _attr_typol(self):
645        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
646
647    def _attr_get(self, ri, var):
648        local_vars = ['const struct nlattr *attr2;']
649        get_lines = [f'attr_{self.c_name} = attr;',
650                     'ynl_attr_for_each_nested(attr2, attr)',
651                     f'\t{var}->n_{self.c_name}++;']
652        return get_lines, None, local_vars
653
654
655class TypeNestTypeValue(Type):
656    def _complex_member_type(self, ri):
657        return self.nested_struct_type
658
659    def _attr_typol(self):
660        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
661
662    def _attr_get(self, ri, var):
663        prev = 'attr'
664        tv_args = ''
665        get_lines = []
666        local_vars = []
667        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
668                      f"parg.data = &{var}->{self.c_name};"]
669        if 'type-value' in self.attr:
670            tv_names = [c_lower(x) for x in self.attr["type-value"]]
671            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
672            local_vars += [f'__u32 {", ".join(tv_names)};']
673            for level in self.attr["type-value"]:
674                level = c_lower(level)
675                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
676                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
677                prev = 'attr_' + level
678
679            tv_args = f", {', '.join(tv_names)}"
680
681        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
682        return get_lines, init_lines, local_vars
683
684
685class Struct:
686    def __init__(self, family, space_name, type_list=None, inherited=None):
687        self.family = family
688        self.space_name = space_name
689        self.attr_set = family.attr_sets[space_name]
690        # Use list to catch comparisons with empty sets
691        self._inherited = inherited if inherited is not None else []
692        self.inherited = []
693
694        self.nested = type_list is None
695        if family.name == c_lower(space_name):
696            self.render_name = c_lower(family.ident_name)
697        else:
698            self.render_name = c_lower(family.ident_name + '-' + space_name)
699        self.struct_name = 'struct ' + self.render_name
700        if self.nested and space_name in family.consts:
701            self.struct_name += '_'
702        self.ptr_name = self.struct_name + ' *'
703        # All attr sets this one contains, directly or multiple levels down
704        self.child_nests = set()
705
706        self.request = False
707        self.reply = False
708        self.recursive = False
709
710        self.attr_list = []
711        self.attrs = dict()
712        if type_list is not None:
713            for t in type_list:
714                self.attr_list.append((t, self.attr_set[t]),)
715        else:
716            for t in self.attr_set:
717                self.attr_list.append((t, self.attr_set[t]),)
718
719        max_val = 0
720        self.attr_max_val = None
721        for name, attr in self.attr_list:
722            if attr.value >= max_val:
723                max_val = attr.value
724                self.attr_max_val = attr
725            self.attrs[name] = attr
726
727    def __iter__(self):
728        yield from self.attrs
729
730    def __getitem__(self, key):
731        return self.attrs[key]
732
733    def member_list(self):
734        return self.attr_list
735
736    def set_inherited(self, new_inherited):
737        if self._inherited != new_inherited:
738            raise Exception("Inheriting different members not supported")
739        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
740
741
742class EnumEntry(SpecEnumEntry):
743    def __init__(self, enum_set, yaml, prev, value_start):
744        super().__init__(enum_set, yaml, prev, value_start)
745
746        if prev:
747            self.value_change = (self.value != prev.value + 1)
748        else:
749            self.value_change = (self.value != 0)
750        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
751
752        # Added by resolve:
753        self.c_name = None
754        delattr(self, "c_name")
755
756    def resolve(self):
757        self.resolve_up(super())
758
759        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
760
761
762class EnumSet(SpecEnumSet):
763    def __init__(self, family, yaml):
764        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
765
766        if 'enum-name' in yaml:
767            if yaml['enum-name']:
768                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
769                self.user_type = self.enum_name
770            else:
771                self.enum_name = None
772        else:
773            self.enum_name = 'enum ' + self.render_name
774
775        if self.enum_name:
776            self.user_type = self.enum_name
777        else:
778            self.user_type = 'int'
779
780        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
781
782        super().__init__(family, yaml)
783
784    def new_entry(self, entry, prev_entry, value_start):
785        return EnumEntry(self, entry, prev_entry, value_start)
786
787    def value_range(self):
788        low = min([x.value for x in self.entries.values()])
789        high = max([x.value for x in self.entries.values()])
790
791        if high - low + 1 != len(self.entries):
792            raise Exception("Can't get value range for a noncontiguous enum")
793
794        return low, high
795
796
797class AttrSet(SpecAttrSet):
798    def __init__(self, family, yaml):
799        super().__init__(family, yaml)
800
801        if self.subset_of is None:
802            if 'name-prefix' in yaml:
803                pfx = yaml['name-prefix']
804            elif self.name == family.name:
805                pfx = family.ident_name + '-a-'
806            else:
807                pfx = f"{family.ident_name}-a-{self.name}-"
808            self.name_prefix = c_upper(pfx)
809            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
810            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
811        else:
812            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
813            self.max_name = family.attr_sets[self.subset_of].max_name
814            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
815
816        # Added by resolve:
817        self.c_name = None
818        delattr(self, "c_name")
819
820    def resolve(self):
821        self.c_name = c_lower(self.name)
822        if self.c_name in _C_KW:
823            self.c_name += '_'
824        if self.c_name == self.family.c_name:
825            self.c_name = ''
826
827    def new_attr(self, elem, value):
828        if elem['type'] in scalars:
829            t = TypeScalar(self.family, self, elem, value)
830        elif elem['type'] == 'unused':
831            t = TypeUnused(self.family, self, elem, value)
832        elif elem['type'] == 'pad':
833            t = TypePad(self.family, self, elem, value)
834        elif elem['type'] == 'flag':
835            t = TypeFlag(self.family, self, elem, value)
836        elif elem['type'] == 'string':
837            t = TypeString(self.family, self, elem, value)
838        elif elem['type'] == 'binary':
839            t = TypeBinary(self.family, self, elem, value)
840        elif elem['type'] == 'bitfield32':
841            t = TypeBitfield32(self.family, self, elem, value)
842        elif elem['type'] == 'nest':
843            t = TypeNest(self.family, self, elem, value)
844        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
845            if elem["sub-type"] == 'nest':
846                t = TypeArrayNest(self.family, self, elem, value)
847            else:
848                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
849        elif elem['type'] == 'nest-type-value':
850            t = TypeNestTypeValue(self.family, self, elem, value)
851        else:
852            raise Exception(f"No typed class for type {elem['type']}")
853
854        if 'multi-attr' in elem and elem['multi-attr']:
855            t = TypeMultiAttr(self.family, self, elem, value, t)
856
857        return t
858
859
860class Operation(SpecOperation):
861    def __init__(self, family, yaml, req_value, rsp_value):
862        super().__init__(family, yaml, req_value, rsp_value)
863
864        self.render_name = c_lower(family.ident_name + '_' + self.name)
865
866        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
867                         ('dump' in yaml and 'request' in yaml['dump'])
868
869        self.has_ntf = False
870
871        # Added by resolve:
872        self.enum_name = None
873        delattr(self, "enum_name")
874
875    def resolve(self):
876        self.resolve_up(super())
877
878        if not self.is_async:
879            self.enum_name = self.family.op_prefix + c_upper(self.name)
880        else:
881            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
882
883    def mark_has_ntf(self):
884        self.has_ntf = True
885
886
887class Family(SpecFamily):
888    def __init__(self, file_name, exclude_ops):
889        # Added by resolve:
890        self.c_name = None
891        delattr(self, "c_name")
892        self.op_prefix = None
893        delattr(self, "op_prefix")
894        self.async_op_prefix = None
895        delattr(self, "async_op_prefix")
896        self.mcgrps = None
897        delattr(self, "mcgrps")
898        self.consts = None
899        delattr(self, "consts")
900        self.hooks = None
901        delattr(self, "hooks")
902
903        super().__init__(file_name, exclude_ops=exclude_ops)
904
905        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
906        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
907
908        if 'definitions' not in self.yaml:
909            self.yaml['definitions'] = []
910
911        if 'uapi-header' in self.yaml:
912            self.uapi_header = self.yaml['uapi-header']
913        else:
914            self.uapi_header = f"linux/{self.ident_name}.h"
915        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
916            self.uapi_header_name = self.uapi_header[6:-2]
917        else:
918            self.uapi_header_name = self.ident_name
919
920    def resolve(self):
921        self.resolve_up(super())
922
923        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
924            raise Exception("Codegen only supported for genetlink")
925
926        self.c_name = c_lower(self.ident_name)
927        if 'name-prefix' in self.yaml['operations']:
928            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
929        else:
930            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
931        if 'async-prefix' in self.yaml['operations']:
932            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
933        else:
934            self.async_op_prefix = self.op_prefix
935
936        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
937
938        self.hooks = dict()
939        for when in ['pre', 'post']:
940            self.hooks[when] = dict()
941            for op_mode in ['do', 'dump']:
942                self.hooks[when][op_mode] = dict()
943                self.hooks[when][op_mode]['set'] = set()
944                self.hooks[when][op_mode]['list'] = []
945
946        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
947        self.root_sets = dict()
948        # dict space-name -> set('request', 'reply')
949        self.pure_nested_structs = dict()
950
951        self._mark_notify()
952        self._mock_up_events()
953
954        self._load_root_sets()
955        self._load_nested_sets()
956        self._load_attr_use()
957        self._load_hooks()
958
959        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
960        if self.kernel_policy == 'global':
961            self._load_global_policy()
962
963    def new_enum(self, elem):
964        return EnumSet(self, elem)
965
966    def new_attr_set(self, elem):
967        return AttrSet(self, elem)
968
969    def new_operation(self, elem, req_value, rsp_value):
970        return Operation(self, elem, req_value, rsp_value)
971
972    def _mark_notify(self):
973        for op in self.msgs.values():
974            if 'notify' in op:
975                self.ops[op['notify']].mark_has_ntf()
976
977    # Fake a 'do' equivalent of all events, so that we can render their response parsing
978    def _mock_up_events(self):
979        for op in self.yaml['operations']['list']:
980            if 'event' in op:
981                op['do'] = {
982                    'reply': {
983                        'attributes': op['event']['attributes']
984                    }
985                }
986
987    def _load_root_sets(self):
988        for op_name, op in self.msgs.items():
989            if 'attribute-set' not in op:
990                continue
991
992            req_attrs = set()
993            rsp_attrs = set()
994            for op_mode in ['do', 'dump']:
995                if op_mode in op and 'request' in op[op_mode]:
996                    req_attrs.update(set(op[op_mode]['request']['attributes']))
997                if op_mode in op and 'reply' in op[op_mode]:
998                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
999            if 'event' in op:
1000                rsp_attrs.update(set(op['event']['attributes']))
1001
1002            if op['attribute-set'] not in self.root_sets:
1003                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1004            else:
1005                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1006                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1007
1008    def _sort_pure_types(self):
1009        # Try to reorder according to dependencies
1010        pns_key_list = list(self.pure_nested_structs.keys())
1011        pns_key_seen = set()
1012        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1013        for _ in range(rounds):
1014            if len(pns_key_list) == 0:
1015                break
1016            name = pns_key_list.pop(0)
1017            finished = True
1018            for _, spec in self.attr_sets[name].items():
1019                if 'nested-attributes' in spec:
1020                    nested = spec['nested-attributes']
1021                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1022                    if self.pure_nested_structs[nested].recursive:
1023                        continue
1024                    if nested not in pns_key_seen:
1025                        # Dicts are sorted, this will make struct last
1026                        struct = self.pure_nested_structs.pop(name)
1027                        self.pure_nested_structs[name] = struct
1028                        finished = False
1029                        break
1030            if finished:
1031                pns_key_seen.add(name)
1032            else:
1033                pns_key_list.append(name)
1034
1035    def _load_nested_sets(self):
1036        attr_set_queue = list(self.root_sets.keys())
1037        attr_set_seen = set(self.root_sets.keys())
1038
1039        while len(attr_set_queue):
1040            a_set = attr_set_queue.pop(0)
1041            for attr, spec in self.attr_sets[a_set].items():
1042                if 'nested-attributes' not in spec:
1043                    continue
1044
1045                nested = spec['nested-attributes']
1046                if nested not in attr_set_seen:
1047                    attr_set_queue.append(nested)
1048                    attr_set_seen.add(nested)
1049
1050                inherit = set()
1051                if nested not in self.root_sets:
1052                    if nested not in self.pure_nested_structs:
1053                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1054                else:
1055                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1056
1057                if 'type-value' in spec:
1058                    if nested in self.root_sets:
1059                        raise Exception("Inheriting members to a space used as root not supported")
1060                    inherit.update(set(spec['type-value']))
1061                elif spec['type'] == 'indexed-array':
1062                    inherit.add('idx')
1063                self.pure_nested_structs[nested].set_inherited(inherit)
1064
1065        for root_set, rs_members in self.root_sets.items():
1066            for attr, spec in self.attr_sets[root_set].items():
1067                if 'nested-attributes' in spec:
1068                    nested = spec['nested-attributes']
1069                    if attr in rs_members['request']:
1070                        self.pure_nested_structs[nested].request = True
1071                    if attr in rs_members['reply']:
1072                        self.pure_nested_structs[nested].reply = True
1073
1074        self._sort_pure_types()
1075
1076        # Propagate the request / reply / recursive
1077        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1078            for _, spec in self.attr_sets[attr_set].items():
1079                if 'nested-attributes' in spec:
1080                    child_name = spec['nested-attributes']
1081                    struct.child_nests.add(child_name)
1082                    child = self.pure_nested_structs.get(child_name)
1083                    if child:
1084                        if not child.recursive:
1085                            struct.child_nests.update(child.child_nests)
1086                        child.request |= struct.request
1087                        child.reply |= struct.reply
1088                if attr_set in struct.child_nests:
1089                    struct.recursive = True
1090
1091        self._sort_pure_types()
1092
1093    def _load_attr_use(self):
1094        for _, struct in self.pure_nested_structs.items():
1095            if struct.request:
1096                for _, arg in struct.member_list():
1097                    arg.request = True
1098            if struct.reply:
1099                for _, arg in struct.member_list():
1100                    arg.reply = True
1101
1102        for root_set, rs_members in self.root_sets.items():
1103            for attr, spec in self.attr_sets[root_set].items():
1104                if attr in rs_members['request']:
1105                    spec.request = True
1106                if attr in rs_members['reply']:
1107                    spec.reply = True
1108
1109    def _load_global_policy(self):
1110        global_set = set()
1111        attr_set_name = None
1112        for op_name, op in self.ops.items():
1113            if not op:
1114                continue
1115            if 'attribute-set' not in op:
1116                continue
1117
1118            if attr_set_name is None:
1119                attr_set_name = op['attribute-set']
1120            if attr_set_name != op['attribute-set']:
1121                raise Exception('For a global policy all ops must use the same set')
1122
1123            for op_mode in ['do', 'dump']:
1124                if op_mode in op:
1125                    req = op[op_mode].get('request')
1126                    if req:
1127                        global_set.update(req.get('attributes', []))
1128
1129        self.global_policy = []
1130        self.global_policy_set = attr_set_name
1131        for attr in self.attr_sets[attr_set_name]:
1132            if attr in global_set:
1133                self.global_policy.append(attr)
1134
1135    def _load_hooks(self):
1136        for op in self.ops.values():
1137            for op_mode in ['do', 'dump']:
1138                if op_mode not in op:
1139                    continue
1140                for when in ['pre', 'post']:
1141                    if when not in op[op_mode]:
1142                        continue
1143                    name = op[op_mode][when]
1144                    if name in self.hooks[when][op_mode]['set']:
1145                        continue
1146                    self.hooks[when][op_mode]['set'].add(name)
1147                    self.hooks[when][op_mode]['list'].append(name)
1148
1149
1150class RenderInfo:
1151    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1152        self.family = family
1153        self.nl = cw.nlib
1154        self.ku_space = ku_space
1155        self.op_mode = op_mode
1156        self.op = op
1157
1158        self.fixed_hdr = None
1159        if op and op.fixed_header:
1160            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1161
1162        # 'do' and 'dump' response parsing is identical
1163        self.type_consistent = True
1164        if op_mode != 'do' and 'dump' in op:
1165            if 'do' in op:
1166                if ('reply' in op['do']) != ('reply' in op["dump"]):
1167                    self.type_consistent = False
1168                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1169                    self.type_consistent = False
1170            else:
1171                self.type_consistent = False
1172
1173        self.attr_set = attr_set
1174        if not self.attr_set:
1175            self.attr_set = op['attribute-set']
1176
1177        self.type_name_conflict = False
1178        if op:
1179            self.type_name = c_lower(op.name)
1180        else:
1181            self.type_name = c_lower(attr_set)
1182            if attr_set in family.consts:
1183                self.type_name_conflict = True
1184
1185        self.cw = cw
1186
1187        self.struct = dict()
1188        if op_mode == 'notify':
1189            op_mode = 'do'
1190        for op_dir in ['request', 'reply']:
1191            if op:
1192                type_list = []
1193                if op_dir in op[op_mode]:
1194                    type_list = op[op_mode][op_dir]['attributes']
1195                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1196        if op_mode == 'event':
1197            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1198
1199
1200class CodeWriter:
1201    def __init__(self, nlib, out_file=None, overwrite=True):
1202        self.nlib = nlib
1203        self._overwrite = overwrite
1204
1205        self._nl = False
1206        self._block_end = False
1207        self._silent_block = False
1208        self._ind = 0
1209        self._ifdef_block = None
1210        if out_file is None:
1211            self._out = os.sys.stdout
1212        else:
1213            self._out = tempfile.NamedTemporaryFile('w+')
1214            self._out_file = out_file
1215
1216    def __del__(self):
1217        self.close_out_file()
1218
1219    def close_out_file(self):
1220        if self._out == os.sys.stdout:
1221            return
1222        # Avoid modifying the file if contents didn't change
1223        self._out.flush()
1224        if not self._overwrite and os.path.isfile(self._out_file):
1225            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1226                return
1227        with open(self._out_file, 'w+') as out_file:
1228            self._out.seek(0)
1229            shutil.copyfileobj(self._out, out_file)
1230            self._out.close()
1231        self._out = os.sys.stdout
1232
1233    @classmethod
1234    def _is_cond(cls, line):
1235        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1236
1237    def p(self, line, add_ind=0):
1238        if self._block_end:
1239            self._block_end = False
1240            if line.startswith('else'):
1241                line = '} ' + line
1242            else:
1243                self._out.write('\t' * self._ind + '}\n')
1244
1245        if self._nl:
1246            self._out.write('\n')
1247            self._nl = False
1248
1249        ind = self._ind
1250        if line[-1] == ':':
1251            ind -= 1
1252        if self._silent_block:
1253            ind += 1
1254        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1255        if line[0] == '#':
1256            ind = 0
1257        if add_ind:
1258            ind += add_ind
1259        self._out.write('\t' * ind + line + '\n')
1260
1261    def nl(self):
1262        self._nl = True
1263
1264    def block_start(self, line=''):
1265        if line:
1266            line = line + ' '
1267        self.p(line + '{')
1268        self._ind += 1
1269
1270    def block_end(self, line=''):
1271        if line and line[0] not in {';', ','}:
1272            line = ' ' + line
1273        self._ind -= 1
1274        self._nl = False
1275        if not line:
1276            # Delay printing closing bracket in case "else" comes next
1277            if self._block_end:
1278                self._out.write('\t' * (self._ind + 1) + '}\n')
1279            self._block_end = True
1280        else:
1281            self.p('}' + line)
1282
1283    def write_doc_line(self, doc, indent=True):
1284        words = doc.split()
1285        line = ' *'
1286        for word in words:
1287            if len(line) + len(word) >= 79:
1288                self.p(line)
1289                line = ' *'
1290                if indent:
1291                    line += '  '
1292            line += ' ' + word
1293        self.p(line)
1294
1295    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1296        if not args:
1297            args = ['void']
1298
1299        if doc:
1300            self.p('/*')
1301            self.p(' * ' + doc)
1302            self.p(' */')
1303
1304        oneline = qual_ret
1305        if qual_ret[-1] != '*':
1306            oneline += ' '
1307        oneline += f"{name}({', '.join(args)}){suffix}"
1308
1309        if len(oneline) < 80:
1310            self.p(oneline)
1311            return
1312
1313        v = qual_ret
1314        if len(v) > 3:
1315            self.p(v)
1316            v = ''
1317        elif qual_ret[-1] != '*':
1318            v += ' '
1319        v += name + '('
1320        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1321        delta_ind = len(v) - len(ind)
1322        v += args[0]
1323        i = 1
1324        while i < len(args):
1325            next_len = len(v) + len(args[i])
1326            if v[0] == '\t':
1327                next_len += delta_ind
1328            if next_len > 76:
1329                self.p(v + ',')
1330                v = ind
1331            else:
1332                v += ', '
1333            v += args[i]
1334            i += 1
1335        self.p(v + ')' + suffix)
1336
1337    def write_func_lvar(self, local_vars):
1338        if not local_vars:
1339            return
1340
1341        if type(local_vars) is str:
1342            local_vars = [local_vars]
1343
1344        local_vars.sort(key=len, reverse=True)
1345        for var in local_vars:
1346            self.p(var)
1347        self.nl()
1348
1349    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1350        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1351        self.write_func_lvar(local_vars=local_vars)
1352
1353        self.block_start()
1354        for line in body:
1355            self.p(line)
1356        self.block_end()
1357
1358    def writes_defines(self, defines):
1359        longest = 0
1360        for define in defines:
1361            if len(define[0]) > longest:
1362                longest = len(define[0])
1363        longest = ((longest + 8) // 8) * 8
1364        for define in defines:
1365            line = '#define ' + define[0]
1366            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1367            if type(define[1]) is int:
1368                line += str(define[1])
1369            elif type(define[1]) is str:
1370                line += '"' + define[1] + '"'
1371            self.p(line)
1372
1373    def write_struct_init(self, members):
1374        longest = max([len(x[0]) for x in members])
1375        longest += 1  # because we prepend a .
1376        longest = ((longest + 8) // 8) * 8
1377        for one in members:
1378            line = '.' + one[0]
1379            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1380            line += '= ' + str(one[1]) + ','
1381            self.p(line)
1382
1383    def ifdef_block(self, config):
1384        config_option = None
1385        if config:
1386            config_option = 'CONFIG_' + c_upper(config)
1387        if self._ifdef_block == config_option:
1388            return
1389
1390        if self._ifdef_block:
1391            self.p('#endif /* ' + self._ifdef_block + ' */')
1392        if config_option:
1393            self.p('#ifdef ' + config_option)
1394        self._ifdef_block = config_option
1395
1396
1397scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1398
1399direction_to_suffix = {
1400    'reply': '_rsp',
1401    'request': '_req',
1402    '': ''
1403}
1404
1405op_mode_to_wrapper = {
1406    'do': '',
1407    'dump': '_list',
1408    'notify': '_ntf',
1409    'event': '',
1410}
1411
1412_C_KW = {
1413    'auto',
1414    'bool',
1415    'break',
1416    'case',
1417    'char',
1418    'const',
1419    'continue',
1420    'default',
1421    'do',
1422    'double',
1423    'else',
1424    'enum',
1425    'extern',
1426    'float',
1427    'for',
1428    'goto',
1429    'if',
1430    'inline',
1431    'int',
1432    'long',
1433    'register',
1434    'return',
1435    'short',
1436    'signed',
1437    'sizeof',
1438    'static',
1439    'struct',
1440    'switch',
1441    'typedef',
1442    'union',
1443    'unsigned',
1444    'void',
1445    'volatile',
1446    'while'
1447}
1448
1449
1450def rdir(direction):
1451    if direction == 'reply':
1452        return 'request'
1453    if direction == 'request':
1454        return 'reply'
1455    return direction
1456
1457
1458def op_prefix(ri, direction, deref=False):
1459    suffix = f"_{ri.type_name}"
1460
1461    if not ri.op_mode or ri.op_mode == 'do':
1462        suffix += f"{direction_to_suffix[direction]}"
1463    else:
1464        if direction == 'request':
1465            suffix += '_req_dump'
1466        else:
1467            if ri.type_consistent:
1468                if deref:
1469                    suffix += f"{direction_to_suffix[direction]}"
1470                else:
1471                    suffix += op_mode_to_wrapper[ri.op_mode]
1472            else:
1473                suffix += '_rsp'
1474                suffix += '_dump' if deref else '_list'
1475
1476    return f"{ri.family.c_name}{suffix}"
1477
1478
1479def type_name(ri, direction, deref=False):
1480    return f"struct {op_prefix(ri, direction, deref=deref)}"
1481
1482
1483def print_prototype(ri, direction, terminate=True, doc=None):
1484    suffix = ';' if terminate else ''
1485
1486    fname = ri.op.render_name
1487    if ri.op_mode == 'dump':
1488        fname += '_dump'
1489
1490    args = ['struct ynl_sock *ys']
1491    if 'request' in ri.op[ri.op_mode]:
1492        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1493
1494    ret = 'int'
1495    if 'reply' in ri.op[ri.op_mode]:
1496        ret = f"{type_name(ri, rdir(direction))} *"
1497
1498    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1499
1500
1501def print_req_prototype(ri):
1502    print_prototype(ri, "request", doc=ri.op['doc'])
1503
1504
1505def print_dump_prototype(ri):
1506    print_prototype(ri, "request")
1507
1508
1509def put_typol_fwd(cw, struct):
1510    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1511
1512
1513def put_typol(cw, struct):
1514    type_max = struct.attr_set.max_name
1515    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1516
1517    for _, arg in struct.member_list():
1518        arg.attr_typol(cw)
1519
1520    cw.block_end(line=';')
1521    cw.nl()
1522
1523    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1524    cw.p(f'.max_attr = {type_max},')
1525    cw.p(f'.table = {struct.render_name}_policy,')
1526    cw.block_end(line=';')
1527    cw.nl()
1528
1529
1530def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1531    args = [f'int {arg_name}']
1532    if enum:
1533        args = [enum.user_type + ' ' + arg_name]
1534    cw.write_func_prot('const char *', f'{render_name}_str', args)
1535    cw.block_start()
1536    if enum and enum.type == 'flags':
1537        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1538    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1539    cw.p('return NULL;')
1540    cw.p(f'return {map_name}[{arg_name}];')
1541    cw.block_end()
1542    cw.nl()
1543
1544
1545def put_op_name_fwd(family, cw):
1546    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1547
1548
1549def put_op_name(family, cw):
1550    map_name = f'{family.c_name}_op_strmap'
1551    cw.block_start(line=f"static const char * const {map_name}[] =")
1552    for op_name, op in family.msgs.items():
1553        if op.rsp_value:
1554            # Make sure we don't add duplicated entries, if multiple commands
1555            # produce the same response in legacy families.
1556            if family.rsp_by_value[op.rsp_value] != op:
1557                cw.p(f'// skip "{op_name}", duplicate reply value')
1558                continue
1559
1560            if op.req_value == op.rsp_value:
1561                cw.p(f'[{op.enum_name}] = "{op_name}",')
1562            else:
1563                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1564    cw.block_end(line=';')
1565    cw.nl()
1566
1567    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1568
1569
1570def put_enum_to_str_fwd(family, cw, enum):
1571    args = [enum.user_type + ' value']
1572    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1573
1574
1575def put_enum_to_str(family, cw, enum):
1576    map_name = f'{enum.render_name}_strmap'
1577    cw.block_start(line=f"static const char * const {map_name}[] =")
1578    for entry in enum.entries.values():
1579        cw.p(f'[{entry.value}] = "{entry.name}",')
1580    cw.block_end(line=';')
1581    cw.nl()
1582
1583    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1584
1585
1586def put_req_nested_prototype(ri, struct, suffix=';'):
1587    func_args = ['struct nlmsghdr *nlh',
1588                 'unsigned int attr_type',
1589                 f'{struct.ptr_name}obj']
1590
1591    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1592                          suffix=suffix)
1593
1594
1595def put_req_nested(ri, struct):
1596    put_req_nested_prototype(ri, struct, suffix='')
1597    ri.cw.block_start()
1598    ri.cw.write_func_lvar('struct nlattr *nest;')
1599
1600    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1601
1602    for _, arg in struct.member_list():
1603        arg.attr_put(ri, "obj")
1604
1605    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1606
1607    ri.cw.nl()
1608    ri.cw.p('return 0;')
1609    ri.cw.block_end()
1610    ri.cw.nl()
1611
1612
1613def _multi_parse(ri, struct, init_lines, local_vars):
1614    if struct.nested:
1615        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1616    else:
1617        if ri.fixed_hdr:
1618            local_vars += ['void *hdr;']
1619        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1620
1621    array_nests = set()
1622    multi_attrs = set()
1623    needs_parg = False
1624    for arg, aspec in struct.member_list():
1625        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1626            if aspec["sub-type"] == 'nest':
1627                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1628                array_nests.add(arg)
1629            else:
1630                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1631        if 'multi-attr' in aspec:
1632            multi_attrs.add(arg)
1633        needs_parg |= 'nested-attributes' in aspec
1634    if array_nests or multi_attrs:
1635        local_vars.append('int i;')
1636    if needs_parg:
1637        local_vars.append('struct ynl_parse_arg parg;')
1638        init_lines.append('parg.ys = yarg->ys;')
1639
1640    all_multi = array_nests | multi_attrs
1641
1642    for anest in sorted(all_multi):
1643        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1644
1645    ri.cw.block_start()
1646    ri.cw.write_func_lvar(local_vars)
1647
1648    for line in init_lines:
1649        ri.cw.p(line)
1650    ri.cw.nl()
1651
1652    for arg in struct.inherited:
1653        ri.cw.p(f'dst->{arg} = {arg};')
1654
1655    if ri.fixed_hdr:
1656        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1657        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1658    for anest in sorted(all_multi):
1659        aspec = struct[anest]
1660        ri.cw.p(f"if (dst->{aspec.c_name})")
1661        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1662
1663    ri.cw.nl()
1664    ri.cw.block_start(line=iter_line)
1665    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1666    ri.cw.nl()
1667
1668    first = True
1669    for _, arg in struct.member_list():
1670        good = arg.attr_get(ri, 'dst', first=first)
1671        # First may be 'unused' or 'pad', ignore those
1672        first &= not good
1673
1674    ri.cw.block_end()
1675    ri.cw.nl()
1676
1677    for anest in sorted(array_nests):
1678        aspec = struct[anest]
1679
1680        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1681        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1682        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1683        ri.cw.p('i = 0;')
1684        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1685        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1686        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1687        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1688        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1689        ri.cw.p('i++;')
1690        ri.cw.block_end()
1691        ri.cw.block_end()
1692    ri.cw.nl()
1693
1694    for anest in sorted(multi_attrs):
1695        aspec = struct[anest]
1696        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1697        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1698        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1699        ri.cw.p('i = 0;')
1700        if 'nested-attributes' in aspec:
1701            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1702        ri.cw.block_start(line=iter_line)
1703        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1704        if 'nested-attributes' in aspec:
1705            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1706            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1707            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1708        elif aspec.type in scalars:
1709            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1710        else:
1711            raise Exception('Nest parsing type not supported yet')
1712        ri.cw.p('i++;')
1713        ri.cw.block_end()
1714        ri.cw.block_end()
1715        ri.cw.block_end()
1716    ri.cw.nl()
1717
1718    if struct.nested:
1719        ri.cw.p('return 0;')
1720    else:
1721        ri.cw.p('return YNL_PARSE_CB_OK;')
1722    ri.cw.block_end()
1723    ri.cw.nl()
1724
1725
1726def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1727    func_args = ['struct ynl_parse_arg *yarg',
1728                 'const struct nlattr *nested']
1729    for arg in struct.inherited:
1730        func_args.append('__u32 ' + arg)
1731
1732    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1733                          suffix=suffix)
1734
1735
1736def parse_rsp_nested(ri, struct):
1737    parse_rsp_nested_prototype(ri, struct, suffix='')
1738
1739    local_vars = ['const struct nlattr *attr;',
1740                  f'{struct.ptr_name}dst = yarg->data;']
1741    init_lines = []
1742
1743    _multi_parse(ri, struct, init_lines, local_vars)
1744
1745
1746def parse_rsp_msg(ri, deref=False):
1747    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1748        return
1749
1750    func_args = ['const struct nlmsghdr *nlh',
1751                 'struct ynl_parse_arg *yarg']
1752
1753    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1754                  'const struct nlattr *attr;']
1755    init_lines = ['dst = yarg->data;']
1756
1757    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1758
1759    if ri.struct["reply"].member_list():
1760        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1761    else:
1762        # Empty reply
1763        ri.cw.block_start()
1764        ri.cw.p('return YNL_PARSE_CB_OK;')
1765        ri.cw.block_end()
1766        ri.cw.nl()
1767
1768
1769def print_req(ri):
1770    ret_ok = '0'
1771    ret_err = '-1'
1772    direction = "request"
1773    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1774                  'struct nlmsghdr *nlh;',
1775                  'int err;']
1776
1777    if 'reply' in ri.op[ri.op_mode]:
1778        ret_ok = 'rsp'
1779        ret_err = 'NULL'
1780        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1781
1782    if ri.fixed_hdr:
1783        local_vars += ['size_t hdr_len;',
1784                       'void *hdr;']
1785
1786    print_prototype(ri, direction, terminate=False)
1787    ri.cw.block_start()
1788    ri.cw.write_func_lvar(local_vars)
1789
1790    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1791
1792    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1793    if 'reply' in ri.op[ri.op_mode]:
1794        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1795    ri.cw.nl()
1796
1797    if ri.fixed_hdr:
1798        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1799        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1800        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1801        ri.cw.nl()
1802
1803    for _, attr in ri.struct["request"].member_list():
1804        attr.attr_put(ri, "req")
1805    ri.cw.nl()
1806
1807    if 'reply' in ri.op[ri.op_mode]:
1808        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1809        ri.cw.p('yrs.yarg.data = rsp;')
1810        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1811        if ri.op.value is not None:
1812            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1813        else:
1814            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1815        ri.cw.nl()
1816    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1817    ri.cw.p('if (err < 0)')
1818    if 'reply' in ri.op[ri.op_mode]:
1819        ri.cw.p('goto err_free;')
1820    else:
1821        ri.cw.p('return -1;')
1822    ri.cw.nl()
1823
1824    ri.cw.p(f"return {ret_ok};")
1825    ri.cw.nl()
1826
1827    if 'reply' in ri.op[ri.op_mode]:
1828        ri.cw.p('err_free:')
1829        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1830        ri.cw.p(f"return {ret_err};")
1831
1832    ri.cw.block_end()
1833
1834
1835def print_dump(ri):
1836    direction = "request"
1837    print_prototype(ri, direction, terminate=False)
1838    ri.cw.block_start()
1839    local_vars = ['struct ynl_dump_state yds = {};',
1840                  'struct nlmsghdr *nlh;',
1841                  'int err;']
1842
1843    if ri.fixed_hdr:
1844        local_vars += ['size_t hdr_len;',
1845                       'void *hdr;']
1846
1847    ri.cw.write_func_lvar(local_vars)
1848
1849    ri.cw.p('yds.yarg.ys = ys;')
1850    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1851    ri.cw.p("yds.yarg.data = NULL;")
1852    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1853    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1854    if ri.op.value is not None:
1855        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1856    else:
1857        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1858    ri.cw.nl()
1859    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1860
1861    if ri.fixed_hdr:
1862        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1863        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1864        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1865        ri.cw.nl()
1866
1867    if "request" in ri.op[ri.op_mode]:
1868        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1869        ri.cw.nl()
1870        for _, attr in ri.struct["request"].member_list():
1871            attr.attr_put(ri, "req")
1872    ri.cw.nl()
1873
1874    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1875    ri.cw.p('if (err < 0)')
1876    ri.cw.p('goto free_list;')
1877    ri.cw.nl()
1878
1879    ri.cw.p('return yds.first;')
1880    ri.cw.nl()
1881    ri.cw.p('free_list:')
1882    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1883    ri.cw.p('return NULL;')
1884    ri.cw.block_end()
1885
1886
1887def call_free(ri, direction, var):
1888    return f"{op_prefix(ri, direction)}_free({var});"
1889
1890
1891def free_arg_name(direction):
1892    if direction:
1893        return direction_to_suffix[direction][1:]
1894    return 'obj'
1895
1896
1897def print_alloc_wrapper(ri, direction):
1898    name = op_prefix(ri, direction)
1899    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1900    ri.cw.block_start()
1901    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1902    ri.cw.block_end()
1903
1904
1905def print_free_prototype(ri, direction, suffix=';'):
1906    name = op_prefix(ri, direction)
1907    struct_name = name
1908    if ri.type_name_conflict:
1909        struct_name += '_'
1910    arg = free_arg_name(direction)
1911    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1912
1913
1914def _print_type(ri, direction, struct):
1915    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1916    if not direction and ri.type_name_conflict:
1917        suffix += '_'
1918
1919    if ri.op_mode == 'dump':
1920        suffix += '_dump'
1921
1922    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1923
1924    if ri.fixed_hdr:
1925        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1926        ri.cw.nl()
1927
1928    meta_started = False
1929    for _, attr in struct.member_list():
1930        for type_filter in ['len', 'bit']:
1931            line = attr.presence_member(ri.ku_space, type_filter)
1932            if line:
1933                if not meta_started:
1934                    ri.cw.block_start(line=f"struct")
1935                    meta_started = True
1936                ri.cw.p(line)
1937    if meta_started:
1938        ri.cw.block_end(line='_present;')
1939        ri.cw.nl()
1940
1941    for arg in struct.inherited:
1942        ri.cw.p(f"__u32 {arg};")
1943
1944    for _, attr in struct.member_list():
1945        attr.struct_member(ri)
1946
1947    ri.cw.block_end(line=';')
1948    ri.cw.nl()
1949
1950
1951def print_type(ri, direction):
1952    _print_type(ri, direction, ri.struct[direction])
1953
1954
1955def print_type_full(ri, struct):
1956    _print_type(ri, "", struct)
1957
1958
1959def print_type_helpers(ri, direction, deref=False):
1960    print_free_prototype(ri, direction)
1961    ri.cw.nl()
1962
1963    if ri.ku_space == 'user' and direction == 'request':
1964        for _, attr in ri.struct[direction].member_list():
1965            attr.setter(ri, ri.attr_set, direction, deref=deref)
1966    ri.cw.nl()
1967
1968
1969def print_req_type_helpers(ri):
1970    if len(ri.struct["request"].attr_list) == 0:
1971        return
1972    print_alloc_wrapper(ri, "request")
1973    print_type_helpers(ri, "request")
1974
1975
1976def print_rsp_type_helpers(ri):
1977    if 'reply' not in ri.op[ri.op_mode]:
1978        return
1979    print_type_helpers(ri, "reply")
1980
1981
1982def print_parse_prototype(ri, direction, terminate=True):
1983    suffix = "_rsp" if direction == "reply" else "_req"
1984    term = ';' if terminate else ''
1985
1986    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1987                          ['const struct nlattr **tb',
1988                           f"struct {ri.op.render_name}{suffix} *req"],
1989                          suffix=term)
1990
1991
1992def print_req_type(ri):
1993    if len(ri.struct["request"].attr_list) == 0:
1994        return
1995    print_type(ri, "request")
1996
1997
1998def print_req_free(ri):
1999    if 'request' not in ri.op[ri.op_mode]:
2000        return
2001    _free_type(ri, 'request', ri.struct['request'])
2002
2003
2004def print_rsp_type(ri):
2005    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2006        direction = 'reply'
2007    elif ri.op_mode == 'event':
2008        direction = 'reply'
2009    else:
2010        return
2011    print_type(ri, direction)
2012
2013
2014def print_wrapped_type(ri):
2015    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2016    if ri.op_mode == 'dump':
2017        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2018    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2019        ri.cw.p('__u16 family;')
2020        ri.cw.p('__u8 cmd;')
2021        ri.cw.p('struct ynl_ntf_base_type *next;')
2022        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2023    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2024    ri.cw.block_end(line=';')
2025    ri.cw.nl()
2026    print_free_prototype(ri, 'reply')
2027    ri.cw.nl()
2028
2029
2030def _free_type_members_iter(ri, struct):
2031    for _, attr in struct.member_list():
2032        if attr.free_needs_iter():
2033            ri.cw.p('unsigned int i;')
2034            ri.cw.nl()
2035            break
2036
2037
2038def _free_type_members(ri, var, struct, ref=''):
2039    for _, attr in struct.member_list():
2040        attr.free(ri, var, ref)
2041
2042
2043def _free_type(ri, direction, struct):
2044    var = free_arg_name(direction)
2045
2046    print_free_prototype(ri, direction, suffix='')
2047    ri.cw.block_start()
2048    _free_type_members_iter(ri, struct)
2049    _free_type_members(ri, var, struct)
2050    if direction:
2051        ri.cw.p(f'free({var});')
2052    ri.cw.block_end()
2053    ri.cw.nl()
2054
2055
2056def free_rsp_nested_prototype(ri):
2057        print_free_prototype(ri, "")
2058
2059
2060def free_rsp_nested(ri, struct):
2061    _free_type(ri, "", struct)
2062
2063
2064def print_rsp_free(ri):
2065    if 'reply' not in ri.op[ri.op_mode]:
2066        return
2067    _free_type(ri, 'reply', ri.struct['reply'])
2068
2069
2070def print_dump_type_free(ri):
2071    sub_type = type_name(ri, 'reply')
2072
2073    print_free_prototype(ri, 'reply', suffix='')
2074    ri.cw.block_start()
2075    ri.cw.p(f"{sub_type} *next = rsp;")
2076    ri.cw.nl()
2077    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2078    _free_type_members_iter(ri, ri.struct['reply'])
2079    ri.cw.p('rsp = next;')
2080    ri.cw.p('next = rsp->next;')
2081    ri.cw.nl()
2082
2083    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2084    ri.cw.p(f'free(rsp);')
2085    ri.cw.block_end()
2086    ri.cw.block_end()
2087    ri.cw.nl()
2088
2089
2090def print_ntf_type_free(ri):
2091    print_free_prototype(ri, 'reply', suffix='')
2092    ri.cw.block_start()
2093    _free_type_members_iter(ri, ri.struct['reply'])
2094    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2095    ri.cw.p(f'free(rsp);')
2096    ri.cw.block_end()
2097    ri.cw.nl()
2098
2099
2100def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2101    if terminate and ri and policy_should_be_static(struct.family):
2102        return
2103
2104    if terminate:
2105        prefix = 'extern '
2106    else:
2107        if ri and policy_should_be_static(struct.family):
2108            prefix = 'static '
2109        else:
2110            prefix = ''
2111
2112    suffix = ';' if terminate else ' = {'
2113
2114    max_attr = struct.attr_max_val
2115    if ri:
2116        name = ri.op.render_name
2117        if ri.op.dual_policy:
2118            name += '_' + ri.op_mode
2119    else:
2120        name = struct.render_name
2121    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2122
2123
2124def print_req_policy(cw, struct, ri=None):
2125    if ri and ri.op:
2126        cw.ifdef_block(ri.op.get('config-cond', None))
2127    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2128    for _, arg in struct.member_list():
2129        arg.attr_policy(cw)
2130    cw.p("};")
2131    cw.ifdef_block(None)
2132    cw.nl()
2133
2134
2135def kernel_can_gen_family_struct(family):
2136    return family.proto == 'genetlink'
2137
2138
2139def policy_should_be_static(family):
2140    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2141
2142
2143def print_kernel_policy_ranges(family, cw):
2144    first = True
2145    for _, attr_set in family.attr_sets.items():
2146        if attr_set.subset_of:
2147            continue
2148
2149        for _, attr in attr_set.items():
2150            if not attr.request:
2151                continue
2152            if 'full-range' not in attr.checks:
2153                continue
2154
2155            if first:
2156                cw.p('/* Integer value ranges */')
2157                first = False
2158
2159            sign = '' if attr.type[0] == 'u' else '_signed'
2160            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2161            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2162            members = []
2163            if 'min' in attr.checks:
2164                members.append(('min', str(attr.get_limit('min')) + suffix))
2165            if 'max' in attr.checks:
2166                members.append(('max', str(attr.get_limit('max')) + suffix))
2167            cw.write_struct_init(members)
2168            cw.block_end(line=';')
2169            cw.nl()
2170
2171
2172def print_kernel_op_table_fwd(family, cw, terminate):
2173    exported = not kernel_can_gen_family_struct(family)
2174
2175    if not terminate or exported:
2176        cw.p(f"/* Ops table for {family.ident_name} */")
2177
2178        pol_to_struct = {'global': 'genl_small_ops',
2179                         'per-op': 'genl_ops',
2180                         'split': 'genl_split_ops'}
2181        struct_type = pol_to_struct[family.kernel_policy]
2182
2183        if not exported:
2184            cnt = ""
2185        elif family.kernel_policy == 'split':
2186            cnt = 0
2187            for op in family.ops.values():
2188                if 'do' in op:
2189                    cnt += 1
2190                if 'dump' in op:
2191                    cnt += 1
2192        else:
2193            cnt = len(family.ops)
2194
2195        qual = 'static const' if not exported else 'const'
2196        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2197        if terminate:
2198            cw.p(f"extern {line};")
2199        else:
2200            cw.block_start(line=line + ' =')
2201
2202    if not terminate:
2203        return
2204
2205    cw.nl()
2206    for name in family.hooks['pre']['do']['list']:
2207        cw.write_func_prot('int', c_lower(name),
2208                           ['const struct genl_split_ops *ops',
2209                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2210    for name in family.hooks['post']['do']['list']:
2211        cw.write_func_prot('void', c_lower(name),
2212                           ['const struct genl_split_ops *ops',
2213                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2214    for name in family.hooks['pre']['dump']['list']:
2215        cw.write_func_prot('int', c_lower(name),
2216                           ['struct netlink_callback *cb'], suffix=';')
2217    for name in family.hooks['post']['dump']['list']:
2218        cw.write_func_prot('int', c_lower(name),
2219                           ['struct netlink_callback *cb'], suffix=';')
2220
2221    cw.nl()
2222
2223    for op_name, op in family.ops.items():
2224        if op.is_async:
2225            continue
2226
2227        if 'do' in op:
2228            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2229            cw.write_func_prot('int', name,
2230                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2231
2232        if 'dump' in op:
2233            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2234            cw.write_func_prot('int', name,
2235                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2236    cw.nl()
2237
2238
2239def print_kernel_op_table_hdr(family, cw):
2240    print_kernel_op_table_fwd(family, cw, terminate=True)
2241
2242
2243def print_kernel_op_table(family, cw):
2244    print_kernel_op_table_fwd(family, cw, terminate=False)
2245    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2246        for op_name, op in family.ops.items():
2247            if op.is_async:
2248                continue
2249
2250            cw.ifdef_block(op.get('config-cond', None))
2251            cw.block_start()
2252            members = [('cmd', op.enum_name)]
2253            if 'dont-validate' in op:
2254                members.append(('validate',
2255                                ' | '.join([c_upper('genl-dont-validate-' + x)
2256                                            for x in op['dont-validate']])), )
2257            for op_mode in ['do', 'dump']:
2258                if op_mode in op:
2259                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2260                    members.append((op_mode + 'it', name))
2261            if family.kernel_policy == 'per-op':
2262                struct = Struct(family, op['attribute-set'],
2263                                type_list=op['do']['request']['attributes'])
2264
2265                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2266                members.append(('policy', name))
2267                members.append(('maxattr', struct.attr_max_val.enum_name))
2268            if 'flags' in op:
2269                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2270            cw.write_struct_init(members)
2271            cw.block_end(line=',')
2272    elif family.kernel_policy == 'split':
2273        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2274                    'dump': {'pre': 'start', 'post': 'done'}}
2275
2276        for op_name, op in family.ops.items():
2277            for op_mode in ['do', 'dump']:
2278                if op.is_async or op_mode not in op:
2279                    continue
2280
2281                cw.ifdef_block(op.get('config-cond', None))
2282                cw.block_start()
2283                members = [('cmd', op.enum_name)]
2284                if 'dont-validate' in op:
2285                    dont_validate = []
2286                    for x in op['dont-validate']:
2287                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2288                            continue
2289                        if op_mode == "dump" and x == 'strict':
2290                            continue
2291                        dont_validate.append(x)
2292
2293                    if dont_validate:
2294                        members.append(('validate',
2295                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2296                                                    for x in dont_validate])), )
2297                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2298                if 'pre' in op[op_mode]:
2299                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2300                members.append((op_mode + 'it', name))
2301                if 'post' in op[op_mode]:
2302                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2303                if 'request' in op[op_mode]:
2304                    struct = Struct(family, op['attribute-set'],
2305                                    type_list=op[op_mode]['request']['attributes'])
2306
2307                    if op.dual_policy:
2308                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2309                    else:
2310                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2311                    members.append(('policy', name))
2312                    members.append(('maxattr', struct.attr_max_val.enum_name))
2313                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2314                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2315                cw.write_struct_init(members)
2316                cw.block_end(line=',')
2317    cw.ifdef_block(None)
2318
2319    cw.block_end(line=';')
2320    cw.nl()
2321
2322
2323def print_kernel_mcgrp_hdr(family, cw):
2324    if not family.mcgrps['list']:
2325        return
2326
2327    cw.block_start('enum')
2328    for grp in family.mcgrps['list']:
2329        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2330        cw.p(grp_id)
2331    cw.block_end(';')
2332    cw.nl()
2333
2334
2335def print_kernel_mcgrp_src(family, cw):
2336    if not family.mcgrps['list']:
2337        return
2338
2339    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2340    for grp in family.mcgrps['list']:
2341        name = grp['name']
2342        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2343        cw.p('[' + grp_id + '] = { "' + name + '", },')
2344    cw.block_end(';')
2345    cw.nl()
2346
2347
2348def print_kernel_family_struct_hdr(family, cw):
2349    if not kernel_can_gen_family_struct(family):
2350        return
2351
2352    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2353    cw.nl()
2354    if 'sock-priv' in family.kernel_family:
2355        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2356        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2357        cw.nl()
2358
2359
2360def print_kernel_family_struct_src(family, cw):
2361    if not kernel_can_gen_family_struct(family):
2362        return
2363
2364    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2365    cw.p('.name\t\t= ' + family.fam_key + ',')
2366    cw.p('.version\t= ' + family.ver_key + ',')
2367    cw.p('.netnsok\t= true,')
2368    cw.p('.parallel_ops\t= true,')
2369    cw.p('.module\t\t= THIS_MODULE,')
2370    if family.kernel_policy == 'per-op':
2371        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2372        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2373    elif family.kernel_policy == 'split':
2374        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2375        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2376    if family.mcgrps['list']:
2377        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2378        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2379    if 'sock-priv' in family.kernel_family:
2380        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2381        # Force cast here, actual helpers take pointer to the real type.
2382        cw.p(f'.sock_priv_init\t= (void *){family.c_name}_nl_sock_priv_init,')
2383        cw.p(f'.sock_priv_destroy = (void *){family.c_name}_nl_sock_priv_destroy,')
2384    cw.block_end(';')
2385
2386
2387def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2388    start_line = 'enum'
2389    if enum_name in obj:
2390        if obj[enum_name]:
2391            start_line = 'enum ' + c_lower(obj[enum_name])
2392    elif ckey and ckey in obj:
2393        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2394    cw.block_start(line=start_line)
2395
2396
2397def render_uapi(family, cw):
2398    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2399    cw.p('#ifndef ' + hdr_prot)
2400    cw.p('#define ' + hdr_prot)
2401    cw.nl()
2402
2403    defines = [(family.fam_key, family["name"]),
2404               (family.ver_key, family.get('version', 1))]
2405    cw.writes_defines(defines)
2406    cw.nl()
2407
2408    defines = []
2409    for const in family['definitions']:
2410        if const['type'] != 'const':
2411            cw.writes_defines(defines)
2412            defines = []
2413            cw.nl()
2414
2415        # Write kdoc for enum and flags (one day maybe also structs)
2416        if const['type'] == 'enum' or const['type'] == 'flags':
2417            enum = family.consts[const['name']]
2418
2419            if enum.has_doc():
2420                cw.p('/**')
2421                doc = ''
2422                if 'doc' in enum:
2423                    doc = ' - ' + enum['doc']
2424                cw.write_doc_line(enum.enum_name + doc)
2425                for entry in enum.entries.values():
2426                    if entry.has_doc():
2427                        doc = '@' + entry.c_name + ': ' + entry['doc']
2428                        cw.write_doc_line(doc)
2429                cw.p(' */')
2430
2431            uapi_enum_start(family, cw, const, 'name')
2432            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2433            for entry in enum.entries.values():
2434                suffix = ','
2435                if entry.value_change:
2436                    suffix = f" = {entry.user_value()}" + suffix
2437                cw.p(entry.c_name + suffix)
2438
2439            if const.get('render-max', False):
2440                cw.nl()
2441                cw.p('/* private: */')
2442                if const['type'] == 'flags':
2443                    max_name = c_upper(name_pfx + 'mask')
2444                    max_val = f' = {enum.get_mask()},'
2445                    cw.p(max_name + max_val)
2446                else:
2447                    max_name = c_upper(name_pfx + 'max')
2448                    cw.p('__' + max_name + ',')
2449                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2450            cw.block_end(line=';')
2451            cw.nl()
2452        elif const['type'] == 'const':
2453            defines.append([c_upper(family.get('c-define-name',
2454                                               f"{family.ident_name}-{const['name']}")),
2455                            const['value']])
2456
2457    if defines:
2458        cw.writes_defines(defines)
2459        cw.nl()
2460
2461    max_by_define = family.get('max-by-define', False)
2462
2463    for _, attr_set in family.attr_sets.items():
2464        if attr_set.subset_of:
2465            continue
2466
2467        max_value = f"({attr_set.cnt_name} - 1)"
2468
2469        val = 0
2470        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2471        for _, attr in attr_set.items():
2472            suffix = ','
2473            if attr.value != val:
2474                suffix = f" = {attr.value},"
2475                val = attr.value
2476            val += 1
2477            cw.p(attr.enum_name + suffix)
2478        cw.nl()
2479        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2480        if not max_by_define:
2481            cw.p(f"{attr_set.max_name} = {max_value}")
2482        cw.block_end(line=';')
2483        if max_by_define:
2484            cw.p(f"#define {attr_set.max_name} {max_value}")
2485        cw.nl()
2486
2487    # Commands
2488    separate_ntf = 'async-prefix' in family['operations']
2489
2490    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2491    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2492    max_value = f"({cnt_name} - 1)"
2493
2494    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2495    val = 0
2496    for op in family.msgs.values():
2497        if separate_ntf and ('notify' in op or 'event' in op):
2498            continue
2499
2500        suffix = ','
2501        if op.value != val:
2502            suffix = f" = {op.value},"
2503            val = op.value
2504        cw.p(op.enum_name + suffix)
2505        val += 1
2506    cw.nl()
2507    cw.p(cnt_name + ('' if max_by_define else ','))
2508    if not max_by_define:
2509        cw.p(f"{max_name} = {max_value}")
2510    cw.block_end(line=';')
2511    if max_by_define:
2512        cw.p(f"#define {max_name} {max_value}")
2513    cw.nl()
2514
2515    if separate_ntf:
2516        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2517        for op in family.msgs.values():
2518            if separate_ntf and not ('notify' in op or 'event' in op):
2519                continue
2520
2521            suffix = ','
2522            if 'value' in op:
2523                suffix = f" = {op['value']},"
2524            cw.p(op.enum_name + suffix)
2525        cw.block_end(line=';')
2526        cw.nl()
2527
2528    # Multicast
2529    defines = []
2530    for grp in family.mcgrps['list']:
2531        name = grp['name']
2532        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2533                        f'{name}'])
2534    cw.nl()
2535    if defines:
2536        cw.writes_defines(defines)
2537        cw.nl()
2538
2539    cw.p(f'#endif /* {hdr_prot} */')
2540
2541
2542def _render_user_ntf_entry(ri, op):
2543    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2544    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2545    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2546    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2547    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2548    ri.cw.block_end(line=',')
2549
2550
2551def render_user_family(family, cw, prototype):
2552    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2553    if prototype:
2554        cw.p(f'extern {symbol};')
2555        return
2556
2557    if family.ntfs:
2558        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2559        for ntf_op_name, ntf_op in family.ntfs.items():
2560            if 'notify' in ntf_op:
2561                op = family.ops[ntf_op['notify']]
2562                ri = RenderInfo(cw, family, "user", op, "notify")
2563            elif 'event' in ntf_op:
2564                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2565            else:
2566                raise Exception('Invalid notification ' + ntf_op_name)
2567            _render_user_ntf_entry(ri, ntf_op)
2568        for op_name, op in family.ops.items():
2569            if 'event' not in op:
2570                continue
2571            ri = RenderInfo(cw, family, "user", op, "event")
2572            _render_user_ntf_entry(ri, op)
2573        cw.block_end(line=";")
2574        cw.nl()
2575
2576    cw.block_start(f'{symbol} = ')
2577    cw.p(f'.name\t\t= "{family.c_name}",')
2578    if family.fixed_header:
2579        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2580    else:
2581        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2582    if family.ntfs:
2583        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2584        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2585    cw.block_end(line=';')
2586
2587
2588def family_contains_bitfield32(family):
2589    for _, attr_set in family.attr_sets.items():
2590        if attr_set.subset_of:
2591            continue
2592        for _, attr in attr_set.items():
2593            if attr.type == "bitfield32":
2594                return True
2595    return False
2596
2597
2598def find_kernel_root(full_path):
2599    sub_path = ''
2600    while True:
2601        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2602        full_path = os.path.dirname(full_path)
2603        maintainers = os.path.join(full_path, "MAINTAINERS")
2604        if os.path.exists(maintainers):
2605            return full_path, sub_path[:-1]
2606
2607
2608def main():
2609    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2610    parser.add_argument('--mode', dest='mode', type=str, required=True)
2611    parser.add_argument('--spec', dest='spec', type=str, required=True)
2612    parser.add_argument('--header', dest='header', action='store_true', default=None)
2613    parser.add_argument('--source', dest='header', action='store_false')
2614    parser.add_argument('--user-header', nargs='+', default=[])
2615    parser.add_argument('--cmp-out', action='store_true', default=None,
2616                        help='Do not overwrite the output file if the new output is identical to the old')
2617    parser.add_argument('--exclude-op', action='append', default=[])
2618    parser.add_argument('-o', dest='out_file', type=str, default=None)
2619    args = parser.parse_args()
2620
2621    if args.header is None:
2622        parser.error("--header or --source is required")
2623
2624    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2625
2626    try:
2627        parsed = Family(args.spec, exclude_ops)
2628        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2629            print('Spec license:', parsed.license)
2630            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2631            os.sys.exit(1)
2632    except yaml.YAMLError as exc:
2633        print(exc)
2634        os.sys.exit(1)
2635        return
2636
2637    supported_models = ['unified']
2638    if args.mode in ['user', 'kernel']:
2639        supported_models += ['directional']
2640    if parsed.msg_id_model not in supported_models:
2641        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2642        os.sys.exit(1)
2643
2644    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2645
2646    _, spec_kernel = find_kernel_root(args.spec)
2647    if args.mode == 'uapi' or args.header:
2648        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2649    else:
2650        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2651    cw.p("/* Do not edit directly, auto-generated from: */")
2652    cw.p(f"/*\t{spec_kernel} */")
2653    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2654    if args.exclude_op or args.user_header:
2655        line = ''
2656        line += ' --user-header '.join([''] + args.user_header)
2657        line += ' --exclude-op '.join([''] + args.exclude_op)
2658        cw.p(f'/* YNL-ARG{line} */')
2659    cw.nl()
2660
2661    if args.mode == 'uapi':
2662        render_uapi(parsed, cw)
2663        return
2664
2665    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2666    if args.header:
2667        cw.p('#ifndef ' + hdr_prot)
2668        cw.p('#define ' + hdr_prot)
2669        cw.nl()
2670
2671    hdr_file=os.path.basename(args.out_file[:-2]) + ".h"
2672
2673    if args.mode == 'kernel':
2674        cw.p('#include <net/netlink.h>')
2675        cw.p('#include <net/genetlink.h>')
2676        cw.nl()
2677        if not args.header:
2678            if args.out_file:
2679                cw.p(f'#include "{hdr_file}"')
2680            cw.nl()
2681        headers = ['uapi/' + parsed.uapi_header]
2682        headers += parsed.kernel_family.get('headers', [])
2683    else:
2684        cw.p('#include <stdlib.h>')
2685        cw.p('#include <string.h>')
2686        if args.header:
2687            cw.p('#include <linux/types.h>')
2688            if family_contains_bitfield32(parsed):
2689                cw.p('#include <linux/netlink.h>')
2690        else:
2691            cw.p(f'#include "{hdr_file}"')
2692            cw.p('#include "ynl.h"')
2693        headers = [parsed.uapi_header]
2694    for definition in parsed['definitions']:
2695        if 'header' in definition:
2696            headers.append(definition['header'])
2697    for one in headers:
2698        cw.p(f"#include <{one}>")
2699    cw.nl()
2700
2701    if args.mode == "user":
2702        if not args.header:
2703            cw.p("#include <linux/genetlink.h>")
2704            cw.nl()
2705            for one in args.user_header:
2706                cw.p(f'#include "{one}"')
2707        else:
2708            cw.p('struct ynl_sock;')
2709            cw.nl()
2710            render_user_family(parsed, cw, True)
2711        cw.nl()
2712
2713    if args.mode == "kernel":
2714        if args.header:
2715            for _, struct in sorted(parsed.pure_nested_structs.items()):
2716                if struct.request:
2717                    cw.p('/* Common nested types */')
2718                    break
2719            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2720                if struct.request:
2721                    print_req_policy_fwd(cw, struct)
2722            cw.nl()
2723
2724            if parsed.kernel_policy == 'global':
2725                cw.p(f"/* Global operation policy for {parsed.name} */")
2726
2727                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2728                print_req_policy_fwd(cw, struct)
2729                cw.nl()
2730
2731            if parsed.kernel_policy in {'per-op', 'split'}:
2732                for op_name, op in parsed.ops.items():
2733                    if 'do' in op and 'event' not in op:
2734                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2735                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2736                        cw.nl()
2737
2738            print_kernel_op_table_hdr(parsed, cw)
2739            print_kernel_mcgrp_hdr(parsed, cw)
2740            print_kernel_family_struct_hdr(parsed, cw)
2741        else:
2742            print_kernel_policy_ranges(parsed, cw)
2743
2744            for _, struct in sorted(parsed.pure_nested_structs.items()):
2745                if struct.request:
2746                    cw.p('/* Common nested types */')
2747                    break
2748            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2749                if struct.request:
2750                    print_req_policy(cw, struct)
2751            cw.nl()
2752
2753            if parsed.kernel_policy == 'global':
2754                cw.p(f"/* Global operation policy for {parsed.name} */")
2755
2756                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2757                print_req_policy(cw, struct)
2758                cw.nl()
2759
2760            for op_name, op in parsed.ops.items():
2761                if parsed.kernel_policy in {'per-op', 'split'}:
2762                    for op_mode in ['do', 'dump']:
2763                        if op_mode in op and 'request' in op[op_mode]:
2764                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2765                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2766                            print_req_policy(cw, ri.struct['request'], ri=ri)
2767                            cw.nl()
2768
2769            print_kernel_op_table(parsed, cw)
2770            print_kernel_mcgrp_src(parsed, cw)
2771            print_kernel_family_struct_src(parsed, cw)
2772
2773    if args.mode == "user":
2774        if args.header:
2775            cw.p('/* Enums */')
2776            put_op_name_fwd(parsed, cw)
2777
2778            for name, const in parsed.consts.items():
2779                if isinstance(const, EnumSet):
2780                    put_enum_to_str_fwd(parsed, cw, const)
2781            cw.nl()
2782
2783            cw.p('/* Common nested types */')
2784            for attr_set, struct in parsed.pure_nested_structs.items():
2785                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2786                print_type_full(ri, struct)
2787
2788            for op_name, op in parsed.ops.items():
2789                cw.p(f"/* ============== {op.enum_name} ============== */")
2790
2791                if 'do' in op and 'event' not in op:
2792                    cw.p(f"/* {op.enum_name} - do */")
2793                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2794                    print_req_type(ri)
2795                    print_req_type_helpers(ri)
2796                    cw.nl()
2797                    print_rsp_type(ri)
2798                    print_rsp_type_helpers(ri)
2799                    cw.nl()
2800                    print_req_prototype(ri)
2801                    cw.nl()
2802
2803                if 'dump' in op:
2804                    cw.p(f"/* {op.enum_name} - dump */")
2805                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2806                    print_req_type(ri)
2807                    print_req_type_helpers(ri)
2808                    if not ri.type_consistent:
2809                        print_rsp_type(ri)
2810                    print_wrapped_type(ri)
2811                    print_dump_prototype(ri)
2812                    cw.nl()
2813
2814                if op.has_ntf:
2815                    cw.p(f"/* {op.enum_name} - notify */")
2816                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2817                    if not ri.type_consistent:
2818                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2819                    print_wrapped_type(ri)
2820
2821            for op_name, op in parsed.ntfs.items():
2822                if 'event' in op:
2823                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2824                    cw.p(f"/* {op.enum_name} - event */")
2825                    print_rsp_type(ri)
2826                    cw.nl()
2827                    print_wrapped_type(ri)
2828            cw.nl()
2829        else:
2830            cw.p('/* Enums */')
2831            put_op_name(parsed, cw)
2832
2833            for name, const in parsed.consts.items():
2834                if isinstance(const, EnumSet):
2835                    put_enum_to_str(parsed, cw, const)
2836            cw.nl()
2837
2838            has_recursive_nests = False
2839            cw.p('/* Policies */')
2840            for struct in parsed.pure_nested_structs.values():
2841                if struct.recursive:
2842                    put_typol_fwd(cw, struct)
2843                    has_recursive_nests = True
2844            if has_recursive_nests:
2845                cw.nl()
2846            for name in parsed.pure_nested_structs:
2847                struct = Struct(parsed, name)
2848                put_typol(cw, struct)
2849            for name in parsed.root_sets:
2850                struct = Struct(parsed, name)
2851                put_typol(cw, struct)
2852
2853            cw.p('/* Common nested types */')
2854            if has_recursive_nests:
2855                for attr_set, struct in parsed.pure_nested_structs.items():
2856                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2857                    free_rsp_nested_prototype(ri)
2858                    if struct.request:
2859                        put_req_nested_prototype(ri, struct)
2860                    if struct.reply:
2861                        parse_rsp_nested_prototype(ri, struct)
2862                cw.nl()
2863            for attr_set, struct in parsed.pure_nested_structs.items():
2864                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2865
2866                free_rsp_nested(ri, struct)
2867                if struct.request:
2868                    put_req_nested(ri, struct)
2869                if struct.reply:
2870                    parse_rsp_nested(ri, struct)
2871
2872            for op_name, op in parsed.ops.items():
2873                cw.p(f"/* ============== {op.enum_name} ============== */")
2874                if 'do' in op and 'event' not in op:
2875                    cw.p(f"/* {op.enum_name} - do */")
2876                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2877                    print_req_free(ri)
2878                    print_rsp_free(ri)
2879                    parse_rsp_msg(ri)
2880                    print_req(ri)
2881                    cw.nl()
2882
2883                if 'dump' in op:
2884                    cw.p(f"/* {op.enum_name} - dump */")
2885                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2886                    if not ri.type_consistent:
2887                        parse_rsp_msg(ri, deref=True)
2888                    print_req_free(ri)
2889                    print_dump_type_free(ri)
2890                    print_dump(ri)
2891                    cw.nl()
2892
2893                if op.has_ntf:
2894                    cw.p(f"/* {op.enum_name} - notify */")
2895                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2896                    if not ri.type_consistent:
2897                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2898                    print_ntf_type_free(ri)
2899
2900            for op_name, op in parsed.ntfs.items():
2901                if 'event' in op:
2902                    cw.p(f"/* {op.enum_name} - event */")
2903
2904                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2905                    parse_rsp_msg(ri)
2906
2907                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2908                    print_ntf_type_free(ri)
2909            cw.nl()
2910            render_user_family(parsed, cw, False)
2911
2912    if args.header:
2913        cw.p(f'#endif /* {hdr_prot} */')
2914
2915
2916if __name__ == "__main__":
2917    main()
2918