Skip to content

Commit

Permalink
refactor: change TestClient from function to class
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Aug 22, 2024
1 parent 4f244c6 commit da1c668
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 79 deletions.
22 changes: 8 additions & 14 deletions examples/normal/test_main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
import contextlib

import pytest
from httpx import ASGITransport, AsyncClient
from httpx import AsyncClient
from main import app

from fastapi_cdn_host.utils import TestClient, TestClientType


@pytest.fixture(scope="session")
def anyio_backend():
def anyio_backend() -> str:
return "asyncio"


@contextlib.asynccontextmanager
async def TestClient(app, base_url="http://test", **kw):
async with AsyncClient(transport=ASGITransport(app), base_url=base_url, **kw) as c:
yield c


@pytest.fixture(scope="session")
async def client():
async def client() -> TestClientType:
async with TestClient(app) as c:
yield c


@pytest.mark.anyio
async def test_home(client: AsyncClient):
async def test_home(client: AsyncClient) -> None:
response = await client.get("/")
response2 = await client.get("")
assert response.status_code == 307
Expand All @@ -37,14 +31,14 @@ async def test_home(client: AsyncClient):


@pytest.mark.anyio
async def test_get(client: AsyncClient):
async def test_get(client: AsyncClient) -> None:
response = await client.get("/app")
assert response.status_code == 200
assert isinstance(response.json()["routes"], str)


@pytest.mark.anyio
async def test_post(client: AsyncClient):
async def test_post(client: AsyncClient) -> None:
response = await client.post("/test", json={"a": 1})
assert response.status_code == 200
data = response.json()
Expand Down
29 changes: 7 additions & 22 deletions examples/offline/test_main.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,23 @@
import sys
from typing import AsyncGenerator

import pytest
from httpx import AsyncClient
from main import app

if sys.version_info >= (3, 10):
from asynctor.utils import client_manager
else:
from contextlib import asynccontextmanager

from httpx import ASGITransport

@asynccontextmanager
async def client_manager(app, base_url="http://test", **kw):
kw.pop("mount_lifespan", None)
kw.update(transport=ASGITransport(app), base_url=base_url)
async with AsyncClient(**kw) as c:
yield c
from fastapi_cdn_host.utils import TestClient, TestClientType


@pytest.fixture(scope="session")
def anyio_backend():
def anyio_backend() -> str:
return "asyncio"


@pytest.fixture(scope="session")
async def client() -> AsyncGenerator[AsyncClient, None]:
async with client_manager(app, mount_lifespan=False) as c:
async def client() -> TestClientType:
async with TestClient(app) as c:
yield c


@pytest.mark.anyio
async def test_home(client: AsyncClient):
async def test_home(client: AsyncClient) -> None:
response = await client.get("/")
response2 = await client.get("")
assert response.status_code == 307
Expand All @@ -46,14 +31,14 @@ async def test_home(client: AsyncClient):


@pytest.mark.anyio
async def test_get(client: AsyncClient):
async def test_get(client: AsyncClient) -> None:
response = await client.get("/app")
assert response.status_code == 200
assert isinstance(response.json()["routes"], str)


@pytest.mark.anyio
async def test_post(client: AsyncClient):
async def test_post(client: AsyncClient) -> None:
response = await client.post("/test", json={"a": "1", "b": 2})
assert response.status_code == 200
data = response.json()
Expand Down
69 changes: 44 additions & 25 deletions fastapi_cdn_host/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Generator, Union
from typing import AsyncGenerator, Generator, Optional, Union

import anyio
import asyncer
import typer
from rich import print
from rich.progress import Progress, SpinnerColumn
Expand All @@ -22,15 +23,17 @@ def run_shell(cmd: str) -> None:
print(f"--> {cmd}")
command = shlex.split(cmd)
cmd_env = None
index = 0
for i, c in enumerate(command):
if "=" not in c:
index = i
break
name, value = c.split("=")
if cmd_env is None:
cmd_env = os.environ.copy()
cmd_env[name] = value
if i != 0:
command = command[i:]
if cmd_env is not None:
command = command[index:]
subprocess.run(command, env=cmd_env) # nosec:B603


