Skip to content

Commit

Permalink
Run a significantly smaller Dask Graph. (apache#33806)
Browse files Browse the repository at this point in the history
* Run a significantly smaller Dask Graph.

`dask.optimize()` isn't a cure-all for everything. It won't remove tasks in a highly-duplicated graph, like the one we collect in the DaskRunner's visitor. The fix, however, is quite simple: We can get away with the dask distributed client only running the last value in the translated operation graph.

Even with this small change, all tests pass. Further, it should lead to a significant performance improvement.

* Fixing all run_pylint.sh issues.

* *shakes fist at `isort`*.
  • Loading branch information
alxmrs authored Jan 30, 2025
1 parent bbe6394 commit ebd8898
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 8 additions & 3 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
scheduler.
"""
import argparse
import collections
import dataclasses
import typing as t

Expand Down Expand Up @@ -143,8 +144,8 @@ def to_dask_bag_visitor() -> PipelineVisitor:

@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform,
db.Bag] = dataclasses.field(default_factory=dict)
bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
default_factory=collections.OrderedDict)

def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
Expand Down Expand Up @@ -212,6 +213,10 @@ def run_pipeline(self, pipeline, options):

dask_visitor = self.to_dask_bag_visitor()
pipeline.visit(dask_visitor)
opt_graph = dask.optimize(*list(dask_visitor.bags.values()))
# The dictionary in this visitor keeps a mapping of every Beam
# PTransform to the equivalent Bag operation. This is highly
# redundant. Thus, we can get away with computing just the last
# value, which should be connected to the full Bag Task Graph.
opt_graph = dask.optimize(list(dask_visitor.bags.values())[-1])
futures = client.compute(opt_graph)
return DaskRunnerResult(client, futures)
2 changes: 0 additions & 2 deletions sdks/python/apache_beam/runners/dask/transform_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
A minimum set of operation substitutions, to adap Beam's PTransform model
to Dask Bag functions.
TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html.
"""
import abc
import dataclasses
Expand Down

0 comments on commit ebd8898

Please sign in to comment.