1#!/usr/bin/python3
2
3# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
4#
5# SPDX-License-Identifier: MPL-2.0
6#
7# This Source Code Form is subject to the terms of the Mozilla Public
8# License, v. 2.0.  If a copy of the MPL was not distributed with this
9# file, you can obtain one at https://mozilla.org/MPL/2.0/.
10#
11# See the COPYRIGHT file distributed with this work for additional
12# information regarding copyright ownership.
13
14import time
15import os
16
17import pytest
18
19pytest.importorskip("dns", minversion="2.0.0")
20import dns.resolver
21
22
23def wait_for_transfer(ip, port, client_ip, name, rrtype):
24    resolver = dns.resolver.Resolver()
25    resolver.nameservers = [ip]
26    resolver.port = port
27
28    for _ in range(10):
29        try:
30            resolver.resolve(name, rrtype, source=client_ip)
31        except dns.resolver.NoNameservers:
32            time.sleep(1)
33        else:
34            break
35    else:
36        raise RuntimeError(
37            "zone transfer failed: "
38            f"client {client_ip} got NXDOMAIN for {name} {rrtype} from @{ip}:{port}"
39        )
40
41
42def test_rpz_multiple_views(named_port):
43    resolver = dns.resolver.Resolver()
44    resolver.nameservers = ["10.53.0.3"]
45    resolver.port = named_port
46
47    wait_for_transfer("10.53.0.3", named_port, "10.53.0.2", "rpz-external.local", "SOA")
48    wait_for_transfer("10.53.0.3", named_port, "10.53.0.5", "rpz-external.local", "SOA")
49
50    # For 10.53.0.1 source IP:
51    # - baddomain.com isn't allowed (CNAME .), should return NXDOMAIN
52    # - gooddomain.com is allowed
53    # - allowed. is allowed
54    with pytest.raises(dns.resolver.NXDOMAIN):
55        resolver.resolve("baddomain.", "A", source="10.53.0.1")
56
57    ans = resolver.resolve("gooddomain.", "A", source="10.53.0.1")
58    assert ans[0].address == "10.53.0.2"
59
60    ans = resolver.resolve("allowed.", "A", source="10.53.0.1")
61    assert ans[0].address == "10.53.0.2"
62
63    # For 10.53.0.2 source IP:
64    # - allowed.com isn't allowed (CNAME .), should return NXDOMAIN
65    # - baddomain.com is allowed
66    # - gooddomain.com is allowed
67    ans = resolver.resolve("baddomain.", "A", source="10.53.0.2")
68    assert ans[0].address == "10.53.0.2"
69
70    ans = resolver.resolve("gooddomain.", "A", source="10.53.0.2")
71    assert ans[0].address == "10.53.0.2"
72
73    with pytest.raises(dns.resolver.NXDOMAIN):
74        resolver.resolve("allowed.", "A", source="10.53.0.2")
75
76    # For 10.53.0.3 source IP:
77    # - gooddomain.com is allowed
78    # - baddomain.com is allowed
79    # - allowed. is allowed
80    ans = resolver.resolve("baddomain.", "A", source="10.53.0.3")
81    assert ans[0].address == "10.53.0.2"
82
83    ans = resolver.resolve("gooddomain.", "A", source="10.53.0.3")
84    assert ans[0].address == "10.53.0.2"
85
86    ans = resolver.resolve("allowed.", "A", source="10.53.0.3")
87    assert ans[0].address == "10.53.0.2"
88
89    # For 10.53.0.4 source IP:
90    # - gooddomain.com isn't allowed (CNAME .), should return NXDOMAIN
91    # - baddomain.com isn't allowed (CNAME .), should return NXDOMAIN
92    # - allowed. is allowed
93    with pytest.raises(dns.resolver.NXDOMAIN):
94        resolver.resolve("baddomain.", "A", source="10.53.0.4")
95
96    with pytest.raises(dns.resolver.NXDOMAIN):
97        resolver.resolve("gooddomain.", "A", source="10.53.0.4")
98
99    ans = resolver.resolve("allowed.", "A", source="10.53.0.4")
100    assert ans[0].address == "10.53.0.2"
101
102    # For 10.53.0.5 (any) source IP:
103    # - baddomain.com is allowed
104    # - gooddomain.com isn't allowed (CNAME .), should return NXDOMAIN
105    # - allowed.com isn't allowed (CNAME .), should return NXDOMAIN
106    ans = resolver.resolve("baddomain.", "A", source="10.53.0.5")
107    assert ans[0].address == "10.53.0.2"
108
109    with pytest.raises(dns.resolver.NXDOMAIN):
110        resolver.resolve("gooddomain.", "A", source="10.53.0.5")
111
112    with pytest.raises(dns.resolver.NXDOMAIN):
113        resolver.resolve("allowed.", "A", source="10.53.0.5")
114
115
116def test_rpz_passthru_logging(named_port):
117    resolver = dns.resolver.Resolver()
118    resolver.nameservers = ["10.53.0.3"]
119    resolver.port = named_port
120
121    # Should generate a log entry into rpz_passthru.txt
122    ans = resolver.resolve("allowed.", "A", source="10.53.0.1")
123    assert ans[0].address == "10.53.0.2"
124
125    # baddomain.com isn't allowed (CNAME .), should return NXDOMAIN
126    # Should generate a log entry into rpz.txt
127    with pytest.raises(dns.resolver.NXDOMAIN):
128        resolver.resolve("baddomain.", "A", source="10.53.0.1")
129
130    rpz_passthru_logfile = os.path.join("ns3", "rpz_passthru.txt")
131    rpz_logfile = os.path.join("ns3", "rpz.txt")
132
133    assert os.path.isfile(rpz_passthru_logfile)
134    assert os.path.isfile(rpz_logfile)
135
136    with open(rpz_passthru_logfile, encoding="utf-8") as log_file:
137        line = log_file.read()
138        assert "rpz QNAME PASSTHRU rewrite allowed/A/IN" in line
139
140    with open(rpz_logfile, encoding="utf-8") as log_file:
141        line = log_file.read()
142        assert "rpz QNAME PASSTHRU rewrite allowed/A/IN" not in line
143        assert "rpz QNAME NXDOMAIN rewrite baddomain/A/IN" in line
144