From e92360f7a333b16e90470e79fcc9e832d8d6e75b Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 25 Apr 2023 15:05:21 -0700 Subject: [PATCH 01/61] TST: Add tests for proper AxesArray warnings/slices AxesArray needs to warn when being created with a set of axes that is incompatible with the array data. It also needs to handle slices that copy or remove an axis --- test/utils/test_axes.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b1a38e6f4..2b8b35285 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -5,6 +5,7 @@ from numpy.testing import assert_raises from pysindy import AxesArray +from pysindy.utils.axes import AxesWarning def test_reduce_mean_noinf_recursion(): @@ -137,3 +138,22 @@ def test_n_elements(): with pytest.raises(IndexError): assert arr3.n_coord == 1 assert arr3.n_sample == 1 + + +def test_warn_bad_axes(): + axes = {"ax_time": 1, "ax_coord": 2} + with pytest.warns(AxesWarning): + AxesArray(np.ones(8).reshape((2, 2, 2)), axes) + with pytest.warns(AxesWarning): + AxesArray(np.ones(2), axes) + + +def test_fancy_indexing_modifies_axes(): + axes = {"ax_time": 1, "ax_coord": 2} + arr = AxesArray(np.ones(4).reshape((2, 2)), axes) + slim = arr[1, :] + fat = arr[[[0, 1], [0, 1]]] + assert slim.ax_time is None + assert slim.ax_coord == 1 + assert fat.ax_time == [0, 1] + assert fat.ax_coord == 2 From 47e879325bb160a358afd5b1c612af71a9e1bfba Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 25 Apr 2023 18:36:08 -0700 Subject: [PATCH 02/61] WIP making array slicing consistent Improved tests Remove axes for singleton slices using new _reverse_map attribute Restrict changes to shape --- pysindy/utils/axes.py | 60 ++++++++++++++++++++++++++++++++++++++++- test/utils/test_axes.py | 14 ++++++---- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index bad10d55c..87bcdadb8 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,10 +1,16 @@ +import copy +import warnings +from typing import Collection from typing import List +from typing import Sequence import numpy as np from sklearn.base import TransformerMixin HANDLED_FUNCTIONS = {} +AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) + class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. @@ -30,14 +36,62 @@ def __new__(cls, input_array, axes): "ax_sample": None, "ax_spatial": [], } + n_axes = sum(1 for k, v in axes.items() if v) if axes is None: return obj - obj.__dict__.update({**defaults, **axes}) + in_ndim = len(input_array.shape) + if n_axes != in_ndim: + warnings.warn( + f"{n_axes} axes labeled for array with {in_ndim} axes", AxesWarning + ) + axes = {**defaults, **axes} + listed_axes = [ + el for k, v in axes.items() if isinstance(v, Collection) for el in v + ] + listed_axes += [ + v + for k, v in axes.items() + if not isinstance(v, Collection) and v is not None + ] + _reverse_map = {} + for axis in listed_axes: + if axis >= in_ndim: + raise ValueError( + f"Assigned definition to axis {axis}, but array only has" + f" {in_ndim} axes" + ) + ax_names = [ax_name for ax_name in axes if axes[ax_name] == axis] + if len(ax_names) > 1: + raise ValueError(f"Assigned multiple definitions to axis {axis}") + _reverse_map[axis] = ax_names[0] + obj.__dict__.update({**axes}) + obj.__dict__["_reverse_map"] = _reverse_map return obj + def __getitem__(self, key, /): + remove_axes = [] + if isinstance(key, int): + remove_axes.append(key) + if isinstance(key, Sequence): + for axis, k in enumerate(key): + if isinstance(k, int): + remove_axes.append(axis) + new_item = super().__getitem__(key) + if not isinstance(new_item, AxesArray): + return new_item + for axis in remove_axes: + ax_name = self._reverse_map[axis] + if isinstance(new_item.__dict__[ax_name], int): + new_item.__dict__[ax_name] = None + else: + new_item.__dict__[ax_name].remove(axis) + new_item._reverse_map.pop(axis) + return new_item + def __array_finalize__(self, obj) -> None: if obj is None: return + self._reverse_map = copy.deepcopy(getattr(obj, "_reverse_map", {})) self.ax_time = getattr(obj, "ax_time", None) self.ax_coord = getattr(obj, "ax_coord", None) self.ax_sample = getattr(obj, "ax_sample", None) @@ -59,6 +113,10 @@ def n_sample(self): def n_coord(self): return self.shape[self.ax_coord] if self.ax_coord is not None else 1 + @property + def shape(self): + return super().shape + def __array_ufunc__( self, ufunc, method, *inputs, out=None, **kwargs ): # this method is called whenever you use a ufunc diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 2b8b35285..e0b89d876 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -140,16 +140,20 @@ def test_n_elements(): assert arr3.n_sample == 1 -def test_warn_bad_axes(): - axes = {"ax_time": 1, "ax_coord": 2} +def test_warn_toofew_axes(): + axes = {"ax_time": 0, "ax_coord": 1} with pytest.warns(AxesWarning): AxesArray(np.ones(8).reshape((2, 2, 2)), axes) - with pytest.warns(AxesWarning): - AxesArray(np.ones(2), axes) + + +def test_toomany_axes(): + axes = {"ax_time": 0, "ax_coord": 2} + with pytest.raises(ValueError): + AxesArray(np.ones(4).reshape((2, 2)), axes) def test_fancy_indexing_modifies_axes(): - axes = {"ax_time": 1, "ax_coord": 2} + axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) slim = arr[1, :] fat = arr[[[0, 1], [0, 1]]] From 647e6ec6924b80928e7737a19b202e721da468cc Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 30 Apr 2023 14:15:01 -0700 Subject: [PATCH 03/61] WIP: Offload AxesArray construction logic to _AxisMapping --- pysindy/utils/axes.py | 251 +++++++++++++++++++++++++++------------- test/utils/test_axes.py | 60 ++++++++-- 2 files changed, 217 insertions(+), 94 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 87bcdadb8..f57b6553e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,8 +1,9 @@ import copy import warnings -from typing import Collection from typing import List +from typing import MutableMapping from typing import Sequence +from typing import Union import numpy as np from sklearn.base import TransformerMixin @@ -12,6 +13,87 @@ AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) +class _AxisMapping: + """Convenience wrapper for a two-way map between axis names and + indexes. + """ + + def __init__( + self, + axes: MutableMapping[str, Union[int, Sequence[int]]] = None, + in_ndim: int = 0, + ): + if axes is None: + axes = {} + axes = copy.deepcopy(axes) + self.fwd_map = {} + self.reverse_map = {} + null = object() + + def coerce_sequence(obj): + if isinstance(obj, Sequence): + return sorted(obj) + return [obj] + + for ax_name, ax_ids in axes.items(): + ax_ids = coerce_sequence(ax_ids) + self.fwd_map[ax_name] = ax_ids + for ax_id in ax_ids: + old_name = self.reverse_map.get(ax_id, null) + if old_name is not null: + raise ValueError(f"Assigned multiple definitions to axis {ax_id}") + if ax_id >= in_ndim: + raise ValueError( + f"Assigned definition to axis {ax_id}, but array only has" + f" {in_ndim} axes" + ) + self.reverse_map[ax_id] = ax_name + if len(self.reverse_map) != in_ndim: + warnings.warn( + f"{len(self.reverse_map)} axes labeled for array with {in_ndim} axes", + AxesWarning, + ) + + @staticmethod + def _compat_axes(in_dict: dict[str, Sequence]) -> dict[str, Union[Sequence, int]]: + """Turn single-element axis index lists into ints""" + axes = {} + for k, v in in_dict.items(): + if len(v) == 1: + axes[k] = v[0] + else: + axes[k] = v + return axes + + @property + def compat_axes(self): + return self._compat_axes(self.fwd_map) + + def reduce(self, axis: Union[int, None] = None): + """Create an axes dict from self with specified axis + removed and all greater axes decremented. + + Arguments: + axis: the axis index to remove. By numpy ufunc convention, + axis=None (default) removes _all_ axes. + """ + if axis is None: + return {} + new_axes = copy.deepcopy(self.fwd_map) + in_ndim = len(self.reverse_map) + remove_ax_name = self.reverse_map[axis] + if len(new_axes[remove_ax_name]) == 1: + new_axes.pop(remove_ax_name) + else: + new_axes[remove_ax_name].remove(axis) + decrement_names = set() + for ax_id in range(axis + 1, in_ndim): + decrement_names.add(self.reverse_map[ax_id]) + for dec_name in decrement_names: + new_axes[dec_name] = [ax_id - 1 for ax_id in new_axes[dec_name]] + return self._compat_axes(new_axes) + + class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. @@ -30,93 +112,85 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): def __new__(cls, input_array, axes): obj = np.asarray(input_array).view(cls) - defaults = { - "ax_time": None, - "ax_coord": None, - "ax_sample": None, - "ax_spatial": [], - } - n_axes = sum(1 for k, v in axes.items() if v) if axes is None: - return obj + axes = {} in_ndim = len(input_array.shape) - if n_axes != in_ndim: - warnings.warn( - f"{n_axes} axes labeled for array with {in_ndim} axes", AxesWarning - ) - axes = {**defaults, **axes} - listed_axes = [ - el for k, v in axes.items() if isinstance(v, Collection) for el in v - ] - listed_axes += [ - v - for k, v in axes.items() - if not isinstance(v, Collection) and v is not None - ] - _reverse_map = {} - for axis in listed_axes: - if axis >= in_ndim: - raise ValueError( - f"Assigned definition to axis {axis}, but array only has" - f" {in_ndim} axes" - ) - ax_names = [ax_name for ax_name in axes if axes[ax_name] == axis] - if len(ax_names) > 1: - raise ValueError(f"Assigned multiple definitions to axis {axis}") - _reverse_map[axis] = ax_names[0] - obj.__dict__.update({**axes}) - obj.__dict__["_reverse_map"] = _reverse_map + obj.__ax_map = _AxisMapping(axes, in_ndim) return obj - def __getitem__(self, key, /): - remove_axes = [] - if isinstance(key, int): - remove_axes.append(key) - if isinstance(key, Sequence): - for axis, k in enumerate(key): - if isinstance(k, int): - remove_axes.append(axis) - new_item = super().__getitem__(key) - if not isinstance(new_item, AxesArray): - return new_item - for axis in remove_axes: - ax_name = self._reverse_map[axis] - if isinstance(new_item.__dict__[ax_name], int): - new_item.__dict__[ax_name] = None - else: - new_item.__dict__[ax_name].remove(axis) - new_item._reverse_map.pop(axis) - return new_item - - def __array_finalize__(self, obj) -> None: - if obj is None: - return - self._reverse_map = copy.deepcopy(getattr(obj, "_reverse_map", {})) - self.ax_time = getattr(obj, "ax_time", None) - self.ax_coord = getattr(obj, "ax_coord", None) - self.ax_sample = getattr(obj, "ax_sample", None) - self.ax_spatial = getattr(obj, "ax_spatial", []) - - @property - def n_spatial(self): - return tuple(self.shape[ax] for ax in self.ax_spatial) - - @property - def n_time(self): - return self.shape[self.ax_time] if self.ax_time is not None else 1 - @property - def n_sample(self): - return self.shape[self.ax_sample] if self.ax_sample is not None else 1 + def axes(self): + return self.__ax_map.compat_axes @property - def n_coord(self): - return self.shape[self.ax_coord] if self.ax_coord is not None else 1 + def _reverse_map(self): + return self.__ax_map.reverse_map @property def shape(self): return super().shape + def __getattr__(self, name): + parts = name.split("_", 1) + if parts[0] == "ax": + return self.axes[name] + if parts[0] == "n": + fwd_map = self.__ax_map.fwd_map + shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) + if len(shape) == 1: + return shape[0] + return shape + raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") + + # def __getitem__(self, key, /): + # pass + # return super().__getitem__(self, key) + # def __getitem__(self, key, /): + # remove_axes = [] + # if isinstance(key, int): + # remove_axes.append(key) + # if isinstance(key, Sequence): + # for axis, k in enumerate(key): + # if isinstance(k, int): + # remove_axes.append(axis) + # new_item = super().__getitem__(key) + # if not isinstance(new_item, AxesArray): + # return new_item + # for axis in remove_axes: + # ax_name = self._reverse_map[axis] + # if isinstance(new_item.__dict__[ax_name], int): + # new_item.__dict__[ax_name] = None + # else: + # new_item.__dict__[ax_name].remove(axis) + # new_item._reverse_map.pop(axis) + # return new_item + + def __array_wrap__(self, out_arr, context=None): + return super().__array_wrap__(self, out_arr, context) + + def __array_finalize__(self, obj) -> None: + if obj is None: # explicit construction via super().__new__().. not called? + return + # view from numpy array, called in constructor but also tests + if all( + ( + not isinstance(obj, AxesArray), + self.shape == (), + not hasattr(self, "__ax_map"), + ) + ): + self.__ax_map = _AxisMapping({}) + # required by ravel() and view() used in numpy testing. Also for zeros_like... + elif all( + ( + isinstance(obj, AxesArray), + not hasattr(self, "__ax_map"), + self.shape == obj.shape, + ) + ): + self.__ax_map = _AxisMapping(obj.axes, len(obj.shape)) + # maybe add errors for incompatible views? + def __array_ufunc__( self, ufunc, method, *inputs, out=None, **kwargs ): # this method is called whenever you use a ufunc @@ -145,17 +219,30 @@ def __array_ufunc__( return if ufunc.nout == 1: results = (results,) - results = tuple( - (AxesArray(np.asarray(result), self.__dict__) if output is None else output) - for result, output in zip(results, outputs) - ) + if method == "reduce" and ( + "keepdims" not in kwargs.keys() or kwargs["keepdims"] is False + ): + axes = None + if kwargs["axis"] is not None: + axes = self.__ax_map.reduce(axis=kwargs["axis"]) + else: + axes = self.axes + final_results = [] + for result, output in zip(results, outputs): + if output is not None: + final_results.append(output) + elif axes is None: + final_results.append(result) + else: + final_results.append(AxesArray(np.asarray(result), axes)) + results = tuple(final_results) return results[0] if len(results) == 1 else results def __array_function__(self, func, types, args, kwargs): if func not in HANDLED_FUNCTIONS: arr = super(AxesArray, self).__array_function__(func, types, args, kwargs) if isinstance(arr, np.ndarray): - return AxesArray(arr, axes=self.__dict__) + return AxesArray(arr, axes=self.axes) elif arr is not None: return arr return @@ -177,7 +264,7 @@ def decorator(func): @implements(np.concatenate) def concatenate(arrays, axis=0): parents = [np.asarray(obj) for obj in arrays] - ax_list = [obj.__dict__ for obj in arrays if isinstance(obj, AxesArray)] + ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)] for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]): if ax1 != ax2: raise TypeError("Concatenating >1 AxesArray with incompatible axes") diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index e0b89d876..65bb4c63d 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -5,11 +5,12 @@ from numpy.testing import assert_raises from pysindy import AxesArray +from pysindy.utils.axes import _AxisMapping from pysindy.utils.axes import AxesWarning def test_reduce_mean_noinf_recursion(): - arr = AxesArray(np.array([[1]]), {}) + arr = AxesArray(np.array([[1]]), {"ax_a": [0, 1]}) np.mean(arr, axis=0) @@ -26,31 +27,31 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): d = np.arange(5.0) # 1 input, 1 output - a = AxesArray(d, {}) + a = AxesArray(d, {"ax_time": 0}) b = np.sin(a) check = np.sin(d) assert_(np.all(check == b)) b = np.sin(d, out=(a,)) assert_(np.all(check == b)) assert_(b is a) - a = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) b = np.sin(a, out=a) assert_(np.all(check == b)) # 1 input, 2 outputs - a = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) b1, b2 = np.modf(a) b1, b2 = np.modf(d, out=(None, a)) assert_(b2 is a) - a = AxesArray(np.arange(5.0), {}) - b = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) + b = AxesArray(np.arange(5.0), {"ax_time": 0}) c1, c2 = np.modf(a, out=(a, b)) assert_(c1 is a) assert_(c2 is b) # 2 input, 1 output - a = AxesArray(np.arange(5.0), {}) - b = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) + b = AxesArray(np.arange(5.0), {"ax_time": 0}) c = np.add(a, b, out=a) assert_(c is a) # some tests with a non-ndarray subclass @@ -59,13 +60,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert_(a.__array_ufunc__(np.add, "__call__", a, b) is NotImplemented) assert_(b.__array_ufunc__(np.add, "__call__", a, b) is NotImplemented) assert_raises(TypeError, np.add, a, b) - a = AxesArray(a, {}) + a = AxesArray(a, {"ax_time": 0}) assert_(a.__array_ufunc__(np.add, "__call__", a, b) is NotImplemented) assert_(b.__array_ufunc__(np.add, "__call__", a, b) == "A!") assert_(np.add(a, b) == "A!") # regression check for gh-9102 -- tests ufunc.reduce implicitly. d = np.array([[1, 2, 3], [1, 2, 3]]) - a = AxesArray(d, {}) + a = AxesArray(d, {"ax_time": [0, 1]}) c = a.any() check = d.any() assert_equal(c, check) @@ -89,6 +90,11 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): c = np.add.reduce(a, 1, None, b) assert_equal(c, check) assert_(c is b) + + +def test_ufunc_override_accumulate(): + d = np.array([[1, 2, 3], [1, 2, 3]]) + a = AxesArray(d, {"ax_time": [0, 1]}) check = np.add.accumulate(d, axis=0) c = np.add.accumulate(a, axis=0) assert_equal(c, check) @@ -123,14 +129,16 @@ def test_n_elements(): assert arr.n_spatial == (1, 2) assert arr.n_time == 3 assert arr.n_coord == 4 - assert arr.n_sample == 1 arr2 = np.concatenate((arr, arr), axis=arr.ax_time) assert arr2.n_spatial == (1, 2) assert arr2.n_time == 6 assert arr2.n_coord == 4 - assert arr2.n_sample == 1 + +def test_limited_slice(): + arr = np.empty(np.arange(1, 5)) + arr = AxesArray(arr, {"ax_spatial": [0, 1], "ax_time": 2, "ax_coord": 3}) arr3 = arr[..., :2, 0] assert arr3.n_spatial == (1, 2) assert arr3.n_time == 2 @@ -152,6 +160,13 @@ def test_toomany_axes(): AxesArray(np.ones(4).reshape((2, 2)), axes) +def test_conflicting_axes_defn(): + axes = {"ax_time": 0, "ax_coord": 0} + with pytest.raises(ValueError): + AxesArray(np.ones(4), axes) + + +# @pytest.mark.skip("giving error") def test_fancy_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) @@ -161,3 +176,24 @@ def test_fancy_indexing_modifies_axes(): assert slim.ax_coord == 1 assert fat.ax_time == [0, 1] assert fat.ax_coord == 2 + + +def test_reduce_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": 4, + "ax_e": [5, 6], + }, + 7, + ) + result = ax_map.reduce(3) + expected = { + "ax_a": [0, 1], + "ax_b": 2, + "ax_d": 3, + "ax_e": [4, 5], + } + assert result == expected From 06c5b9ac039f477633eca01a15e46829b90a83a8 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 30 Apr 2023 14:37:30 -0700 Subject: [PATCH 04/61] WIP Allow __array_function__ to let ufuncs pass through without change. In cases where dimensions change, chances are __array_ufunc__ took care of creating an AxesArray. Return it. If there's a case where __array_function__ created an array with different dimensions than self, it will still error. --- pysindy/utils/axes.py | 4 +++- test/utils/test_axes.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index f57b6553e..13724ad1a 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -241,7 +241,9 @@ def __array_ufunc__( def __array_function__(self, func, types, args, kwargs): if func not in HANDLED_FUNCTIONS: arr = super(AxesArray, self).__array_function__(func, types, args, kwargs) - if isinstance(arr, np.ndarray): + if isinstance(arr, AxesArray): + return arr + elif isinstance(arr, np.ndarray): return AxesArray(arr, axes=self.axes) elif arr is not None: return arr diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 65bb4c63d..bfbc44595 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -14,6 +14,13 @@ def test_reduce_mean_noinf_recursion(): np.mean(arr, axis=0) +def test_repr(): + a = AxesArray(np.arange(5.0), {"ax_time": 0}) + result = a.__repr__() + expected = "AxesArray([0., 1., 2., 3., 4.])" + assert result == expected + + def test_ufunc_override(): # This is largely a clone of test_ufunc_override_with_super() from # numpy/core/tests/test_umath.py @@ -92,6 +99,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert_(c is b) +@pytest.mark.skip("Expected error") def test_ufunc_override_accumulate(): d = np.array([[1, 2, 3], [1, 2, 3]]) a = AxesArray(d, {"ax_time": [0, 1]}) @@ -136,6 +144,7 @@ def test_n_elements(): assert arr2.n_coord == 4 +@pytest.mark.skip("Expected error") def test_limited_slice(): arr = np.empty(np.arange(1, 5)) arr = AxesArray(arr, {"ax_spatial": [0, 1], "ax_time": 2, "ax_coord": 3}) @@ -166,7 +175,7 @@ def test_conflicting_axes_defn(): AxesArray(np.ones(4), axes) -# @pytest.mark.skip("giving error") +@pytest.mark.skip("giving error") def test_fancy_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) From 218e1f45f110cdcef6d1993bbad76baf99882325 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 30 Apr 2023 17:03:23 -0700 Subject: [PATCH 05/61] WIP: begin __getitem__ work to id axes --- pysindy/utils/axes.py | 30 +++++++++++++++++++++++++++--- test/utils/test_axes.py | 2 +- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 13724ad1a..b5f31c3c7 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -142,9 +142,33 @@ def __getattr__(self, name): return shape raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") - # def __getitem__(self, key, /): - # pass - # return super().__getitem__(self, key) + def __getitem__(self, key, /): + output = super().__getitem__(key) + # determine axes of output + in_dim = self.shape # noqa + out_dim = output.shape # noqa + remove_dims = [] # noqa + basic_indexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)] + if any( # basic indexing + isinstance(key, basic_indexer), + isinstance(key, tuple) and all(isinstance(k, basic_indexer) for k in key), + ): + pass + if any( # fancy indexing + isinstance(key, Sequence) and not isinstance(key, tuple), + isinstance(key, np.ndarray), + isinstance(key, tuple) and any(isinstance(k, Sequence) for k in key), + isinstance(key, tuple) and any(isinstance(k, np.ndarray) for k in key), # ? + ): + # check if integer or boolean indexing + # if integer, check which dimensions get broadcast where + # if multiple, axes are merged. If adjacent, merged inplace, otherwise moved to beginning + pass + else: + raise TypeError(f"AxisArray {self} does not know how to slice with {key}") + # mulligan structured arrays, etc. + return output + # def __getitem__(self, key, /): # remove_axes = [] # if isinstance(key, int): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index bfbc44595..657e9e61a 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -99,7 +99,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert_(c is b) -@pytest.mark.skip("Expected error") +# @pytest.mark.skip("Expected error") def test_ufunc_override_accumulate(): d = np.array([[1, 2, 3], [1, 2, 3]]) a = AxesArray(d, {"ax_time": [0, 1]}) From 0d358de2e90846d53a06ff51e3bb599336eca13c Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 18 May 2023 09:16:35 -0700 Subject: [PATCH 06/61] ENH: add function to standardize basic indexing keys --- pysindy/utils/axes.py | 26 +++++++++++++++++++++++--- test/utils/test_axes.py | 12 +++++++++++- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index b5f31c3c7..748f91ddb 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -153,7 +153,9 @@ def __getitem__(self, key, /): isinstance(key, basic_indexer), isinstance(key, tuple) and all(isinstance(k, basic_indexer) for k in key), ): - pass + key = _standardize_basic_indexer(self, key) + + return output if any( # fancy indexing isinstance(key, Sequence) and not isinstance(key, tuple), isinstance(key, np.ndarray), @@ -162,8 +164,9 @@ def __getitem__(self, key, /): ): # check if integer or boolean indexing # if integer, check which dimensions get broadcast where - # if multiple, axes are merged. If adjacent, merged inplace, otherwise moved to beginning - pass + # if multiple, axes are merged. If adjacent, merged inplace, + # otherwise moved to beginning + return output else: raise TypeError(f"AxisArray {self} does not know how to slice with {key}") # mulligan structured arrays, etc. @@ -297,6 +300,23 @@ def concatenate(arrays, axis=0): return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) +def _standardize_basic_indexer(arr: np.ndarray, key): + """Convert to a tuple of slices, ints, and None.""" + if isinstance(key, tuple): + if not any(ax_key is Ellipsis for ax_key in key): + key = (*key, Ellipsis) + slicedim = sum(isinstance(ax_key, slice | int) for ax_key in key) + final_key = [] + for ax_key in key: + inner_iterator = (ax_key,) + if ax_key is Ellipsis: + inner_iterator = (arr.ndim - slicedim) * (slice(None),) + for el in inner_iterator: + final_key.append(el) + return tuple(final_key) + return _standardize_basic_indexer(arr, (key,)) + + def comprehend_axes(x): axes = {} axes["ax_coord"] = len(x.shape) - 1 diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 657e9e61a..9c70394b7 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -5,6 +5,7 @@ from numpy.testing import assert_raises from pysindy import AxesArray +from pysindy.utils import axes from pysindy.utils.axes import _AxisMapping from pysindy.utils.axes import AxesWarning @@ -176,7 +177,7 @@ def test_conflicting_axes_defn(): @pytest.mark.skip("giving error") -def test_fancy_indexing_modifies_axes(): +def test_fancy_getitem_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) slim = arr[1, :] @@ -187,6 +188,15 @@ def test_fancy_indexing_modifies_axes(): assert fat.ax_coord == 2 +def test_standardize_basic_indexer(): + arr = np.arange(6).reshape(2, 3) + result = axes._standardize_basic_indexer(arr, Ellipsis) + assert result == (slice(None), slice(None)) + + result = axes._standardize_basic_indexer(arr, (np.newaxis, 1, 1, Ellipsis)) + assert result == (None, 1, 1) + + def test_reduce_AxisMapping(): ax_map = _AxisMapping( { From 3393599b0e651199d2350b17330418ed82dd9d34 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 21 May 2023 13:04:15 -0700 Subject: [PATCH 07/61] ENH: rename AxisMapping.reduce and apply to multiple axes --- pysindy/utils/axes.py | 67 +++++++++++++++++++++++++++-------------- test/utils/test_axes.py | 30 +++++++++++++++--- 2 files changed, 70 insertions(+), 27 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 748f91ddb..22626965e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,5 +1,7 @@ import copy import warnings +from collections import defaultdict +from typing import Collection from typing import List from typing import MutableMapping from typing import Sequence @@ -69,28 +71,37 @@ def _compat_axes(in_dict: dict[str, Sequence]) -> dict[str, Union[Sequence, int] def compat_axes(self): return self._compat_axes(self.fwd_map) - def reduce(self, axis: Union[int, None] = None): - """Create an axes dict from self with specified axis + def remove_axis(self, axis: Union[Collection[int], int, None] = None): + """Create an axes dict from self with specified axis or axes removed and all greater axes decremented. Arguments: - axis: the axis index to remove. By numpy ufunc convention, - axis=None (default) removes _all_ axes. + axis: the axis index or axes indexes to remove. By numpy + ufunc convention, axis=None (default) removes _all_ axes. """ if axis is None: return {} new_axes = copy.deepcopy(self.fwd_map) in_ndim = len(self.reverse_map) - remove_ax_name = self.reverse_map[axis] - if len(new_axes[remove_ax_name]) == 1: - new_axes.pop(remove_ax_name) - else: - new_axes[remove_ax_name].remove(axis) - decrement_names = set() - for ax_id in range(axis + 1, in_ndim): - decrement_names.add(self.reverse_map[ax_id]) - for dec_name in decrement_names: - new_axes[dec_name] = [ax_id - 1 for ax_id in new_axes[dec_name]] + decrement_names = defaultdict(lambda: 0) + removal_names = [] + if not isinstance(axis, Collection): + axis = [axis] + for ax in axis: + remove_ax_name = self.reverse_map[ax] + removal_names.append(remove_ax_name) + if len(new_axes[remove_ax_name]) == 1: + new_axes.pop(remove_ax_name) + else: + new_axes[remove_ax_name].remove(ax) + names_beyond_axis = set() + for ax_id in range(ax + 1, in_ndim): + names_beyond_axis.add(self.reverse_map[ax_id]) + for ax_name in names_beyond_axis: + decrement_names[ax_name] += 1 + [decrement_names.pop(name, None) for name in removal_names] + for dec_name, dec_amt in decrement_names.items(): + new_axes[dec_name] = [ax_id - dec_amt for ax_id in new_axes[dec_name]] return self._compat_axes(new_axes) @@ -147,15 +158,24 @@ def __getitem__(self, key, /): # determine axes of output in_dim = self.shape # noqa out_dim = output.shape # noqa - remove_dims = [] # noqa + remove_axes = [] # noqa + new_axes = [] # noqa basic_indexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)] - if any( # basic indexing - isinstance(key, basic_indexer), - isinstance(key, tuple) and all(isinstance(k, basic_indexer) for k in key), + if any( + ( # basic indexing + isinstance(key, basic_indexer), + isinstance(key, tuple) + and all(isinstance(k, basic_indexer) for k in key), + ) ): key = _standardize_basic_indexer(self, key) - - return output + shift = 0 + for ax_ind, indexer in enumerate(key): + if indexer is None: + new_axes.append(ax_ind - shift) + elif isinstance(indexer, int): + remove_axes.append(ax_ind) + shift += 1 if any( # fancy indexing isinstance(key, Sequence) and not isinstance(key, tuple), isinstance(key, np.ndarray), @@ -166,10 +186,13 @@ def __getitem__(self, key, /): # if integer, check which dimensions get broadcast where # if multiple, axes are merged. If adjacent, merged inplace, # otherwise moved to beginning - return output + pass else: raise TypeError(f"AxisArray {self} does not know how to slice with {key}") # mulligan structured arrays, etc. + new_map = _AxisMapping(self.__ax_map.remove_axis(remove_axes)) + new_map = _AxisMapping(new_map.insert_axes(new_axes)) + output.__ax_map = new_map return output # def __getitem__(self, key, /): @@ -251,7 +274,7 @@ def __array_ufunc__( ): axes = None if kwargs["axis"] is not None: - axes = self.__ax_map.reduce(axis=kwargs["axis"]) + axes = self.__ax_map.remove_axis(axis=kwargs["axis"]) else: axes = self.axes final_results = [] diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 9c70394b7..a3fb56c62 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -176,14 +176,14 @@ def test_conflicting_axes_defn(): AxesArray(np.ones(4), axes) -@pytest.mark.skip("giving error") -def test_fancy_getitem_modifies_axes(): +def test_getitem_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) - slim = arr[1, :] + slim = arr[1, :, None] fat = arr[[[0, 1], [0, 1]]] assert slim.ax_time is None - assert slim.ax_coord == 1 + assert slim.ax_new == 1 + assert slim.ax_coord == 0 assert fat.ax_time == [0, 1] assert fat.ax_coord == 2 @@ -208,7 +208,7 @@ def test_reduce_AxisMapping(): }, 7, ) - result = ax_map.reduce(3) + result = ax_map.remove_axis(3) expected = { "ax_a": [0, 1], "ax_b": 2, @@ -216,3 +216,23 @@ def test_reduce_AxisMapping(): "ax_e": [4, 5], } assert result == expected + + +def test_reduce_multiple_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": 4, + "ax_e": [5, 6], + }, + 7, + ) + result = ax_map.remove_axis([3, 4]) + expected = { + "ax_a": [0, 1], + "ax_b": 2, + "ax_e": [3, 4], + } + assert result == expected From 6f6aec51e7ebed29517ff13b19e6f2a10a5d828a Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 21 May 2023 13:20:25 -0700 Subject: [PATCH 08/61] TST: Add test for inserting axes & mis-ordered axes --- test/utils/test_axes.py | 62 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index a3fb56c62..ae914e5ec 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -236,3 +236,65 @@ def test_reduce_multiple_AxisMapping(): "ax_e": [3, 4], } assert result == expected + + +def test_reduce_twisted_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 6], + "ax_b": 2, + "ax_c": 3, + "ax_d": 4, + "ax_e": [1, 5], + }, + 7, + ) + result = ax_map.remove_axis([3, 4]) + expected = { + "ax_a": [0, 4], + "ax_b": 2, + "ax_e": [1, 3], + } + assert result == expected + + +def test_insert_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": [4, 5], + }, + 6, + ) + result = ax_map.insert_axis(3) + expected = { + "ax_a": [0, 1], + "ax_b": 2, + "ax_unk": 3, + "ax_c": 4, + "ax_d": [5, 6], + } + assert result == expected + + +def test_insert_multiple_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": [4, 5], + }, + 6, + ) + result = ax_map.insert_axis([1, 4]) + expected = { + "ax_a": [0, 2], + "ax_unk": [1, 4], + "ax_b": 3, + "ax_c": 5, + "ax_d": [6, 7], + } + assert result == expected From a138e059fcda0f7b4951a88b639abe47cf640b6b Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 22 May 2023 10:41:04 -0700 Subject: [PATCH 09/61] BUG: build insert_axis --- pysindy/utils/axes.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 22626965e..09b4f380e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -28,8 +28,8 @@ def __init__( if axes is None: axes = {} axes = copy.deepcopy(axes) - self.fwd_map = {} - self.reverse_map = {} + self.fwd_map: dict[str, list[int]] = {} + self.reverse_map: dict[int, str] = {} null = object() def coerce_sequence(obj): @@ -104,6 +104,33 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): new_axes[dec_name] = [ax_id - dec_amt for ax_id in new_axes[dec_name]] return self._compat_axes(new_axes) + def insert_axis(self, axis: Union[Collection[int], int]): + """Create an axes dict from self with specified axis or axes + added and all greater axes incremented. + + Arguments: + axis: the axis index or axes indexes to add. + + Todo: + May be more efficient to determine final axis-to-axis + mapping, then apply, rather than apply changes after each + axis insert. + """ + new_axes = copy.deepcopy(self.fwd_map) + in_ndim = len(self.reverse_map) + if not isinstance(axis, Collection): + axis = [axis] + for cum_shift, ax in enumerate(axis): + if "ax_unk" in new_axes.keys(): + new_axes["ax_unk"].append(ax) + else: + new_axes["ax_unk"] = [ax] + for ax_id in range(ax, in_ndim + cum_shift): + ax_name = self.reverse_map[ax_id - cum_shift] + new_axes[ax_name].remove(ax_id) + new_axes[ax_name].append(ax_id + 1) + return self._compat_axes(new_axes) + class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. From d5e046b972ab3e1789c824769a366c24dee53335 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 22 May 2023 11:30:08 -0700 Subject: [PATCH 10/61] BUG: Allow reduce_axis to handle twisted axes. twisted axes is when axes are not adjacent but have the same label, e.g. arr = AxesArray(np.empty((1,1,1)), {"ax_spatial": [0,2], "ax_time": 1}) --- pysindy/utils/axes.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 09b4f380e..c6b66dcf6 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,6 +1,5 @@ import copy import warnings -from collections import defaultdict from typing import Collection from typing import List from typing import MutableMapping @@ -83,25 +82,20 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): return {} new_axes = copy.deepcopy(self.fwd_map) in_ndim = len(self.reverse_map) - decrement_names = defaultdict(lambda: 0) - removal_names = [] if not isinstance(axis, Collection): axis = [axis] - for ax in axis: - remove_ax_name = self.reverse_map[ax] - removal_names.append(remove_ax_name) + for cum_shift, orig_ax_remove in enumerate(axis): + remove_ax_name = self.reverse_map[orig_ax_remove] + curr_ax_remove = orig_ax_remove - cum_shift if len(new_axes[remove_ax_name]) == 1: new_axes.pop(remove_ax_name) else: - new_axes[remove_ax_name].remove(ax) - names_beyond_axis = set() - for ax_id in range(ax + 1, in_ndim): - names_beyond_axis.add(self.reverse_map[ax_id]) - for ax_name in names_beyond_axis: - decrement_names[ax_name] += 1 - [decrement_names.pop(name, None) for name in removal_names] - for dec_name, dec_amt in decrement_names.items(): - new_axes[dec_name] = [ax_id - dec_amt for ax_id in new_axes[dec_name]] + new_axes[remove_ax_name].remove(curr_ax_remove) + for old_ax_dec in range(curr_ax_remove + 1, in_ndim - cum_shift): + orig_ax_dec = old_ax_dec + cum_shift + ax_dec_name = self.reverse_map[orig_ax_dec] + new_axes[ax_dec_name].remove(old_ax_dec) + new_axes[ax_dec_name].append(old_ax_dec - 1) return self._compat_axes(new_axes) def insert_axis(self, axis: Union[Collection[int], int]): From 86c2e3d8bb65605985a589fff69a59e2207bdda2 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 22 May 2023 12:01:41 -0700 Subject: [PATCH 11/61] ENH: Enable basic indexing on AxesArrays Also split apart basic and fancy indexing tests --- pysindy/utils/axes.py | 24 +++++++++++++++++------- test/utils/test_axes.py | 14 ++++++++++---- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index c6b66dcf6..41b320d3c 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -176,6 +176,8 @@ def __getattr__(self, name): def __getitem__(self, key, /): output = super().__getitem__(key) + if not isinstance(output, AxesArray): + return output # determine axes of output in_dim = self.shape # noqa out_dim = output.shape # noqa @@ -197,11 +199,14 @@ def __getitem__(self, key, /): elif isinstance(indexer, int): remove_axes.append(ax_ind) shift += 1 - if any( # fancy indexing - isinstance(key, Sequence) and not isinstance(key, tuple), - isinstance(key, np.ndarray), - isinstance(key, tuple) and any(isinstance(k, Sequence) for k in key), - isinstance(key, tuple) and any(isinstance(k, np.ndarray) for k in key), # ? + elif any( # fancy indexing + ( + isinstance(key, Sequence) and not isinstance(key, tuple), + isinstance(key, np.ndarray), + isinstance(key, tuple) and any(isinstance(k, Sequence) for k in key), + isinstance(key, tuple) + and any(isinstance(k, np.ndarray) for k in key), # ? + ) ): # check if integer or boolean indexing # if integer, check which dimensions get broadcast where @@ -211,8 +216,13 @@ def __getitem__(self, key, /): else: raise TypeError(f"AxisArray {self} does not know how to slice with {key}") # mulligan structured arrays, etc. - new_map = _AxisMapping(self.__ax_map.remove_axis(remove_axes)) - new_map = _AxisMapping(new_map.insert_axes(new_axes)) + new_map = _AxisMapping( + self.__ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) + ) + new_map = _AxisMapping( + new_map.insert_axis(new_axes), + len(in_dim) - len(remove_axes) + len(new_axes), + ) output.__ax_map = new_map return output diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index ae914e5ec..a515c6af0 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -176,14 +176,20 @@ def test_conflicting_axes_defn(): AxesArray(np.ones(4), axes) -def test_getitem_modifies_axes(): +def test_basic_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) slim = arr[1, :, None] - fat = arr[[[0, 1], [0, 1]]] - assert slim.ax_time is None - assert slim.ax_new == 1 + with pytest.raises(KeyError): + slim.ax_time + assert slim.ax_unk == 1 assert slim.ax_coord == 0 + + +def test_fancy_indexing_modifies_axes(): + axes = {"ax_time": 0, "ax_coord": 1} + arr = AxesArray(np.ones(4).reshape((2, 2)), axes) + fat = arr[[[0, 1], [0, 1]]] assert fat.ax_time == [0, 1] assert fat.ax_coord == 2 From 478cf5255a7d74697e3425492cb44c52c0c00e46 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 22 May 2023 12:29:26 -0700 Subject: [PATCH 12/61] TST: Build test for standardizing fancy indexers Also rename _standardize_basic_indexer as _standardize_indexer --- pysindy/utils/axes.py | 35 ++++++++++------------------------- test/utils/test_axes.py | 17 +++++++++++++++-- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 41b320d3c..576ae2fdb 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -191,7 +191,7 @@ def __getitem__(self, key, /): and all(isinstance(k, basic_indexer) for k in key), ) ): - key = _standardize_basic_indexer(self, key) + key, _ = _standardize_indexer(self, key) shift = 0 for ax_ind, indexer in enumerate(key): if indexer is None: @@ -226,26 +226,6 @@ def __getitem__(self, key, /): output.__ax_map = new_map return output - # def __getitem__(self, key, /): - # remove_axes = [] - # if isinstance(key, int): - # remove_axes.append(key) - # if isinstance(key, Sequence): - # for axis, k in enumerate(key): - # if isinstance(k, int): - # remove_axes.append(axis) - # new_item = super().__getitem__(key) - # if not isinstance(new_item, AxesArray): - # return new_item - # for axis in remove_axes: - # ax_name = self._reverse_map[axis] - # if isinstance(new_item.__dict__[ax_name], int): - # new_item.__dict__[ax_name] = None - # else: - # new_item.__dict__[ax_name].remove(axis) - # new_item._reverse_map.pop(axis) - # return new_item - def __array_wrap__(self, out_arr, context=None): return super().__array_wrap__(self, out_arr, context) @@ -354,8 +334,13 @@ def concatenate(arrays, axis=0): return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) -def _standardize_basic_indexer(arr: np.ndarray, key): - """Convert to a tuple of slices, ints, and None.""" +def _standardize_indexer(arr: np.ndarray, key): + """Convert to a tuple of slices, ints, None, and ndarrays. + + Returns: + A tuple of the normalized indexer as well as the indexes of + fancy indexers + """ if isinstance(key, tuple): if not any(ax_key is Ellipsis for ax_key in key): key = (*key, Ellipsis) @@ -367,8 +352,8 @@ def _standardize_basic_indexer(arr: np.ndarray, key): inner_iterator = (arr.ndim - slicedim) * (slice(None),) for el in inner_iterator: final_key.append(el) - return tuple(final_key) - return _standardize_basic_indexer(arr, (key,)) + return tuple(final_key), tuple() + return _standardize_indexer(arr, (key,)) def comprehend_axes(x): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index a515c6af0..6aeea736c 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -196,13 +196,26 @@ def test_fancy_indexing_modifies_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) - result = axes._standardize_basic_indexer(arr, Ellipsis) + result, _ = axes._standardize_indexer(arr, Ellipsis) assert result == (slice(None), slice(None)) - result = axes._standardize_basic_indexer(arr, (np.newaxis, 1, 1, Ellipsis)) + result, _ = axes._standardize_indexer(arr, (np.newaxis, 1, 1, Ellipsis)) assert result == (None, 1, 1) +def test_standardize_fancy_indexer(): + arr = np.arange(6).reshape(2, 3) + result_indexer, result_fancy = axes._standardize_indexer(arr, [1]) + assert result_indexer == (np.ones(1), slice(None)) + assert result_fancy == (0,) + + result_indexer, result_fancy = axes._standardize_indexer( + arr, (np.newaxis, [1], 1, Ellipsis) + ) + assert result_indexer == (None, np.ones(1), 1) + assert result_fancy == (1,) + + def test_reduce_AxisMapping(): ax_map = _AxisMapping( { From beef4a75c022a8ed7854f4b1b1a71c01adde10a6 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 22 May 2023 13:11:43 -0700 Subject: [PATCH 13/61] ENH: Allow _standardize_indexer to handle fancy indexes --- pysindy/utils/axes.py | 44 ++++++++++++++++++++++++++--------------- test/utils/test_axes.py | 12 +++++++---- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 576ae2fdb..d8b8a1d96 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -12,6 +12,7 @@ HANDLED_FUNCTIONS = {} AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) +BasicIndexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)] class _AxisMapping: @@ -183,12 +184,11 @@ def __getitem__(self, key, /): out_dim = output.shape # noqa remove_axes = [] # noqa new_axes = [] # noqa - basic_indexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)] if any( ( # basic indexing - isinstance(key, basic_indexer), + isinstance(key, BasicIndexer), isinstance(key, tuple) - and all(isinstance(k, basic_indexer) for k in key), + and all(isinstance(k, BasicIndexer) for k in key), ) ): key, _ = _standardize_indexer(self, key) @@ -341,19 +341,31 @@ def _standardize_indexer(arr: np.ndarray, key): A tuple of the normalized indexer as well as the indexes of fancy indexers """ - if isinstance(key, tuple): - if not any(ax_key is Ellipsis for ax_key in key): - key = (*key, Ellipsis) - slicedim = sum(isinstance(ax_key, slice | int) for ax_key in key) - final_key = [] - for ax_key in key: - inner_iterator = (ax_key,) - if ax_key is Ellipsis: - inner_iterator = (arr.ndim - slicedim) * (slice(None),) - for el in inner_iterator: - final_key.append(el) - return tuple(final_key), tuple() - return _standardize_indexer(arr, (key,)) + if not isinstance(key, tuple): + key = (key,) + if not any(ax_key is Ellipsis for ax_key in key): + key = (*key, Ellipsis) + new_key = [] + fancy_inds = [] + slicedim = 0 + for indexer_ind, ax_key in enumerate(key): + if not isinstance(ax_key, BasicIndexer): + ax_key = np.array(ax_key) + fancy_inds.append(indexer_ind) + new_key.append(ax_key) + if isinstance(ax_key, slice | int | np.ndarray): + slicedim += 1 + ellipsis_dims = arr.ndim - slicedim + ellind = new_key.index(Ellipsis) + new_key[ellind : ellind + 1] = ellipsis_dims * (slice(None),) + fancy_inds = [ind if ind < ellind else ind + ellind for ind in fancy_inds] + # for ax_key in new_key: + # inner_iterator = (ax_key,) + # if ax_key is Ellipsis: + # inner_iterator = (arr.ndim - slicedim) * (slice(None),) + # for el in inner_iterator: + # final_key.append(el) + return tuple(new_key), tuple(fancy_inds) def comprehend_axes(x): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 6aeea736c..89613f39b 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -196,11 +196,15 @@ def test_fancy_indexing_modifies_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) - result, _ = axes._standardize_indexer(arr, Ellipsis) - assert result == (slice(None), slice(None)) + result_indexer, result_fancy = axes._standardize_indexer(arr, Ellipsis) + assert result_indexer == (slice(None), slice(None)) + assert result_fancy == () - result, _ = axes._standardize_indexer(arr, (np.newaxis, 1, 1, Ellipsis)) - assert result == (None, 1, 1) + result_indexer, result_fancy = axes._standardize_indexer( + arr, (np.newaxis, 1, 1, Ellipsis) + ) + assert result_indexer == (None, 1, 1) + assert result_fancy == () def test_standardize_fancy_indexer(): From 5397edc4f3be29c6a22bc5946ec3e599df7c9e1a Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 23 May 2023 12:06:51 -0700 Subject: [PATCH 14/61] BUG: make _standardize_indexer handle lists with numpy arrays --- pysindy/utils/axes.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index d8b8a1d96..c6e7d69ad 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -184,6 +184,7 @@ def __getitem__(self, key, /): out_dim = output.shape # noqa remove_axes = [] # noqa new_axes = [] # noqa + key, _ = _standardize_indexer(self, key) if any( ( # basic indexing isinstance(key, BasicIndexer), @@ -191,7 +192,6 @@ def __getitem__(self, key, /): and all(isinstance(k, BasicIndexer) for k in key), ) ): - key, _ = _standardize_indexer(self, key) shift = 0 for ax_ind, indexer in enumerate(key): if indexer is None: @@ -356,15 +356,12 @@ def _standardize_indexer(arr: np.ndarray, key): if isinstance(ax_key, slice | int | np.ndarray): slicedim += 1 ellipsis_dims = arr.ndim - slicedim - ellind = new_key.index(Ellipsis) + # .index(Ellipsis) in case array is present. + for i, v in enumerate(new_key): + if isinstance(v, type(Ellipsis)): + ellind = i new_key[ellind : ellind + 1] = ellipsis_dims * (slice(None),) fancy_inds = [ind if ind < ellind else ind + ellind for ind in fancy_inds] - # for ax_key in new_key: - # inner_iterator = (ax_key,) - # if ax_key is Ellipsis: - # inner_iterator = (arr.ndim - slicedim) * (slice(None),) - # for el in inner_iterator: - # final_key.append(el) return tuple(new_key), tuple(fancy_inds) From 353c9d3ee230700ae1ce8450e00e4a52434b104e Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 23 May 2023 12:44:56 -0700 Subject: [PATCH 15/61] TST: enhance advanced indexing test --- test/utils/test_axes.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 89613f39b..254ef63d7 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -188,8 +188,31 @@ def test_basic_indexing_modifies_axes(): def test_fancy_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} - arr = AxesArray(np.ones(4).reshape((2, 2)), axes) + arr = AxesArray(np.arange(4).reshape((2, 2)), axes) + flat = arr[[0, 1], [0, 1]] + same = arr[[[0], [1]], [0, 1]] + tpose = arr[[0, 1], [[0], [1]]] + assert flat.shape == (2,) + np.testing.assert_array_equal(np.asarray(flat), np.array([0, 3])) + + assert flat.ax__timecoord == 0 + with pytest.raises(AttributeError): + flat.ax_coord + with pytest.raises(AttributeError): + flat.ax_time + + assert same.shape == arr.shape + np.testing.assert_equal(same, arr) + assert same.ax_time == 0 + assert same.ax_coord == 1 + + assert tpose.shape == arr.shape + np.testing.assert_equal(same, arr.T) + assert same.ax_time == 1 + assert same.ax_coord == 0 + fat = arr[[[0, 1], [0, 1]]] + assert fat.shape == (2, 2, 2) assert fat.ax_time == [0, 1] assert fat.ax_coord == 2 From fb9e01a4323409e4a4e6510103b9db6c827428b4 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 3 Jun 2023 07:40:26 -0700 Subject: [PATCH 16/61] TST: Enhance basic slicing test --- test/utils/test_axes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 254ef63d7..6e466e394 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -184,6 +184,11 @@ def test_basic_indexing_modifies_axes(): slim.ax_time assert slim.ax_unk == 1 assert slim.ax_coord == 0 + reverse_slim = arr[None, :, 1] + with pytest.raises(KeyError): + reverse_slim.ax_time + assert reverse_slim.ax_unk == 0 + assert reverse_slim.ax_coord == 1 def test_fancy_indexing_modifies_axes(): From 0acda3b9190f4985d771d1b099c91242b59ad759 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 23 May 2023 13:19:33 -0700 Subject: [PATCH 17/61] WIP: rearrange __getitem__ for advanced --- pysindy/utils/axes.py | 114 +++++++++++++++++++++++++++------------- test/utils/test_axes.py | 14 +++-- 2 files changed, 89 insertions(+), 39 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index c6e7d69ad..b83abac10 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -3,6 +3,7 @@ from typing import Collection from typing import List from typing import MutableMapping +from typing import NewType from typing import Sequence from typing import Union @@ -13,6 +14,10 @@ AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) BasicIndexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)] +Indexer = BasicIndexer | np.ndarray +OldIndex = NewType("OldIndex", int) +KeyIndex = NewType("KeyIndex", int) +NewIndex = NewType("NewIndex", int) class _AxisMapping: @@ -99,7 +104,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): new_axes[ax_dec_name].append(old_ax_dec - 1) return self._compat_axes(new_axes) - def insert_axis(self, axis: Union[Collection[int], int]): + def insert_axis(self, axis: Union[Collection[int], int], new_name: str = "ax_unk"): """Create an axes dict from self with specified axis or axes added and all greater axes incremented. @@ -116,10 +121,10 @@ def insert_axis(self, axis: Union[Collection[int], int]): if not isinstance(axis, Collection): axis = [axis] for cum_shift, ax in enumerate(axis): - if "ax_unk" in new_axes.keys(): - new_axes["ax_unk"].append(ax) + if new_name in new_axes.keys(): + new_axes[new_name].append(ax) else: - new_axes["ax_unk"] = [ax] + new_axes[new_name] = [ax] for ax_id in range(ax, in_ndim + cum_shift): ax_name = self.reverse_map[ax_id - cum_shift] new_axes[ax_name].remove(ax_id) @@ -179,42 +184,77 @@ def __getitem__(self, key, /): output = super().__getitem__(key) if not isinstance(output, AxesArray): return output - # determine axes of output - in_dim = self.shape # noqa - out_dim = output.shape # noqa - remove_axes = [] # noqa - new_axes = [] # noqa - key, _ = _standardize_indexer(self, key) - if any( - ( # basic indexing - isinstance(key, BasicIndexer), - isinstance(key, tuple) - and all(isinstance(k, BasicIndexer) for k in key), - ) - ): - shift = 0 - for ax_ind, indexer in enumerate(key): - if indexer is None: - new_axes.append(ax_ind - shift) - elif isinstance(indexer, int): - remove_axes.append(ax_ind) - shift += 1 - elif any( # fancy indexing - ( - isinstance(key, Sequence) and not isinstance(key, tuple), - isinstance(key, np.ndarray), - isinstance(key, tuple) and any(isinstance(k, Sequence) for k in key), - isinstance(key, tuple) - and any(isinstance(k, np.ndarray) for k in key), # ? - ) - ): + in_dim = self.shape + key, adv_ids = _standardize_indexer(self, key) + remove_axes = [] + new_axes = [] + leftshift = 0 + rightshift = 0 + for key_ind, indexer in enumerate(key): + if indexer is None: + new_axes.append(key_ind - leftshift) + rightshift += 1 + elif isinstance(indexer, int): + remove_axes.append(key_ind - rightshift) + leftshift += 1 + if adv_ids: + adv_ids = sorted(adv_ids) + source_axis = [ # after basic indexing applied # noqa + len([id for id in range(idx_id) if key[id] is not None]) + for idx_id in adv_ids + ] + adv_indexers = [np.array(key[i]) for i in adv_ids] # noqa + bcast_nd = np.broadcast(*adv_indexers).nd + adjacent = all(i + 1 == j for i, j in zip(adv_ids[:-1], adv_ids[1:])) + bcast_start_axis = 0 if not adjacent else min(adv_ids) + adv_map = {} + + def _compare_bcast_shapes(result_ndim, base_shape): + """Identify which broadcast shape axes are due to base_shape""" + return [ + result_ndim - 1 - ax_id + for ax_id, length in enumerate(reversed(base_shape)) + if length > 1 + ] + + for idx_id, idxer in zip(adv_ids, adv_indexers): + base_idxer_ax_name = self._reverse_map[ # count non-None keys + len([id for id in range(idx_id) if key[id] is not None]) + ] + adv_map[base_idxer_ax_name] = [ + bcast_start_axis + shp + for shp in _compare_bcast_shapes(bcast_nd, idxer.shape) + ] + + conflicts = {} + for bcast_ax in range(bcast_nd): + ax_names = [name for name, axes in adv_map.items() if bcast_ax in axes] + if len(ax_names) > 1: + conflicts[bcast_ax] = ax_names + [] + if len(ax_names) == 0: + if "ax_unk" not in adv_map.keys(): + adv_map["ax_unk"] = [bcast_ax + bcast_start_axis] + else: + adv_map["ax_unk"].append(bcast_ax + bcast_start_axis) + + for conflict_axis, conflict_names in conflicts.items(): + new_name = "ax_" + for name in conflict_names: + adv_map[name].remove(conflict_axis) + if not adv_map[name]: + adv_map.pop(name) + new_name += name[3:] + adv_map[new_name] = [conflict_axis] + # check if integer or boolean indexing # if integer, check which dimensions get broadcast where # if multiple, axes are merged. If adjacent, merged inplace, # otherwise moved to beginning + remove_axes.append(adv_map.keys()) # Error: remove_axis takes ints + + out_obj = np.broadcast(np.array(key[i]) for i in adv_ids) # noqa pass - else: - raise TypeError(f"AxisArray {self} does not know how to slice with {key}") # mulligan structured arrays, etc. new_map = _AxisMapping( self.__ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) @@ -334,7 +374,9 @@ def concatenate(arrays, axis=0): return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) -def _standardize_indexer(arr: np.ndarray, key): +def _standardize_indexer( + arr: np.ndarray, key +) -> tuple[tuple[Indexer], tuple[KeyIndex]]: """Convert to a tuple of slices, ints, None, and ndarrays. Returns: diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 6e466e394..34dc3a53d 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -22,6 +22,9 @@ def test_repr(): assert result == expected +@pytest.mark.skip( + "Not until fancy indexing (boolean) either short-circuited or implemented" +) def test_ufunc_override(): # This is largely a clone of test_ufunc_override_with_super() from # numpy/core/tests/test_umath.py @@ -100,7 +103,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert_(c is b) -# @pytest.mark.skip("Expected error") +@pytest.mark.skip("Expected error") def test_ufunc_override_accumulate(): d = np.array([[1, 2, 3], [1, 2, 3]]) a = AxesArray(d, {"ax_time": [0, 1]}) @@ -186,9 +189,14 @@ def test_basic_indexing_modifies_axes(): assert slim.ax_coord == 0 reverse_slim = arr[None, :, 1] with pytest.raises(KeyError): - reverse_slim.ax_time + reverse_slim.ax_coord assert reverse_slim.ax_unk == 0 - assert reverse_slim.ax_coord == 1 + assert reverse_slim.ax_time == 1 + almost_new = arr[None, None, 1, 1, None, None] + with pytest.raises(KeyError): + almost_new.ax_time + almost_new.ax_coord + assert set(almost_new.ax_unk) == {0, 1, 2, 3} def test_fancy_indexing_modifies_axes(): From a54e684556a62842d29be3b4ccfae66eca0df66a Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 3 Jun 2023 10:46:05 -0700 Subject: [PATCH 18/61] WIP but not at a stable point --- pysindy/utils/axes.py | 126 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 103 insertions(+), 23 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index b83abac10..234261a2a 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -4,6 +4,7 @@ from typing import List from typing import MutableMapping from typing import NewType +from typing import Optional from typing import Sequence from typing import Union @@ -18,6 +19,11 @@ OldIndex = NewType("OldIndex", int) KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) +# ListOrItem = list[T] | T +PartialReIndexer = tuple[KeyIndex, Optional[OldIndex], str] +CompleteReIndexer = tuple[ + list[KeyIndex], Optional[list[OldIndex]], Optional[list[NewIndex]] +] class _AxisMapping: @@ -180,12 +186,48 @@ def __getattr__(self, name): return shape raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") - def __getitem__(self, key, /): + def __getitem__(self, key: Indexer | Sequence[Indexer], /): output = super().__getitem__(key) if not isinstance(output, AxesArray): return output in_dim = self.shape - key, adv_ids = _standardize_indexer(self, key) + key, adv_inds = _standardize_indexer(self, key) + if adv_inds: + adjacent, bcast_nd, bcast_start_axis = _determine_adv_broadcasting(adv_inds) + else: + adjacent, bcast_nd, bcast_start_axis = True, 0, 0 + old_index = OldIndex(0) + pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] + for key_ind, indexer in enumerate(key): + if isinstance(indexer, int | slice | np.ndarray): + pindexers.append((key_ind, old_index, indexer)) + old_index += 1 + elif indexer is None: + pindexers.append((key_ind, [None], None)) + else: + raise TypeError( + f"AxesArray indexer of type {type(indexer)} not understood" + ) + if not adjacent: + _move_idxs_to_front(key, adv_inds) + adv_inds = range(len(adv_inds)) + pindexers = _squeeze_to_sublist(pindexers, adv_inds) + cindexers: list[CompleteReIndexer] = [] + curr_axis = 0 + for pindexer in enumerate(pindexers): + if isinstance(pindexer, list): # advanced indexing bundle + bcast_idxers = _adv_broadcast_magic(key, adv_inds, pindexer) + cindexers += bcast_idxers + curr_axis += bcast_nd + elif pindexer[-1] is None: + cindexers.append((*pindexer[:-1], curr_axis)) + curr_axis += 1 + elif isinstance(pindexer[-1], int): + cindexers.append((*pindexer[:-1], None)) + elif isinstance(pindexer[-1], slice): + cindexers.append((*pindexer[:-1], curr_axis)) + curr_axis += 1 + remove_axes = [] new_axes = [] leftshift = 0 @@ -197,27 +239,19 @@ def __getitem__(self, key, /): elif isinstance(indexer, int): remove_axes.append(key_ind - rightshift) leftshift += 1 - if adv_ids: - adv_ids = sorted(adv_ids) + if adv_inds: + adv_inds = sorted(adv_inds) source_axis = [ # after basic indexing applied # noqa len([id for id in range(idx_id) if key[id] is not None]) - for idx_id in adv_ids + for idx_id in adv_inds ] - adv_indexers = [np.array(key[i]) for i in adv_ids] # noqa + adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa bcast_nd = np.broadcast(*adv_indexers).nd - adjacent = all(i + 1 == j for i, j in zip(adv_ids[:-1], adv_ids[1:])) - bcast_start_axis = 0 if not adjacent else min(adv_ids) + adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) + bcast_start_axis = 0 if not adjacent else min(adv_inds) adv_map = {} - def _compare_bcast_shapes(result_ndim, base_shape): - """Identify which broadcast shape axes are due to base_shape""" - return [ - result_ndim - 1 - ax_id - for ax_id, length in enumerate(reversed(base_shape)) - if length > 1 - ] - - for idx_id, idxer in zip(adv_ids, adv_indexers): + for idx_id, idxer in zip(adv_inds, adv_indexers): base_idxer_ax_name = self._reverse_map[ # count non-None keys len([id for id in range(idx_id) if key[id] is not None]) ] @@ -253,7 +287,7 @@ def _compare_bcast_shapes(result_ndim, base_shape): # otherwise moved to beginning remove_axes.append(adv_map.keys()) # Error: remove_axis takes ints - out_obj = np.broadcast(np.array(key[i]) for i in adv_ids) # noqa + out_obj = np.broadcast(np.array(key[i]) for i in adv_inds) # noqa pass # mulligan structured arrays, etc. new_map = _AxisMapping( @@ -381,19 +415,19 @@ def _standardize_indexer( Returns: A tuple of the normalized indexer as well as the indexes of - fancy indexers + advanced indexers """ if not isinstance(key, tuple): key = (key,) if not any(ax_key is Ellipsis for ax_key in key): key = (*key, Ellipsis) new_key = [] - fancy_inds = [] + adv_inds = [] slicedim = 0 for indexer_ind, ax_key in enumerate(key): if not isinstance(ax_key, BasicIndexer): ax_key = np.array(ax_key) - fancy_inds.append(indexer_ind) + adv_inds.append(indexer_ind) new_key.append(ax_key) if isinstance(ax_key, slice | int | np.ndarray): slicedim += 1 @@ -403,8 +437,54 @@ def _standardize_indexer( if isinstance(v, type(Ellipsis)): ellind = i new_key[ellind : ellind + 1] = ellipsis_dims * (slice(None),) - fancy_inds = [ind if ind < ellind else ind + ellind for ind in fancy_inds] - return tuple(new_key), tuple(fancy_inds) + adv_inds = [ind if ind < ellind else ind + ellind for ind in adv_inds] + return tuple(new_key), tuple(adv_inds) + + +def _adv_broadcast_magic(*args): + raise NotImplementedError + + +def _compare_bcast_shapes(result_ndim: int, base_shape: tuple[int]) -> list[int]: + """Identify which broadcast shape axes are due to base_shape + + Args: + result_ndim: number of dimensions broadcast shape has + base_shape: shape of one element of broadcasting + + Result: + tuple of axes in broadcast result that come from base shape + """ + return [ + result_ndim - 1 - ax_id + for ax_id, length in enumerate(reversed(base_shape)) + if length > 1 + ] + + +def _move_idxs_to_front(li: list, idxs: Sequence) -> None: + """Move all items at indexes specified to the front of a list""" + front = [] + for idx in reversed(idxs): + obj = li.pop(idx) + front.insert(0, obj) + li = front + li + + +def _determine_adv_broadcasting( + key: Indexer | Sequence[Indexer], adv_inds: Sequence[OldIndex] +) -> tuple: + """Calculate the shape and location for the result of advanced indexing""" + adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) + adv_indexers = [np.array(key[i]) for i in adv_inds] + bcast_nd = np.broadcast(*adv_indexers).nd + bcast_start_axis = 0 if not adjacent else min(adv_inds) + return adjacent, bcast_nd, bcast_start_axis + + +def _squeeze_to_sublist(li: list, idxs: Sequence) -> list: + "Turn contiguous elements of a list into a sub-list in the same position" + return li[: min(idxs)] + [li[idx] for idx in idxs] + li[max(idxs) :] def comprehend_axes(x): From 6ccba030c095c84ec92b47124d1c31e73e242b52 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 05:52:29 +0000 Subject: [PATCH 19/61] CLN: Make AxesArray syntax a little clearer --- pysindy/utils/axes.py | 21 ++++++++++++--------- test/utils/test_axes.py | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 234261a2a..411fb4f76 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -2,7 +2,6 @@ import warnings from typing import Collection from typing import List -from typing import MutableMapping from typing import NewType from typing import Optional from typing import Sequence @@ -31,17 +30,18 @@ class _AxisMapping: indexes. """ + fwd_map: dict[str, list[int]] + reverse_map: dict[int, str] + def __init__( self, - axes: MutableMapping[str, Union[int, Sequence[int]]] = None, + axes: dict[str, Union[int, Sequence[int]]] = None, in_ndim: int = 0, ): if axes is None: axes = {} - axes = copy.deepcopy(axes) - self.fwd_map: dict[str, list[int]] = {} - self.reverse_map: dict[int, str] = {} - null = object() + self.fwd_map = {} + self.reverse_map = {} def coerce_sequence(obj): if isinstance(obj, Sequence): @@ -52,8 +52,8 @@ def coerce_sequence(obj): ax_ids = coerce_sequence(ax_ids) self.fwd_map[ax_name] = ax_ids for ax_id in ax_ids: - old_name = self.reverse_map.get(ax_id, null) - if old_name is not null: + old_name = self.reverse_map.get(ax_id) + if old_name is not None: raise ValueError(f"Assigned multiple definitions to axis {ax_id}") if ax_id >= in_ndim: raise ValueError( @@ -68,7 +68,9 @@ def coerce_sequence(obj): ) @staticmethod - def _compat_axes(in_dict: dict[str, Sequence]) -> dict[str, Union[Sequence, int]]: + def _compat_axes( + in_dict: dict[str, Sequence[int]] + ) -> dict[str, Union[Sequence[int], int]]: """Turn single-element axis index lists into ints""" axes = {} for k, v in in_dict.items(): @@ -110,6 +112,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): new_axes[ax_dec_name].append(old_ax_dec - 1) return self._compat_axes(new_axes) + # TODO: delete default kwarg value def insert_axis(self, axis: Union[Collection[int], int], new_name: str = "ax_unk"): """Create an axes dict from self with specified axis or axes added and all greater axes incremented. diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 34dc3a53d..715fcc318 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -327,7 +327,7 @@ def test_insert_AxisMapping(): }, 6, ) - result = ax_map.insert_axis(3) + result = ax_map.insert_axis(3, "ax_unk") expected = { "ax_a": [0, 1], "ax_b": 2, From b8c8739e4275fefbcb67107e5e1def67a4bb844c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 05:53:27 +0000 Subject: [PATCH 20/61] BUG: Sort axis argument when inserting or removing axes --- pysindy/utils/axes.py | 4 ++-- test/utils/test_axes.py | 48 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 411fb4f76..3b752011b 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -98,7 +98,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): in_ndim = len(self.reverse_map) if not isinstance(axis, Collection): axis = [axis] - for cum_shift, orig_ax_remove in enumerate(axis): + for cum_shift, orig_ax_remove in enumerate(sorted(axis)): remove_ax_name = self.reverse_map[orig_ax_remove] curr_ax_remove = orig_ax_remove - cum_shift if len(new_axes[remove_ax_name]) == 1: @@ -129,7 +129,7 @@ def insert_axis(self, axis: Union[Collection[int], int], new_name: str = "ax_unk in_ndim = len(self.reverse_map) if not isinstance(axis, Collection): axis = [axis] - for cum_shift, ax in enumerate(axis): + for cum_shift, ax in enumerate(sorted(axis)): if new_name in new_axes.keys(): new_axes[new_name].append(ax) else: diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 715fcc318..20e6b0b06 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -317,6 +317,13 @@ def test_reduce_twisted_AxisMapping(): assert result == expected +def test_reduce_misordered_AxisMapping(): + ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 7) + result = ax_map.remove_axis([2, 1]) + expected = {"ax_a": 0, "ax_c": 1} + assert result == expected + + def test_insert_AxisMapping(): ax_map = _AxisMapping( { @@ -338,6 +345,26 @@ def test_insert_AxisMapping(): assert result == expected +def test_insert_existing_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": [4, 5], + }, + 6, + ) + result = ax_map.insert_axis(3, "ax_b") + expected = { + "ax_a": [0, 1], + "ax_b": [2, 3], + "ax_c": 4, + "ax_d": [5, 6], + } + assert result == expected + + def test_insert_multiple_AxisMapping(): ax_map = _AxisMapping( { @@ -357,3 +384,24 @@ def test_insert_multiple_AxisMapping(): "ax_d": [6, 7], } assert result == expected + + +def test_insert_misordered_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": [4, 5], + }, + 6, + ) + result = ax_map.insert_axis([4, 1]) + expected = { + "ax_a": [0, 2], + "ax_unk": [1, 4], + "ax_b": 3, + "ax_c": 5, + "ax_d": [6, 7], + } + assert result == expected From 17842dafd8fe88db5d4d1aa021f0652ebe13405d Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 07:35:13 +0000 Subject: [PATCH 21/61] BUG: Fix everything about _squeeze_to_sublist with tests --- pysindy/utils/axes.py | 19 ++++++++++++++----- test/utils/test_axes.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 3b752011b..b3d778a68 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -13,7 +13,7 @@ HANDLED_FUNCTIONS = {} AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) -BasicIndexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)] +BasicIndexer = Union[slice, int, type(Ellipsis), type(None)] Indexer = BasicIndexer | np.ndarray OldIndex = NewType("OldIndex", int) KeyIndex = NewType("KeyIndex", int) @@ -71,7 +71,7 @@ def coerce_sequence(obj): def _compat_axes( in_dict: dict[str, Sequence[int]] ) -> dict[str, Union[Sequence[int], int]]: - """Turn single-element axis index lists into ints""" + """Like fwd_map, but unpack single-element axis lists""" axes = {} for k, v in in_dict.items(): if len(v) == 1: @@ -211,6 +211,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): raise TypeError( f"AxesArray indexer of type {type(indexer)} not understood" ) + # Advanced indexing can move axes if they are not adjacent if not adjacent: _move_idxs_to_front(key, adv_inds) adv_inds = range(len(adv_inds)) @@ -485,9 +486,17 @@ def _determine_adv_broadcasting( return adjacent, bcast_nd, bcast_start_axis -def _squeeze_to_sublist(li: list, idxs: Sequence) -> list: - "Turn contiguous elements of a list into a sub-list in the same position" - return li[: min(idxs)] + [li[idx] for idx in idxs] + li[max(idxs) :] +def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list: + """Turn contiguous elements of a list into a sub-list in the same position + + e.g. _squeeze_to_sublist(["a", "b", "c", "d"], [1,2]) = ["a", ["b", "c"], "d"] + """ + for left, right in zip(idxs[:-1], idxs[1:]): + if left + 1 != right: + raise ValueError("Indexes to squeeze must be contiguous") + if not idxs: + return li + return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :] def comprehend_axes(x): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 20e6b0b06..8893b5417 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -405,3 +405,15 @@ def test_insert_misordered_AxisMapping(): "ax_d": [6, 7], } assert result == expected + + +def test_squeeze_to_sublist(): + li = ["a", "b", "c", "d"] + result = axes._squeeze_to_sublist(li, [1, 2]) + assert result == ["a", ["b", "c"], "d"] + + result = axes._squeeze_to_sublist(li, []) + assert result == li + + with pytest.raises(ValueError, match="Indexes to squeeze"): + axes._squeeze_to_sublist(li, [0, 2]) From efbfac2bf4901f71c1ebb39d10705464c8549d68 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 07:53:25 +0000 Subject: [PATCH 22/61] CLN: explain _standardize_indexer --- pysindy/utils/axes.py | 52 ++++++++++++++++++++++++++--------------- test/utils/test_axes.py | 8 +++---- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index b3d778a68..fd0a61a56 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -194,7 +194,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): if not isinstance(output, AxesArray): return output in_dim = self.shape - key, adv_inds = _standardize_indexer(self, key) + key, adv_inds = standardize_indexer(self, key) if adv_inds: adjacent, bcast_nd, bcast_start_axis = _determine_adv_broadcasting(adv_inds) else: @@ -412,39 +412,53 @@ def concatenate(arrays, axis=0): return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) -def _standardize_indexer( - arr: np.ndarray, key +def standardize_indexer( + arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[tuple[Indexer], tuple[KeyIndex]]: - """Convert to a tuple of slices, ints, None, and ndarrays. + """Convert any legal numpy indexer to a "standard" form. + Standard form involves creating an equivalent indexer that is a tuple with + one element per index of the original axis. All advanced indexer elements + are converted to numpy arrays Returns: A tuple of the normalized indexer as well as the indexes of advanced indexers """ - if not isinstance(key, tuple): - key = (key,) + if isinstance(key, tuple): + key = list(key) + else: + key = [ + key, + ] if not any(ax_key is Ellipsis for ax_key in key): - key = (*key, Ellipsis) - new_key = [] - adv_inds = [] - slicedim = 0 + key = [*key, Ellipsis] + + _expand_indexer_ellipsis(key, arr.ndim) + + new_key: list[Indexer] = [] + adv_inds: list[int] = [] for indexer_ind, ax_key in enumerate(key): if not isinstance(ax_key, BasicIndexer): ax_key = np.array(ax_key) adv_inds.append(indexer_ind) new_key.append(ax_key) - if isinstance(ax_key, slice | int | np.ndarray): - slicedim += 1 - ellipsis_dims = arr.ndim - slicedim - # .index(Ellipsis) in case array is present. - for i, v in enumerate(new_key): - if isinstance(v, type(Ellipsis)): - ellind = i - new_key[ellind : ellind + 1] = ellipsis_dims * (slice(None),) - adv_inds = [ind if ind < ellind else ind + ellind for ind in adv_inds] return tuple(new_key), tuple(adv_inds) +def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: + """Replace ellipsis in indexers with the appropriate amount of slice(None) + + Mutates indexers + """ + try: + ellind = indexers.index(Ellipsis) + except ValueError: + return + n_new_dims = sum(k is None for k in indexers) + n_ellipsis_dims = ndim - (len(indexers) - n_new_dims - 1) + indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),) + + def _adv_broadcast_magic(*args): raise NotImplementedError diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 8893b5417..ece0d32ad 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -232,11 +232,11 @@ def test_fancy_indexing_modifies_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) - result_indexer, result_fancy = axes._standardize_indexer(arr, Ellipsis) + result_indexer, result_fancy = axes.standardize_indexer(arr, Ellipsis) assert result_indexer == (slice(None), slice(None)) assert result_fancy == () - result_indexer, result_fancy = axes._standardize_indexer( + result_indexer, result_fancy = axes.standardize_indexer( arr, (np.newaxis, 1, 1, Ellipsis) ) assert result_indexer == (None, 1, 1) @@ -245,11 +245,11 @@ def test_standardize_basic_indexer(): def test_standardize_fancy_indexer(): arr = np.arange(6).reshape(2, 3) - result_indexer, result_fancy = axes._standardize_indexer(arr, [1]) + result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) assert result_indexer == (np.ones(1), slice(None)) assert result_fancy == (0,) - result_indexer, result_fancy = axes._standardize_indexer( + result_indexer, result_fancy = axes.standardize_indexer( arr, (np.newaxis, [1], 1, Ellipsis) ) assert result_indexer == (None, np.ones(1), 1) From c101a5a1df280059e0cef07993c295aa46e26631 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 08:28:39 +0000 Subject: [PATCH 23/61] TST: Test _determine_adv_broadcasting --- pysindy/utils/axes.py | 11 ++++++----- test/utils/test_axes.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index fd0a61a56..fae7e2933 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -8,17 +8,18 @@ from typing import Union import numpy as np +from numpy.typing import NDArray from sklearn.base import TransformerMixin HANDLED_FUNCTIONS = {} AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) BasicIndexer = Union[slice, int, type(Ellipsis), type(None)] -Indexer = BasicIndexer | np.ndarray -OldIndex = NewType("OldIndex", int) +Indexer = BasicIndexer | NDArray +StandardIndexer = Union[slice, int, type(None), NDArray] +OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) -# ListOrItem = list[T] | T PartialReIndexer = tuple[KeyIndex, Optional[OldIndex], str] CompleteReIndexer = tuple[ list[KeyIndex], Optional[list[OldIndex]], Optional[list[NewIndex]] @@ -414,7 +415,7 @@ def concatenate(arrays, axis=0): def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] -) -> tuple[tuple[Indexer], tuple[KeyIndex]]: +) -> tuple[tuple[StandardIndexer], tuple[KeyIndex]]: """Convert any legal numpy indexer to a "standard" form. Standard form involves creating an equivalent indexer that is a tuple with @@ -490,7 +491,7 @@ def _move_idxs_to_front(li: list, idxs: Sequence) -> None: def _determine_adv_broadcasting( - key: Indexer | Sequence[Indexer], adv_inds: Sequence[OldIndex] + key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] ) -> tuple: """Calculate the shape and location for the result of advanced indexing""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index ece0d32ad..a32879d97 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -417,3 +417,17 @@ def test_squeeze_to_sublist(): with pytest.raises(ValueError, match="Indexes to squeeze"): axes._squeeze_to_sublist(li, [0, 2]) + + +def test_determine_adv_broadcasting(): + indexers = (np.ones(1), np.ones((4, 1)), np.ones(3)) + res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [0, 1, 2]) + assert res_adj is True + assert res_nd == 2 + assert res_start == 0 + + indexers = (None, np.ones(1), 2, np.ones(3)) + res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) + assert res_adj is False + assert res_nd == 1 + assert res_start == 0 From 4ac044f26b95eac55ad252516a6bad1a9814db45 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 10:22:28 +0000 Subject: [PATCH 24/61] CLN: Remove name mangling from AxesArray _ax_map I added it before I knew what name mangling was for --- pysindy/utils/axes.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index fae7e2933..b30b20d45 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -163,16 +163,16 @@ def __new__(cls, input_array, axes): if axes is None: axes = {} in_ndim = len(input_array.shape) - obj.__ax_map = _AxisMapping(axes, in_ndim) + obj._ax_map = _AxisMapping(axes, in_ndim) return obj @property def axes(self): - return self.__ax_map.compat_axes + return self._ax_map.compat_axes @property def _reverse_map(self): - return self.__ax_map.reverse_map + return self._ax_map.reverse_map @property def shape(self): @@ -183,7 +183,7 @@ def __getattr__(self, name): if parts[0] == "ax": return self.axes[name] if parts[0] == "n": - fwd_map = self.__ax_map.fwd_map + fwd_map = self._ax_map.fwd_map shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) if len(shape) == 1: return shape[0] @@ -296,13 +296,13 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): pass # mulligan structured arrays, etc. new_map = _AxisMapping( - self.__ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) + self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) ) new_map = _AxisMapping( new_map.insert_axis(new_axes), len(in_dim) - len(remove_axes) + len(new_axes), ) - output.__ax_map = new_map + output._ax_map = new_map return output def __array_wrap__(self, out_arr, context=None): @@ -319,7 +319,7 @@ def __array_finalize__(self, obj) -> None: not hasattr(self, "__ax_map"), ) ): - self.__ax_map = _AxisMapping({}) + self._ax_map = _AxisMapping({}) # required by ravel() and view() used in numpy testing. Also for zeros_like... elif all( ( @@ -328,7 +328,7 @@ def __array_finalize__(self, obj) -> None: self.shape == obj.shape, ) ): - self.__ax_map = _AxisMapping(obj.axes, len(obj.shape)) + self._ax_map = _AxisMapping(obj.axes, len(obj.shape)) # maybe add errors for incompatible views? def __array_ufunc__( @@ -364,7 +364,7 @@ def __array_ufunc__( ): axes = None if kwargs["axis"] is not None: - axes = self.__ax_map.remove_axis(axis=kwargs["axis"]) + axes = self._ax_map.remove_axis(axis=kwargs["axis"]) else: axes = self.axes final_results = [] From 48634dff44ea0bde1d27b96418e2a21f27eef8c7 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 10:30:05 +0000 Subject: [PATCH 25/61] CLN: Simplify advanced indexing broadcast calculation --- pysindy/utils/axes.py | 18 ++++++++---------- test/utils/test_axes.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index b30b20d45..f5347146e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -196,10 +196,8 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): return output in_dim = self.shape key, adv_inds = standardize_indexer(self, key) - if adv_inds: - adjacent, bcast_nd, bcast_start_axis = _determine_adv_broadcasting(adv_inds) - else: - adjacent, bcast_nd, bcast_start_axis = True, 0, 0 + adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) + # Handle moving around non-adjacent advanced axes old_index = OldIndex(0) pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] for key_ind, indexer in enumerate(key): @@ -253,7 +251,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa bcast_nd = np.broadcast(*adv_indexers).nd adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) - bcast_start_axis = 0 if not adjacent else min(adv_inds) + bcast_start_ax = 0 if not adjacent else min(adv_inds) adv_map = {} for idx_id, idxer in zip(adv_inds, adv_indexers): @@ -261,7 +259,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): len([id for id in range(idx_id) if key[id] is not None]) ] adv_map[base_idxer_ax_name] = [ - bcast_start_axis + shp + bcast_start_ax + shp for shp in _compare_bcast_shapes(bcast_nd, idxer.shape) ] @@ -273,9 +271,9 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): [] if len(ax_names) == 0: if "ax_unk" not in adv_map.keys(): - adv_map["ax_unk"] = [bcast_ax + bcast_start_axis] + adv_map["ax_unk"] = [bcast_ax + bcast_start_ax] else: - adv_map["ax_unk"].append(bcast_ax + bcast_start_axis) + adv_map["ax_unk"].append(bcast_ax + bcast_start_ax) for conflict_axis, conflict_names in conflicts.items(): new_name = "ax_" @@ -493,11 +491,11 @@ def _move_idxs_to_front(li: list, idxs: Sequence) -> None: def _determine_adv_broadcasting( key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] ) -> tuple: - """Calculate the shape and location for the result of advanced indexing""" + """Calculate the shape and location for the result of advanced indexing.""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) adv_indexers = [np.array(key[i]) for i in adv_inds] bcast_nd = np.broadcast(*adv_indexers).nd - bcast_start_axis = 0 if not adjacent else min(adv_inds) + bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None return adjacent, bcast_nd, bcast_start_axis diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index a32879d97..b37b20b83 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -420,14 +420,19 @@ def test_squeeze_to_sublist(): def test_determine_adv_broadcasting(): - indexers = (np.ones(1), np.ones((4, 1)), np.ones(3)) - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [0, 1, 2]) + indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3)) + res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) assert res_adj is True assert res_nd == 2 - assert res_start == 0 + assert res_start == 1 indexers = (None, np.ones(1), 2, np.ones(3)) res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) assert res_adj is False assert res_nd == 1 assert res_start == 0 + + res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) + assert res_adj is True + assert res_nd == 0 + assert res_start is None From 91064dddde9af02d45a5e1c87bdd1c77b1fd6cdb Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:44:17 +0000 Subject: [PATCH 26/61] CLN: Extract function on the basic indexing TBD: should this fully return a new _AxisMapping and maybe a new indexer? --- pysindy/utils/axes.py | 35 ++++++++++++++++++++++++----------- test/utils/test_axes.py | 6 +++--- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index f5347146e..bc16e8a7f 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -197,6 +197,8 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): in_dim = self.shape key, adv_inds = standardize_indexer(self, key) adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) + remove_axes, new_axes = _apply_basic_indexing(key) + # Handle moving around non-adjacent advanced axes old_index = OldIndex(0) pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] @@ -231,17 +233,6 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): cindexers.append((*pindexer[:-1], curr_axis)) curr_axis += 1 - remove_axes = [] - new_axes = [] - leftshift = 0 - rightshift = 0 - for key_ind, indexer in enumerate(key): - if indexer is None: - new_axes.append(key_ind - leftshift) - rightshift += 1 - elif isinstance(indexer, int): - remove_axes.append(key_ind - rightshift) - leftshift += 1 if adv_inds: adv_inds = sorted(adv_inds) source_axis = [ # after basic indexing applied # noqa @@ -512,6 +503,28 @@ def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list: return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :] +def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[int]]: + """Determine where axes should be removed and added + + Only considers the basic indexers in key. Numpy arrays are treated as + slices, in that they don't affect the final dimensions of the output + """ + remove_axes = [] + new_axes = [] + deleted_to_left = 0 + added_to_left = 0 + for key_ind, indexer in enumerate(key): + if isinstance(indexer, int): + orig_arr_axis = key_ind - added_to_left + remove_axes.append(orig_arr_axis) + deleted_to_left += 1 + elif indexer is None: + new_arr_axis = key_ind - deleted_to_left + new_axes.append(new_arr_axis) + added_to_left += 1 + return remove_axes, new_axes + + def comprehend_axes(x): axes = {} axes["ax_coord"] = len(x.shape) - 1 diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b37b20b83..8892837df 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -192,11 +192,11 @@ def test_basic_indexing_modifies_axes(): reverse_slim.ax_coord assert reverse_slim.ax_unk == 0 assert reverse_slim.ax_time == 1 - almost_new = arr[None, None, 1, 1, None, None] + almost_new = arr[None, None, 1, :, None, None] with pytest.raises(KeyError): almost_new.ax_time - almost_new.ax_coord - assert set(almost_new.ax_unk) == {0, 1, 2, 3} + assert almost_new.ax_coord == 2 + assert set(almost_new.ax_unk) == {0, 1, 3, 4} def test_fancy_indexing_modifies_axes(): From 700521a69168115bbead6280dbf1e5f9c7c9317c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:46:48 +0000 Subject: [PATCH 27/61] CLN: Type the return of determine_adv_broadcasting --- pysindy/utils/axes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index bc16e8a7f..d6a8d7046 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -481,7 +481,7 @@ def _move_idxs_to_front(li: list, idxs: Sequence) -> None: def _determine_adv_broadcasting( key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] -) -> tuple: +) -> tuple[bool, int, Optional[int]]: """Calculate the shape and location for the result of advanced indexing.""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) adv_indexers = [np.array(key[i]) for i in adv_inds] From 204223f105fb383de083ee6b7b54c98b53f2dad4 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 17:36:33 +0000 Subject: [PATCH 28/61] ENH: Enable fancy indexing in AxesArray Involves processing the keys several times, with increasing standardization --- pysindy/utils/axes.py | 234 +++++++++++++++------------------------- test/utils/test_axes.py | 33 ++---- 2 files changed, 96 insertions(+), 171 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index d6a8d7046..8a0f9dd96 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,7 +1,9 @@ import copy import warnings +from enum import Enum from typing import Collection from typing import List +from typing import Literal from typing import NewType from typing import Optional from typing import Sequence @@ -26,6 +28,14 @@ ] +class Sentinels(Enum): + ADV_NAME = object() + ADV_REMOVE = object() + + +Literal[Sentinels.ADV_NAME] + + class _AxisMapping: """Convenience wrapper for a two-way map between axis names and indexes. @@ -181,7 +191,10 @@ def shape(self): def __getattr__(self, name): parts = name.split("_", 1) if parts[0] == "ax": - return self.axes[name] + try: + return self.axes[name] + except KeyError: + raise AttributeError(f"AxesArray has no axis '{name}'") if parts[0] == "n": fwd_map = self._ax_map.fwd_map shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) @@ -193,104 +206,22 @@ def __getattr__(self, name): def __getitem__(self, key: Indexer | Sequence[Indexer], /): output = super().__getitem__(key) if not isinstance(output, AxesArray): - return output + return output # why? in_dim = self.shape key, adv_inds = standardize_indexer(self, key) - adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) - remove_axes, new_axes = _apply_basic_indexing(key) - - # Handle moving around non-adjacent advanced axes - old_index = OldIndex(0) - pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] - for key_ind, indexer in enumerate(key): - if isinstance(indexer, int | slice | np.ndarray): - pindexers.append((key_ind, old_index, indexer)) - old_index += 1 - elif indexer is None: - pindexers.append((key_ind, [None], None)) - else: - raise TypeError( - f"AxesArray indexer of type {type(indexer)} not understood" - ) - # Advanced indexing can move axes if they are not adjacent - if not adjacent: - _move_idxs_to_front(key, adv_inds) - adv_inds = range(len(adv_inds)) - pindexers = _squeeze_to_sublist(pindexers, adv_inds) - cindexers: list[CompleteReIndexer] = [] - curr_axis = 0 - for pindexer in enumerate(pindexers): - if isinstance(pindexer, list): # advanced indexing bundle - bcast_idxers = _adv_broadcast_magic(key, adv_inds, pindexer) - cindexers += bcast_idxers - curr_axis += bcast_nd - elif pindexer[-1] is None: - cindexers.append((*pindexer[:-1], curr_axis)) - curr_axis += 1 - elif isinstance(pindexer[-1], int): - cindexers.append((*pindexer[:-1], None)) - elif isinstance(pindexer[-1], slice): - cindexers.append((*pindexer[:-1], curr_axis)) - curr_axis += 1 - + bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) if adv_inds: - adv_inds = sorted(adv_inds) - source_axis = [ # after basic indexing applied # noqa - len([id for id in range(idx_id) if key[id] is not None]) - for idx_id in adv_inds - ] - adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa - bcast_nd = np.broadcast(*adv_indexers).nd - adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) - bcast_start_ax = 0 if not adjacent else min(adv_inds) - adv_map = {} - - for idx_id, idxer in zip(adv_inds, adv_indexers): - base_idxer_ax_name = self._reverse_map[ # count non-None keys - len([id for id in range(idx_id) if key[id] is not None]) - ] - adv_map[base_idxer_ax_name] = [ - bcast_start_ax + shp - for shp in _compare_bcast_shapes(bcast_nd, idxer.shape) - ] - - conflicts = {} - for bcast_ax in range(bcast_nd): - ax_names = [name for name, axes in adv_map.items() if bcast_ax in axes] - if len(ax_names) > 1: - conflicts[bcast_ax] = ax_names - [] - if len(ax_names) == 0: - if "ax_unk" not in adv_map.keys(): - adv_map["ax_unk"] = [bcast_ax + bcast_start_ax] - else: - adv_map["ax_unk"].append(bcast_ax + bcast_start_ax) - - for conflict_axis, conflict_names in conflicts.items(): - new_name = "ax_" - for name in conflict_names: - adv_map[name].remove(conflict_axis) - if not adv_map[name]: - adv_map.pop(name) - new_name += name[3:] - adv_map[new_name] = [conflict_axis] - - # check if integer or boolean indexing - # if integer, check which dimensions get broadcast where - # if multiple, axes are merged. If adjacent, merged inplace, - # otherwise moved to beginning - remove_axes.append(adv_map.keys()) # Error: remove_axis takes ints - - out_obj = np.broadcast(np.array(key[i]) for i in adv_inds) # noqa - pass - # mulligan structured arrays, etc. + key = replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) + remove_axes, new_axes, adv_names = _apply_indexing(key, self._reverse_map) + new_axes = _rename_broadcast_axes(new_axes, adv_names) new_map = _AxisMapping( self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) ) - new_map = _AxisMapping( - new_map.insert_axis(new_axes), - len(in_dim) - len(remove_axes) + len(new_axes), - ) + for new_ax_ind, new_ax_name in new_axes: + new_map = _AxisMapping( + new_map.insert_axis(new_ax_ind, new_ax_name), + len(in_dim) - len(remove_axes) + len(new_axes), + ) output._ax_map = new_map return output @@ -404,7 +335,7 @@ def concatenate(arrays, axis=0): def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] -) -> tuple[tuple[StandardIndexer], tuple[KeyIndex]]: +) -> tuple[list[StandardIndexer], tuple[KeyIndex]]: """Convert any legal numpy indexer to a "standard" form. Standard form involves creating an equivalent indexer that is a tuple with @@ -432,7 +363,7 @@ def standardize_indexer( ax_key = np.array(ax_key) adv_inds.append(indexer_ind) new_key.append(ax_key) - return tuple(new_key), tuple(adv_inds) + return new_key, tuple(adv_inds) def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: @@ -449,61 +380,63 @@ def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),) -def _adv_broadcast_magic(*args): - raise NotImplementedError - - -def _compare_bcast_shapes(result_ndim: int, base_shape: tuple[int]) -> list[int]: - """Identify which broadcast shape axes are due to base_shape - - Args: - result_ndim: number of dimensions broadcast shape has - base_shape: shape of one element of broadcasting - - Result: - tuple of axes in broadcast result that come from base shape - """ - return [ - result_ndim - 1 - ax_id - for ax_id, length in enumerate(reversed(base_shape)) - if length > 1 - ] - - -def _move_idxs_to_front(li: list, idxs: Sequence) -> None: - """Move all items at indexes specified to the front of a list""" - front = [] - for idx in reversed(idxs): - obj = li.pop(idx) - front.insert(0, obj) - li = front + li - - def _determine_adv_broadcasting( - key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] -) -> tuple[bool, int, Optional[int]]: + key: Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] +) -> tuple[int, Optional[KeyIndex]]: """Calculate the shape and location for the result of advanced indexing.""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) adv_indexers = [np.array(key[i]) for i in adv_inds] bcast_nd = np.broadcast(*adv_indexers).nd bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None - return adjacent, bcast_nd, bcast_start_axis - - -def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list: - """Turn contiguous elements of a list into a sub-list in the same position - - e.g. _squeeze_to_sublist(["a", "b", "c", "d"], [1,2]) = ["a", ["b", "c"], "d"] - """ - for left, right in zip(idxs[:-1], idxs[1:]): - if left + 1 != right: - raise ValueError("Indexes to squeeze must be contiguous") - if not idxs: - return li - return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :] - - -def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[int]]: + return bcast_nd, KeyIndex(bcast_start_axis) + + +def _rename_broadcast_axes( + new_axes: list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], + adv_names: list[str], +) -> list[tuple[int, str]]: + """Normalize sentinel and NoneType names""" + + def _calc_bcast_name(*names: str) -> str: + if not names: + return "" + if all(a == b for a, b in zip(names[1:], names[:-1])): + return names[0] + names = [name[3:] for name in dict.fromkeys(names)] # ordered deduplication + return "ax_" + "_".join(names) + + bcast_name = _calc_bcast_name(*adv_names) + renamed_axes = [] + for ax_ind, ax_name in new_axes: + if ax_name is None: + renamed_axes.append((ax_ind, "ax_unk")) + elif ax_name is Sentinels.ADV_NAME: + renamed_axes.append((ax_ind, bcast_name)) + else: + renamed_axes.append((ax_ind, ax_name)) + return renamed_axes + + +def replace_adv_indexers( + key: Sequence[StandardIndexer], + adv_inds: list[int], + bcast_start_ax: int, + bcast_nd: int, +) -> tuple[ + Union[None, str, int, Literal[Sentinels.ADV_NAME], Literal[Sentinels.ADV_REMOVE]], + ..., +]: + for adv_ind in adv_inds: + key[adv_ind] = Sentinels.ADV_REMOVE + key = key[:bcast_start_ax] + bcast_nd * [Sentinels.ADV_NAME] + key[bcast_start_ax:] + return key + + +def _apply_indexing( + key: tuple[StandardIndexer], reverse_map: dict[int, str] +) -> tuple[ + list[int], list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], list[str] +]: """Determine where axes should be removed and added Only considers the basic indexers in key. Numpy arrays are treated as @@ -511,18 +444,23 @@ def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[ """ remove_axes = [] new_axes = [] + adv_names = [] deleted_to_left = 0 added_to_left = 0 for key_ind, indexer in enumerate(key): - if isinstance(indexer, int): + if isinstance(indexer, int) or indexer is Sentinels.ADV_REMOVE: orig_arr_axis = key_ind - added_to_left + if indexer is Sentinels.ADV_REMOVE: + adv_names.append(reverse_map[orig_arr_axis]) remove_axes.append(orig_arr_axis) deleted_to_left += 1 - elif indexer is None: + elif ( + indexer is None or indexer is Sentinels.ADV_NAME or isinstance(indexer, str) + ): new_arr_axis = key_ind - deleted_to_left - new_axes.append(new_arr_axis) + new_axes.append((new_arr_axis, indexer)) added_to_left += 1 - return remove_axes, new_axes + return remove_axes, new_axes, adv_names def comprehend_axes(x): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 8892837df..ebe703f47 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -22,9 +22,6 @@ def test_repr(): assert result == expected -@pytest.mark.skip( - "Not until fancy indexing (boolean) either short-circuited or implemented" -) def test_ufunc_override(): # This is largely a clone of test_ufunc_override_with_super() from # numpy/core/tests/test_umath.py @@ -199,7 +196,7 @@ def test_basic_indexing_modifies_axes(): assert set(almost_new.ax_unk) == {0, 1, 3, 4} -def test_fancy_indexing_modifies_axes(): +def test_adv_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.arange(4).reshape((2, 2)), axes) flat = arr[[0, 1], [0, 1]] @@ -208,21 +205,23 @@ def test_fancy_indexing_modifies_axes(): assert flat.shape == (2,) np.testing.assert_array_equal(np.asarray(flat), np.array([0, 3])) - assert flat.ax__timecoord == 0 + assert flat.ax_time_coord == 0 with pytest.raises(AttributeError): flat.ax_coord with pytest.raises(AttributeError): flat.ax_time assert same.shape == arr.shape - np.testing.assert_equal(same, arr) - assert same.ax_time == 0 - assert same.ax_coord == 1 + np.testing.assert_equal(np.asarray(same), np.asarray(arr)) + assert same.ax_time_coord == [0, 1] + with pytest.raises(AttributeError): + same.ax_coord assert tpose.shape == arr.shape - np.testing.assert_equal(same, arr.T) - assert same.ax_time == 1 - assert same.ax_coord == 0 + np.testing.assert_equal(np.asarray(tpose), np.asarray(arr.T)) + assert tpose.ax_time_coord == [0, 1] + with pytest.raises(AttributeError): + tpose.ax_coord fat = arr[[[0, 1], [0, 1]]] assert fat.shape == (2, 2, 2) @@ -407,18 +406,6 @@ def test_insert_misordered_AxisMapping(): assert result == expected -def test_squeeze_to_sublist(): - li = ["a", "b", "c", "d"] - result = axes._squeeze_to_sublist(li, [1, 2]) - assert result == ["a", ["b", "c"], "d"] - - result = axes._squeeze_to_sublist(li, []) - assert result == li - - with pytest.raises(ValueError, match="Indexes to squeeze"): - axes._squeeze_to_sublist(li, [0, 2]) - - def test_determine_adv_broadcasting(): indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3)) res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) From 49ac5a0a0af5e931b9f7987a3fa3fae86b76e810 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 17:42:16 +0000 Subject: [PATCH 29/61] TST: Update tests for new helper function values --- pysindy/utils/axes.py | 7 +++++-- test/utils/test_axes.py | 45 +++++++++++++++++------------------------ 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 8a0f9dd96..1e803b1ba 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -196,8 +196,11 @@ def __getattr__(self, name): except KeyError: raise AttributeError(f"AxesArray has no axis '{name}'") if parts[0] == "n": - fwd_map = self._ax_map.fwd_map - shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) + try: + ax_ids = self._ax_map.fwd_map["ax_" + parts[1]] + except KeyError: + raise AttributeError(f"AxesArray has no axis '{name}'") + shape = tuple(self.shape[ax_id] for ax_id in ax_ids) if len(shape) == 1: return shape[0] return shape diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index ebe703f47..f33b94750 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -145,19 +145,6 @@ def test_n_elements(): assert arr2.n_coord == 4 -@pytest.mark.skip("Expected error") -def test_limited_slice(): - arr = np.empty(np.arange(1, 5)) - arr = AxesArray(arr, {"ax_spatial": [0, 1], "ax_time": 2, "ax_coord": 3}) - arr3 = arr[..., :2, 0] - assert arr3.n_spatial == (1, 2) - assert arr3.n_time == 2 - # No way to intercept slicing and remove ax_coord - with pytest.raises(IndexError): - assert arr3.n_coord == 1 - assert arr3.n_sample == 1 - - def test_warn_toofew_axes(): axes = {"ax_time": 0, "ax_coord": 1} with pytest.warns(AxesWarning): @@ -176,21 +163,30 @@ def test_conflicting_axes_defn(): AxesArray(np.ones(4), axes) +def test_missing_axis_errors(): + axes = {"ax_time": 0} + arr = AxesArray(np.arange(3), axes) + with pytest.raises(AttributeError): + arr.ax_spatial + with pytest.raises(AttributeError): + arr.n_spatial + + def test_basic_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) slim = arr[1, :, None] - with pytest.raises(KeyError): + with pytest.raises(AttributeError): slim.ax_time assert slim.ax_unk == 1 assert slim.ax_coord == 0 reverse_slim = arr[None, :, 1] - with pytest.raises(KeyError): + with pytest.raises(AttributeError): reverse_slim.ax_coord assert reverse_slim.ax_unk == 0 assert reverse_slim.ax_time == 1 almost_new = arr[None, None, 1, :, None, None] - with pytest.raises(KeyError): + with pytest.raises(AttributeError): almost_new.ax_time assert almost_new.ax_coord == 2 assert set(almost_new.ax_unk) == {0, 1, 3, 4} @@ -232,26 +228,26 @@ def test_adv_indexing_modifies_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) result_indexer, result_fancy = axes.standardize_indexer(arr, Ellipsis) - assert result_indexer == (slice(None), slice(None)) + assert result_indexer == [slice(None), slice(None)] assert result_fancy == () result_indexer, result_fancy = axes.standardize_indexer( arr, (np.newaxis, 1, 1, Ellipsis) ) - assert result_indexer == (None, 1, 1) + assert result_indexer == [None, 1, 1] assert result_fancy == () def test_standardize_fancy_indexer(): arr = np.arange(6).reshape(2, 3) result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) - assert result_indexer == (np.ones(1), slice(None)) + assert result_indexer == [np.ones(1), slice(None)] assert result_fancy == (0,) result_indexer, result_fancy = axes.standardize_indexer( arr, (np.newaxis, [1], 1, Ellipsis) ) - assert result_indexer == (None, np.ones(1), 1) + assert result_indexer == [None, np.ones(1), 1] assert result_fancy == (1,) @@ -408,18 +404,15 @@ def test_insert_misordered_AxisMapping(): def test_determine_adv_broadcasting(): indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3)) - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) - assert res_adj is True + res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) assert res_nd == 2 assert res_start == 1 indexers = (None, np.ones(1), 2, np.ones(3)) - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) - assert res_adj is False + res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) assert res_nd == 1 assert res_start == 0 - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) - assert res_adj is True + res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) assert res_nd == 0 assert res_start is None From 42088a172fe91d0d607d14edfded5812736b844a Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:34:15 +0000 Subject: [PATCH 30/61] ENH: Enable boolean advanced indexing in AxesArray Modify the standardize_indexer() function Parameterize StandardIndexer --- pysindy/utils/axes.py | 59 ++++++++++++++++++++++++----------------- test/utils/test_axes.py | 17 ++++++------ 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 1e803b1ba..2235f4a53 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -17,8 +17,8 @@ AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) BasicIndexer = Union[slice, int, type(Ellipsis), type(None)] -Indexer = BasicIndexer | NDArray -StandardIndexer = Union[slice, int, type(None), NDArray] +Indexer = BasicIndexer | NDArray | list +StandardIndexer = Union[slice, int, type(None), NDArray[np.dtype(int)]] OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) @@ -338,12 +338,13 @@ def concatenate(arrays, axis=0): def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] -) -> tuple[list[StandardIndexer], tuple[KeyIndex]]: +) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: """Convert any legal numpy indexer to a "standard" form. Standard form involves creating an equivalent indexer that is a tuple with one element per index of the original axis. All advanced indexer elements - are converted to numpy arrays + are converted to numpy arrays, and boolean arrays are converted to + integer arrays with obj.nonzero(). Returns: A tuple of the normalized indexer as well as the indexes of advanced indexers @@ -351,36 +352,46 @@ def standardize_indexer( if isinstance(key, tuple): key = list(key) else: - key = [ - key, - ] + key = [key] + if not any(ax_key is Ellipsis for ax_key in key): key = [*key, Ellipsis] - _expand_indexer_ellipsis(key, arr.ndim) - new_key: list[Indexer] = [] - adv_inds: list[int] = [] - for indexer_ind, ax_key in enumerate(key): + for ax_key in key: if not isinstance(ax_key, BasicIndexer): ax_key = np.array(ax_key) - adv_inds.append(indexer_ind) + if ax_key.dtype == np.dtype(np.bool_): + new_key += ax_key.nonzero() + continue new_key.append(ax_key) - return new_key, tuple(adv_inds) + new_key = _expand_indexer_ellipsis(new_key, arr.ndim) + # Can't identify position of advanced indexers before expanding ellipses + adv_inds: list[KeyIndex] = [] + for key_ind, ax_key in enumerate(new_key): + if isinstance(ax_key, np.ndarray): + adv_inds.append(KeyIndex(key_ind)) -def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: - """Replace ellipsis in indexers with the appropriate amount of slice(None) + return new_key, tuple(adv_inds) - Mutates indexers - """ - try: - ellind = indexers.index(Ellipsis) - except ValueError: - return - n_new_dims = sum(k is None for k in indexers) - n_ellipsis_dims = ndim - (len(indexers) - n_new_dims - 1) - indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),) + +def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]: + """Replace ellipsis in indexers with the appropriate amount of slice(None)""" + # [...].index errors if list contains numpy array + ellind = [ind for ind, val in enumerate(key) if val is ...][0] + new_key = [] + n_new_dims = sum(ax_key is None for ax_key in key) + n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1) + new_key = ( + key[:ellind] + + n_ellipsis_dims + * [ + slice(None), + ] + + key[ellind + 1 + n_ellipsis_dims :] + ) + return new_key def _determine_adv_broadcasting( diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index f33b94750..e396fad05 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -98,15 +98,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): c = np.add.reduce(a, 1, None, b) assert_equal(c, check) assert_(c is b) - - -@pytest.mark.skip("Expected error") -def test_ufunc_override_accumulate(): - d = np.array([[1, 2, 3], [1, 2, 3]]) - a = AxesArray(d, {"ax_time": [0, 1]}) check = np.add.accumulate(d, axis=0) c = np.add.accumulate(a, axis=0) - assert_equal(c, check) + # assert_equal(c, check) b = np.zeros_like(c) c = np.add.accumulate(a, 0, None, b) assert_equal(c, check) @@ -238,7 +232,7 @@ def test_standardize_basic_indexer(): assert result_fancy == () -def test_standardize_fancy_indexer(): +def test_standardize_advanced_indexer(): arr = np.arange(6).reshape(2, 3) result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) assert result_indexer == [np.ones(1), slice(None)] @@ -251,6 +245,13 @@ def test_standardize_fancy_indexer(): assert result_fancy == (1,) +def test_standardize_bool_indexer(): + arr = np.ones((1, 2)) + result, result_adv = axes.standardize_indexer(arr, [[True, True]]) + assert_equal(result, [[0, 0], [0, 1]]) + assert result_adv == (0, 1) + + def test_reduce_AxisMapping(): ax_map = _AxisMapping( { From 62d12eab1be044ada227f59eb372a421e307f0fa Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:59:29 +0000 Subject: [PATCH 31/61] ENH: Allow inserting axes by adding strings to index --- pysindy/utils/axes.py | 12 ++++++++---- test/utils/test_axes.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 2235f4a53..b61b7ca05 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -16,7 +16,7 @@ HANDLED_FUNCTIONS = {} AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) -BasicIndexer = Union[slice, int, type(Ellipsis), type(None)] +BasicIndexer = Union[slice, int, type(Ellipsis), type(None), str] Indexer = BasicIndexer | NDArray | list StandardIndexer = Union[slice, int, type(None), NDArray[np.dtype(int)]] OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent @@ -207,7 +207,11 @@ def __getattr__(self, name): raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") def __getitem__(self, key: Indexer | Sequence[Indexer], /): - output = super().__getitem__(key) + if isinstance(key, list | np.ndarray): + base_indexer = key + else: + base_indexer = tuple(None if isinstance(k, str) else k for k in key) + output = super().__getitem__(base_indexer) if not isinstance(output, AxesArray): return output # why? in_dim = self.shape @@ -381,7 +385,7 @@ def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]: # [...].index errors if list contains numpy array ellind = [ind for ind, val in enumerate(key) if val is ...][0] new_key = [] - n_new_dims = sum(ax_key is None for ax_key in key) + n_new_dims = sum(ax_key is None or isinstance(ax_key, str) for ax_key in key) n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1) new_key = ( key[:ellind] @@ -427,7 +431,7 @@ def _calc_bcast_name(*names: str) -> str: elif ax_name is Sentinels.ADV_NAME: renamed_axes.append((ax_ind, bcast_name)) else: - renamed_axes.append((ax_ind, ax_name)) + renamed_axes.append((ax_ind, "ax_" + ax_name)) return renamed_axes diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index e396fad05..d9cb02bcb 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -186,6 +186,14 @@ def test_basic_indexing_modifies_axes(): assert set(almost_new.ax_unk) == {0, 1, 3, 4} +def test_insert_named_axis(): + arr = AxesArray(np.ones(1), axes={"ax_time": 0}) + expanded = arr["time", :] + result = expanded.axes + expected = {"ax_time": [0, 1]} + assert result == expected + + def test_adv_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.arange(4).reshape((2, 2)), axes) @@ -213,6 +221,10 @@ def test_adv_indexing_modifies_axes(): with pytest.raises(AttributeError): tpose.ax_coord + +def test_adv_indexing_adds_axes(): + axes = {"ax_time": 0, "ax_coord": 1} + arr = AxesArray(np.arange(4).reshape((2, 2)), axes) fat = arr[[[0, 1], [0, 1]]] assert fat.shape == (2, 2, 2) assert fat.ax_time == [0, 1] From 341f15c605a1cb14684fdb2f21171488bbb7e0f2 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 5 Jan 2024 18:02:12 +0000 Subject: [PATCH 32/61] CLN: Remove default name for a new axis --- pysindy/utils/axes.py | 3 +-- test/utils/test_axes.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index b61b7ca05..298d8ae04 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -123,8 +123,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): new_axes[ax_dec_name].append(old_ax_dec - 1) return self._compat_axes(new_axes) - # TODO: delete default kwarg value - def insert_axis(self, axis: Union[Collection[int], int], new_name: str = "ax_unk"): + def insert_axis(self, axis: Union[Collection[int], int], new_name: str): """Create an axes dict from self with specified axis or axes added and all greater axes incremented. diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index d9cb02bcb..184667873 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -383,7 +383,7 @@ def test_insert_multiple_AxisMapping(): }, 6, ) - result = ax_map.insert_axis([1, 4]) + result = ax_map.insert_axis([1, 4], new_name="ax_unk") expected = { "ax_a": [0, 2], "ax_unk": [1, 4], @@ -404,7 +404,7 @@ def test_insert_misordered_AxisMapping(): }, 6, ) - result = ax_map.insert_axis([4, 1]) + result = ax_map.insert_axis([4, 1], new_name="ax_unk") expected = { "ax_a": [0, 2], "ax_unk": [1, 4], From 0ae9f6abf47c8fa36e4cf3bf59e4bd4ea173cfe2 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 6 Jan 2024 00:10:35 +0000 Subject: [PATCH 33/61] CLN: Remove default name for a new axis --- pysindy/utils/axes.py | 4 ++-- test/utils/test_axes.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 298d8ae04..0c8af2294 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -242,7 +242,7 @@ def __array_finalize__(self, obj) -> None: ( not isinstance(obj, AxesArray), self.shape == (), - not hasattr(self, "__ax_map"), + not hasattr(self, "_ax_map"), ) ): self._ax_map = _AxisMapping({}) @@ -250,7 +250,7 @@ def __array_finalize__(self, obj) -> None: elif all( ( isinstance(obj, AxesArray), - not hasattr(self, "__ax_map"), + not hasattr(self, "_ax_map"), self.shape == obj.shape, ) ): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 184667873..870bbab11 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -100,7 +100,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert_(c is b) check = np.add.accumulate(d, axis=0) c = np.add.accumulate(a, axis=0) - # assert_equal(c, check) + assert_equal(c, check) b = np.zeros_like(c) c = np.add.accumulate(a, 0, None, b) assert_equal(c, check) @@ -119,11 +119,11 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): a = d.copy().view(AxesArray) np.add.at(check, ([0, 1], [0, 2]), 1.0) np.add.at(a, ([0, 1], [0, 2]), 1.0) - assert_equal(a, check) + assert_equal(np.asarray(a), np.asarray(check)) # modified b = np.array(1.0).view(AxesArray) a = d.copy().view(AxesArray) np.add.at(a, ([0, 1], [0, 2]), b) - assert_equal(a, check) + assert_equal(np.asarray(a), np.asarray(check)) # modified def test_n_elements(): From 80452a6c0ee11ef155c80582c6274b42d09536b0 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 7 Jan 2024 18:27:59 +0000 Subject: [PATCH 34/61] CLN: Remove unused type expressions --- pysindy/utils/axes.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 0c8af2294..5758c0fbe 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -22,10 +22,6 @@ OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) -PartialReIndexer = tuple[KeyIndex, Optional[OldIndex], str] -CompleteReIndexer = tuple[ - list[KeyIndex], Optional[list[OldIndex]], Optional[list[NewIndex]] -] class Sentinels(Enum): From 51bae6737072046f06504e0797d4f4c9a4565959 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:03:48 +0000 Subject: [PATCH 35/61] bug(axes): Only pre-standardize tuple indexers Previously, numpy arrays were the only non-iterables allowed --- pysindy/utils/axes.py | 7 ++++--- test/utils/test_axes.py | 8 ++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 5758c0fbe..0de44e036 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -184,6 +184,7 @@ def shape(self): return super().shape def __getattr__(self, name): + # TODO: replace with structural pattern matching on Oct 2025 (3.9 EOL) parts = name.split("_", 1) if parts[0] == "ax": try: @@ -202,10 +203,10 @@ def __getattr__(self, name): raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") def __getitem__(self, key: Indexer | Sequence[Indexer], /): - if isinstance(key, list | np.ndarray): - base_indexer = key - else: + if isinstance(key, tuple): base_indexer = tuple(None if isinstance(k, str) else k for k in key) + else: + base_indexer = key output = super().__getitem__(base_indexer) if not isinstance(output, AxesArray): return output # why? diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 870bbab11..a52a75a99 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -1,6 +1,7 @@ import numpy as np import pytest from numpy.testing import assert_ +from numpy.testing import assert_array_equal from numpy.testing import assert_equal from numpy.testing import assert_raises @@ -166,6 +167,13 @@ def test_missing_axis_errors(): arr.n_spatial +def test_simple_slice(): + arr = AxesArray(np.ones(2), {"ax_coord": 0}) + assert_array_equal(arr[:], arr) + assert_array_equal(arr[slice(None)], arr) + assert arr[0] == 1 + + def test_basic_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) From 23817f0999491073ebb6c5efeecaa7e983a8d7b1 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:30:11 +0000 Subject: [PATCH 36/61] feat(_AxisMapping): create an ndim property Used in fixing bug: Handle removing negative axis indexes --- pysindy/utils/axes.py | 11 +++++++++-- test/utils/test_axes.py | 17 ++++------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 0de44e036..7fda04f49 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -93,7 +93,8 @@ def compat_axes(self): def remove_axis(self, axis: Union[Collection[int], int, None] = None): """Create an axes dict from self with specified axis or axes - removed and all greater axes decremented. + removed and all greater axes decremented. This can be passed to + the constructor to create a new _AxisMapping Arguments: axis: the axis index or axes indexes to remove. By numpy @@ -105,6 +106,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): in_ndim = len(self.reverse_map) if not isinstance(axis, Collection): axis = [axis] + axis = [ax_id if ax_id >= 0 else (self.ndim + ax_id) for ax_id in axis] for cum_shift, orig_ax_remove in enumerate(sorted(axis)): remove_ax_name = self.reverse_map[orig_ax_remove] curr_ax_remove = orig_ax_remove - cum_shift @@ -146,6 +148,10 @@ def insert_axis(self, axis: Union[Collection[int], int], new_name: str): new_axes[ax_name].append(ax_id + 1) return self._compat_axes(new_axes) + @property + def ndim(self): + return len(self.reverse_map) + class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. @@ -481,7 +487,8 @@ def comprehend_axes(x): axes = {} axes["ax_coord"] = len(x.shape) - 1 axes["ax_time"] = len(x.shape) - 2 - axes["ax_spatial"] = list(range(len(x.shape) - 2)) + if x.ndim > 2: + axes["ax_spatial"] = list(range(len(x.shape) - 2)) return axes diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index a52a75a99..f5576f48b 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -274,22 +274,13 @@ def test_standardize_bool_indexer(): def test_reduce_AxisMapping(): ax_map = _AxisMapping( - { - "ax_a": [0, 1], - "ax_b": 2, - "ax_c": 3, - "ax_d": 4, - "ax_e": [5, 6], - }, + {"ax_a": [0, 1], "ax_b": 2, "ax_c": 3, "ax_d": 4, "ax_e": [5, 6]}, 7, ) result = ax_map.remove_axis(3) - expected = { - "ax_a": [0, 1], - "ax_b": 2, - "ax_d": 3, - "ax_e": [4, 5], - } + expected = {"ax_a": [0, 1], "ax_b": 2, "ax_d": 3, "ax_e": [4, 5]} + assert result == expected + result = ax_map.remove_axis(-4) assert result == expected From c11c0d69fcdb489414d99b9e8a6e31b1b73a7fb1 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 12 Jan 2024 12:36:06 +0000 Subject: [PATCH 37/61] bug(axes): enable 0-degree arrays If arr[key] returns an element of an array, arr[key, ...] returns a 0-degree array. --- pysindy/utils/axes.py | 13 +++---------- test/utils/test_axes.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 7fda04f49..731c5080b 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -215,7 +215,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): base_indexer = key output = super().__getitem__(base_indexer) if not isinstance(output, AxesArray): - return output # why? + return output # return an element from the array in_dim = self.shape key, adv_inds = standardize_indexer(self, key) bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) @@ -386,17 +386,10 @@ def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]: """Replace ellipsis in indexers with the appropriate amount of slice(None)""" # [...].index errors if list contains numpy array ellind = [ind for ind, val in enumerate(key) if val is ...][0] - new_key = [] n_new_dims = sum(ax_key is None or isinstance(ax_key, str) for ax_key in key) n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1) - new_key = ( - key[:ellind] - + n_ellipsis_dims - * [ - slice(None), - ] - + key[ellind + 1 + n_ellipsis_dims :] - ) + new_key = key[:ellind] + key[ellind + 1 :] + new_key = new_key[:ellind] + (n_ellipsis_dims * [slice(None)]) + new_key[ellind:] return new_key diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index f5576f48b..e3910e29e 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -174,6 +174,15 @@ def test_simple_slice(): assert arr[0] == 1 +# @pytest.mark.skip # TODO: make this pass +def test_0d_indexer(): + arr = AxesArray(np.ones(2), {"ax_coord": 0}) + arr_out = arr[1, ...] + assert arr_out.ndim == 0 + assert arr_out.axes == {} + assert arr_out[()] == 1 + + def test_basic_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) @@ -428,3 +437,22 @@ def test_determine_adv_broadcasting(): res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) assert res_nd == 0 assert res_start is None + + +def test_replace_ellipsis(): + key = [..., 0] + result = axes._expand_indexer_ellipsis(key, 2) + expected = [slice(None), 0] + assert result == expected + + +def test_strip_ellipsis(): + key = [1, ...] + result = axes._expand_indexer_ellipsis(key, 1) + expected = [1] + assert result == expected + + key = [..., 1] + result = axes._expand_indexer_ellipsis(key, 1) + expected = [1] + assert result == expected From bb1c73d082d5ef76a7279f31c37340581ee9be19 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:34:22 +0000 Subject: [PATCH 38/61] feat(axes): Enable np.reshape on AxesArrays Only a limited subset of reshapes with obvious relabeling semantics are allowed: For this version, it's just an outer product of some axes Also clean up typing and documentation and add reshape tests --- pysindy/utils/axes.py | 130 +++++++++++++++++++++++++++++++++++----- test/utils/test_axes.py | 37 +++++++++++- 2 files changed, 151 insertions(+), 16 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 731c5080b..dcfd6d8a1 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -42,7 +42,7 @@ class _AxisMapping: def __init__( self, - axes: dict[str, Union[int, Sequence[int]]] = None, + axes: Optional[dict[str, Union[int, Sequence[int]]]] = None, in_ndim: int = 0, ): if axes is None: @@ -75,9 +75,7 @@ def coerce_sequence(obj): ) @staticmethod - def _compat_axes( - in_dict: dict[str, Sequence[int]] - ) -> dict[str, Union[Sequence[int], int]]: + def _compat_axes(in_dict: dict[str, list[int]]) -> dict[str, Union[list[int], int]]: """Like fwd_map, but unpack single-element axis lists""" axes = {} for k, v in in_dict.items(): @@ -156,20 +154,35 @@ def ndim(self): class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. + Limitations: + * Not all numpy functions, such as ``np.flatten()``, does not have an + implementation for AxesArray, a regular numpy array is returned. + * For functions that are implemented for `AxesArray`, such as + ``np.reshape()``, use the numpy function rather than the bound + method (e.g. arr.reshape) + * Such such functions may raise ValueErrors where numpy would not, when + it is impossible to determine the output axis labels. + + Bound methods, such as arr.reshape, are not implemented. Use the functions. + While the functions in the numpy namespace will work on ``AxesArray`` + objects, the documentation must be found in their equivalent names here. + Parameters: - input_array (array-like): the data to create the array. - axes (dict): A dictionary of axis labels to shape indices. - Allowed keys: - - ax_time: int - - ax_coord: int - - ax_sample: int - - ax_spatial: List[int] + input_array: the data to create the array. + axes: A dictionary of axis labels to shape indices. Axes labels must + be of the format "ax_name". indices can be either an int or a + list of ints. Raises: - AxesWarning if axes does not match shape of input_array + * AxesWarning if axes does not match shape of input_array. + * ValueError if assigning the same axis index to multiple meanings or + assigning an axis beyond ndim. + """ - def __new__(cls, input_array, axes): + _ax_map: _AxisMapping + + def __new__(cls, input_array: NDArray, axes: dict[str, int | list[int]]): obj = np.asarray(input_array).view(cls) if axes is None: axes = {} @@ -226,10 +239,10 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): new_map = _AxisMapping( self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) ) - for new_ax_ind, new_ax_name in new_axes: + for insert_counter, (new_ax_ind, new_ax_name) in enumerate(new_axes): new_map = _AxisMapping( new_map.insert_axis(new_ax_ind, new_ax_name), - len(in_dim) - len(remove_axes) + len(new_axes), + in_ndim=len(in_dim) - len(remove_axes) + (insert_counter + 1), ) output._ax_map = new_map return output @@ -342,6 +355,72 @@ def concatenate(arrays, axis=0): return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) +@implements(np.reshape) +def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): + """Gives a new shape to an array without changing its data. + + Args: + a: Array to be reshaped + newshape: int or tuple of ints + The new shape should be compatible with the original shape. In + addition, the axis labels must make sense when the data is + translated to a new shape. Currently, the only use case supported + is to flatten an outer product of two or more axes with the same + label and size. + order: Must be "C" + """ + if order != "C": + raise ValueError("AxesArray only supports reshaping in 'C' order currently.") + out = np.reshape(np.asarray(a), newshape, order) # handle any regular errors + + new_axes = {} + if isinstance(newshape, int): + newshape = [newshape] + newshape = list(newshape) + explicit_new_size = np.multiply.reduce(np.array(newshape)) + if explicit_new_size < 0: + replace_ind = newshape.index(-1) + newshape[replace_ind] = a.size // (-1 * explicit_new_size) + + curr_base = 0 + for curr_new in range(len(newshape)): + if curr_base >= a.ndim: + raise ValueError( + "Cannot reshape an AxesArray this way. Adding a length-1 axis at" + f" dimension {curr_new} not understood." + ) + base_name = a._ax_map.reverse_map[curr_base] + if a.shape[curr_base] == newshape[curr_new]: + _compat_axes_append(new_axes, base_name, curr_new) + curr_base += 1 + elif newshape[curr_new] == 1: + raise ValueError( + f"Cannot reshape an AxesArray this way. Inserting a new axis at" + f" dimension {curr_new} of new shape is not supported" + ) + else: # outer product + remaining = newshape[curr_new] + while remaining > 1: + if a._ax_map.reverse_map[curr_base] != base_name: + raise ValueError( + "Cannot reshape an AxesArray this way. It would combine" + f" {base_name} with {a._ax_map.reverse_map[curr_base]}" + ) + remaining, error = divmod(remaining, a.shape[curr_base]) + if error: + raise ValueError( + f"Cannot reshape an AxesArray this way. Array dimension" + f" {curr_base} has size {a.shape[curr_base]}, must divide into" + f" newshape dimension {curr_new} with size" + f" {newshape[curr_new]}." + ) + curr_base += 1 + + _compat_axes_append(new_axes, base_name, curr_new) + + return AxesArray(out, axes=new_axes) + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: @@ -524,3 +603,24 @@ def wrap_axes(axes: dict, obj): except KeyError: pass return obj + + +def _compat_axes_append( + axes_dict: dict[str, Union[int, list[int]]], + ax_name: str, + newaxis: Union[int, list[int]], +) -> None: + if isinstance(newaxis, int): + try: + axes_dict[ax_name].append(newaxis) + except KeyError: + axes_dict[ax_name] = newaxis + except AttributeError: + axes_dict[ax_name] = [axes_dict[ax_name], newaxis] + else: + try: + axes_dict[ax_name] += newaxis + except KeyError: + axes_dict[ax_name] = newaxis + except AttributeError: + axes_dict[ax_name] = [axes_dict[ax_name], *newaxis] diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index e3910e29e..7f19596c2 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -140,6 +140,41 @@ def test_n_elements(): assert arr2.n_coord == 4 +def test_reshape_outer_product(): + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) + merge = np.reshape(arr, (4,)) + assert merge.axes == {"ax_a": 0} + + +def test_reshape_fill_outer_product(): + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) + merge = np.reshape(arr, (-1,)) + assert merge.axes == {"ax_a": 0} + + +def test_reshape_fill_regular(): + arr = AxesArray(np.arange(8).reshape((2, 2, 2)), {"ax_a": [0, 1], "ax_b": 2}) + merge = np.reshape(arr, (4, -1)) + assert merge.axes == {"ax_a": 0, "ax_b": 1} + + +def test_illegal_reshape(): + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) + # melding across axes + with pytest.raises(ValueError, match="Cannot reshape an AxesArray"): + np.reshape(arr, (4, 1)) + + # Add a hidden 1 in the middle! maybe a matching 1 + + # different name outer product + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": 0, "ax_b": 1}) + with pytest.raises(ValueError, match="Cannot reshape an AxesArray"): + np.reshape(arr, (4,)) + # newaxes + with pytest.raises(ValueError, match="Cannot reshape an AxesArray"): + np.reshape(arr, (2, 1, 2)) + + def test_warn_toofew_axes(): axes = {"ax_time": 0, "ax_coord": 1} with pytest.warns(AxesWarning): @@ -334,7 +369,7 @@ def test_reduce_twisted_AxisMapping(): def test_reduce_misordered_AxisMapping(): - ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 7) + ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 4) result = ax_map.remove_axis([2, 1]) expected = {"ax_a": 0, "ax_c": 1} assert result == expected From f13d5936cf8de0cce248944ae7da9b37166e0dc2 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:50:31 +0000 Subject: [PATCH 39/61] bug: Make caller more explicit to create AxesArray replace AxesArray.__dict__ with AxesArray.axes Correct the axes definitions where caller just was ok with being wrong before --- pysindy/feature_library/base.py | 9 ++++----- pysindy/feature_library/generalized_library.py | 2 +- pysindy/feature_library/polynomial_library.py | 2 +- pysindy/feature_library/sindy_pi_library.py | 2 +- pysindy/optimizers/base.py | 3 ++- test/test_feature_library.py | 1 + test/test_optimizers.py | 2 +- 7 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pysindy/feature_library/base.py b/pysindy/feature_library/base.py index 16149b27c..54697da45 100644 --- a/pysindy/feature_library/base.py +++ b/pysindy/feature_library/base.py @@ -63,10 +63,9 @@ def correct_shape(self, x: AxesArray): return x def calc_trajectory(self, diff_method, x, t): - axes = x.__dict__ x_dot = diff_method(x, t=t) - x = AxesArray(diff_method.smoothed_x_, axes) - return x, AxesArray(x_dot, axes) + x = AxesArray(diff_method.smoothed_x_, x.axes) + return x, AxesArray(x_dot, x.axes) def get_spatial_grid(self): return None @@ -337,7 +336,7 @@ def __init__( self.libraries = libraries self.inputs_per_library = inputs_per_library - def _combinations(self, lib_i, lib_j): + def _combinations(self, lib_i: AxesArray, lib_j: AxesArray) -> AxesArray: """ Compute combinations of the numerical libraries. @@ -351,7 +350,7 @@ def _combinations(self, lib_i, lib_j): lib_i.shape[lib_i.ax_coord] * lib_j.shape[lib_j.ax_coord] ) lib_full = np.reshape( - lib_i[..., :, np.newaxis] * lib_j[..., np.newaxis, :], + lib_i[..., :, "coord"] * lib_j[..., "coord", :], shape, ) diff --git a/pysindy/feature_library/generalized_library.py b/pysindy/feature_library/generalized_library.py index 3e5e24055..29834c2a8 100644 --- a/pysindy/feature_library/generalized_library.py +++ b/pysindy/feature_library/generalized_library.py @@ -237,7 +237,7 @@ def transform(self, x_full): else: xps.append(lib.transform([x])[0]) - xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].__dict__) + xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].axes) xp_full = xp_full + [xp] return xp_full diff --git a/pysindy/feature_library/polynomial_library.py b/pysindy/feature_library/polynomial_library.py index 75dbf5637..e62af38bd 100644 --- a/pysindy/feature_library/polynomial_library.py +++ b/pysindy/feature_library/polynomial_library.py @@ -225,7 +225,7 @@ def transform(self, x_full): dtype=x.dtype, order=self.order, ), - x.__dict__, + x.axes, ) for i, comb in enumerate(combinations): xp[..., i] = x[..., comb].prod(-1) diff --git a/pysindy/feature_library/sindy_pi_library.py b/pysindy/feature_library/sindy_pi_library.py index 8d5f054a7..f45cf567f 100644 --- a/pysindy/feature_library/sindy_pi_library.py +++ b/pysindy/feature_library/sindy_pi_library.py @@ -404,5 +404,5 @@ def transform(self, x_full): *[x[:, comb] for comb in f_combs] ) * f_dot(*[x_dot[:, comb] for comb in f_dot_combs]) library_idx += 1 - xp_full = xp_full + [AxesArray(xp, x.__dict__)] + xp_full = xp_full + [AxesArray(xp, x.axes)] return xp_full diff --git a/pysindy/optimizers/base.py b/pysindy/optimizers/base.py index 45d4842b2..614341b54 100644 --- a/pysindy/optimizers/base.py +++ b/pysindy/optimizers/base.py @@ -144,7 +144,8 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws): self : returns an instance of self """ x_ = AxesArray(np.asarray(x_), {"ax_sample": 0, "ax_coord": 1}) - y = AxesArray(np.asarray(y), {"ax_sample": 0, "ax_coord": 1}) + y_axes = {"ax_sample": 0} if y.ndim == 1 else {"ax_sample": 0, "ax_coord": 1} + y = AxesArray(np.asarray(y), y_axes) x_, y = drop_nan_samples(x_, y) x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True) diff --git a/test/test_feature_library.py b/test/test_feature_library.py index 8e98b1a0d..6fba611a5 100644 --- a/test/test_feature_library.py +++ b/test/test_feature_library.py @@ -247,6 +247,7 @@ def test_sindypi_library_bad_params(params): pytest.lazy_fixture("ode_library"), pytest.lazy_fixture("sindypi_library"), ], + ids=type, ) def test_fit_transform(data_lorenz, library): x, t = data_lorenz diff --git a/test/test_optimizers.py b/test/test_optimizers.py index c69ce9823..7bd657aa1 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -587,7 +587,7 @@ def test_specific_bad_parameters(error, optimizer, params, data_lorenz): def test_bad_optimizers(data_derivative_1d): x, x_dot = data_derivative_1d x = x.reshape(-1, 1) - + x_dot = x_dot.reshape(-1, 1) with pytest.raises(InvalidParameterError): # Error: optimizer does not have a callable fit method opt = WrappedOptimizer(DummyEmptyModel()) From 996d555dd9bc2ca2c54c7d144d8b91731ffd956a Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 13 Jan 2024 04:12:32 +0000 Subject: [PATCH 40/61] feat(axes): Make np.transpose work on AxesArray Finally a simple one --- pysindy/utils/axes.py | 25 +++++++++++++++++++++++++ test/utils/test_axes.py | 14 ++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index dcfd6d8a1..e5e8f9d67 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -7,6 +7,7 @@ from typing import NewType from typing import Optional from typing import Sequence +from typing import Tuple from typing import Union import numpy as np @@ -167,6 +168,11 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): While the functions in the numpy namespace will work on ``AxesArray`` objects, the documentation must be found in their equivalent names here. + Current array function implementations: + * ``np.concatenate`` + * ``np.reshape`` + * ``np.transpose`` + Parameters: input_array: the data to create the array. axes: A dictionary of axis labels to shape indices. Axes labels must @@ -421,6 +427,25 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): return AxesArray(out, axes=new_axes) +@implements(np.transpose) +def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None): + """Returns an array with axes transposed. + + Args: + a: input array + axes: As the numpy function + """ + out = np.transpose(np.asarray(a), axes) + if axes is None: + axes = range(a.ndim)[::-1] + new_axes = {} + old_reverse = a._ax_map.reverse_map + for new_ind, old_ind in enumerate(axes): + _compat_axes_append(new_axes, old_reverse[old_ind], new_ind) + + return AxesArray(out, new_axes) + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 7f19596c2..2e6b127de 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -491,3 +491,17 @@ def test_strip_ellipsis(): result = axes._expand_indexer_ellipsis(key, 1) expected = [1] assert result == expected + + +def test_transpose(): + axes = {"ax_a": 0, "ax_b": [1, 2]} + arr = AxesArray(np.arange(8).reshape(2, 2, 2), axes) + tp = np.transpose(arr, [2, 0, 1]) + result = tp.axes + expected = {"ax_a": 1, "ax_b": [0, 2]} + assert result == expected + assert_array_equal(tp, np.transpose(np.asarray(arr), [2, 0, 1])) + arr = arr[..., 0] + tp = arr.T + expected = {"ax_a": 1, "ax_b": 0} + assert_array_equal(tp, np.asarray(arr).T) From 3298f5f96da9aee0495d9d32496f803e95b66617 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 03:34:45 +0000 Subject: [PATCH 41/61] feat(axes): Make np.einsum work on AxesArray Added helper function to create fwd axis map from list of axis names --- pysindy/utils/axes.py | 122 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 105 insertions(+), 17 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index e5e8f9d67..5f042c7f8 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,13 +1,23 @@ +""" +A module that defines one external class, AxesArray, to act like a numpy array +but keep track of axis definitions. + +TODO: Add developer documentation here. +""" +from __future__ import annotations + import copy import warnings from enum import Enum from typing import Collection +from typing import Dict from typing import List from typing import Literal from typing import NewType from typing import Optional from typing import Sequence from typing import Tuple +from typing import TypeVar from typing import Union import numpy as np @@ -34,9 +44,7 @@ class Sentinels(Enum): class _AxisMapping: - """Convenience wrapper for a two-way map between axis names and - indexes. - """ + """Convenience wrapper for a two-way map between axis names and indexes.""" fwd_map: dict[str, list[int]] reverse_map: dict[int, str] @@ -75,6 +83,13 @@ def coerce_sequence(obj): AxesWarning, ) + @staticmethod + def fwd_from_names(names: List[str]) -> dict[str, Sequence[int]]: + fwd_map: dict[str, Sequence[int]] = {} + for ax_ind, name in enumerate(names): + _compat_dict_append(fwd_map, name, [ax_ind]) + return fwd_map + @staticmethod def _compat_axes(in_dict: dict[str, list[int]]) -> dict[str, Union[list[int], int]]: """Like fwd_map, but unpack single-element axis lists""" @@ -397,7 +412,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): ) base_name = a._ax_map.reverse_map[curr_base] if a.shape[curr_base] == newshape[curr_new]: - _compat_axes_append(new_axes, base_name, curr_new) + _compat_dict_append(new_axes, base_name, curr_new) curr_base += 1 elif newshape[curr_new] == 1: raise ValueError( @@ -422,7 +437,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): ) curr_base += 1 - _compat_axes_append(new_axes, base_name, curr_new) + _compat_dict_append(new_axes, base_name, curr_new) return AxesArray(out, axes=new_axes) @@ -441,11 +456,80 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) new_axes = {} old_reverse = a._ax_map.reverse_map for new_ind, old_ind in enumerate(axes): - _compat_axes_append(new_axes, old_reverse[old_ind], new_ind) + _compat_dict_append(new_axes, old_reverse[old_ind], new_ind) return AxesArray(out, new_axes) +@implements(np.einsum) +def _einsum( + subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs +) -> AxesArray: + calc = np.einsum(subscripts, *operands, out=out, **kwargs) + try: + # explicit mode + lscripts, rscript = "->".split(subscripts) + except ValueError: + # implicit mode + lscripts = subscripts + rscripts = "".join( + sorted(c for c in set(subscripts) if subscripts.count(c) > 1 and c != ",") + ) + # 0-dimensional case, may just be better to check type of "calc": + if rscripts == "": + return calc + allscript_names: List[Dict[str, List[str]]] = [] + # script -> axis name for each left script + for lscr, op in zip(lscripts, operands): + script_names: Dict[str, List[str]] = {} + allscript_names.append(script_names) + # handle script ellipses + try: + ell_ind = lscr.index("...") + ell_width = op.ndim - (len(lscr) - 3) + ell_expand = range(ell_ind, ell_ind + ell_width) + ell_names = [op._ax_map.reverse_map[ax_ind] for ax_ind in ell_expand] + script_names["..."] = ell_names + except ValueError: + ell_ind = len(lscr) + ell_width = 0 + # handle script non-ellipsis chars + shift = 0 + for ax_ind, char in enumerate(lscr): + if char == ".": + shift += 1 + continue + if ax_ind < ell_ind: + scr_name = op._ax_map.reverse_map[ax_ind] + else: + scr_name = op._ax_map.reverse_map[ax_ind - 3 + ell_width] + _compat_dict_append(script_names, char, [scr_name]) + + # assemble output reverse map + out_names = [] + shift = 0 + + def _join_unique_names(l_of_s: List[str]) -> str: + ordered_uniques = dict.fromkeys(l_of_s).keys() + return "_".join(ax_name.lstrip("ax_") for ax_name in ordered_uniques) + + for char in rscript.replace("...", "."): + if char == ".": + for script_names in allscript_names: + out_names += script_names.get("...", []) + else: + ax_names = [] + for script_names in allscript_names: + ax_names += script_names.get(char, []) + ax_names = "ax_" + _join_unique_names(ax_names) + out_names.append(ax_names) + + out_axes = _AxisMapping.fwd_from_names(out_names) + if isinstance(out, AxesArray): + out._ax_map = _AxisMapping(out_axes, calc.ndim) + return AxesArray(calc, axes=out_axes) + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: @@ -630,22 +714,26 @@ def wrap_axes(axes: dict, obj): return obj -def _compat_axes_append( - axes_dict: dict[str, Union[int, list[int]]], - ax_name: str, - newaxis: Union[int, list[int]], +T = TypeVar("T") # TODO: Bind to a non-sequence after type-negation PEP + + +def _compat_dict_append( + compat_dict: dict[str, Union[T, list[T]]], + key: str, + item_or_list: Union[T, list[T]], ) -> None: - if isinstance(newaxis, int): + """Add an element or list of elements to a dictionary, preserving old values""" + if not isinstance(item_or_list, list): try: - axes_dict[ax_name].append(newaxis) + compat_dict[key].append(item_or_list) except KeyError: - axes_dict[ax_name] = newaxis + compat_dict[key] = item_or_list except AttributeError: - axes_dict[ax_name] = [axes_dict[ax_name], newaxis] + compat_dict[key] = [compat_dict[key], item_or_list] else: try: - axes_dict[ax_name] += newaxis + compat_dict[key] += item_or_list except KeyError: - axes_dict[ax_name] = newaxis + compat_dict[key] = item_or_list except AttributeError: - axes_dict[ax_name] = [axes_dict[ax_name], *newaxis] + compat_dict[key] = [compat_dict[key], *item_or_list] From 18d449e35940207b31c7fdafba76d25a2921e339 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 03:43:14 +0000 Subject: [PATCH 42/61] bug: clean up callers of AxesArray --- pysindy/differentiation/finite_difference.py | 8 ++++---- pysindy/utils/axes.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pysindy/differentiation/finite_difference.py b/pysindy/differentiation/finite_difference.py index 39f9bddd3..5783045f3 100644 --- a/pysindy/differentiation/finite_difference.py +++ b/pysindy/differentiation/finite_difference.py @@ -201,20 +201,20 @@ def _constant_coefficients(self, dt): def _accumulate(self, coeffs, x): # slice to select the stencil indices - s = [slice(None)] * len(x.shape) + s = [slice(None)] * x.ndim s[self.axis] = self.stencil_inds - # a new axis is introduced after self.axis for the stencil indices + # a new axis is introduced before self.axis for the stencil indices # To contract with the coefficients, roll by -self.axis to put it first # Then roll back by self.axis to return the order - trans = np.roll(np.arange(len(x.shape) + 1), -self.axis) + trans = np.roll(np.arange(x.ndim + 1), -self.axis) return np.transpose( np.einsum( "ij...,ij->j...", np.transpose(x[tuple(s)], axes=trans), np.transpose(coeffs), ), - np.roll(np.arange(len(x.shape)), self.axis), + np.roll(np.arange(x.ndim), self.axis), ) def _differentiate(self, x, t): diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 5f042c7f8..1347922df 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -691,13 +691,15 @@ def concat_sample_axis(x_list: List[AxesArray]): """Concatenate all trajectories and axes used to create samples.""" new_arrs = [] for x in x_list: - sample_axes = ( - x.ax_spatial - + ([x.ax_time] if x.ax_time is not None else []) - + ([x.ax_sample] if x.ax_sample is not None else []) - ) + sample_ax_names = ("ax_spatial", "ax_time", "ax_sample") + sample_ax_inds = [] + for name in sample_ax_names: + ax_inds = getattr(x, name, []) + if isinstance(ax_inds, int): + ax_inds = [ax_inds] + sample_ax_inds += ax_inds new_axes = {"ax_sample": 0, "ax_coord": 1} - n_samples = np.prod([x.shape[ax] for ax in sample_axes]) + n_samples = np.prod([x.shape[ax] for ax in sample_ax_inds]) arr = AxesArray(x.reshape((n_samples, x.shape[x.ax_coord])), new_axes) new_arrs.append(arr) return np.concatenate(new_arrs, axis=new_arrs[0].ax_sample) From cc6025e663bc7edcba2befa57667ba6031417e63 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 16:57:45 +0000 Subject: [PATCH 43/61] feat(axes): Support numpy.ix_ --- pysindy/utils/axes.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index a634f2fdb..4f698daa7 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -188,6 +188,12 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): * ``np.reshape`` * ``np.transpose`` + Indexing: + AxesArray supports all of the basic and advanced indexing of numpy + arrays, with the addition that new axes can be inserted with a string + name for the axis. If ``None`` or ``np.newaxis`` are passed, the + axis is named "unk". + Parameters: input_array: the data to create the array. axes: A dictionary of axis labels to shape indices. Axes labels must @@ -366,6 +372,14 @@ def decorator(func): return decorator +@implements(np.ix_) +def ix_(*args: AxesArray): + calc = np.ix_(*(np.asarray(arg) for arg in args)) + ax_names = [list(arr.axes)[0] for arr in args] + axes = _AxisMapping.fwd_from_names(ax_names) + return tuple(AxesArray(arr, axes) for arr in calc) + + @implements(np.concatenate) def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): parents = [np.asarray(obj) for obj in arrays] From f5b201594d90d1c47fd0e8d38e59a4dd78c2055c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 18:30:28 +0000 Subject: [PATCH 44/61] bug: Make axes explicit in PDEs --- pysindy/feature_library/pde_library.py | 8 +--- pysindy/feature_library/weak_pde_library.py | 41 +++++++++++++-------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/pysindy/feature_library/pde_library.py b/pysindy/feature_library/pde_library.py index d6c8666c9..8042a221b 100644 --- a/pysindy/feature_library/pde_library.py +++ b/pysindy/feature_library/pde_library.py @@ -276,13 +276,7 @@ def get_feature_names(self, input_features=None): def derivative_string(multiindex): ret = "" for axis in range(self.ind_range): - if self.implicit_terms and ( - axis - in [ - self.spatiotemporal_grid.ax_time, - self.spatiotemporal_grid.ax_sample, - ] - ): + if self.implicit_terms and (axis == self.spatiotemporal_grid.ax_time,): str_deriv = "t" else: str_deriv = str(axis + 1) diff --git a/pysindy/feature_library/weak_pde_library.py b/pysindy/feature_library/weak_pde_library.py index 5aa3cbbb6..02ed2851f 100644 --- a/pysindy/feature_library/weak_pde_library.py +++ b/pysindy/feature_library/weak_pde_library.py @@ -9,6 +9,7 @@ from sklearn.utils.validation import check_is_fitted from ..utils import AxesArray +from ..utils import comprehend_axes from .base import BaseFeatureLibrary from .base import x_sequence_or_item from pysindy.differentiation import FiniteDifference @@ -245,7 +246,10 @@ def __init__( self.num_derivatives = num_derivatives self.multiindices = multiindices - self.spatiotemporal_grid = spatiotemporal_grid + + self.spatiotemporal_grid = AxesArray( + spatiotemporal_grid, axes=comprehend_axes(spatiotemporal_grid) + ) # Weak form checks and setup self._weak_form_setup() @@ -255,12 +259,14 @@ def _weak_form_setup(self): L_xt = xt2 - xt1 if self.H_xt is not None: if np.isscalar(self.H_xt): - self.H_xt = np.array(self.grid_ndim * [self.H_xt]) + self.H_xt = AxesArray( + np.array(self.grid_ndim * [self.H_xt]), {"ax_coord": 0} + ) if self.grid_ndim != len(self.H_xt): raise ValueError( "The user-defined grid (spatiotemporal_grid) and " "the user-defined sizes of the subdomains for the " - "weak form, do not have the same # of spatiotemporal " + "weak form do not have the same # of spatiotemporal " "dimensions. For instance, if spatiotemporal_grid is 4D, " "then H_xt should be a 4D list of the subdomain lengths." ) @@ -285,8 +291,8 @@ def _weak_form_setup(self): self._set_up_weights() def _get_spatial_endpoints(self): - x1 = np.zeros(self.grid_ndim) - x2 = np.zeros(self.grid_ndim) + x1 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0}) + x2 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0}) for i in range(self.grid_ndim): inds = [slice(None)] * (self.grid_ndim + 1) for j in range(self.grid_ndim): @@ -306,7 +312,9 @@ def _set_up_weights(self): # Sample the random domain centers xt1, xt2 = self._get_spatial_endpoints() - domain_centers = np.zeros((self.K, self.grid_ndim)) + domain_centers = AxesArray( + np.zeros((self.K, self.grid_ndim)), {"ax_sample": 0, "ax_coord": 1} + ) for i in range(self.grid_ndim): domain_centers[:, i] = np.random.uniform( xt1[i] + self.H_xt[i], xt2[i] - self.H_xt[i], size=self.K @@ -321,15 +329,12 @@ def _set_up_weights(self): s = [0] * (self.grid_ndim + 1) s[i] = slice(None) s[-1] = i - newinds = np.intersect1d( - np.where( - self.spatiotemporal_grid[tuple(s)] - >= domain_centers[k][i] - self.H_xt[i] - ), - np.where( - self.spatiotemporal_grid[tuple(s)] - <= domain_centers[k][i] + self.H_xt[i] - ), + ax_vals = self.spatiotemporal_grid[tuple(s)] + cell_left = domain_centers[k][i] - self.H_xt[i] + cell_right = domain_centers[k][i] + self.H_xt[i] + newinds = AxesArray( + ((ax_vals > cell_left) & (ax_vals < cell_right)).nonzero()[0], + ax_vals.axes, ) # If less than two indices along any axis, resample if len(newinds) < 2: @@ -346,6 +351,7 @@ def _set_up_weights(self): self.inds_k = self.inds_k + [inds] k = k + 1 + # TODO: fix meaning of axes in XT_k # Values of the spatiotemporal grid on the domain cells XT_k = [ self.spatiotemporal_grid[np.ix_(*self.inds_k[k])] for k in range(self.K) @@ -468,6 +474,11 @@ def _set_up_weights(self): ) weights1 = weights1 + [weights2] + # TODO: get rest of code to work with AxesArray + deaxify = lambda arr_list: [np.asarray(arr) for arr in arr_list] + tweights = deaxify(tweights) + weights0 = deaxify(weights0) + weights1 = deaxify(weights1) # Product weights over the axes for time derivatives, shaped as inds_k self.fulltweights = [] deriv = np.zeros(self.grid_ndim) From e0eb87e42835016460c6c066be8613bc76847d55 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 19:53:04 +0000 Subject: [PATCH 45/61] fix(axes): Prevent inf recursive einsum --- pysindy/utils/axes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 4f698daa7..cea8ca1dd 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -482,7 +482,9 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) def _einsum( subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs ) -> AxesArray: - calc = np.einsum(subscripts, *operands, out=out, **kwargs) + calc = np.einsum( + subscripts, *(np.asarray(arr) for arr in operands), out=out, **kwargs + ) try: # explicit mode lscripts, rscript = "->".split(subscripts) From c93ca169e4809e85a10535e400a48cc40721b9aa Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 19:54:09 +0000 Subject: [PATCH 46/61] feat(axes): Add np.linalg.solve --- pysindy/differentiation/finite_difference.py | 9 +++++---- pysindy/utils/axes.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pysindy/differentiation/finite_difference.py b/pysindy/differentiation/finite_difference.py index 5783045f3..f29ab8ec5 100644 --- a/pysindy/differentiation/finite_difference.py +++ b/pysindy/differentiation/finite_difference.py @@ -1,6 +1,7 @@ import numpy as np from .base import BaseDifferentiation +from pysindy.utils.axes import AxesArray class FiniteDifference(BaseDifferentiation): @@ -94,12 +95,12 @@ def _coefficients(self, t): self.stencil - t[ (self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2, - np.newaxis, + "coord", ] )[:, np.newaxis, :] ** pows - b = np.zeros(self.n_stencil) - b[self.d] = np.math.factorial(self.d) - return np.linalg.solve(matrices, [b]) + b = AxesArray(np.zeros((1, self.n_stencil)), self.stencil.axes) + b[0, self.d] = np.math.factorial(self.d) + return np.linalg.solve(matrices, b) def _coefficients_boundary_forward(self, t): # use the same stencil for each boundary point, diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index cea8ca1dd..ca24d5094 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -549,6 +549,19 @@ def _join_unique_names(l_of_s: List[str]) -> str: return AxesArray(calc, axes=out_axes) +@implements(np.linalg.solve) +def solve(a: AxesArray, b: AxesArray): + result = np.linalg.solve(np.asarray(a), np.asarray(b)) + a_rev = a._ax_map.reverse_map + contracted_axis_name = a_rev[sorted(a_rev)[-1]] + b_rev = b._ax_map.reverse_map + rest_of_names = [b_rev[k] for k in sorted(b_rev)] + axes = _AxisMapping.fwd_from_names( + [*rest_of_names[:-2], contracted_axis_name, rest_of_names[-1]] + ) + return AxesArray(result, axes) + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: From 690aa0604a3bf97301ce1072cfbc65f8f53257eb Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:31:34 +0000 Subject: [PATCH 47/61] bug: Enable AxesArray in FiniteDifference internals --- pysindy/differentiation/finite_difference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pysindy/differentiation/finite_difference.py b/pysindy/differentiation/finite_difference.py index f29ab8ec5..17e828dc1 100644 --- a/pysindy/differentiation/finite_difference.py +++ b/pysindy/differentiation/finite_difference.py @@ -1,4 +1,7 @@ +from typing import Union + import numpy as np +from numpy.typing import NDArray from .base import BaseDifferentiation from pysindy.utils.axes import AxesArray @@ -218,7 +221,9 @@ def _accumulate(self, coeffs, x): np.roll(np.arange(x.ndim), self.axis), ) - def _differentiate(self, x, t): + def _differentiate( + self, x: NDArray, t: Union[NDArray, float, list[float]] + ) -> NDArray: """ Apply finite difference method. """ @@ -249,6 +254,7 @@ def _differentiate(self, x, t): s[self.axis] = slice(start, stop) interior = interior + x[tuple(s)] * coeffs[i] else: + t = AxesArray(np.array(t), axes={"ax_time": 0}) coeffs = self._coefficients(t) interior = self._accumulate(coeffs, x) s[self.axis] = slice((self.n_stencil - 1) // 2, -(self.n_stencil - 1) // 2) From 264f8b324a57a24e91deeddf133ca3f07d938164 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:20:34 +0000 Subject: [PATCH 48/61] bug(axes): Fix transpose and einsum bugs --- pysindy/utils/axes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index ca24d5094..8f7ad59e1 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -487,19 +487,19 @@ def _einsum( ) try: # explicit mode - lscripts, rscript = "->".split(subscripts) + lscripts, rscript = subscripts.split("->") except ValueError: # implicit mode lscripts = subscripts - rscripts = "".join( + rscript = "".join( sorted(c for c in set(subscripts) if subscripts.count(c) > 1 and c != ",") ) # 0-dimensional case, may just be better to check type of "calc": - if rscripts == "": + if rscript == "": return calc allscript_names: List[Dict[str, List[str]]] = [] # script -> axis name for each left script - for lscr, op in zip(lscripts, operands): + for lscr, op in zip(lscripts.split(","), operands): script_names: Dict[str, List[str]] = {} allscript_names.append(script_names) # handle script ellipses @@ -540,8 +540,8 @@ def _join_unique_names(l_of_s: List[str]) -> str: ax_names = [] for script_names in allscript_names: ax_names += script_names.get(char, []) - ax_names = "ax_" + _join_unique_names(ax_names) - out_names.append(ax_names) + ax_name = "ax_" + _join_unique_names(ax_names) + out_names.append(ax_name) out_axes = _AxisMapping.fwd_from_names(out_names) if isinstance(out, AxesArray): @@ -550,7 +550,7 @@ def _join_unique_names(l_of_s: List[str]) -> str: @implements(np.linalg.solve) -def solve(a: AxesArray, b: AxesArray): +def solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map contracted_axis_name = a_rev[sorted(a_rev)[-1]] From f0fc6b3715287bc691f297bb203125b25bc53374 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:36:06 +0000 Subject: [PATCH 49/61] fix(axes): Pass correct ndim to _AxisMapping() --- pysindy/utils/axes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 8f7ad59e1..334887b9b 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -389,7 +389,7 @@ def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): raise TypeError("Concatenating >1 AxesArray with incompatible axes") result = np.concatenate(parents, axis, out=out, dtype=dtype, casting=casting) if isinstance(out, AxesArray): - out.__dict__ = ax_list[0] + out._ax_map = _AxisMapping(ax_list[0], in_ndim=result.ndim) return AxesArray(result, axes=ax_list[0]) From 8084fd47d1dffe18b48ffef57cc1c9ff35208df9 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 20:12:15 +0000 Subject: [PATCH 50/61] feat(axes): Add tensordot function for AxesArrays dispatches to einsum, which is apparently faster anyways Still need to write tests for linalg_solve, einsum, and tensordot --- pysindy/feature_library/weak_pde_library.py | 10 ++-- pysindy/utils/axes.py | 65 ++++++++++++++------- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/pysindy/feature_library/weak_pde_library.py b/pysindy/feature_library/weak_pde_library.py index 02ed2851f..1566a2bca 100644 --- a/pysindy/feature_library/weak_pde_library.py +++ b/pysindy/feature_library/weak_pde_library.py @@ -474,11 +474,11 @@ def _set_up_weights(self): ) weights1 = weights1 + [weights2] - # TODO: get rest of code to work with AxesArray - deaxify = lambda arr_list: [np.asarray(arr) for arr in arr_list] - tweights = deaxify(tweights) - weights0 = deaxify(weights0) - weights1 = deaxify(weights1) + # TODO: get rest of code to work with AxesArray. Too unsure of + # which axis labels to use at this point to continue + tweights = [np.asarray(arr) for arr in tweights] + weights0 = [np.asarray(arr) for arr in weights0] + weights1 = [[np.asarray(arr) for arr in sublist] for sublist in weights1] # Product weights over the axes for time derivatives, shaped as inds_k self.fulltweights = [] deriv = np.zeros(self.grid_ndim) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 334887b9b..0c592a9dc 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -40,9 +40,6 @@ class Sentinels(Enum): ADV_REMOVE = object() -Literal[Sentinels.ADV_NAME] - - class _AxisMapping: """Convenience wrapper for a two-way map between axis names and indexes.""" @@ -479,7 +476,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) @implements(np.einsum) -def _einsum( +def einsum( subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs ) -> AxesArray: calc = np.einsum( @@ -550,7 +547,7 @@ def _join_unique_names(l_of_s: List[str]) -> str: @implements(np.linalg.solve) -def solve(a: AxesArray, b: AxesArray) -> AxesArray: +def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map contracted_axis_name = a_rev[sorted(a_rev)[-1]] @@ -562,6 +559,34 @@ def solve(a: AxesArray, b: AxesArray) -> AxesArray: return AxesArray(result, axes) +@implements(np.tensordot) +def tensordot( + a: AxesArray, b: AxesArray, axes: Union[int, Sequence[Sequence[int]]] = 2 +) -> AxesArray: + sub = _tensordot_to_einsum(a.ndim, b.ndim, axes) + return einsum(sub, a, b) + + +def _tensordot_to_einsum( + a_ndim: int, b_ndim: int, axes: Union[int, Sequence[Sequence[int]]] +) -> str: + lc_ord = range(97, 123) + if isinstance(axes, int): + if axes > 26: + raise ValueError("Too many axes") + sub_a = f"...{[chr(code) for code in lc_ord[:axes]]}" + sub_b_li = f"{[chr(code) for code in lc_ord[:axes]]}..." + sub = sub_a + sub_b_li + else: + sub_a = f"{[chr(code) for code in lc_ord[:a_ndim]]}" + sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]] + for a_ind, b_ind in zip(*axes): + sub_b_li[b_ind] - sub_a[a_ind] + sub_b = "".join(sub_b_li) + sub = f"{sub_a},{sub_b}" + return sub + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: @@ -748,26 +773,24 @@ def wrap_axes(axes: dict, obj): return obj -T = TypeVar("T") # TODO: Bind to a non-sequence after type-negation PEP +T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP +ItemOrList = Union[T, list[T]] +CompatDict = dict[str, ItemOrList[T]] def _compat_dict_append( - compat_dict: dict[str, Union[T, list[T]]], + compat_dict: CompatDict[T], key: str, - item_or_list: Union[T, list[T]], + item_or_list: ItemOrList[T], ) -> None: """Add an element or list of elements to a dictionary, preserving old values""" + try: + prev_val = compat_dict[key] + except KeyError: + compat_dict[key] = item_or_list + return if not isinstance(item_or_list, list): - try: - compat_dict[key].append(item_or_list) - except KeyError: - compat_dict[key] = item_or_list - except AttributeError: - compat_dict[key] = [compat_dict[key], item_or_list] - else: - try: - compat_dict[key] += item_or_list - except KeyError: - compat_dict[key] = item_or_list - except AttributeError: - compat_dict[key] = [compat_dict[key], *item_or_list] + item_or_list = [item_or_list] + if not isinstance(prev_val, list): + prev_val = [prev_val] + compat_dict[key] = prev_val + item_or_list From 8f1e4bc6b651a0b6f07a626fdcac1ab2cf05e95c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 20:44:39 +0000 Subject: [PATCH 51/61] test(axes): Add linalg.solve() tests for AxesArray --- test/utils/test_axes.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index f60eaef9d..00d992547 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -512,3 +512,36 @@ def test_transpose(): tp = arr.T expected = {"ax_a": 1, "ax_b": 0} assert_array_equal(tp, np.asarray(arr).T) + + +def test_linalg_solve_align_left(): + axesA = {"ax_prob": 0, "ax_sample": 1, "ax_coord": 2} + arrA = AxesArray(np.arange(8).reshape(2, 2, 2), axesA) + axesb = {"ax_prob": 0, "ax_sample": 1} + arrb = AxesArray(np.arange(4).reshape(2, 2), axesb) + result = np.linalg.solve(arrA, arrb) + expected_axes = {"ax_prob": 0, "ax_coord": 1} + assert result.axes == expected_axes + super_result = np.linalg.solve(np.asarray(arrA), np.asarray(arrb)) + assert_array_equal(result, super_result) + + +def test_linalg_solve_align_right(): + axesA = {"ax_sample": 0, "ax_feature": 1} + arrA = AxesArray(np.arange(4).reshape(2, 2), axesA) + axesb = {"ax_sample": 0, "ax_target": 1} + arrb = AxesArray(np.arange(4).reshape(2, 2), axesb) + result = np.linalg.solve(arrA, arrb) + expected_axes = {"ax_feature": 0, "ax_target": 1} + assert result.axes == expected_axes + super_result = np.linalg.solve(np.asarray(arrA), np.asarray(arrb)) + assert_array_equal(result, super_result) + + +def test_linalg_solve_incompatible_left(): + axesA = {"ax_prob": 0, "ax_sample": 1, "ax_coord": 2} + arrA = AxesArray(np.arange(8).reshape(2, 2, 2), axesA) + axesb = {"ax_foo": 0, "ax_sample": 1} + arrb = AxesArray(np.arange(4).reshape(2, 2), axesb) + with pytest.raises(ValueError, match="fdsafds"): + np.linalg.solve(arrA, arrb) From 3dacb89368d48cd778573e64391d8f3876c7616d Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 21:19:35 +0000 Subject: [PATCH 52/61] bug(axes) Change axis alignment linalg_solve + test --- pysindy/utils/axes.py | 16 ++++++++++++---- test/utils/test_axes.py | 42 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 0c592a9dc..27c10abc5 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -550,12 +550,20 @@ def _join_unique_names(l_of_s: List[str]) -> str: def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map - contracted_axis_name = a_rev[sorted(a_rev)[-1]] + a_names = [a_rev[k] for k in sorted(a_rev)] + contracted_axis_name = a_names[-1] b_rev = b._ax_map.reverse_map - rest_of_names = [b_rev[k] for k in sorted(b_rev)] - axes = _AxisMapping.fwd_from_names( - [*rest_of_names[:-2], contracted_axis_name, rest_of_names[-1]] + b_names = [b_rev[k] for k in sorted(b_rev)] + match_axes_list = a_names[:-1] + start = max(b.ndim - a.ndim, 0) + end = start + len(match_axes_list) + align = slice(start, end) + if match_axes_list != b_names[align]: + raise ValueError("Mismatch in operand axis names when aligning A and b") + all_names = ( + b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :] ) + axes = _AxisMapping.fwd_from_names(all_names) return AxesArray(result, axes) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 00d992547..38b19350b 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -538,10 +538,50 @@ def test_linalg_solve_align_right(): assert_array_equal(result, super_result) +def test_linalg_solve_align_right_xl(): + axesA = {"ax_sample": 0, "ax_feature": 1} + arrA = AxesArray(np.arange(4).reshape(2, 2), axesA) + axesb = {"ax_prob": 0, "ax_sample": 1, "ax_target": 2} + arrb = AxesArray(np.arange(8).reshape(2, 2, 2), axesb) + result = np.linalg.solve(arrA, arrb) + expected_axes = {"ax_prob": 0, "ax_feature": 1, "ax_target": 2} + assert result.axes == expected_axes + super_result = np.linalg.solve(np.asarray(arrA), np.asarray(arrb)) + assert_array_equal(result, super_result) + + def test_linalg_solve_incompatible_left(): axesA = {"ax_prob": 0, "ax_sample": 1, "ax_coord": 2} arrA = AxesArray(np.arange(8).reshape(2, 2, 2), axesA) axesb = {"ax_foo": 0, "ax_sample": 1} arrb = AxesArray(np.arange(4).reshape(2, 2), axesb) - with pytest.raises(ValueError, match="fdsafds"): + with pytest.raises(ValueError, match="Mismatch in operand axis names"): np.linalg.solve(arrA, arrb) + + +def test_tensordot_int_axes(): + ... + + +def test_tensordot_list_axes(): + ... + + +def test_einsum_implicit(): + ... + + +def test_einsum_trace(): + ... + + +def test_einsum_diag(): + ... + + +def test_einsum_contraction(): + ... + + +def test_einsum_mixed(): + ... From 21c3b0100cde38886264f692d1201259feb90287 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 21:27:33 +0000 Subject: [PATCH 53/61] test(axes): Add tensordot tests --- test/utils/test_axes.py | 43 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 38b19350b..a59fd2891 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -559,12 +559,51 @@ def test_linalg_solve_incompatible_left(): np.linalg.solve(arrA, arrb) +def test_ts_to_einsum_int_axes(): + a_str, b_str = axes._tensordot_to_einsum(3, 3, 2).split(",") + # expecting 'abc,bcf + assert a_str[1] == b_str[0] + assert a_str[2] == b_str[1] + assert a_str[0] not in b_str + assert b_str[2] not in a_str + + +def test_ts_to_einsum_list_axes(): + a_str, b_str = axes._tensordot_to_einsum(3, 3, [[1], [2]]).split(",") + # expecting 'abcd,efbh + assert a_str[0] not in b_str + assert a_str[1] == b_str[2] + assert a_str[2] not in b_str + assert a_str[3] not in b_str + assert b_str[0] not in a_str + assert b_str[1] not in a_str + assert b_str[3] not in a_str + + def test_tensordot_int_axes(): - ... + axes_a = {"ax_a": 0, "ax_b": [1, 2]} + axes_b = {"ax_b": [0, 1], "ax_c": 2} + arr = np.arange(8).reshape((2, 2, 2)) + arr_a = AxesArray(arr, axes_a) + arr_b = AxesArray(arr, axes_b) + result = np.tensordot(arr_a, arr_b, 2) + super_result = np.tensordot(arr, arr, 2) + expected_axes = {"ax_a": 0, "ax_c": 1} + assert result.axes == expected_axes + assert_array_equal(result, super_result) def test_tensordot_list_axes(): - ... + axes_a = {"ax_a": 0, "ax_b": [1, 2]} + axes_b = {"ax_c": [0, 1], "ax_b": 2} + arr = np.arange(8).reshape((2, 2, 2)) + arr_a = AxesArray(arr, axes_a) + arr_b = AxesArray(arr, axes_b) + result = np.tensordot(arr_a, arr_b, [[1], [2]]) + super_result = np.tensordot(arr, arr, 2) + expected_axes = {"ax_a": 0, "ax_b": 1, "ax_c": [2, 3]} + assert result.axes == expected_axes + assert_array_equal(result, super_result) def test_einsum_implicit(): From 36ae58c685712facef5dce7edc947703ea0911a6 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 16 Jan 2024 11:00:16 +0000 Subject: [PATCH 54/61] fix(axes): pass ts-to-einsum tests --- pysindy/utils/axes.py | 13 +++++++------ test/utils/test_axes.py | 15 +++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 27c10abc5..7e4438848 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -582,16 +582,17 @@ def _tensordot_to_einsum( if isinstance(axes, int): if axes > 26: raise ValueError("Too many axes") - sub_a = f"...{[chr(code) for code in lc_ord[:axes]]}" - sub_b_li = f"{[chr(code) for code in lc_ord[:axes]]}..." - sub = sub_a + sub_b_li + sub_a = "..." + "".join([chr(code) for code in lc_ord[:axes]]) + sub_b = "".join([chr(code) for code in lc_ord[:axes]]) + "..." else: - sub_a = f"{[chr(code) for code in lc_ord[:a_ndim]]}" + sub_a = "".join([chr(code) for code in lc_ord[:a_ndim]]) sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]] for a_ind, b_ind in zip(*axes): - sub_b_li[b_ind] - sub_a[a_ind] + if a_ind > 26 or b_ind > 26: + raise ValueError("Too many axes") + sub_b_li[b_ind] = sub_a[a_ind] sub_b = "".join(sub_b_li) - sub = f"{sub_a},{sub_b}" + sub = f"{sub_a},{sub_b}" return sub diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index a59fd2891..3ed1f75da 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -562,22 +562,21 @@ def test_linalg_solve_incompatible_left(): def test_ts_to_einsum_int_axes(): a_str, b_str = axes._tensordot_to_einsum(3, 3, 2).split(",") # expecting 'abc,bcf - assert a_str[1] == b_str[0] - assert a_str[2] == b_str[1] - assert a_str[0] not in b_str - assert b_str[2] not in a_str + assert a_str[:3] == "..." + assert b_str[-3:] == "..." + a_str = a_str.lstrip("...") + b_str = b_str.rstrip("...") + assert a_str == b_str def test_ts_to_einsum_list_axes(): a_str, b_str = axes._tensordot_to_einsum(3, 3, [[1], [2]]).split(",") # expecting 'abcd,efbh - assert a_str[0] not in b_str assert a_str[1] == b_str[2] + assert a_str[0] not in b_str assert a_str[2] not in b_str - assert a_str[3] not in b_str assert b_str[0] not in a_str assert b_str[1] not in a_str - assert b_str[3] not in a_str def test_tensordot_int_axes(): @@ -600,7 +599,7 @@ def test_tensordot_list_axes(): arr_a = AxesArray(arr, axes_a) arr_b = AxesArray(arr, axes_b) result = np.tensordot(arr_a, arr_b, [[1], [2]]) - super_result = np.tensordot(arr, arr, 2) + super_result = np.tensordot(arr, arr, [[1], [2]]) expected_axes = {"ax_a": 0, "ax_b": 1, "ax_c": [2, 3]} assert result.axes == expected_axes assert_array_equal(result, super_result) From 730a58237f2f708aea255e3a66ee4033d920cebb Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 16 Jan 2024 12:46:33 +0000 Subject: [PATCH 55/61] fix(einsum): Replace lstrip with removeprefix in renaming axes --- pysindy/utils/axes.py | 59 +++++++++++++++++++++++------------------ test/utils/test_axes.py | 20 ++++++++++++-- 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 7e4438848..fd39e12c2 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -489,13 +489,43 @@ def einsum( # implicit mode lscripts = subscripts rscript = "".join( - sorted(c for c in set(subscripts) if subscripts.count(c) > 1 and c != ",") + sorted(c for c in set(subscripts) if subscripts.count(c) == 1 and c != ",") ) # 0-dimensional case, may just be better to check type of "calc": if rscript == "": return calc + + # assemble output reverse map + allscript_names = _label_einsum_scripts(lscripts, operands) + out_names = [] + + for char in rscript.replace("...", "."): + if char == ".": + for script_names in allscript_names: + out_names += script_names.get("...", []) + else: + ax_names = [] + for script_names in allscript_names: + ax_names += script_names.get(char, []) + ax_name = "ax_" + _join_unique_names(ax_names) + out_names.append(ax_name) + + out_axes = _AxisMapping.fwd_from_names(out_names) + if isinstance(out, AxesArray): + out._ax_map = _AxisMapping(out_axes, calc.ndim) + return AxesArray(calc, axes=out_axes) + + +def _join_unique_names(l_of_s: List[str]) -> str: + ordered_uniques = dict.fromkeys(l_of_s).keys() + return "_".join(ax_name.removeprefix("ax_") for ax_name in ordered_uniques) + + +def _label_einsum_scripts( + lscripts: list[str], operands: tuple[AxesArray] +) -> list[dict[str, str]]: + """Create a list of what axis name each script refers to in its operand.""" allscript_names: List[Dict[str, List[str]]] = [] - # script -> axis name for each left script for lscr, op in zip(lscripts.split(","), operands): script_names: Dict[str, List[str]] = {} allscript_names.append(script_names) @@ -520,30 +550,7 @@ def einsum( else: scr_name = op._ax_map.reverse_map[ax_ind - 3 + ell_width] _compat_dict_append(script_names, char, [scr_name]) - - # assemble output reverse map - out_names = [] - shift = 0 - - def _join_unique_names(l_of_s: List[str]) -> str: - ordered_uniques = dict.fromkeys(l_of_s).keys() - return "_".join(ax_name.lstrip("ax_") for ax_name in ordered_uniques) - - for char in rscript.replace("...", "."): - if char == ".": - for script_names in allscript_names: - out_names += script_names.get("...", []) - else: - ax_names = [] - for script_names in allscript_names: - ax_names += script_names.get(char, []) - ax_name = "ax_" + _join_unique_names(ax_names) - out_names.append(ax_name) - - out_axes = _AxisMapping.fwd_from_names(out_names) - if isinstance(out, AxesArray): - out._ax_map = _AxisMapping(out_axes, calc.ndim) - return AxesArray(calc, axes=out_axes) + return allscript_names @implements(np.linalg.solve) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 3ed1f75da..126f1adce 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -585,8 +585,8 @@ def test_tensordot_int_axes(): arr = np.arange(8).reshape((2, 2, 2)) arr_a = AxesArray(arr, axes_a) arr_b = AxesArray(arr, axes_b) - result = np.tensordot(arr_a, arr_b, 2) super_result = np.tensordot(arr, arr, 2) + result = np.tensordot(arr_a, arr_b, 2) expected_axes = {"ax_a": 0, "ax_c": 1} assert result.axes == expected_axes assert_array_equal(result, super_result) @@ -598,8 +598,8 @@ def test_tensordot_list_axes(): arr = np.arange(8).reshape((2, 2, 2)) arr_a = AxesArray(arr, axes_a) arr_b = AxesArray(arr, axes_b) - result = np.tensordot(arr_a, arr_b, [[1], [2]]) super_result = np.tensordot(arr, arr, [[1], [2]]) + result = np.tensordot(arr_a, arr_b, [[1], [2]]) expected_axes = {"ax_a": 0, "ax_b": 1, "ax_c": [2, 3]} assert result.axes == expected_axes assert_array_equal(result, super_result) @@ -617,9 +617,25 @@ def test_einsum_diag(): ... +def test_einsum_1dsum(): + ... + + +def test_einsum_alldsum(): + ... + + def test_einsum_contraction(): ... +def test_einsum_explicit_ellipsis(): + ... + + +def test_einsum_scalar(): + ... + + def test_einsum_mixed(): ... From 93f14a76c2199d18326944a5757ac89cc3a7be4b Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:29:25 +0000 Subject: [PATCH 56/61] bug(finite_difference): Wrap internal arrays as AxesArrays Also add helper methods to AxesArray class --- pysindy/differentiation/finite_difference.py | 34 +++++++++++++------- pysindy/utils/axes.py | 17 +++++++++- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/pysindy/differentiation/finite_difference.py b/pysindy/differentiation/finite_difference.py index 17e828dc1..9ccb7ca5c 100644 --- a/pysindy/differentiation/finite_difference.py +++ b/pysindy/differentiation/finite_difference.py @@ -88,20 +88,25 @@ def __init__( def _coefficients(self, t): nt = len(t) - self.stencil_inds = np.array( - [np.arange(i, nt - self.n_stencil + i + 1) for i in range(self.n_stencil)] + self.stencil_inds = AxesArray( + np.array( + [ + np.arange(i, nt - self.n_stencil + i + 1) + for i in range(self.n_stencil) + ] + ), + {"ax_offset": 0, "ax_ti": 1}, + ) + self.stencil = AxesArray( + np.transpose(t[self.stencil_inds]), {"ax_time": 0, "ax_offset": 1} ) - self.stencil = np.transpose(t[self.stencil_inds]) - pows = np.arange(self.n_stencil)[np.newaxis, :, np.newaxis] - matrices = ( + dt_endpoints = ( self.stencil - - t[ - (self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2, - "coord", - ] - )[:, np.newaxis, :] ** pows - b = AxesArray(np.zeros((1, self.n_stencil)), self.stencil.axes) + - t[(self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2, "offset"] + ) + matrices = dt_endpoints[:, "power", :] ** pows + b = AxesArray(np.zeros((1, self.n_stencil)), {"ax_time": 0, "ax_power": 1}) b[0, self.d] = np.math.factorial(self.d) return np.linalg.solve(matrices, b) @@ -212,10 +217,15 @@ def _accumulate(self, coeffs, x): # To contract with the coefficients, roll by -self.axis to put it first # Then roll back by self.axis to return the order trans = np.roll(np.arange(x.ndim + 1), -self.axis) + # TODO: assign x's axes much earlier in the call stack + x = AxesArray(x, {"ax_unk": list(range(x.ndim))}) + x_expanded = AxesArray( + np.transpose(x[tuple(s)], axes=trans), x.insert_axis(0, "ax_offset") + ) return np.transpose( np.einsum( "ij...,ij->j...", - np.transpose(x[tuple(s)], axes=trans), + x_expanded, np.transpose(coeffs), ), np.roll(np.arange(x.ndim), self.axis), diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index fd39e12c2..1d5136cdf 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -197,6 +197,11 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): be of the format "ax_name". indices can be either an int or a list of ints. + Attributes: + axes: dictionary of axis name to dimension index/indices + ax_: lookup ax_name in axes + n_: lookup shape of subarray defined by ax_name + Raises: * AxesWarning if axes does not match shape of input_array. * ValueError if assigning the same axis index to multiple meanings or @@ -206,7 +211,7 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): _ax_map: _AxisMapping - def __new__(cls, input_array: NDArray, axes: dict[str, int | list[int]]): + def __new__(cls, input_array: NDArray, axes: CompatDict[int]): obj = np.asarray(input_array).view(cls) if axes is None: axes = {} @@ -226,6 +231,16 @@ def _reverse_map(self): def shape(self): return super().shape + def insert_axis( + self, axis: Union[Collection[int], int], new_name: str + ) -> CompatDict[int]: + """Create the constructor axes dict from this array, with new axis/axes""" + return self._ax_map.insert_axis(axis, new_name) + + def remove_axis(self, axis: Union[Collection[int], int]) -> CompatDict[int]: + """Create the constructor axes dict from this array, without axis/axes""" + return self._ax_map.remove_axis(axis) + def __getattr__(self, name): # TODO: replace with structural pattern matching on Oct 2025 (3.9 EOL) parts = name.split("_", 1) From 323d115897b3ff815885e1c7af7559317ed13807 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:41:21 +0000 Subject: [PATCH 57/61] bug(AxesArray.tensordot): Adapt int index to list of lists --- pysindy/utils/axes.py | 29 +++++++++++++++++------------ test/utils/test_axes.py | 17 ++++++++++++----- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 1d5136cdf..46a0d4666 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -3,6 +3,15 @@ but keep track of axis definitions. TODO: Add developer documentation here. + +The recommended way to refactor existing code to use AxesArrays is to add them +at the lowest level possible. Enter debug mode and see how long the expected +axes persist throughout array operations. When AxesArray loses track of the +correct axes, re-assign them with an AxesArray constructor (which only uses a +view of the data). + +Starting at the macro level runs the risk of triggering a great deal of errors +from unimplemented functions. """ from __future__ import annotations @@ -601,19 +610,15 @@ def _tensordot_to_einsum( a_ndim: int, b_ndim: int, axes: Union[int, Sequence[Sequence[int]]] ) -> str: lc_ord = range(97, 123) + sub_a = "".join([chr(code) for code in lc_ord[:a_ndim]]) if isinstance(axes, int): - if axes > 26: - raise ValueError("Too many axes") - sub_a = "..." + "".join([chr(code) for code in lc_ord[:axes]]) - sub_b = "".join([chr(code) for code in lc_ord[:axes]]) + "..." - else: - sub_a = "".join([chr(code) for code in lc_ord[:a_ndim]]) - sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]] - for a_ind, b_ind in zip(*axes): - if a_ind > 26 or b_ind > 26: - raise ValueError("Too many axes") - sub_b_li[b_ind] = sub_a[a_ind] - sub_b = "".join(sub_b_li) + axes = [range(-axes, 0), range(0, axes)] + sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]] + if np.array(axes).max() > 26: + raise ValueError("Too many axes") + for a_ind, b_ind in zip(*axes): + sub_b_li[b_ind] = sub_a[a_ind] + sub_b = "".join(sub_b_li) sub = f"{sub_a},{sub_b}" return sub diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 126f1adce..b4f8fb3d4 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -562,11 +562,9 @@ def test_linalg_solve_incompatible_left(): def test_ts_to_einsum_int_axes(): a_str, b_str = axes._tensordot_to_einsum(3, 3, 2).split(",") # expecting 'abc,bcf - assert a_str[:3] == "..." - assert b_str[-3:] == "..." - a_str = a_str.lstrip("...") - b_str = b_str.rstrip("...") - assert a_str == b_str + assert a_str[0] not in b_str + assert b_str[-1] not in a_str + assert a_str[1:] == b_str[:-1] def test_ts_to_einsum_list_axes(): @@ -605,37 +603,46 @@ def test_tensordot_list_axes(): assert_array_equal(result, super_result) +@pytest.mark.skip def test_einsum_implicit(): ... +@pytest.mark.skip def test_einsum_trace(): ... +@pytest.mark.skip def test_einsum_diag(): ... +@pytest.mark.skip def test_einsum_1dsum(): ... +@pytest.mark.skip def test_einsum_alldsum(): ... +@pytest.mark.skip def test_einsum_contraction(): ... +@pytest.mark.skip def test_einsum_explicit_ellipsis(): ... +@pytest.mark.skip def test_einsum_scalar(): ... +@pytest.mark.skip def test_einsum_mixed(): ... From cef716744ba4c4c4b950c8f18c162501fa6bd1cf Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:12:22 +0000 Subject: [PATCH 58/61] clean: downgrade typing syntax and stdlibrary use to python 3.8 Also gitignore env8 directory, where I keep my python3.8 environment --- .gitignore | 1 + pysindy/differentiation/finite_difference.py | 3 +- pysindy/utils/axes.py | 51 +++++++++++--------- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index cebac22c5..f862d8a3a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ venv/ ENV/ env.bak/ venv.bak/ +env8 # automatically generated by setuptools-scm pysindy/version.py diff --git a/pysindy/differentiation/finite_difference.py b/pysindy/differentiation/finite_difference.py index 9ccb7ca5c..69c5ebafa 100644 --- a/pysindy/differentiation/finite_difference.py +++ b/pysindy/differentiation/finite_difference.py @@ -1,3 +1,4 @@ +from typing import List from typing import Union import numpy as np @@ -232,7 +233,7 @@ def _accumulate(self, coeffs, x): ) def _differentiate( - self, x: NDArray, t: Union[NDArray, float, list[float]] + self, x: NDArray, t: Union[NDArray, float, List[float]] ) -> NDArray: """ Apply finite difference method. diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 46a0d4666..fa51e4863 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -20,6 +20,7 @@ from enum import Enum from typing import Collection from typing import Dict +from typing import get_args from typing import List from typing import Literal from typing import NewType @@ -36,9 +37,9 @@ HANDLED_FUNCTIONS = {} AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) -BasicIndexer = Union[slice, int, type(Ellipsis), type(None), str] -Indexer = BasicIndexer | NDArray | list -StandardIndexer = Union[slice, int, type(None), NDArray[np.dtype(int)]] +BasicIndexer = Union[slice, int, type(Ellipsis), None, str] +Indexer = Union[BasicIndexer, NDArray, List] +StandardIndexer = Union[slice, int, None, NDArray[np.dtype(int)]] OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) @@ -52,8 +53,8 @@ class Sentinels(Enum): class _AxisMapping: """Convenience wrapper for a two-way map between axis names and indexes.""" - fwd_map: dict[str, list[int]] - reverse_map: dict[int, str] + fwd_map: Dict[str, List[int]] + reverse_map: Dict[int, str] def __init__( self, @@ -90,14 +91,14 @@ def coerce_sequence(obj): ) @staticmethod - def fwd_from_names(names: List[str]) -> dict[str, Sequence[int]]: - fwd_map: dict[str, Sequence[int]] = {} + def fwd_from_names(names: List[str]) -> Dict[str, Sequence[int]]: + fwd_map: Dict[str, Sequence[int]] = {} for ax_ind, name in enumerate(names): _compat_dict_append(fwd_map, name, [ax_ind]) return fwd_map @staticmethod - def _compat_axes(in_dict: dict[str, list[int]]) -> dict[str, Union[list[int], int]]: + def _compat_axes(in_dict: Dict[str, List[int]]) -> Dict[str, Union[list[int], int]]: """Like fwd_map, but unpack single-element axis lists""" axes = {} for k, v in in_dict.items(): @@ -269,7 +270,7 @@ def __getattr__(self, name): return shape raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") - def __getitem__(self, key: Indexer | Sequence[Indexer], /): + def __getitem__(self, key: Union[Indexer, Sequence[Indexer]], /): if isinstance(key, tuple): base_indexer = tuple(None if isinstance(k, str) else k for k in key) else: @@ -542,12 +543,14 @@ def einsum( def _join_unique_names(l_of_s: List[str]) -> str: ordered_uniques = dict.fromkeys(l_of_s).keys() - return "_".join(ax_name.removeprefix("ax_") for ax_name in ordered_uniques) + return "_".join( + ax_name[3:] if ax_name[:3] == "ax_" else ax_name for ax_name in ordered_uniques + ) def _label_einsum_scripts( - lscripts: list[str], operands: tuple[AxesArray] -) -> list[dict[str, str]]: + lscripts: List[str], operands: tuple[AxesArray] +) -> List[dict[str, str]]: """Create a list of what axis name each script refers to in its operand.""" allscript_names: List[Dict[str, List[str]]] = [] for lscr, op in zip(lscripts.split(","), operands): @@ -644,9 +647,9 @@ def standardize_indexer( if not any(ax_key is Ellipsis for ax_key in key): key = [*key, Ellipsis] - new_key: list[Indexer] = [] + new_key: List[Indexer] = [] for ax_key in key: - if not isinstance(ax_key, BasicIndexer): + if not isinstance(ax_key, get_args(BasicIndexer)): ax_key = np.array(ax_key) if ax_key.dtype == np.dtype(np.bool_): new_key += ax_key.nonzero() @@ -655,7 +658,7 @@ def standardize_indexer( new_key = _expand_indexer_ellipsis(new_key, arr.ndim) # Can't identify position of advanced indexers before expanding ellipses - adv_inds: list[KeyIndex] = [] + adv_inds: List[KeyIndex] = [] for key_ind, ax_key in enumerate(new_key): if isinstance(ax_key, np.ndarray): adv_inds.append(KeyIndex(key_ind)) @@ -663,7 +666,7 @@ def standardize_indexer( return new_key, tuple(adv_inds) -def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]: +def _expand_indexer_ellipsis(key: List[Indexer], ndim: int) -> List[Indexer]: """Replace ellipsis in indexers with the appropriate amount of slice(None)""" # [...].index errors if list contains numpy array ellind = [ind for ind, val in enumerate(key) if val is ...][0] @@ -686,9 +689,9 @@ def _determine_adv_broadcasting( def _rename_broadcast_axes( - new_axes: list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], - adv_names: list[str], -) -> list[tuple[int, str]]: + new_axes: List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], + adv_names: List[str], +) -> List[tuple[int, str]]: """Normalize sentinel and NoneType names""" def _calc_bcast_name(*names: str) -> str: @@ -713,7 +716,7 @@ def _calc_bcast_name(*names: str) -> str: def replace_adv_indexers( key: Sequence[StandardIndexer], - adv_inds: list[int], + adv_inds: List[int], bcast_start_ax: int, bcast_nd: int, ) -> tuple[ @@ -727,9 +730,9 @@ def replace_adv_indexers( def _apply_indexing( - key: tuple[StandardIndexer], reverse_map: dict[int, str] + key: tuple[StandardIndexer], reverse_map: Dict[int, str] ) -> tuple[ - list[int], list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], list[str] + List[int], List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], List[str] ]: """Determine where axes should be removed and added @@ -810,8 +813,8 @@ def wrap_axes(axes: dict, obj): T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP -ItemOrList = Union[T, list[T]] -CompatDict = dict[str, ItemOrList[T]] +ItemOrList = Union[T, List[T]] +CompatDict = Dict[str, ItemOrList[T]] def _compat_dict_append( From 2c56053c2b65b6bb67835bafda9524637ffd17fa Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 17 Jan 2024 02:31:49 +0000 Subject: [PATCH 59/61] doc: Fix doc build errors. Upgrade sphinx Newer sphinx gives more accurate line numbers for errors --- pyproject.toml | 2 +- pysindy/utils/axes.py | 68 ++++++++++++++++++++++------------------- test/utils/test_axes.py | 10 +++--- 3 files changed, 43 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 60028b62a..65fc0b263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ docs = [ "ipython", "pandoc", "sphinx-rtd-theme", - "sphinx==5.3.0", + "sphinx==7.1.2", "sphinxcontrib-apidoc", "nbsphinx" ] diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index fa51e4863..d90c6d959 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -2,7 +2,9 @@ A module that defines one external class, AxesArray, to act like a numpy array but keep track of axis definitions. -TODO: Add developer documentation here. +.. todo:: + + Add developer documentation here. The recommended way to refactor existing code to use AxesArrays is to add them at the lowest level possible. Enter debug mode and see how long the expected @@ -43,6 +45,9 @@ OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) +T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP +ItemOrList = Union[T, List[T]] +CompatDict = Dict[str, ItemOrList[T]] class Sentinels(Enum): @@ -178,27 +183,31 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. Limitations: - * Not all numpy functions, such as ``np.flatten()``, does not have an - implementation for AxesArray, a regular numpy array is returned. - * For functions that are implemented for `AxesArray`, such as - ``np.reshape()``, use the numpy function rather than the bound - method (e.g. arr.reshape) - * Such such functions may raise ValueErrors where numpy would not, when - it is impossible to determine the output axis labels. - Bound methods, such as arr.reshape, are not implemented. Use the functions. - While the functions in the numpy namespace will work on ``AxesArray`` - objects, the documentation must be found in their equivalent names here. + * Not all numpy functions, such as ``np.flatten()``, have an + implementation for ``AxesArray``. In such cases a regular numpy array + is returned. + * For functions that are implemented for `AxesArray`, such as + ``np.reshape()``, use the numpy function rather than the bound + method (e.g. ``arr.reshape``) + * Such functions may raise ``ValueError`` where numpy would not, when + it is impossible to determine the output axis labels. Current array function implementations: + * ``np.concatenate`` * ``np.reshape`` * ``np.transpose`` + * ``np.linalg.solve`` + * ``np.einsum`` + * ``np.tensordot`` Indexing: AxesArray supports all of the basic and advanced indexing of numpy arrays, with the addition that new axes can be inserted with a string - name for the axis. If ``None`` or ``np.newaxis`` are passed, the + name for the axis. E.g. ``arr = arr[..., "lineno"]`` will add a + length-one axis at the end, along with the properties ``arr.ax_lineno`` + and ``arr.n_lineno``. If ``None`` or ``np.newaxis`` are passed, the axis is named "unk". Parameters: @@ -215,7 +224,7 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): Raises: * AxesWarning if axes does not match shape of input_array. * ValueError if assigning the same axis index to multiple meanings or - assigning an axis beyond ndim. + assigning an axis beyond ndim. """ @@ -239,6 +248,7 @@ def _reverse_map(self): @property def shape(self): + """Shape of array. Unlike numpy ndarray, this is not assignable.""" return super().shape def insert_axis( @@ -279,10 +289,10 @@ def __getitem__(self, key: Union[Indexer, Sequence[Indexer]], /): if not isinstance(output, AxesArray): return output # return an element from the array in_dim = self.shape - key, adv_inds = standardize_indexer(self, key) + key, adv_inds = _standardize_indexer(self, key) bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) if adv_inds: - key = replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) + key = _replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) remove_axes, new_axes, adv_names = _apply_indexing(key, self._reverse_map) new_axes = _rename_broadcast_axes(new_axes, adv_names) new_map = _AxisMapping( @@ -384,8 +394,8 @@ def __array_function__(self, func, types, args, kwargs): return HANDLED_FUNCTIONS[func](*args, **kwargs) -def implements(numpy_function): - """Register an __array_function__ implementation for MyArray objects.""" +def _implements(numpy_function): + """Register an __array_function__ implementation for AxesArray objects.""" def decorator(func): HANDLED_FUNCTIONS[numpy_function] = func @@ -394,7 +404,7 @@ def decorator(func): return decorator -@implements(np.ix_) +@_implements(np.ix_) def ix_(*args: AxesArray): calc = np.ix_(*(np.asarray(arg) for arg in args)) ax_names = [list(arr.axes)[0] for arr in args] @@ -402,7 +412,7 @@ def ix_(*args: AxesArray): return tuple(AxesArray(arr, axes) for arr in calc) -@implements(np.concatenate) +@_implements(np.concatenate) def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): parents = [np.asarray(obj) for obj in arrays] ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)] @@ -415,7 +425,7 @@ def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): return AxesArray(result, axes=ax_list[0]) -@implements(np.reshape) +@_implements(np.reshape) def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): """Gives a new shape to an array without changing its data. @@ -481,7 +491,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): return AxesArray(out, axes=new_axes) -@implements(np.transpose) +@_implements(np.transpose) def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None): """Returns an array with axes transposed. @@ -500,7 +510,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) return AxesArray(out, new_axes) -@implements(np.einsum) +@_implements(np.einsum) def einsum( subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs ) -> AxesArray: @@ -580,7 +590,7 @@ def _label_einsum_scripts( return allscript_names -@implements(np.linalg.solve) +@_implements(np.linalg.solve) def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map @@ -601,7 +611,7 @@ def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: return AxesArray(result, axes) -@implements(np.tensordot) +@_implements(np.tensordot) def tensordot( a: AxesArray, b: AxesArray, axes: Union[int, Sequence[Sequence[int]]] = 2 ) -> AxesArray: @@ -626,7 +636,7 @@ def _tensordot_to_einsum( return sub -def standardize_indexer( +def _standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: """Convert any legal numpy indexer to a "standard" form. @@ -635,6 +645,7 @@ def standardize_indexer( one element per index of the original axis. All advanced indexer elements are converted to numpy arrays, and boolean arrays are converted to integer arrays with obj.nonzero(). + Returns: A tuple of the normalized indexer as well as the indexes of advanced indexers @@ -714,7 +725,7 @@ def _calc_bcast_name(*names: str) -> str: return renamed_axes -def replace_adv_indexers( +def _replace_adv_indexers( key: Sequence[StandardIndexer], adv_inds: List[int], bcast_start_ax: int, @@ -812,11 +823,6 @@ def wrap_axes(axes: dict, obj): return obj -T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP -ItemOrList = Union[T, List[T]] -CompatDict = Dict[str, ItemOrList[T]] - - def _compat_dict_append( compat_dict: CompatDict[T], key: str, diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b4f8fb3d4..c7327f240 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -292,11 +292,11 @@ def test_adv_indexing_adds_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) - result_indexer, result_fancy = axes.standardize_indexer(arr, Ellipsis) + result_indexer, result_fancy = axes._standardize_indexer(arr, Ellipsis) assert result_indexer == [slice(None), slice(None)] assert result_fancy == () - result_indexer, result_fancy = axes.standardize_indexer( + result_indexer, result_fancy = axes._standardize_indexer( arr, (np.newaxis, 1, 1, Ellipsis) ) assert result_indexer == [None, 1, 1] @@ -305,11 +305,11 @@ def test_standardize_basic_indexer(): def test_standardize_advanced_indexer(): arr = np.arange(6).reshape(2, 3) - result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) + result_indexer, result_fancy = axes._standardize_indexer(arr, [1]) assert result_indexer == [np.ones(1), slice(None)] assert result_fancy == (0,) - result_indexer, result_fancy = axes.standardize_indexer( + result_indexer, result_fancy = axes._standardize_indexer( arr, (np.newaxis, [1], 1, Ellipsis) ) assert result_indexer == [None, np.ones(1), 1] @@ -318,7 +318,7 @@ def test_standardize_advanced_indexer(): def test_standardize_bool_indexer(): arr = np.ones((1, 2)) - result, result_adv = axes.standardize_indexer(arr, [[True, True]]) + result, result_adv = axes._standardize_indexer(arr, [[True, True]]) assert_equal(result, [[0, 0], [0, 1]]) assert result_adv == (0, 1) From 3ede6d0bc25e00b427d5a05ef95b61f24e38c40f Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:48:42 +0000 Subject: [PATCH 60/61] feat/doc(axes): Make helpers public so docs pick them up --- pysindy/utils/axes.py | 98 ++++++++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 30 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index d90c6d959..4224cf550 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,6 +1,41 @@ """ A module that defines one external class, AxesArray, to act like a numpy array -but keep track of axis definitions. +but keep track of axis definitions. It aims to allow meaningful replacement +of magic numbers for axis conventions in code. E.g:: + + import numpy as np + + arr = AxesArray(np.ones((2,3,4)), {"ax_time": 0, "ax_spatial": [1, 2]}) + print(arr.axes) + print(arr.ax_time) + print(arr.n_time) + print(arr.ax_spatial) + print(arr.n_spatial) + +Would show:: + + {"ax_time": 0, "ax_spatial": [1, 2]} + 0 + 2 + [1, 2] + [3, 4] + +It is up to the user to handle the ``list[int] | int`` return values, but this +module has several functions to deal with the axes dictionary, internally +referred to as type ``CompatDict[T]``: + +Appending an item to a ``CompatDict[T]`` + :py:func:`compat_dict_append` + +Generating a ``CompatDict[int]`` of axes from list of axes names: + :py:func:`fwd_from_names` + +Create new ``CompatDict[int]`` from this ``AxesArray`` with new axis/axes added: + :py:meth:`AxesArray.insert_axis` + +Create new ``CompatDict[int]`` from this ``AxesArray`` with axis/axes removed: + :py:meth:`AxesArray.remove_axis` + .. todo:: @@ -50,7 +85,7 @@ CompatDict = Dict[str, ItemOrList[T]] -class Sentinels(Enum): +class _Sentinels(Enum): ADV_NAME = object() ADV_REMOVE = object() @@ -95,13 +130,6 @@ def coerce_sequence(obj): AxesWarning, ) - @staticmethod - def fwd_from_names(names: List[str]) -> Dict[str, Sequence[int]]: - fwd_map: Dict[str, Sequence[int]] = {} - for ax_ind, name in enumerate(names): - _compat_dict_append(fwd_map, name, [ax_ind]) - return fwd_map - @staticmethod def _compat_axes(in_dict: Dict[str, List[int]]) -> Dict[str, Union[list[int], int]]: """Like fwd_map, but unpack single-element axis lists""" @@ -222,9 +250,9 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): n_: lookup shape of subarray defined by ax_name Raises: - * AxesWarning if axes does not match shape of input_array. - * ValueError if assigning the same axis index to multiple meanings or - assigning an axis beyond ndim. + AxesWarning if axes does not match shape of input_array. + ValueError if assigning the same axis index to multiple meanings or + assigning an axis beyond ndim. """ @@ -310,7 +338,7 @@ def __array_wrap__(self, out_arr, context=None): return super().__array_wrap__(self, out_arr, context) def __array_finalize__(self, obj) -> None: - if obj is None: # explicit construction via super().__new__().. not called? + if obj is None: # explicit construction via super().__new__() return # view from numpy array, called in constructor but also tests if all( @@ -408,7 +436,7 @@ def decorator(func): def ix_(*args: AxesArray): calc = np.ix_(*(np.asarray(arg) for arg in args)) ax_names = [list(arr.axes)[0] for arr in args] - axes = _AxisMapping.fwd_from_names(ax_names) + axes = fwd_from_names(ax_names) return tuple(AxesArray(arr, axes) for arr in calc) @@ -461,7 +489,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): ) base_name = a._ax_map.reverse_map[curr_base] if a.shape[curr_base] == newshape[curr_new]: - _compat_dict_append(new_axes, base_name, curr_new) + compat_dict_append(new_axes, base_name, curr_new) curr_base += 1 elif newshape[curr_new] == 1: raise ValueError( @@ -486,7 +514,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): ) curr_base += 1 - _compat_dict_append(new_axes, base_name, curr_new) + compat_dict_append(new_axes, base_name, curr_new) return AxesArray(out, axes=new_axes) @@ -505,7 +533,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) new_axes = {} old_reverse = a._ax_map.reverse_map for new_ind, old_ind in enumerate(axes): - _compat_dict_append(new_axes, old_reverse[old_ind], new_ind) + compat_dict_append(new_axes, old_reverse[old_ind], new_ind) return AxesArray(out, new_axes) @@ -545,7 +573,7 @@ def einsum( ax_name = "ax_" + _join_unique_names(ax_names) out_names.append(ax_name) - out_axes = _AxisMapping.fwd_from_names(out_names) + out_axes = fwd_from_names(out_names) if isinstance(out, AxesArray): out._ax_map = _AxisMapping(out_axes, calc.ndim) return AxesArray(calc, axes=out_axes) @@ -586,7 +614,7 @@ def _label_einsum_scripts( scr_name = op._ax_map.reverse_map[ax_ind] else: scr_name = op._ax_map.reverse_map[ax_ind - 3 + ell_width] - _compat_dict_append(script_names, char, [scr_name]) + compat_dict_append(script_names, char, [scr_name]) return allscript_names @@ -607,7 +635,7 @@ def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: all_names = ( b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :] ) - axes = _AxisMapping.fwd_from_names(all_names) + axes = fwd_from_names(all_names) return AxesArray(result, axes) @@ -700,7 +728,7 @@ def _determine_adv_broadcasting( def _rename_broadcast_axes( - new_axes: List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], + new_axes: List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], adv_names: List[str], ) -> List[tuple[int, str]]: """Normalize sentinel and NoneType names""" @@ -718,7 +746,7 @@ def _calc_bcast_name(*names: str) -> str: for ax_ind, ax_name in new_axes: if ax_name is None: renamed_axes.append((ax_ind, "ax_unk")) - elif ax_name is Sentinels.ADV_NAME: + elif ax_name is _Sentinels.ADV_NAME: renamed_axes.append((ax_ind, bcast_name)) else: renamed_axes.append((ax_ind, "ax_" + ax_name)) @@ -731,19 +759,19 @@ def _replace_adv_indexers( bcast_start_ax: int, bcast_nd: int, ) -> tuple[ - Union[None, str, int, Literal[Sentinels.ADV_NAME], Literal[Sentinels.ADV_REMOVE]], + Union[None, str, int, Literal[_Sentinels.ADV_NAME], Literal[_Sentinels.ADV_REMOVE]], ..., ]: for adv_ind in adv_inds: - key[adv_ind] = Sentinels.ADV_REMOVE - key = key[:bcast_start_ax] + bcast_nd * [Sentinels.ADV_NAME] + key[bcast_start_ax:] + key[adv_ind] = _Sentinels.ADV_REMOVE + key = key[:bcast_start_ax] + bcast_nd * [_Sentinels.ADV_NAME] + key[bcast_start_ax:] return key def _apply_indexing( key: tuple[StandardIndexer], reverse_map: Dict[int, str] ) -> tuple[ - List[int], List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], List[str] + List[int], List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], List[str] ]: """Determine where axes should be removed and added @@ -756,14 +784,16 @@ def _apply_indexing( deleted_to_left = 0 added_to_left = 0 for key_ind, indexer in enumerate(key): - if isinstance(indexer, int) or indexer is Sentinels.ADV_REMOVE: + if isinstance(indexer, int) or indexer is _Sentinels.ADV_REMOVE: orig_arr_axis = key_ind - added_to_left - if indexer is Sentinels.ADV_REMOVE: + if indexer is _Sentinels.ADV_REMOVE: adv_names.append(reverse_map[orig_arr_axis]) remove_axes.append(orig_arr_axis) deleted_to_left += 1 elif ( - indexer is None or indexer is Sentinels.ADV_NAME or isinstance(indexer, str) + indexer is None + or indexer is _Sentinels.ADV_NAME + or isinstance(indexer, str) ): new_arr_axis = key_ind - deleted_to_left new_axes.append((new_arr_axis, indexer)) @@ -823,7 +853,7 @@ def wrap_axes(axes: dict, obj): return obj -def _compat_dict_append( +def compat_dict_append( compat_dict: CompatDict[T], key: str, item_or_list: ItemOrList[T], @@ -839,3 +869,11 @@ def _compat_dict_append( if not isinstance(prev_val, list): prev_val = [prev_val] compat_dict[key] = prev_val + item_or_list + + +def fwd_from_names(names: List[str]) -> CompatDict[int]: + """Create mapping of name: axis or name: [ax_1, ax_2, ...]""" + fwd_map: Dict[str, Sequence[int]] = {} + for ax_ind, name in enumerate(names): + compat_dict_append(fwd_map, name, [ax_ind]) + return fwd_map From 9c76e797ffee5cd28eab7307694871cf2c56e8d0 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:52:33 +0000 Subject: [PATCH 61/61] tst(axes): Cover more lines! Added a bunch of tests, mostly to check ValueErrors emitted correctly Remove default argument to _AxisMapping, AxesArray Remove __array_wrap__ --- pysindy/utils/axes.py | 26 ++++++-------------------- test/utils/test_axes.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 4224cf550..ed27957aa 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -98,11 +98,9 @@ class _AxisMapping: def __init__( self, - axes: Optional[dict[str, Union[int, Sequence[int]]]] = None, - in_ndim: int = 0, + axes: dict[str, Union[int, Sequence[int]]], + in_ndim: int, ): - if axes is None: - axes = {} self.fwd_map = {} self.reverse_map = {} @@ -260,8 +258,6 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): def __new__(cls, input_array: NDArray, axes: CompatDict[int]): obj = np.asarray(input_array).view(cls) - if axes is None: - axes = {} in_ndim = len(input_array.shape) obj._ax_map = _AxisMapping(axes, in_ndim) return obj @@ -334,9 +330,6 @@ def __getitem__(self, key: Union[Indexer, Sequence[Indexer]], /): output._ax_map = new_map return output - def __array_wrap__(self, out_arr, context=None): - return super().__array_wrap__(self, out_arr, context) - def __array_finalize__(self, obj) -> None: if obj is None: # explicit construction via super().__new__() return @@ -348,7 +341,7 @@ def __array_finalize__(self, obj) -> None: not hasattr(self, "_ax_map"), ) ): - self._ax_map = _AxisMapping({}) + self._ax_map = _AxisMapping({}, in_ndim=0) # required by ravel() and view() used in numpy testing. Also for zeros_like... elif all( ( @@ -357,7 +350,7 @@ def __array_finalize__(self, obj) -> None: self.shape == obj.shape, ) ): - self._ax_map = _AxisMapping(obj.axes, len(obj.shape)) + self._ax_map = _AxisMapping(obj.axes, obj.ndim) # maybe add errors for incompatible views? def __array_ufunc__( @@ -409,14 +402,7 @@ def __array_ufunc__( def __array_function__(self, func, types, args, kwargs): if func not in HANDLED_FUNCTIONS: - arr = super(AxesArray, self).__array_function__(func, types, args, kwargs) - if isinstance(arr, AxesArray): - return arr - elif isinstance(arr, np.ndarray): - return AxesArray(arr, axes=self.axes) - elif arr is not None: - return arr - return + return super(AxesArray, self).__array_function__(func, types, args, kwargs) if not all(issubclass(t, AxesArray) for t in types): return NotImplemented return HANDLED_FUNCTIONS[func](*args, **kwargs) @@ -446,7 +432,7 @@ def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)] for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]): if ax1 != ax2: - raise TypeError("Concatenating >1 AxesArray with incompatible axes") + raise ValueError("Concatenating >1 AxesArray with incompatible axes") result = np.concatenate(parents, axis, out=out, dtype=dtype, casting=casting) if isinstance(out, AxesArray): out._ax_map = _AxisMapping(ax_list[0], in_ndim=result.ndim) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index c7327f240..b26a73890 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -11,6 +11,10 @@ from pysindy.utils.axes import AxesWarning +def test_axesarray_create(): + AxesArray(np.array(1), {}) + + def test_concat_out(): arr = AxesArray(np.arange(3).reshape(1, 3), {"ax_a": 0, "ax_b": 1}) arr_out = np.empty((2, 3)).view(AxesArray) @@ -18,6 +22,13 @@ def test_concat_out(): assert_equal(result, arr_out) +def test_bad_concat(): + arr = AxesArray(np.arange(3).reshape(1, 3), {"ax_a": 0, "ax_b": 1}) + arr2 = AxesArray(np.arange(3).reshape(1, 3), {"ax_b": 0, "ax_c": 1}) + with pytest.raises(ValueError): + np.concatenate((arr, arr2), axis=0) + + def test_reduce_mean_noinf_recursion(): arr = AxesArray(np.array([[1]]), {"ax_a": [0, 1]}) np.mean(arr, axis=0) @@ -153,6 +164,14 @@ def test_reshape_outer_product(): assert merge.axes == {"ax_a": 0} +def test_reshape_bad_divmod(): + arr = AxesArray(np.arange(12).reshape((2, 3, 2)), {"ax_a": [0, 1], "ax_b": 2}) + with pytest.raises( + ValueError, match="Cannot reshape an AxesArray this way. Array dimension" + ): + np.reshape(arr, (4, 3)) + + def test_reshape_fill_outer_product(): arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) merge = np.reshape(arr, (-1,)) @@ -335,6 +354,13 @@ def test_reduce_AxisMapping(): assert result == expected +def test_reduce_all_AxisMapping(): + ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2}, 3) + result = ax_map.remove_axis() + expected = {} + assert result == expected + + def test_reduce_multiple_AxisMapping(): ax_map = _AxisMapping( { @@ -638,9 +664,11 @@ def test_einsum_explicit_ellipsis(): ... -@pytest.mark.skip def test_einsum_scalar(): - ... + arr = AxesArray(np.ones(1), {"ax_a": 0}) + expected = 1 + result = np.einsum("i,i", arr, arr) + assert result == expected @pytest.mark.skip