Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] fix maximum recursion issue when serializing exceptions #43952

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import socket
from contextlib import closing
import copy
import logging
import os
import queue
import socket
import threading
from contextlib import closing
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -35,7 +36,13 @@ class StartTraceback(Exception):


def skip_exceptions(exc: Optional[Exception]) -> Exception:
"""Skip all contained `StartTracebacks` to reduce traceback output"""
"""Skip all contained `StartTracebacks` to reduce traceback output.

Returns a shallow copy of the exception with all `StartTracebacks` removed.

If the RAY_AIR_FULL_TRACEBACKS environment variable is set,
the original exception (not a copy) is returned.
"""
should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))

if should_not_shorten:
Expand All @@ -45,12 +52,15 @@ def skip_exceptions(exc: Optional[Exception]) -> Exception:
# If this is a StartTraceback, skip
return skip_exceptions(exc.__cause__)

# Else, make sure nested exceptions are properly skipped
# Perform a shallow copy to prevent recursive __cause__/__context__.
new_exc = copy.copy(exc).with_traceback(exc.__traceback__)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: with_traceback is needed so that the traceback shows the original line that errored, rather than this line where the copy is happening.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does making a shallow copy remove the __context__ so that the new_exc has no context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but more importantly the __context__ gets set (to the StartTraceback) after the exception gets raised, so we want to make sure there is no nested __cause__ or __context__ that points back to this exception.

# Make sure nested exceptions are properly skipped.
cause = getattr(exc, "__cause__", None)
if cause:
exc.__cause__ = skip_exceptions(cause)
new_exc.__cause__ = skip_exceptions(cause)

return exc
return new_exc


def exception_cause(exc: Optional[Exception]) -> Optional[Exception]:
Expand Down
43 changes: 42 additions & 1 deletion python/ray/air/tests/test_tracebacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest

import ray
from ray import cloudpickle
from tblib import pickling_support
from ray.train import ScalingConfig
from ray.air._internal.util import StartTraceback, skip_exceptions
from ray.air._internal.util import StartTraceback, skip_exceptions, exception_cause
from ray.train.data_parallel_trainer import DataParallelTrainer

from ray.tune import Tuner
Expand Down Expand Up @@ -47,6 +49,45 @@ def test_short_traceback(levels):
assert i == levels - start_traceback + 1


def test_recursion():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas on why Maximum recursion happens iff pickling_support.install() is called? Should we also test it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a bug in tblib.

Good point on testing. Originally I was going to remove tblib as a dependency in a separate PR, but even if I do I can add it as a test dependency.

"""Test that the skipped exception does not point to the original exception."""
root_exception = None

with pytest.raises(StartTraceback) as exc_info:
try:
raise Exception("Root Exception")
except Exception as e:
root_exception = e
raise StartTraceback from root_exception

assert root_exception, "Root exception was not captured."

start_traceback = exc_info.value
skipped_exception = skip_exceptions(start_traceback)

assert (
root_exception != skipped_exception
), "Skipped exception points to the original exception."


def test_tblib():
"""Test that tblib does not cause a maximum recursion error."""

with pytest.raises(Exception) as exc_info:
try:
try:
raise Exception("Root Exception")
except Exception as root_exception:
raise StartTraceback from root_exception
except Exception as start_traceback:
raise skip_exceptions(start_traceback) from exception_cause(start_traceback)

pickling_support.install()
reraised_exception = exc_info.value
# This should not raise a RecursionError/PicklingError.
cloudpickle.dumps(reraised_exception)


def test_traceback_tuner(ray_start_2_cpus):
"""Ensure that the Tuner's stack trace is not too long."""

Expand Down
Loading