Skip to content

Commit

Permalink
ENH: Support objects with __array__ method (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored May 8, 2020
1 parent 8eb145e commit 4ab3ff7
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[settings]
line_length=88
multi_line_output=3
known_third_party=mock,numpy,pkg_resources,pytest,setuptools
known_third_party=mock,numpy,pkg_resources,pytest,setuptools,xarray,pandas
known_first_party=pyproj,test
include_trailing_comma=true
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Change Log
-----
* Minimum supported Python version 3.6 (issue #499)
* Refactor Proj to inherit from Transformer (issue #624)
* ENH: Support obects with '__array__' method (pandas.Series, xarray.DataArray, dask.array.Array) (issue #573)

2.6.1
~~~~~
Expand Down
11 changes: 9 additions & 2 deletions pyproj/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def _copytobuffer(xx: Any) -> Tuple[Any, bool, bool, bool]:
isfloat = False
islist = False
istuple = False
# first, if it's a numpy array scalar convert to float
# check for pandas.Series, xarray.DataArray or dask.array.Array
if hasattr(xx, "__array__") and callable(xx.__array__):
xx = xx.__array__()

# if it's a numpy array scalar convert to float
# (array scalars don't support buffer API)
if hasattr(xx, "shape"):
if xx.shape == ():
Expand All @@ -56,7 +60,10 @@ def _copytobuffer(xx: Any) -> Tuple[Any, bool, bool, bool]:
# inx,isfloat,islist,istuple
return inx, False, False, False
except Exception:
raise TypeError("input must be an array, list, tuple or scalar")
raise TypeError(
"input must be an array, list, tuple, scalar, "
"or have the __array__ method."
)
else:
# perhaps they are regular python arrays?
if hasattr(xx, "typecode"):
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cython>=0.28.4
black; python_version >= '3.6'
black
flake8
mock
mypy
Expand All @@ -9,3 +9,4 @@ pytest>3.6
pytest-cov
shapely
pre-commit
xarray
16 changes: 13 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy
import pytest
from pandas import Series
from xarray import DataArray

from pyproj.utils import _copytobuffer, _copytobuffer_return_scalar

Expand All @@ -20,6 +22,7 @@ def test__copytobuffer_return_scalar__invalid():
"in_data, is_float, is_list, is_tuple",
[
(numpy.array(1), True, False, False),
(DataArray(numpy.array(1)), True, False, False),
(1, True, False, False),
([1], False, True, False),
((1,), False, False, True),
Expand All @@ -29,9 +32,16 @@ def test__copytobuffer(in_data, is_float, is_list, is_tuple):
assert _copytobuffer(in_data) == (array("d", [1]), is_float, is_list, is_tuple)


def test__copytobuffer__numpy_array():
in_arr = numpy.array([1])
assert _copytobuffer(in_arr) == (in_arr.astype("d"), False, False, False)
@pytest.mark.parametrize(
"in_arr", [numpy.array([1]), DataArray(numpy.array([1])), Series(numpy.array([1]))],
)
def test__copytobuffer__numpy_array(in_arr):
assert _copytobuffer(in_arr) == (
in_arr.astype("d").__array__(),
False,
False,
False,
)


def test__copytobuffer__invalid():
Expand Down

0 comments on commit 4ab3ff7

Please sign in to comment.