diff --git a/Makefile b/Makefile index 5dc56f9406..bb44f08a1d 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ release: package twine upload dist/* coverage: - coverage3 run --source=proxy tests.py + coverage3 run --source=proxy,plugin_examples tests.py coverage3 html open htmlcov/index.html diff --git a/plugin_examples.py b/plugin_examples.py index 5a1f123e1c..7ae2d73d31 100644 --- a/plugin_examples.py +++ b/plugin_examples.py @@ -49,7 +49,12 @@ class ProposedRestApiPlugin(proxy.HttpProxyBasePlugin): Used to test and develop client side applications without need of an actual upstream REST API server. - Returns proposed REST API mock responses to the client.""" + Returns proposed REST API mock responses to the client + without establishing upstream connection. + + Note: This plugin won't work if your client is making + HTTPS connection to api.example.com. + """ API_SERVER = b'api.example.com' @@ -76,6 +81,11 @@ class ProposedRestApiPlugin(proxy.HttpProxyBasePlugin): } def before_upstream_connection(self, request: proxy.HttpParser) -> Optional[proxy.HttpParser]: + # Return None to disable establishing connection to upstream + # Most likely our api.example.com won't even exist under development scenario + return None + + def handle_client_request(self, request: proxy.HttpParser) -> Optional[proxy.HttpParser]: if request.host != self.API_SERVER: return request assert request.path @@ -94,9 +104,6 @@ def before_upstream_connection(self, request: proxy.HttpParser) -> Optional[prox )) return None - def handle_client_request(self, request: proxy.HttpParser) -> Optional[proxy.HttpParser]: - return request - def handle_upstream_chunk(self, chunk: bytes) -> bytes: return chunk @@ -107,15 +114,16 @@ def on_upstream_connection_close(self) -> None: class RedirectToCustomServerPlugin(proxy.HttpProxyBasePlugin): """Modifies client request to redirect all incoming requests to a fixed server address.""" - UPSTREAM_SERVER = b'http://localhost:8899' + UPSTREAM_SERVER = b'http://localhost:8899/' def before_upstream_connection(self, request: proxy.HttpParser) -> Optional[proxy.HttpParser]: # Redirect all non-https requests to inbuilt WebServer. if request.method != proxy.httpMethods.CONNECT: - request.url = urlparse.urlsplit(self.UPSTREAM_SERVER) - # This command will re-parse modified url and - # update host, port, path fields - request.set_line_attributes() + request.set_url(self.UPSTREAM_SERVER) + # Update Host header too, otherwise upstream can reject our request + if request.has_header(b'Host'): + request.del_header(b'Host') + request.add_header(b'Host', urlparse.urlsplit(self.UPSTREAM_SERVER).netloc) return request def handle_client_request(self, request: proxy.HttpParser) -> Optional[proxy.HttpParser]: diff --git a/proxy.py b/proxy.py index fc2064c5b7..b0e2a05a13 100755 --- a/proxy.py +++ b/proxy.py @@ -598,6 +598,10 @@ def del_headers(self, headers: List[bytes]) -> None: for key in headers: self.del_header(key.lower()) + def set_url(self, url: bytes) -> None: + self.url = urlparse.urlsplit(url) + self.set_line_attributes() + def set_line_attributes(self) -> None: if self.type == httpParserTypes.REQUEST_PARSER: if self.method == httpMethods.CONNECT and self.url: @@ -700,13 +704,12 @@ def process_line(self, raw: bytes) -> None: line = raw.split(WHITESPACE) if self.type == httpParserTypes.REQUEST_PARSER: self.method = line[0].upper() - self.url = urlparse.urlsplit(line[1]) + self.set_url(line[1]) self.version = line[2] else: self.version = line[0] self.code = line[1] self.reason = WHITESPACE.join(line[2:]) - self.set_line_attributes() def process_header(self, raw: bytes) -> None: parts = raw.split(COLON) @@ -1226,7 +1229,7 @@ def get_descriptors( def write_to_descriptors(self, w: List[Union[int, _HasFileno]]) -> bool: if self.request.has_upstream_server() and \ self.server and not self.server.closed and \ - self.server.buffer_size() > 0 and \ + self.server.has_buffer() and \ self.server.connection in w: logger.debug('Server is write ready, flushing buffer') try: @@ -1378,16 +1381,20 @@ def on_request_complete(self) -> Union[socket.socket, bool]: if not self.request.has_upstream_server(): return False + self.authenticate() + # Note: can raise HttpRequestRejected exception # Invoke plugin.before_upstream_connection + do_connect = True for plugin in self.plugins.values(): r = plugin.before_upstream_connection(self.request) if r is None: - return False + do_connect = False + break self.request = r - self.authenticate() - self.connect_upstream() + if do_connect: + self.connect_upstream() for plugin in self.plugins.values(): assert self.request is not None @@ -1444,7 +1451,7 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection # connection headers are meant for communication between client and # first intercepting proxy. - self.request.add_headers([(b'Via', b'1.1 proxy.py v%s' % version)]) + self.request.add_headers([(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)]) # Disable args.disable_headers before dispatching to upstream self.server.queue( self.request.build( @@ -2095,7 +2102,8 @@ def initialize(self) -> None: """Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins.""" conn = self.optionally_wrap_socket(self.client.connection) conn.setblocking(False) - self.client = TcpClientConnection(conn=conn, addr=self.addr) + if self.config.encryption_enabled(): + self.client = TcpClientConnection(conn=conn, addr=self.addr) if b'ProtocolHandlerPlugin' in self.config.plugins: for klass in self.config.plugins[b'ProtocolHandlerPlugin']: instance = klass(self.config, self.client, self.request) diff --git a/tests.py b/tests.py index e009f3da3f..8111f0f718 100644 --- a/tests.py +++ b/tests.py @@ -10,6 +10,7 @@ import base64 import errno import ipaddress +import json import logging import multiprocessing import os @@ -20,9 +21,11 @@ import unittest import uuid from contextlib import closing -from typing import Dict, Optional, Tuple, Union, Any, cast +from typing import Dict, Optional, Tuple, Union, Any, cast, Type from unittest import mock +from urllib import parse as urlparse +import plugin_examples import proxy if os.name != 'nt': @@ -44,6 +47,23 @@ def get_available_port() -> int: return int(port) +def get_plugin_by_test_name(test_name: str) -> Type[proxy.HttpProxyBasePlugin]: + plugin: Type[proxy.HttpProxyBasePlugin] = plugin_examples.ModifyPostDataPlugin + if test_name == 'test_modify_post_data_plugin': + plugin = plugin_examples.ModifyPostDataPlugin + elif test_name == 'test_proposed_rest_api_plugin': + plugin = plugin_examples.ProposedRestApiPlugin + elif test_name == 'test_redirect_to_custom_server_plugin': + plugin = plugin_examples.RedirectToCustomServerPlugin + elif test_name == 'test_filter_by_upstream_host_plugin': + plugin = plugin_examples.FilterByUpstreamHostPlugin + elif test_name == 'test_cache_responses_plugin': + plugin = plugin_examples.CacheResponsesPlugin + elif test_name == 'test_man_in_the_middle_plugin': + plugin = plugin_examples.ManInTheMiddlePlugin + return plugin + + class TestTextBytes(unittest.TestCase): def test_text(self) -> None: @@ -1048,8 +1068,11 @@ def assert_tunnel_response( def test_http_tunnel(self, mock_server_connection: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True - server.buffer_size.return_value = 0 - server.has_buffer.side_effect = [False, False, False, True] + + def has_buffer() -> bool: + return cast(bool, server.queue.called) + + server.has_buffer.side_effect = has_buffer self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey( fileobj=self._conn, @@ -1570,6 +1593,233 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( mock_server_conn.assert_not_called() +class TestHttpProxyPluginExamples(unittest.TestCase): + + @mock.patch('selectors.DefaultSelector') + @mock.patch('socket.fromfd') + def setUp(self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock) -> None: + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self.config = proxy.ProtocolConfig() + self.plugin = mock.MagicMock() + + self.mock_fromfd = mock_fromfd + self.mock_selector = mock_selector + + plugin = get_plugin_by_test_name(self._testMethodName) + + self.config.plugins = { + b'ProtocolHandlerPlugin': [proxy.HttpProxyPlugin], + b'HttpProxyBasePlugin': [plugin], + } + self._conn = mock_fromfd.return_value + self.proxy = proxy.ProtocolHandler( + self.fileno, self._addr, config=self.config) + self.proxy.initialize() + + @mock.patch('proxy.TcpServerConnection') + def test_modify_post_data_plugin(self, mock_server_conn: mock.Mock) -> None: + original = b'{"key": "value"}' + modified = b'{"key": "modified"}' + + self._conn.recv.return_value = proxy.build_http_request( + b'POST', b'http://httpbin.org/post', + headers={ + b'Host': b'httpbin.org', + b'Content-Type': b'application/x-www-form-urlencoded', + b'Content-Length': proxy.bytes_(len(original)), + }, + body=original + ) + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + + self.proxy.run_once() + mock_server_conn.assert_called_with('httpbin.org', 80) + mock_server_conn.return_value.queue.assert_called_with( + proxy.build_http_request( + b'POST', b'/post', + headers={ + b'Host': b'httpbin.org', + b'Content-Length': proxy.bytes_(len(modified)), + b'Content-Type': b'application/json', + b'Via': b'1.1 %s' % proxy.PROXY_AGENT_HEADER_VALUE, + }, + body=modified + ) + ) + + @mock.patch('proxy.TcpServerConnection') + def test_proposed_rest_api_plugin( + self, mock_server_conn: mock.Mock) -> None: + path = b'/v1/users/' + self._conn.recv.return_value = proxy.build_http_request( + b'GET', b'http://%s%s' % (plugin_examples.ProposedRestApiPlugin.API_SERVER, path), + headers={ + b'Host': plugin_examples.ProposedRestApiPlugin.API_SERVER, + } + ) + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + self.proxy.run_once() + + mock_server_conn.assert_not_called() + self.assertEqual( + self.proxy.client.buffer, + proxy.build_http_response( + proxy.httpStatusCodes.OK, reason=b'OK', + headers={b'Content-Type': b'application/json'}, + body=proxy.bytes_(json.dumps(plugin_examples.ProposedRestApiPlugin.REST_API_SPEC[path])) + )) + + @mock.patch('proxy.TcpServerConnection') + def test_redirect_to_custom_server_plugin( + self, mock_server_conn: mock.Mock) -> None: + request = proxy.build_http_request( + b'GET', b'http://example.org/get', + headers={ + b'Host': b'example.org', + } + ) + self._conn.recv.return_value = request + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + self.proxy.run_once() + + upstream = urlparse.urlsplit( + plugin_examples.RedirectToCustomServerPlugin.UPSTREAM_SERVER) + mock_server_conn.assert_called_with('localhost', 8899) + mock_server_conn.return_value.queue.assert_called_with( + proxy.build_http_request( + b'GET', upstream.path, + headers={ + b'Host': upstream.netloc, + b'Via': b'1.1 %s' % proxy.PROXY_AGENT_HEADER_VALUE, + } + ) + ) + + @mock.patch('proxy.TcpServerConnection') + def test_filter_by_upstream_host_plugin( + self, mock_server_conn: mock.Mock) -> None: + request = proxy.build_http_request( + b'GET', b'http://google.com/', + headers={ + b'Host': b'google.com', + } + ) + self._conn.recv.return_value = request + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + self.proxy.run_once() + + mock_server_conn.assert_not_called() + self.assertEqual( + self.proxy.client.buffer, + proxy.build_http_response( + proxy.httpStatusCodes.I_AM_A_TEAPOT, + reason=b'I\'m a tea pot', + headers={ + proxy.PROXY_AGENT_HEADER_KEY: proxy.PROXY_AGENT_HEADER_VALUE + }, + ) + ) + + @mock.patch('proxy.TcpServerConnection') + def test_cache_responses_plugin( + self, mock_server_conn: mock.Mock) -> None: + pass + + @mock.patch('proxy.TcpServerConnection') + def test_man_in_the_middle_plugin( + self, mock_server_conn: mock.Mock) -> None: + request = proxy.build_http_request( + b'GET', b'http://super.secure/', + headers={ + b'Host': b'super.secure', + } + ) + self._conn.recv.return_value = request + + server = mock_server_conn.return_value + server.connect.return_value = True + + def has_buffer() -> bool: + return cast(bool, server.queue.called) + + def closed() -> bool: + return not server.connect.called + + server.has_buffer.side_effect = has_buffer + type(server).closed = mock.PropertyMock(side_effect=closed) + + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], + [(selectors.SelectorKey( + fileobj=server.connection, + fd=server.connection.fileno, + events=selectors.EVENT_WRITE, + data=None), selectors.EVENT_WRITE)], + [(selectors.SelectorKey( + fileobj=server.connection, + fd=server.connection.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + + # Client read + self.proxy.run_once() + mock_server_conn.assert_called_with('super.secure', 80) + server.connect.assert_called_once() + queued_request = \ + proxy.build_http_request( + b'GET', b'/', + headers={ + b'Host': b'super.secure', + b'Via': b'1.1 %s' % proxy.PROXY_AGENT_HEADER_VALUE + } + ) + server.queue.assert_called_once_with(queued_request) + + # Server write + self.proxy.run_once() + server.flush.assert_called_once() + + # Server read + server.recv.return_value = \ + proxy.build_http_response( + proxy.httpStatusCodes.OK, + reason=b'OK', body=b'Original Response From Upstream') + self.proxy.run_once() + self.assertEqual( + self.proxy.client.buffer, + proxy.build_http_response( + proxy.httpStatusCodes.OK, + reason=b'OK', body=b'Hello from man in the middle') + ) + + class TestHttpProxyTlsInterception(unittest.TestCase): @mock.patch('ssl.wrap_socket') @@ -1701,6 +1951,175 @@ def mock_connection() -> Any: self.assertEqual(self.proxy_plugin.return_value.client._conn, self.mock_ssl_wrap.return_value) +class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): + + @mock.patch('ssl.wrap_socket') + @mock.patch('ssl.create_default_context') + @mock.patch('proxy.TcpServerConnection') + @mock.patch('subprocess.Popen') + @mock.patch('selectors.DefaultSelector') + @mock.patch('socket.fromfd') + def setUp(self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_popen: mock.Mock, + mock_server_conn: mock.Mock, + mock_ssl_context: mock.Mock, + mock_ssl_wrap: mock.Mock) -> None: + self.mock_fromfd = mock_fromfd + self.mock_selector = mock_selector + self.mock_popen = mock_popen + self.mock_server_conn = mock_server_conn + self.mock_ssl_context = mock_ssl_context + self.mock_ssl_wrap = mock_ssl_wrap + + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self.config = proxy.ProtocolConfig( + ca_cert_file='ca-cert.pem', + ca_key_file='ca-key.pem', + ca_signing_key_file='ca-signing-key.pem',) + self.plugin = mock.MagicMock() + + plugin = get_plugin_by_test_name(self._testMethodName) + + self.config.plugins = { + b'ProtocolHandlerPlugin': [proxy.HttpProxyPlugin], + b'HttpProxyBasePlugin': [plugin], + } + self._conn = mock.MagicMock(spec=socket.socket) + mock_fromfd.return_value = self._conn + self.proxy = proxy.ProtocolHandler( + self.fileno, self._addr, config=self.config) + self.proxy.initialize() + + self.server = self.mock_server_conn.return_value + + self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection + self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_wrap.return_value = self.client_ssl_connection + + def has_buffer() -> bool: + return cast(bool, self.server.queue.called) + + def closed() -> bool: + return not self.server.connect.called + + def mock_connection() -> Any: + if self.mock_ssl_context.return_value.wrap_socket.called: + return self.server_ssl_connection + return self._conn + + self.server.has_buffer.side_effect = has_buffer + type(self.server).closed = mock.PropertyMock(side_effect=closed) + type(self.server).connection = mock.PropertyMock(side_effect=mock_connection) + + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], + [(selectors.SelectorKey( + fileobj=self.client_ssl_connection, + fd=self.client_ssl_connection.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], + [(selectors.SelectorKey( + fileobj=self.server_ssl_connection, + fd=self.server_ssl_connection.fileno, + events=selectors.EVENT_WRITE, + data=None), selectors.EVENT_WRITE)], + [(selectors.SelectorKey( + fileobj=self.server_ssl_connection, + fd=self.server_ssl_connection.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + + # Connect + def send(raw: bytes) -> int: + return len(raw) + + self._conn.send.side_effect = send + self._conn.recv.return_value = proxy.build_http_request( + proxy.httpMethods.CONNECT, b'uni.corn:443' + ) + self.proxy.run_once() + + self.mock_popen.assert_called() + self.mock_server_conn.assert_called_once_with('uni.corn', 443) + self.server.connect.assert_called() + self.assertEqual(self.proxy.client.connection, self.client_ssl_connection) + self.assertEqual(self.server.connection, self.server_ssl_connection) + self._conn.send.assert_called_with( + proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT + ) + self.assertEqual(self.proxy.client.buffer, b'') + + def test_modify_post_data_plugin(self) -> None: + original = b'{"key": "value"}' + modified = b'{"key": "modified"}' + self.client_ssl_connection.recv.return_value = proxy.build_http_request( + b'POST', b'/', + headers={ + b'Host': b'uni.corn', + b'Content-Type': b'application/x-www-form-urlencoded', + b'Content-Length': proxy.bytes_(len(original)), + }, + body=original + ) + self.proxy.run_once() + self.server.queue.assert_called_with( + proxy.build_http_request( + b'POST', b'/', + headers={ + b'Host': b'uni.corn', + b'Content-Length': proxy.bytes_(len(modified)), + b'Content-Type': b'application/json', + }, + body=modified + ) + ) + + @mock.patch('proxy.TcpServerConnection') + def test_cache_responses_plugin( + self, mock_server_conn: mock.Mock) -> None: + pass + + @mock.patch('proxy.TcpServerConnection') + def test_man_in_the_middle_plugin( + self, mock_server_conn: mock.Mock) -> None: + request = proxy.build_http_request( + b'GET', b'/', + headers={ + b'Host': b'uni.corn', + } + ) + self.client_ssl_connection.recv.return_value = request + + # Client read + self.proxy.run_once() + self.server.queue.assert_called_once_with(request) + + # Server write + self.proxy.run_once() + self.server.flush.assert_called_once() + + # Server read + self.server.recv.return_value = \ + proxy.build_http_response( + proxy.httpStatusCodes.OK, + reason=b'OK', body=b'Original Response From Upstream') + self.proxy.run_once() + self.assertEqual( + self.proxy.client.buffer, + proxy.build_http_response( + proxy.httpStatusCodes.OK, + reason=b'OK', body=b'Hello from man in the middle') + ) + + class TestHttpRequestRejected(unittest.TestCase): def setUp(self) -> None: