Skip to content

Commit

Permalink
Type hint tweaks and refactors from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Johan Bergsma <[email protected]>
  • Loading branch information
ml-evs and JPBergsma committed Nov 28, 2022
1 parent 089eff5 commit b9185cc
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 38 deletions.
11 changes: 6 additions & 5 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class OptimadeClient:

def __init__(
self,
base_urls: Optional[Union[str, List[str]]] = None,
base_urls: Optional[Union[str, Iterable[str]]] = None,
max_results_per_provider: int = 1000,
headers: Optional[Dict] = None,
http_timeout: int = 10,
Expand All @@ -111,14 +111,15 @@ def __init__(
"""

if not base_urls:
base_urls = get_all_databases() # type: ignore[assignment]

self.max_results_per_provider = max_results_per_provider
if self.max_results_per_provider in (-1, 0):
self.max_results_per_provider = None

self.base_urls = base_urls # type: ignore[assignment]
if not base_urls:
self.base_urls = get_all_databases()
else:
self.base_urls = base_urls

if isinstance(self.base_urls, str):
self.base_urls = [self.base_urls]
self.base_urls = list(self.base_urls)
Expand Down
17 changes: 9 additions & 8 deletions optimade/filterparser/lark_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple

Expand All @@ -21,7 +20,7 @@ class ParserError(Exception):
"""


def get_versions() -> Dict[Tuple[int, int, int], Dict[str, str]]:
def get_versions() -> Dict[Tuple[int, int, int], Dict[str, Path]]:
"""Find grammar files within this package's grammar directory,
returning a dictionary broken down by scraped grammar version
(major, minor, patch) and variant (a string tag).
Expand All @@ -30,13 +29,15 @@ def get_versions() -> Dict[Tuple[int, int, int], Dict[str, str]]:
A mapping from version, variant to grammar file name.
"""
dct: Dict[Tuple[int, int, int], Dict[str, Path]] = defaultdict(dict)
dct: Dict[Tuple[int, int, int], Dict[str, Path]] = {}
for filename in Path(__file__).parent.joinpath("../grammar").glob("*.lark"):
tags = filename.stem.lstrip("v").split(".")
version = tuple(map(int, tags[:3])) # ignore: type[index]
variant = "default" if len(tags) == 3 else tags[-1]
dct[version][variant] = filename # type: ignore[index]
return dict(dct) # type: ignore[arg-type]
version: Tuple[int, int, int] = (int(tags[0]), int(tags[1]), int(tags[2]))
variant: str = "default" if len(tags) == 3 else str(tags[-1])
if version not in dct:
dct[version] = {}
dct[version][variant] = filename
return dct


AVAILABLE_PARSERS = get_versions()
Expand Down Expand Up @@ -98,7 +99,7 @@ def parse(self, filter_: str) -> Tree:
"""
try:
self.tree = self.lark.parse(filter_)
self.filter = filter_ # type: ignore[assignment]
self.filter = filter_
return self.tree
except Exception as exc:
raise BadRequest(
Expand Down
4 changes: 2 additions & 2 deletions optimade/models/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
EPS = 2**-23


Vector3D = conlist(float, min_items=3, max_items=3) # type: ignore[valid-type]
Vector3D_unknown = conlist(Union[float, None], min_items=3, max_items=3) # type: ignore[valid-type]
Vector3D = conlist(float, min_items=3, max_items=3)
Vector3D_unknown = conlist(Union[float, None], min_items=3, max_items=3)


class Periodicity(IntEnum):
Expand Down
4 changes: 2 additions & 2 deletions optimade/server/entry_collections/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.provider_prefix = CONFIG.provider.prefix
self.provider_fields = [
field if isinstance(field, str) else field["name"]
for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, []) # type: ignore[call-overload]
for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, [])
]

self._all_fields: Set[str] = set()
Expand Down Expand Up @@ -376,7 +376,7 @@ def parse_sort_params(self, sort_params: str) -> Iterable[Tuple[str, int]]:
BadRequest: if an invalid sort is requested.
Returns:
A tuple of tuples containing the aliased field name and
A list of tuples containing the aliased field name and
sort direction encoded as 1 (ascending) or -1 (descending).
"""
Expand Down
2 changes: 1 addition & 1 deletion optimade/server/mappers/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def all_length_aliases(cls) -> Tuple[Tuple[str, str], ...]:
from optimade.server.config import CONFIG

return cls.LENGTH_ALIASES + tuple(
CONFIG.length_aliases.get(cls.ENDPOINT, {}).items() # type: ignore[call-overload]
CONFIG.length_aliases.get(cls.ENDPOINT, {}).items()
)

@classmethod
Expand Down
16 changes: 7 additions & 9 deletions optimade/server/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,6 @@ def handle_api_hint(api_hint: List[str]) -> Union[None, str]:
major_api_hint = int(re.findall(r"/v([0-9]+)", api_hint_str)[0])
major_implementation = int(BASE_URL_PREFIXES["major"][len("/v") :])

if major_api_hint > major_implementation:
# Let's not try to handle a request for a newer major version
raise VersionNotSupported(
detail=(
f"The provided `api_hint` ({api_hint_str[1:]!r}) is not supported by this implementation. "
f"Supported versions include: {', '.join(BASE_URL_PREFIXES.values())}"
)
)
if major_api_hint <= major_implementation:
# If less than:
# Use the current implementation in hope that it can still handle older requests
Expand All @@ -192,7 +184,13 @@ def handle_api_hint(api_hint: List[str]) -> Union[None, str]:
# Go to /v<MAJOR>, since this should point to the latest available
return BASE_URL_PREFIXES["major"]

return None
# Let's not try to handle a request for a newer major version
raise VersionNotSupported(
detail=(
f"The provided `api_hint` ({api_hint_str[1:]!r}) is not supported by this implementation. "
f"Supported versions include: {', '.join(BASE_URL_PREFIXES.values())}"
)
)

@staticmethod
def is_versioned_base_url(url: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion optimade/server/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def handle_response_fields(

new_results = []
while results:
new_entry = results.pop(0).dict(exclude_unset=True, by_alias=True) # type: ignore[union-attr]
new_entry = results.pop(0).dict(exclude_unset=True, by_alias=True)

# Remove fields excluded by their omission in `response_fields`
for field in exclude_fields:
Expand Down
13 changes: 6 additions & 7 deletions optimade/validator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def get(self, request: str):
while retries < self.max_retries:
retries += 1
try:
self.response = requests.get( # type: ignore[assignment]
self.last_request, # type: ignore[arg-type]
self.response = requests.get(
self.last_request,
headers=self.headers,
timeout=(self.timeout, self.read_timeout),
)

status_code = self.response.status_code # type: ignore[attr-defined]
status_code = self.response.status_code
# If we hit a 429 Too Many Requests status, then try again in 1 second
if status_code != 429:
return self.response
Expand Down Expand Up @@ -369,19 +369,18 @@ def wrapper(
if not isinstance(result, ValidationError):
message += traceback.split("\n")

message = "\n".join(message) # type: ignore[assignment]

failure_type = None
if isinstance(result, InternalError):
summary = (
f"{request} - {test_fn.__name__} - failed with internal error"
)
failure_type = "internal"
else:
summary = f"{request} - {test_fn.__name__} - failed with error"
failure_type = "optional" if optional else None # type: ignore[assignment]
failure_type = "optional" if optional else None

validator.results.add_failure(
summary, message, failure_type=failure_type
summary, "\n".join(message), failure_type=failure_type
)

# set failure result to None as this is expected by other functions
Expand Down
2 changes: 0 additions & 2 deletions tests/server/test_server_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def test_versioned_base_urls(client, index_client, server: str):
This depends on the routers for each kind of server.
"""
import json

from optimade.server.routers.utils import BASE_URL_PREFIXES

Expand Down Expand Up @@ -129,7 +128,6 @@ def test_meta_schema_value_obeys_index(client, index_client, server: str):
"""Test that the reported `meta->schema` is correct for index/non-index
servers.
"""
import json

from optimade.server.config import CONFIG
from optimade.server.routers.utils import BASE_URL_PREFIXES
Expand Down
2 changes: 1 addition & 1 deletion tests/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
version = f"/v{__api_version__.split('.')[0]}"
self.version = version

def request( # type: ignore[override] # pylint: disable=too-many-locals
def request( # pylint: disable=too-many-locals
self,
method: str,
url: str,
Expand Down

0 comments on commit b9185cc

Please sign in to comment.