diff --git a/test/test_graph.py b/test/test_graph.py index b8de4189e..7a94ba004 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -51,14 +51,26 @@ def finalize_iteration(self) -> None: dp.started = False +def _x_and_x_plus_5(x): + return [x, x + 5] + + +def _x_mod_2(x): + return x % 2 + + +def _x_mult_2(x): + return x * 2 + + class TestGraph(expecttest.TestCase): def _get_datapipes(self) -> Tuple[IterDataPipe, IterDataPipe, IterDataPipe]: src_dp = IterableWrapper(range(20)) - m1 = src_dp.map(lambda x: [x, x + 5]) + m1 = src_dp.map(_x_and_x_plus_5) ub = m1.unbatch() - c1, c2 = ub.demux(2, lambda x: x % 2) + c1, c2 = ub.demux(2, _x_mod_2) dm = c1.main_datapipe - m2 = c1.map(lambda x: x * 2) + m2 = c1.map(_x_mult_2) dp = m2.zip(c2) return traverse(dp, only_datapipe=True), (src_dp, m1, ub, dm, c1, c2, m2, dp)