Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring to create dynamic request model #213

Merged
Merged
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

### Changed

* Refactor to remove hardcoded search request models. Request models are now dynamically created based on the enabled extensions.
([#213](https://github.com/stac-utils/stac-fastapi/pull/213))

### Removed

### Fixed
Expand Down
34 changes: 20 additions & 14 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.openapi.utils import get_openapi
from pydantic import BaseModel
from stac_pydantic import Collection, Item, ItemCollection
from stac_pydantic.api import ConformanceClasses, LandingPage, Search
from stac_pydantic.api import ConformanceClasses, LandingPage
from stac_pydantic.api.collections import Collections
from stac_pydantic.version import STAC_VERSION
from starlette.responses import JSONResponse, Response
Expand All @@ -20,18 +20,17 @@
GeoJSONResponse,
ItemCollectionUri,
ItemUri,
SearchGetRequest,
_create_request_model,
create_request_model,
)
from stac_fastapi.api.openapi import update_openapi
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint

# TODO: make this module not depend on `stac_fastapi.extensions`
from stac_fastapi.extensions.core import FieldsExtension
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
from stac_fastapi.types.config import ApiSettings, Settings
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.search import STACSearch
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest


@attr.s
Expand Down Expand Up @@ -76,9 +75,13 @@ class StacApi:
api_version: str = attr.ib(default="0.1")
stac_version: str = attr.ib(default=STAC_VERSION)
description: str = attr.ib(default="stac-fastapi")
search_request_model: Type[Search] = attr.ib(default=STACSearch)
search_get_request: Type[SearchGetRequest] = attr.ib(default=SearchGetRequest)
item_collection_uri: Type[ItemCollectionUri] = attr.ib(default=ItemCollectionUri)
search_get_request_model: Type[BaseSearchGetRequest] = attr.ib(
default=BaseSearchGetRequest
)
search_post_request_model: Type[BaseSearchPostRequest] = attr.ib(
default=BaseSearchPostRequest
)
pagination_extension = attr.ib(default=TokenPaginationExtension)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware]))

Expand Down Expand Up @@ -176,7 +179,6 @@ def register_post_search(self):
Returns:
None
"""
search_request_model = _create_request_model(self.search_request_model)
fields_ext = self.get_extension(FieldsExtension)
self.router.add_api_route(
name="Search",
Expand All @@ -189,7 +191,7 @@ def register_post_search(self):
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(
self.client.post_search, search_request_model, GeoJSONResponse
self.client.post_search, self.search_post_request_model, GeoJSONResponse
),
)

Expand All @@ -211,7 +213,7 @@ def register_get_search(self):
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.get_search, self.search_get_request, GeoJSONResponse
self.client.get_search, self.search_get_request_model, GeoJSONResponse
),
)

Expand Down Expand Up @@ -261,6 +263,12 @@ def register_get_item_collection(self):
Returns:
None
"""
get_pagination_model = self.get_extension(self.pagination_extension).GET
request_model = create_request_model(
"ItemCollectionURI",
base_model=ItemCollectionUri,
mixins=[get_pagination_model],
)
self.router.add_api_route(
name="Get ItemCollection",
path="/collections/{collection_id}/items",
Expand All @@ -272,9 +280,7 @@ def register_get_item_collection(self):
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.item_collection,
self.item_collection_uri,
self.response_class,
self.client.item_collection, request_model, self.response_class
),
)

Expand Down
195 changes: 108 additions & 87 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,100 @@
"""api request/response models."""

import abc
import importlib
from typing import Dict, Optional, Type, Union
from typing import Optional, Type, Union

import attr
from fastapi import Body, Path
from pydantic import BaseModel, create_model
from pydantic.fields import UndefinedType


def _create_request_model(model: Type[BaseModel]) -> Type[BaseModel]:
from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.search import (
APIRequest,
BaseSearchGetRequest,
BaseSearchPostRequest,
)


def create_request_model(
model_name="SearchGetRequest",
base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest,
extensions: Optional[ApiExtension] = None,
mixins: Optional[Union[BaseModel, APIRequest]] = None,
request_type: Optional[str] = "GET",
) -> Union[Type[BaseModel], APIRequest]:
"""Create a pydantic model for validating request bodies."""
fields = {}
for (k, v) in model.__fields__.items():
# TODO: Filter out fields based on which extensions are present
field_info = v.field_info
body = Body(
None
if isinstance(field_info.default, UndefinedType)
else field_info.default,
default_factory=field_info.default_factory,
alias=field_info.alias,
alias_priority=field_info.alias_priority,
title=field_info.title,
description=field_info.description,
const=field_info.const,
gt=field_info.gt,
ge=field_info.ge,
lt=field_info.lt,
le=field_info.le,
multiple_of=field_info.multiple_of,
min_items=field_info.min_items,
max_items=field_info.max_items,
min_length=field_info.min_length,
max_length=field_info.max_length,
regex=field_info.regex,
extra=field_info.extra,
)
fields[k] = (v.outer_type_, body)
return create_model(model.__name__, **fields, __base__=model)


@attr.s # type:ignore
class APIRequest(abc.ABC):
"""Generic API Request base class."""

