1#!/usr/bin/env python3
2# ex: set filetype=python:
3
4"""Define and implement the Abstract Syntax Tree for the XDR language."""
5
6import sys
7from typing import List
8from dataclasses import dataclass
9
10from lark import ast_utils, Transformer
11from lark.tree import Meta
12
13this_module = sys.modules[__name__]
14
15excluded_apis = []
16header_name = "none"
17public_apis = []
18enums = set()
19structs = set()
20pass_by_reference = set()
21
22
23@dataclass
24class _XdrAst(ast_utils.Ast):
25    """Base class for the XDR abstract syntax tree"""
26
27
28@dataclass
29class _XdrIdentifier(_XdrAst):
30    """Corresponds to 'identifier' in the XDR language grammar"""
31
32    symbol: str
33
34
35@dataclass
36class _XdrValue(_XdrAst):
37    """Corresponds to 'value' in the XDR language grammar"""
38
39    value: str
40
41
42@dataclass
43class _XdrConstantValue(_XdrAst):
44    """Corresponds to 'constant' in the XDR language grammar"""
45
46    value: int
47
48
49@dataclass
50class _XdrTypeSpecifier(_XdrAst):
51    """Corresponds to 'type_specifier' in the XDR language grammar"""
52
53    type_name: str
54    c_classifier: str
55
56
57@dataclass
58class _XdrDefinedType(_XdrTypeSpecifier):
59    """Corresponds to a type defined by the input specification"""
60
61
62@dataclass
63class _XdrBuiltInType(_XdrTypeSpecifier):
64    """Corresponds to a built-in XDR type"""
65
66
67@dataclass
68class _XdrDeclaration(_XdrAst):
69    """Base class of XDR type declarations"""
70
71
72@dataclass
73class _XdrFixedLengthOpaque(_XdrDeclaration):
74    """A fixed-length opaque declaration"""
75
76    name: str
77    size: str
78    template: str = "fixed_length_opaque"
79
80
81@dataclass
82class _XdrVariableLengthOpaque(_XdrDeclaration):
83    """A variable-length opaque declaration"""
84
85    name: str
86    maxsize: str
87    template: str = "variable_length_opaque"
88
89
90@dataclass
91class _XdrVariableLengthString(_XdrDeclaration):
92    """A (NUL-terminated) variable-length string declaration"""
93
94    name: str
95    maxsize: str
96    template: str = "variable_length_string"
97
98
99@dataclass
100class _XdrFixedLengthArray(_XdrDeclaration):
101    """A fixed-length array declaration"""
102
103    name: str
104    spec: _XdrTypeSpecifier
105    size: str
106    template: str = "fixed_length_array"
107
108
109@dataclass
110class _XdrVariableLengthArray(_XdrDeclaration):
111    """A variable-length array declaration"""
112
113    name: str
114    spec: _XdrTypeSpecifier
115    maxsize: str
116    template: str = "variable_length_array"
117
118
119@dataclass
120class _XdrOptionalData(_XdrDeclaration):
121    """An 'optional_data' declaration"""
122
123    name: str
124    spec: _XdrTypeSpecifier
125    template: str = "optional_data"
126
127
128@dataclass
129class _XdrBasic(_XdrDeclaration):
130    """A 'basic' declaration"""
131
132    name: str
133    spec: _XdrTypeSpecifier
134    template: str = "basic"
135
136
137@dataclass
138class _XdrVoid(_XdrDeclaration):
139    """A void declaration"""
140
141    template: str = "void"
142
143
144@dataclass
145class _XdrConstant(_XdrAst):
146    """Corresponds to 'constant_def' in the grammar"""
147
148    name: str
149    value: str
150
151
152@dataclass
153class _XdrEnumerator(_XdrAst):
154    """An 'identifier = value' enumerator"""
155
156    name: str
157    value: str
158
159
160@dataclass
161class _XdrEnum(_XdrAst):
162    """An XDR enum definition"""
163
164    name: str
165    minimum: int
166    maximum: int
167    enumerators: List[_XdrEnumerator]
168
169
170@dataclass
171class _XdrStruct(_XdrAst):
172    """An XDR struct definition"""
173
174    name: str
175    fields: List[_XdrDeclaration]
176
177
178@dataclass
179class _XdrPointer(_XdrAst):
180    """An XDR pointer definition"""
181
182    name: str
183    fields: List[_XdrDeclaration]
184
185
186@dataclass
187class _XdrTypedef(_XdrAst):
188    """An XDR typedef"""
189
190    declaration: _XdrDeclaration
191
192
193@dataclass
194class _XdrCaseSpec(_XdrAst):
195    """One case in an XDR union"""
196
197    values: List[str]
198    arm: _XdrDeclaration
199    template: str = "case_spec"
200
201
202@dataclass
203class _XdrDefaultSpec(_XdrAst):
204    """Default case in an XDR union"""
205
206    arm: _XdrDeclaration
207    template: str = "default_spec"
208
209
210@dataclass
211class _XdrUnion(_XdrAst):
212    """An XDR union"""
213
214    name: str
215    discriminant: _XdrDeclaration
216    cases: List[_XdrCaseSpec]
217    default: _XdrDeclaration
218
219
220@dataclass
221class _RpcProcedure(_XdrAst):
222    """RPC procedure definition"""
223
224    name: str
225    number: str
226    argument: _XdrTypeSpecifier
227    result: _XdrTypeSpecifier
228
229
230@dataclass
231class _RpcVersion(_XdrAst):
232    """RPC version definition"""
233
234    name: str
235    number: str
236    procedures: List[_RpcProcedure]
237
238
239@dataclass
240class _RpcProgram(_XdrAst):
241    """RPC program definition"""
242
243    name: str
244    number: str
245    versions: List[_RpcVersion]
246
247
248@dataclass
249class _Pragma(_XdrAst):
250    """Empty class for pragma directives"""
251
252
253@dataclass
254class Definition(_XdrAst, ast_utils.WithMeta):
255    """Corresponds to 'definition' in the grammar"""
256
257    meta: Meta
258    value: _XdrAst
259
260
261@dataclass
262class Specification(_XdrAst, ast_utils.AsList):
263    """Corresponds to 'specification' in the grammar"""
264
265    definitions: List[Definition]
266
267
268class ParseToAst(Transformer):
269    """Functions that transform productions into AST nodes"""
270
271    def identifier(self, children):
272        """Instantiate one _XdrIdentifier object"""
273        return _XdrIdentifier(children[0].value)
274
275    def value(self, children):
276        """Instantiate one _XdrValue object"""
277        if isinstance(children[0], _XdrIdentifier):
278            return _XdrValue(children[0].symbol)
279        return _XdrValue(children[0].children[0].value)
280
281    def constant(self, children):
282        """Instantiate one _XdrConstantValue object"""
283        match children[0].data:
284            case "decimal_constant":
285                value = int(children[0].children[0].value, base=10)
286            case "hexadecimal_constant":
287                value = int(children[0].children[0].value, base=16)
288            case "octal_constant":
289                value = int(children[0].children[0].value, base=8)
290        return _XdrConstantValue(value)
291
292    def type_specifier(self, children):
293        """Instantiate one type_specifier object"""
294        c_classifier = ""
295        if isinstance(children[0], _XdrIdentifier):
296            name = children[0].symbol
297            if name in enums:
298                c_classifier = "enum "
299            if name in structs:
300                c_classifier = "struct "
301            return _XdrDefinedType(
302                type_name=name,
303                c_classifier=c_classifier,
304            )
305
306        token = children[0].data
307        return _XdrBuiltInType(
308            type_name=token.value,
309            c_classifier=c_classifier,
310        )
311
312    def constant_def(self, children):
313        """Instantiate one _XdrConstant object"""
314        name = children[0].symbol
315        value = children[1].value
316        return _XdrConstant(name, value)
317
318    # cel: Python can compute a min() and max() for the enumerator values
319    #      so that the generated code can perform proper range checking.
320    def enum(self, children):
321        """Instantiate one _XdrEnum object"""
322        enum_name = children[0].symbol
323        enums.add(enum_name)
324
325        i = 0
326        enumerators = []
327        body = children[1]
328        while i < len(body.children):
329            name = body.children[i].symbol
330            value = body.children[i + 1].value
331            enumerators.append(_XdrEnumerator(name, value))
332            i = i + 2
333
334        return _XdrEnum(enum_name, 0, 0, enumerators)
335
336    def fixed_length_opaque(self, children):
337        """Instantiate one _XdrFixedLengthOpaque declaration object"""
338        name = children[0].symbol
339        size = children[1].value
340
341        return _XdrFixedLengthOpaque(name, size)
342
343    def variable_length_opaque(self, children):
344        """Instantiate one _XdrVariableLengthOpaque declaration object"""
345        name = children[0].symbol
346        if children[1] is not None:
347            maxsize = children[1].value
348        else:
349            maxsize = "0"
350
351        return _XdrVariableLengthOpaque(name, maxsize)
352
353    def variable_length_string(self, children):
354        """Instantiate one _XdrVariableLengthString declaration object"""
355        name = children[0].symbol
356        if children[1] is not None:
357            maxsize = children[1].value
358        else:
359            maxsize = "0"
360
361        return _XdrVariableLengthString(name, maxsize)
362
363    def fixed_length_array(self, children):
364        """Instantiate one _XdrFixedLengthArray declaration object"""
365        spec = children[0]
366        name = children[1].symbol
367        size = children[2].value
368
369        return _XdrFixedLengthArray(name, spec, size)
370
371    def variable_length_array(self, children):
372        """Instantiate one _XdrVariableLengthArray declaration object"""
373        spec = children[0]
374        name = children[1].symbol
375        if children[2] is not None:
376            maxsize = children[2].value
377        else:
378            maxsize = "0"
379
380        return _XdrVariableLengthArray(name, spec, maxsize)
381
382    def optional_data(self, children):
383        """Instantiate one _XdrOptionalData declaration object"""
384        spec = children[0]
385        name = children[1].symbol
386        structs.add(name)
387        pass_by_reference.add(name)
388
389        return _XdrOptionalData(name, spec)
390
391    def basic(self, children):
392        """Instantiate one _XdrBasic object"""
393        spec = children[0]
394        name = children[1].symbol
395
396        return _XdrBasic(name, spec)
397
398    def void(self, children):
399        """Instantiate one _XdrVoid declaration object"""
400
401        return _XdrVoid()
402
403    def struct(self, children):
404        """Instantiate one _XdrStruct object"""
405        name = children[0].symbol
406        structs.add(name)
407        pass_by_reference.add(name)
408        fields = children[1].children
409
410        last_field = fields[-1]
411        if (
412            isinstance(last_field, _XdrOptionalData)
413            and name == last_field.spec.type_name
414        ):
415            return _XdrPointer(name, fields)
416
417        return _XdrStruct(name, fields)
418
419    def typedef(self, children):
420        """Instantiate one _XdrTypedef object"""
421        new_type = children[0]
422        if isinstance(new_type, _XdrBasic) and isinstance(
423            new_type.spec, _XdrDefinedType
424        ):
425            if new_type.spec.type_name in pass_by_reference:
426                pass_by_reference.add(new_type.name)
427
428        return _XdrTypedef(new_type)
429
430    def case_spec(self, children):
431        """Instantiate one _XdrCaseSpec object"""
432        values = []
433        for item in children[0:-1]:
434            values.append(item.value)
435        arm = children[-1]
436
437        return _XdrCaseSpec(values, arm)
438
439    def default_spec(self, children):
440        """Instantiate one _XdrDefaultSpec object"""
441        arm = children[0]
442
443        return _XdrDefaultSpec(arm)
444
445    def union(self, children):
446        """Instantiate one _XdrUnion object"""
447        name = children[0].symbol
448        structs.add(name)
449        pass_by_reference.add(name)
450
451        body = children[1]
452        discriminant = body.children[0].children[0]
453        cases = body.children[1:-1]
454        default = body.children[-1]
455
456        return _XdrUnion(name, discriminant, cases, default)
457
458    def procedure_def(self, children):
459        """Instantiate one _RpcProcedure object"""
460        result = children[0]
461        name = children[1].symbol
462        argument = children[2]
463        number = children[3].value
464
465        return _RpcProcedure(name, number, argument, result)
466
467    def version_def(self, children):
468        """Instantiate one _RpcVersion object"""
469        name = children[0].symbol
470        number = children[-1].value
471        procedures = children[1:-1]
472
473        return _RpcVersion(name, number, procedures)
474
475    def program_def(self, children):
476        """Instantiate one _RpcProgram object"""
477        name = children[0].symbol
478        number = children[-1].value
479        versions = children[1:-1]
480
481        return _RpcProgram(name, number, versions)
482
483    def pragma_def(self, children):
484        """Instantiate one _Pragma object"""
485        directive = children[0].children[0].data
486        match directive:
487            case "exclude_directive":
488                excluded_apis.append(children[1].symbol)
489            case "header_directive":
490                global header_name
491                header_name = children[1].symbol
492            case "public_directive":
493                public_apis.append(children[1].symbol)
494            case _:
495                raise NotImplementedError("Directive not supported")
496        return _Pragma()
497
498
499transformer = ast_utils.create_transformer(this_module, ParseToAst())
500
501
502def transform_parse_tree(parse_tree):
503    """Transform productions into an abstract syntax tree"""
504
505    return transformer.transform(parse_tree)
506
507
508def get_header_name() -> str:
509    """Return header name set by pragma header directive"""
510    return header_name
511