Skip to content

Commit

Permalink
Added unit testing with in-memory database and TestClient.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenMillar-MOJ committed Jul 18, 2024
1 parent db163c7 commit bfb4148
Show file tree
Hide file tree
Showing 16 changed files with 93 additions and 74 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ alembic upgrade head
To run the uvicorn server, use the below code:

```shell
uvicorn app.main:case_api --reload
uvicorn app:case_api --reload
```

# Migrations
Expand Down
2 changes: 1 addition & 1 deletion alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# are written from script.py.mako
# output_encoding = utf-8

sqlalchemy.url = sqlite:///database.db
sqlalchemy.url = None


[post_write_hooks]
Expand Down
3 changes: 3 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app.main import create_app

case_api = create_app()
3 changes: 3 additions & 0 deletions app/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app.db.database import engine

db_engine = engine
8 changes: 6 additions & 2 deletions app/db/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# One line of FastAPI imports here later 👈
from sqlmodel import create_engine
from sqlmodel import create_engine, Session
import os

cwd = os.getcwd()
Expand All @@ -9,3 +8,8 @@

connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, echo=True, connect_args=connect_args)


def get_session():
with Session(engine) as session:
yield session
8 changes: 3 additions & 5 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from .routers import case_information
import os

cwd = os.getcwd()
os.chdir("./app")


def create_app():
app = FastAPI()
app.include_router(case_information.router)
return app


case_api = create_app()
if __name__ == "__main__":
cwd = os.getcwd()
os.chdir("./app")
5 changes: 4 additions & 1 deletion app/migrations/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from logging.config import fileConfig

from app.models.cases import Case # noqa: F401
# This imports all the models
from app.models import * # noqa: F401

from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlmodel import SQLModel
Expand All @@ -10,6 +12,7 @@
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
config.set_main_option('sqlalchemy.url', 'sqlite:///database.db')

# Interpret the config file for Python logging.
# This line sets up loggers basically.
Expand Down
39 changes: 0 additions & 39 deletions app/migrations/versions/2c87ae16567d_init.py

This file was deleted.

2 changes: 1 addition & 1 deletion app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@

from .cases import Case
7 changes: 3 additions & 4 deletions app/models/cases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# One line of FastAPI imports here later 👈
from sqlmodel import Field, SQLModel
from datetime import datetime
from .categories import Categories
Expand All @@ -10,13 +9,13 @@ class BaseCase(SQLModel):


class Case(BaseCase, table=True):
id: int | None = Field(default=None, primary_key=True)
id: int = Field(default=None, primary_key=True)
time: datetime


class CaseCreate(BaseCase):
class CaseRequest(BaseCase):
pass


class CaseRead(SQLModel):
class CaseLookup(SQLModel):
id: int
37 changes: 17 additions & 20 deletions app/routers/case_information.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter, HTTPException
from ..models.cases import CaseCreate, Case
from fastapi import APIRouter, HTTPException, Depends
from ..models.cases import CaseRequest, Case
from datetime import datetime
from sqlmodel import Session, select
from app.db.database import engine
from app.db.database import get_session
import random

router = APIRouter(
Expand All @@ -13,30 +13,27 @@


@router.get("/{case_id}", tags=["cases"])
async def read_case(case_id: str):
with Session(engine) as session:
case = session.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
return case
async def read_case(case_id: str, session: Session = Depends(get_session),):
case = session.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
return case


@router.get("/", tags=["cases"])
async def read_all_cases():
with Session(engine) as session:
cases = session.exec(select(Case)).all()
return cases
async def read_all_cases(session: Session = Depends(get_session)):
cases = session.exec(select(Case)).all()
return cases


def generate_id():
return str(random.randint(1, 100000))


@router.post("/", tags=["cases"], response_model=Case)
def create_case(request: CaseCreate):
with Session(engine) as session:
case = Case(category=request.category, time=datetime.now(), name=request.name, id=generate_id())
session.add(case)
session.commit()
session.refresh(case)
return case
def create_case(request: CaseRequest, session: Session = Depends(get_session)):
case = Case(category=request.category, time=datetime.now(), name=request.name, id=generate_id())
session.add(case)
session.commit()
session.refresh(case)
return case
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[tool.pytest.ini_options]
filterwarnings = [
"ignore",
"default:::app",
]
Empty file added tests/__init__.py
Empty file.
Empty file added tests/cases/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions tests/cases/test_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fastapi.testclient import TestClient
from sqlmodel import Session


def test_create_case(client: TestClient, session: Session):
response = client.post(
"/cases/", json={"category": "Housing", "name": "John Doe"})
case = response.json()

assert response.status_code == 200
assert case["category"] == "Housing"
assert case["name"] == "John Doe"
assert case["id"] is not None


def test_read_case(client: TestClient, session: Session):
response = client.post(
"/cases/", json={"category": "Housing", "name": "John Doe"})
case = response.json()
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from sqlmodel import SQLModel, create_engine, Session, StaticPool
from app import case_api
from app.db.database import get_session
from fastapi.testclient import TestClient


@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session


@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session

case_api.dependency_overrides[get_session] = get_session_override

client = TestClient(case_api)
yield client
case_api.dependency_overrides.clear()

0 comments on commit bfb4148

Please sign in to comment.