From b9185cc79c03039df916102491feeba5aface385 Mon Sep 17 00:00:00 2001 From: Matthew Evans Date: Fri, 18 Nov 2022 11:01:45 +0000 Subject: [PATCH] Type hint tweaks and refactors from code review Co-authored-by: Johan Bergsma --- optimade/client/client.py | 11 ++++++----- optimade/filterparser/lark_parser.py | 17 +++++++++-------- optimade/models/structures.py | 4 ++-- .../entry_collections/entry_collections.py | 4 ++-- optimade/server/mappers/entries.py | 2 +- optimade/server/middleware.py | 16 +++++++--------- optimade/server/routers/utils.py | 2 +- optimade/validator/utils.py | 13 ++++++------- tests/server/test_server_validation.py | 2 -- tests/server/utils.py | 2 +- 10 files changed, 35 insertions(+), 38 deletions(-) diff --git a/optimade/client/client.py b/optimade/client/client.py index 1a8f4cb9b..c173c75db 100644 --- a/optimade/client/client.py +++ b/optimade/client/client.py @@ -91,7 +91,7 @@ class OptimadeClient: def __init__( self, - base_urls: Optional[Union[str, List[str]]] = None, + base_urls: Optional[Union[str, Iterable[str]]] = None, max_results_per_provider: int = 1000, headers: Optional[Dict] = None, http_timeout: int = 10, @@ -111,14 +111,15 @@ def __init__( """ - if not base_urls: - base_urls = get_all_databases() # type: ignore[assignment] - self.max_results_per_provider = max_results_per_provider if self.max_results_per_provider in (-1, 0): self.max_results_per_provider = None - self.base_urls = base_urls # type: ignore[assignment] + if not base_urls: + self.base_urls = get_all_databases() + else: + self.base_urls = base_urls + if isinstance(self.base_urls, str): self.base_urls = [self.base_urls] self.base_urls = list(self.base_urls) diff --git a/optimade/filterparser/lark_parser.py b/optimade/filterparser/lark_parser.py index 21e0b89fc..775f1a315 100644 --- a/optimade/filterparser/lark_parser.py +++ b/optimade/filterparser/lark_parser.py @@ -4,7 +4,6 @@ """ -from collections import defaultdict from pathlib import Path from typing import Dict, Optional, Tuple @@ -21,7 +20,7 @@ class ParserError(Exception): """ -def get_versions() -> Dict[Tuple[int, int, int], Dict[str, str]]: +def get_versions() -> Dict[Tuple[int, int, int], Dict[str, Path]]: """Find grammar files within this package's grammar directory, returning a dictionary broken down by scraped grammar version (major, minor, patch) and variant (a string tag). @@ -30,13 +29,15 @@ def get_versions() -> Dict[Tuple[int, int, int], Dict[str, str]]: A mapping from version, variant to grammar file name. """ - dct: Dict[Tuple[int, int, int], Dict[str, Path]] = defaultdict(dict) + dct: Dict[Tuple[int, int, int], Dict[str, Path]] = {} for filename in Path(__file__).parent.joinpath("../grammar").glob("*.lark"): tags = filename.stem.lstrip("v").split(".") - version = tuple(map(int, tags[:3])) # ignore: type[index] - variant = "default" if len(tags) == 3 else tags[-1] - dct[version][variant] = filename # type: ignore[index] - return dict(dct) # type: ignore[arg-type] + version: Tuple[int, int, int] = (int(tags[0]), int(tags[1]), int(tags[2])) + variant: str = "default" if len(tags) == 3 else str(tags[-1]) + if version not in dct: + dct[version] = {} + dct[version][variant] = filename + return dct AVAILABLE_PARSERS = get_versions() @@ -98,7 +99,7 @@ def parse(self, filter_: str) -> Tree: """ try: self.tree = self.lark.parse(filter_) - self.filter = filter_ # type: ignore[assignment] + self.filter = filter_ return self.tree except Exception as exc: raise BadRequest( diff --git a/optimade/models/structures.py b/optimade/models/structures.py index a9daca7f1..14262ecef 100644 --- a/optimade/models/structures.py +++ b/optimade/models/structures.py @@ -39,8 +39,8 @@ EPS = 2**-23 -Vector3D = conlist(float, min_items=3, max_items=3) # type: ignore[valid-type] -Vector3D_unknown = conlist(Union[float, None], min_items=3, max_items=3) # type: ignore[valid-type] +Vector3D = conlist(float, min_items=3, max_items=3) +Vector3D_unknown = conlist(Union[float, None], min_items=3, max_items=3) class Periodicity(IntEnum): diff --git a/optimade/server/entry_collections/entry_collections.py b/optimade/server/entry_collections/entry_collections.py index 7236a8dd8..de5860a31 100644 --- a/optimade/server/entry_collections/entry_collections.py +++ b/optimade/server/entry_collections/entry_collections.py @@ -91,7 +91,7 @@ def __init__( self.provider_prefix = CONFIG.provider.prefix self.provider_fields = [ field if isinstance(field, str) else field["name"] - for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, []) # type: ignore[call-overload] + for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, []) ] self._all_fields: Set[str] = set() @@ -376,7 +376,7 @@ def parse_sort_params(self, sort_params: str) -> Iterable[Tuple[str, int]]: BadRequest: if an invalid sort is requested. Returns: - A tuple of tuples containing the aliased field name and + A list of tuples containing the aliased field name and sort direction encoded as 1 (ascending) or -1 (descending). """ diff --git a/optimade/server/mappers/entries.py b/optimade/server/mappers/entries.py index 87b673ce6..fc1953696 100644 --- a/optimade/server/mappers/entries.py +++ b/optimade/server/mappers/entries.py @@ -173,7 +173,7 @@ def all_length_aliases(cls) -> Tuple[Tuple[str, str], ...]: from optimade.server.config import CONFIG return cls.LENGTH_ALIASES + tuple( - CONFIG.length_aliases.get(cls.ENDPOINT, {}).items() # type: ignore[call-overload] + CONFIG.length_aliases.get(cls.ENDPOINT, {}).items() ) @classmethod diff --git a/optimade/server/middleware.py b/optimade/server/middleware.py index 84c02902f..99196715b 100644 --- a/optimade/server/middleware.py +++ b/optimade/server/middleware.py @@ -176,14 +176,6 @@ def handle_api_hint(api_hint: List[str]) -> Union[None, str]: major_api_hint = int(re.findall(r"/v([0-9]+)", api_hint_str)[0]) major_implementation = int(BASE_URL_PREFIXES["major"][len("/v") :]) - if major_api_hint > major_implementation: - # Let's not try to handle a request for a newer major version - raise VersionNotSupported( - detail=( - f"The provided `api_hint` ({api_hint_str[1:]!r}) is not supported by this implementation. " - f"Supported versions include: {', '.join(BASE_URL_PREFIXES.values())}" - ) - ) if major_api_hint <= major_implementation: # If less than: # Use the current implementation in hope that it can still handle older requests @@ -192,7 +184,13 @@ def handle_api_hint(api_hint: List[str]) -> Union[None, str]: # Go to /v, since this should point to the latest available return BASE_URL_PREFIXES["major"] - return None + # Let's not try to handle a request for a newer major version + raise VersionNotSupported( + detail=( + f"The provided `api_hint` ({api_hint_str[1:]!r}) is not supported by this implementation. " + f"Supported versions include: {', '.join(BASE_URL_PREFIXES.values())}" + ) + ) @staticmethod def is_versioned_base_url(url: str) -> bool: diff --git a/optimade/server/routers/utils.py b/optimade/server/routers/utils.py index 74e0f11e7..b4003bdd0 100644 --- a/optimade/server/routers/utils.py +++ b/optimade/server/routers/utils.py @@ -115,7 +115,7 @@ def handle_response_fields( new_results = [] while results: - new_entry = results.pop(0).dict(exclude_unset=True, by_alias=True) # type: ignore[union-attr] + new_entry = results.pop(0).dict(exclude_unset=True, by_alias=True) # Remove fields excluded by their omission in `response_fields` for field in exclude_fields: diff --git a/optimade/validator/utils.py b/optimade/validator/utils.py index 591604111..0c97ce601 100644 --- a/optimade/validator/utils.py +++ b/optimade/validator/utils.py @@ -232,13 +232,13 @@ def get(self, request: str): while retries < self.max_retries: retries += 1 try: - self.response = requests.get( # type: ignore[assignment] - self.last_request, # type: ignore[arg-type] + self.response = requests.get( + self.last_request, headers=self.headers, timeout=(self.timeout, self.read_timeout), ) - status_code = self.response.status_code # type: ignore[attr-defined] + status_code = self.response.status_code # If we hit a 429 Too Many Requests status, then try again in 1 second if status_code != 429: return self.response @@ -369,8 +369,7 @@ def wrapper( if not isinstance(result, ValidationError): message += traceback.split("\n") - message = "\n".join(message) # type: ignore[assignment] - + failure_type = None if isinstance(result, InternalError): summary = ( f"{request} - {test_fn.__name__} - failed with internal error" @@ -378,10 +377,10 @@ def wrapper( failure_type = "internal" else: summary = f"{request} - {test_fn.__name__} - failed with error" - failure_type = "optional" if optional else None # type: ignore[assignment] + failure_type = "optional" if optional else None validator.results.add_failure( - summary, message, failure_type=failure_type + summary, "\n".join(message), failure_type=failure_type ) # set failure result to None as this is expected by other functions diff --git a/tests/server/test_server_validation.py b/tests/server/test_server_validation.py index d570171ec..2d2a38f7c 100644 --- a/tests/server/test_server_validation.py +++ b/tests/server/test_server_validation.py @@ -95,7 +95,6 @@ def test_versioned_base_urls(client, index_client, server: str): This depends on the routers for each kind of server. """ - import json from optimade.server.routers.utils import BASE_URL_PREFIXES @@ -129,7 +128,6 @@ def test_meta_schema_value_obeys_index(client, index_client, server: str): """Test that the reported `meta->schema` is correct for index/non-index servers. """ - import json from optimade.server.config import CONFIG from optimade.server.routers.utils import BASE_URL_PREFIXES diff --git a/tests/server/utils.py b/tests/server/utils.py index b94f0807a..c92075be9 100644 --- a/tests/server/utils.py +++ b/tests/server/utils.py @@ -49,7 +49,7 @@ def __init__( version = f"/v{__api_version__.split('.')[0]}" self.version = version - def request( # type: ignore[override] # pylint: disable=too-many-locals + def request( # pylint: disable=too-many-locals self, method: str, url: str,