Skip to content

Commit

Permalink
Factor out handling of special nodes like optax.EmptyState. Clients s…
Browse files Browse the repository at this point in the history
…hould perform pre- and post-processing to do this on their end. This conceptually simplifies PyTreeCheckpointHandler and lessens the need to concern ourselves with the PyTree / state_dict distinction.

PiperOrigin-RevId: 466125368
  • Loading branch information
cpgaffney1 authored and copybara-github committed Aug 8, 2022
1 parent d45ad9f commit c6c9540
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 46 deletions.
17 changes: 14 additions & 3 deletions orbax/checkpoint/checkpointer_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,26 @@ def init_state():
with Mesh(mesh.devices, mesh.axis_names):
state = init()
state_shape = jax.eval_shape(init)
processed_state, processed_state_shape = test_utils.preprocess_flax_pytree(
state, state, state_shape)

restore_args = jax.tree_map(
lambda _: RestoreArgs(mesh=mesh, mesh_axes=mesh_axes), state_shape)
lambda _: RestoreArgs(mesh=mesh, mesh_axes=mesh_axes),
processed_state_shape)

checkpointer = self.checkpointer(PyTreeCheckpointHandler())
checkpointer.save(self.directory, state)
checkpointer.save(self.directory, processed_state)
self.wait_if_async(checkpointer)
# Already fully replicated, don't need to provide args.
restored = checkpointer.restore(
self.directory, item=state_shape, restore_args=restore_args)
self.directory, item=processed_state_shape, restore_args=restore_args)

test_utils.assert_tree_equal(self, processed_state['params'],
restored['params'])
test_utils.assert_tree_equal(self, processed_state['opt_state'],
restored['opt_state'])

restored = test_utils.postprocess_flax_pytree(state, restored)

test_utils.assert_tree_equal(self, state.params, restored.params)
test_utils.assert_tree_equal(self, state.opt_state, restored.opt_state)
83 changes: 40 additions & 43 deletions orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,6 @@ async def _deserialize_array(
param_info: ParamInfo) -> Union[ArrayOrScalar, GlobalDeviceArray]:
"""Writes an array (np.ndarray) or GlobalDeviceArray."""
tspec = param_info.tspec
if not epath.Path(tspec['kvstore']['path']).iterdir():
# Empty dictionaries are written as directories containing no files.
return {}

