Skip to content

Commit

Permalink
gh-113202: Add a strict option to itertools.batched() (gh-113203)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhettinger authored Dec 16, 2023
1 parent fe479fb commit 1583c40
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 24 deletions.
18 changes: 15 additions & 3 deletions Doc/library/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,14 @@ loops that truncate the stream.
Added the optional *initial* parameter.


.. function:: batched(iterable, n)
.. function:: batched(iterable, n, *, strict=False)

Batch data from the *iterable* into tuples of length *n*. The last
batch may be shorter than *n*.

If *strict* is true, will raise a :exc:`ValueError` if the final
batch is shorter than *n*.

Loops over the input iterable and accumulates data into tuples up to
size *n*. The input is consumed lazily, just enough to fill a batch.
The result is yielded as soon as the batch is full or when the input
Expand All @@ -190,16 +193,21 @@ loops that truncate the stream.

Roughly equivalent to::

def batched(iterable, n):
def batched(iterable, n, *, strict=False):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, n)):
if strict and len(batch) != n:
raise ValueError('batched(): incomplete batch')
yield batch

.. versionadded:: 3.12

.. versionchanged:: 3.13
Added the *strict* option.


.. function:: chain(*iterables)

Expand Down Expand Up @@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor:
def reshape(matrix, cols):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), cols)
return batched(chain.from_iterable(matrix), cols, strict=True)

def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
Expand Down Expand Up @@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor:
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]
>>> list(reshape(M, 4))
[(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
>>> list(reshape(M, 5))
Traceback (most recent call last):
...
ValueError: batched(): incomplete batch
>>> list(reshape(M, 6))
[(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]
>>> list(reshape(M, 12))
Expand Down
4 changes: 4 additions & 0 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,11 @@ def test_batched(self):
[('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
self.assertEqual(list(batched('ABCDEFG', 1)),
[('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
self.assertEqual(list(batched('ABCDEF', 2, strict=True)),
[('A', 'B'), ('C', 'D'), ('E', 'F')])

with self.assertRaises(ValueError): # Incomplete batch when strict
list(batched('ABCDEFG', 3, strict=True))
with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG'))
with self.assertRaises(TypeError):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``strict`` option to ``batched()`` in the ``itertools`` module.
32 changes: 23 additions & 9 deletions Modules/clinic/itertoolsmodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 17 additions & 12 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,21 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"

/* batched object ************************************************************/

/* Note: The built-in zip() function includes a "strict" argument
that was needed because that function would silently truncate data,
and there was no easy way for a user to detect the data loss.
The same reasoning does not apply to batched() which never drops data.
Instead, batched() produces a shorter tuple which can be handled
as the user sees fit. If requested, it would be reasonable to add
"fillvalue" support which had demonstrated value in zip_longest().
For now, the API is kept simple and clean.
*/

typedef struct {
PyObject_HEAD
PyObject *it;
Py_ssize_t batch_size;
bool strict;
} batchedobject;

/*[clinic input]
@classmethod
itertools.batched.__new__ as batched_new
iterable: object
n: Py_ssize_t
*
strict: bool = False
Batch data into tuples of length n. The last batch may be shorter than n.
Loops over the input iterable and accumulates data into tuples
Expand All @@ -140,11 +134,15 @@ or when the input iterable is exhausted.
('D', 'E', 'F')
('G',)
If "strict" is True, raises a ValueError if the final batch is shorter
than n.
[clinic start generated code]*/

static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
int strict)
/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
{
PyObject *it;
batchedobject *bo;
Expand All @@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
}
bo->batch_size = n;
bo->it = it;
bo->strict = (bool) strict;
return (PyObject *)bo;
}

Expand Down Expand Up @@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
Py_DECREF(result);
return NULL;
}
if (bo->strict) {
Py_CLEAR(bo->it);
Py_DECREF(result);
PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
return NULL;
}
_PyTuple_Resize(&result, i);
return result;
}
Expand Down

0 comments on commit 1583c40

Please sign in to comment.