1# SPDX-License-Identifier: GPL-2.0
2
3import os
4from pathlib import Path
5from lib.py import KsftSkipEx, KsftXfailEx
6from lib.py import cmd, ip
7from lib.py import NetNS, NetdevSimDev
8from .remote import Remote
9
10
11def _load_env_file(src_path):
12    env = os.environ.copy()
13
14    src_dir = Path(src_path).parent.resolve()
15    if not (src_dir / "net.config").exists():
16        return env
17
18    with open((src_dir / "net.config").as_posix(), 'r') as fp:
19        for line in fp.readlines():
20            full_file = line
21            # Strip comments
22            pos = line.find("#")
23            if pos >= 0:
24                line = line[:pos]
25            line = line.strip()
26            if not line:
27                continue
28            pair = line.split('=', maxsplit=1)
29            if len(pair) != 2:
30                raise Exception("Can't parse configuration line:", full_file)
31            env[pair[0]] = pair[1]
32    return env
33
34
35class NetDrvEnv:
36    """
37    Class for a single NIC / host env, with no remote end
38    """
39    def __init__(self, src_path, **kwargs):
40        self._ns = None
41
42        self.env = _load_env_file(src_path)
43
44        if 'NETIF' in self.env:
45            self.dev = ip("link show dev " + self.env['NETIF'], json=True)[0]
46        else:
47            self._ns = NetdevSimDev(**kwargs)
48            self.dev = self._ns.nsims[0].dev
49        self.ifindex = self.dev['ifindex']
50
51    def __enter__(self):
52        ip(f"link set dev {self.dev['ifname']} up")
53
54        return self
55
56    def __exit__(self, ex_type, ex_value, ex_tb):
57        """
58        __exit__ gets called at the end of a "with" block.
59        """
60        self.__del__()
61
62    def __del__(self):
63        if self._ns:
64            self._ns.remove()
65            self._ns = None
66
67
68class NetDrvEpEnv:
69    """
70    Class for an environment with a local device and "remote endpoint"
71    which can be used to send traffic in.
72
73    For local testing it creates two network namespaces and a pair
74    of netdevsim devices.
75    """
76
77    # Network prefixes used for local tests
78    nsim_v4_pfx = "192.0.2."
79    nsim_v6_pfx = "2001:db8::"
80
81    def __init__(self, src_path, nsim_test=None):
82
83        self.env = _load_env_file(src_path)
84
85        # Things we try to destroy
86        self.remote = None
87        # These are for local testing state
88        self._netns = None
89        self._ns = None
90        self._ns_peer = None
91
92        if "NETIF" in self.env:
93            if nsim_test is True:
94                raise KsftXfailEx("Test only works on netdevsim")
95            self._check_env()
96
97            self.dev = ip("link show dev " + self.env['NETIF'], json=True)[0]
98
99            self.v4 = self.env.get("LOCAL_V4")
100            self.v6 = self.env.get("LOCAL_V6")
101            self.remote_v4 = self.env.get("REMOTE_V4")
102            self.remote_v6 = self.env.get("REMOTE_V6")
103            kind = self.env["REMOTE_TYPE"]
104            args = self.env["REMOTE_ARGS"]
105        else:
106            if nsim_test is False:
107                raise KsftXfailEx("Test does not work on netdevsim")
108
109            self.create_local()
110
111            self.dev = self._ns.nsims[0].dev
112
113            self.v4 = self.nsim_v4_pfx + "1"
114            self.v6 = self.nsim_v6_pfx + "1"
115            self.remote_v4 = self.nsim_v4_pfx + "2"
116            self.remote_v6 = self.nsim_v6_pfx + "2"
117            kind = "netns"
118            args = self._netns.name
119
120        self.remote = Remote(kind, args, src_path)
121
122        self.addr = self.v6 if self.v6 else self.v4
123        self.remote_addr = self.remote_v6 if self.remote_v6 else self.remote_v4
124
125        self.addr_ipver = "6" if self.v6 else "4"
126        # Bracketed addresses, some commands need IPv6 to be inside []
127        self.baddr = f"[{self.v6}]" if self.v6 else self.v4
128        self.remote_baddr = f"[{self.remote_v6}]" if self.remote_v6 else self.remote_v4
129
130        self.ifname = self.dev['ifname']
131        self.ifindex = self.dev['ifindex']
132
133        self._required_cmd = {}
134
135    def create_local(self):
136        self._netns = NetNS()
137        self._ns = NetdevSimDev()
138        self._ns_peer = NetdevSimDev(ns=self._netns)
139
140        with open("/proc/self/ns/net") as nsfd0, \
141             open("/var/run/netns/" + self._netns.name) as nsfd1:
142            ifi0 = self._ns.nsims[0].ifindex
143            ifi1 = self._ns_peer.nsims[0].ifindex
144            NetdevSimDev.ctrl_write('link_device',
145                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
146
147        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
148        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
149        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
150
151        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
152        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
153        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
154
155    def _check_env(self):
156        vars_needed = [
157            ["LOCAL_V4", "LOCAL_V6"],
158            ["REMOTE_V4", "REMOTE_V6"],
159            ["REMOTE_TYPE"],
160            ["REMOTE_ARGS"]
161        ]
162        missing = []
163
164        for choice in vars_needed:
165            for entry in choice:
166                if entry in self.env:
167                    break
168            else:
169                missing.append(choice)
170        # Make sure v4 / v6 configs are symmetric
171        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
172            missing.append(["LOCAL_V6", "REMOTE_V6"])
173        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
174            missing.append(["LOCAL_V4", "REMOTE_V4"])
175        if missing:
176            raise Exception("Invalid environment, missing configuration:", missing,
177                            "Please see tools/testing/selftests/drivers/net/README.rst")
178
179    def __enter__(self):
180        return self
181
182    def __exit__(self, ex_type, ex_value, ex_tb):
183        """
184        __exit__ gets called at the end of a "with" block.
185        """
186        self.__del__()
187
188    def __del__(self):
189        if self._ns:
190            self._ns.remove()
191            self._ns = None
192        if self._ns_peer:
193            self._ns_peer.remove()
194            self._ns_peer = None
195        if self._netns:
196            del self._netns
197            self._netns = None
198        if self.remote:
199            del self.remote
200            self.remote = None
201
202    def require_v4(self):
203        if not self.v4 or not self.remote_v4:
204            raise KsftSkipEx("Test requires IPv4 connectivity")
205
206    def require_v6(self):
207        if not self.v6 or not self.remote_v6:
208            raise KsftSkipEx("Test requires IPv6 connectivity")
209
210    def _require_cmd(self, comm, key, host=None):
211        cached = self._required_cmd.get(comm, {})
212        if cached.get(key) is None:
213            cached[key] = cmd("command -v -- " + comm, fail=False,
214                              shell=True, host=host).ret == 0
215        self._required_cmd[comm] = cached
216        return cached[key]
217
218    def require_cmd(self, comm, local=True, remote=False):
219        if local:
220            if not self._require_cmd(comm, "local"):
221                raise KsftSkipEx("Test requires command: " + comm)
222        if remote:
223            if not self._require_cmd(comm, "remote"):
224                raise KsftSkipEx("Test requires (remote) command: " + comm)
225