1# SPDX-License-Identifier: GPL-2.0
2
3import os
4import time
5from pathlib import Path
6from lib.py import KsftSkipEx, KsftXfailEx
7from lib.py import ksft_setup
8from lib.py import cmd, ethtool, ip
9from lib.py import NetNS, NetdevSimDev
10from .remote import Remote
11
12
13def _load_env_file(src_path):
14    env = os.environ.copy()
15
16    src_dir = Path(src_path).parent.resolve()
17    if not (src_dir / "net.config").exists():
18        return ksft_setup(env)
19
20    with open((src_dir / "net.config").as_posix(), 'r') as fp:
21        for line in fp.readlines():
22            full_file = line
23            # Strip comments
24            pos = line.find("#")
25            if pos >= 0:
26                line = line[:pos]
27            line = line.strip()
28            if not line:
29                continue
30            pair = line.split('=', maxsplit=1)
31            if len(pair) != 2:
32                raise Exception("Can't parse configuration line:", full_file)
33            env[pair[0]] = pair[1]
34    return ksft_setup(env)
35
36
37class NetDrvEnv:
38    """
39    Class for a single NIC / host env, with no remote end
40    """
41    def __init__(self, src_path, **kwargs):
42        self._ns = None
43
44        self.env = _load_env_file(src_path)
45
46        if 'NETIF' in self.env:
47            self.dev = ip("link show dev " + self.env['NETIF'], json=True)[0]
48        else:
49            self._ns = NetdevSimDev(**kwargs)
50            self.dev = self._ns.nsims[0].dev
51        self.ifindex = self.dev['ifindex']
52
53    def __enter__(self):
54        ip(f"link set dev {self.dev['ifname']} up")
55
56        return self
57
58    def __exit__(self, ex_type, ex_value, ex_tb):
59        """
60        __exit__ gets called at the end of a "with" block.
61        """
62        self.__del__()
63
64    def __del__(self):
65        if self._ns:
66            self._ns.remove()
67            self._ns = None
68
69
70class NetDrvEpEnv:
71    """
72    Class for an environment with a local device and "remote endpoint"
73    which can be used to send traffic in.
74
75    For local testing it creates two network namespaces and a pair
76    of netdevsim devices.
77    """
78
79    # Network prefixes used for local tests
80    nsim_v4_pfx = "192.0.2."
81    nsim_v6_pfx = "2001:db8::"
82
83    def __init__(self, src_path, nsim_test=None):
84
85        self.env = _load_env_file(src_path)
86
87        self._stats_settle_time = None
88
89        # Things we try to destroy
90        self.remote = None
91        # These are for local testing state
92        self._netns = None
93        self._ns = None
94        self._ns_peer = None
95
96        if "NETIF" in self.env:
97            if nsim_test is True:
98                raise KsftXfailEx("Test only works on netdevsim")
99            self._check_env()
100
101            self.dev = ip("link show dev " + self.env['NETIF'], json=True)[0]
102
103            self.v4 = self.env.get("LOCAL_V4")
104            self.v6 = self.env.get("LOCAL_V6")
105            self.remote_v4 = self.env.get("REMOTE_V4")
106            self.remote_v6 = self.env.get("REMOTE_V6")
107            kind = self.env["REMOTE_TYPE"]
108            args = self.env["REMOTE_ARGS"]
109        else:
110            if nsim_test is False:
111                raise KsftXfailEx("Test does not work on netdevsim")
112
113            self.create_local()
114
115            self.dev = self._ns.nsims[0].dev
116
117            self.v4 = self.nsim_v4_pfx + "1"
118            self.v6 = self.nsim_v6_pfx + "1"
119            self.remote_v4 = self.nsim_v4_pfx + "2"
120            self.remote_v6 = self.nsim_v6_pfx + "2"
121            kind = "netns"
122            args = self._netns.name
123
124        self.remote = Remote(kind, args, src_path)
125
126        self.addr = self.v6 if self.v6 else self.v4
127        self.remote_addr = self.remote_v6 if self.remote_v6 else self.remote_v4
128
129        self.addr_ipver = "6" if self.v6 else "4"
130        # Bracketed addresses, some commands need IPv6 to be inside []
131        self.baddr = f"[{self.v6}]" if self.v6 else self.v4
132        self.remote_baddr = f"[{self.remote_v6}]" if self.remote_v6 else self.remote_v4
133
134        self.ifname = self.dev['ifname']
135        self.ifindex = self.dev['ifindex']
136
137        self._required_cmd = {}
138
139    def create_local(self):
140        self._netns = NetNS()
141        self._ns = NetdevSimDev()
142        self._ns_peer = NetdevSimDev(ns=self._netns)
143
144        with open("/proc/self/ns/net") as nsfd0, \
145             open("/var/run/netns/" + self._netns.name) as nsfd1:
146            ifi0 = self._ns.nsims[0].ifindex
147            ifi1 = self._ns_peer.nsims[0].ifindex
148            NetdevSimDev.ctrl_write('link_device',
149                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
150
151        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
152        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
153        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
154
155        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
156        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
157        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
158
159    def _check_env(self):
160        vars_needed = [
161            ["LOCAL_V4", "LOCAL_V6"],
162            ["REMOTE_V4", "REMOTE_V6"],
163            ["REMOTE_TYPE"],
164            ["REMOTE_ARGS"]
165        ]
166        missing = []
167
168        for choice in vars_needed:
169            for entry in choice:
170                if entry in self.env:
171                    break
172            else:
173                missing.append(choice)
174        # Make sure v4 / v6 configs are symmetric
175        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
176            missing.append(["LOCAL_V6", "REMOTE_V6"])
177        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
178            missing.append(["LOCAL_V4", "REMOTE_V4"])
179        if missing:
180            raise Exception("Invalid environment, missing configuration:", missing,
181                            "Please see tools/testing/selftests/drivers/net/README.rst")
182
183    def __enter__(self):
184        return self
185
186    def __exit__(self, ex_type, ex_value, ex_tb):
187        """
188        __exit__ gets called at the end of a "with" block.
189        """
190        self.__del__()
191
192    def __del__(self):
193        if self._ns:
194            self._ns.remove()
195            self._ns = None
196        if self._ns_peer:
197            self._ns_peer.remove()
198            self._ns_peer = None
199        if self._netns:
200            del self._netns
201            self._netns = None
202        if self.remote:
203            del self.remote
204            self.remote = None
205
206    def require_v4(self):
207        if not self.v4 or not self.remote_v4:
208            raise KsftSkipEx("Test requires IPv4 connectivity")
209
210    def require_v6(self):
211        if not self.v6 or not self.remote_v6:
212            raise KsftSkipEx("Test requires IPv6 connectivity")
213
214    def _require_cmd(self, comm, key, host=None):
215        cached = self._required_cmd.get(comm, {})
216        if cached.get(key) is None:
217            cached[key] = cmd("command -v -- " + comm, fail=False,
218                              shell=True, host=host).ret == 0
219        self._required_cmd[comm] = cached
220        return cached[key]
221
222    def require_cmd(self, comm, local=True, remote=False):
223        if local:
224            if not self._require_cmd(comm, "local"):
225                raise KsftSkipEx("Test requires command: " + comm)
226        if remote:
227            if not self._require_cmd(comm, "remote"):
228                raise KsftSkipEx("Test requires (remote) command: " + comm)
229
230    def wait_hw_stats_settle(self):
231        """
232        Wait for HW stats to become consistent, some devices DMA HW stats
233        periodically so events won't be reflected until next sync.
234        Good drivers will tell us via ethtool what their sync period is.
235        """
236        if self._stats_settle_time is None:
237            data = ethtool("-c " + self.ifname, json=True)[0]
238
239            self._stats_settle_time = 0.025 + \
240                data.get('stats-block-usecs', 0) / 1000 / 1000
241
242        time.sleep(self._stats_settle_time)
243