diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index afe0cceb8a853..5473a947c6939 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -385,7 +385,10 @@ async def run( ) async def kill_infrastructure( - self, infrastructure_pid: str, grace_seconds: int = 30 + self, + infrastructure_pid: str, + configuration: BaseJobConfiguration, + grace_seconds: int = 30, ): """ Method for killing infrastructure created by a worker. Should be implemented by @@ -505,8 +508,13 @@ async def cancel_run(self, flow_run: "FlowRun"): ) return + configuration = await self._get_configuration(flow_run) + try: - await self.kill_infrastructure(flow_run.infrastructure_pid) + await self.kill_infrastructure( + infrastructure_pid=flow_run.infrastructure_pid, + configuration=configuration, + ) except NotImplementedError: self._logger.error( f"Worker type {self.type!r} does not support killing created " @@ -711,8 +719,6 @@ async def _submit_run_and_capture_errors( self, flow_run: "FlowRun", task_status: anyio.abc.TaskStatus = None ) -> Union[BaseWorkerResult, Exception]: try: - # TODO: Add functionality to handle base job configuration and - # job configuration variables when kicking off a flow run configuration = await self._get_configuration(flow_run) submitted_event = self._emit_flow_run_submitted_event(configuration) result = await self.run( diff --git a/src/prefect/workers/process.py b/src/prefect/workers/process.py index 1db8cbb239f83..7a2d7dfaac772 100644 --- a/src/prefect/workers/process.py +++ b/src/prefect/workers/process.py @@ -59,8 +59,8 @@ def _parse_infrastructure_pid(infrastructure_pid: str) -> Tuple[str, int]: class ProcessJobConfiguration(BaseJobConfiguration): - stream_output: bool - working_dir: Optional[Path] + stream_output: bool = Field(default=True) + working_dir: Optional[Path] = Field(default=None) @validator("working_dir") def validate_command(cls, v): @@ -195,7 +195,10 @@ async def run( ) async def kill_infrastructure( - self, infrastructure_pid: str, grace_seconds: int = 30 + self, + infrastructure_pid: str, + configuration: ProcessJobConfiguration, + grace_seconds: int = 30, ): hostname, pid = _parse_infrastructure_pid(infrastructure_pid) diff --git a/tests/workers/test_base_worker.py b/tests/workers/test_base_worker.py index 75e53447a165b..284ed511aca6c 100644 --- a/tests/workers/test_base_worker.py +++ b/tests/workers/test_base_worker.py @@ -37,7 +37,25 @@ async def run(self): pass async def kill_infrastructure( - self, infrastructure_pid: str, grace_seconds: int = 30 + self, + infrastructure_pid: str, + grace_seconds: int = 30, + configuration: Optional[BaseJobConfiguration] = None, + ): + pass + + +class WorkerWithOldKillInfrastructureInterface(BaseWorker): + type = "test-old-interface" + job_configuration = BaseJobConfiguration + + async def run(self): + pass + + async def kill_infrastructure( + self, + infrastructure_pid: str, + grace_seconds: int = 30, ): pass @@ -1265,16 +1283,17 @@ class TestCancellation: async def test_worker_cancel_run_called_for_cancelling_run( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, cancelling_constructor, work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) async with WorkerTestImpl(work_pool_name=work_pool.name) as worker: + await worker.sync_with_backend() worker.cancel_run = AsyncMock() await worker.check_for_cancelled_flow_runs() @@ -1294,14 +1313,15 @@ async def test_worker_cancel_run_called_for_cancelling_run( ], ) async def test_worker_cancel_run_not_called_for_other_states( - self, orion_client: PrefectClient, deployment, state, work_pool + self, orion_client: PrefectClient, worker_deployment_wq1, state, work_pool ): await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=state, ) async with WorkerTestImpl(work_pool_name=work_pool.name) as worker: + await worker.sync_with_backend() worker.cancel_run = AsyncMock() await worker.check_for_cancelled_flow_runs() @@ -1313,18 +1333,14 @@ async def test_worker_cancel_run_not_called_for_other_states( async def test_worker_cancel_run_called_for_cancelling_run_with_multiple_work_queues( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, cancelling_constructor, work_pool, work_queue_1, work_queue_2, ): - deployment.work_pool_name = work_pool.name - deployment.work_queue_name = work_queue_1.name - await orion_client.update_deployment(deployment) - flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) @@ -1332,6 +1348,7 @@ async def test_worker_cancel_run_called_for_cancelling_run_with_multiple_work_qu work_pool_name=work_pool.name, work_queues=[work_queue_1.name, work_queue_2.name], ) as worker: + await worker.sync_with_backend() worker.cancel_run = AsyncMock() await worker.check_for_cancelled_flow_runs() @@ -1362,6 +1379,7 @@ async def test_worker_cancel_run_not_called_for_same_queue_names_in_different_wo work_pool_name=work_pool.name, work_queues=[work_queue_1.name], ) as worker: + await worker.sync_with_backend() worker.cancel_run = AsyncMock() await worker.check_for_cancelled_flow_runs() @@ -1371,18 +1389,23 @@ async def test_worker_cancel_run_not_called_for_same_queue_names_in_different_wo "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] ) async def test_worker_cancel_run_not_called_for_other_work_queues( - self, orion_client: PrefectClient, deployment, cancelling_constructor, work_pool + self, + orion_client: PrefectClient, + worker_deployment_wq1, + cancelling_constructor, + work_pool, ): await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) async with WorkerTestImpl( work_pool_name=work_pool.name, - work_queues=[f"not-{deployment.work_queue_name}"], + work_queues=[f"not-{worker_deployment_wq1.work_queue_name}"], prefetch_seconds=10, ) as worker: + await worker.sync_with_backend() worker.cancel_run = AsyncMock() await worker.check_for_cancelled_flow_runs() @@ -1394,10 +1417,14 @@ async def test_worker_cancel_run_not_called_for_other_work_queues( "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] ) async def test_worker_cancel_run_kills_run_with_infrastructure_pid( - self, orion_client: PrefectClient, deployment, cancelling_constructor, work_pool + self, + orion_client: PrefectClient, + worker_deployment_wq1, + cancelling_constructor, + work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) @@ -1406,10 +1433,14 @@ async def test_worker_cancel_run_kills_run_with_infrastructure_pid( async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() worker.kill_infrastructure = AsyncMock() await worker.check_for_cancelled_flow_runs() + configuration = await worker._get_configuration(flow_run) - worker.kill_infrastructure.assert_awaited_once_with("test") + worker.kill_infrastructure.assert_awaited_once_with( + infrastructure_pid="test", configuration=configuration + ) @pytest.mark.parametrize( "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] @@ -1417,19 +1448,20 @@ async def test_worker_cancel_run_kills_run_with_infrastructure_pid( async def test_worker_cancel_run_with_missing_infrastructure_pid( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, caplog, cancelling_constructor, work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() worker.kill_infrastructure = AsyncMock() await worker.check_for_cancelled_flow_runs() @@ -1453,10 +1485,14 @@ async def test_worker_cancel_run_with_missing_infrastructure_pid( "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] ) async def test_worker_cancel_run_updates_state_type( - self, orion_client: PrefectClient, deployment, cancelling_constructor, work_pool + self, + orion_client: PrefectClient, + worker_deployment_wq1, + cancelling_constructor, + work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) @@ -1465,6 +1501,7 @@ async def test_worker_cancel_run_updates_state_type( async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() await worker.check_for_cancelled_flow_runs() post_flow_run = await orion_client.read_flow_run(flow_run.id) @@ -1474,12 +1511,16 @@ async def test_worker_cancel_run_updates_state_type( "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] ) async def test_worker_cancel_run_preserves_other_state_properties( - self, orion_client: PrefectClient, deployment, cancelling_constructor, work_pool + self, + orion_client: PrefectClient, + worker_deployment_wq1, + cancelling_constructor, + work_pool, ): expected_changed_fields = {"type", "name", "timestamp", "id"} flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(message="test"), ) @@ -1488,6 +1529,7 @@ async def test_worker_cancel_run_preserves_other_state_properties( async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() await worker.check_for_cancelled_flow_runs() post_flow_run = await orion_client.read_flow_run(flow_run.id) @@ -1501,13 +1543,13 @@ async def test_worker_cancel_run_preserves_other_state_properties( async def test_worker_cancel_run_with_infrastructure_not_available_during_kill( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, caplog, cancelling_constructor, work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) @@ -1516,15 +1558,19 @@ async def test_worker_cancel_run_with_infrastructure_not_available_during_kill( async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() worker.kill_infrastructure = AsyncMock() worker.kill_infrastructure.side_effect = InfrastructureNotAvailable("Test!") await worker.check_for_cancelled_flow_runs() # Perform a second call to check that it is tracked locally that this worker # should not try again await worker.check_for_cancelled_flow_runs() + configuration = await worker._get_configuration(flow_run) # Only awaited once - worker.kill_infrastructure.assert_awaited_once_with("test") + worker.kill_infrastructure.assert_awaited_once_with( + infrastructure_pid="test", configuration=configuration + ) # State name not updated; other workers may attempt the kill post_flow_run = await orion_client.read_flow_run(flow_run.id) @@ -1542,13 +1588,13 @@ async def test_worker_cancel_run_with_infrastructure_not_available_during_kill( async def test_worker_cancel_run_with_infrastructure_not_found_during_kill( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, caplog, cancelling_constructor, work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) @@ -1557,14 +1603,18 @@ async def test_worker_cancel_run_with_infrastructure_not_found_during_kill( async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() worker.kill_infrastructure = AsyncMock() worker.kill_infrastructure.side_effect = InfrastructureNotFound("Test!") await worker.check_for_cancelled_flow_runs() # Perform a second call to check that another cancellation attempt is not made await worker.check_for_cancelled_flow_runs() + configuration = await worker._get_configuration(flow_run) # Only awaited once - worker.kill_infrastructure.assert_awaited_once_with("test") + worker.kill_infrastructure.assert_awaited_once_with( + infrastructure_pid="test", configuration=configuration + ) # State name updated to prevent further attempts post_flow_run = await orion_client.read_flow_run(flow_run.id) @@ -1582,13 +1632,13 @@ async def test_worker_cancel_run_with_infrastructure_not_found_during_kill( async def test_worker_cancel_run_with_unknown_error_during_kill( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, caplog, cancelling_constructor, work_pool, ): flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -1596,14 +1646,21 @@ async def test_worker_cancel_run_with_unknown_error_during_kill( async with WorkerTestImpl( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() worker.kill_infrastructure = AsyncMock() worker.kill_infrastructure.side_effect = ValueError("Oh no!") await worker.check_for_cancelled_flow_runs() await anyio.sleep(0.5) await worker.check_for_cancelled_flow_runs() + configuration = await worker._get_configuration(flow_run) # Multiple attempts should be made - worker.kill_infrastructure.assert_has_awaits([call("test"), call("test")]) + worker.kill_infrastructure.assert_has_awaits( + [ + call(infrastructure_pid="test", configuration=configuration), + call(infrastructure_pid="test", configuration=configuration), + ] + ) # State name not updated post_flow_run = await orion_client.read_flow_run(flow_run.id) @@ -1622,7 +1679,7 @@ async def test_worker_cancel_run_with_unknown_error_during_kill( async def test_worker_cancel_run_without_infrastructure_support_for_kill( self, orion_client: PrefectClient, - deployment, + worker_deployment_wq1, caplog, cancelling_constructor, work_pool, @@ -1636,7 +1693,7 @@ async def run(self, flow_run, configuration, task_status=None): pass flow_run = await orion_client.create_flow_run_from_deployment( - deployment.id, + worker_deployment_wq1.id, state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -1644,6 +1701,7 @@ async def run(self, flow_run, configuration, task_status=None): async with WorkerNoKill( work_pool_name=work_pool.name, prefetch_seconds=10 ) as worker: + await worker.sync_with_backend() await worker.check_for_cancelled_flow_runs() # State name not updated; another worker may have a code version that supports @@ -1657,3 +1715,32 @@ async def run(self, flow_run, configuration, task_status=None): in caplog.text ) assert "Cancellation cannot be guaranteed." in caplog.text + + @pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] + ) + async def test_worker_cancel_run_does_not_raise_for_old_interface( + self, + orion_client: PrefectClient, + worker_deployment_wq1, + cancelling_constructor, + work_pool, + ): + flow_run = await orion_client.create_flow_run_from_deployment( + worker_deployment_wq1.id, + state=cancelling_constructor(), + ) + + await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") + + async with WorkerWithOldKillInfrastructureInterface( + work_pool_name=work_pool.name, prefetch_seconds=10 + ) as worker: + await worker.sync_with_backend() + worker.kill_infrastructure = AsyncMock() + await worker.check_for_cancelled_flow_runs() + configuration = await worker._get_configuration(flow_run) + + worker.kill_infrastructure.assert_awaited_once_with( + infrastructure_pid="test", configuration=configuration + ) diff --git a/tests/workers/test_process_worker.py b/tests/workers/test_process_worker.py index 90b23cd3013b5..8e5784e0e993b 100644 --- a/tests/workers/test_process_worker.py +++ b/tests/workers/test_process_worker.py @@ -20,7 +20,11 @@ from prefect.server.schemas.core import WorkPool from prefect.server.schemas.states import StateDetails, StateType from prefect.testing.utilities import AsyncMock, MagicMock -from prefect.workers.process import ProcessWorker, ProcessWorkerResult +from prefect.workers.process import ( + ProcessJobConfiguration, + ProcessWorker, + ProcessWorkerResult, +) @flow @@ -463,17 +467,24 @@ async def test_process_kill_mismatching_hostname(monkeypatch, work_pool): async with ProcessWorker(work_pool_name=work_pool.name) as worker: with pytest.raises(InfrastructureNotAvailable): - await worker.kill_infrastructure(infrastructure_pid=infrastructure_pid) + await worker.kill_infrastructure( + infrastructure_pid=infrastructure_pid, + configuration=ProcessJobConfiguration(), + ) os_kill.assert_not_called() async def test_process_kill_no_matching_pid(monkeypatch, work_pool): + patch_client(monkeypatch) infrastructure_pid = f"{socket.gethostname()}:12345" async with ProcessWorker(work_pool_name=work_pool.name) as worker: with pytest.raises(InfrastructureNotFound): - await worker.kill_infrastructure(infrastructure_pid=infrastructure_pid) + await worker.kill_infrastructure( + infrastructure_pid=infrastructure_pid, + configuration=ProcessJobConfiguration(), + ) @pytest.mark.skipif( @@ -481,6 +492,7 @@ async def test_process_kill_no_matching_pid(monkeypatch, work_pool): reason="SIGTERM/SIGKILL are only used in non-Windows environments", ) async def test_process_kill_sends_sigterm_then_sigkill(monkeypatch, work_pool): + patch_client(monkeypatch) os_kill = MagicMock() monkeypatch.setattr("os.kill", os_kill) @@ -489,7 +501,9 @@ async def test_process_kill_sends_sigterm_then_sigkill(monkeypatch, work_pool): async with ProcessWorker(work_pool_name=work_pool.name) as worker: await worker.kill_infrastructure( - infrastructure_pid=infrastructure_pid, grace_seconds=grace_seconds + infrastructure_pid=infrastructure_pid, + grace_seconds=grace_seconds, + configuration=ProcessJobConfiguration(), ) os_kill.assert_has_calls( @@ -506,6 +520,7 @@ async def test_process_kill_sends_sigterm_then_sigkill(monkeypatch, work_pool): reason="SIGTERM/SIGKILL are only used in non-Windows environments", ) async def test_process_kill_early_return(monkeypatch, work_pool): + patch_client(monkeypatch) os_kill = MagicMock(side_effect=[None, ProcessLookupError]) anyio_sleep = AsyncMock() monkeypatch.setattr("os.kill", os_kill) @@ -516,7 +531,9 @@ async def test_process_kill_early_return(monkeypatch, work_pool): async with ProcessWorker(work_pool_name=work_pool.name) as worker: await worker.kill_infrastructure( - infrastructure_pid=infrastructure_pid, grace_seconds=grace_seconds + infrastructure_pid=infrastructure_pid, + grace_seconds=grace_seconds, + configuration=ProcessJobConfiguration(), ) os_kill.assert_has_calls( @@ -534,6 +551,7 @@ async def test_process_kill_early_return(monkeypatch, work_pool): reason="CTRL_BREAK_EVENT is only defined in Windows", ) async def test_process_kill_windows_sends_ctrl_break(monkeypatch, work_pool): + patch_client(monkeypatch) os_kill = MagicMock() monkeypatch.setattr("os.kill", os_kill) @@ -542,7 +560,9 @@ async def test_process_kill_windows_sends_ctrl_break(monkeypatch, work_pool): async with ProcessWorker(work_pool_name=work_pool.name) as worker: await worker.kill_infrastructure( - infrastructure_pid=infrastructure_pid, grace_seconds=grace_seconds + infrastructure_pid=infrastructure_pid, + grace_seconds=grace_seconds, + configuration=ProcessJobConfiguration(), ) os_kill.assert_called_once_with(12345, signal.CTRL_BREAK_EVENT)