Skip to content

Commit

Permalink
Improve worker coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Aug 25, 2016
1 parent 471df5c commit 18fae34
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 35 deletions.
9 changes: 7 additions & 2 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def connection_lost(self, exc):

@asyncio.coroutine
def handle_request(self, message, payload):
self._manager._requests_count += 1
if self.access_log:
now = self._loop.time()

Expand Down Expand Up @@ -118,7 +119,12 @@ def __init__(self, app, router, *,
self._secure_proxy_ssl_header = secure_proxy_ssl_header
self._kwargs = kwargs
self._kwargs.setdefault('logger', app.logger)
self.num_connections = 0
self._requests_count = 0

@property
def requests_count(self):
"""Number of processed requests."""
return self._requests_count

@property
def secure_proxy_ssl_header(self):
Expand All @@ -142,7 +148,6 @@ def finish_connections(self, timeout=None):
self._connections.clear()

def __call__(self):
self.num_connections += 1
return self._handler(
self, self._app, self._router, loop=self._loop,
secure_proxy_ssl_header=self._secure_proxy_ssl_header,
Expand Down
15 changes: 7 additions & 8 deletions aiohttp/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,18 @@ def _run(self):
while self.alive:
self.notify()

if pid == os.getpid() and self.ppid != os.getppid():
cnt = sum(handler.requests_count
for handler in self.servers.values())
if self.cfg.max_requests and cnt > self.cfg.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)

elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
yield from asyncio.sleep(1.0, loop=self.loop)

if self.cfg.max_requests and self.servers:
connections = 0
for _, handler in self.servers.items():
connections += handler.num_connections
if connections > self.cfg.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)
except BaseException:
pass

Expand Down
133 changes: 108 additions & 25 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for aiohttp/worker.py"""
import asyncio
import pathlib
import ssl
from unittest import mock

import pytest
Expand All @@ -18,7 +20,7 @@
class BaseTestWorker:

def __init__(self):
self.servers = []
self.servers = {}
self.exit_code = 0
self.cfg = mock.Mock()
self.cfg.graceful_timeout = 100
Expand All @@ -38,7 +40,9 @@ class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker):

@pytest.fixture(params=PARAMS)
def worker(request):
return request.param()
ret = request.param()
ret.notify = mock.Mock()
return ret


def test_init_process(worker):
Expand Down Expand Up @@ -95,6 +99,7 @@ def test_make_handler(worker):
assert f is worker.wsgi.make_handler.return_value


@asyncio.coroutine
def test__run_ok(worker, loop):
worker.ppid = 1
worker.alive = True
Expand All @@ -103,17 +108,11 @@ def test__run_ok(worker, loop):
sock.cfg_addr = ('localhost', 8080)
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = mock.Mock()
worker.close.return_value = helpers.create_future(loop)
worker.close.return_value.set_result(())
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.notify = mock.Mock()
worker.loop = loop
ret = helpers.create_future(loop)
loop.create_server = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: ret))
ret.set_result(sock)
worker.wsgi.make_handler.return_value.num_connections = 1
loop.create_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 1
worker.cfg.max_requests = 100
worker.cfg.is_ssl = True

Expand All @@ -122,46 +121,50 @@ def test__run_ok(worker, loop):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
loop.run_until_complete(worker._run())
yield from worker._run()

assert worker.notify.called
assert worker.log.info.called
worker.notify.assert_called_with()
worker.log.info.assert_called_with("Parent changed, shutting down: %s",
worker)

args, kwargs = loop.create_server.call_args
assert 'ssl' in kwargs
ctx = kwargs['ssl']
assert ctx is ssl_context


@asyncio.coroutine
def test__run_exc(worker, loop):
with mock.patch('aiohttp.worker.os') as m_os:
m_os.getpid.return_value = 1
m_os.getppid.return_value = 1

worker.servers = [mock.Mock()]
handler = mock.Mock()
handler.requests_count = 0
worker.servers = {mock.Mock(): handler}
worker.ppid = 1
worker.alive = True
worker.sockets = []
worker.log = mock.Mock()
worker.loop = mock.Mock()
worker.notify = mock.Mock()
worker.loop = loop
worker.cfg.is_ssl = False
worker.cfg.max_redirects = 0
worker.cfg.max_requests = 100

with mock.patch('aiohttp.worker.asyncio.sleep') as m_sleep:
slp = helpers.create_future(loop)
slp.set_exception(KeyboardInterrupt)
m_sleep.return_value = slp

worker.close = mock.Mock()
worker.close.return_value = helpers.create_future(loop)
worker.close.return_value.set_result(1)
worker.close = make_mocked_coro(None)

loop.run_until_complete(worker._run())
yield from worker._run()

assert m_sleep.called
assert worker.close.called
m_sleep.assert_called_with(1.0, loop=loop)
worker.close.assert_called_with()


@asyncio.coroutine
def test_close(worker, loop):
srv = mock.Mock()
srv.wait_closed = make_mocked_coro(None)
Expand All @@ -178,11 +181,91 @@ def test_close(worker, loop):
app.shutdown.return_value = helpers.create_future(loop)
app.shutdown.return_value.set_result(None)

loop.run_until_complete(worker.close())
yield from worker.close()
app.shutdown.assert_called_with()
app.cleanup.assert_called_with()
handler.finish_connections.assert_called_with(timeout=95.0)
srv.close.assert_called_with()
assert worker.servers is None

loop.run_until_complete(worker.close())
yield from worker.close()


@asyncio.coroutine
def test__run_ok_no_max_requests(worker, loop):
worker.ppid = 1
worker.alive = True
worker.servers = {}
sock = mock.Mock()
sock.cfg_addr = ('localhost', 8080)
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.loop = loop
loop.create_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 1
worker.cfg.max_requests = 0
worker.cfg.is_ssl = True

ssl_context = mock.Mock()
with mock.patch('ssl.SSLContext', return_value=ssl_context):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
yield from worker._run()

worker.notify.assert_called_with()
worker.log.info.assert_called_with("Parent changed, shutting down: %s",
worker)

args, kwargs = loop.create_server.call_args
assert 'ssl' in kwargs
ctx = kwargs['ssl']
assert ctx is ssl_context


@asyncio.coroutine
def test__run_ok_max_requests_exceeded(worker, loop):
worker.ppid = 1
worker.alive = True
worker.servers = {}
sock = mock.Mock()
sock.cfg_addr = ('localhost', 8080)
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.loop = loop
loop.create_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 15
worker.cfg.max_requests = 10
worker.cfg.is_ssl = True

ssl_context = mock.Mock()
with mock.patch('ssl.SSLContext', return_value=ssl_context):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
yield from worker._run()

worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests, shutting down: %s",
worker)

args, kwargs = loop.create_server.call_args
assert 'ssl' in kwargs
ctx = kwargs['ssl']
assert ctx is ssl_context


def test__create_ssl_context_without_certs_and_ciphers(worker):
here = pathlib.Path(__file__).parent
worker.cfg.ssl_version = ssl.PROTOCOL_SSLv23
worker.cfg.cert_reqs = ssl.CERT_OPTIONAL
worker.cfg.certfile = str(here / 'sample.crt')
worker.cfg.keyfile = str(here / 'sample.key')
worker.cfg.ca_certs = None
worker.cfg.ciphers = None
crt = worker._create_ssl_context(worker.cfg)
assert isinstance(crt, ssl.SSLContext)

0 comments on commit 18fae34

Please sign in to comment.