diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index 0f2317074cea..cc17d9919b8e 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -22,6 +22,7 @@ scheduler. """ import argparse +import collections import dataclasses import typing as t @@ -143,8 +144,8 @@ def to_dask_bag_visitor() -> PipelineVisitor: @dataclasses.dataclass class DaskBagVisitor(PipelineVisitor): - bags: t.Dict[AppliedPTransform, - db.Bag] = dataclasses.field(default_factory=dict) + bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field( + default_factory=collections.OrderedDict) def visit_transform(self, transform_node: AppliedPTransform) -> None: op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) @@ -212,6 +213,10 @@ def run_pipeline(self, pipeline, options): dask_visitor = self.to_dask_bag_visitor() pipeline.visit(dask_visitor) - opt_graph = dask.optimize(*list(dask_visitor.bags.values())) + # The dictionary in this visitor keeps a mapping of every Beam + # PTransform to the equivalent Bag operation. This is highly + # redundant. Thus, we can get away with computing just the last + # value, which should be connected to the full Bag Task Graph. + opt_graph = dask.optimize(list(dask_visitor.bags.values())[-1]) futures = client.compute(opt_graph) return DaskRunnerResult(client, futures) diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py index e3bd5fd87763..e72ebcce8b13 100644 --- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -19,8 +19,6 @@ A minimum set of operation substitutions, to adap Beam's PTransform model to Dask Bag functions. - -TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html. """ import abc import dataclasses