Skip to content

Commit

Permalink
Internal Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715526271
  • Loading branch information
Grain Team authored and copybara-github committed Jan 14, 2025
1 parent 009cc09 commit 1fab87c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
7 changes: 7 additions & 0 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def _pretty_format_summary(
# the visualization graph.
col_names.remove("output_spec")
col_names.remove("is_output")
# TODO: Add a column for `is_prefetch` in the logged execution
# summary.
col_names.remove("wait_time_ratio")
col_names.remove("is_prefetch")
# Insert the average processing time column after the max processing time
# column.
index = col_names.index("max_processing_time_ns")
Expand Down Expand Up @@ -273,6 +277,8 @@ class StatsConfig:
# Whether this transformation mutates the element spec. This is used to
# determine element spec of the current transformation.
transform_mutates_spec: bool = True
# Whether this transformation is a prefetch transformation.
is_prefetch: bool = False
# Whether to log the execution summary.
log_summary: bool = False

Expand Down Expand Up @@ -539,6 +545,7 @@ def _build_execution_summary(
self._summary.name = self._config.name
self._summary.output_spec = str(self.output_spec)
self._summary.is_output = self._is_output
self._summary.is_prefetch = self._config.is_prefetch
execution_summary.nodes.get_or_create(node_id)
execution_summary.nodes[node_id].CopyFrom(self._summary)
current_node_id = node_id
Expand Down
16 changes: 16 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _stats(self):
dataset_stats.StatsConfig(
name=str(self),
transform_mutates_spec=self._MUTATES_ELEMENT_SPEC,
is_prefetch=True,
),
(parent_stats,),
execution_tracking_mode,
Expand Down Expand Up @@ -452,6 +453,21 @@ def __init__(
_LAST_WORKER_INDEX: -1,
}

@functools.cached_property
def _stats(self):
config = dataset_stats.StatsConfig(
name=str(self),
transform_mutates_spec=self._MUTATES_ELEMENT_SPEC,
is_prefetch=True,
)
return dataset_stats.make_stats(
config,
[],
execution_tracking_mode=(
self._ctx.dataset_options.execution_tracking_mode
),
)

def __iter__(self) -> dataset.DatasetIterator[T]:
return self

Expand Down

0 comments on commit 1fab87c

Please sign in to comment.