Skip to content

Commit

Permalink
Make jax dependency optional.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712534349
  • Loading branch information
iindyk authored and copybara-github committed Jan 6, 2025
1 parent 20f1174 commit 3099aec
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: tests
name: Build & Test

on:
push:
Expand Down
1 change: 0 additions & 1 deletion grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ py_library(
deps = [
"//grain/_src/core:config", # build_cleaner: keep
"//grain/_src/core:constants", # build_cleaner: keep
"//grain/_src/core:grain_random", # build_cleaner: keep
"//grain/_src/core:sharding", # build_cleaner: keep
],
)
Expand Down
4 changes: 4 additions & 0 deletions grain/_src/core/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def assert_same_structure(a, b):
def flatten(structure):
return tree_util.tree_flatten(structure)[0]

def flatten_with_path(structure):
return tree_util.tree_flatten_with_path(structure)[0]

def unflatten_as(structure, flat_sequence):
return tree_util.tree_unflatten(
tree_util.tree_structure(structure), flat_sequence
Expand Down Expand Up @@ -132,6 +135,7 @@ def _shape(obj):
map_structure_with_path = tree.map_structure_with_path
assert_same_structure = tree.assert_same_structure
flatten = tree.flatten
flatten_with_path = tree.flatten_with_path
unflatten_as = tree.unflatten_as

def spec_like(structure):
Expand Down
9 changes: 9 additions & 0 deletions grain/_src/core/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def assert_same_structure(self, a, b):
def flatten(self, structure):
...

def flatten_with_path(self, structure):
...

def unflatten_as(self, structure, flat_sequence):
...

Expand Down Expand Up @@ -85,6 +88,12 @@ def test_assert_same_structure(self):
def test_flatten(self):
self.assertEqual(tree.flatten({"A": "v2", "B": "v1"}), ["v2", "v1"])

def test_flatten_with_path(self):
result = tree.flatten_with_path({"A": "v2", "B": "v1"})
# Maybe extract keys from path elements.
result = tree.map_structure(lambda x: getattr(x, "key", x), result)
self.assertEqual(result, [(("A",), "v2"), (("B",), "v1")])

def test_unflatten_as(self):
self.assertEqual(
tree.unflatten_as({"A": "v2", "B": "v1"}, [1, 2]), {"A": 1, "B": 2}
Expand Down
22 changes: 13 additions & 9 deletions grain/_src/python/checkpoint_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@
from etils import epath
from grain._src.python import data_loader
from grain._src.python.dataset import dataset
import jax

IteratorType = TypeVar(
"IteratorType", data_loader.PyGrainDatasetIterator, dataset.DatasetIterator
)


def _get_process_index_and_count():
try:
import jax # pylint:disable=g-import-not-at-top # pytype:disable=import-error

return jax.process_index(), jax.process_count()
except ImportError:
return 0, 1


# Ipmlements orbax.checkpoint.CheckpointHandler.
class PyGrainCheckpointHandler:
"""Orbax CheckpointHandler for PyGrain iterators."""
Expand All @@ -44,10 +52,8 @@ def save(
state = json.dumps(item.get_state(), indent=4)
else:
state = item.get_state().decode()
filename = (
directory
/ f"process_{jax.process_index()}-of-{jax.process_count()}.json"
)
process_index, process_count = _get_process_index_and_count()
filename = directory / f"process_{process_index}-of-{process_count}.json"
filename.write_text(state)

def restore(
Expand All @@ -58,10 +64,8 @@ def restore(
) -> IteratorType:
"""Restores the given iterator from the checkpoint in `directory`."""
item = item or args.item # pytype:disable=attribute-error
filename = (
directory
/ f"process_{jax.process_index()}-of-{jax.process_count()}.json"
)
process_index, process_count = _get_process_index_and_count()
filename = directory / f"process_{process_index}-of-{process_count}.json"
if not filename.exists():
raise ValueError(f"File {filename} does not exist.")
state = filename.read_text()
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ py_library(
srcs = ["packing_packed_batch.py"],
srcs_version = "PY3",
deps = [
"//grain/_src/core:tree",
],
)

Expand Down
37 changes: 26 additions & 11 deletions grain/_src/python/dataset/transformations/packing_packed_batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides a helper class for multi-bin first-fit packing.
Example packing is a step in many input pipelines for sequence to sequence
Expand All @@ -20,10 +33,9 @@
import dataclasses
from typing import Generic, TypeVar

import jax
from grain._src.core import tree
import jaxtyping as jt
import numpy as np
import tree


_T = TypeVar("_T")
Expand Down Expand Up @@ -110,19 +122,19 @@ def make_packed_buffer(length: int, x: np.ndarray | int):
dtype=dtype,
)

self._values = jax.tree.map(
self._values = tree.map_structure(
make_packed_buffer, length_struct, element_for_shapes
)

def make_packed_aux_info(length: int):
return zeros(shape=(num_packing_bins, length), dtype=np.int32)

self._segment_ids = jax.tree.map(make_packed_aux_info, length_struct)
self._positions = jax.tree.map(make_packed_aux_info, length_struct)
self._segment_ids = tree.map_structure(make_packed_aux_info, length_struct)
self._positions = tree.map_structure(make_packed_aux_info, length_struct)

# Tracks the next empty position to insert an example for each row
# in the batch, for each feature in features_to_pack.
self._first_free_cell_per_row = jax.tree.map(
self._first_free_cell_per_row = tree.map_structure(
lambda _: zeros(num_packing_bins, dtype=np.int64), length_struct
)

Expand All @@ -131,14 +143,17 @@ def make_packed_aux_info(length: int):
self._num_examples_per_row = [0 for _ in range(num_packing_bins)]

def get_packed_batch(self):
"""Returns the current packed batch."""
rows_with_values = sum(x > 0 for x in self._num_examples_per_row)
if rows_with_values < len(self._num_examples_per_row):
# Partial batch, last rows don't have values.
self._values = jax.tree.map(lambda x: x[:rows_with_values], self._values)
self._segment_ids = jax.tree.map(
self._values = tree.map_structure(
lambda x: x[:rows_with_values], self._values
)
self._segment_ids = tree.map_structure(
lambda x: x[:rows_with_values], self._segment_ids
)
self._positions = jax.tree.map(
self._positions = tree.map_structure(
lambda x: x[:rows_with_values], self._positions
)
return _extract_and_rekey_packed_batch(
Expand All @@ -163,7 +178,7 @@ def _can_add_at_row(
"""
tree.assert_same_structure(element, self._length_struct)

element_feature_lengths = jax.tree.map(
element_feature_lengths = tree.map_structure(
lambda x: 1 if np.ndim(x) == 0 else len(x), element
)

Expand Down Expand Up @@ -191,7 +206,7 @@ def _feature_will_fit(feature_length, first_free_cell, max_length):
return feature_length + first_free_cell <= max_length

is_row_free_struct = tree.flatten_with_path(
jax.tree.map(
tree.map_structure(
_feature_will_fit,
element_feature_lengths,
self._first_free_cell_per_row,
Expand Down
17 changes: 9 additions & 8 deletions grain/_src/python/experimental/example_packing/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from absl import logging
from grain._src.core import tree
from grain._src.python import record
import jax
import jaxtyping as jt
import numpy as np

Expand All @@ -46,19 +45,21 @@ def make_packed_buffer(length: int, input_arr: np.ndarray):
dtype=input_arr.dtype,
)

self._batch = jax.tree.map(
self._batch = tree.map_structure(
make_packed_buffer, length_struct, element_for_shapes
)

def make_packed_aux_info(length: int):
return np.zeros(shape=(batch_size, length), dtype=np.int32)

self._segmentations = jax.tree.map(make_packed_aux_info, length_struct)
self._positions = jax.tree.map(make_packed_aux_info, length_struct)
self._segmentations = tree.map_structure(
make_packed_aux_info, length_struct
)
self._positions = tree.map_structure(make_packed_aux_info, length_struct)

# Tracks the next empty position to insert an example for each row
# in the batch, for each feature in features_to_pack.
self._first_free_cell_per_row = jax.tree.map(
self._first_free_cell_per_row = tree.map_structure(
lambda _: np.zeros(batch_size, dtype=np.int32), length_struct
)

Expand All @@ -79,10 +80,10 @@ def get_packed_batch(self) -> record.Record[tuple[_T, _T, _T]]:

def _can_add_at_row(self, element: jt.PyTree[np.ndarray]) -> int:
"""Returns the index of the first row which fits element, or -1 if none."""
element_feature_lengths = jax.tree.map(len, element)
element_feature_lengths = tree.map_structure(len, element)

# Check no feature exceeds max length
length_exceeded = jax.tree.map(
length_exceeded = tree.map_structure(
lambda feature_length, max_length: feature_length > max_length,
element_feature_lengths,
self._length_struct,
Expand All @@ -97,7 +98,7 @@ def _can_add_at_row(self, element: jt.PyTree[np.ndarray]) -> int:
def _feature_will_fit(feature_length, first_free_cell, max_length):
return feature_length + first_free_cell <= max_length

is_row_free_struct = jax.tree.map(
is_row_free_struct = tree.map_structure(
_feature_will_fit,
element_feature_lengths,
self._first_free_cell_per_row,
Expand Down
12 changes: 10 additions & 2 deletions grain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@

# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=g-importing-member

from ._src.core import grain_random as random
from ._src.core.config import config
from ._src.core.constants import DATASET_INDEX, EPOCH, INDEX, META_FEATURES, RECORD, RECORD_KEY, SEED
from ._src.core.constants import (
DATASET_INDEX,
EPOCH,
INDEX,
META_FEATURES,
RECORD,
RECORD_KEY,
SEED,
)
from ._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions

0 comments on commit 3099aec

Please sign in to comment.