Skip to content

Commit

Permalink
Improve Cannot serialize host local jax.Array error message.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730455747
  • Loading branch information
niketkumar authored and Orbax Authors committed Feb 24, 2025
1 parent 7ddcff2 commit 8f07ef9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
6 changes: 5 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Improve `Cannot serialize host local jax.Array` error message.

## [0.11.6] - 2025-02-20

### Added
Expand All @@ -18,7 +22,7 @@ types and save/load free functions.
`ocp.Future`, but serves as a generic type representing a result from a
background task. It also incorporates typing, in contrast with its predecessor.

## Changed
### Changed

- Move process_metadata to the checkpoint_num directory
- Updated default JAX version requirement to 0.5.0.
Expand Down
34 changes: 24 additions & 10 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,16 @@ def check_input_arguments(*args):
raise ValueError('Found input args with mismatched lengths.')


def check_array_values(values: Sequence[Union[jax.Array, np.ndarray]]):
for v in values:
def check_array_values(
values: Sequence[Union[jax.Array, np.ndarray]],
infos: Sequence[types.ParamInfo],
):
for v, info in zip(values, infos):
if v.size == 0:
raise ValueError('Cannot save arrays with zero size.')
raise ValueError(
f'Cannot save arrays with zero size: ParamInfo: [name={info.name},'
f'value_typestr={info.value_typestr}]'
)


async def _validate_params(
Expand Down Expand Up @@ -637,7 +643,7 @@ async def serialize(
"""Uses Tensorstore to serialize a numpy array."""
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)
check_array_values(values)
check_array_values(values, infos)
if logging.vlog_is_on(1):
_print_ts_debug_data(self._metadata_key, infos)
copied_values = [copy.deepcopy(v) for v in values]
Expand Down Expand Up @@ -1094,22 +1100,30 @@ async def serialize(
args: Optional[Sequence[types.SaveArgs]] = None,
) -> Sequence[future.Future]:
"""See superclass documentation."""
for v in values:
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)
check_array_values(values, infos)
for v, info in zip(values, infos):
if (
isinstance(v, jax.Array)
and jax.process_count() > 1
and v.is_fully_addressable
):
debug_param_info = (
f'ParamInfo=[name={info.name},value_typestr={info.value_typestr}]'
)
debug_array = (
f'jax.Array=[value={v},shape={v.shape},dtype={v.dtype},'
f'sharding={v.sharding},device={v.device}]'
)
raise ValueError(
'Cannot serialize host local arrays. Arrays like this are typically'
' obtained using pmap. Consider using'
f'Cannot serialize host local jax.Array ({debug_param_info},'
f' {debug_array}) in multi-host setting. Arrays like this are'
' typically obtained using pmap. Consider using'
' fully_replicated_host_local_array_to_global_array in'
' orbax/checkpoint/utils.py to convert your arrays into'
' serializable objects.'
)
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)
check_array_values(values)

assert all([info.enable_pinned_host_transfer for info in infos]) or all(
[not info.enable_pinned_host_transfer for info in infos]
Expand Down

0 comments on commit 8f07ef9

Please sign in to comment.