diff --git a/src/dask_awkward/lib/inspect.py b/src/dask_awkward/lib/inspect.py index 53c1ea8e..3f63fc22 100644 --- a/src/dask_awkward/lib/inspect.py +++ b/src/dask_awkward/lib/inspect.py @@ -177,36 +177,39 @@ def report_necessary_columns( seen_names = set() name_to_necessary_columns: dict[str, frozenset | None] = {} - for obj in collections: - dsk = obj.__dask_graph__() - keys = obj.__dask_keys__() - projection_data = o._prepare_buffer_projection(dsk, keys) - - # If the projection failed, or there are no input layers - if projection_data is None: - # Ensure that we have a record of the seen layers, if they're inputs - for name, layer in dsk.items(): - if isinstance(layer, AwkwardInputLayer): - seen_names.add(name) - continue - - # Unpack projection information - layer_to_reports, layer_to_projection_state = projection_data - for name, report in layer_to_reports.items(): - layer = dsk.layers[name] - if not (isinstance(layer, AwkwardInputLayer) and layer.is_columnar): + with o.typetracer_nochecks(): + for obj in collections: + dsk = obj.__dask_graph__() + keys = obj.__dask_keys__() + projection_data = o._prepare_buffer_projection(dsk, keys) + + # If the projection failed, or there are no input layers + if projection_data is None: + # Ensure that we have a record of the seen layers, if they're inputs + for name, layer in dsk.items(): + if isinstance(layer, AwkwardInputLayer): + seen_names.add(name) continue - existing_columns = name_to_necessary_columns.setdefault(name, frozenset()) + # Unpack projection information + layer_to_reports, layer_to_projection_state = projection_data + for name, report in layer_to_reports.items(): + layer = dsk.layers[name] + if not (isinstance(layer, AwkwardInputLayer) and layer.is_columnar): + continue - assert existing_columns is not None - # Update set of touched keys - name_to_necessary_columns[name] = ( - existing_columns - | layer.necessary_columns( - report=report, state=layer_to_projection_state[name] + existing_columns = name_to_necessary_columns.setdefault( + name, frozenset() + ) + + assert existing_columns is not None + # Update set of touched keys + name_to_necessary_columns[name] = ( + existing_columns + | layer.necessary_columns( + report=report, state=layer_to_projection_state[name] + ) ) - ) # Populate result with names of seen layers for k in seen_names: diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 41612edd..a1ab9d6a 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -14,6 +14,7 @@ from dask.local import get_sync from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer +from dask_awkward.lib.utils import typetracer_nochecks from dask_awkward.utils import first if TYPE_CHECKING: @@ -45,7 +46,8 @@ def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping: dsk = HighLevelGraph.from_collections(str(id(dsk)), dsk, dependencies=()) # Perform dask-awkward specific optimizations. - dsk = optimize(dsk, keys=keys) + with typetracer_nochecks(): + dsk = optimize(dsk, keys=keys) # Perform Blockwise optimizations for HLG input dsk = optimize_blockwise(dsk, keys=keys) # fuse nearby layers diff --git a/src/dask_awkward/lib/utils.py b/src/dask_awkward/lib/utils.py index a25b3dd5..7b067386 100644 --- a/src/dask_awkward/lib/utils.py +++ b/src/dask_awkward/lib/utils.py @@ -3,6 +3,7 @@ __all__ = ("trace_form_structure", "buffer_keys_required_to_compute_shapes") from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping +from contextlib import contextmanager from typing import TYPE_CHECKING, TypedDict, TypeVar import awkward as ak @@ -164,3 +165,16 @@ def impl(form: Form, key: str) -> None: form = ak.forms.from_dict(form.to_dict()) impl(form, key) return form + + +@contextmanager +def typetracer_nochecks(): + from awkward._nplikes.typetracer import TypeTracerArray + + oldval = getattr(TypeTracerArray, "runtime_typechecks", None) + TypeTracerArray.runtime_typechecks = False + yield + if oldval is not None: + TypeTracerArray.runtime_typechecks = oldval + else: + del TypeTracerArray.runtime_typechecks diff --git a/tests/test_utils.py b/tests/test_utils.py index 3d31a964..3d1b38c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import pytest +from dask_awkward.lib.utils import typetracer_nochecks from dask_awkward.utils import ( LazyInputsDict, field_access_to_front, @@ -78,3 +79,12 @@ def test_field_access_to_front(pairs): res = field_access_to_front(pairs[0]) assert res[0] == pairs[1] assert res[1] == pairs[2] + + +def test_nocheck_context(): + from awkward._nplikes.typetracer import TypeTracerArray + + assert getattr(TypeTracerArray, "runtime_typechecks", True) + with typetracer_nochecks(): + assert not TypeTracerArray.runtime_typechecks + assert getattr(TypeTracerArray, "runtime_typechecks", True)