From 1a3fe8d03b42a7613319f5eb0cea5ea1443c745b Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 8 Dec 2022 17:03:18 -0800 Subject: [PATCH] Remove some numpy support Resolves #689. --- DOCS.md | 7 +------ coconut/compiler/templates/header.py_template | 19 ++++--------------- coconut/root.py | 2 +- coconut/tests/src/extras.coco | 6 ++++-- 4 files changed, 10 insertions(+), 24 deletions(-) diff --git a/DOCS.md b/DOCS.md index 76b49ea9c..43a15c80b 100644 --- a/DOCS.md +++ b/DOCS.md @@ -427,7 +427,6 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all * [`fmap`](#fmap) will use [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) to map over `numpy` arrays. * [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multi-dimensional indices in a `numpy` array. * [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array. - * [`flatten`](#flatten) can flatten the first axis of a given `numpy` array. * [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same. - [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). @@ -3030,8 +3029,6 @@ _Can't be done quickly without Coconut's iterator slicing, which requires many c Coconut's `cycle` is a modified version of `itertools.cycle` with a `times` parameter that controls the number of times to cycle through _iterable_ before stopping. `cycle` also supports `in`, slicing, `len`, `reversed`, `.count()`, `.index()`, and `repr`. -When given a [`numpy`](#numpy-integration) array and a finite _times_, `cycle` will return a `numpy` array of _iterable_ concatenated with itself along the first axis _times_ times. - ##### Python Docs **cycle**(_iterable_) @@ -3219,9 +3216,7 @@ for x in input_data: Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`. -Additionally, `flatten` includes special support for [`numpy`](#numpy-integration) objects, in which case a multidimensional array is returned instead of an iterator. - -Note that `flatten` only flattens the top level (first axis) of the given iterable/array. +Note that `flatten` only flattens the top level of the given iterable/array. ##### Python Docs diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index d0d50c916..b481539bb 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -472,13 +472,9 @@ class reversed(_coconut_has_iter): return self.__class__(_coconut_map(func, self.iter)) class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_become_very_innefficient} """Flatten an iterable of iterables into a single iterable. - Flattens the first axis of numpy arrays.""" + Only flattens the top level of the iterable.""" __slots__ = () def __new__(cls, iterable): - if iterable.__class__.__module__ in _coconut.numpy_modules: - if len(iterable.shape) < 2: - raise _coconut.TypeError("flatten() on numpy arrays requires two or more dimensions") - return iterable.reshape(-1, *iterable.shape[2:]) self = _coconut_has_iter.__new__(cls, iterable) return self def get_new_iter(self): @@ -529,12 +525,11 @@ Additionally supports Cartesian products of numpy arrays.""" else: numpy = _coconut.numpy iterables *= repeat - la = _coconut.len(iterables) dtype = numpy.result_type(*iterables) - arr = numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype) + arr = numpy.empty([_coconut.len(a) for a in iterables] + [_coconut.len(iterables)], dtype=dtype) for i, a in _coconut.enumerate(numpy.ix_(*iterables)): - arr[...,i] = a - return arr.reshape(-1, la) + arr[..., i] = a + return arr.reshape(-1, _coconut.len(iterables)) self = _coconut.object.__new__(cls) self.iters = iterables self.repeat = repeat @@ -934,12 +929,6 @@ class count(_coconut_base_hashable): class cycle(_coconut_has_iter): __slots__ = ("times",) def __new__(cls, iterable, times=None): - if times is not None: - if iterable.__class__.__module__ in _coconut.numpy_modules: - return _coconut.numpy.concatenate((iterable,) * times) - if iterable.__class__.__module__ in _coconut.jax_numpy_modules: - import jax.numpy as jnp - return jnp.concatenate((iterable,) * times) self = _coconut_has_iter.__new__(cls, iterable) self.times = times return self diff --git a/coconut/root.py b/coconut/root.py index 3233a1dbc..69c5218d6 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "2.1.1" VERSION_NAME = "The Spanish Inquisition" # False for release, int >= 1 for develop -DEVELOP = 27 +DEVELOP = 28 ALPHA = False # for pre releases rather than post releases # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 036f5197c..8855d83cd 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -386,8 +386,10 @@ def test_numpy() -> bool: `np.array_equal` np.array([1, 1;; 1, 2;; 2, 1;; 2, 2]) ) # type: ignore - assert flatten(np.array([1,2;;3,4])) `np.array_equal` np.array([1,2,3,4]) # type: ignore - assert cycle(np.array([1,2;;3,4]), 2) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4]) # type: ignore + assert flatten(np.array([1,2;;3,4])) `isinstance` flatten + assert (flatten(np.array([1,2;;3,4])) |> list) == [1,2,3,4] + assert cycle(np.array([1,2;;3,4]), 2) `isinstance` cycle + assert (cycle(np.array([1,2;;3,4]), 2) |> np.asarray) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4]) return True