@abc.abstractmethod
def kwargs(self) -> Dict:
"""Transform api request params into format which matches the signature of the endpoint."""
...
extension_models = []

# Check extensions for additional parameters to search
for extension in extensions or []:
if extension_model := extension.get_request_model(request_type):
extension_models.append(extension_model)

mixins = mixins or []

models = [base_model] + extension_models + mixins

# Handle GET requests
if all([issubclass(m, APIRequest) for m in models]):
return attr.make_class(model_name, attrs={}, bases=tuple(models))

# Handle POST requests
elif all([issubclass(m, BaseModel) for m in models]):
for model in models:
for (k, v) in model.__fields__.items():
field_info = v.field_info
body = Body(
None
if isinstance(field_info.default, UndefinedType)
else field_info.default,
default_factory=field_info.default_factory,
alias=field_info.alias,
alias_priority=field_info.alias_priority,
title=field_info.title,
description=field_info.description,
const=field_info.const,
gt=field_info.gt,
ge=field_info.ge,
lt=field_info.lt,
le=field_info.le,
multiple_of=field_info.multiple_of,
min_items=field_info.min_items,
max_items=field_info.max_items,
min_length=field_info.min_length,
max_length=field_info.max_length,
regex=field_info.regex,
extra=field_info.extra,
)
fields[k] = (v.outer_type_, body)
return create_model(model_name, **fields, __base__=base_model)

raise TypeError("Mixed Request Model types. Check extension request types.")


def create_get_request_model(
extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest
):
"""Wrap create_request_model to create the GET request model."""
return create_request_model(
"SearchGetRequest",
base_model=BaseSearchGetRequest,
extensions=extensions,
request_type="GET",
)


def create_post_request_model(
extensions, base_model: BaseSearchPostRequest = BaseSearchGetRequest
):
"""Wrap create_request_model to create the POST request model."""
return create_request_model(
"SearchPostRequest",
base_model=BaseSearchPostRequest,
extensions=extensions,
request_type="POST",
)


@attr.s # type:ignore
Expand All @@ -58,76 +103,52 @@ class CollectionUri(APIRequest):

collection_id: str = attr.ib(default=Path(..., description="Collection ID"))

def kwargs(self) -> Dict:
"""kwargs."""
return {"id": self.collection_id}


@attr.s
class ItemUri(CollectionUri):
"""Delete item."""

item_id: str = attr.ib(default=Path(..., description="Item ID"))

def kwargs(self) -> Dict:
"""kwargs."""
return {"collection_id": self.collection_id, "item_id": self.item_id}


@attr.s
class EmptyRequest(APIRequest):
"""Empty request."""

def kwargs(self) -> Dict:
"""kwargs."""
return {}
...


@attr.s
class ItemCollectionUri(CollectionUri):
"""Get item collection."""

limit: int = attr.ib(default=10)
token: str = attr.ib(default=None)

def kwargs(self) -> Dict:
"""kwargs."""
return {
"id": self.collection_id,
"limit": self.limit,
"token": self.token,
}

class POSTTokenPagination(BaseModel):
"""Token pagination model for POST requests."""

token: Optional[str] = None


@attr.s
class SearchGetRequest(APIRequest):
"""GET search request."""

collections: Optional[str] = attr.ib(default=None)
ids: Optional[str] = attr.ib(default=None)
bbox: Optional[str] = attr.ib(default=None)
datetime: Optional[Union[str]] = attr.ib(default=None)
limit: Optional[int] = attr.ib(default=10)
query: Optional[str] = attr.ib(default=None)
class GETTokenPagination(APIRequest):
"""Token pagination for GET requests."""

token: Optional[str] = attr.ib(default=None)
fields: Optional[str] = attr.ib(default=None)
sortby: Optional[str] = attr.ib(default=None)

def kwargs(self) -> Dict:
"""kwargs."""
return {
"collections": self.collections.split(",")
if self.collections
else self.collections,
"ids": self.ids.split(",") if self.ids else self.ids,
"bbox": self.bbox.split(",") if self.bbox else self.bbox,
"datetime": self.datetime,
"limit": self.limit,
"query": self.query,
"token": self.token,
"fields": self.fields.split(",") if self.fields else self.fields,
"sortby": self.sortby.split(",") if self.sortby else self.sortby,
}


class POSTPagination(BaseModel):
"""Page based pagination for POST requests."""

page: Optional[str] = None


@attr.s
class GETPagination(APIRequest):
"""Page based pagination for GET requests."""

page: Optional[str] = attr.ib(default=None)


# Test for ORJSON and use it rather than stdlib JSON where supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .context import ContextExtension
from .fields import FieldsExtension
from .filter import FilterExtension
from .pagination import PaginationExtension, TokenPaginationExtension
from .query import QueryExtension
from .sort import SortExtension
from .transaction import TransactionExtension
Expand All @@ -12,8 +13,10 @@
"ContextExtension",
"FieldsExtension",
"FilterExtension",
"PaginationExtension",
"QueryExtension",
"SortExtension",
"TilesExtension",
"TokenPaginationExtension",
"TransactionExtension",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Fields extension module."""


from .fields import FieldsExtension

__all__ = ["FieldsExtension"]
Loading