Expand Down Expand Up @@ -73,21 +76,19 @@ def patch_app(path: Union[str, Path], remove=True) -> Generator[Path, None, None


@asynccontextmanager
async def percentbar(
msg: str, seconds=5, color: str = "cyan", transient=False
) -> AsyncGenerator[None, None]:
async def percentbar(msg: str, **kwargs) -> AsyncGenerator[None, None]:
"""Progressbar with custom font color
:param msg: prompt message.
:param seconds: max seconds of tasks.
:param color: font color,e.g.: 'blue'.
:param transient: whether clean progressbar after finished.
"""
seconds = kwargs.pop("seconds", 5)
color = kwargs.pop("color", "")
total = seconds * 100

async def play(progress, task, expected=1 / 2, thod=0.8) -> None:
async def play(progress, task) -> None:
expected, threshold = 1 / 2, 0.8
cost = seconds * expected
quick = int(total * thod)
quick = int(total * threshold)
delay = cost / quick
for i in range(quick):
await anyio.sleep(delay)
Expand All @@ -99,25 +100,34 @@ async def play(progress, task, expected=1 / 2, thod=0.8) -> None:
await anyio.sleep(delay)
progress.advance(task)

with Progress(transient=transient) as progress:
task = progress.add_task(f"[{color}]{msg}:", total=total)
async with anyio.create_task_group() as tg:
tg.start_soon(play, progress, task)
with Progress(**kwargs) as p:
if color:
t = p.add_task(f"[{color}]{msg}:", total=total)
else:
t = p.add_task(f"{msg}:", total=total)
async with asyncer.create_task_group() as tg:
tg.soonify(play)(p, t)
yield
tg.cancel_scope.cancel()
progress.update(task, completed=total)
p.update(t, completed=total)


@contextmanager
def spinnerbar(msg, color="yellow", transient=True) -> Generator[None, None, None]:
def spinnerbar(
msg, color: Optional[str] = None, **kwargs
) -> Generator[None, None, None]:
kwargs.setdefault("transient", True)
with Progress(
SpinnerColumn(), *Progress.get_default_columns(), transient=transient
SpinnerColumn(), *Progress.get_default_columns(), **kwargs
) as progress:
progress.add_task(f"[{color}]{msg}...", total=None)
if color:
progress.add_task(f"[{color}]{msg}...", total=None)
else:
progress.add_task(f"{msg}...", total=None)
yield


async def download_offline_assets(dirname="static") -> None:
async def download_offline_assets(dirname: str) -> None:
cwd = await anyio.Path.cwd()
static_root = cwd / dirname
if not await static_root.exists():
Expand All @@ -131,7 +141,7 @@ async def download_offline_assets(dirname="static") -> None:
async with percentbar("Comparing cdn hosts response speed"):
urls = await CdnHostBuilder.sniff_the_fastest()
print("Result:", urls)
with spinnerbar("Fetching files from cdn"):
with spinnerbar("Fetching files from cdn", color="yellow"):
url_list = [urls.js, urls.css, urls.redoc]
contents = await HttpSniff.bulk_fetch(url_list, get_content=True)
for url, content in zip(url_list, contents):
Expand All @@ -149,13 +159,23 @@ def dev(
path: Annotated[
Path,
typer.Argument(
help="A path to a Python file or package directory (with [blue]__init__.py[/blue] file) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried."
help=(
"A path to a Python file or package directory"
" (with [blue]__init__.py[/blue] file)"
" containing a [bold]FastAPI[/bold] app."
" If not provided, a default set of paths will be tried."
)
),
],
port: Annotated[
Union[int, None],
typer.Option(
help="The port to serve on. You would normally have a termination proxy on top (another program) handling HTTPS on port [blue]443[/blue] and HTTP on port [blue]80[/blue], transferring the communication to your app."
help=(
"The port to serve on."
" You would normally have a termination proxy on top (another program)"
" handling HTTPS on port [blue]443[/blue] and HTTP on port [blue]80[/blue],"
" transferring the communication to your app."
)
),
] = None,
remove: Annotated[
Expand All @@ -170,8 +190,7 @@ def dev(
] = False,
):
if str(path) == "offline":
# TODO: download assets to local
anyio.run(download_offline_assets)
asyncer.runnify(download_offline_assets)(dirname="static")
return
with patch_app(path, remove) as file:
mode = "run" if prod else "dev"
Expand Down
41 changes: 24 additions & 17 deletions fastapi_cdn_host/utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
import calendar
import sys
from contextlib import asynccontextmanager
from datetime import datetime
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Optional, TypeAlias

from fastapi import HTTPException, Request, status
from fastapi import FastAPI, HTTPException, Request, status
from httpx import ASGITransport, AsyncClient

TestClientType: TypeAlias = AsyncGenerator[AsyncClient, None]

@asynccontextmanager
async def TestClient(
app, base_url="http://test", **kw
) -> AsyncGenerator[AsyncClient, None]:

class TestClient(AsyncClient):
"""Async test client
Usage::
>>> from main import app
>>> @pytest.fixture(scope='session')
... async def client() -> AsyncClient:
... async with TestClient(app) as c:
... yield c
Example::
```py
from main import app
assert isinstance(app, FastAPI)
@pytest.fixture(scope='session')
async def client() -> TestClientType:
async with TestClient(app) as c:
yield c
```
"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url=base_url, **kw) as c:
yield c

__test__ = False

def __init__(self, app: FastAPI, base_url="http://test", **kw) -> None:
transport = ASGITransport(app=app) # type:ignore[arg-type]
super().__init__(transport=transport, base_url=base_url, **kw)


class ParamLock:
Expand Down Expand Up @@ -95,7 +102,7 @@ def weekday_lock(request: Request, name="day", exclude_localhost=True) -> None:
Usage::
>>> import fastapi_cdn_host
>>> from fastapi import FastAPI
>>> app = FastAPI(openapi_url='/v1/api.json')
>>> app = FastAPI(openapi_url="/v1/api.json")
>>> fastapi_cdn_host.patch_docs(app, lock=fastapi_cdn_host.weekday_lock)
"""
WeekdayLock(name, exclude_localhost)(request)
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ python = "^3.8"
fastapi = {version = ">=0.100"}
httpx = {version = ">=0.23"}
fastapi-cli = {version = "^0.0.5", optional = true}
asyncer = "^0.0.7"

[tool.poetry.extras]
all = ["fastapi-cli"]
Expand Down

0 comments on commit da1c668

Please sign in to comment.