Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: __getitem__ method of EA #37898

Merged
merged 1 commit into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Optional, Sequence, Type, TypeVar
from __future__ import annotations

from typing import Any, Optional, Sequence, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -212,7 +214,9 @@ def __setitem__(self, key, value):
def _validate_setitem_value(self, value):
return value

def __getitem__(self, key):
def __getitem__(
self: NDArrayBackedExtensionArrayT, key: Union[int, slice, np.ndarray]
) -> Union[NDArrayBackedExtensionArrayT, Any]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jorisvandenbossche if typevar is not used, mypy gives

pandas\core\arrays\datetimelike.py:280: error: Item "NDArrayBackedExtensionArray" of "Union[NDArrayBackedExtensionArray, Any]" has no attribute "_freq"  [union-attr]

if lib.is_integer(key):
# fast-path
result = self._ndarray[key]
Expand Down
19 changes: 11 additions & 8 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
This is an experimental API and subject to breaking changes
without warning.
"""
from __future__ import annotations

import operator
from typing import (
Any,
Expand Down Expand Up @@ -254,8 +256,9 @@ def _from_factorized(cls, values, original):
# Must be a Sequence
# ------------------------------------------------------------------------

def __getitem__(self, item):
# type (Any) -> Any
def __getitem__(
self, item: Union[int, slice, np.ndarray]
) -> Union[ExtensionArray, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should ExtensionArray be a typevar here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it should. Until there is consensus on usage have only added where needed to keep mypy green. see #37898 (comment), #37817 (comment), #31728 (comment), #31384 (comment) and others

"""
Select a subset of self.

Expand Down Expand Up @@ -661,7 +664,7 @@ def dropna(self):
"""
return self[~self.isna()]

def shift(self, periods: int = 1, fill_value: object = None) -> "ExtensionArray":
def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
"""
Shift values by desired number.

Expand Down Expand Up @@ -831,7 +834,7 @@ def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
"""
return self.astype(object), np.nan

def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, "ExtensionArray"]:
def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, ExtensionArray]:
"""
Encode the extension array as an enumerated type.

Expand Down Expand Up @@ -940,7 +943,7 @@ def take(
*,
allow_fill: bool = False,
fill_value: Any = None,
) -> "ExtensionArray":
) -> ExtensionArray:
"""
Take elements from an array.

Expand Down Expand Up @@ -1109,7 +1112,7 @@ def _formatter(self, boxed: bool = False) -> Callable[[Any], Optional[str]]:
# Reshaping
# ------------------------------------------------------------------------

def transpose(self, *axes) -> "ExtensionArray":
def transpose(self, *axes) -> ExtensionArray:
"""
Return a transposed view on this array.

Expand All @@ -1119,10 +1122,10 @@ def transpose(self, *axes) -> "ExtensionArray":
return self[:]

@property
def T(self) -> "ExtensionArray":
def T(self) -> ExtensionArray:
return self.transpose()

def ravel(self, order="C") -> "ExtensionArray":
def ravel(self, order="C") -> ExtensionArray:
"""
Return a flattened view on this array.

Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from datetime import datetime, timedelta
import operator
from typing import (
Expand Down Expand Up @@ -264,7 +266,9 @@ def __array__(self, dtype=None) -> np.ndarray:
return np.array(list(self), dtype=object)
return self._ndarray

def __getitem__(self, key):
def __getitem__(
self, key: Union[int, slice, np.ndarray]
) -> Union[DatetimeLikeArrayMixin, DTScalarOrNaT]:
"""
This getitem defers to the underlying array, which by-definition can
only handle list-likes, slices, and integer scalars
Expand Down
8 changes: 5 additions & 3 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, time, timedelta, tzinfo
from typing import Optional, Union
from typing import Optional, Union, cast
import warnings

import numpy as np
Expand Down Expand Up @@ -444,9 +444,11 @@ def _generate_range(
)

if not left_closed and len(index) and index[0] == start:
index = index[1:]
# TODO: overload DatetimeLikeArrayMixin.__getitem__
index = cast(DatetimeArray, index[1:])
if not right_closed and len(index) and index[-1] == end:
index = index[:-1]
# TODO: overload DatetimeLikeArrayMixin.__getitem__
index = cast(DatetimeArray, index[:-1])

