Skip to content

Commit

Permalink
Remove some numpy support
Browse files Browse the repository at this point in the history
Resolves   #689.
  • Loading branch information
evhub committed Dec 30, 2022
1 parent 3dcb9a0 commit 1a3fe8d
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 24 deletions.
7 changes: 1 addition & 6 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all
* [`fmap`](#fmap) will use [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) to map over `numpy` arrays.
* [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multi-dimensional indices in a `numpy` array.
* [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array.
* [`flatten`](#flatten) can flatten the first axis of a given `numpy` array.
* [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same.
- [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification).
- Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions).
Expand Down Expand Up @@ -3030,8 +3029,6 @@ _Can't be done quickly without Coconut's iterator slicing, which requires many c

Coconut's `cycle` is a modified version of `itertools.cycle` with a `times` parameter that controls the number of times to cycle through _iterable_ before stopping. `cycle` also supports `in`, slicing, `len`, `reversed`, `.count()`, `.index()`, and `repr`.

When given a [`numpy`](#numpy-integration) array and a finite _times_, `cycle` will return a `numpy` array of _iterable_ concatenated with itself along the first axis _times_ times.

##### Python Docs

**cycle**(_iterable_)
Expand Down Expand Up @@ -3219,9 +3216,7 @@ for x in input_data:

Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`.

Additionally, `flatten` includes special support for [`numpy`](#numpy-integration) objects, in which case a multidimensional array is returned instead of an iterator.

Note that `flatten` only flattens the top level (first axis) of the given iterable/array.
Note that `flatten` only flattens the top level of the given iterable/array.

##### Python Docs

Expand Down
19 changes: 4 additions & 15 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,9 @@ class reversed(_coconut_has_iter):
return self.__class__(_coconut_map(func, self.iter))
class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_become_very_innefficient}
"""Flatten an iterable of iterables into a single iterable.
Flattens the first axis of numpy arrays."""
Only flattens the top level of the iterable."""
__slots__ = ()
def __new__(cls, iterable):
if iterable.__class__.__module__ in _coconut.numpy_modules:
if len(iterable.shape) < 2:
raise _coconut.TypeError("flatten() on numpy arrays requires two or more dimensions")
return iterable.reshape(-1, *iterable.shape[2:])
self = _coconut_has_iter.__new__(cls, iterable)
return self
def get_new_iter(self):
Expand Down Expand Up @@ -529,12 +525,11 @@ Additionally supports Cartesian products of numpy arrays."""
else:
numpy = _coconut.numpy
iterables *= repeat
la = _coconut.len(iterables)
dtype = numpy.result_type(*iterables)
arr = numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype)
arr = numpy.empty([_coconut.len(a) for a in iterables] + [_coconut.len(iterables)], dtype=dtype)
for i, a in _coconut.enumerate(numpy.ix_(*iterables)):
arr[...,i] = a
return arr.reshape(-1, la)
arr[..., i] = a
return arr.reshape(-1, _coconut.len(iterables))
self = _coconut.object.__new__(cls)
self.iters = iterables
self.repeat = repeat
Expand Down Expand Up @@ -934,12 +929,6 @@ class count(_coconut_base_hashable):
class cycle(_coconut_has_iter):
__slots__ = ("times",)
def __new__(cls, iterable, times=None):
if times is not None:
if iterable.__class__.__module__ in _coconut.numpy_modules:
return _coconut.numpy.concatenate((iterable,) * times)
if iterable.__class__.__module__ in _coconut.jax_numpy_modules:
import jax.numpy as jnp
return jnp.concatenate((iterable,) * times)
self = _coconut_has_iter.__new__(cls, iterable)
self.times = times
return self
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "2.1.1"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 27
DEVELOP = 28
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
6 changes: 4 additions & 2 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,10 @@ def test_numpy() -> bool:
`np.array_equal`
np.array([1, 1;; 1, 2;; 2, 1;; 2, 2])
) # type: ignore
assert flatten(np.array([1,2;;3,4])) `np.array_equal` np.array([1,2,3,4]) # type: ignore
assert cycle(np.array([1,2;;3,4]), 2) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4]) # type: ignore
assert flatten(np.array([1,2;;3,4])) `isinstance` flatten
assert (flatten(np.array([1,2;;3,4])) |> list) == [1,2,3,4]
assert cycle(np.array([1,2;;3,4]), 2) `isinstance` cycle
assert (cycle(np.array([1,2;;3,4]), 2) |> np.asarray) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4])
return True


Expand Down

0 comments on commit 1a3fe8d

Please sign in to comment.