Skip to content

Commit

Permalink
feat: test_client() for aiohttp_app
Browse files Browse the repository at this point in the history
  • Loading branch information
ainquel committed Dec 19, 2018
1 parent 10cc011 commit ef7c9ca
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion connexion/apps/aiohttp_app.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
import asyncio
import logging
import os.path
import pkgutil
import sys

from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer

from ..apis.aiohttp_api import AioHttpApi
from ..exceptions import ConnexionException
from ..lifecycle import ConnexionResponse
from ..tests import AbstractClient
from ..utils import is_json_mimetype
from .abstract import AbstractApp

logger = logging.getLogger('connexion.aiohttp_app')


class AioHttpApp(AbstractApp):

api_cls = AioHttpApi

def __init__(self, import_name, only_one_api=False, **kwargs):
super(AioHttpApp, self).__init__(import_name, AioHttpApi, server='aiohttp', **kwargs)
super(AioHttpApp, self).__init__(import_name, self.api_cls, server='aiohttp', **kwargs)
self._only_one_api = only_one_api
self._api_added = False

Expand Down Expand Up @@ -96,3 +103,50 @@ def run(self, port=None, server=None, debug=None, host=None, **options):
web.run_app(self.app, port=self.port, host=self.host, access_log=access_log)
else:
raise Exception('Server {} not recognized'.format(self.server))

def test_client(self):
"""Return a flask's test_client compatible."""
return AioHttpClient.from_app(self)


class AioHttpClient(AbstractClient):
""" A specific test client for aiohttp framework."""

def _request(
self,
method,
url,
**kwargs
):
# code inspired from https://github.com/aio-libs/aiohttp/blob/v3.4.4/aiohttp/pytest_plugin.py#L286
# set the loop in the app,
# and use only this one to avoid loop conflicts
self.app.app._set_loop(None)
loop = self.app.app.loop
client = TestClient(TestServer(self.app.app, loop=loop), loop=loop)
loop.run_until_complete(client.start_server())

@asyncio.coroutine
def _async_request():
nonlocal client
content_type = kwargs.get("content_type")
if content_type:
headers = kwargs.setdefault("headers", {})
if "Content-Type" not in headers:
headers["Content-Type"] = content_type
kwargs["params"] = kwargs.get("query_string")
res = yield from client.request(method.upper(), url, **kwargs)
body = yield from res.read()
print(res.content_type)
if is_json_mimetype(res.content_type):
body = body + b"\n"
print("test_client", body)
return ConnexionResponse(
status_code=res.status,
headers=res.headers,
body=body
)

response = loop.run_until_complete(_async_request())
loop.run_until_complete(client.close())
return response

0 comments on commit ef7c9ca

Please sign in to comment.