Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rculbertson committed Oct 28, 2024
1 parent 128a718 commit 0e4c9e1
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 49 deletions.
11 changes: 10 additions & 1 deletion modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
from .._serialization import deserialize, deserialize_data_format, serialize
from .._traceback import append_modal_tb
from ..config import config, logger
from ..exception import DeserializationError, ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
from ..exception import (
DeserializationError,
ExecutionError,
FunctionTimeoutError,
InvalidError,
RemoteError,
SystemEventException,
)
from ..mount import ROOT_DIR, _is_modal_path, _Mount
from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES
Expand Down Expand Up @@ -437,6 +444,8 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,

if result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT:
raise FunctionTimeoutError(result.exception)
if result.status == api_pb2.GenericResult.GENERIC_STATUS_SYSTEM_EVENT:
raise SystemEventException(result.exception)
elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
if data:
try:
Expand Down
4 changes: 4 additions & 0 deletions modal/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class SandboxTerminatedError(Error):
"""Raised when a Sandbox is terminated for an internal reason."""


class SystemEventException(Exception):
"""Raised when an input failed due to an internal reason, such as preemption, network loss, or container crash."""


class FunctionTimeoutError(TimeoutError):
"""Raised when a Function exceeds its execution duration limit and times out."""

Expand Down
55 changes: 16 additions & 39 deletions modal/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright Modal Labs 2023
import asyncio
import inspect
import textwrap
import time
Expand Down Expand Up @@ -63,9 +62,11 @@
from .config import config
from .exception import (
ExecutionError,
FunctionTimeoutError,
InvalidError,
NotFoundError,
OutputExpiredError,
SystemEventException,
deprecation_error,
deprecation_warning,
)
Expand All @@ -87,7 +88,7 @@
_SynchronizedQueue,
)
from .proxy import _Proxy
from .retries import Retries
from .retries import Retries, RetryManager
from .schedule import Schedule
from .scheduler_placement import SchedulerPlacement
from .secret import _Secret
Expand All @@ -98,8 +99,7 @@
import modal.cls
import modal.partial_function

MIN_INPUT_RETRY_DELAY_MS = 1000
MAX_INPUT_RETRY_DELAY_MS = 24 * 60 * 60 * 1000
SYSTEM_RETRY_POLICY = api_pb2.FunctionRetryPolicy(initial_delay_ms=1000, backoff_coefficient=1.0, retries=100)


class _Invocation:
Expand Down Expand Up @@ -208,18 +208,6 @@ async def retry_input(self) -> None:
if not processed_inputs:
raise Exception(f"Could not retry input {self._item.idx} - the input queue seems to be full")

async def get_sync_output(self) -> Any:
# TODO(ryan): What should timeout be?
get_outputs_timeout = 120
response: api_pb2.FunctionGetOutputsResponse = await self.pop_function_call_outputs(
timeout=get_outputs_timeout, clear_on_success=True
)
if response.outputs:
item: api_pb2.FunctionGetOutputsItem = response.outputs[0]
return await _process_result(item.result, item.data_format, self.stub, self.client)
else:
raise TimeoutError(f"Received zero outputs in {get_outputs_timeout} seconds")

async def run_function(self) -> Any:
# waits indefinitely for a single result for the function, and clear the outputs buffer after
item: api_pb2.FunctionGetOutputsItem = (
Expand Down Expand Up @@ -1311,33 +1299,22 @@ async def _call_function(self, args, kwargs) -> ReturnType:
client=self._client,
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
)
if not self._retry_policy:
return await invocation.run_function()
# User errors including timeouts are managed by the user specified retry policy.
# System events (preemption, network loss, container crashes) are handled according to system retry policy.
user_retry_manager = RetryManager(self._retry_policy) if self._retry_policy else None
system_retry_manager = RetryManager(SYSTEM_RETRY_POLICY)

retry_policy = self._retry_policy
max_retry_count = retry_policy.retries if retry_policy else 1
attempt_count = 0
while True:
try:
return await invocation.get_sync_output()
except UserCodeException as exc:
attempt_count += 1
if attempt_count >= max_retry_count:
return await invocation.run_function()
except SystemEventException as exc:
await system_retry_manager.handle_exception(exc)
except (UserCodeException, FunctionTimeoutError) as exc:
# Retry these errors only if user specified retry policy
if not user_retry_manager:
raise exc
delay_ms = self._retry_delay_ms(attempt_count, retry_policy)
await asyncio.sleep(delay_ms / 1000)
await invocation.retry_input()

@staticmethod
def _retry_delay_ms(attempt_count: int, retry_policy: api_pb2.FunctionRetryPolicy) -> float:
if attempt_count < 1:
raise ValueError(f"Cannot compute retry delay. attempt_count must be at least 1, but was {attempt_count}")
delay_ms = retry_policy.initial_delay_ms * (retry_policy.backoff_coefficient ** (attempt_count - 1))
if delay_ms < MIN_INPUT_RETRY_DELAY_MS:
return MIN_INPUT_RETRY_DELAY_MS
if delay_ms > MAX_INPUT_RETRY_DELAY_MS:
return MAX_INPUT_RETRY_DELAY_MS
return delay_ms
await user_retry_manager.handle_exception(exc)
await invocation.retry_input()

