1#!/usr/bin/env python3
2import os
3import socket
4import struct
5import subprocess
6import sys
7from ctypes import c_byte
8from ctypes import c_char
9from ctypes import c_int
10from ctypes import c_long
11from ctypes import c_uint32
12from ctypes import c_uint8
13from ctypes import c_ulong
14from ctypes import c_ushort
15from ctypes import sizeof
16from ctypes import Structure
17from enum import Enum
18from typing import Any
19from typing import Dict
20from typing import List
21from typing import NamedTuple
22from typing import Optional
23from typing import Union
24
25import pytest
26from atf_python.sys.netpfil.ipfw.insns import BaseInsn
27from atf_python.sys.netpfil.ipfw.insns import insn_attrs
28from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType
29from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType
30from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType
31from atf_python.sys.netpfil.ipfw.utils import align8
32from atf_python.sys.netpfil.ipfw.utils import AttrDescr
33from atf_python.sys.netpfil.ipfw.utils import enum_from_int
34from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map
35
36
37class IpFw3OpHeader(Structure):
38    _fields_ = [
39        ("opcode", c_ushort),
40        ("version", c_ushort),
41        ("reserved1", c_ushort),
42        ("reserved2", c_ushort),
43    ]
44
45
46class IpFwObjTlv(Structure):
47    _fields_ = [
48        ("n_type", c_ushort),
49        ("flags", c_ushort),
50        ("length", c_uint32),
51    ]
52
53
54class BaseTlv(object):
55    obj_enum_class = IpFwTlvType
56
57    def __init__(self, obj_type):
58        if isinstance(obj_type, Enum):
59            self.obj_type = obj_type.value
60            self._enum = obj_type
61        else:
62            self.obj_type = obj_type
63            self._enum = enum_from_int(self.obj_enum_class, obj_type)
64        self.obj_list = []
65
66    def add_obj(self, obj):
67        self.obj_list.append(obj)
68
69    @property
70    def len(self):
71        return len(bytes(self))
72
73    @property
74    def obj_name(self):
75        if self._enum is not None:
76            return self._enum.name
77        else:
78            return "tlv#{}".format(self.obj_type)
79
80    def print_hdr(self, prepend=""):
81        print(
82            "{}len={} type={}({}){}".format(
83                prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value()
84            )
85        )
86
87    def print_obj(self, prepend=""):
88        self.print_hdr(prepend)
89        prepend = "  " + prepend
90        for obj in self.obj_list:
91            obj.print_obj(prepend)
92
93    def print_obj_hex(self, prepend=""):
94        print(prepend)
95        print()
96        print(" ".join(["x{:02X}".format(b) for b in bytes(self)]))
97
98    @classmethod
99    def _validate(cls, data):
100        if len(data) < sizeof(IpFwObjTlv):
101            raise ValueError("TLV too short")
102        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
103        if len(data) != hdr.length:
104            raise ValueError("wrong TLV size")
105
106    @classmethod
107    def _parse(cls, data, attr_map):
108        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
109        return cls(hdr.n_type)
110
111    @classmethod
112    def from_bytes(cls, data, attr_map=None):
113        cls._validate(data)
114        obj = cls._parse(data, attr_map)
115        return obj
116
117    def __bytes__(self):
118        raise NotImplementedError()
119
120    def _print_obj_value(self):
121        return " " + " ".join(
122            ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]]
123        )
124
125    def as_hexdump(self):
126        return " ".join(["x{:02X}".format(b) for b in bytes(self)])
127
128
129class UnknownTlv(BaseTlv):
130    def __init__(self, obj_type, data):
131        super().__init__(obj_type)
132        self._data = data
133
134    @classmethod
135    def _validate(cls, data):
136        if len(data) < sizeof(IpFwObjNTlv):
137            raise ValueError("TLV size is too short")
138        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
139        if len(data) != hdr.length:
140            raise ValueError("wrong TLV size")
141
142    @classmethod
143    def _parse(cls, data, attr_map):
144        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
145        self = cls(hdr.n_type, data)
146        return self
147
148    def __bytes__(self):
149        return self._data
150
151
152class Tlv(BaseTlv):
153    @staticmethod
154    def parse_tlvs(data, attr_map):
155        # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data]))
156        off = 0
157        ret = []
158        while off + sizeof(IpFwObjTlv) <= len(data):
159            hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)])
160            if off + hdr.length > len(data):
161                raise ValueError("TLV size do not match")
162            obj_data = data[off : off + hdr.length]
163            obj_descr = attr_map.get(hdr.n_type, None)
164            if obj_descr is None:
165                # raise ValueError("unknown child TLV {}".format(hdr.n_type))
166                cls = UnknownTlv
167                child_map = {}
168            else:
169                cls = obj_descr["ad"].cls
170                child_map = obj_descr.get("child", {})
171            # print("FOUND OBJECT type {}".format(cls))
172            # print()
173            obj = cls.from_bytes(obj_data, child_map)
174            ret.append(obj)
175            off += hdr.length
176        return ret
177
178
179class IpFwObjNTlv(Structure):
180    _fields_ = [
181        ("head", IpFwObjTlv),
182        ("idx", c_ushort),
183        ("n_set", c_uint8),
184        ("n_type", c_uint8),
185        ("spare", c_uint32),
186        ("name", c_char * 64),
187    ]
188
189
190class NTlv(Tlv):
191    def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None):
192        super().__init__(obj_type)
193        self.n_idx = idx
194        self.n_set = n_set
195        self.n_type = n_type
196        self.n_name = name
197
198    @classmethod
199    def _validate(cls, data):
200        if len(data) != sizeof(IpFwObjNTlv):
201            raise ValueError("TLV size is not correct")
202        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
203        if len(data) != hdr.length:
204            raise ValueError("wrong TLV size")
205
206    @classmethod
207    def _parse(cls, data, attr_map):
208        hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)])
209        name = hdr.name.decode("utf-8")
210        self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name)
211        return self
212
213    def __bytes__(self):
214        name_bytes = self.n_name.encode("utf-8")
215        if len(name_bytes) < 64:
216            name_bytes += b"\0" * (64 - len(name_bytes))
217        hdr = IpFwObjNTlv(
218            head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)),
219            idx=self.n_idx,
220            n_set=self.n_set,
221            n_type=self.n_type,
222            name=name_bytes[:64],
223        )
224        return bytes(hdr)
225
226    def _print_obj_value(self):
227        return " " + "type={} set={} idx={} name={}".format(
228            self.n_type, self.n_set, self.n_idx, self.n_name
229        )
230
231
232class IpFwObjCTlv(Structure):
233    _fields_ = [
234        ("head", IpFwObjTlv),
235        ("count", c_uint32),
236        ("objsize", c_ushort),
237        ("version", c_uint8),
238        ("flags", c_uint8),
239    ]
240
241
242class CTlv(Tlv):
243    def __init__(self, obj_type, obj_list=[]):
244        super().__init__(obj_type)
245        if obj_list:
246            self.obj_list.extend(obj_list)
247
248    @classmethod
249    def _validate(cls, data):
250        if len(data) < sizeof(IpFwObjCTlv):
251            raise ValueError("TLV too short")
252        hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
253        if len(data) != hdr.head.length:
254            raise ValueError("wrong TLV size")
255
256    @classmethod
257    def _parse(cls, data, attr_map):
258        hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
259        tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map)
260        if len(tlv_list) != hdr.count:
261            raise ValueError("wrong number of objects")
262        self = cls(hdr.head.n_type, obj_list=tlv_list)
263        return self
264
265    def __bytes__(self):
266        ret = b""
267        for obj in self.obj_list:
268            ret += bytes(obj)
269        length = len(ret) + sizeof(IpFwObjCTlv)
270        if self.obj_list:
271            objsize = len(bytes(self.obj_list[0]))
272        else:
273            objsize = 0
274        hdr = IpFwObjCTlv(
275            head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)),
276            count=len(self.obj_list),
277            objsize=objsize,
278        )
279        return bytes(hdr) + ret
280
281    def _print_obj_value(self):
282        return ""
283
284
285class IpFwRule(Structure):
286    _fields_ = [
287        ("act_ofs", c_ushort),
288        ("cmd_len", c_ushort),
289        ("spare", c_ushort),
290        ("n_set", c_uint8),
291        ("flags", c_uint8),
292        ("rulenum", c_uint32),
293        ("n_id", c_uint32),
294    ]
295
296
297class RawRule(Tlv):
298    def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]):
299        super().__init__(obj_type)
300        self.n_set = n_set
301        self.rulenum = rulenum
302        if obj_list:
303            self.obj_list.extend(obj_list)
304
305    @classmethod
306    def _validate(cls, data):
307        min_size = sizeof(IpFwRule)
308        if len(data) < min_size:
309            raise ValueError("rule TLV too short")
310        rule = IpFwRule.from_buffer_copy(data[:min_size])
311        if len(data) != min_size + rule.cmd_len * 4:
312            raise ValueError("rule TLV cmd_len incorrect")
313
314    @classmethod
315    def _parse(cls, data, attr_map):
316        hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)])
317        self = cls(
318            n_set=hdr.n_set,
319            rulenum=hdr.rulenum,
320            obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs),
321        )
322        return self
323
324    def __bytes__(self):
325        act_ofs = 0
326        cmd_len = 0
327        ret = b""
328        for obj in self.obj_list:
329            if obj.is_action and act_ofs == 0:
330                act_ofs = cmd_len
331            obj_bytes = bytes(obj)
332            cmd_len += len(obj_bytes) // 4
333            ret += obj_bytes
334
335        hdr = IpFwRule(
336            act_ofs=act_ofs,
337            cmd_len=cmd_len,
338            n_set=self.n_set,
339            rulenum=self.rulenum,
340        )
341        return bytes(hdr) + ret
342
343    @property
344    def obj_name(self):
345        return "rule#{}".format(self.rulenum)
346
347    def _print_obj_value(self):
348        cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4
349        return " set={} cmd_len={}".format(self.n_set, cmd_len)
350
351
352class CTlvRule(CTlv):
353    def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]):
354        super().__init__(obj_type, obj_list)
355
356    @classmethod
357    def _parse(cls, data, attr_map):
358        chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
359        rule_list = []
360        off = sizeof(IpFwObjCTlv)
361        while off + sizeof(IpFwRule) <= len(data):
362            hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)])
363            rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4
364            # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len))
365            if off + rule_len > len(data):
366                raise ValueError("wrong rule size")
367            rule = RawRule.from_bytes(data[off : off + rule_len])
368            rule_list.append(rule)
369            off += align8(rule_len)
370        if off != len(data):
371            raise ValueError("rule bytes left: off={} len={}".format(off, len(data)))
372        return cls(chdr.head.n_type, obj_list=rule_list)
373
374    # XXX: _validate
375
376    def __bytes__(self):
377        ret = b""
378        for rule in self.obj_list:
379            rule_bytes = bytes(rule)
380            remainder = len(rule_bytes) % 8
381            if remainder > 0:
382                rule_bytes += b"\0" * (8 - remainder)
383            ret += rule_bytes
384        hdr = IpFwObjCTlv(
385            head=IpFwObjTlv(
386                n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv)
387            ),
388            count=len(self.obj_list),
389        )
390        return bytes(hdr) + ret
391
392
393class BaseIpFwMessage(object):
394    messages = []
395
396    def __init__(self, msg_type, obj_list=[]):
397        if isinstance(msg_type, Enum):
398            self.obj_type = msg_type.value
399            self._enum = msg_type
400        else:
401            self.obj_type = msg_type
402            self._enum = enum_from_int(self.messages, self.obj_type)
403        self.obj_list = []
404        if obj_list:
405            self.obj_list.extend(obj_list)
406
407    def add_obj(self, obj):
408        self.obj_list.append(obj)
409
410    def get_obj(self, obj_type):
411        obj_type_raw = enum_or_int(obj_type)
412        for obj in self.obj_list:
413            if obj.obj_type == obj_type_raw:
414                return obj
415        return None
416
417    @staticmethod
418    def parse_header(data: bytes):
419        if len(data) < sizeof(IpFw3OpHeader):
420            raise ValueError("length less than op3 message header")
421        return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader)
422
423    def parse_obj_list(self, data: bytes):
424        off = 0
425        while off < len(data):
426            # print("PARSE off={} rem={}".format(off, len(data) - off))
427            hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)])
428            # print(" tlv len {}".format(hdr.length))
429            if hdr.length + off > len(data):
430                raise ValueError("TLV too big")
431            tlv = Tlv(hdr.n_type, data[off : off + hdr.length])
432            self.add_obj(tlv)
433            off += hdr.length
434
435    def is_type(self, msg_type):
436        return enum_or_int(msg_type) == self.msg_type
437
438    @property
439    def obj_name(self):
440        if self._enum is not None:
441            return self._enum.name
442        else:
443            return "msg#{}".format(self.obj_type)
444
445    def print_hdr(self, prepend=""):
446        print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name))
447
448    @classmethod
449    def from_bytes(cls, data):
450        try:
451            hdr, hdrlen = cls.parse_header(data)
452            self = cls(hdr.opcode)
453            self._orig_data = data
454        except ValueError as e:
455            print("Failed to parse op3 header: {}".format(e))
456            cls.print_as_bytes(data)
457            raise
458        tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map)
459        self.obj_list.extend(tlv_list)
460        return self
461
462    def __bytes__(self):
463        ret = bytes(IpFw3OpHeader(opcode=self.obj_type))
464        for obj in self.obj_list:
465            ret += bytes(obj)
466        return ret
467
468    def print_obj(self):
469        self.print_hdr()
470        for obj in self.obj_list:
471            obj.print_obj("  ")
472
473    @staticmethod
474    def print_as_bytes(data: bytes, descr: str):
475        print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
476        off = 0
477        step = 16
478        while off < len(data):
479            for i in range(step):
480                if off + i < len(data):
481                    print(" {:02X}".format(data[off + i]), end="")
482            print("")
483            off += step
484        print("--------------------")
485
486
487rule_attrs = prepare_attrs_map(
488    [
489        AttrDescr(
490            IpFwTlvType.IPFW_TLV_TBLNAME_LIST,
491            CTlv,
492            [
493                AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv),
494                AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv),
495                AttrDescr(IpFwTlvType.IPFW_TLV_EACTION, NTlv),
496            ],
497            True,
498        ),
499        AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule),
500    ]
501)
502
503
504class IpFwXRule(BaseIpFwMessage):
505    messages = [Op3CmdType.IP_FW_XADD]
506    attr_map = rule_attrs
507
508
509legacy_classes = []
510set3_classes = []
511get3_classes = [IpFwXRule]
512