Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT] Fix task hang when error object cannot be pickled (#2910) #2913

Merged
merged 1 commit into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,9 @@ mars/services/web/static

# docs
docs/source/savefig/

# Unit / Performance Testing #
##############################
asv_bench/env/
asv_bench/html/
asv_bench/results/
2 changes: 1 addition & 1 deletion mars/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,4 @@ cpdef long long ceildiv(long long x, long long y) nogil:


__all__ = ['to_str', 'to_binary', 'to_text', 'TypeDispatcher', 'tokenize', 'tokenize_int',
'register_tokenizer', 'insert_reversed_tuple', 'ceildiv']
'register_tokenizer', 'ceildiv']
10 changes: 6 additions & 4 deletions mars/core/graph/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def test_to_dot():
graph_reprs = []
for n in graph:
graph_reprs.append(
f"{n.op.key} -> {[succ.op.key for succ in graph.successors(n)]}"
f"{n.op.key} -> {[succ.key for succ in graph.successors(n)]}"
)
logging.error(
"Unexpected error in test_to_dot.\ndot = %r\ngraph_repr: %r",
"Unexpected error in test_to_dot.\ndot = %r\ngraph_repr = %r",
dot,
"\n".join(graph_reprs),
)
missing_prefix = next(str(n.key)[5] not in dot for n in graph)
logging.error("Missing prefix %s", missing_prefix)
missing_prefix = next(n.key for n in graph if str(n.key)[5] not in dot)
logging.error(
"Missing prefix %r (type: %s)", missing_prefix, type(missing_prefix)
)
raise
8 changes: 7 additions & 1 deletion mars/oscar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
from .batch import extensible
from .core import ActorRef
from .debug import set_debug_options, DebugOptions
from .errors import ActorNotExist, ActorAlreadyExist, ServerClosed, Return
from .errors import (
ActorNotExist,
ActorAlreadyExist,
ServerClosed,
SendMessageFailed,
Return,
)
from .utils import create_actor_ref

# make sure methods are registered
Expand Down
12 changes: 11 additions & 1 deletion mars/oscar/backends/mars/tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .....utils import get_next_port
from .... import create_actor_ref, Actor, kill_actor
from ....context import get_context
from ....errors import NoIdleSlot, ActorNotExist, ServerClosed
from ....errors import NoIdleSlot, ActorNotExist, ServerClosed, SendMessageFailed
from ...allocate_strategy import (
AddressSpecified,
IdleLabel,
Expand Down Expand Up @@ -53,6 +53,11 @@
from ..pool import MainActorPool, SubActorPool


class _CannotBePickled:
def __getstate__(self):
raise RuntimeError("cannot pickle")


class _CannotBeUnpickled:
def __getstate__(self):
return ()
Expand Down Expand Up @@ -85,6 +90,9 @@ async def sleep(self, second):
def return_cannot_unpickle(self):
return _CannotBeUnpickled()

def raise_cannot_pickle(self):
raise ValueError(_CannotBePickled())


def _add_pool_conf(
config: ActorPoolConfig,
Expand Down Expand Up @@ -506,6 +514,8 @@ async def test_create_actor_pool():
assert await actor_ref2.add(1) == 4
with pytest.raises(RuntimeError):
await actor_ref2.return_cannot_unpickle()
with pytest.raises(SendMessageFailed):
await actor_ref2.raise_cannot_pickle()
assert (await ctx.has_actor(actor_ref2)) is True
assert (await ctx.actor_ref(actor_ref2)) == actor_ref2
# test cancel
Expand Down
40 changes: 34 additions & 6 deletions mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from ..api import Actor
from ..core import ActorRef
from ..debug import record_message_trace, debug_async_timeout
from ..errors import ActorAlreadyExist, ActorNotExist, ServerClosed, CannotCancelTask
from ..errors import (
ActorAlreadyExist,
ActorNotExist,
ServerClosed,
CannotCancelTask,
SendMessageFailed,
)
from ..utils import create_actor_ref
from .allocate_strategy import allocated_type, AddressSpecified
from .communication import Channel, Server, get_server_type, gen_local_address
Expand Down Expand Up @@ -320,6 +326,31 @@ async def _run_coro(self, message_id: bytes, coro: Coroutine):
finally:
self._process_messages.pop(message_id, None)

async def _send_channel(
self, result: _MessageBase, channel: Channel, resend_failure: bool = True
):
try:
await channel.send(result)
except (ChannelClosed, ConnectionResetError):
if not self._stopped.is_set():
raise
except Exception as ex:
logger.exception(
"Error when sending message %s from %s to %s",
result.message_id.hex(),
channel.local_address,
channel.dest_address,
)
if not resend_failure: # pragma: no cover
raise

with _ErrorProcessor(result.message_id, result.protocol) as processor:
raise SendMessageFailed(
f"Error when sending message {result.message_id.hex()}. "
f"Caused by {ex!r}. See server logs for more details"
) from None
await self._send_channel(processor.result, channel, resend_failure=False)

async def process_message(self, message: _MessageBase, channel: Channel):
handler = self._message_handler[message.message_type]
with _ErrorProcessor(message.message_id, message.protocol) as processor:
Expand All @@ -333,11 +364,8 @@ async def process_message(self, message: _MessageBase, channel: Channel):
processor.result = await self._run_coro(
message.message_id, handler(self, message)
)
try:
await channel.send(processor.result)
except (ChannelClosed, ConnectionResetError):
if not self._stopped.is_set():
raise

await self._send_channel(processor.result, channel)

async def call(self, dest_address: str, message: _MessageBase) -> ResultMessageType:
return await self._caller.call(self._router, dest_address, message)
Expand Down
4 changes: 4 additions & 0 deletions mars/oscar/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class CannotCancelTask(MarsError):
pass


class SendMessageFailed(MarsError):
pass


class Return(MarsError):
def __init__(self, value):
self.value = value
14 changes: 14 additions & 0 deletions mars/remote/tests/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ... import dataframe as md
from ... import tensor as mt
from ... import oscar as mo
from ...core import tile
from ...deploy.oscar.session import get_default_session
from ...learn.utils import shuffle
Expand Down Expand Up @@ -188,3 +189,16 @@ def f(*_args):
r4 = spawn(f, args=(r2, r3))

assert r4.execute().fetch() is None


def test_remote_with_unpickable(setup_cluster):
def f(*_):
class Unpickleable:
def __reduce__(self):
raise ValueError

raise KeyError(Unpickleable())

with pytest.raises(mo.SendMessageFailed):
d = spawn(f, retry_when_fail=False)
d.execute()