1#!/usr/bin/env python3
2# ex: set filetype=python:
3
4"""Generate code for an RPC program's procedures"""
5
6from jinja2 import Environment
7
8from generators import SourceGenerator, create_jinja2_environment
9from xdr_ast import _RpcProgram, _RpcVersion, excluded_apis
10
11
12def emit_version_definitions(
13    environment: Environment, program: str, version: _RpcVersion
14) -> None:
15    """Emit procedure numbers for each RPC version's procedures"""
16    template = environment.get_template("definition/open.j2")
17    print(template.render(program=program.upper()))
18
19    template = environment.get_template("definition/procedure.j2")
20    for procedure in version.procedures:
21        if procedure.name not in excluded_apis:
22            print(
23                template.render(
24                    name=procedure.name,
25                    value=procedure.number,
26                )
27            )
28
29    template = environment.get_template("definition/close.j2")
30    print(template.render())
31
32
33def emit_version_declarations(
34    environment: Environment, program: str, version: _RpcVersion
35) -> None:
36    """Emit declarations for each RPC version's procedures"""
37    arguments = dict.fromkeys([])
38    for procedure in version.procedures:
39        if procedure.name not in excluded_apis:
40            arguments[procedure.argument.type_name] = None
41    if len(arguments) > 0:
42        print("")
43        template = environment.get_template("declaration/argument.j2")
44        for argument in arguments:
45            print(template.render(program=program, argument=argument))
46
47    results = dict.fromkeys([])
48    for procedure in version.procedures:
49        if procedure.name not in excluded_apis:
50            results[procedure.result.type_name] = None
51    if len(results) > 0:
52        print("")
53        template = environment.get_template("declaration/result.j2")
54        for result in results:
55            print(template.render(program=program, result=result))
56
57
58def emit_version_argument_decoders(
59    environment: Environment, program: str, version: _RpcVersion
60) -> None:
61    """Emit server argument decoders for each RPC version's procedures"""
62    arguments = dict.fromkeys([])
63    for procedure in version.procedures:
64        if procedure.name not in excluded_apis:
65            arguments[procedure.argument.type_name] = None
66
67    template = environment.get_template("decoder/argument.j2")
68    for argument in arguments:
69        print(template.render(program=program, argument=argument))
70
71
72def emit_version_result_decoders(
73    environment: Environment, program: str, version: _RpcVersion
74) -> None:
75    """Emit client result decoders for each RPC version's procedures"""
76    results = dict.fromkeys([])
77    for procedure in version.procedures:
78        if procedure.name not in excluded_apis:
79            results[procedure.result.type_name] = None
80
81    template = environment.get_template("decoder/result.j2")
82    for result in results:
83        print(template.render(program=program, result=result))
84
85
86def emit_version_argument_encoders(
87    environment: Environment, program: str, version: _RpcVersion
88) -> None:
89    """Emit client argument encoders for each RPC version's procedures"""
90    arguments = dict.fromkeys([])
91    for procedure in version.procedures:
92        if procedure.name not in excluded_apis:
93            arguments[procedure.argument.type_name] = None
94
95    template = environment.get_template("encoder/argument.j2")
96    for argument in arguments:
97        print(template.render(program=program, argument=argument))
98
99
100def emit_version_result_encoders(
101    environment: Environment, program: str, version: _RpcVersion
102) -> None:
103    """Emit server result encoders for each RPC version's procedures"""
104    results = dict.fromkeys([])
105    for procedure in version.procedures:
106        if procedure.name not in excluded_apis:
107            results[procedure.result.type_name] = None
108
109    template = environment.get_template("encoder/result.j2")
110    for result in results:
111        print(template.render(program=program, result=result))
112
113
114class XdrProgramGenerator(SourceGenerator):
115    """Generate source code for an RPC program's procedures"""
116
117    def __init__(self, language: str, peer: str):
118        """Initialize an instance of this class"""
119        self.environment = create_jinja2_environment(language, "program")
120        self.peer = peer
121
122    def emit_definition(self, node: _RpcProgram) -> None:
123        """Emit procedure numbers for each of an RPC programs's procedures"""
124        raw_name = node.name
125        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
126
127        for version in node.versions:
128            emit_version_definitions(self.environment, program, version)
129
130    def emit_declaration(self, node: _RpcProgram) -> None:
131        """Emit a declaration pair for each of an RPC programs's procedures"""
132        raw_name = node.name
133        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
134
135        for version in node.versions:
136            emit_version_declarations(self.environment, program, version)
137
138    def emit_decoder(self, node: _RpcProgram) -> None:
139        """Emit all decoder functions for an RPC program's procedures"""
140        raw_name = node.name
141        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
142        match self.peer:
143            case "server":
144                for version in node.versions:
145                    emit_version_argument_decoders(
146                        self.environment, program, version,
147                    )
148            case "client":
149                for version in node.versions:
150                    emit_version_result_decoders(
151                        self.environment, program, version,
152                    )
153
154    def emit_encoder(self, node: _RpcProgram) -> None:
155        """Emit all encoder functions for an RPC program's procedures"""
156        raw_name = node.name
157        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
158        match self.peer:
159            case "server":
160                for version in node.versions:
161                    emit_version_result_encoders(
162                        self.environment, program, version,
163                    )
164            case "client":
165                for version in node.versions:
166                    emit_version_argument_encoders(
167                        self.environment, program, version,
168                    )
169