1# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
2#
3# SPDX-License-Identifier: MPL-2.0
4#
5# This Source Code Form is subject to the terms of the Mozilla Public
6# License, v. 2.0.  If a copy of the MPL was not distributed with this
7# file, you can obtain one at https://mozilla.org/MPL/2.0/.
8#
9# See the COPYRIGHT file distributed with this work for additional
10# information regarding copyright ownership.
11
12from datetime import datetime, timedelta
13from collections import defaultdict
14from time import sleep
15import os
16
17import dns.message
18import dns.query
19import dns.rcode
20
21import isctest
22
23
24# ISO datetime format without msec
25fmt = "%Y-%m-%dT%H:%M:%SZ"
26
27# The constants were taken from BIND 9 source code (lib/dns/zone.c)
28max_refresh = timedelta(seconds=2419200)  # 4 weeks
29max_expires = timedelta(seconds=14515200)  # 24 weeks
30dayzero = datetime.utcfromtimestamp(0).replace(microsecond=0)
31
32# Wait for the secondary zone files to appear to extract their mtime
33max_secondary_zone_waittime_sec = 5
34
35
36# Generic helper functions
37def check_expires(expires, min_time, max_time):
38    assert expires >= min_time
39    assert expires <= max_time
40
41
42def check_refresh(refresh, min_time, max_time):
43    assert refresh >= min_time
44    assert refresh <= max_time
45
46
47def check_loaded(loaded, expected, now):
48    # Sanity check the zone timers values
49    assert (loaded - expected).total_seconds() < max_secondary_zone_waittime_sec
50    assert loaded <= now
51
52
53def check_zone_timers(loaded, expires, refresh, loaded_exp):
54    now = datetime.utcnow().replace(microsecond=0)
55    # Sanity checks the zone timers values
56    if expires is not None:
57        check_expires(expires, now, now + max_expires)
58    if refresh is not None:
59        check_refresh(refresh, now, now + max_refresh)
60    check_loaded(loaded, loaded_exp, now)
61
62
63#
64# The output is gibberish, but at least make sure it does not crash.
65#
66def check_manykeys(name, zone=None):
67    # pylint: disable=unused-argument
68    assert name == "manykeys"
69
70
71def zone_mtime(zonedir, name):
72    try:
73        si = os.stat(os.path.join(zonedir, "{}.db".format(name)))
74    except FileNotFoundError:
75        return dayzero
76
77    mtime = datetime.utcfromtimestamp(si.st_mtime).replace(microsecond=0)
78
79    return mtime
80
81
82def test_zone_timers_primary(fetch_zones, load_timers, **kwargs):
83    statsip = kwargs["statsip"]
84    statsport = kwargs["statsport"]
85    zonedir = kwargs["zonedir"]
86
87    zones = fetch_zones(statsip, statsport)
88
89    for zone in zones:
90        (name, loaded, expires, refresh) = load_timers(zone, True)
91        mtime = zone_mtime(zonedir, name)
92        check_zone_timers(loaded, expires, refresh, mtime)
93
94
95def test_zone_timers_secondary(fetch_zones, load_timers, **kwargs):
96    statsip = kwargs["statsip"]
97    statsport = kwargs["statsport"]
98    zonedir = kwargs["zonedir"]
99
100    # If any one of the zone files isn't ready, then retry until timeout.
101    tries = max_secondary_zone_waittime_sec
102    while tries >= 0:
103        zones = fetch_zones(statsip, statsport)
104        again = False
105        for zone in zones:
106            (name, loaded, expires, refresh) = load_timers(zone, False)
107            mtime = zone_mtime(zonedir, name)
108            if (mtime != dayzero) or (tries == 0):
109                # mtime was either retrieved successfully or no tries were
110                # left, run the check anyway.
111                check_zone_timers(loaded, expires, refresh, mtime)
112            else:
113                tries = tries - 1
114                again = True
115                break
116        if again:
117            sleep(1)
118        else:
119            break
120
121
122def test_zone_with_many_keys(fetch_zones, load_zone, **kwargs):
123    statsip = kwargs["statsip"]
124    statsport = kwargs["statsport"]
125
126    zones = fetch_zones(statsip, statsport)
127
128    for zone in zones:
129        name = load_zone(zone)
130        if name == "manykeys":
131            check_manykeys(name)
132
133
134def create_msg(qname, qtype):
135    msg = dns.message.make_query(
136        qname, qtype, want_dnssec=True, use_edns=0, payload=4096
137    )
138
139    return msg
140
141
142def create_expected(data):
143    expected = {
144        "dns-tcp-requests-sizes-received-ipv4": defaultdict(int),
145        "dns-tcp-responses-sizes-sent-ipv4": defaultdict(int),
146        "dns-tcp-requests-sizes-received-ipv6": defaultdict(int),
147        "dns-tcp-responses-sizes-sent-ipv6": defaultdict(int),
148        "dns-udp-requests-sizes-received-ipv4": defaultdict(int),
149        "dns-udp-requests-sizes-received-ipv6": defaultdict(int),
150        "dns-udp-responses-sizes-sent-ipv4": defaultdict(int),
151        "dns-udp-responses-sizes-sent-ipv6": defaultdict(int),
152    }
153
154    for k, v in data.items():
155        for kk, vv in v.items():
156            expected[k][kk] += vv
157
158    return expected
159
160
161def update_expected(expected, key, msg):
162    msg_len = len(msg.to_wire())
163    bucket_num = (msg_len // 16) * 16
164    bucket = "{}-{}".format(bucket_num, bucket_num + 15)
165
166    expected[key][bucket] += 1
167
168
169def check_traffic(data, expected):
170    def ordered(obj):
171        if isinstance(obj, dict):
172            return sorted((k, ordered(v)) for k, v in obj.items())
173        if isinstance(obj, list):
174            return sorted(ordered(x) for x in obj)
175        return obj
176
177    ordered_data = ordered(data)
178    ordered_expected = ordered(expected)
179
180    assert len(ordered_data) == 8
181    assert len(ordered_expected) == 8
182    assert len(data) == len(ordered_data)
183    assert len(expected) == len(ordered_expected)
184
185    assert ordered_data == ordered_expected
186
187
188def test_traffic(fetch_traffic, **kwargs):
189    statsip = kwargs["statsip"]
190    statsport = kwargs["statsport"]
191
192    data = fetch_traffic(statsip, statsport)
193    exp = create_expected(data)
194
195    msg = create_msg("short.example.", "TXT")
196    update_expected(exp, "dns-udp-requests-sizes-received-ipv4", msg)
197    ans = isctest.query.udp(msg, statsip)
198    isctest.check.noerror(ans)
199    update_expected(exp, "dns-udp-responses-sizes-sent-ipv4", ans)
200    data = fetch_traffic(statsip, statsport)
201
202    check_traffic(data, exp)
203
204    msg = create_msg("long.example.", "TXT")
205    update_expected(exp, "dns-udp-requests-sizes-received-ipv4", msg)
206    ans = isctest.query.udp(msg, statsip)
207    isctest.check.noerror(ans)
208    update_expected(exp, "dns-udp-responses-sizes-sent-ipv4", ans)
209    data = fetch_traffic(statsip, statsport)
210
211    check_traffic(data, exp)
212
213    msg = create_msg("short.example.", "TXT")
214    update_expected(exp, "dns-tcp-requests-sizes-received-ipv4", msg)
215    ans = isctest.query.tcp(msg, statsip)
216    isctest.check.noerror(ans)
217    update_expected(exp, "dns-tcp-responses-sizes-sent-ipv4", ans)
218    data = fetch_traffic(statsip, statsport)
219
220    check_traffic(data, exp)
221
222    msg = create_msg("long.example.", "TXT")
223    update_expected(exp, "dns-tcp-requests-sizes-received-ipv4", msg)
224    ans = isctest.query.tcp(msg, statsip)
225    isctest.check.noerror(ans)
226    update_expected(exp, "dns-tcp-responses-sizes-sent-ipv4", ans)
227    data = fetch_traffic(statsip, statsport)
228
229    check_traffic(data, exp)
230