Skip to content

Commit

Permalink
Add type hints to wrappers.py (#1835)
Browse files Browse the repository at this point in the history
* refactor: add type hints to wrappers.py

* use _SupportsComparison type from typeshed

* escape imported types in quotes

* simplify casts

* fixed linter errors

---------

Co-authored-by: Miguel Grinberg <[email protected]>
(cherry picked from commit 2c79b48)
  • Loading branch information
Caiofcas authored and github-actions[bot] committed Jun 17, 2024
1 parent 2a5d495 commit a535382
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 22 deletions.
26 changes: 20 additions & 6 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,25 @@

import collections.abc
from copy import copy
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union

from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from .exceptions import UnknownDslObject, ValidationException

JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]]
# Usefull types

JSONType: TypeAlias = Union[
int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]
]


# Type variables for internals

_KeyT = TypeVar("_KeyT")
_ValT = TypeVar("_ValT")

# Constants

SKIP_VALUES = ("", None)
EXPAND__TO_DOT = True
Expand Down Expand Up @@ -110,18 +122,20 @@ def to_list(self):
return self._l_


class AttrDict:
class AttrDict(Generic[_KeyT, _ValT]):
"""
Helper class to provide attribute like access (read and write) to
dictionaries. Used to provide a convenient way to access both results and
nested dsl dicts.
"""

def __init__(self, d):
_d_: Dict[_KeyT, _ValT]

def __init__(self, d: Dict[_KeyT, _ValT]):
# assign the inner dict manually to prevent __setattr__ from firing
super().__setattr__("_d_", d)

def __contains__(self, key):
def __contains__(self, key: object) -> bool:
return key in self._d_

def __nonzero__(self):
Expand Down
63 changes: 52 additions & 11 deletions elasticsearch_dsl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,61 @@
# under the License.

import operator
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
Dict,
Literal,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

if TYPE_CHECKING:
from _operator import _SupportsComparison

from typing_extensions import TypeAlias

from .utils import AttrDict

ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"]
RangeValT = TypeVar("RangeValT", bound="_SupportsComparison")

__all__ = ["Range"]


class Range(AttrDict):
OPS = {
class Range(AttrDict[ComparisonOperators, RangeValT]):
OPS: ClassVar[
Mapping[
ComparisonOperators,
Callable[["_SupportsComparison", "_SupportsComparison"], bool],
]
] = {
"lt": operator.lt,
"lte": operator.le,
"gt": operator.gt,
"gte": operator.ge,
}

def __init__(self, *args, **kwargs):
if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)):
def __init__(
self,
d: Optional[Dict[ComparisonOperators, RangeValT]] = None,
/,
**kwargs: RangeValT,
):
if d is not None and (kwargs or not isinstance(d, dict)):
raise ValueError(
"Range accepts a single dictionary or a set of keyword arguments."
)
data = args[0] if args else kwargs

if d is None:
data = cast(Dict[ComparisonOperators, RangeValT], kwargs)
else:
data = d

for k in data:
if k not in self.OPS:
Expand All @@ -47,30 +82,36 @@ def __init__(self, *args, **kwargs):
if "lt" in data and "lte" in data:
raise ValueError("You cannot specify both lt and lte for Range.")

super().__init__(args[0] if args else kwargs)
super().__init__(data)

def __repr__(self):
def __repr__(self) -> str:
return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items())

def __contains__(self, item):
def __contains__(self, item: object) -> bool:
if isinstance(item, str):
return super().__contains__(item)

item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS)
if not item_supports_comp:
return False

for op in self.OPS:
if op in self._d_ and not self.OPS[op](item, self._d_[op]):
if op in self._d_ and not self.OPS[op](
cast("_SupportsComparison", item), self._d_[op]
):
return False
return True

@property
def upper(self):
def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
if "lt" in self._d_:
return self._d_["lt"], False
if "lte" in self._d_:
return self._d_["lte"], True
return None, False

@property
def lower(self):
def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
if "gt" in self._d_:
return self._d_["gt"], False
if "gte" in self._d_:
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
TYPED_FILES = (
"elasticsearch_dsl/function.py",
"elasticsearch_dsl/query.py",
"elasticsearch_dsl/wrappers.py",
"tests/test_query.py",
"tests/test_wrappers.py",
)


Expand Down
28 changes: 23 additions & 5 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
# under the License.

from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence

if TYPE_CHECKING:
from _operator import _SupportsComparison

import pytest

Expand All @@ -34,7 +38,9 @@
({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()),
],
)
def test_range_contains(kwargs, item):
def test_range_contains(
kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison"
) -> None:
assert item in Range(**kwargs)


Expand All @@ -48,7 +54,9 @@ def test_range_contains(kwargs, item):
({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()),
],
)
def test_range_not_contains(kwargs, item):
def test_range_not_contains(
kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison"
) -> None:
assert item not in Range(**kwargs)


Expand All @@ -62,7 +70,9 @@ def test_range_not_contains(kwargs, item):
((), {"gt": 1, "gte": 1}),
],
)
def test_range_raises_value_error_on_wrong_params(args, kwargs):
def test_range_raises_value_error_on_wrong_params(
args: Sequence[Any], kwargs: Mapping[str, "_SupportsComparison"]
) -> None:
with pytest.raises(ValueError):
Range(*args, **kwargs)

Expand All @@ -76,7 +86,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs):
(Range(lt=42), None, False),
],
)
def test_range_lower(range, lower, inclusive):
def test_range_lower(
range: Range["_SupportsComparison"],
lower: Optional["_SupportsComparison"],
inclusive: bool,
) -> None:
assert (lower, inclusive) == range.lower


Expand All @@ -89,5 +103,9 @@ def test_range_lower(range, lower, inclusive):
(Range(gt=42), None, False),
],
)
def test_range_upper(range, upper, inclusive):
def test_range_upper(
range: Range["_SupportsComparison"],
upper: Optional["_SupportsComparison"],
inclusive: bool,
) -> None:
assert (upper, inclusive) == range.upper

0 comments on commit a535382

Please sign in to comment.