From 1f8e6ff367eed7842f2edf86f0ad664b1d2dfc83 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Sun, 7 Jan 2024 22:05:08 +0100 Subject: [PATCH] Update for arrow --- pandas/core/arrays/interval.py | 20 ++++++++++++------- .../indexes/interval/test_constructors.py | 10 ++++++++-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 475e91443a423..2dbc2a663c8a8 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -80,6 +80,7 @@ unique, value_counts_internal as value_counts, ) +from pandas.core.arrays import ArrowExtensionArray from pandas.core.arrays.base import ( ExtensionArray, _extension_array_shared_docs, @@ -369,13 +370,18 @@ def _ensure_simple_new_inputs( right = ensure_wrapped_if_datetimelike(right) right = extract_array(right, extract_numpy=True) - lbase = getattr(left, "_ndarray", left) - lbase = getattr(lbase, "_data", lbase).base - rbase = getattr(right, "_ndarray", right) - rbase = getattr(rbase, "_data", rbase).base - if lbase is not None and lbase is rbase: - # If these share data, then setitem could corrupt our IA - right = right.copy() + if isinstance(left, ArrowExtensionArray) or isinstance( + right, ArrowExtensionArray + ): + pass + else: + lbase = getattr(left, "_ndarray", left) + lbase = getattr(lbase, "_data", lbase).base + rbase = getattr(right, "_ndarray", right) + rbase = getattr(rbase, "_data", rbase).base + if lbase is not None and lbase is rbase: + # If these share data, then setitem could corrupt our IA + right = right.copy() dtype = IntervalDtype(left.dtype, closed=closed) diff --git a/pandas/tests/indexes/interval/test_constructors.py b/pandas/tests/indexes/interval/test_constructors.py index b8b0d154a07f0..b0289ded55604 100644 --- a/pandas/tests/indexes/interval/test_constructors.py +++ b/pandas/tests/indexes/interval/test_constructors.py @@ -3,6 +3,8 @@ import numpy as np import pytest +import pandas.util._test_decorators as td + from pandas.core.dtypes.common import is_unsigned_integer_dtype from pandas.core.dtypes.dtypes import IntervalDtype @@ -510,10 +512,14 @@ def test_dtype_closed_mismatch(): IntervalArray([], dtype=dtype, closed="neither") -def test_masked_dtype(): +@pytest.mark.parametrize( + "dtype", + ["Float64", pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow"))], +) +def test_ea_dtype(dtype): # GH#56765 bins = [(0.0, 0.4), (0.4, 0.6)] - interval_dtype = IntervalDtype(subtype="Float64", closed="left") + interval_dtype = IntervalDtype(subtype=dtype, closed="left") result = IntervalIndex.from_tuples(bins, closed="left", dtype=interval_dtype) assert result.dtype == interval_dtype expected = IntervalIndex.from_tuples(bins, closed="left").astype(interval_dtype)