1# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
2#
3# SPDX-License-Identifier: MPL-2.0
4#
5# This Source Code Form is subject to the terms of the Mozilla Public
6# License, v. 2.0.  If a copy of the MPL was not distributed with this
7# file, you can obtain one at https://mozilla.org/MPL/2.0/.
8#
9# See the COPYRIGHT file distributed with this work for additional
10# information regarding copyright ownership.
11
12import os
13import select
14import signal
15import socket
16import sys
17import time
18
19import dns.flags
20import dns.message
21
22
23def port():
24    env_port = os.getenv("PORT")
25    if env_port is None:
26        env_port = 5300
27    else:
28        env_port = int(env_port)
29
30    return env_port
31
32
33def udp_listen(port):
34    udp = socket.socket(type=socket.SOCK_DGRAM)
35    udp.bind(("10.53.0.3", port))
36
37    return udp
38
39
40def tcp_listen(port):
41    tcp = socket.socket(type=socket.SOCK_STREAM)
42    tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
43    tcp.bind(("10.53.0.3", port))
44    tcp.listen(100)
45
46    return tcp
47
48
49def udp_tc_once(udp):
50    qrybytes, clientaddr = udp.recvfrom(65535)
51    qry = dns.message.from_wire(qrybytes)
52    answ = dns.message.make_response(qry)
53    answ.flags |= dns.flags.TC
54    answbytes = answ.to_wire()
55    udp.sendto(answbytes, clientaddr)
56
57
58def tcp_once(tcp):
59    csock, _clientaddr = tcp.accept()
60    time.sleep(5)
61    csock.close()
62
63
64def sigterm(signum, frame):
65    os.remove("ans.pid")
66    sys.exit(0)
67
68
69def write_pid():
70    with open("ans.pid", "w") as f:
71        pid = os.getpid()
72        f.write("{}".format(pid))
73
74
75signal.signal(signal.SIGTERM, sigterm)
76write_pid()
77
78udp = udp_listen(port())
79tcp = tcp_listen(port())
80
81input = [udp, tcp]
82
83while True:
84    try:
85        inputready, outputready, exceptready = select.select(input, [], [])
86    except select.error:
87        break
88    except socket.error:
89        break
90    except KeyboardInterrupt:
91        break
92
93    for s in inputready:
94        if s == udp:
95            udp_tc_once(udp)
96        if s == tcp:
97            tcp_once(tcp)
98
99sigterm(signal.SIGTERM, 0)
100