diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 3767f451..2e185636 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -242,6 +242,16 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: # Set ourselves on our own loop temporalio.workflow._Runtime.set_on_loop(self, self) + # After GC, Python raises GeneratorExit calls from all awaiting tasks. + # Then in a finally of such an await, another exception can swallow + # these causing even more issues. We will set ourselves as deleted so we + # can check in some places to swallow these errors on tear down. + self._deleting = False + + def __del__(self) -> None: + # We have confirmed there are no super() versions of __del__ + self._deleting = True + #### Activation functions #### # These are in alphabetical order and besides "activate", all other calls # are "_apply_" + the job field name. @@ -629,14 +639,26 @@ def _apply_start_workflow( # Async call to run on the scheduler thread. This will be wrapped in # another function which applies exception handling. async def run_workflow(input: ExecuteWorkflowInput) -> None: - result = await self._inbound.execute_workflow(input) - result_payloads = self._payload_converter.to_payloads([result]) - if len(result_payloads) != 1: - raise ValueError( - f"Expected 1 result payload, got {len(result_payloads)}" - ) - command = self._add_command() - command.complete_workflow_execution.result.CopyFrom(result_payloads[0]) + try: + result = await self._inbound.execute_workflow(input) + result_payloads = self._payload_converter.to_payloads([result]) + if len(result_payloads) != 1: + raise ValueError( + f"Expected 1 result payload, got {len(result_payloads)}" + ) + command = self._add_command() + command.complete_workflow_execution.result.CopyFrom(result_payloads[0]) + except BaseException as err: + # During tear down, generator exit and event loop exceptions can occur + if not self._deleting: + raise + if not isinstance( + err, + (GeneratorExit, temporalio.workflow._NotInWorkflowEventLoopError), + ): + logger.debug( + "Ignoring exception while deleting workflow", exc_info=True + ) # Schedule it input = ExecuteWorkflowInput( @@ -1260,6 +1282,16 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: else: # All other exceptions fail the task self._current_activation_error = err + except BaseException as err: + # During tear down, generator exit and no-runtime exceptions can appear + if not self._deleting: + raise + if not isinstance( + err, (GeneratorExit, temporalio.workflow._NotInWorkflowEventLoopError) + ): + logger.debug( + "Ignoring exception while deleting workflow", exc_info=True + ) def _set_workflow_failure(self, err: temporalio.exceptions.FailureError) -> None: # All other failure errors fail the workflow diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 4c735e53..67720f5a 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -340,7 +340,7 @@ class _Runtime(ABC): def current() -> _Runtime: loop = _Runtime.maybe_current() if not loop: - raise RuntimeError("Not in workflow event loop") + raise _NotInWorkflowEventLoopError("Not in workflow event loop") return loop @staticmethod @@ -3843,6 +3843,12 @@ def __init__(self, message: str) -> None: self.message = message +class _NotInWorkflowEventLoopError(temporalio.exceptions.TemporalError): + def __init__(self, *args: object) -> None: + super().__init__("Not in workflow event loop") + self.message = "Not in workflow event loop" + + class VersioningIntent(Enum): """Indicates whether the user intends certain commands to be run on a compatible worker Build Id version or not. diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 6bec62f9..d84ed646 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -5,6 +5,7 @@ import logging.handlers import pickle import queue +import sys import threading import uuid from abc import ABC, abstractmethod @@ -2843,3 +2844,71 @@ async def test_manual_result_type(client: Client): assert res3 == {"some_string": "from-query"} res4 = await handle.query("some_query", result_type=ManualResultType) assert res4 == ManualResultType(some_string="from-query") + + +@workflow.defn +class SwallowGeneratorExitWorkflow: + def __init__(self) -> None: + self._signal_count = 0 + + @workflow.run + async def run(self) -> None: + try: + # Wait for signal count to reach 2 + await workflow.wait_condition(lambda: self._signal_count > 1) + finally: + # This finally, on eviction, is actually called because the above + # await raises GeneratorExit. Then this will raise a + # _NotInWorkflowEventLoopError swallowing that. + await workflow.wait_condition(lambda: self._signal_count > 2) + + @workflow.signal + async def signal(self) -> None: + self._signal_count += 1 + + @workflow.query + async def signal_count(self) -> int: + return self._signal_count + + +async def test_swallow_generator_exit(client: Client): + if sys.version_info < (3, 8): + pytest.skip("sys.unraisablehook not in 3.7") + # This test simulates GeneratorExit and GC issues by forcing eviction on + # each step + async with new_worker( + client, SwallowGeneratorExitWorkflow, max_cached_workflows=0 + ) as worker: + # Put a hook to catch unraisable exceptions + old_hook = sys.unraisablehook + hook_calls: List[Any] = [] + sys.unraisablehook = hook_calls.append + try: + handle = await client.start_workflow( + SwallowGeneratorExitWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + async def signal_count() -> int: + return await handle.query(SwallowGeneratorExitWorkflow.signal_count) + + # Confirm signal count as 0 + await assert_eq_eventually(0, signal_count) + + # Send signal and confirm it's at 1 + await handle.signal(SwallowGeneratorExitWorkflow.signal) + await assert_eq_eventually(1, signal_count) + + await handle.signal(SwallowGeneratorExitWorkflow.signal) + await assert_eq_eventually(2, signal_count) + + await handle.signal(SwallowGeneratorExitWorkflow.signal) + await assert_eq_eventually(3, signal_count) + + await handle.result() + finally: + sys.unraisablehook = old_hook + + # Confirm no unraisable exceptions + assert not hook_calls