From 15b0d4f7b9b6fa92edf0f8c40606231838c85fd3 Mon Sep 17 00:00:00 2001 From: Pedro Crespo-Valero <32402063+pcrespov@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:55:15 +0100 Subject: [PATCH] base rest --- .../src/models_library/rest_base.py | 19 ++++++++++++++++ .../src/models_library/rest_ordering.py | 22 ++++++++----------- .../src/models_library/rest_pagination.py | 3 ++- .../servicelib/aiohttp/requests_validation.py | 13 +---------- .../tests/aiohttp/test_requests_validation.py | 21 ++++++++++-------- .../folders/_models.py | 6 ++--- .../products/_handlers.py | 11 ++++------ .../simcore_service_webserver/tags/schemas.py | 6 ++--- .../wallets/_handlers.py | 8 +++---- .../workspaces/_models.py | 6 ++--- 10 files changed, 59 insertions(+), 56 deletions(-) create mode 100644 packages/models-library/src/models_library/rest_base.py diff --git a/packages/models-library/src/models_library/rest_base.py b/packages/models-library/src/models_library/rest_base.py new file mode 100644 index 000000000000..a6b24ef63824 --- /dev/null +++ b/packages/models-library/src/models_library/rest_base.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel, Extra + + +class RequestParameters(BaseModel): + """ + Base model for any type of request parameters, + i.e. context, path, query, headers + """ + + def as_params(self, **export_options) -> dict[str, str]: + data = self.dict(**export_options) + return {k: f"{v}" for k, v in data.items()} + + +class StrictRequestParameters(RequestParameters): + """Use a base class for context, path and query parameters""" + + class Config: + extra = Extra.forbid # strict diff --git a/packages/models-library/src/models_library/rest_ordering.py b/packages/models-library/src/models_library/rest_ordering.py index ab71b84f986b..d0358059e614 100644 --- a/packages/models-library/src/models_library/rest_ordering.py +++ b/packages/models-library/src/models_library/rest_ordering.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, Json, validator from .basic_types import IDStr +from .rest_base import RequestParameters class OrderDirection(str, Enum): @@ -30,7 +31,7 @@ class Config: } -class BaseOrderByQueryParams(BaseModel): +class _BaseOrderByQueryParams(RequestParameters): order_by: OrderBy | None = None @@ -39,7 +40,7 @@ def create_order_by_query_model_classes( sortable_fields: set[str], default_order_by: OrderBy, override_direction_default: bool = False, -) -> tuple[type[BaseOrderByQueryParams], type[BaseModel]]: +) -> tuple[type[_BaseOrderByQueryParams], type[BaseModel]]: """ Factory to create an uniform model used as ordering parameters in a query @@ -52,7 +53,7 @@ def create_order_by_query_model_classes( msg_direction_options = "|".join(sorted(OrderDirection)) order_by_example: dict[str, Any] = OrderBy.Config.schema_extra["example"] - class _JsonOrderBy(OrderBy): + class _OrderByJsonable(OrderBy): direction: OrderDirection = Field( default=default_order_by.direction if override_direction_default @@ -87,9 +88,9 @@ def _check_if_sortable_field(cls, v): f"The default sorting order is '{default_order_by.direction.value}' on '{default_order_by.field}'." ) - class _RequestValidatorModel(BaseOrderByQueryParams): + class _RequestValidatorModel(_BaseOrderByQueryParams): # Used in rest handler for verification - order_by: _JsonOrderBy = Field( + order_by: _OrderByJsonable = Field( default=default_order_by, description=description, ) @@ -109,14 +110,9 @@ class _OpenapiModel(BaseModel): description=description, ) - @validator("order_by", allow_reuse=True) - @classmethod - def _validate_json_content(cls, v): - if v: - _RequestValidatorModel(order_by=v) - return v - class Config: - schema_extra: ClassVar[dict[str, Any]] = {"title": "Order By Parameters"} + schema_extra: ClassVar[dict[str, Any]] = { + "title": "Order By Parameters", + } return _RequestValidatorModel, _OpenapiModel diff --git a/packages/models-library/src/models_library/rest_pagination.py b/packages/models-library/src/models_library/rest_pagination.py index 89c90cb1c2d3..0213fb4f8a53 100644 --- a/packages/models-library/src/models_library/rest_pagination.py +++ b/packages/models-library/src/models_library/rest_pagination.py @@ -13,6 +13,7 @@ ) from pydantic.generics import GenericModel +from .rest_base import RequestParameters from .utils.common_validators import none_to_empty_list_pre_validator # Default limit values @@ -29,7 +30,7 @@ class PageLimitInt(ConstrainedInt): DEFAULT_NUMBER_OF_ITEMS_PER_PAGE: Final[PageLimitInt] = parse_obj_as(PageLimitInt, 20) -class PageQueryParameters(BaseModel): +class PageQueryParameters(RequestParameters): """Use as pagination options in query parameters""" limit: PageLimitInt = Field( diff --git a/packages/service-library/src/servicelib/aiohttp/requests_validation.py b/packages/service-library/src/servicelib/aiohttp/requests_validation.py index 085243c5d26d..e5cef8ecd960 100644 --- a/packages/service-library/src/servicelib/aiohttp/requests_validation.py +++ b/packages/service-library/src/servicelib/aiohttp/requests_validation.py @@ -14,7 +14,7 @@ from aiohttp import web from models_library.utils.json_serialization import json_dumps -from pydantic import BaseModel, Extra, ValidationError, parse_obj_as +from pydantic import BaseModel, ValidationError, parse_obj_as from ..mimetype_constants import MIMETYPE_APPLICATION_JSON from . import status @@ -24,17 +24,6 @@ UnionOfModelTypes: TypeAlias = Union[type[ModelClass], type[ModelClass]] # noqa: UP007 -class RequestParams(BaseModel): - ... - - -class StrictRequestParams(BaseModel): - """Use a base class for context, path and query parameters""" - - class Config: - extra = Extra.forbid # strict - - @contextmanager def handle_validation_as_http_error( *, error_msg_template: str, resource_name: str, use_error_v1: bool diff --git a/packages/service-library/tests/aiohttp/test_requests_validation.py b/packages/service-library/tests/aiohttp/test_requests_validation.py index 31b423f51ad4..bef0a6e327ad 100644 --- a/packages/service-library/tests/aiohttp/test_requests_validation.py +++ b/packages/service-library/tests/aiohttp/test_requests_validation.py @@ -10,6 +10,7 @@ from aiohttp import web from aiohttp.test_utils import TestClient, make_mocked_request from faker import Faker +from models_library.rest_base import RequestParameters, StrictRequestParameters from models_library.rest_ordering import ( OrderBy, OrderDirection, @@ -19,8 +20,6 @@ from pydantic import BaseModel, Field from servicelib.aiohttp import status from servicelib.aiohttp.requests_validation import ( - RequestParams, - StrictRequestParams, parse_request_body_as, parse_request_headers_as, parse_request_path_parameters_as, @@ -38,7 +37,7 @@ def jsonable_encoder(data): return json.loads(json_dumps(data)) -class MyRequestContext(RequestParams): +class MyRequestContext(RequestParameters): user_id: int = Field(alias=RQT_USERID_KEY) secret: str = Field(alias=APP_SECRET_KEY) @@ -47,7 +46,7 @@ def create_fake(cls, faker: Faker): return cls(user_id=faker.pyint(), secret=faker.password()) -class MyRequestPathParams(StrictRequestParams): +class MyRequestPathParams(StrictRequestParameters): project_uuid: UUID @classmethod @@ -55,7 +54,7 @@ def create_fake(cls, faker: Faker): return cls(project_uuid=faker.uuid4()) -class MyRequestQueryParams(RequestParams): +class MyRequestQueryParams(RequestParameters): is_ok: bool = True label: str @@ -64,7 +63,7 @@ def create_fake(cls, faker: Faker): return cls(is_ok=faker.pybool(), label=faker.word()) -class MyRequestHeadersParams(RequestParams): +class MyRequestHeadersParams(RequestParameters): user_agent: str = Field(alias="X-Simcore-User-Agent") optional_header: str | None = Field(default=None, alias="X-Simcore-Optional-Header") @@ -364,7 +363,7 @@ async def test_parse_request_with_invalid_headers_params( def test_parse_request_query_parameters_as_with_order_by_query_models(): - OrderByModel, _ = create_order_by_query_model_classes( + OrderByModel, OrderByModelOAS = create_order_by_query_model_classes( sortable_fields={"modified", "name"}, default_order_by=OrderBy(field="name") ) @@ -377,5 +376,9 @@ def test_parse_request_query_parameters_as_with_order_by_query_models(): query_params = parse_request_query_parameters_as(OrderByModel, request) assert query_params.order_by == expected - assert OrderByModel.schema()["properties"]["order_by"]["type"] == "string" - assert OrderByModel.schema()["properties"]["order_by"]["format"] == "json-string" + expected_schema = {"type": "string", "format": "json-string"} + assert { + k: v + for k, v in OrderByModel.schema()["properties"]["order_by"] + if k in expected + } == expected_schema diff --git a/services/web/server/src/simcore_service_webserver/folders/_models.py b/services/web/server/src/simcore_service_webserver/folders/_models.py index 532010a1f063..ffce2d72f9b3 100644 --- a/services/web/server/src/simcore_service_webserver/folders/_models.py +++ b/services/web/server/src/simcore_service_webserver/folders/_models.py @@ -2,6 +2,7 @@ from models_library.basic_types import IDStr from models_library.folders import FolderID +from models_library.rest_base import RequestParameters, StrictRequestParameters from models_library.rest_filters import Filters, FiltersQueryParameters from models_library.rest_ordering import OrderBy, OrderDirection from models_library.rest_pagination import PageQueryParameters @@ -13,7 +14,6 @@ ) from models_library.workspaces import WorkspaceID from pydantic import BaseModel, Extra, Field, Json, validator -from servicelib.aiohttp.requests_validation import RequestParams, StrictRequestParams from servicelib.request_keys import RQT_USERID_KEY from .._constants import RQ_PRODUCT_KEY @@ -21,12 +21,12 @@ _logger = logging.getLogger(__name__) -class FoldersRequestContext(RequestParams): +class FoldersRequestContext(RequestParameters): user_id: UserID = Field(..., alias=RQT_USERID_KEY) # type: ignore[literal-required] product_name: str = Field(..., alias=RQ_PRODUCT_KEY) # type: ignore[literal-required] -class FoldersPathParams(StrictRequestParams): +class FoldersPathParams(StrictRequestParameters): folder_id: FolderID diff --git a/services/web/server/src/simcore_service_webserver/products/_handlers.py b/services/web/server/src/simcore_service_webserver/products/_handlers.py index bfdabef6d6f1..1d7e4e4bc573 100644 --- a/services/web/server/src/simcore_service_webserver/products/_handlers.py +++ b/services/web/server/src/simcore_service_webserver/products/_handlers.py @@ -4,13 +4,10 @@ from aiohttp import web from models_library.api_schemas_webserver.product import GetCreditPrice, GetProduct from models_library.basic_types import IDStr +from models_library.rest_base import RequestParameters, StrictRequestParameters from models_library.users import UserID from pydantic import Extra, Field -from servicelib.aiohttp.requests_validation import ( - RequestParams, - StrictRequestParams, - parse_request_path_parameters_as, -) +from servicelib.aiohttp.requests_validation import parse_request_path_parameters_as from servicelib.request_keys import RQT_USERID_KEY from simcore_service_webserver.utils_aiohttp import envelope_json_response @@ -27,7 +24,7 @@ _logger = logging.getLogger(__name__) -class _ProductsRequestContext(RequestParams): +class _ProductsRequestContext(RequestParameters): user_id: UserID = Field(..., alias=RQT_USERID_KEY) # type: ignore[literal-required] product_name: str = Field(..., alias=RQ_PRODUCT_KEY) # type: ignore[literal-required] @@ -49,7 +46,7 @@ async def _get_current_product_price(request: web.Request): return envelope_json_response(credit_price) -class _ProductsRequestParams(StrictRequestParams): +class _ProductsRequestParams(StrictRequestParameters): product_name: IDStr | Literal["current"] diff --git a/services/web/server/src/simcore_service_webserver/tags/schemas.py b/services/web/server/src/simcore_service_webserver/tags/schemas.py index 01663e0d3371..c9d4a9d90a1e 100644 --- a/services/web/server/src/simcore_service_webserver/tags/schemas.py +++ b/services/web/server/src/simcore_service_webserver/tags/schemas.py @@ -2,18 +2,18 @@ from datetime import datetime from models_library.api_schemas_webserver._base import InputSchema, OutputSchema +from models_library.rest_base import RequestParameters, StrictRequestParameters from models_library.users import GroupID, UserID from pydantic import ConstrainedStr, Field, PositiveInt -from servicelib.aiohttp.requests_validation import RequestParams, StrictRequestParams from servicelib.request_keys import RQT_USERID_KEY from simcore_postgres_database.utils_tags import TagDict -class TagRequestContext(RequestParams): +class TagRequestContext(RequestParameters): user_id: UserID = Field(..., alias=RQT_USERID_KEY) # type: ignore[literal-required] -class TagPathParams(StrictRequestParams): +class TagPathParams(StrictRequestParameters): tag_id: PositiveInt diff --git a/services/web/server/src/simcore_service_webserver/wallets/_handlers.py b/services/web/server/src/simcore_service_webserver/wallets/_handlers.py index dc6855f2c019..954ed6b263b1 100644 --- a/services/web/server/src/simcore_service_webserver/wallets/_handlers.py +++ b/services/web/server/src/simcore_service_webserver/wallets/_handlers.py @@ -9,12 +9,11 @@ WalletGetWithAvailableCredits, ) from models_library.error_codes import create_error_code +from models_library.rest_base import RequestParameters, StrictRequestParameters from models_library.users import UserID from models_library.wallets import WalletID from pydantic import Field from servicelib.aiohttp.requests_validation import ( - RequestParams, - StrictRequestParams, parse_request_body_as, parse_request_path_parameters_as, ) @@ -106,19 +105,18 @@ async def wrapper(request: web.Request) -> web.StreamResponse: return wrapper -# # wallets COLLECTION ------------------------- # routes = web.RouteTableDef() -class WalletsRequestContext(RequestParams): +class WalletsRequestContext(RequestParameters): user_id: UserID = Field(..., alias=RQT_USERID_KEY) # type: ignore[literal-required] product_name: str = Field(..., alias=RQ_PRODUCT_KEY) # type: ignore[literal-required] -class WalletsPathParams(StrictRequestParams): +class WalletsPathParams(StrictRequestParameters): wallet_id: WalletID diff --git a/services/web/server/src/simcore_service_webserver/workspaces/_models.py b/services/web/server/src/simcore_service_webserver/workspaces/_models.py index 1cd037e34bd6..5a80bc678f8d 100644 --- a/services/web/server/src/simcore_service_webserver/workspaces/_models.py +++ b/services/web/server/src/simcore_service_webserver/workspaces/_models.py @@ -1,6 +1,7 @@ import logging from models_library.basic_types import IDStr +from models_library.rest_base import RequestParameters, StrictRequestParameters from models_library.rest_filters import Filters, FiltersQueryParameters from models_library.rest_ordering import ( OrderBy, @@ -12,7 +13,6 @@ from models_library.users import GroupID, UserID from models_library.workspaces import WorkspaceID from pydantic import BaseModel, Extra, Field -from servicelib.aiohttp.requests_validation import RequestParams, StrictRequestParams from servicelib.request_keys import RQT_USERID_KEY from .._constants import RQ_PRODUCT_KEY @@ -20,12 +20,12 @@ _logger = logging.getLogger(__name__) -class WorkspacesRequestContext(RequestParams): +class WorkspacesRequestContext(RequestParameters): user_id: UserID = Field(..., alias=RQT_USERID_KEY) # type: ignore[literal-required] product_name: str = Field(..., alias=RQ_PRODUCT_KEY) # type: ignore[literal-required] -class WorkspacesPathParams(StrictRequestParams): +class WorkspacesPathParams(StrictRequestParameters): workspace_id: WorkspaceID