From 7bae7b99e4b26b070705397c327406de8b5c6c25 Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Tue, 10 Aug 2021 11:56:48 +0100 Subject: [PATCH 1/9] refactoring to create dynamic request model --- stac_fastapi/api/stac_fastapi/api/app.py | 45 +++- stac_fastapi/api/stac_fastapi/api/models.py | 165 ++++++------- .../extensions/core/fields/__init__.py | 6 + .../extensions/core/{ => fields}/fields.py | 6 + .../extensions/core/fields/request.py | 71 ++++++ .../extensions/core/query/__init__.py | 5 + .../extensions/core/{ => query}/query.py | 5 + .../extensions/core/query/request.py | 21 ++ .../extensions/core/sort/__init__.py | 5 + .../extensions/core/sort/request.py | 23 ++ .../extensions/core/{ => sort}/sort.py | 5 + .../types/stac_fastapi/types/extension.py | 12 + .../types/stac_fastapi/types/search.py | 225 +++++++++++++----- 13 files changed, 438 insertions(+), 156 deletions(-) create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py rename stac_fastapi/extensions/stac_fastapi/extensions/core/{ => fields}/fields.py (90%) create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py rename stac_fastapi/extensions/stac_fastapi/extensions/core/{ => query}/query.py (83%) create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py rename stac_fastapi/extensions/stac_fastapi/extensions/core/{ => sort}/sort.py (82%) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 919d66813..0e0993343 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -7,7 +7,7 @@ from fastapi.openapi.utils import get_openapi from pydantic import BaseModel from stac_pydantic import Collection, Item, ItemCollection -from stac_pydantic.api import ConformanceClasses, LandingPage, Search +from stac_pydantic.api import ConformanceClasses, LandingPage from stac_pydantic.api.collections import Collections from stac_pydantic.version import STAC_VERSION from starlette.responses import JSONResponse, Response @@ -17,9 +17,10 @@ APIRequest, CollectionUri, EmptyRequest, + GETTokenPagination, ItemCollectionUri, ItemUri, - SearchGetRequest, + POSTTokenPagination, _create_request_model, ) from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint @@ -29,7 +30,7 @@ from stac_fastapi.types.config import ApiSettings, Settings from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient from stac_fastapi.types.extension import ApiExtension -from stac_fastapi.types.search import STACSearch +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest @attr.s @@ -68,7 +69,14 @@ class StacApi: api_version: str = attr.ib(default="0.1") stac_version: str = attr.ib(default=STAC_VERSION) description: str = attr.ib(default="stac-fastapi") - search_request_model: Type[Search] = attr.ib(default=STACSearch) + search_get_request_model: Type[BaseSearchGetRequest] = attr.ib( + default=BaseSearchGetRequest + ) + search_post_request_model: Type[BaseSearchPostRequest] = attr.ib( + default=BaseSearchPostRequest + ) + get_pagination_model: Type[APIRequest] = attr.ib(default=GETTokenPagination) + post_pagination_model: Type[BaseModel] = attr.ib(default=POSTTokenPagination) response_class: Type[Response] = attr.ib(default=JSONResponse) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: @@ -160,7 +168,13 @@ def register_post_search(self): Returns: None """ - search_request_model = _create_request_model(self.search_request_model) + search_request_model = _create_request_model( + "SearchPostRequest", + base_model=self.search_post_request_model, + extensions=self.extensions, + mixins=[self.post_pagination_model], + request_type="POST", + ) fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", @@ -183,6 +197,14 @@ def register_get_search(self): Returns: None """ + search_request_model = _create_request_model( + "SearchGetRequest", + self.search_get_request_model, + self.extensions, + mixins=[self.get_pagination_model], + request_type="GET", + ) + fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", @@ -194,7 +216,9 @@ def register_get_search(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.get_search, SearchGetRequest), + endpoint=self._create_endpoint( + self.client.get_search, search_request_model + ), ) def register_get_collections(self): @@ -239,6 +263,11 @@ def register_get_item_collection(self): Returns: None """ + request_model = _create_request_model( + "ItemCollectionURI", + base_model=ItemCollectionUri, + mixins=[GETTokenPagination], + ) self.router.add_api_route( name="Get ItemCollection", path="/collections/{collectionId}/items", @@ -249,9 +278,7 @@ def register_get_item_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( - self.client.item_collection, ItemCollectionUri - ), + endpoint=self._create_endpoint(self.client.item_collection, request_model), ) def register_core(self): diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index f44ae0b38..613594258 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,54 +1,71 @@ """api request/response models.""" -import abc -from typing import Dict, Optional, Type, Union +from typing import Optional, Type, Union import attr from fastapi import Body, Path from pydantic import BaseModel, create_model from pydantic.fields import UndefinedType +from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.search import APIRequest, BaseSearchGetRequest -def _create_request_model(model: Type[BaseModel]) -> Type[BaseModel]: + +def _create_request_model( + model_name="SearchGetRequest", + base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest, + extensions: Optional[ApiExtension] = None, + mixins: Optional[Union[BaseModel, APIRequest]] = None, + request_type: Optional[str] = "GET", +) -> Union[Type[BaseModel], APIRequest]: """Create a pydantic model for validating request bodies.""" fields = {} - for (k, v) in model.__fields__.items(): - # TODO: Filter out fields based on which extensions are present - field_info = v.field_info - body = Body( - None - if isinstance(field_info.default, UndefinedType) - else field_info.default, - default_factory=field_info.default_factory, - alias=field_info.alias, - alias_priority=field_info.alias_priority, - title=field_info.title, - description=field_info.description, - const=field_info.const, - gt=field_info.gt, - ge=field_info.ge, - lt=field_info.lt, - le=field_info.le, - multiple_of=field_info.multiple_of, - min_items=field_info.min_items, - max_items=field_info.max_items, - min_length=field_info.min_length, - max_length=field_info.max_length, - regex=field_info.regex, - extra=field_info.extra, - ) - fields[k] = (v.outer_type_, body) - return create_model(model.__name__, **fields, __base__=model) - - -@attr.s # type:ignore -class APIRequest(abc.ABC): - """Generic API Request base class.""" - - @abc.abstractmethod - def kwargs(self) -> Dict: - """Transform api request params into format which matches the signature of the endpoint.""" - ... + extension_models = [] + + # Check extensions for additional parameters to search + for extension in extensions or []: + if extension_model := extension.get_request_model(request_type): + extension_models.append(extension_model) + + mixins = mixins or [] + + models = [base_model] + extension_models + mixins + + # Handle GET requests + if all([issubclass(m, APIRequest) for m in models]): + return attr.make_class(model_name, attrs={}, bases=tuple(models)) + + # Handle POST requests + elif all([issubclass(m, BaseModel) for m in models]): + for model in models: + for (k, v) in model.__fields__.items(): + field_info = v.field_info + body = Body( + None + if isinstance(field_info.default, UndefinedType) + else field_info.default, + default_factory=field_info.default_factory, + alias=field_info.alias, + alias_priority=field_info.alias_priority, + title=field_info.title, + description=field_info.description, + const=field_info.const, + gt=field_info.gt, + ge=field_info.ge, + lt=field_info.lt, + le=field_info.le, + multiple_of=field_info.multiple_of, + min_items=field_info.min_items, + max_items=field_info.max_items, + min_length=field_info.min_length, + max_length=field_info.max_length, + regex=field_info.regex, + extra=field_info.extra, + ) + fields[k] = (v.outer_type_, body) + return create_model(model_name, **fields, __base__=base_model) + + raise TypeError("Mixed Request Model types. Check extension request types.") @attr.s # type:ignore @@ -57,10 +74,6 @@ class CollectionUri(APIRequest): collectionId: str = attr.ib(default=Path(..., description="Collection ID")) - def kwargs(self) -> Dict: - """kwargs.""" - return {"id": self.collectionId} - @attr.s class ItemUri(CollectionUri): @@ -68,18 +81,12 @@ class ItemUri(CollectionUri): itemId: str = attr.ib(default=Path(..., description="Item ID")) - def kwargs(self) -> Dict: - """kwargs.""" - return {"collection_id": self.collectionId, "item_id": self.itemId} - @attr.s class EmptyRequest(APIRequest): """Empty request.""" - def kwargs(self) -> Dict: - """kwargs.""" - return {} + ... @attr.s @@ -87,43 +94,29 @@ class ItemCollectionUri(CollectionUri): """Get item collection.""" limit: int = attr.ib(default=10) - token: str = attr.ib(default=None) - def kwargs(self) -> Dict: - """kwargs.""" - return { - "id": self.collectionId, - "limit": self.limit, - "token": self.token, - } + +class POSTTokenPagination(BaseModel): + """Token pagination model for POST requests.""" + + token: Optional[str] = None @attr.s -class SearchGetRequest(APIRequest): - """GET search request.""" - - collections: Optional[str] = attr.ib(default=None) - ids: Optional[str] = attr.ib(default=None) - bbox: Optional[str] = attr.ib(default=None) - datetime: Optional[Union[str]] = attr.ib(default=None) - limit: Optional[int] = attr.ib(default=10) - query: Optional[str] = attr.ib(default=None) +class GETTokenPagination(APIRequest): + """Token pagination for GET requests.""" + token: Optional[str] = attr.ib(default=None) - fields: Optional[str] = attr.ib(default=None) - sortby: Optional[str] = attr.ib(default=None) - - def kwargs(self) -> Dict: - """kwargs.""" - return { - "collections": self.collections.split(",") - if self.collections - else self.collections, - "ids": self.ids.split(",") if self.ids else self.ids, - "bbox": self.bbox.split(",") if self.bbox else self.bbox, - "datetime": self.datetime, - "limit": self.limit, - "query": self.query, - "token": self.token, - "fields": self.fields.split(",") if self.fields else self.fields, - "sortby": self.sortby.split(",") if self.sortby else self.sortby, - } + + +class POSTPagination(BaseModel): + """Page based pagination for POST requests.""" + + page: Optional[str] = None + + +@attr.s +class GETPagination(APIRequest): + """Page based pagination for GET requests.""" + + page: Optional[str] = attr.ib(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py new file mode 100644 index 000000000..b9a246b63 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py @@ -0,0 +1,6 @@ +"""Fields extension module.""" + + +from .fields import FieldsExtension + +__all__ = ["FieldsExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py similarity index 90% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/fields.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py index 17db33f37..aaea4787f 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py @@ -1,4 +1,5 @@ """fields extension.""" + from typing import Set import attr @@ -6,6 +7,8 @@ from stac_fastapi.types.extension import ApiExtension +from .request import FieldsExtensionGetRequest, FieldsExtensionPostRequest + @attr.s class FieldsExtension(ApiExtension): @@ -23,6 +26,9 @@ class FieldsExtension(ApiExtension): """ + GET = FieldsExtensionGetRequest + POST = FieldsExtensionPostRequest + default_includes: Set[str] = attr.ib( default=attr.Factory( lambda: { diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py new file mode 100644 index 000000000..52ea3af2c --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -0,0 +1,71 @@ +"""Request models for the fields extension.""" + +from typing import Dict, Optional, Set + +import attr +from pydantic import BaseModel, Field + +from stac_fastapi.types.config import Settings +from stac_fastapi.types.search import APIRequest, str2list + + +class PostFieldsExtension(BaseModel): + """FieldsExtension. + + Attributes: + include: set of fields to include. + exclude: set of fields to exclude. + """ + + include: Optional[Set[str]] = set() + exclude: Optional[Set[str]] = set() + + @staticmethod + def _get_field_dict(fields: Optional[Set[str]]) -> Dict: + """Pydantic include/excludes notation. + + Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export + Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude + """ + field_dict = {} + for field in fields or []: + if "." in field: + parent, key = field.split(".") + if parent not in field_dict: + field_dict[parent] = {key} + else: + field_dict[parent].add(key) + else: + field_dict[field] = ... # type:ignore + return field_dict + + @property + def filter_fields(self) -> Dict: + """Create pydantic include/exclude expression. + + Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed + to the API + Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude + """ + # Always include default_includes, even if they + # exist in the exclude list. + include = (self.include or set()) - (self.exclude or set()) + include |= Settings.get().default_includes or set() + + return { + "include": self._get_field_dict(include), + "exclude": self._get_field_dict(self.exclude), + } + + +@attr.s +class FieldsExtensionGetRequest(APIRequest): + """Additional fields for the GET request.""" + + fields: Optional[str] = attr.ib(default=None, converter=str2list) + + +class FieldsExtensionPostRequest(BaseModel): + """Additional fields and schema for the POST request.""" + + fields: Optional[PostFieldsExtension] = Field(PostFieldsExtension()) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py new file mode 100644 index 000000000..5bbe70595 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py @@ -0,0 +1,5 @@ +"""Query extension module.""" + +from .query import QueryExtension + +__all__ = ["QueryExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py similarity index 83% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/query.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py index 4960a6f01..7cf177acf 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py @@ -4,6 +4,8 @@ from stac_fastapi.types.extension import ApiExtension +from .request import QueryExtensionGetRequest, QueryExtensionPostRequest + @attr.s class QueryExtension(ApiExtension): @@ -15,6 +17,9 @@ class QueryExtension(ApiExtension): https://github.com/radiantearth/stac-api-spec/blob/master/item-search/README.md#query """ + GET = QueryExtensionGetRequest + POST = QueryExtensionPostRequest + def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py new file mode 100644 index 000000000..84da48116 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -0,0 +1,21 @@ +"""Request model for the Query extension.""" + +from typing import Optional + +import attr +from pydantic import BaseModel + +from stac_fastapi.types.search import APIRequest + + +@attr.s +class QueryExtensionGetRequest(APIRequest): + """Query Extension GET request model.""" + + query: Optional[str] = attr.ib(default=None) + + +class QueryExtensionPostRequest(BaseModel): + """Query Extension POST request model.""" + + query: Optional[str] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py new file mode 100644 index 000000000..b6996b018 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py @@ -0,0 +1,5 @@ +"""Sort extension module.""" + +from .sort import SortExtension + +__all__ = ["SortExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py new file mode 100644 index 000000000..c19f40dba --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -0,0 +1,23 @@ +# encoding: utf-8 +"""Request model for the Sort Extension.""" + +from typing import List, Optional + +import attr +from pydantic import BaseModel +from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel + +from stac_fastapi.types.search import APIRequest, str2list + + +@attr.s +class SortExtensionGetRequest(APIRequest): + """Sortby Parameter for GET requests.""" + + sortby: Optional[str] = attr.ib(default=None, converter=str2list) + + +class SortExtensionPostRequest(BaseModel): + """Sortby parameter for POST requests.""" + + sortby: Optional[List[PostSortModel]] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py similarity index 82% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/sort.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py index 1eb067b8c..13ad28628 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py @@ -4,6 +4,8 @@ from stac_fastapi.types.extension import ApiExtension +from .request import SortExtensionGetRequest, SortExtensionPostRequest + @attr.s class SortExtension(ApiExtension): @@ -15,6 +17,9 @@ class SortExtension(ApiExtension): https://github.com/radiantearth/stac-api-spec/blob/master/item-search/README.md#sort """ + GET = SortExtensionGetRequest + POST = SortExtensionPostRequest + def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. diff --git a/stac_fastapi/types/stac_fastapi/types/extension.py b/stac_fastapi/types/stac_fastapi/types/extension.py index 144790999..86ce5cddb 100644 --- a/stac_fastapi/types/stac_fastapi/types/extension.py +++ b/stac_fastapi/types/stac_fastapi/types/extension.py @@ -1,14 +1,26 @@ """base api extension.""" import abc +from typing import Optional import attr from fastapi import FastAPI +from pydantic import BaseModel @attr.s class ApiExtension(abc.ABC): """Abstract base class for defining API extensions.""" + GET = None + POST = None + + def get_request_model(self, verb: Optional[str] = "GET") -> Optional[BaseModel]: + """Return the request model for the extension.method. + + The model can differ based on HTTP verb + """ + return getattr(self, verb) + @abc.abstractmethod def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index d9e8d4c0e..c55173ffa 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -3,18 +3,28 @@ # TODO: replace with stac-pydantic """ +import abc import operator +from datetime import datetime from enum import auto from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Set, Union - -from pydantic import Field, root_validator -from stac_pydantic.api import Search -from stac_pydantic.api.extensions.fields import FieldsExtension as FieldsBase +from typing import Any, Callable, Dict, List, Optional, Union + +import attr +from geojson_pydantic.geometries import ( + LineString, + MultiLineString, + MultiPoint, + MultiPolygon, + Point, + Polygon, + _GeometryBase, +) +from pydantic import BaseModel, validator +from pydantic.datetime_parse import parse_datetime +from stac_pydantic.shared import BBox from stac_pydantic.utils import AutoValueEnum -from stac_fastapi.types.config import Settings - # Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 NumType = Union[float, int] @@ -40,67 +50,160 @@ def operator(self) -> Callable[[Any, Any], bool]: return getattr(operator, self._value_) -class FieldsExtension(FieldsBase): - """FieldsExtension. +def str2list(x: str) -> Optional[List]: + """Convert string to list base on , delimiter.""" + if x: + return x.split(",") + + +@attr.s # type:ignore +class APIRequest(abc.ABC): + """Generic API Request base class.""" + + def kwargs(self) -> Dict: + """Transform api request params into format which matches the signature of the endpoint.""" + return self.__dict__ + + +@attr.s +class BaseSearchGetRequest(APIRequest): + """Base arguments for GET Request.""" - Attributes: - include: set of fields to include. - exclude: set of fields to exclude. + collections: Optional[str] = attr.ib(default=None, converter=str2list) + ids: Optional[str] = attr.ib(default=None, converter=str2list) + bbox: Optional[str] = attr.ib(default=None, converter=str2list) + intersects: Optional[str] = attr.ib(default=None, converter=str2list) + datetime: Optional[Union[str]] = attr.ib(default=None) + limit: Optional[int] = attr.ib(default=10) + + +class BaseSearchPostRequest(BaseModel): + """Search model. + + Replace base model in STAC-pydantic as it includes additional fields, + not in the core model. + https://github.com/radiantearth/stac-api-spec/tree/master/item-search#query-parameter-table + + PR to fix this: + https://github.com/stac-utils/stac-pydantic/pull/100 """ - include: Optional[Set[str]] = set() - exclude: Optional[Set[str]] = set() + collections: Optional[List[str]] + ids: Optional[List[str]] + bbox: Optional[BBox] + intersects: Optional[ + Union[Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon] + ] + datetime: Optional[str] + limit: int = 10 - @staticmethod - def _get_field_dict(fields: Optional[Set[str]]) -> Dict: - """Pydantic include/excludes notation. + @property + def start_date(self) -> Optional[datetime]: + """Extract the start date from the datetime string.""" + if not self.datetime: + return + + values = self.datetime.split("/") + if len(values) == 1: + return None + if values[0] == "..": + return None + return parse_datetime(values[0]) - Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - field_dict = {} - for field in fields or []: - if "." in field: - parent, key = field.split(".") - if parent not in field_dict: - field_dict[parent] = {key} - else: - field_dict[parent].add(key) + @property + def end_date(self) -> Optional[datetime]: + """Extract the end date from the datetime string.""" + if not self.datetime: + return + + values = self.datetime.split("/") + if len(values) == 1: + return parse_datetime(values[0]) + if values[1] == "..": + return None + return parse_datetime(values[1]) + + @validator("intersects") + def validate_spatial(cls, v, values): + """Check bbox and intersects are not both supplied.""" + if v and values["bbox"]: + raise ValueError("intersects and bbox parameters are mutually exclusive") + return v + + @validator("bbox") + def validate_bbox(cls, v: BBox): + """Check order of supplied bbox coordinates.""" + if v: + # Validate order + if len(v) == 4: + xmin, ymin, xmax, ymax = v else: - field_dict[field] = ... # type:ignore - return field_dict + xmin, ymin, min_elev, xmax, ymax, max_elev = v + if max_elev < min_elev: + raise ValueError( + "Maximum elevation must greater than minimum elevation" + ) + + if xmax < xmin: + raise ValueError( + "Maximum longitude must be greater than minimum longitude" + ) + + if ymax < ymin: + raise ValueError( + "Maximum longitude must be greater than minimum longitude" + ) + + # Validate against WGS84 + if xmin < -180 or ymin < -90 or xmax > 180 or ymax > 90: + raise ValueError("Bounding box must be within (-180, -90, 180, 90)") + + return v + + @validator("datetime") + def validate_datetime(cls, v): + """Validate datetime.""" + if "/" in v: + values = v.split("/") + else: + # Single date is interpreted as end date + values = ["..", v] + + dates = [] + for value in values: + if value == "..": + dates.append(value) + continue + + parse_datetime(value) + dates.append(value) + + if ".." not in dates: + if parse_datetime(dates[0]) > parse_datetime(dates[1]): + raise ValueError( + "Invalid datetime range, must match format (begin_date, end_date)" + ) + + return v @property - def filter_fields(self) -> Dict: - """Create pydantic include/exclude expression. + def spatial_filter(self) -> Optional[_GeometryBase]: + """Return a geojson-pydantic object representing the spatial filter for the search request. - Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed - to the API - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude + Check for both because the ``bbox`` and ``intersects`` parameters are mutually exclusive. """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() - - return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), - } - - -class STACSearch(Search): - """Search model.""" - - # Make collections optional, default to searching all collections if none are provided - collections: Optional[List[str]] = None - # Override default field extension to include default fields and pydantic includes/excludes factory - field: FieldsExtension = Field(FieldsExtension(), alias="fields") - # Override query extension with supported operators - query: Optional[Dict[str, Dict[Operator, Any]]] - token: Optional[str] = None - - @root_validator(pre=True) - def validate_query_fields(cls, values: Dict) -> Dict: - """Validate query fields (placeholder).""" - return values + if self.bbox: + return Polygon( + coordinates=[ + [ + [self.bbox[0], self.bbox[3]], + [self.bbox[2], self.bbox[3]], + [self.bbox[2], self.bbox[1]], + [self.bbox[0], self.bbox[1]], + [self.bbox[0], self.bbox[3]], + ] + ] + ) + if self.intersects: + return self.intersects + return From f9423ee4eadfa9cf6068c5f3f08e4d532ca5a1ed Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Wed, 11 Aug 2021 15:19:58 +0100 Subject: [PATCH 2/9] Correcting missed reference to configurable pagination class and updating conformance classes. --- stac_fastapi/api/stac_fastapi/api/app.py | 2 +- stac_fastapi/types/stac_fastapi/types/core.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 0e0993343..d7a11be9a 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -266,7 +266,7 @@ def register_get_item_collection(self): request_model = _create_request_model( "ItemCollectionURI", base_model=ItemCollectionUri, - mixins=[GETTokenPagination], + mixins=[self.get_pagination_model], ) self.router.add_api_route( name="Get ItemCollection", diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index fd8955dbd..df963a8a3 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -238,8 +238,9 @@ def _landing_page(self, base_url: str) -> stac_types.LandingPage: description=self.description, stac_version=self.stac_version, conformsTo=[ - "https://stacspec.org/STAC-api.html", - "http://docs.opengeospatial.org/is/17-069r3/17-069r3.html#ats_geojson", + "https://api.stacspec.org/v1.0.0-beta.2/core", + "https://api.stacspec.org/v1.0.0-beta.2/ogcapi-features", + "https://api.stacspec.org/v1.0.0-beta.2/item-search", ], links=[ { @@ -262,7 +263,7 @@ def _landing_page(self, base_url: str) -> stac_types.LandingPage: "rel": Relations.conformance.value, "type": MimeTypes.json, "title": "STAC/WFS3 conformance classes implemented by this server", - "href": base_url, + "href": urljoin(base_url, "conformance"), }, { "rel": Relations.search.value, From 1d8a6982a3e93ce7489f5f55588beb7acce23ba9 Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Mon, 8 Nov 2021 12:14:13 +0000 Subject: [PATCH 3/9] Changing pydantic limit model to have max/min [skip ci] --- stac_fastapi/types/stac_fastapi/types/search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index db7e3e09f..ac6c5a91c 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -20,7 +20,7 @@ Polygon, _GeometryBase, ) -from pydantic import BaseModel, validator +from pydantic import BaseModel, conint, validator from pydantic.datetime_parse import parse_datetime from stac_pydantic.shared import BBox from stac_pydantic.utils import AutoValueEnum @@ -95,7 +95,7 @@ class BaseSearchPostRequest(BaseModel): Union[Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon] ] datetime: Optional[str] - limit: int = 10 + limit: Optional[conint(ge=0, le=10000)] = 10 @property def start_date(self) -> Optional[datetime]: From f0a32af4399918125d1971123e478f0e616dd614 Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Tue, 9 Nov 2021 15:00:38 +0000 Subject: [PATCH 4/9] fixing against test suite for sqlalchemy --- stac_fastapi/api/stac_fastapi/api/app.py | 33 ++------ stac_fastapi/api/stac_fastapi/api/models.py | 2 +- .../stac_fastapi/extensions/core/__init__.py | 3 + .../extensions/core/pagination/__init__.py | 6 ++ .../extensions/core/pagination/pagination.py | 37 +++++++++ .../core/pagination/token_pagination.py | 37 +++++++++ .../extensions/core/query/request.py | 4 +- .../third_party/bulk_transactions.py | 4 +- .../sqlalchemy/stac_fastapi/sqlalchemy/app.py | 46 +++++++--- .../stac_fastapi/sqlalchemy/core.py | 35 ++++---- .../sqlalchemy/extensions/__init__.py | 5 ++ .../{types/search.py => extensions/query.py} | 83 +++++-------------- .../sqlalchemy/models/database.py | 2 +- .../stac_fastapi/sqlalchemy/transactions.py | 14 ++-- .../sqlalchemy/tests/clients/test_postgres.py | 4 +- stac_fastapi/sqlalchemy/tests/conftest.py | 50 ++++++++--- stac_fastapi/types/stac_fastapi/types/core.py | 3 + 17 files changed, 221 insertions(+), 147 deletions(-) create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py create mode 100644 stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py rename stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/{types/search.py => extensions/query.py} (53%) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index e30aea1e7..1a2bf206e 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -17,16 +17,14 @@ APIRequest, CollectionUri, EmptyRequest, - GETTokenPagination, ItemCollectionUri, ItemUri, - POSTTokenPagination, - _create_request_model, + create_request_model, ) from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint # TODO: make this module not depend on `stac_fastapi.extensions` -from stac_fastapi.extensions.core import FieldsExtension +from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension from stac_fastapi.types.config import ApiSettings, Settings from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient from stac_fastapi.types.extension import ApiExtension @@ -79,8 +77,7 @@ class StacApi: search_post_request_model: Type[BaseSearchPostRequest] = attr.ib( default=BaseSearchPostRequest ) - get_pagination_model: Type[APIRequest] = attr.ib(default=GETTokenPagination) - post_pagination_model: Type[BaseModel] = attr.ib(default=POSTTokenPagination) + pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) @@ -173,13 +170,6 @@ def register_post_search(self): Returns: None """ - search_request_model = _create_request_model( - "SearchPostRequest", - base_model=self.search_post_request_model, - extensions=self.extensions, - mixins=[self.post_pagination_model], - request_type="POST", - ) fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", @@ -192,7 +182,7 @@ def register_post_search(self): response_model_exclude_none=True, methods=["POST"], endpoint=self._create_endpoint( - self.client.post_search, search_request_model + self.client.post_search, self.search_post_request_model ), ) @@ -202,14 +192,6 @@ def register_get_search(self): Returns: None """ - search_request_model = _create_request_model( - "SearchGetRequest", - self.search_get_request_model, - self.extensions, - mixins=[self.get_pagination_model], - request_type="GET", - ) - fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", @@ -222,7 +204,7 @@ def register_get_search(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.get_search, search_request_model + self.client.get_search, self.search_get_request_model ), ) @@ -268,10 +250,11 @@ def register_get_item_collection(self): Returns: None """ - request_model = _create_request_model( + get_pagination_model = self.get_extension(self.pagination_extension).GET + request_model = create_request_model( "ItemCollectionURI", base_model=ItemCollectionUri, - mixins=[self.get_pagination_model], + mixins=[get_pagination_model], ) self.router.add_api_route( name="Get ItemCollection", diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 613594258..2a58a84d1 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -11,7 +11,7 @@ from stac_fastapi.types.search import APIRequest, BaseSearchGetRequest -def _create_request_model( +def create_request_model( model_name="SearchGetRequest", base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest, extensions: Optional[ApiExtension] = None, diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py index beb3e41d1..d720a6377 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py @@ -4,6 +4,7 @@ from .context import ContextExtension from .fields import FieldsExtension from .filter import FilterExtension +from .pagination import PaginationExtension, TokenPaginationExtension from .query import QueryExtension from .sort import SortExtension from .transaction import TransactionExtension @@ -12,8 +13,10 @@ "ContextExtension", "FieldsExtension", "FilterExtension", + "PaginationExtension", "QueryExtension", "SortExtension", "TilesExtension", + "TokenPaginationExtension", "TransactionExtension", ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py new file mode 100644 index 000000000..255701226 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py @@ -0,0 +1,6 @@ +"""pagination classes as extensions.""" + +from .pagination import PaginationExtension +from .token_pagination import TokenPaginationExtension + +__all__ = ["PaginationExtension", "TokenPaginationExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py new file mode 100644 index 000000000..5e834ed38 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py @@ -0,0 +1,37 @@ +"""Pagination API extension.""" + +from typing import List, Optional + +import attr +from fastapi import FastAPI + +from stac_fastapi.api.models import GETPagination, POSTPagination +from stac_fastapi.types.extension import ApiExtension + + +@attr.s +class PaginationExtension(ApiExtension): + """Token Pagination. + + Though not strictly an extension, the chosen pagination will modify the + form of the request object. By making pagination an extension class, we can + use create_request_model to dynamically add the correct pagination parameter + to the request model for OpenAPI generation. + """ + + GET = GETPagination + POST = POSTPagination + + conformance_classes: List[str] = attr.ib(factory=list) + schema_href: Optional[str] = attr.ib(default=None) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + pass diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py new file mode 100644 index 000000000..1e1399971 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py @@ -0,0 +1,37 @@ +"""Token pagination API extension.""" + +from typing import List, Optional + +import attr +from fastapi import FastAPI + +from stac_fastapi.api.models import GETTokenPagination, POSTTokenPagination +from stac_fastapi.types.extension import ApiExtension + + +@attr.s +class TokenPaginationExtension(ApiExtension): + """Token Pagination. + + Though not strictly an extension, the chosen pagination will modify the + form of the request object. By making pagination an extension class, we can + use create_request_model to dynamically add the correct pagination parameter + to the request model for OpenAPI generation. + """ + + GET = GETTokenPagination + POST = POSTTokenPagination + + conformance_classes: List[str] = attr.ib(factory=list) + schema_href: Optional[str] = attr.ib(default=None) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + pass diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py index 84da48116..8b282884a 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -1,6 +1,6 @@ """Request model for the Query extension.""" -from typing import Optional +from typing import Any, Dict, Optional import attr from pydantic import BaseModel @@ -18,4 +18,4 @@ class QueryExtensionGetRequest(APIRequest): class QueryExtensionPostRequest(BaseModel): """Query Extension POST request model.""" - query: Optional[str] + query: Optional[Dict[str, Dict[str, Any]]] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index cc40eb453..a6e139292 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, FastAPI from pydantic import BaseModel -from stac_fastapi.api.models import _create_request_model +from stac_fastapi.api.models import create_request_model from stac_fastapi.api.routes import create_sync_endpoint from stac_fastapi.types.extension import ApiExtension @@ -72,7 +72,7 @@ def register(self, app: FastAPI) -> None: Returns: None """ - items_request_model = _create_request_model(Items) + items_request_model = create_request_model("Items", base_model=Items) router = APIRouter() router.add_api_route( diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py index c70231171..ba6ccd016 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py @@ -1,36 +1,56 @@ """FastAPI application.""" from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( FieldsExtension, - QueryExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.sqlalchemy.config import SqlalchemySettings from stac_fastapi.sqlalchemy.core import CoreCrudClient +from stac_fastapi.sqlalchemy.extensions import QueryExtension from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.transactions import ( BulkTransactionsClient, TransactionsClient, ) -from stac_fastapi.sqlalchemy.types.search import SQLAlchemySTACSearch +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest settings = SqlalchemySettings() session = Session.create_from_settings(settings) +extensions = [ + TransactionExtension(client=TransactionsClient(session=session), settings=settings), + BulkTransactionExtension(client=BulkTransactionsClient(session=session)), + FieldsExtension(), + QueryExtension(), + SortExtension(), + TokenPaginationExtension(), +] + +GET_REQUEST_MODEL = create_request_model( + "SearchGetRequest", + base_model=BaseSearchGetRequest, + extensions=extensions, + request_type="GET", +) + +POST_REQUEST_MODEL = create_request_model( + "SearchPostRequest", + base_model=BaseSearchPostRequest, + extensions=extensions, + request_type="POST", +) + api = StacApi( settings=settings, - extensions=[ - TransactionExtension( - client=TransactionsClient(session=session), settings=settings - ), - BulkTransactionExtension(client=BulkTransactionsClient(session=session)), - FieldsExtension(), - QueryExtension(), - SortExtension(), - ], - client=CoreCrudClient(session=session), - search_request_model=SQLAlchemySTACSearch, + extensions=extensions, + client=CoreCrudClient( + session=session, extensions=extensions, post_request_model=POST_REQUEST_MODEL + ), + search_get_request_model=GET_REQUEST_MODEL, + search_post_request_model=POST_REQUEST_MODEL, ) app = api.app diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py index 3102b4f6c..ad7e7849c 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py @@ -21,13 +21,14 @@ from stac_pydantic.shared import MimeTypes from stac_fastapi.sqlalchemy import serializers +from stac_fastapi.sqlalchemy.extensions.query import Operator from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.tokens import PaginationTokenClient -from stac_fastapi.sqlalchemy.types.search import Operator, SQLAlchemySTACSearch from stac_fastapi.types.config import Settings from stac_fastapi.types.core import BaseCoreClient from stac_fastapi.types.errors import NotFoundError +from stac_fastapi.types.search import BaseSearchPostRequest from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection logger = logging.getLogger(__name__) @@ -90,15 +91,15 @@ def all_collections(self, **kwargs) -> Collections: ) return collection_list - def get_collection(self, id: str, **kwargs) -> Collection: + def get_collection(self, collectionId: str, **kwargs) -> Collection: """Get collection by id.""" base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: - collection = self._lookup_id(id, self.collection_table, session) + collection = self._lookup_id(collectionId, self.collection_table, session) return self.collection_serializer.db_to_stac(collection, base_url) def item_collection( - self, id: str, limit: int = 10, token: str = None, **kwargs + self, collectionId: str, limit: int = 10, token: str = None, **kwargs ) -> ItemCollection: """Read an item collection from the database.""" base_url = str(kwargs["request"].base_url) @@ -106,7 +107,7 @@ def item_collection( collection_children = ( session.query(self.item_table) .join(self.collection_table) - .filter(self.collection_table.id == id) + .filter(self.collection_table.id == collectionId) .order_by(self.item_table.datetime.desc(), self.item_table.id) ) count = None @@ -135,7 +136,7 @@ def item_collection( { "rel": Relations.next.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}collections/{id}/items?token={page.next}&limit={limit}", + "href": f"{kwargs['request'].base_url}collections/{collectionId}/items?token={page.next}&limit={limit}", "method": "GET", } ) @@ -144,7 +145,7 @@ def item_collection( { "rel": Relations.previous.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}collections/{id}/items?token={page.previous}&limit={limit}", + "href": f"{kwargs['request'].base_url}collections/{collectionId}/items?token={page.previous}&limit={limit}", "method": "GET", } ) @@ -170,16 +171,16 @@ def item_collection( context=context_obj, ) - def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: + def get_item(self, itemId: str, collectionId: str, **kwargs) -> Item: """Get item by id.""" base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: db_query = session.query(self.item_table) - db_query = db_query.filter(self.item_table.collection_id == collection_id) - db_query = db_query.filter(self.item_table.id == item_id) + db_query = db_query.filter(self.item_table.collection_id == collectionId) + db_query = db_query.filter(self.item_table.id == itemId) item = db_query.first() if not item: - raise NotFoundError(f"{self.item_table.__name__} {id} not found") + raise NotFoundError(f"{self.item_table.__name__} {itemId} not found") return self.item_serializer.db_to_stac(item, base_url=base_url) def get_search( @@ -233,7 +234,7 @@ def get_search( # Do the request try: - search_request = SQLAlchemySTACSearch(**base_args) + search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") resp = self.post_search(search_request, request=kwargs["request"]) @@ -256,7 +257,7 @@ def get_search( return resp def post_search( - self, search_request: SQLAlchemySTACSearch, **kwargs + self, search_request: BaseSearchPostRequest, **kwargs ) -> ItemCollection: """POST search catalog.""" base_url = str(kwargs["request"].base_url) @@ -428,12 +429,12 @@ def post_search( for k in search_request.query.keys() ] ) - if not search_request.field.include: - search_request.field.include = query_include + if not search_request.fields.include: + search_request.fields.include = query_include else: - search_request.field.include.union(query_include) + search_request.fields.include.union(query_include) - filter_kwargs = search_request.field.filter_fields + filter_kwargs = search_request.fields.filter_fields # Need to pass through `.json()` for proper serialization # of datetime response_features = [ diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py new file mode 100644 index 000000000..d97a001cd --- /dev/null +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py @@ -0,0 +1,5 @@ +"""sqlalchemy extensions modifications.""" + +from .query import Operator, QueryableTypes, QueryExtension + +__all__ = ["Operator", "QueryableTypes", "QueryExtension"] diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/types/search.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/query.py similarity index 53% rename from stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/types/search.py rename to stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/query.py index acfc152af..36f7a7710 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/types/search.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/query.py @@ -1,4 +1,4 @@ -"""stac_fastapi.types.search module. +"""STAC SQLAlchemy specific query search model. # TODO: replace with stac-pydantic """ @@ -8,16 +8,14 @@ from dataclasses import dataclass from enum import auto from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Union import sqlalchemy as sa -from pydantic import Field, ValidationError, conint, root_validator +from pydantic import BaseModel, ValidationError, root_validator from pydantic.error_wrappers import ErrorWrapper -from stac_pydantic.api import Search -from stac_pydantic.api.extensions.fields import FieldsExtension as FieldsBase from stac_pydantic.utils import AutoValueEnum -from stac_fastapi.types.config import Settings +from stac_fastapi.extensions.core.query import QueryExtension as QueryExtensionBase logger = logging.getLogger("uvicorn") logger.setLevel(logging.INFO) @@ -34,6 +32,7 @@ class Operator(str, AutoValueEnum): lte = auto() gt = auto() gte = auto() + # TODO: These are defined in the spec but aren't currently implemented by the api # startsWith = auto() # endsWith = auto() @@ -86,66 +85,14 @@ class QueryableTypes: dtype = sa.String -class FieldsExtension(FieldsBase): - """FieldsExtension. +class QueryExtensionPostRequest(BaseModel): + """Queryable validation. - Attributes: - include: set of fields to include. - exclude: set of fields to exclude. + Add queryables validation to the POST request + to raise errors for unsupported querys. """ - include: Optional[Set[str]] = set() - exclude: Optional[Set[str]] = set() - - @staticmethod - def _get_field_dict(fields: Optional[Set[str]]) -> Dict: - """Pydantic include/excludes notation. - - Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - field_dict = {} - for field in fields or []: - if "." in field: - parent, key = field.split(".") - if parent not in field_dict: - field_dict[parent] = {key} - else: - field_dict[parent].add(key) - else: - field_dict[field] = ... # type:ignore - return field_dict - - @property - def filter_fields(self) -> Dict: - """Create pydantic include/exclude expression. - - Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed - to the API - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() - - return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), - } - - -class SQLAlchemySTACSearch(Search): - """Search model.""" - - # Make collections optional, default to searching all collections if none are provided - collections: Optional[List[str]] = None - # Override default field extension to include default fields and pydantic includes/excludes factory - field: FieldsExtension = Field(FieldsExtension(), alias="fields") - # Override query extension with supported operators query: Optional[Dict[Queryables, Dict[Operator, Any]]] - token: Optional[str] = None - limit: Optional[conint(ge=0, le=10000)] = 10 @root_validator(pre=True) def validate_query_fields(cls, values: Dict) -> Dict: @@ -162,6 +109,16 @@ def validate_query_fields(cls, values: Dict) -> Dict: "STACSearch", ) ], - SQLAlchemySTACSearch, + QueryExtensionPostRequest, ) return values + + +class QueryExtension(QueryExtensionBase): + """Query Extenson. + + Override the POST request model to add validation against + supported fields + """ + + POST = QueryExtensionPostRequest diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py index 128a145c7..e521d453f 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py @@ -8,7 +8,7 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.declarative import declarative_base -from stac_fastapi.sqlalchemy.types.search import Queryables, QueryableTypes +from stac_fastapi.sqlalchemy.extensions.query import Queryables, QueryableTypes BaseModel = declarative_base() diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py index 555458869..12a97c789 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py @@ -85,29 +85,27 @@ def update_collection( return self.collection_serializer.db_to_stac(db_model, base_url) - def delete_item( - self, item_id: str, collection_id: str, **kwargs - ) -> stac_types.Item: + def delete_item(self, itemId: str, collectionId: str, **kwargs) -> stac_types.Item: """Delete item.""" base_url = str(kwargs["request"].base_url) with self.session.writer.context_session() as session: - query = session.query(self.item_table).filter(self.item_table.id == item_id) + query = session.query(self.item_table).filter(self.item_table.id == itemId) data = query.first() if not data: - raise NotFoundError(f"Item {item_id} not found") + raise NotFoundError(f"Item {itemId} not found") query.delete() return self.item_serializer.db_to_stac(data, base_url=base_url) - def delete_collection(self, id: str, **kwargs) -> stac_types.Collection: + def delete_collection(self, collectionId: str, **kwargs) -> stac_types.Collection: """Delete collection.""" base_url = str(kwargs["request"].base_url) with self.session.writer.context_session() as session: query = session.query(self.collection_table).filter( - self.collection_table.id == id + self.collection_table.id == collectionId ) data = query.first() if not data: - raise NotFoundError(f"Collection {id} not found") + raise NotFoundError(f"Collection {collectionId} not found") query.delete() return self.collection_serializer.db_to_stac(data, base_url=base_url) diff --git a/stac_fastapi/sqlalchemy/tests/clients/test_postgres.py b/stac_fastapi/sqlalchemy/tests/clients/test_postgres.py index d20f24489..bfeea8358 100644 --- a/stac_fastapi/sqlalchemy/tests/clients/test_postgres.py +++ b/stac_fastapi/sqlalchemy/tests/clients/test_postgres.py @@ -97,8 +97,8 @@ def test_get_item( data = load_test_data("test_item.json") postgres_transactions.create_item(data, request=MockStarletteRequest) coll = postgres_core.get_item( - item_id=data["id"], - collection_id=data["collection"], + itemId=data["id"], + collectionId=data["collection"], request=MockStarletteRequest, ) assert coll["id"] == data["id"] diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 795868e2e..7abd9150f 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -6,23 +6,25 @@ from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( ContextExtension, FieldsExtension, - QueryExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.sqlalchemy.config import SqlalchemySettings from stac_fastapi.sqlalchemy.core import CoreCrudClient +from stac_fastapi.sqlalchemy.extensions import QueryExtension from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.transactions import ( BulkTransactionsClient, TransactionsClient, ) -from stac_fastapi.sqlalchemy.types.search import SQLAlchemySTACSearch from stac_fastapi.types.config import Settings +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -105,19 +107,41 @@ def postgres_bulk_transactions(db_session): @pytest.fixture def api_client(db_session): settings = SqlalchemySettings() + extensions = [ + TransactionExtension( + client=TransactionsClient(session=db_session), settings=settings + ), + ContextExtension(), + SortExtension(), + FieldsExtension(), + QueryExtension(), + TokenPaginationExtension(), + ] + + get_request_model = create_request_model( + "SearchGetRequest", + base_model=BaseSearchGetRequest, + extensions=extensions, + request_type="GET", + ) + + post_request_model = create_request_model( + "SearchPostRequest", + base_model=BaseSearchPostRequest, + extensions=extensions, + request_type="POST", + ) + return StacApi( settings=settings, - client=CoreCrudClient(session=db_session), - extensions=[ - TransactionExtension( - client=TransactionsClient(session=db_session), settings=settings - ), - ContextExtension(), - SortExtension(), - FieldsExtension(), - QueryExtension(), - ], - search_request_model=SQLAlchemySTACSearch, + client=CoreCrudClient( + session=db_session, + extensions=extensions, + post_request_model=post_request_model, + ), + extensions=extensions, + search_get_request_model=get_request_model, + search_post_request_model=post_request_model, ) diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 2fdb092a8..7fdafec4a 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -14,6 +14,7 @@ from stac_fastapi.types import stac as stac_types from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.search import BaseSearchPostRequest from stac_fastapi.types.stac import Conformance NumType = Union[float, int] @@ -306,6 +307,7 @@ class BaseCoreClient(LandingPageMixin, abc.ABC): factory=lambda: BASE_CONFORMANCE_CLASSES ) extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list)) + post_request_model = attr.ib(default=BaseSearchPostRequest) def conformance_classes(self) -> List[str]: """Generate conformance classes by adding extension conformance to base conformance classes.""" @@ -495,6 +497,7 @@ class AsyncBaseCoreClient(LandingPageMixin, abc.ABC): factory=lambda: BASE_CONFORMANCE_CLASSES ) extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list)) + post_request_model = attr.ib(default=BaseSearchPostRequest) def conformance_classes(self) -> List[str]: """Generate conformance classes by adding extension conformance to base conformance classes.""" From 3e9685fd3365bc7c49507897690b6585989c6dc7 Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Thu, 11 Nov 2021 12:06:34 +0000 Subject: [PATCH 5/9] wrapping function calls to make get and post request models --- stac_fastapi/api/stac_fastapi/api/models.py | 26 ++++++++++++++++++- .../sqlalchemy/stac_fastapi/sqlalchemy/app.py | 23 ++++------------ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 2a58a84d1..f1a2bba4a 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -8,7 +8,11 @@ from pydantic.fields import UndefinedType from stac_fastapi.types.extension import ApiExtension -from stac_fastapi.types.search import APIRequest, BaseSearchGetRequest +from stac_fastapi.types.search import ( + APIRequest, + BaseSearchGetRequest, + BaseSearchPostRequest, +) def create_request_model( @@ -68,6 +72,26 @@ def create_request_model( raise TypeError("Mixed Request Model types. Check extension request types.") +def create_get_request_model(extensions): + """Wrap create_request_model to create the GET request model.""" + return create_request_model( + "SearchGetRequest", + base_model=BaseSearchGetRequest, + extensions=extensions, + request_type="GET", + ) + + +def create_post_requst_model(extensions): + """Wrap create_request_model to create the POST request model.""" + return create_request_model( + "SearchPostRequest", + base_model=BaseSearchPostRequest, + extensions=extensions, + request_type="POST", + ) + + @attr.s # type:ignore class CollectionUri(APIRequest): """Delete collection.""" diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py index ba6ccd016..09fdf2e73 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py @@ -1,6 +1,6 @@ """FastAPI application.""" from stac_fastapi.api.app import StacApi -from stac_fastapi.api.models import create_request_model +from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, SortExtension, @@ -16,7 +16,6 @@ BulkTransactionsClient, TransactionsClient, ) -from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest settings = SqlalchemySettings() session = Session.create_from_settings(settings) @@ -29,28 +28,16 @@ TokenPaginationExtension(), ] -GET_REQUEST_MODEL = create_request_model( - "SearchGetRequest", - base_model=BaseSearchGetRequest, - extensions=extensions, - request_type="GET", -) - -POST_REQUEST_MODEL = create_request_model( - "SearchPostRequest", - base_model=BaseSearchPostRequest, - extensions=extensions, - request_type="POST", -) +post_request_model = create_post_request_model(extensions) api = StacApi( settings=settings, extensions=extensions, client=CoreCrudClient( - session=session, extensions=extensions, post_request_model=POST_REQUEST_MODEL + session=session, extensions=extensions, post_request_model=post_request_model ), - search_get_request_model=GET_REQUEST_MODEL, - search_post_request_model=POST_REQUEST_MODEL, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=post_request_model, ) app = api.app From c1d1fbac22701035be21dc06c5e31cfab85dd39b Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Thu, 11 Nov 2021 14:09:36 +0000 Subject: [PATCH 6/9] working to pass tests for pgstac --- stac_fastapi/api/stac_fastapi/api/models.py | 8 +- .../pgstac/stac_fastapi/pgstac/app.py | 33 +++--- .../pgstac/stac_fastapi/pgstac/core.py | 36 +++--- .../pgstac/extensions/__init__.py | 5 + .../stac_fastapi/pgstac/extensions/query.py | 48 ++++++++ .../stac_fastapi/pgstac/transactions.py | 12 +- .../stac_fastapi/pgstac/types/search.py | 104 +----------------- stac_fastapi/pgstac/tests/conftest.py | 28 +++-- 8 files changed, 129 insertions(+), 145 deletions(-) create mode 100644 stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py create mode 100644 stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index f1a2bba4a..fe527d671 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -72,7 +72,9 @@ def create_request_model( raise TypeError("Mixed Request Model types. Check extension request types.") -def create_get_request_model(extensions): +def create_get_request_model( + extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest +): """Wrap create_request_model to create the GET request model.""" return create_request_model( "SearchGetRequest", @@ -82,7 +84,9 @@ def create_get_request_model(extensions): ) -def create_post_requst_model(extensions): +def create_post_request_model( + extensions, base_model: BaseSearchPostRequest = BaseSearchGetRequest +): """Wrap create_request_model to create the POST request model.""" return create_request_model( "SearchPostRequest", diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py index 96ab724a0..80eeccedf 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py @@ -2,10 +2,12 @@ from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions import QueryExtension from stac_fastapi.extensions.core import ( FieldsExtension, - QueryExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.pgstac.config import Settings @@ -15,22 +17,27 @@ from stac_fastapi.pgstac.types.search import PgstacSearch settings = Settings() +extensions = [ + TransactionExtension( + client=TransactionsClient(), + settings=settings, + response_class=ORJSONResponse, + ), + QueryExtension(), + SortExtension(), + FieldsExtension(), + TokenPaginationExtension(), +] + +post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) api = StacApi( settings=settings, - extensions=[ - TransactionExtension( - client=TransactionsClient(), - settings=settings, - response_class=ORJSONResponse, - ), - QueryExtension(), - SortExtension(), - FieldsExtension(), - ], - client=CoreCrudClient(), - search_request_model=PgstacSearch, + extensions=extensions, + client=CoreCrudClient(post_request_model=post_request_model), response_class=ORJSONResponse, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=post_request_model, ) app = api.app diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index f3a5b7e4f..363c55e34 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -1,7 +1,7 @@ """Item crud client.""" import re from datetime import datetime -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urljoin import attr @@ -27,8 +27,6 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - search_request_model: Type[PgstacSearch] = attr.ib(init=False, default=PgstacSearch) - async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] @@ -71,7 +69,7 @@ async def all_collections(self, **kwargs) -> Collections: collection_list = Collections(collections=linked_collections or [], links=links) return collection_list - async def get_collection(self, id: str, **kwargs) -> Collection: + async def get_collection(self, collectionId: str, **kwargs) -> Collection: """Get collection by id. Called with `GET /collections/{collectionId}`. @@ -91,14 +89,14 @@ async def get_collection(self, id: str, **kwargs) -> Collection: """ SELECT * FROM get_collection(:id::text); """, - id=id, + id=collectionId, ) collection = await conn.fetchval(q, *p) if collection is None: raise NotFoundError(f"Collection {id} does not exist.") collection["links"] = await CollectionLinks( - collection_id=id, request=request + collection_id=collectionId, request=request ).get_links(extra_links=collection.get("links")) return Collection(**collection) @@ -175,7 +173,11 @@ async def _search_base( return collection async def item_collection( - self, id: str, limit: Optional[int] = None, token: str = None, **kwargs + self, + collectionId: str, + limit: Optional[int] = None, + token: str = None, + **kwargs, ) -> ItemCollection: """Get all items from a specific collection. @@ -190,17 +192,19 @@ async def item_collection( An ItemCollection. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(id, **kwargs) + await self.get_collection(collectionId, **kwargs) - req = self.search_request_model(collections=[id], limit=limit, token=token) + req = self.post_request_model( + collections=[collectionId], limit=limit, token=token + ) item_collection = await self._search_base(req, **kwargs) links = await CollectionLinks( - collection_id=id, request=kwargs["request"] + collection_id=collectionId, request=kwargs["request"] ).get_links(extra_links=item_collection["links"]) item_collection["links"] = links return item_collection - async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: + async def get_item(self, itemId: str, collectionId: str, **kwargs) -> Item: """Get item by id. Called with `GET /collections/{collectionId}/items/{itemId}`. @@ -212,15 +216,13 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: Item. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, **kwargs) + await self.get_collection(collectionId, **kwargs) - req = self.search_request_model( - ids=[item_id], collections=[collection_id], limit=1 - ) + req = self.post_request_model(ids=[itemId], collections=[collectionId], limit=1) item_collection = await self._search_base(req, **kwargs) if not item_collection["features"]: raise NotFoundError( - f"Item {item_id} in Collection {collection_id} does not exist." + f"Item {itemId} in Collection {collectionId} does not exist." ) return Item(**item_collection["features"][0]) @@ -301,7 +303,7 @@ async def get_search( # Do the request try: - search_request = self.search_request_model(**base_args) + search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") return await self.post_search(search_request, request=kwargs["request"]) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py new file mode 100644 index 000000000..410bc63f1 --- /dev/null +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py @@ -0,0 +1,5 @@ +"""pgstac extension customisations.""" + +from .query import QueryExtension + +__all__ = ["QueryExtension"] diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py new file mode 100644 index 000000000..91df8539d --- /dev/null +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py @@ -0,0 +1,48 @@ +"""Pgstac query customisation.""" + +import operator +from enum import auto +from types import DynamicClassAttribute +from typing import Any, Callable, Dict, Optional + +from pydantic import BaseModel +from stac_pydantic.utils import AutoValueEnum + +from stac_fastapi.extensions.core.query import QueryExtension as QueryExtensionBase + + +class Operator(str, AutoValueEnum): + """Defines the set of operators supported by the API.""" + + eq = auto() + ne = auto() + lt = auto() + lte = auto() + gt = auto() + gte = auto() + # TODO: These are defined in the spec but aren't currently implemented by the api + # startsWith = auto() + # endsWith = auto() + # contains = auto() + # in = auto() + + @DynamicClassAttribute + def operator(self) -> Callable[[Any, Any], bool]: + """Return python operator.""" + return getattr(operator, self._value_) + + +class QueryExtensionPostRequest(BaseModel): + """Query Extension POST request model.""" + + query: Optional[Dict[str, Dict[Operator, Any]]] + + +class QueryExtension(QueryExtensionBase): + """Query Extension. + + Override the POST request model to add validation against + supported fields + """ + + POST = QueryExtensionPostRequest diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py index ca835a7bd..539b4f302 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py @@ -49,16 +49,16 @@ async def update_collection( await dbfunc(pool, "update_collection", collection) return collection - async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> Dict: + async def delete_item(self, itemId: str, collectionId: str, **kwargs) -> Dict: """Delete collection.""" request = kwargs["request"] pool = request.app.state.writepool - await dbfunc(pool, "delete_item", item_id) - return {"deleted item": item_id} + await dbfunc(pool, "delete_item", itemId) + return {"deleted item": itemId} - async def delete_collection(self, id: str, **kwargs) -> Dict: + async def delete_collection(self, collectionId: str, **kwargs) -> Dict: """Delete collection.""" request = kwargs["request"] pool = request.app.state.writepool - await dbfunc(pool, "delete_collection", id) - return {"deleted collection": id} + await dbfunc(pool, "delete_collection", collectionId) + return {"deleted collection": collectionId} diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py index b32fbfa54..a13126f1e 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py @@ -1,111 +1,19 @@ """stac_fastapi.types.search module.""" -import operator -from enum import auto -from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Optional -from pydantic import Field, conint, root_validator, validator -from stac_pydantic.api import Search -from stac_pydantic.api.extensions.fields import FieldsExtension as FieldsBase -from stac_pydantic.utils import AutoValueEnum +from pydantic import validator -from stac_fastapi.types.config import Settings +from stac_fastapi.types.search import BaseSearchPostRequest -# Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 -NumType = Union[float, int] +class PgstacSearch(BaseSearchPostRequest): + """Search model. -class Operator(str, AutoValueEnum): - """Defines the set of operators supported by the API.""" - - eq = auto() - ne = auto() - lt = auto() - lte = auto() - gt = auto() - gte = auto() - # TODO: These are defined in the spec but aren't currently implemented by the api - # startsWith = auto() - # endsWith = auto() - # contains = auto() - # in = auto() - - @DynamicClassAttribute - def operator(self) -> Callable[[Any, Any], bool]: - """Return python operator.""" - return getattr(operator, self._value_) - - -class FieldsExtension(FieldsBase): - """FieldsExtension. - - Attributes: - include: set of fields to include. - exclude: set of fields to exclude. + Overrides the validation for datetime from the base request model. """ - include: Optional[Set[str]] = set() - exclude: Optional[Set[str]] = set() - - @staticmethod - def _get_field_dict(fields: Optional[Set[str]]) -> Dict: - """Pydantic include/excludes notation. - - Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - field_dict = {} - for field in fields or []: - if "." in field: - parent, key = field.split(".") - if parent not in field_dict: - field_dict[parent] = {key} - else: - field_dict[parent].add(key) - else: - field_dict[field] = ... # type:ignore - return field_dict - - @property - def filter_fields(self) -> Dict: - """Create pydantic include/exclude expression. - - Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed - to the API - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() - - return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), - } - - -class PgstacSearch(Search): - """Search model.""" - - # Make collections optional, default to searching all collections if none are provided - collections: Optional[List[str]] = None - ids: Optional[List[str]] = None - # Override default field extension to include default fields and pydantic includes/excludes factory - fields: FieldsExtension = Field(FieldsExtension()) - # Override query extension with supported operators - query: Optional[Dict[str, Dict[Operator, Any]]] - filter: Optional[Dict] - token: Optional[str] = None datetime: Optional[str] = None - sortby: Any - limit: Optional[conint(ge=0, le=10000)] = 10 - - @root_validator(pre=True) - def validate_query_fields(cls, values: Dict) -> Dict: - """Pgstac does not require the base validator for query fields.""" - return values @validator("datetime") def validate_datetime(cls, v): diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index a82af074a..12f8274f2 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -12,15 +12,18 @@ from stac_pydantic import Collection, Item from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, - QueryExtension, + FilterExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.core import CoreCrudClient from stac_fastapi.pgstac.db import close_db_connection, connect_to_db +from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.transactions import TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch @@ -82,16 +85,23 @@ async def pgstac(pg): @pytest.fixture(scope="session") def api_client(pg): print("creating client with settings") + + extensions = [ + TransactionExtension(client=TransactionsClient(), settings=settings), + QueryExtension(), + FilterExtension(), + SortExtension(), + FieldsExtension(), + TokenPaginationExtension(), + ] + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) + api = StacApi( settings=settings, - extensions=[ - TransactionExtension(client=TransactionsClient(), settings=settings), - QueryExtension(), - SortExtension(), - FieldsExtension(), - ], - client=CoreCrudClient(), - search_request_model=PgstacSearch, + extensions=extensions, + client=CoreCrudClient(post_request_model=post_request_model), + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=post_request_model, response_class=ORJSONResponse, ) From c94f5f3db7b91fb819dcc57a1cbbfbb101fd327c Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Fri, 26 Nov 2021 17:05:34 +0000 Subject: [PATCH 7/9] changed constraints on limit parameter to pass test --- stac_fastapi/types/stac_fastapi/types/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index ac6c5a91c..3ef9a80c1 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -95,7 +95,7 @@ class BaseSearchPostRequest(BaseModel): Union[Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon] ] datetime: Optional[str] - limit: Optional[conint(ge=0, le=10000)] = 10 + limit: Optional[conint(gt=0, le=10000)] = 10 @property def start_date(self) -> Optional[datetime]: From 512538842762e7dc16b51d3b6bddf8cb1ecc7090 Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Thu, 2 Dec 2021 13:11:30 +0000 Subject: [PATCH 8/9] fixing against pgstac tests --- .../pgstac/stac_fastapi/pgstac/core.py | 24 ++++++++++--------- .../stac_fastapi/pgstac/transactions.py | 12 +++++----- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index 1b028bd9a..c0ba8e183 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -69,7 +69,7 @@ async def all_collections(self, **kwargs) -> Collections: collection_list = Collections(collections=linked_collections or [], links=links) return collection_list - async def get_collection(self, collectionId: str, **kwargs) -> Collection: + async def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -89,14 +89,14 @@ async def get_collection(self, collectionId: str, **kwargs) -> Collection: """ SELECT * FROM get_collection(:id::text); """, - id=collectionId, + id=collection_id, ) collection = await conn.fetchval(q, *p) if collection is None: raise NotFoundError(f"Collection {id} does not exist.") collection["links"] = await CollectionLinks( - collection_id=collectionId, request=request + collection_id=collection_id, request=request ).get_links(extra_links=collection.get("links")) return Collection(**collection) @@ -174,7 +174,7 @@ async def _search_base( async def item_collection( self, - collectionId: str, + collection_id: str, limit: Optional[int] = None, token: str = None, **kwargs, @@ -192,19 +192,19 @@ async def item_collection( An ItemCollection. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collectionId, **kwargs) + await self.get_collection(collection_id, **kwargs) req = self.post_request_model( - collections=[collectionId], limit=limit, token=token + collections=[collection_id], limit=limit, token=token ) item_collection = await self._search_base(req, **kwargs) links = await CollectionLinks( - collection_id=collectionId, request=kwargs["request"] + collection_id=collection_id, request=kwargs["request"] ).get_links(extra_links=item_collection["links"]) item_collection["links"] = links return item_collection - async def get_item(self, itemId: str, collectionId: str, **kwargs) -> Item: + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -216,13 +216,15 @@ async def get_item(self, itemId: str, collectionId: str, **kwargs) -> Item: Item. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collectionId, **kwargs) + await self.get_collection(collection_id, **kwargs) - req = self.post_request_model(ids=[itemId], collections=[collectionId], limit=1) + req = self.post_request_model( + ids=[item_id], collections=[collection_id], limit=1 + ) item_collection = await self._search_base(req, **kwargs) if not item_collection["features"]: raise NotFoundError( - f"Item {itemId} in Collection {collectionId} does not exist." + f"Item {item_id} in Collection {collection_id} does not exist." ) return Item(**item_collection["features"][0]) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py index 539b4f302..4a06a928f 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py @@ -49,16 +49,16 @@ async def update_collection( await dbfunc(pool, "update_collection", collection) return collection - async def delete_item(self, itemId: str, collectionId: str, **kwargs) -> Dict: + async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> Dict: """Delete collection.""" request = kwargs["request"] pool = request.app.state.writepool - await dbfunc(pool, "delete_item", itemId) - return {"deleted item": itemId} + await dbfunc(pool, "delete_item", item_id) + return {"deleted item": item_id} - async def delete_collection(self, collectionId: str, **kwargs) -> Dict: + async def delete_collection(self, collection_id: str, **kwargs) -> Dict: """Delete collection.""" request = kwargs["request"] pool = request.app.state.writepool - await dbfunc(pool, "delete_collection", collectionId) - return {"deleted collection": collectionId} + await dbfunc(pool, "delete_collection", collection_id) + return {"deleted collection": collection_id} From e30e39461ea17e1d2a568caf0c4c69c9401010df Mon Sep 17 00:00:00 2001 From: rsmith013 Date: Thu, 2 Dec 2021 13:17:15 +0000 Subject: [PATCH 9/9] adding changelog entry --- CHANGES.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 02e616a60..e409d7d02 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,9 @@ ### Changed +* Refactor to remove hardcoded search request models. Request models are now dynamically created based on the enabled extensions. + ([#213](https://github.com/stac-utils/stac-fastapi/pull/213)) + ### Removed ### Fixed