Skip to content

Commit

Permalink
docs: more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Dec 18, 2024
1 parent aa76539 commit 49983e4
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,17 @@ def optimize(dsk: HighLevelGraph, keys: Sequence[Key], **_: Any) -> Mapping:
def _prepare_buffer_projection(
dsk: HighLevelGraph, keys: Sequence[Key]
) -> tuple[dict[str, TypeTracerReport], dict[str, Any]] | None:
"""Pair layer names with lists of necessary columns."""
import awkward as ak
"""Prepare for buffer projection by building and evaluating a version of the graph that
has been annotated for typetracer reporting.
Parameters
----------
dsk : HighLevelGraph
Task graph to optimize.
keys : list[str]
Sequence of keys to optimize with respect to.
"""
# Skip early if we can't meaningfully optimise any layers
if not _has_projectable_awkward_io_layer(dsk):
return None

Expand All @@ -97,39 +105,50 @@ def _prepare_buffer_projection(

for name, lay in dsk.layers.items():
if isinstance(lay, AwkwardInputLayer):
# The layer supports buffer projection
if lay.is_projectable:
# Insert mocked array into layers, replacing generation func
# Keep track of mocked state
# Replace input layer with one that is ready for input projection using a report
# Store the report for subsequent retrieval, and cache the transient state
# that column projection later needs to finalise the optimisation
(
projection_layers[name],
layer_to_reports[name],
layer_to_projection_state[name],
) = lay.prepare_for_projection()
# Layers that don't support buffer projection might support mocking
# This means that we at least do not have to compute them during evaluation of the optimisation graph
elif lay.is_mockable:
projection_layers[name] = lay.mock()
# Layers that don't support buffer projection might support mocking
# This means that we at least do not have to compute them during evaluation of the optimisation graph
elif hasattr(lay, "mock"):
projection_layers[name] = lay.mock()

# Ensure that the buffers of each output are entirely touched
for name in _ak_output_layer_names(dsk):
projection_layers[name] = _mock_output(projection_layers[name])

hlg = HighLevelGraph(projection_layers, dsk.dependencies)

# The caller should apply this optimisation with respect to a number of output keys
minimal_keys: set[Key] = set()
for k in keys:
if isinstance(k, tuple) and len(k) == 2:
minimal_keys.add((k[0], 0))
else:
minimal_keys.add(k)

# now we try to compute for each possible output layer key (leaf
# Now we try to compute for each possible output layer key (leaf
# node on partition 0); this will cause the typetacer reports to
# get correct fields/columns touched. If the result is a record or
# an array we of course want to touch all of the data/fields.
try:
for layer in hlg.layers.values():
layer.__dict__.pop("_cached_dict", None)

results = get_sync(hlg, list(minimal_keys))

# Touch all the buffers associated with the given output keys
for out in results:
if isinstance(out, (ak.Array, ak.Record)):
touch_data(out)
Expand Down Expand Up @@ -185,14 +204,15 @@ def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph
New, optimized task graph with column-projected ``AwkwardInputLayer``.
"""
# 1. Build-and-evaluate typetracer-annotated graph
projection_data = _prepare_buffer_projection(dsk, keys)
if projection_data is None:
return dsk

# Unpack result
# 2. Unpack result
layer_to_reports, layer_to_projection_state = projection_data

# Project layers using projection state
# 3. Project layers using projection state from (1)
layers = dict(dsk.layers)
for name, state in layer_to_projection_state.items():
layers[name] = cast(AwkwardInputLayer, layers[name]).project(
Expand Down Expand Up @@ -432,9 +452,3 @@ def _recursive_replace(args, layer, parent, indices):
else:
args2.append(arg)
return args2


def _buffer_keys_for_layer(
buffer_keys: Iterable[str], known_buffer_keys: frozenset[str]
) -> set[str]:
return {k for k in buffer_keys if k in known_buffer_keys}

0 comments on commit 49983e4

Please sign in to comment.