async def _call_function_nowait(
self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
Expand Down
38 changes: 38 additions & 0 deletions modal/retries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright Modal Labs 2022
import asyncio
from datetime import timedelta

from modal_proto import api_pb2

from .exception import InvalidError

MIN_INPUT_RETRY_DELAY_MS = 1000
MAX_INPUT_RETRY_DELAY_MS = 24 * 60 * 60 * 1000


class Retries:
"""Adds a retry policy to a Modal function.
Expand Down Expand Up @@ -103,3 +107,37 @@ def _to_proto(self) -> api_pb2.FunctionRetryPolicy:
initial_delay_ms=self.initial_delay // timedelta(milliseconds=1),
max_delay_ms=self.max_delay // timedelta(milliseconds=1),
)


class RetryManager:
"""
Helper class to apply the specified retry policy.
"""

def __init__(self, retry_policy: api_pb2.FunctionRetryPolicy):
self.retry_policy = retry_policy
self.attempt_count = 0

async def handle_exception(self, exc: Exception):
"""
Raises an exception if the maximum retry count has been reached, otherwise sleeps for calculated delay.
"""
self.attempt_count += 1
if self.attempt_count > self.retry_policy.retries:
raise exc
delay_ms = self._retry_delay_ms(self.attempt_count, self.retry_policy)
await asyncio.sleep(delay_ms / 1000)

@staticmethod
def _retry_delay_ms(attempt_count: int, retry_policy: api_pb2.FunctionRetryPolicy) -> float:
"""
Computes the amount of time to sleep before retrying based on the backend_coefficient and initial_delay_ms args.
"""
if attempt_count < 1:
raise ValueError(f"Cannot compute retry delay. attempt_count must be at least 1, but was {attempt_count}")
delay_ms = retry_policy.initial_delay_ms * (retry_policy.backoff_coefficient ** (attempt_count - 1))
if delay_ms < MIN_INPUT_RETRY_DELAY_MS:
return MIN_INPUT_RETRY_DELAY_MS
if delay_ms > MAX_INPUT_RETRY_DELAY_MS:
return MAX_INPUT_RETRY_DELAY_MS
return delay_ms
7 changes: 7 additions & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,11 @@ message FunctionGetOutputsItem {
string task_id = 6;
double input_started_at = 7;
double output_created_at = 8;
// Client should wait this long before retrying
// If -1, do not retry. If -2, use user retry policy
// -2 user defined retry
// positive value indicates
int32 retry_delay = 9;
}

message FunctionGetOutputsRequest {
Expand Down Expand Up @@ -1583,6 +1588,8 @@ message GenericResult { // Used for both tasks and function outputs
// Used when the user's function fails to initialize (ex. S3 mount failed due to invalid credentials).
// Terminates the function and all remaining inputs.
GENERIC_STATUS_INIT_FAILURE = 5;
// Used when an input must be retried due to preemption, network loss, or container crash
GENERIC_STATUS_SYSTEM_EVENT = 6;
}

GenericStatus status = 1; // Status of the task or function output.
Expand Down
18 changes: 9 additions & 9 deletions test/function_retry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import modal
from modal import App
from modal.functions import _Function
from modal.retries import RetryManager
from modal_proto import api_pb2

function_call_count = 0
Expand All @@ -12,7 +12,7 @@
@pytest.fixture(autouse=True)
def reset_function_call_count(monkeypatch):
# Set default retry delay to something small so we don't slow down tests
monkeypatch.setattr("modal.functions.MIN_INPUT_RETRY_DELAY_MS", 0.00001)
monkeypatch.setattr("modal.retries.MIN_INPUT_RETRY_DELAY_MS", 0.00001)
global function_call_count
function_call_count = 0

Expand All @@ -26,14 +26,14 @@ def __init__(self, function_call_count):
self.function_call_count = function_call_count


def counting_function(function_calls_until_success: int):
def counting_function(attempt_to_return_success: int):
"""
A function that updates the global function_call_count counter each time it is called.
"""
global function_call_count
function_call_count += 1
if function_call_count < function_calls_until_success:
if function_call_count < attempt_to_return_success:
raise FunctionCallCountException(function_call_count)
return function_call_count

Expand All @@ -55,8 +55,8 @@ def test_all_retries_fail_raises_error(client, setup_app_and_function):
app, f = setup_app_and_function
with app.run(client=client):
with pytest.raises(FunctionCallCountException) as exc_info:
f.remote(4)
assert exc_info.value.function_call_count == 3
f.remote(5)
assert exc_info.value.function_call_count == 4


def test_failures_followed_by_success(client, setup_app_and_function):
Expand All @@ -75,10 +75,10 @@ def test_no_retries_when_first_call_succeeds(client, setup_app_and_function):

def test_retry_dealy_ms():
with pytest.raises(ValueError):
_Function._retry_delay_ms(0, api_pb2.FunctionRetryPolicy())
RetryManager._retry_delay_ms(0, api_pb2.FunctionRetryPolicy())

retry_policy = api_pb2.FunctionRetryPolicy(retries=2, backoff_coefficient=3, initial_delay_ms=2000)
assert _Function._retry_delay_ms(1, retry_policy) == 2000
assert RetryManager._retry_delay_ms(1, retry_policy) == 2000

retry_policy = api_pb2.FunctionRetryPolicy(retries=2, backoff_coefficient=3, initial_delay_ms=2000)
assert _Function._retry_delay_ms(2, retry_policy) == 6000
assert RetryManager._retry_delay_ms(2, retry_policy) == 6000

0 comments on commit 0e4c9e1

Please sign in to comment.