Skip to content

Commit

Permalink
Introduce retry module to retry coroutines.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731379450
  • Loading branch information
niketkumar authored and Orbax Authors committed Feb 27, 2025
1 parent acec3f3 commit 2beaa27
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 1 deletion.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Introduce `retry` module to retry coroutines.

### Changed

- Improve `Cannot serialize host local jax.Array` error message.
Expand Down
5 changes: 5 additions & 0 deletions checkpoint/orbax/checkpoint/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ py_library(
name = "threading",
srcs = ["threading.py"],
)

py_library(
name = "retry",
srcs = ["retry.py"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def __init__(
type_handler_registry
)
)
if self._array_metadata_store:
self._array_metadata_store.set_primary_host(self._primary_host)
self._array_metadata_validator = array_metadata_validator


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,44 @@ def __init__(
self,
path_resolver: PathResolver = PathResolver(),
ser_deser: SerDeserializer = SerDeserializer(),
primary_host: int | None = 0, # None means all hosts are primary hosts.
write_timeout_secs: int = 600, # 10 minutes.
):
self._path_resolver = path_resolver
self._ser_deser = ser_deser
self._primary_host = primary_host
self._write_timeout_secs = write_timeout_secs

def set_primary_host(self, primary_host: int | None) -> None:
"""Sets the primary host."""
self._primary_host = primary_host

async def _maybe_create_base_dir(self, base_dir: epath.Path) -> None:
"""Creates the base directory if it does not exist."""
# Create the base dir/folder if it does not exist.
if multihost.is_primary_host(self._primary_host):
# primary host creates, rest of the hosts wait.
await asyncio.to_thread(base_dir.mkdir, parents=True, exist_ok=True)
else:
# non-primary host waits for primary host to create the base dir/folder.
async def wait_for_base_dir_creation():
while not await asyncio.to_thread(base_dir.exists):
await asyncio.sleep(0.25)

try:
await asyncio.wait_for(
wait_for_base_dir_creation(), timeout=self._write_timeout_secs
)
except asyncio.TimeoutError as e:
primary_process = (
'LOCAL' if self._primary_host is None else self._primary_host
)
raise ValueError(
f'[process_index={multihost.process_index()}] Timed out waiting for'
f' array_metadatas base directory creation: {base_dir}.'
f' timeout={self._write_timeout_secs} seconds.'
f' primary_process={primary_process}'
) from e

async def write(
self,
Expand All @@ -155,7 +190,7 @@ async def write(
file_path = self._path_resolver.get_write_file_path(
checkpoint_dir, process_index
)
await asyncio.to_thread(file_path.parent.mkdir, parents=True, exist_ok=True)
await self._maybe_create_base_dir(file_path.parent)
await asyncio.to_thread(
file_path.write_text, self._ser_deser.serialize(array_metadatas)
)
Expand Down
12 changes: 12 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Provides async variants of path functions."""

import asyncio
from typing import Any

from etils import epath
Expand Down Expand Up @@ -57,3 +58,14 @@ def async_rmtree(path: epath.Path):

def async_is_tmp_checkpoint(path: epath.Path):
return asyncio_utils.as_async_function(step_lib.is_tmp_checkpoint)(path)


async def maybe_mkdir(
path: epath.PathLike, mode: int = 0o777, parents: bool = False
) -> bool:
"""Creates a new directory at `path` if it does not exist."""
path = epath.Path(path)
if await asyncio.to_thread(path.exists):
return False
await asyncio.to_thread(path.mkdir, mode=mode, parents=parents)
return True
92 changes: 92 additions & 0 deletions checkpoint/orbax/checkpoint/_src/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Retry strategies."""

import asyncio
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from absl import logging

_R = TypeVar('_R')


async def retry(
*,
awaitable_factory: Callable[[], Awaitable[_R]],
retry_on_result: Callable[[Any], bool],
retry_on_exception: Callable[[Exception], bool],
sleep_for_secs: Callable[[_R | None, Exception | None, int], float],
max_retries: int,
) -> _R:
"""Retries an awaitable based on result or exception.
Args:
awaitable_factory: Creates the awaitable to retry. It is needed because an
awaitable cannot be awaited more than once.
retry_on_result: A callable that takes the result of the awaitable and
returns True if the awaitable should be retried, False otherwise.
retry_on_exception: A callable that takes an exception and returns True if
the awaitable should be retried, False otherwise.
sleep_for_secs: A callable that takes the result of the awaitable, the
exception raised by the awaitable, and the number of retries remaining,
and returns the number of seconds to sleep between retries.
max_retries: The maximum number of times to retry the awaitable.
Returns:
The result of the awaitable.
Raises:
ValueError: If max_retries is negative.
Exception: Due to the awaitable raising an exception.
"""
if max_retries < 0:
raise ValueError('max_retries must be non-negative.')

awaitable = awaitable_factory()
try:
result = await awaitable_factory()
if max_retries == 0:
return result
if not retry_on_result(result):
return result
sleep_secs = sleep_for_secs(result, None, max_retries)
logging.warning(
'Will retry after %s seconds due to result=%s, awaitable=%s',
sleep_secs,
result,
awaitable,
)
await asyncio.sleep(sleep_secs)
except Exception as e: # pylint: disable=broad-except
if max_retries == 0:
raise
if not retry_on_exception(e):
raise
sleep_secs = sleep_for_secs(None, e, max_retries)
logging.warning(
'Will retry after %s seconds due to exception=%s, awaitable=%s',
sleep_secs,
e,
awaitable,
)
await asyncio.sleep(sleep_secs)

return await retry(
awaitable_factory=awaitable_factory,
retry_on_result=retry_on_result,
retry_on_exception=retry_on_exception,
sleep_for_secs=sleep_for_secs,
max_retries=max_retries - 1,
)

0 comments on commit 2beaa27

Please sign in to comment.