From c6c954072a964080c819f8b6c3fcadf7cd02446c Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Mon, 8 Aug 2022 13:02:57 -0700 Subject: [PATCH] Factor out handling of special nodes like optax.EmptyState. Clients should perform pre- and post-processing to do this on their end. This conceptually simplifies PyTreeCheckpointHandler and lessens the need to concern ourselves with the PyTree / state_dict distinction. PiperOrigin-RevId: 466125368 --- orbax/checkpoint/checkpointer_test_utils.py | 17 +++- orbax/checkpoint/pytree_checkpoint_handler.py | 83 +++++++++---------- orbax/checkpoint/test_utils.py | 37 +++++++++ 3 files changed, 91 insertions(+), 46 deletions(-) diff --git a/orbax/checkpoint/checkpointer_test_utils.py b/orbax/checkpoint/checkpointer_test_utils.py index 03200553..4be1944a 100644 --- a/orbax/checkpoint/checkpointer_test_utils.py +++ b/orbax/checkpoint/checkpointer_test_utils.py @@ -159,15 +159,26 @@ def init_state(): with Mesh(mesh.devices, mesh.axis_names): state = init() state_shape = jax.eval_shape(init) + processed_state, processed_state_shape = test_utils.preprocess_flax_pytree( + state, state, state_shape) restore_args = jax.tree_map( - lambda _: RestoreArgs(mesh=mesh, mesh_axes=mesh_axes), state_shape) + lambda _: RestoreArgs(mesh=mesh, mesh_axes=mesh_axes), + processed_state_shape) checkpointer = self.checkpointer(PyTreeCheckpointHandler()) - checkpointer.save(self.directory, state) + checkpointer.save(self.directory, processed_state) self.wait_if_async(checkpointer) # Already fully replicated, don't need to provide args. restored = checkpointer.restore( - self.directory, item=state_shape, restore_args=restore_args) + self.directory, item=processed_state_shape, restore_args=restore_args) + + test_utils.assert_tree_equal(self, processed_state['params'], + restored['params']) + test_utils.assert_tree_equal(self, processed_state['opt_state'], + restored['opt_state']) + + restored = test_utils.postprocess_flax_pytree(state, restored) + test_utils.assert_tree_equal(self, state.params, restored.params) test_utils.assert_tree_equal(self, state.opt_state, restored.opt_state) diff --git a/orbax/checkpoint/pytree_checkpoint_handler.py b/orbax/checkpoint/pytree_checkpoint_handler.py index 222afe16..a8e12408 100644 --- a/orbax/checkpoint/pytree_checkpoint_handler.py +++ b/orbax/checkpoint/pytree_checkpoint_handler.py @@ -271,9 +271,6 @@ async def _deserialize_array( param_info: ParamInfo) -> Union[ArrayOrScalar, GlobalDeviceArray]: """Writes an array (np.ndarray) or GlobalDeviceArray.""" tspec = param_info.tspec - if not epath.Path(tspec['kvstore']['path']).iterdir(): - # Empty dictionaries are written as directories containing no files. - return {} if restore_args.dtype is not None: tspec = { @@ -309,37 +306,38 @@ def __init__(self, enable_flax=True): self._structure = None self._enable_flax = enable_flax - def _get_param_names(self, state_dict: PyTree) -> PyTree: + def _get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" + state_dict = utils.to_state_dict(item) names = traverse_util.unflatten_dict({ k: '.'.join(k) for k in traverse_util.flatten_dict(state_dict, keep_empty_nodes=True) }) - return names + r = flax.serialization.from_state_dict(item, names) + return r - def _get_param_infos(self, state_dict: PyTree, - directory: epath.Path) -> PyTree: - """Returns parameter information for elements in `state_dict`. + def _get_param_infos(self, item: PyTree, directory: epath.Path) -> PyTree: + """Returns parameter information for elements in `item`. At minimum, this method should extract the names of each parameter for saving/restoring. Args: - state_dict: a nested dict to extract information from. + item: a PyTree to extract information from. directory: a directory where checkpoint files are located. Returns: - A PyTree matching `state_dict` of ParamInfo. + A PyTree matching `item` of ParamInfo. """ - if not state_dict: - raise ValueError('Found empty state_dict') - names = self._get_param_names(state_dict) + if not item: + raise ValueError('Found empty item') + names = self._get_param_names(item) def _param_info(name, arr_or_spec): path = directory / name return ParamInfo(name, _get_tensorstore_spec(path, arr_or_spec)) - return jax.tree_map(_param_info, names, state_dict) + return jax.tree_map(_param_info, names, item) async def async_save( self, @@ -367,8 +365,6 @@ async def async_save( A Future that will commit the data to `directory` when awaited. Copying the data from its source will be awaited in this function. """ - # Convert arbitrary pytree into dictionary. - item = utils.to_state_dict(item) item = await lazy_array.maybe_get_tree_async(item) if save_args is None: @@ -424,7 +420,8 @@ def restore(self, directory: epath.Path, item: Optional[PyTree] = None, restore_args: Optional[PyTree] = None, - transforms: Optional[PyTree] = None) -> PyTree: + transforms: Optional[PyTree] = None, + transforms_default_to_original: bool = True) -> PyTree: """Restores a PyTree from the checkpoint directory at the given step. Optional arguments meshes and mesh_axes define how each array in the @@ -445,6 +442,7 @@ def restore(self, for further information. Note that if `transforms` is provided, all other arguments are structured relative to the saved object, not to the restored object post-transformation. + transforms_default_to_original: See transform_utils.apply_transformations. Returns: A PyTree matching the structure of `item` with restored array values as @@ -461,52 +459,51 @@ def restore(self, raise FileNotFoundError( f'Requested directory for restore does not exist at {directory}') - # Note: because of None values in `state_dict`, avoid jax.tree_map with it - # in first position. Use param_infos. - state_dict = self.structure(directory) - param_infos = self._get_param_infos(state_dict, directory) - - if transforms is not None: - if item is None: + structure = self.structure(directory) + param_infos = self._get_param_infos(structure, directory) + if item is None: + if transforms is not None: raise ValueError( 'If providing `transforms`, must provide `item` matching structure of expected result.' ) - if transform_utils.has_value_functions(transforms): - raise ValueError( - 'Found disallowed `value_fn` or `multi_value-fn` in `transforms`.') - state_dict = transform_utils.apply_transformations( - self.structure(directory), transforms, utils.to_state_dict(item)) - # param_infos must be transformed because the ts.Spec of saved params may - # correspond to the original structure rather than the new. - param_infos = transform_utils.apply_transformations( - param_infos, transforms, utils.to_state_dict(item)) + item = structure + else: + if transforms is None: + item = flax.serialization.from_state_dict(item, structure) + param_infos = flax.serialization.from_state_dict(item, param_infos) + else: + if transform_utils.has_value_functions(transforms): + raise ValueError( + 'Found disallowed `value_fn` or `multi_value_fn` in `transforms`.' + ) + item = transform_utils.apply_transformations( + structure, transforms, item, transforms_default_to_original) + # param_infos must be transformed because the ts.Spec of saved params + # may correspond to the original structure rather than the new. + param_infos = transform_utils.apply_transformations( + param_infos, transforms, item, transforms_default_to_original) if restore_args is None: - restore_args = jax.tree_map(lambda x: RestoreArgs(as_gda=False), - state_dict) - else: - restore_args = utils.to_state_dict(restore_args) + restore_args = jax.tree_map(lambda x: RestoreArgs(as_gda=False), item) future_arrays = jax.tree_map( _maybe_deserialize, restore_args, - state_dict, + item, param_infos, is_leaf=lambda args: isinstance(args, RestoreArgs) or not args) - future_arrays, state_dict_def = jax.tree_flatten(future_arrays) + future_arrays, item_def = jax.tree_flatten(future_arrays) async def _async_restore(future_arrays): return await asyncio.gather(*future_arrays) result = asyncio.run(_async_restore(future_arrays)) - restored_state_dict = jax.tree_unflatten(state_dict_def, result) - # Convert back into original object. - restored = flax.serialization.from_state_dict(item, restored_state_dict) + restored_item = jax.tree_unflatten(item_def, result) multihost_utils.sync_global_devices('PyTreeCheckpointHandler:restore') - return restored + return restored_item def structure(self, directory: epath.Path) -> PyTree: """Restores the PyTree structure from a Flax checkpoint. diff --git a/orbax/checkpoint/test_utils.py b/orbax/checkpoint/test_utils.py index 3a5ecb40..d2bd1de8 100644 --- a/orbax/checkpoint/test_utils.py +++ b/orbax/checkpoint/test_utils.py @@ -17,6 +17,7 @@ from typing import List, Optional from etils import epath +import flax.serialization from flax.training.train_state import TrainState import jax from jax.experimental import multihost_utils @@ -27,6 +28,7 @@ import numpy as np import optax from orbax.checkpoint import pytree_checkpoint_handler +from orbax.checkpoint import transform_utils def save_fake_tmp_dir(directory: epath.Path, @@ -164,3 +166,38 @@ def init_flax_model(model): tx = optax.adamw(learning_rate=0.001) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return jax.tree_map(np.asarray, state) + + +def _filter(mask_tree, *other_trees): + """Filter optax EmptyState and MaskedNode out of PyTree.""" + if not isinstance(mask_tree, dict): + return other_trees + result_trees = [{} for _ in other_trees] + for k, v in mask_tree.items(): + if not isinstance(v, (optax.EmptyState, optax.MaskedNode)): + values = _preprocess_helper(v, *(t[k] for t in other_trees)) + for i, v1 in enumerate(values): + if isinstance(v1, dict): + if v1: + result_trees[i][k] = v1 + else: + result_trees[i][k] = v1 + return tuple(result_trees) + + +def _preprocess_helper(mask_tree, *other_trees): + return jax.tree_map( + _filter, mask_tree, *other_trees, is_leaf=lambda x: isinstance(x, dict)) + + +def preprocess_flax_pytree(mask_tree, *other_trees): + mask_tree = flax.serialization.to_state_dict(mask_tree) + other_trees = (flax.serialization.to_state_dict(t) for t in other_trees) + r = _preprocess_helper(mask_tree, *other_trees) + return r[0] if len(r) == 1 else r + + +def postprocess_flax_pytree(reference, processed): + """Merge trees by taking values from processed or reference if key not present.""" + transforms = {} + return transform_utils.apply_transformations(processed, transforms, reference)