1#
2# Copyright 2020 Haiku, Inc. All rights reserved.
3# Distributed under the terms of the MIT License.
4#
5# Authors:
6#  Kyle Ambroff-Kao, kyle@ambroffkao.com
7#
8
9"""
10Transparent HTTP proxy.
11"""
12
13import http.client
14import http.server
15import optparse
16import socket
17import sys
18import urllib.parse
19
20
21class RequestHandler(http.server.BaseHTTPRequestHandler):
22    """
23    Implement the basic requirements for a transparent HTTP proxy as defined
24    by RFC 7230. Enough of the functionality is implemented to support the
25    integration tests in HttpTest that use the HTTP proxy feature.
26
27    There are many error conditions and failure modes which are not handled.
28    Those cases can be added as the test suite expands to handle more error
29    cases.
30    """
31    def __init__(self, *args, **kwargs):
32        # This is used to hold on to persistent connections to the downstream
33        # servers. This maps downstream_host:port => HTTPConnection
34        #
35        # This implementation is not thread safe, but that's OK we only have
36        # a single thread anyway.
37        self._connections = {}
38
39        super(RequestHandler, self).__init__(*args, **kwargs)
40
41    def _proxy_request(self):
42        # Extract the downstream server from the request path.
43        #
44        # Note that no attempt is made to prevent message forwarding loops
45        # here. This doesn't need to be a complete proxy implementation, just
46        # enough of one for integration tests. RFC 7230 section 5.7 says if
47        # this were a complete implementation, it would have to make sure that
48        # the target system was not this process to avoid a loop.
49        target = urllib.parse.urlparse(self.path)
50
51        # If Connection: close wasn't used, then we may still have a connection
52        # to this downstream server handy.
53        conn = self._connections.get(target.netloc, None)
54        if conn is None:
55            conn = http.client.HTTPConnection(target.netloc)
56
57        # Collect headers from client which will be sent to the downstream
58        # server.
59        client_headers = {}
60        for header_name in self.headers:
61            if header_name in ('Host', 'Content-Length'):
62                continue
63            for header_value in self.headers.get_all(header_name):
64                client_headers[header_name] = header_value
65
66        # Compute X-Forwarded-For header
67        client_address = '{}:{}'.format(*self.client_address)
68        x_forwarded_for_header = self.headers.get('X-Forwarded-For', None)
69        if x_forwarded_for_header is None:
70            client_headers['X-Forwarded-For'] = client_address
71        else:
72            client_headers['X-Forwarded-For'] = \
73                x_forwarded_for_header + ', ' + client_address
74
75        # Read the request body from client.
76        request_body_length = int(self.headers.get('Content-Length', '0'))
77        request_body = self.rfile.read(request_body_length)
78
79        # Send the request to the downstream server
80        if target.query:
81            target_path = target.path + '?' + target.query
82        else:
83            target_path = target.path
84        conn.request(self.command, target_path, request_body, client_headers)
85        response = conn.getresponse()
86
87        # Echo the response to the client.
88        self.send_response_only(response.status, response.reason)
89        for header_name, header_value in response.headers.items():
90            self.send_header(header_name, header_value)
91        self.end_headers()
92
93        # Read the response body from upstream and write it to downstream, if
94        # there is a response body at all.
95        response_content_length = \
96            int(response.headers.get('Content-Length', '0'))
97        if response_content_length > 0:
98            self.wfile.write(response.read(response_content_length))
99
100        # Cleanup, possibly hang on to persistent connection to target
101        # server.
102        connection_header_value = self.headers.get('Connection', None)
103        if response.will_close or connection_header_value == 'close':
104            conn.close()
105            self.close_connection = True
106        else:
107            # Hang on to this connection for future requests. This isn't
108            # really bulletproof but it's good enough for integration tests.
109            self._connections[target.netloc] = conn
110
111        self.log_message(
112            'Proxied request from %s to %s',
113            client_address,
114            self.path)
115
116    def do_GET(self):
117        self._proxy_request()
118
119    def do_HEAD(self):
120        self._proxy_request()
121
122    def do_POST(self):
123        self._proxy_request()
124
125    def do_PUT(self):
126        self._proxy_request()
127
128    def do_DELETE(self):
129        self._proxy_request()
130
131    def do_PATCH(self):
132        self._proxy_request()
133
134    def do_OPTIONS(self):
135        self._proxy_request()
136
137
138def main():
139    options = parse_args(sys.argv)
140
141    bind_addr = (
142        options.bind_addr,
143        0 if options.port is None else options.port)
144
145    server = http.server.HTTPServer(
146        bind_addr,
147        RequestHandler,
148        bind_and_activate=False)
149    if options.port is None:
150        server.server_port = server.socket.getsockname()[1]
151    else:
152        server.server_port = options.port
153
154    if options.server_socket_fd:
155        server.socket = socket.fromfd(
156            options.server_socket_fd,
157            socket.AF_INET,
158            socket.SOCK_STREAM)
159    else:
160        server.server_bind()
161        server.server_activate()
162
163    print(
164        'Transparent HTTP proxy listening on port',
165        server.server_port,
166        file=sys.stderr)
167    try:
168        server.serve_forever(0.01)
169    except KeyboardInterrupt:
170        server.server_close()
171
172
173def parse_args(argv):
174    parser = optparse.OptionParser(
175        usage='Usage: %prog [OPTIONS]',
176        description=__doc__)
177    parser.add_option(
178        '--bind-addr',
179        default='127.0.0.1',
180        dest='bind_addr',
181        help='By default only bind to loopback')
182    parser.add_option(
183        '--port',
184        dest='port',
185        default=None,
186        type='int',
187        help='If not specified a random port will be used.')
188    parser.add_option(
189        "--fd",
190        dest='server_socket_fd',
191        default=None,
192        type='int',
193        help='A socket FD to use for accept() instead of binding a new one.')
194    options, args = parser.parse_args(argv)
195    if len(args) > 1:
196        parser.error('Unexpected arguments: {}'.format(', '.join(args[1:])))
197    return options
198
199
200if __name__ == '__main__':
201    main()
202