From 2abb2762875d0ec484b59d41129c553cb23b4655 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 21 Dec 2023 20:01:10 -0800 Subject: [PATCH] Add xarray support Resolves #816. --- DOCS.md | 15 +++-- __coconut__/__init__.pyi | 2 + _coconut/__init__.pyi | 4 +- coconut/compiler/header.py | 6 +- coconut/compiler/templates/header.py_template | 65 ++++++++++++------- coconut/constants.py | 10 ++- coconut/root.py | 2 +- coconut/tests/src/extras.coco | 33 ++++++++-- 8 files changed, 98 insertions(+), 39 deletions(-) diff --git a/DOCS.md b/DOCS.md index 5b5150cfe..9c1671d6c 100644 --- a/DOCS.md +++ b/DOCS.md @@ -487,7 +487,7 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all - `numpy` objects are allowed seamlessly in Coconut's [implicit coefficient syntax](#implicit-function-application-and-coefficients), allowing the use of e.g. `A B**2` shorthand for `A * B**2` when `A` and `B` are `numpy` arrays (note: **not** `A @ B**2`). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). -Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `pandas`/`jax`-specific methods over `numpy` methods when given `pandas`/`jax` objects. +Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`xarray`](https://docs.xarray.dev/en/stable/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects. #### `xonsh` Support @@ -3383,14 +3383,8 @@ In Haskell, `fmap(func, obj)` takes a data type `obj` and returns a new data typ `fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, and `dict` as a variant of `map` that returns back an object of the same type. -The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor). - For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. _Deprecated: `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them._ -For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. - -For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). - For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to ```coconut_python async def fmap_over_async_iters(func, async_iter): @@ -3399,6 +3393,13 @@ async def fmap_over_async_iters(func, async_iter): ``` such that `fmap` can effectively be used as an async map. +Some objects from external libraries are also given special support: +* For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. +* For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). +* For [`xarray`](https://docs.xarray.dev/en/stable/) objects, `fmap` will first convert them into `pandas` objects, apply `fmap`, then convert them back. + +The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor). + _Deprecated: `fmap(func, obj, fallback_to_init=True)` will fall back to `obj.__class__(map(func, obj))` if no `fmap` implementation is available rather than raise `TypeError`._ ##### Example diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 2cba5f7c7..007cdcfab 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1466,6 +1466,8 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U], """ ... +_coconut_fmap = fmap + def _coconut_handle_cls_kwargs(**kwargs: _t.Dict[_t.Text, _t.Any]) -> _t.Callable[[_T], _T]: ... diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 31d9fd411..82d320478 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -109,8 +109,10 @@ npt = _npt # Fake, like typing zip_longest = _zip_longest numpy_modules: _t.Any = ... -pandas_numpy_modules: _t.Any = ... +xarray_modules: _t.Any = ... +pandas_modules: _t.Any = ... jax_numpy_modules: _t.Any = ... + tee_type: _t.Any = ... reiterables: _t.Any = ... fmappables: _t.Any = ... diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 1306fb2a2..2d14cbc88 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -33,8 +33,9 @@ justify_len, report_this_text, numpy_modules, - pandas_numpy_modules, + pandas_modules, jax_numpy_modules, + xarray_modules, self_match_types, is_data_var, data_defaults_var, @@ -291,7 +292,8 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): from_None=" from None" if target.startswith("3") else "", process_="process_" if target_info >= (3, 13) else "", numpy_modules=tuple_str_of(numpy_modules, add_quotes=True), - pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True), + xarray_modules=tuple_str_of(xarray_modules, add_quotes=True), + pandas_modules=tuple_str_of(pandas_modules, add_quotes=True), jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True), self_match_types=tuple_str_of(self_match_types), comma_bytearray=", bytearray" if not target.startswith("3") else "", diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 1e61620a6..f67cd339d 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -54,7 +54,8 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} else: abc.Sequence.register(numpy.ndarray) numpy_modules = {numpy_modules} - pandas_numpy_modules = {pandas_numpy_modules} + xarray_modules = {xarray_modules} + pandas_modules = {pandas_modules} jax_numpy_modules = {jax_numpy_modules} tee_type = type(itertools.tee((), 1)[0]) reiterables = abc.Sequence, abc.Mapping, abc.Set @@ -121,6 +122,20 @@ class _coconut_Sentinel(_coconut_baseclass): _coconut_sentinel = _coconut_Sentinel() def _coconut_get_base_module(obj): return obj.__class__.__module__.split(".", 1)[0] +def _coconut_xarray_to_pandas(obj): + import xarray + if isinstance(obj, xarray.Dataset): + return obj.to_dataframe() + elif isinstance(obj, xarray.DataArray): + return obj.to_series() + else: + return obj.to_pandas() +def _coconut_xarray_to_numpy(obj): + import xarray + if isinstance(obj, xarray.Dataset): + return obj.to_dataframe().to_numpy() + else: + return obj.to_numpy() class MatchError(_coconut_baseclass, Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below} max_val_repr_len = 500 @@ -752,8 +767,10 @@ Additionally supports Cartesian products of numpy arrays.""" if iterables: it_modules = [_coconut_get_base_module(it) for it in iterables] if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules): - if _coconut.any(mod in _coconut.pandas_numpy_modules for mod in it_modules): - iterables = tuple((it.to_numpy() if _coconut_get_base_module(it) in _coconut.pandas_numpy_modules else it) for it in iterables) + if _coconut.any(mod in _coconut.xarray_modules for mod in it_modules): + iterables = tuple((_coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) + if _coconut.any(mod in _coconut.pandas_modules for mod in it_modules): + iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules): from jax import numpy else: @@ -1605,7 +1622,9 @@ def fmap(func, obj, **kwargs): if result is not _coconut.NotImplemented: return result obj_module = _coconut_get_base_module(obj) - if obj_module in _coconut.pandas_numpy_modules: + if obj_module in _coconut.xarray_modules: + return {_coconut_}fmap(func, _coconut_xarray_to_pandas(obj)).to_xarray() + if obj_module in _coconut.pandas_modules: if obj.ndim <= 1: return obj.apply(func) return obj.apply(func, axis=obj.ndim-1) @@ -1941,7 +1960,9 @@ def all_equal(iterable): """ iterable_module = _coconut_get_base_module(iterable) if iterable_module in _coconut.numpy_modules: - if iterable_module in _coconut.pandas_numpy_modules: + if iterable_module in _coconut.xarray_modules: + iterable = _coconut_xarray_to_numpy(iterable) + elif iterable_module in _coconut.pandas_modules: iterable = iterable.to_numpy() return not _coconut.len(iterable) or (iterable == iterable[0]).all() first_item = _coconut_sentinel @@ -2014,8 +2035,11 @@ def _coconut_mk_anon_namedtuple(fields, types=None, of_kwargs=None): return NT return NT(**of_kwargs) def _coconut_ndim(arr): - if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): + arr_mod = _coconut_get_base_module(arr) + if (arr_mod in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): return arr.ndim + if arr_mod in _coconut.xarray_modules:{COMMENT.if_we_got_here_its_a_Dataset_not_a_DataArray} + return 2 if not _coconut.isinstance(arr, _coconut.abc.Sequence) or _coconut.isinstance(arr, (_coconut.str, _coconut.bytes)): return 0 if _coconut.len(arr) == 0: @@ -2040,23 +2064,20 @@ def _coconut_expand_arr(arr, new_dims): arr = [arr] return arr def _coconut_concatenate(arrs, axis): - matconcat = None for a in arrs: if _coconut.hasattr(a.__class__, "__matconcat__"): - matconcat = a.__class__.__matconcat__ - break - a_module = _coconut_get_base_module(a) - if a_module in _coconut.pandas_numpy_modules: - from pandas import concat as matconcat - break - if a_module in _coconut.jax_numpy_modules: - from jax.numpy import concatenate as matconcat - break - if a_module in _coconut.numpy_modules: - matconcat = _coconut.numpy.concatenate - break - if matconcat is not None: - return matconcat(arrs, axis=axis) + return a.__class__.__matconcat__(arrs, axis=axis) + arr_modules = [_coconut_get_base_module(a) for a in arrs] + if any(mod in _coconut.xarray_modules for mod in arr_modules): + return _coconut_concatenate([(_coconut_xarray_to_pandas(a) if mod in _coconut.xarray_modules else a) for a, mod in _coconut.zip(arrs, arr_modules)], axis).to_xarray() + if any(mod in _coconut.pandas_modules for mod in arr_modules): + import pandas + return pandas.concat(arrs, axis=axis) + if any(mod in _coconut.jax_numpy_modules for mod in arr_modules): + import jax.numpy + return jax.numpy.concatenate(arrs, axis=axis) + if any(mod in _coconut.numpy_modules for mod in arr_modules): + return _coconut.numpy.concatenate(arrs, axis=axis) if not axis: return _coconut.list(_coconut.itertools.chain.from_iterable(arrs)) return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)] @@ -2209,4 +2230,4 @@ class _coconut_SupportsInv(_coconut.typing.Protocol): {def_async_map} {def_aliases} _coconut_self_match_types = {self_match_types} -_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file} +_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_fmap, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, fmap, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file} diff --git a/coconut/constants.py b/coconut/constants.py index 133e8dda5..a3268f3b3 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -180,7 +180,10 @@ def get_path_env_var(env_var, default): sys.setrecursionlimit(default_recursion_limit) # modules that numpy-like arrays can live in -pandas_numpy_modules = ( +xarray_modules = ( + "xarray", +) +pandas_modules = ( "pandas", ) jax_numpy_modules = ( @@ -190,7 +193,8 @@ def get_path_env_var(env_var, default): "numpy", "torch", ) + ( - pandas_numpy_modules + xarray_modules + + pandas_modules + jax_numpy_modules ) @@ -999,6 +1003,7 @@ def get_path_env_var(env_var, default): ("numpy", "py34;py<39"), ("numpy", "py39"), ("pandas", "py36"), + ("xarray", "py39"), ), "tests": ( ("pytest", "py<36"), @@ -1021,6 +1026,7 @@ def get_path_env_var(env_var, default): ("trollius", "py<3;cpy"): (2, 2), "requests": (2, 31), ("numpy", "py39"): (1, 26), + ("xarray", "py39"): (2023,), ("dataclasses", "py==36"): (0, 8), ("aenum", "py<34"): (3, 1, 15), "pydata-sphinx-theme": (0, 14), diff --git a/coconut/root.py b/coconut/root.py index 0e1a95ebf..123449d7c 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 9 +DEVELOP = 10 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 6ade7f0c8..1c5fbd7a1 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -10,6 +10,7 @@ from coconut.constants import ( PY34, PY35, PY36, + PY39, PYPY, ) # type: ignore from coconut._pyparsing import USE_COMPUTATION_GRAPH # type: ignore @@ -664,22 +665,46 @@ def test_pandas() -> bool: return True +def test_xarray() -> bool: + import xarray as xr + import numpy as np + def ds1 `dataset_equal` ds2 = (ds1 == ds2).all().values() |> all + da = xr.DataArray([10, 11;; 12, 13], dims=["x", "y"]) + ds = xr.Dataset({"a": da, "b": da + 10}) + assert ds$[0] == "a" + ds_ = [da; da + 10] + assert ds `dataset_equal` ds_ # type: ignore + ds__ = [da; da |> fmap$(.+10)] + assert ds `dataset_equal` ds__ # type: ignore + assert ds `dataset_equal` (ds |> fmap$(ident)) + assert da.to_numpy() `np.array_equal` (da |> fmap$(ident) |> .to_numpy()) + assert (ds |> fmap$(r -> r["a"] + r["b"]) |> .to_numpy()) `np.array_equal` np.array([30; 32;; 34; 36]) + assert not all_equal(da) + assert not all_equal(ds) + assert multi_enumerate(da) |> list == [((0, 0), 10), ((0, 1), 11), ((1, 0), 12), ((1, 1), 13)] + assert cartesian_product(da.sel(x=0), da.sel(x=1)) `np.array_equal` np.array([10; 12;; 10; 13;; 11; 12;; 11; 13]) # type: ignore + return True + + def test_extras() -> bool: if not PYPY and (PY2 or PY34): assert test_numpy() is True print(".", end="") if not PYPY and PY36: assert test_pandas() is True # . + print(".", end="") + if not PYPY and PY39: + assert test_xarray() is True # .. print(".") # newline bc we print stuff after this - assert test_setup_none() is True # .. + assert test_setup_none() is True # ... print(".") # ditto - assert test_convenience() is True # ... + assert test_convenience() is True # .... # everything after here uses incremental parsing, so it must come last print(".", end="") - assert test_incremental() is True # .... + assert test_incremental() is True # ..... if IPY: print(".", end="") - assert test_kernel() is True # ..... + assert test_kernel() is True # ...... return True