dtype = tz_to_dtype(tz)
return cls._simple_new(index.asi8, freq=freq, dtype=dtype)
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Type, TypeVar
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -56,7 +58,7 @@ def itemsize(self) -> int:
return self.numpy_dtype.itemsize

@classmethod
def construct_array_type(cls) -> Type["BaseMaskedArray"]:
def construct_array_type(cls) -> Type[BaseMaskedArray]:
"""
Return the array type associated with this dtype.

Expand Down Expand Up @@ -100,7 +102,9 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
def dtype(self) -> BaseMaskedDtype:
raise AbstractMethodError(self)

def __getitem__(self, item):
def __getitem__(
self, item: Union[int, slice, np.ndarray]
) -> Union[BaseMaskedArray, Any]:
if is_integer(item):
if self._mask[item]:
return self.dtype.na_value
Expand Down
12 changes: 6 additions & 6 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,10 +1671,10 @@ def first(self, numeric_only: bool = False, min_count: int = -1):
def first_compat(obj: FrameOrSeries, axis: int = 0):
def first(x: Series):
"""Helper function for first item that isn't NA."""
x = x.array[notna(x.array)]
if len(x) == 0:
arr = x.array[notna(x.array)]
if not len(arr):
return np.nan
return x[0]
return arr[0]

if isinstance(obj, DataFrame):
return obj.apply(first, axis=axis)
Expand All @@ -1695,10 +1695,10 @@ def last(self, numeric_only: bool = False, min_count: int = -1):
def last_compat(obj: FrameOrSeries, axis: int = 0):
def last(x: Series):
"""Helper function for last item that isn't NA."""
x = x.array[notna(x.array)]
if len(x) == 0:
arr = x.array[notna(x.array)]
if not len(arr):
return np.nan
return x[-1]
return arr[-1]

if isinstance(obj, DataFrame):
return obj.apply(last, axis=axis)
Expand Down
19 changes: 14 additions & 5 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3212,8 +3212,14 @@ def _get_nearest_indexer(self, target: "Index", limit, tolerance) -> np.ndarray:
right_indexer = self.get_indexer(target, "backfill", limit=limit)

target_values = target._values
left_distances = np.abs(self._values[left_indexer] - target_values)
right_distances = np.abs(self._values[right_indexer] - target_values)
# error: Unsupported left operand type for - ("ExtensionArray")
left_distances = np.abs(
self._values[left_indexer] - target_values # type: ignore[operator]
)
# error: Unsupported left operand type for - ("ExtensionArray")
right_distances = np.abs(
self._values[right_indexer] - target_values # type: ignore[operator]
)

op = operator.lt if self.is_monotonic_increasing else operator.le
indexer = np.where(
Expand All @@ -3232,7 +3238,8 @@ def _filter_indexer_tolerance(
indexer: np.ndarray,
tolerance,
) -> np.ndarray:
distance = abs(self._values[indexer] - target)
# error: Unsupported left operand type for - ("ExtensionArray")
distance = abs(self._values[indexer] - target) # type: ignore[operator]
indexer = np.where(distance <= tolerance, indexer, -1)
return indexer

Expand Down Expand Up @@ -3436,6 +3443,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
target = ensure_has_len(target) # target may be an iterator

if not isinstance(target, Index) and len(target) == 0:
values: Union[range, ExtensionArray, np.ndarray]
if isinstance(self, ABCRangeIndex):
values = range(0)
else:
Expand Down Expand Up @@ -4528,8 +4536,9 @@ def asof_locs(self, where: "Index", mask) -> np.ndarray:

result = np.arange(len(self))[mask].take(locs)

first = mask.argmax()
result[(locs == 0) & (where._values < self._values[first])] = -1
# TODO: overload return type of ExtensionArray.__getitem__
first_value = cast(Any, self._values[mask.argmax()])
result[(locs == 0) & (where._values < first_value)] = -1

return result

Expand Down
7 changes: 5 additions & 2 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Any
from typing import Any, cast

import numpy as np

Expand Down Expand Up @@ -694,7 +694,10 @@ def difference(self, other, sort=None):

if self.equals(other):
# pass an empty PeriodArray with the appropriate dtype
return type(self)._simple_new(self._data[:0], name=self.name)

# TODO: overload DatetimeLikeArrayMixin.__getitem__
values = cast(PeriodArray, self._data[:0])
return type(self)._simple_new(values, name=self.name)

if is_object_dtype(other):
return self.astype(object).difference(other).astype(self.dtype)
Expand Down