From 4ab3ff7cf2e3ff089509b921f103dd2fe57ddfda Mon Sep 17 00:00:00 2001 From: "Alan D. Snow" Date: Thu, 7 May 2020 20:37:21 -0500 Subject: [PATCH] ENH: Support objects with __array__ method (#625) --- .isort.cfg | 2 +- docs/history.rst | 1 + pyproj/utils.py | 11 +++++++++-- requirements-dev.txt | 3 ++- test/test_utils.py | 16 +++++++++++++--- 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 746d310c2..73fba75f5 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,6 +1,6 @@ [settings] line_length=88 multi_line_output=3 -known_third_party=mock,numpy,pkg_resources,pytest,setuptools +known_third_party=mock,numpy,pkg_resources,pytest,setuptools,xarray,pandas known_first_party=pyproj,test include_trailing_comma=true diff --git a/docs/history.rst b/docs/history.rst index a4cfde205..e4fa3c002 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -5,6 +5,7 @@ Change Log ----- * Minimum supported Python version 3.6 (issue #499) * Refactor Proj to inherit from Transformer (issue #624) +* ENH: Support obects with '__array__' method (pandas.Series, xarray.DataArray, dask.array.Array) (issue #573) 2.6.1 ~~~~~ diff --git a/pyproj/utils.py b/pyproj/utils.py index eb54e83cd..b8f2a7a6c 100644 --- a/pyproj/utils.py +++ b/pyproj/utils.py @@ -30,7 +30,11 @@ def _copytobuffer(xx: Any) -> Tuple[Any, bool, bool, bool]: isfloat = False islist = False istuple = False - # first, if it's a numpy array scalar convert to float + # check for pandas.Series, xarray.DataArray or dask.array.Array + if hasattr(xx, "__array__") and callable(xx.__array__): + xx = xx.__array__() + + # if it's a numpy array scalar convert to float # (array scalars don't support buffer API) if hasattr(xx, "shape"): if xx.shape == (): @@ -56,7 +60,10 @@ def _copytobuffer(xx: Any) -> Tuple[Any, bool, bool, bool]: # inx,isfloat,islist,istuple return inx, False, False, False except Exception: - raise TypeError("input must be an array, list, tuple or scalar") + raise TypeError( + "input must be an array, list, tuple, scalar, " + "or have the __array__ method." + ) else: # perhaps they are regular python arrays? if hasattr(xx, "typecode"): diff --git a/requirements-dev.txt b/requirements-dev.txt index f3efb3d57..e8bd27356 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ cython>=0.28.4 -black; python_version >= '3.6' +black flake8 mock mypy @@ -9,3 +9,4 @@ pytest>3.6 pytest-cov shapely pre-commit +xarray diff --git a/test/test_utils.py b/test/test_utils.py index 248813d0b..51b4c65a2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,6 +2,8 @@ import numpy import pytest +from pandas import Series +from xarray import DataArray from pyproj.utils import _copytobuffer, _copytobuffer_return_scalar @@ -20,6 +22,7 @@ def test__copytobuffer_return_scalar__invalid(): "in_data, is_float, is_list, is_tuple", [ (numpy.array(1), True, False, False), + (DataArray(numpy.array(1)), True, False, False), (1, True, False, False), ([1], False, True, False), ((1,), False, False, True), @@ -29,9 +32,16 @@ def test__copytobuffer(in_data, is_float, is_list, is_tuple): assert _copytobuffer(in_data) == (array("d", [1]), is_float, is_list, is_tuple) -def test__copytobuffer__numpy_array(): - in_arr = numpy.array([1]) - assert _copytobuffer(in_arr) == (in_arr.astype("d"), False, False, False) +@pytest.mark.parametrize( + "in_arr", [numpy.array([1]), DataArray(numpy.array([1])), Series(numpy.array([1]))], +) +def test__copytobuffer__numpy_array(in_arr): + assert _copytobuffer(in_arr) == ( + in_arr.astype("d").__array__(), + False, + False, + False, + ) def test__copytobuffer__invalid():