From 3dcb9a082b824b07456fb214b79300b5d7ae002b Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 8 Dec 2022 00:32:27 -0800 Subject: [PATCH] Add numpy cycle support Resolves #690. --- DOCS.md | 18 +++++++++++------- coconut/compiler/compiler.py | 6 ++---- coconut/compiler/templates/header.py_template | 6 ++++++ coconut/root.py | 2 +- coconut/tests/src/extras.coco | 1 + 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/DOCS.md b/DOCS.md index 67ce130c2..76b49ea9c 100644 --- a/DOCS.md +++ b/DOCS.md @@ -432,6 +432,8 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all - [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). +Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/) and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `jax.numpy` methods over `numpy` methods when given `jax` arrays. + ### `xonsh` Support Coconut integrates with [`xonsh`](https://xon.sh/) to allow the use of Coconut code directly from your command line. To use Coconut in `xonsh`, simply `pip install coconut` should be all you need to enable the use of Coconut syntax in the `xonsh` shell. In some circumstances, in particular depending on the installed `xonsh` version, adding `xontrib load coconut` to your [`xonshrc`](https://xon.sh/xonshrc.html) file might be necessary. @@ -1707,7 +1709,7 @@ def int_map( Coconut supports multidimensional array literal and array [concatenation](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html)/[stack](https://numpy.org/doc/stable/reference/generated/numpy.stack.html) syntax. -By default, all multidimensional array syntax will simply operate on Python lists of lists. However, if [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`). +By default, all multidimensional array syntax will simply operate on Python lists of lists. However, if [`numpy`](#numpy-integration) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`). As a simple example, 2D matrices can be constructed by separating the rows with `;;` inside of a list literal: ```coconut_pycon @@ -3028,6 +3030,8 @@ _Can't be done quickly without Coconut's iterator slicing, which requires many c Coconut's `cycle` is a modified version of `itertools.cycle` with a `times` parameter that controls the number of times to cycle through _iterable_ before stopping. `cycle` also supports `in`, slicing, `len`, `reversed`, `.count()`, `.index()`, and `repr`. +When given a [`numpy`](#numpy-integration) array and a finite _times_, `cycle` will return a `numpy` array of _iterable_ concatenated with itself along the first axis _times_ times. + ##### Python Docs **cycle**(_iterable_) @@ -3101,7 +3105,7 @@ In functional programming, `fmap(func, obj)` takes a data type `obj` and returns For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. Additionally, for backwards compatibility with old versions of Coconut, `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them. -For [`numpy`](http://www.numpy.org/), [`pandas`](https://pandas.pydata.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. +For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to: ```coconut_python @@ -3215,7 +3219,7 @@ for x in input_data: Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`. -Additionally, `flatten` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator. +Additionally, `flatten` includes special support for [`numpy`](#numpy-integration) objects, in which case a multidimensional array is returned instead of an iterator. Note that `flatten` only flattens the top level (first axis) of the given iterable/array. @@ -3254,7 +3258,7 @@ flat_it = iter_of_iters |> chain.from_iterable |> list Coconut provides an enhanced version of `itertools.product` as a built-in under the name `cartesian_product` with added support for `len`, `repr`, `in`, `.count()`, and `fmap`. -Additionally, `cartesian_product` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator. +Additionally, `cartesian_product` includes special support for [`numpy`](#numpy-integration) objects, in which case a multidimensional array is returned instead of an iterator. ##### Python Docs @@ -3305,7 +3309,7 @@ assert list(product(v, v)) == [(1, 1), (1, 2), (2, 1), (2, 2)] Coconut's `multi_enumerate` enumerates through an iterable of iterables. `multi_enumerate` works like enumerate, but indexes through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing. -For [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, effectively equivalent to: +For [`numpy`](#numpy-integration) objects, effectively equivalent to: ```coconut_python def multi_enumerate(iterable): it = np.nditer(iterable, flags=["multi_index"]) @@ -3313,7 +3317,7 @@ def multi_enumerate(iterable): yield it.multi_index, x ``` -Also supports `len` for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html). +Also supports `len` for [`numpy`](#numpy-integration). ##### Example @@ -3386,7 +3390,7 @@ for item in balance_data: **all\_equal**(_iterable_) -Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects. +Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](#numpy-integration) objects. ##### Example diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 750d4373b..1c5f048f4 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -3554,9 +3554,7 @@ def testlist_star_expr_handle(self, original, loc, tokens, is_list=False): groups, has_star, has_comma = self.split_star_expr_tokens(tokens) is_sequence = has_comma or is_list - if not is_sequence: - if has_star: - raise CoconutDeferredSyntaxError("can't use starred expression here", loc) + if not is_sequence and not has_star: self.internal_assert(len(groups) == 1 and len(groups[0]) == 1, original, loc, "invalid single-item testlist_star_expr tokens", tokens) out = groups[0][0] @@ -3565,7 +3563,7 @@ def testlist_star_expr_handle(self, original, loc, tokens, is_list=False): out = tuple_str_of(groups[0], add_parens=False) # naturally supported on 3.5+ - elif self.target_info >= (3, 5): + elif is_sequence and self.target_info >= (3, 5): to_literal = [] for g in groups: if isinstance(g, list): diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 4094fd0ad..d0d50c916 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -934,6 +934,12 @@ class count(_coconut_base_hashable): class cycle(_coconut_has_iter): __slots__ = ("times",) def __new__(cls, iterable, times=None): + if times is not None: + if iterable.__class__.__module__ in _coconut.numpy_modules: + return _coconut.numpy.concatenate((iterable,) * times) + if iterable.__class__.__module__ in _coconut.jax_numpy_modules: + import jax.numpy as jnp + return jnp.concatenate((iterable,) * times) self = _coconut_has_iter.__new__(cls, iterable) self.times = times return self diff --git a/coconut/root.py b/coconut/root.py index 476d769fd..3233a1dbc 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "2.1.1" VERSION_NAME = "The Spanish Inquisition" # False for release, int >= 1 for develop -DEVELOP = 26 +DEVELOP = 27 ALPHA = False # for pre releases rather than post releases # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index decde4f02..036f5197c 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -387,6 +387,7 @@ def test_numpy() -> bool: np.array([1, 1;; 1, 2;; 2, 1;; 2, 2]) ) # type: ignore assert flatten(np.array([1,2;;3,4])) `np.array_equal` np.array([1,2,3,4]) # type: ignore + assert cycle(np.array([1,2;;3,4]), 2) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4]) # type: ignore return True