1#!/usr/bin/python3
2
3# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
4#
5# SPDX-License-Identifier: MPL-2.0
6#
7# This Source Code Form is subject to the terms of the Mozilla Public
8# License, v. 2.0.  If a copy of the MPL was not distributed with this
9# file, you can obtain one at https://mozilla.org/MPL/2.0/.
10#
11# See the COPYRIGHT file distributed with this work for additional
12# information regarding copyright ownership.
13
14# pylint: disable=unused-variable
15
16import socket
17import struct
18import time
19
20import pytest
21
22pytest.importorskip("dns", minversion="2.0.0")
23import dns.message
24import dns.query
25
26
27TIMEOUT = 10
28
29
30def create_msg(qname, qtype, edns=-1):
31    msg = dns.message.make_query(qname, qtype, use_edns=edns)
32    return msg
33
34
35def timeout():
36    return time.time() + TIMEOUT
37
38
39def create_socket(host, port):
40    sock = socket.create_connection((host, port), timeout=10)
41    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
42    return sock
43
44
45def test_tcp_garbage(named_port):
46    with create_socket("10.53.0.7", named_port) as sock:
47        msg = create_msg("a.example.", "A")
48        (sbytes, stime) = dns.query.send_tcp(sock, msg, timeout())
49        (response, rtime) = dns.query.receive_tcp(sock, timeout())
50
51        wire = msg.to_wire()
52        assert len(wire) > 0
53
54        # Send DNS message shorter than DNS message header (12),
55        # this should cause the connection to be terminated
56        sock.send(struct.pack("!H", 11))
57        sock.send(struct.pack("!s", b"0123456789a"))
58
59        with pytest.raises(EOFError):
60            try:
61                (sbytes, stime) = dns.query.send_tcp(sock, msg, timeout())
62                (response, rtime) = dns.query.receive_tcp(sock, timeout())
63            except ConnectionError as e:
64                raise EOFError from e
65
66
67def test_tcp_garbage_response(named_port):
68    with create_socket("10.53.0.7", named_port) as sock:
69        msg = create_msg("a.example.", "A")
70        (sbytes, stime) = dns.query.send_tcp(sock, msg, timeout())
71        (response, rtime) = dns.query.receive_tcp(sock, timeout())
72
73        wire = msg.to_wire()
74        assert len(wire) > 0
75
76        # Send DNS response instead of DNS query, this should cause
77        # the connection to be terminated
78
79        rmsg = dns.message.make_response(msg)
80        (sbytes, stime) = dns.query.send_tcp(sock, rmsg, timeout())
81
82        with pytest.raises(EOFError):
83            try:
84                (response, rtime) = dns.query.receive_tcp(sock, timeout())
85            except ConnectionError as e:
86                raise EOFError from e
87
88
89# Regression test for CVE-2022-0396
90def test_close_wait(named_port):
91    with create_socket("10.53.0.7", named_port) as sock:
92        msg = create_msg("a.example.", "A")
93        (sbytes, stime) = dns.query.send_tcp(sock, msg, timeout())
94        (response, rtime) = dns.query.receive_tcp(sock, timeout())
95
96        msg = dns.message.make_query("a.example.", "A", use_edns=0, payload=1232)
97        (sbytes, stime) = dns.query.send_tcp(sock, msg, timeout())
98
99        # Shutdown the socket, but ignore the other side closing the socket
100        # first because we sent DNS message with EDNS0
101        try:
102            sock.shutdown(socket.SHUT_RDWR)
103        except ConnectionError:
104            pass
105        except OSError:
106            pass
107
108    # BIND allows one TCP client, the part above sends DNS messaage with EDNS0
109    # after the first query. BIND should react adequately because of
110    # ns7/named.dropedns and close the socket, making room for the next
111    # request. If it gets stuck in CLOSE_WAIT state, there is no connection
112    # available for the query below and it will time out.
113    with create_socket("10.53.0.7", named_port) as sock:
114        msg = create_msg("a.example.", "A")
115        (sbytes, stime) = dns.query.send_tcp(sock, msg, timeout())
116        (response, rtime) = dns.query.receive_tcp(sock, timeout())
117