Skip to content

Commit

Permalink
refactor: refactor code by tools
Browse files Browse the repository at this point in the history
  • Loading branch information
andiserg committed Feb 6, 2024
1 parent 1866f30 commit d1e0a98
Show file tree
Hide file tree
Showing 26 changed files with 204 additions and 69 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: pre-commit

on: [push, pull_request]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v5
- uses: pre-commit/[email protected]
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Tests
name: tests

on: [push, pull_request]

Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ __pypackages__/
dmypy.json

.idea
/pytest.ini
/pytest.ini
42 changes: 42 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-ast
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-json
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: detect-private-key
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
hooks:
- id: ruff
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: [ "--profile", "black", "--filter-files" ]
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
files: src
exclude: "migrations/"
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.8.0'
hooks:
- id: mypy
additional_dependencies: []
- repo: https://github.com/PyCQA/bandit
rev: 1.7.7
hooks:
- id: bandit
18 changes: 13 additions & 5 deletions src/costy/adapters/db/category_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,27 @@ class CategoryGateway(
def __init__(self, session: AsyncSession):
self.session = session

async def get_category(self, category_id: CategoryId) -> Category:
query = select(Category).where(Category.id == category_id)
return await self.session.scalar(query)
async def get_category(self, category_id: CategoryId) -> Category | None:
query = select(Category).where(
Category.id == category_id # type: ignore
)
result: Category | None = await self.session.scalar(query)
return result

async def save_category(self, category: Category) -> None:
self.session.add(category)
await self.session.flush(objects=[category])

async def delete_category(self, category_id: CategoryId) -> None:
query = delete(Category).where(Category.id == category_id)
query = delete(Category).where(
Category.id == category_id # type: ignore
)
await self.session.execute(query)

async def find_categories(self, user_id: UserId) -> list[Category]:
filter_expr = or_(Category.user_id == user_id, Category.user_id == None)
filter_expr = or_(
Category.user_id == user_id, # type: ignore
Category.user_id == None # type: ignore # noqa: E711
)
query = select(Category).where(filter_expr)
return list(await self.session.scalars(query))
22 changes: 16 additions & 6 deletions src/costy/adapters/db/operation_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,34 @@ class OperationGateway(
def __init__(self, session: AsyncSession):
self.session = session

async def get_operation(self, operation_id: OperationId) -> Operation:
query = select(Operation).where(Operation.id == operation_id)
return await self.session.scalar(query)
async def get_operation(
self, operation_id: OperationId
) -> Operation | None:
query = select(Operation).where(
Operation.id == operation_id # type: ignore
)
result: Operation | None = await self.session.scalar(query)
return result

async def save_operation(self, operation: Operation) -> None:
self.session.add(operation)
await self.session.flush(objects=[operation])

async def delete_operation(self, operation_id: OperationId) -> None:
query = delete(Operation).where(Operation.id == operation_id)
query = delete(Operation).where(
Operation.id == operation_id # type: ignore
)
await self.session.execute(query)

async def find_operations_by_user(
self, user_id: UserId, from_time: int, to_time: int
) -> list[Operation]:
query = (
select(Operation)
.where(Operation.user_id == user_id)
.where(Operation.time >= from_time, Operation.time <= to_time)
.where(Operation.user_id == user_id) # type: ignore
.where(
Operation.time >= from_time, # type: ignore
Operation.time <= to_time # type: ignore
)
)
return list(await self.session.scalars(query))
10 changes: 6 additions & 4 deletions src/costy/adapters/db/user_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ async def save_user(self, user: User) -> None:
await self.session.flush(objects=[user])

async def get_user_by_id(self, user_id: UserId) -> User | None:
query = select(User).where(User.id == user_id)
return await self.session.scalar(query)
query = select(User).where(User.id == user_id) # type: ignore
result: User | None = await self.session.scalar(query)
return result

async def get_user_by_email(self, email: str) -> User | None:
query = select(User).where(User.email == email)
return await self.session.scalar(query)
query = select(User).where(User.email == email) # type: ignore
result: User | None = await self.session.scalar(query)
return result
6 changes: 3 additions & 3 deletions src/costy/application/authenticate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class LoginInputDTO:
password: str


class Authenticate(Interactor[LoginInputDTO, UserId]):
class Authenticate(Interactor[LoginInputDTO, UserId | None]):
def __init__(self, user_gateway: UserReader, uow: UoW):
self.user_gateway = user_gateway
self.uow = uow

async def __call__(self, data: LoginInputDTO) -> UserId:
async def __call__(self, data: LoginInputDTO) -> UserId | None:
user = await self.user_gateway.get_user_by_email(data.email)
# TODO: compare hashed passwords
return user.id
return user.id if user else None
2 changes: 1 addition & 1 deletion src/costy/application/category/create_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ async def __call__(self, data: NewCategoryDTO) -> CategoryId:
await self.category_db_gateway.save_category(category)
category_id = category.id
await self.uow.commit()
return category_id
return category_id # type: ignore
2 changes: 1 addition & 1 deletion src/costy/application/common/category_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def save_category(self, category: Category) -> None:

class CategoryReader(Protocol):
@abstractmethod
async def get_category(self, category_id: CategoryId) -> Category:
async def get_category(self, category_id: CategoryId) -> Category | None:
raise NotImplementedError


Expand Down
2 changes: 1 addition & 1 deletion src/costy/application/common/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@


class Interactor(Generic[InputDTO, OutputDTO]):
def __call__(self, data: InputDTO) -> OutputDTO:
async def __call__(self, data: InputDTO) -> OutputDTO:
raise NotImplementedError
4 changes: 3 additions & 1 deletion src/costy/application/common/operation_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

class OperationReader(Protocol):
@abstractmethod
async def get_operation(self, operation_id: OperationId) -> Operation:
async def get_operation(
self, operation_id: OperationId
) -> Operation | None:
raise NotImplementedError


Expand Down
2 changes: 1 addition & 1 deletion src/costy/application/operation/create_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ async def __call__(self, data: NewOperationDTO) -> OperationId:
await self.operation_db_gateway.save_operation(operation)
operation_id = operation.id
await self.uow.commit()
return operation_id
return operation_id # type: ignore
6 changes: 3 additions & 3 deletions src/costy/application/operation/read_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..common.uow import UoW


class ReadOperation(Interactor[OperationId, Operation]):
class ReadOperation(Interactor[OperationId, Operation | None]):
def __init__(
self,
operation_service: OperationService,
Expand All @@ -20,13 +20,13 @@ def __init__(
self.id_provider = id_provider
self.uow = uow

async def __call__(self, operation_id: OperationId):
async def __call__(self, operation_id: OperationId) -> Operation | None:
user_id = await self.id_provider.get_current_user_id()

operation = await self.operation_db_gateway.get_operation(operation_id)

# TODO: Move to access service
if operation.user_id != user_id:
if operation and operation.user_id != user_id:
raise Exception("User must be the owner of the operation")

return operation
2 changes: 1 addition & 1 deletion src/costy/application/user/create_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ async def __call__(self, data: NewUserDTO) -> UserId:
await self.user_db_gateway.save_user(user)
user_id = user.id
await self.uow.commit()
return user_id
return user_id # type: ignore
2 changes: 1 addition & 1 deletion src/costy/domain/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from costy.domain.models.category import CategoryId
from costy.domain.models.user import UserId

OperationId = NewType("OperationID", int)
OperationId = NewType("OperationId", int)


@dataclass(kw_only=True)
Expand Down
7 changes: 6 additions & 1 deletion src/costy/infrastructure/db/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import registry


Expand Down
2 changes: 1 addition & 1 deletion src/costy/infrastructure/db/migrations/README
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Generic single-database configuration.
Generic single-database configuration.
4 changes: 1 addition & 3 deletions src/costy/infrastructure/db/migrations/env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool

from alembic import context
from sqlalchemy import engine_from_config, pool

from costy.infrastructure.config import get_db_connection_url
from costy.infrastructure.db.main import get_registry
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""init tables
Revision ID: f1c4a04700d3
Revises:
Revises:
Create Date: 2024-01-30 22:55:38.115617
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = 'f1c4a04700d3'
Expand Down
16 changes: 13 additions & 3 deletions src/costy/infrastructure/db/orm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from typing import Type

from sqlalchemy import Column, ForeignKey, Integer, String, Table
Expand All @@ -7,8 +8,10 @@
from costy.domain.models.operation import Operation
from costy.domain.models.user import User

Model = typing.Union[Category, Operation, User]

def create_tables(mapper_registry: registry) -> dict[Type, Table]:

def create_tables(mapper_registry: registry) -> dict[Type[Model], Table]:
return {
User: Table(
"users",
Expand All @@ -25,7 +28,12 @@ def create_tables(mapper_registry: registry) -> dict[Type, Table]:
Column("description", String),
Column("time", Integer, nullable=False),
Column("user_id", Integer, ForeignKey("users.id")),
Column("category_id", Integer, ForeignKey("categories.id"), nullable=True),
Column(
"category_id",
Integer,
ForeignKey("categories.id"),
nullable=True
),
),
Category: Table(
"categories",
Expand All @@ -38,6 +46,8 @@ def create_tables(mapper_registry: registry) -> dict[Type, Table]:
}


def map_tables_to_models(mapper_registry: registry, tables: dict[Type, Table]):
def map_tables_to_models(
mapper_registry: registry, tables: dict[Type[Model], Table]
) -> None:
for model, table in tables.items():
mapper_registry.map_imperatively(model, table)
Loading

0 comments on commit d1e0a98

Please sign in to comment.