From ef7c9ca722046e02308bdf256bdce2277e5da6c4 Mon Sep 17 00:00:00 2001 From: panpann Date: Wed, 19 Dec 2018 20:56:31 +0100 Subject: [PATCH] feat: test_client() for aiohttp_app --- connexion/apps/aiohttp_app.py | 56 ++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/connexion/apps/aiohttp_app.py b/connexion/apps/aiohttp_app.py index b5acffbd4..da4cf4967 100644 --- a/connexion/apps/aiohttp_app.py +++ b/connexion/apps/aiohttp_app.py @@ -1,12 +1,17 @@ +import asyncio import logging import os.path import pkgutil import sys from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer from ..apis.aiohttp_api import AioHttpApi from ..exceptions import ConnexionException +from ..lifecycle import ConnexionResponse +from ..tests import AbstractClient +from ..utils import is_json_mimetype from .abstract import AbstractApp logger = logging.getLogger('connexion.aiohttp_app') @@ -14,8 +19,10 @@ class AioHttpApp(AbstractApp): + api_cls = AioHttpApi + def __init__(self, import_name, only_one_api=False, **kwargs): - super(AioHttpApp, self).__init__(import_name, AioHttpApi, server='aiohttp', **kwargs) + super(AioHttpApp, self).__init__(import_name, self.api_cls, server='aiohttp', **kwargs) self._only_one_api = only_one_api self._api_added = False @@ -96,3 +103,50 @@ def run(self, port=None, server=None, debug=None, host=None, **options): web.run_app(self.app, port=self.port, host=self.host, access_log=access_log) else: raise Exception('Server {} not recognized'.format(self.server)) + + def test_client(self): + """Return a flask's test_client compatible.""" + return AioHttpClient.from_app(self) + + +class AioHttpClient(AbstractClient): + """ A specific test client for aiohttp framework.""" + + def _request( + self, + method, + url, + **kwargs + ): + # code inspired from https://github.com/aio-libs/aiohttp/blob/v3.4.4/aiohttp/pytest_plugin.py#L286 + # set the loop in the app, + # and use only this one to avoid loop conflicts + self.app.app._set_loop(None) + loop = self.app.app.loop + client = TestClient(TestServer(self.app.app, loop=loop), loop=loop) + loop.run_until_complete(client.start_server()) + + @asyncio.coroutine + def _async_request(): + nonlocal client + content_type = kwargs.get("content_type") + if content_type: + headers = kwargs.setdefault("headers", {}) + if "Content-Type" not in headers: + headers["Content-Type"] = content_type + kwargs["params"] = kwargs.get("query_string") + res = yield from client.request(method.upper(), url, **kwargs) + body = yield from res.read() + print(res.content_type) + if is_json_mimetype(res.content_type): + body = body + b"\n" + print("test_client", body) + return ConnexionResponse( + status_code=res.status, + headers=res.headers, + body=body + ) + + response = loop.run_until_complete(_async_request()) + loop.run_until_complete(client.close()) + return response