Skip to content

Commit

Permalink
Add numpy cycle support
Browse files Browse the repository at this point in the history
Resolves   #690.
  • Loading branch information
evhub committed Dec 30, 2022
1 parent f2686ce commit 3dcb9a0
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
18 changes: 11 additions & 7 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all
- [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification).
- Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions).

Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/) and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `jax.numpy` methods over `numpy` methods when given `jax` arrays.

### `xonsh` Support

Coconut integrates with [`xonsh`](https://xon.sh/) to allow the use of Coconut code directly from your command line. To use Coconut in `xonsh`, simply `pip install coconut` should be all you need to enable the use of Coconut syntax in the `xonsh` shell. In some circumstances, in particular depending on the installed `xonsh` version, adding `xontrib load coconut` to your [`xonshrc`](https://xon.sh/xonshrc.html) file might be necessary.
Expand Down Expand Up @@ -1707,7 +1709,7 @@ def int_map(

Coconut supports multidimensional array literal and array [concatenation](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html)/[stack](https://numpy.org/doc/stable/reference/generated/numpy.stack.html) syntax.

By default, all multidimensional array syntax will simply operate on Python lists of lists. However, if [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`).
By default, all multidimensional array syntax will simply operate on Python lists of lists. However, if [`numpy`](#numpy-integration) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`).

As a simple example, 2D matrices can be constructed by separating the rows with `;;` inside of a list literal:
```coconut_pycon
Expand Down Expand Up @@ -3028,6 +3030,8 @@ _Can't be done quickly without Coconut's iterator slicing, which requires many c

Coconut's `cycle` is a modified version of `itertools.cycle` with a `times` parameter that controls the number of times to cycle through _iterable_ before stopping. `cycle` also supports `in`, slicing, `len`, `reversed`, `.count()`, `.index()`, and `repr`.

When given a [`numpy`](#numpy-integration) array and a finite _times_, `cycle` will return a `numpy` array of _iterable_ concatenated with itself along the first axis _times_ times.

##### Python Docs

**cycle**(_iterable_)
Expand Down Expand Up @@ -3101,7 +3105,7 @@ In functional programming, `fmap(func, obj)` takes a data type `obj` and returns

For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. Additionally, for backwards compatibility with old versions of Coconut, `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them.

For [`numpy`](http://www.numpy.org/), [`pandas`](https://pandas.pydata.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result.
For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result.

For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to:
```coconut_python
Expand Down Expand Up @@ -3215,7 +3219,7 @@ for x in input_data:

Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`.

Additionally, `flatten` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator.
Additionally, `flatten` includes special support for [`numpy`](#numpy-integration) objects, in which case a multidimensional array is returned instead of an iterator.

Note that `flatten` only flattens the top level (first axis) of the given iterable/array.

Expand Down Expand Up @@ -3254,7 +3258,7 @@ flat_it = iter_of_iters |> chain.from_iterable |> list

Coconut provides an enhanced version of `itertools.product` as a built-in under the name `cartesian_product` with added support for `len`, `repr`, `in`, `.count()`, and `fmap`.

Additionally, `cartesian_product` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator.
Additionally, `cartesian_product` includes special support for [`numpy`](#numpy-integration) objects, in which case a multidimensional array is returned instead of an iterator.

##### Python Docs

Expand Down Expand Up @@ -3305,15 +3309,15 @@ assert list(product(v, v)) == [(1, 1), (1, 2), (2, 1), (2, 2)]

Coconut's `multi_enumerate` enumerates through an iterable of iterables. `multi_enumerate` works like enumerate, but indexes through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing.

For [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, effectively equivalent to:
For [`numpy`](#numpy-integration) objects, effectively equivalent to:
```coconut_python
def multi_enumerate(iterable):
it = np.nditer(iterable, flags=["multi_index"])
for x in it:
yield it.multi_index, x
```

Also supports `len` for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html).
Also supports `len` for [`numpy`](#numpy-integration).

##### Example

Expand Down Expand Up @@ -3386,7 +3390,7 @@ for item in balance_data:

**all\_equal**(_iterable_)

Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects.
Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](#numpy-integration) objects.

##### Example

Expand Down
6 changes: 2 additions & 4 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3554,9 +3554,7 @@ def testlist_star_expr_handle(self, original, loc, tokens, is_list=False):
groups, has_star, has_comma = self.split_star_expr_tokens(tokens)
is_sequence = has_comma or is_list

if not is_sequence:
if has_star:
raise CoconutDeferredSyntaxError("can't use starred expression here", loc)
if not is_sequence and not has_star:
self.internal_assert(len(groups) == 1 and len(groups[0]) == 1, original, loc, "invalid single-item testlist_star_expr tokens", tokens)
out = groups[0][0]

Expand All @@ -3565,7 +3563,7 @@ def testlist_star_expr_handle(self, original, loc, tokens, is_list=False):
out = tuple_str_of(groups[0], add_parens=False)

# naturally supported on 3.5+
elif self.target_info >= (3, 5):
elif is_sequence and self.target_info >= (3, 5):
to_literal = []
for g in groups:
if isinstance(g, list):
Expand Down
6 changes: 6 additions & 0 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,12 @@ class count(_coconut_base_hashable):
class cycle(_coconut_has_iter):
__slots__ = ("times",)
def __new__(cls, iterable, times=None):
if times is not None:
if iterable.__class__.__module__ in _coconut.numpy_modules:
return _coconut.numpy.concatenate((iterable,) * times)
if iterable.__class__.__module__ in _coconut.jax_numpy_modules:
import jax.numpy as jnp
return jnp.concatenate((iterable,) * times)
self = _coconut_has_iter.__new__(cls, iterable)
self.times = times
return self
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "2.1.1"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 26
DEVELOP = 27
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def test_numpy() -> bool:
np.array([1, 1;; 1, 2;; 2, 1;; 2, 2])
) # type: ignore
assert flatten(np.array([1,2;;3,4])) `np.array_equal` np.array([1,2,3,4]) # type: ignore
assert cycle(np.array([1,2;;3,4]), 2) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4]) # type: ignore
return True


Expand Down

0 comments on commit 3dcb9a0

Please sign in to comment.