Skip to content

Commit

Permalink
Fix cartesian_product
Browse files Browse the repository at this point in the history
Resolves   #688.
  • Loading branch information
evhub committed Dec 30, 2022
1 parent b70b8a7 commit 626c1f4
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 14 deletions.
59 changes: 55 additions & 4 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,11 @@ To distribute your code with checkable type annotations, you'll need to include
To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all compiled Coconut code will do a number of special things to better integrate with `numpy` (if `numpy` is available to import when the code is run). Specifically:

- Coconut's [multidimensional array literal and array concatenation syntax](#multidimensional-array-literalconcatenation-syntax) supports `numpy` objects, including using fast `numpy` concatenation methods if given `numpy` arrays rather than Coconut's default much slower implementation built for Python lists of lists.
- Coconut's [`multi_enumerate`](#multi_enumerate) built-in allows for easily looping over all the multi-dimensional indices in a `numpy` array.
- Coconut's [`all_equal`](#all_equal) built-in allows for easily checking if all the elements in a `numpy` array are the same.
- When a `numpy` object is passed to [`fmap`](#fmap), [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) is used instead of the default `fmap` implementation.
- Many of Coconut's built-ins include special `numpy` support, specifically:
* [`fmap`](#fmap) will use [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) to map over `numpy` arrays.
* [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multi-dimensional indices in a `numpy` array.
* [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array.
* [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same.
- [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification).
- Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions).

Expand Down Expand Up @@ -3101,7 +3103,7 @@ for x in input_data:

### `flatten`

Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`.
Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `len`, `repr`, `in`, `.count()`, `.index()`, and `fmap`.

##### Python Docs

Expand Down Expand Up @@ -3132,6 +3134,55 @@ iter_of_iters = [[1, 2], [3, 4]]
flat_it = iter_of_iters |> chain.from_iterable |> list
```

### `cartesian_product`

Coconut provides an enhanced version of `itertools.product` as a built-in under the name `cartesian_product` with added support for `len`, `repr`, `in`, `.count()`, and `fmap`.

Additionally, `cartesian_product` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator.

##### Python Docs

itertools.**product**(_\*iterables, repeat=1_)

Cartesian product of input iterables.

Roughly equivalent to nested for-loops in a generator expression. For example, `product(A, B)` returns the same as `((x,y) for x in A for y in B)`.

The nested loops cycle like an odometer with the rightmost element advancing on every iteration. This pattern creates a lexicographic ordering so that if the input’s iterables are sorted, the product tuples are emitted in sorted order.

To compute the product of an iterable with itself, specify the number of repetitions with the optional repeat keyword argument. For example, `product(A, repeat=4)` means the same as `product(A, A, A, A)`.

This function is roughly equivalent to the following code, except that the actual implementation does not build up intermediate results in memory:

```coconut_python
def product(*args, repeat=1):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = [tuple(pool) for pool in args] * repeat
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
```

Before `product()` runs, it completely consumes the input iterables, keeping pools of values in memory to generate the products. Accordingly, it is only useful with finite inputs.

##### Example

**Coconut:**
```coconut
v = [1, 2]
assert cartesian_product(v, v) |> list == [(1, 1), (1, 2), (2, 1), (2, 2)]
```

**Python:**
```coconut_python
from itertools import product
v = [1, 2]
assert list(product(v, v)) == [(1, 1), (1, 2), (2, 1), (2, 2)]
```

### `multi_enumerate`

Coconut's `multi_enumerate` enumerates through an iterable of iterables. `multi_enumerate` works like enumerate, but indexes through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing.
Expand Down
1 change: 1 addition & 0 deletions _coconut/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ property = property
range = range
reversed = reversed
set = set
setattr = setattr
slice = slice
str = str
sum = sum
Expand Down
22 changes: 16 additions & 6 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _coconut_super(type=None, object_or_type=None):
numpy_modules = {numpy_modules}
jax_numpy_modules = {jax_numpy_modules}
abc.Sequence.register(collections.deque)
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
class _coconut_sentinel{object}:
__slots__ = ()
class _coconut_base_hashable{object}:
Expand All @@ -43,8 +43,11 @@ class _coconut_base_hashable{object}:
return self.__class__ is other.__class__ and self.__reduce__() == other.__reduce__()
def __hash__(self):
return _coconut.hash(self.__reduce__())
def __bool__(self):
return True{COMMENT.avoids_expensive_len_calls}
def __bool__(self):{COMMENT.avoids_expensive_len_calls}
return True
def __setstate__(self, setvars):{COMMENT.fixes_unpickling_with_slots}
for k, v in setvars.items():
_coconut.setattr(self, k, v)
class MatchError(_coconut_base_hashable, Exception):
"""Pattern-matching error. Has attributes .pattern, .value, and .message."""
__slots__ = ("pattern", "value", "_message")
Expand Down Expand Up @@ -440,6 +443,9 @@ class flatten(_coconut_base_hashable):
def __contains__(self, elem):
self.iter, new_iter = _coconut_tee(self.iter)
return _coconut.any(elem in it for it in new_iter)
def __len__(self):
self.iter, new_iter = _coconut_tee(self.iter)
return _coconut.sum(_coconut.len(it) for it in new_iter)
def count(self, elem):
"""Count the number of times elem appears in the flattened iterable."""
self.iter, new_iter = _coconut_tee(self.iter)
Expand All @@ -466,11 +472,15 @@ Additionally supports Cartesian products of numpy arrays."""
if kwargs:
raise _coconut.TypeError("cartesian_product() got unexpected keyword arguments " + _coconut.repr(kwargs))
if iterables and _coconut.all(it.__class__.__module__ in _coconut.numpy_modules for it in iterables):
if _coconut.any(it.__class__.__module__ in _coconut.jax_numpy_modules for it in iterables):
from jax import numpy
else:
numpy = _coconut.numpy
iterables *= repeat
la = _coconut.len(iterables)
dtype = _coconut.numpy.result_type(*iterables)
arr = _coconut.numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype)
for i, a in _coconut.enumerate(_coconut.numpy.ix_(*iterables)):
dtype = numpy.result_type(*iterables)
arr = numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype)
for i, a in _coconut.enumerate(numpy.ix_(*iterables)):
arr[...,i] = a
return arr.reshape(-1, la)
self = _coconut.object.__new__(cls)
Expand Down
11 changes: 7 additions & 4 deletions coconut/tests/src/cocotest/agnostic/main.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1235,10 +1235,13 @@ def main_test() -> bool:
\list = [1, 2, 3]
return \list
assert test_list() == list((1, 2, 3))
match def only_one(1) = 1
only_one.one = 1
assert only_one.one == 1
assert cartesian_product() |> list == [] == cartesian_product(repeat=10) |> list
match def one_or_two(1) = one_or_two.one
addpattern def one_or_two(2) = one_or_two.two # type: ignore
one_or_two.one = 10
one_or_two.two = 20
assert one_or_two(1) == 10
assert one_or_two(2) == 20
assert cartesian_product() |> list == [()] == cartesian_product(repeat=10) |> list
assert cartesian_product() |> len == 1 == cartesian_product(repeat=10) |> len
assert () in cartesian_product()
assert () in cartesian_product(repeat=10)
Expand Down

0 comments on commit 626c1f4

Please sign in to comment.