From 76a57fde6f612e47db7ca9d8d11ec7ea63ecae12 Mon Sep 17 00:00:00 2001 From: Caio Fontes <38675540+Caiofcas@users.noreply.github.com> Date: Tue, 21 May 2024 21:04:52 +0200 Subject: [PATCH] Add Type hints to `function.py` (#1827) * feat: add first type annotations * feat: add _JSONSafeTypes annotation to to_dict methods * feat: add typing for SF function * chore: fix linting * rename _JSONSafeTypes to JSONType * format code --------- Co-authored-by: Miguel Grinberg --- elasticsearch_dsl/function.py | 62 ++++++++++++++++++++++++----------- elasticsearch_dsl/utils.py | 12 ++++--- noxfile.py | 1 + 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index ef77ce8e..635e049b 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -16,38 +16,55 @@ # under the License. import collections.abc -from typing import Dict +from copy import deepcopy +from typing import Any, ClassVar, Dict, MutableMapping, Optional, Union, overload -from .utils import DslBase +from .utils import DslBase, JSONType -# Incomplete annotation to not break query.py tests -def SF(name_or_sf, **params) -> "ScoreFunction": +@overload +def SF(name_or_sf: MutableMapping[str, Any]) -> "ScoreFunction": ... + + +@overload +def SF(name_or_sf: "ScoreFunction") -> "ScoreFunction": ... + + +@overload +def SF(name_or_sf: str, **params: Any) -> "ScoreFunction": ... + + +def SF( + name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]], + **params: Any, +) -> "ScoreFunction": # {"script_score": {"script": "_score"}, "filter": {}} - if isinstance(name_or_sf, collections.abc.Mapping): + if isinstance(name_or_sf, collections.abc.MutableMapping): if params: raise ValueError("SF() cannot accept parameters when passing in a dict.") - kwargs = {} - sf = name_or_sf.copy() + + kwargs: Dict[str, Any] = {} + sf = deepcopy(name_or_sf) for k in ScoreFunction._param_defs: if k in name_or_sf: kwargs[k] = sf.pop(k) # not sf, so just filter+weight, which used to be boost factor + sf_params = params if not sf: name = "boost_factor" # {'FUNCTION': {...}} elif len(sf) == 1: - name, params = sf.popitem() + name, sf_params = sf.popitem() else: raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}") # boost factor special case, see elasticsearch #6343 - if not isinstance(params, collections.abc.Mapping): - params = {"value": params} + if not isinstance(sf_params, collections.abc.Mapping): + sf_params = {"value": sf_params} # mix known params (from _param_defs) and from inside the function - kwargs.update(params) + kwargs.update(sf_params) return ScoreFunction.get_dsl_class(name)(**kwargs) # ScriptScore(script="_score", filter=Q()) @@ -70,14 +87,16 @@ class ScoreFunction(DslBase): "filter": {"type": "query"}, "weight": {}, } - name = None + name: ClassVar[Optional[str]] = None - def to_dict(self): + def to_dict(self) -> Dict[str, JSONType]: d = super().to_dict() # filter and query dicts should be at the same level as us for k in self._param_defs: - if k in d[self.name]: - d[k] = d[self.name].pop(k) + if self.name is not None: + val = d[self.name] + if isinstance(val, dict) and k in val: + d[k] = val.pop(k) return d @@ -88,12 +107,15 @@ class ScriptScore(ScoreFunction): class BoostFactor(ScoreFunction): name = "boost_factor" - def to_dict(self) -> Dict[str, int]: + def to_dict(self) -> Dict[str, JSONType]: d = super().to_dict() - if "value" in d[self.name]: - d[self.name] = d[self.name].pop("value") - else: - del d[self.name] + if self.name is not None: + val = d[self.name] + if isinstance(val, dict): + if "value" in val: + d[self.name] = val.pop("value") + else: + del d[self.name] return d diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index da6d4fa7..6e311316 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,12 +18,14 @@ import collections.abc from copy import copy -from typing import Any, Dict, Optional, Type +from typing import Any, ClassVar, Dict, List, Optional, Type, Union from typing_extensions import Self from .exceptions import UnknownDslObject, ValidationException +JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]] + SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -210,7 +212,7 @@ class DslMeta(type): For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`. """ - _types = {} + _types: ClassVar[Dict[str, Type["DslBase"]]] = {} def __init__(cls, name, bases, attrs): super().__init__(name, bases, attrs) @@ -251,7 +253,8 @@ class DslBase(metaclass=DslMeta): all values in the `must` attribute into Query objects) """ - _param_defs = {} + _type_name: ClassVar[str] + _param_defs: ClassVar[Dict[str, Dict[str, Union[str, bool]]]] = {} @classmethod def get_dsl_class( @@ -356,8 +359,7 @@ def __getattr__(self, name): return AttrDict(value) return value - # TODO: This type annotation can probably be made tighter - def to_dict(self) -> Dict[str, Dict[str, Any]]: + def to_dict(self) -> Dict[str, JSONType]: """ Serialize the DSL object to plain dict """ diff --git a/noxfile.py b/noxfile.py index 4ebbe717..f90f22f0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,6 +30,7 @@ ) TYPED_FILES = ( + "elasticsearch_dsl/function.py", "elasticsearch_dsl/query.py", "tests/test_query.py", )