Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use from __future__ import annotations #962

Merged
merged 4 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ on:
- "1.0"
- "2.0"
pull_request:
merge_group:

jobs:
test:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
### Changed

- Switch to pytest ([#939](https://github.com/stac-utils/pystac/pull/939))
- Use `from __future__ import annotations` for type signatures ([#962](https://github.com/stac-utils/pystac/pull/962))

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion pystac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
]

import os
from typing import Any, AnyStr, Dict, Optional, Union
from typing import Any, Dict, Optional
gadomski marked this conversation as resolved.
Show resolved Hide resolved

from pystac.errors import (
STACError,
Expand Down
18 changes: 10 additions & 8 deletions pystac/asset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from copy import copy, deepcopy
from html import escape
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
Expand All @@ -6,9 +8,9 @@
from pystac.html.jinja_env import get_jinja_env

if TYPE_CHECKING:
from pystac.collection import Collection as Collection_Type
from pystac.common_metadata import CommonMetadata as CommonMetadata_Type
from pystac.item import Item as Item_Type
from pystac.collection import Collection
from pystac.common_metadata import CommonMetadata
from pystac.item import Item


class Asset:
Expand Down Expand Up @@ -50,7 +52,7 @@ class Asset:
"""Optional, Semantic roles (i.e. thumbnail, overview, data, metadata) of the
asset."""

owner: Optional[Union["Item_Type", "Collection_Type"]]
owner: Optional[Union[Item, Collection]]
"""The :class:`~pystac.Item` or :class:`~pystac.Collection` that this asset belongs
to, or ``None`` if it has no owner."""

Expand All @@ -77,7 +79,7 @@ def __init__(
# The Item which owns this Asset.
self.owner = None

def set_owner(self, obj: Union["Collection_Type", "Item_Type"]) -> None:
def set_owner(self, obj: Union[Collection, Item]) -> None:
"""Sets the owning item of this Asset.

The owning item will be used to resolve relative HREFs of this asset.
Expand Down Expand Up @@ -134,7 +136,7 @@ def to_dict(self) -> Dict[str, Any]:

return d

def clone(self) -> "Asset":
def clone(self) -> Asset:
"""Clones this asset. Makes a ``deepcopy`` of the
:attr:`~pystac.Asset.extra_fields`.

Expand Down Expand Up @@ -166,7 +168,7 @@ def has_role(self, role: str) -> bool:
return role in self.roles

@property
def common_metadata(self) -> "CommonMetadata_Type":
def common_metadata(self) -> CommonMetadata:
"""Access the asset's common metadata fields as a
:class:`~pystac.CommonMetadata` object."""
return common_metadata.CommonMetadata(self)
Expand All @@ -183,7 +185,7 @@ def _repr_html_(self) -> str:
return escape(repr(self))

@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "Asset":
def from_dict(cls, d: Dict[str, Any]) -> Asset:
"""Constructs an Asset from a dict.

Returns:
Expand Down
76 changes: 33 additions & 43 deletions pystac/cache.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

from collections import ChainMap
from copy import copy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import pystac

if TYPE_CHECKING:
from pystac.collection import Collection as Collection_Type
from pystac.stac_object import STACObject as STACObject_Type
from pystac.collection import Collection
from pystac.stac_object import STACObject


def get_cache_key(stac_object: "STACObject_Type") -> Tuple[str, bool]:
def get_cache_key(stac_object: STACObject) -> Tuple[str, bool]:
"""Produce a cache key for the given STAC object.

If a self href is set, use that as the cache key.
Expand Down Expand Up @@ -59,31 +61,31 @@ class ResolvedObjectCache:
to collections.
"""

id_keys_to_objects: Dict[str, "STACObject_Type"]
id_keys_to_objects: Dict[str, STACObject]
"""Existing cache of a key made up of the STACObject and it's parents IDs mapped
to the cached STACObject."""

hrefs_to_objects: Dict[str, "STACObject_Type"]
hrefs_to_objects: Dict[str, STACObject]
"""STAC Object HREFs matched to their cached object."""

ids_to_collections: Dict[str, "Collection_Type"]
ids_to_collections: Dict[str, Collection]
"""Map of collection IDs to collections."""

_collection_cache: Optional["ResolvedObjectCollectionCache"]

def __init__(
self,
id_keys_to_objects: Optional[Dict[str, "STACObject_Type"]] = None,
hrefs_to_objects: Optional[Dict[str, "STACObject_Type"]] = None,
ids_to_collections: Optional[Dict[str, "Collection_Type"]] = None,
id_keys_to_objects: Optional[Dict[str, STACObject]] = None,
hrefs_to_objects: Optional[Dict[str, STACObject]] = None,
ids_to_collections: Optional[Dict[str, Collection]] = None,
):
self.id_keys_to_objects = id_keys_to_objects or {}
self.hrefs_to_objects = hrefs_to_objects or {}
self.ids_to_collections = ids_to_collections or {}

self._collection_cache = None

def get_or_cache(self, obj: "STACObject_Type") -> "STACObject_Type":
def get_or_cache(self, obj: STACObject) -> STACObject:
"""Gets the STACObject that is the cached version of the given STACObject; or,
if none exists, sets the cached object to the given object.

Expand All @@ -109,7 +111,7 @@ def get_or_cache(self, obj: "STACObject_Type") -> "STACObject_Type":
self.cache(obj)
return obj

def get(self, obj: "STACObject_Type") -> Optional["STACObject_Type"]:
def get(self, obj: STACObject) -> Optional[STACObject]:
"""Get the cached object that has the same cache key as the given object.

Args:
Expand All @@ -126,7 +128,7 @@ def get(self, obj: "STACObject_Type") -> Optional["STACObject_Type"]:
else:
return self.id_keys_to_objects.get(key)

def get_by_href(self, href: str) -> Optional["STACObject_Type"]:
def get_by_href(self, href: str) -> Optional[STACObject]:
"""Gets the cached object at href.

Args:
Expand All @@ -137,7 +139,7 @@ def get_by_href(self, href: str) -> Optional["STACObject_Type"]:
"""
return self.hrefs_to_objects.get(href)

def get_collection_by_id(self, id: str) -> Optional["Collection_Type"]:
def get_collection_by_id(self, id: str) -> Optional[Collection]:
"""Retrieved a cached Collection by its ID.

Args:
Expand All @@ -149,7 +151,7 @@ def get_collection_by_id(self, id: str) -> Optional["Collection_Type"]:
"""
return self.ids_to_collections.get(id)

def cache(self, obj: "STACObject_Type") -> None:
def cache(self, obj: STACObject) -> None:
"""Set the given object into the cache.

Args:
Expand All @@ -164,7 +166,7 @@ def cache(self, obj: "STACObject_Type") -> None:
if isinstance(obj, pystac.Collection):
self.ids_to_collections[obj.id] = obj

def remove(self, obj: "STACObject_Type") -> None:
def remove(self, obj: STACObject) -> None:
"""Removes any cached object that matches the given object's cache key.

Args:
Expand All @@ -180,7 +182,7 @@ def remove(self, obj: "STACObject_Type") -> None:
if obj.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION:
self.id_keys_to_objects.pop(obj.id, None)

def __contains__(self, obj: "STACObject_Type") -> bool:
def __contains__(self, obj: STACObject) -> bool:
key, is_href = get_cache_key(obj)
return (
key in self.hrefs_to_objects if is_href else key in self.id_keys_to_objects
Expand All @@ -190,15 +192,15 @@ def contains_collection_id(self, collection_id: str) -> bool:
"""Returns True if there is a collection with given collection ID is cached."""
return collection_id in self.ids_to_collections

def as_collection_cache(self) -> "CollectionCache":
def as_collection_cache(self) -> CollectionCache:
if self._collection_cache is None:
self._collection_cache = ResolvedObjectCollectionCache(self)
return self._collection_cache

@staticmethod
def merge(
first: "ResolvedObjectCache", second: "ResolvedObjectCache"
) -> "ResolvedObjectCache":
) -> ResolvedObjectCache:
"""Merges two ResolvedObjectCache.

The merged cache will give preference to the first argument; that is, if there
Expand Down Expand Up @@ -245,37 +247,31 @@ class CollectionCache:
in common properties.
"""

cached_ids: Dict[str, Union["Collection_Type", Dict[str, Any]]]
cached_hrefs: Dict[str, Union["Collection_Type", Dict[str, Any]]]
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]]
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]]

def __init__(
self,
cached_ids: Optional[
Dict[str, Union["Collection_Type", Dict[str, Any]]]
] = None,
cached_hrefs: Optional[
Dict[str, Union["Collection_Type", Dict[str, Any]]]
] = None,
cached_ids: Optional[Dict[str, Union[Collection, Dict[str, Any]]]] = None,
cached_hrefs: Optional[Dict[str, Union[Collection, Dict[str, Any]]]] = None,
):
self.cached_ids = cached_ids or {}
self.cached_hrefs = cached_hrefs or {}

def get_by_id(
self, collection_id: str
) -> Optional[Union["Collection_Type", Dict[str, Any]]]:
) -> Optional[Union[Collection, Dict[str, Any]]]:
return self.cached_ids.get(collection_id)

def get_by_href(
self, href: str
) -> Optional[Union["Collection_Type", Dict[str, Any]]]:
def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
return self.cached_hrefs.get(href)

def contains_id(self, collection_id: str) -> bool:
return collection_id in self.cached_ids

def cache(
self,
collection: Union["Collection_Type", Dict[str, Any]],
collection: Union[Collection, Dict[str, Any]],
href: Optional[str] = None,
) -> None:
"""Caches a collection JSON."""
Expand All @@ -294,28 +290,22 @@ class ResolvedObjectCollectionCache(CollectionCache):
def __init__(
self,
resolved_object_cache: ResolvedObjectCache,
cached_ids: Optional[
Dict[str, Union["Collection_Type", Dict[str, Any]]]
] = None,
cached_hrefs: Optional[
Dict[str, Union["Collection_Type", Dict[str, Any]]]
] = None,
cached_ids: Optional[Dict[str, Union[Collection, Dict[str, Any]]]] = None,
cached_hrefs: Optional[Dict[str, Union[Collection, Dict[str, Any]]]] = None,
):
super().__init__(cached_ids, cached_hrefs)
self.resolved_object_cache = resolved_object_cache

def get_by_id(
self, collection_id: str
) -> Optional[Union["Collection_Type", Dict[str, Any]]]:
) -> Optional[Union[Collection, Dict[str, Any]]]:
result = self.resolved_object_cache.get_collection_by_id(collection_id)
if result is None:
return super().get_by_id(collection_id)
else:
return result

def get_by_href(
self, href: str
) -> Optional[Union["Collection_Type", Dict[str, Any]]]:
def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
result = self.resolved_object_cache.get_by_href(href)
if result is None:
return super().get_by_href(href)
Expand All @@ -329,7 +319,7 @@ def contains_id(self, collection_id: str) -> bool:

def cache(
self,
collection: Union["Collection_Type", Dict[str, Any]],
collection: Union[Collection, Dict[str, Any]],
href: Optional[str] = None,
) -> None:
super().cache(collection, href)
Expand All @@ -339,7 +329,7 @@ def merge(
resolved_object_cache: ResolvedObjectCache,
first: Optional["ResolvedObjectCollectionCache"],
second: Optional["ResolvedObjectCollectionCache"],
) -> "ResolvedObjectCollectionCache":
) -> ResolvedObjectCollectionCache:
first_cached_ids = {}
if first is not None:
first_cached_ids = copy(first.cached_ids)
Expand Down
Loading