Skip to content

Commit

Permalink
(fix): equality check against singleton PandasExtensionArray (#9032)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold authored May 22, 2024
1 parent 12123be commit 9e240c5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 0 additions & 2 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def __setitem__(self, key, val):
self.array[key] = val

def __eq__(self, other):
if np.isscalar(other):
other = type(self)(type(self.array)([other]))
if isinstance(other, PandasExtensionArray):
return self.array == other.array
return self.array == other
Expand Down
13 changes: 9 additions & 4 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2):
).all()

@requires_pyarrow
def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2):
def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2):
concatenated = concatenate(
(PandasExtensionArray(arrow1), PandasExtensionArray(arrow2))
)
Expand Down Expand Up @@ -1024,19 +1024,24 @@ def test_push_dask():
np.testing.assert_equal(actual, expected)


def test_duck_extension_array_equality(categorical1, int1):
def test_extension_array_equality(categorical1, int1):
int_duck_array = PandasExtensionArray(int1)
categorical_duck_array = PandasExtensionArray(categorical1)
assert (int_duck_array != categorical_duck_array).all()
assert (categorical_duck_array == categorical1).all()
assert (int_duck_array[0:2] == int1[0:2]).all()


def test_duck_extension_array_repr(int1):
def test_extension_array_singleton_equality(categorical1):
categorical_duck_array = PandasExtensionArray(categorical1)
assert (categorical_duck_array != "cat3").all()


def test_extension_array_repr(int1):
int_duck_array = PandasExtensionArray(int1)
assert repr(int1) in repr(int_duck_array)


def test_duck_extension_array_attr(int1):
def test_extension_array_attr(int1):
int_duck_array = PandasExtensionArray(int1)
assert (~int_duck_array.fillna(10)).all()

0 comments on commit 9e240c5

Please sign in to comment.