if restore_args.dtype is not None:
tspec = {
Expand Down Expand Up @@ -309,37 +306,38 @@ def __init__(self, enable_flax=True):
self._structure = None
self._enable_flax = enable_flax

def _get_param_names(self, state_dict: PyTree) -> PyTree:
def _get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
state_dict = utils.to_state_dict(item)
names = traverse_util.unflatten_dict({
k: '.'.join(k)
for k in traverse_util.flatten_dict(state_dict, keep_empty_nodes=True)
})
return names
r = flax.serialization.from_state_dict(item, names)
return r

def _get_param_infos(self, state_dict: PyTree,
directory: epath.Path) -> PyTree:
"""Returns parameter information for elements in `state_dict`.
def _get_param_infos(self, item: PyTree, directory: epath.Path) -> PyTree:
"""Returns parameter information for elements in `item`.
At minimum, this method should extract the names of each parameter for
saving/restoring.
Args:
state_dict: a nested dict to extract information from.
item: a PyTree to extract information from.
directory: a directory where checkpoint files are located.
Returns:
A PyTree matching `state_dict` of ParamInfo.
A PyTree matching `item` of ParamInfo.
"""
if not state_dict:
raise ValueError('Found empty state_dict')
names = self._get_param_names(state_dict)
if not item:
raise ValueError('Found empty item')
names = self._get_param_names(item)

def _param_info(name, arr_or_spec):
path = directory / name
return ParamInfo(name, _get_tensorstore_spec(path, arr_or_spec))

return jax.tree_map(_param_info, names, state_dict)
return jax.tree_map(_param_info, names, item)

async def async_save(
self,
Expand Down Expand Up @@ -367,8 +365,6 @@ async def async_save(
A Future that will commit the data to `directory` when awaited. Copying
the data from its source will be awaited in this function.
"""
# Convert arbitrary pytree into dictionary.
item = utils.to_state_dict(item)
item = await lazy_array.maybe_get_tree_async(item)

if save_args is None:
Expand Down Expand Up @@ -424,7 +420,8 @@ def restore(self,
directory: epath.Path,
item: Optional[PyTree] = None,
restore_args: Optional[PyTree] = None,
transforms: Optional[PyTree] = None) -> PyTree:
transforms: Optional[PyTree] = None,
transforms_default_to_original: bool = True) -> PyTree:
"""Restores a PyTree from the checkpoint directory at the given step.
Optional arguments meshes and mesh_axes define how each array in the
Expand All @@ -445,6 +442,7 @@ def restore(self,
for further information. Note that if `transforms` is provided, all
other arguments are structured relative to the saved object, not to the
restored object post-transformation.
transforms_default_to_original: See transform_utils.apply_transformations.
Returns:
A PyTree matching the structure of `item` with restored array values as
Expand All @@ -461,52 +459,51 @@ def restore(self,
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}')

# Note: because of None values in `state_dict`, avoid jax.tree_map with it
# in first position. Use param_infos.
state_dict = self.structure(directory)
param_infos = self._get_param_infos(state_dict, directory)

if transforms is not None:
if item is None:
structure = self.structure(directory)
param_infos = self._get_param_infos(structure, directory)
if item is None:
if transforms is not None:
raise ValueError(
'If providing `transforms`, must provide `item` matching structure of expected result.'
)
if transform_utils.has_value_functions(transforms):
raise ValueError(
'Found disallowed `value_fn` or `multi_value-fn` in `transforms`.')
state_dict = transform_utils.apply_transformations(
self.structure(directory), transforms, utils.to_state_dict(item))
# param_infos must be transformed because the ts.Spec of saved params may
# correspond to the original structure rather than the new.
param_infos = transform_utils.apply_transformations(
param_infos, transforms, utils.to_state_dict(item))
item = structure
else:
if transforms is None:
item = flax.serialization.from_state_dict(item, structure)
param_infos = flax.serialization.from_state_dict(item, param_infos)
else:
if transform_utils.has_value_functions(transforms):
raise ValueError(
'Found disallowed `value_fn` or `multi_value_fn` in `transforms`.'
)
item = transform_utils.apply_transformations(
structure, transforms, item, transforms_default_to_original)
# param_infos must be transformed because the ts.Spec of saved params
# may correspond to the original structure rather than the new.
param_infos = transform_utils.apply_transformations(
param_infos, transforms, item, transforms_default_to_original)

if restore_args is None:
restore_args = jax.tree_map(lambda x: RestoreArgs(as_gda=False),
state_dict)
else:
restore_args = utils.to_state_dict(restore_args)
restore_args = jax.tree_map(lambda x: RestoreArgs(as_gda=False), item)

future_arrays = jax.tree_map(
_maybe_deserialize,
restore_args,
state_dict,
item,
param_infos,
is_leaf=lambda args: isinstance(args, RestoreArgs) or not args)

future_arrays, state_dict_def = jax.tree_flatten(future_arrays)
future_arrays, item_def = jax.tree_flatten(future_arrays)

async def _async_restore(future_arrays):
return await asyncio.gather(*future_arrays)

result = asyncio.run(_async_restore(future_arrays))

restored_state_dict = jax.tree_unflatten(state_dict_def, result)
# Convert back into original object.
restored = flax.serialization.from_state_dict(item, restored_state_dict)
restored_item = jax.tree_unflatten(item_def, result)

multihost_utils.sync_global_devices('PyTreeCheckpointHandler:restore')
return restored
return restored_item

def structure(self, directory: epath.Path) -> PyTree:
"""Restores the PyTree structure from a Flax checkpoint.
Expand Down
37 changes: 37 additions & 0 deletions orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Optional

from etils import epath
import flax.serialization
from flax.training.train_state import TrainState
import jax
from jax.experimental import multihost_utils
Expand All @@ -27,6 +28,7 @@
import numpy as np
import optax
from orbax.checkpoint import pytree_checkpoint_handler
from orbax.checkpoint import transform_utils


def save_fake_tmp_dir(directory: epath.Path,
Expand Down Expand Up @@ -164,3 +166,38 @@ def init_flax_model(model):
tx = optax.adamw(learning_rate=0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
return jax.tree_map(np.asarray, state)


def _filter(mask_tree, *other_trees):
"""Filter optax EmptyState and MaskedNode out of PyTree."""
if not isinstance(mask_tree, dict):
return other_trees
result_trees = [{} for _ in other_trees]
for k, v in mask_tree.items():
if not isinstance(v, (optax.EmptyState, optax.MaskedNode)):
values = _preprocess_helper(v, *(t[k] for t in other_trees))
for i, v1 in enumerate(values):
if isinstance(v1, dict):
if v1:
result_trees[i][k] = v1
else:
result_trees[i][k] = v1
return tuple(result_trees)


def _preprocess_helper(mask_tree, *other_trees):
return jax.tree_map(
_filter, mask_tree, *other_trees, is_leaf=lambda x: isinstance(x, dict))


def preprocess_flax_pytree(mask_tree, *other_trees):
mask_tree = flax.serialization.to_state_dict(mask_tree)
other_trees = (flax.serialization.to_state_dict(t) for t in other_trees)
r = _preprocess_helper(mask_tree, *other_trees)
return r[0] if len(r) == 1 else r


def postprocess_flax_pytree(reference, processed):
"""Merge trees by taking values from processed or reference if key not present."""
transforms = {}
return transform_utils.apply_transformations(processed, transforms, reference)

0 comments on commit c6c9540

Please sign in to comment.