From f4d446a5a1e35dab507ede22420a6dc3660aa734 Mon Sep 17 00:00:00 2001 From: Niklas Rousset <75939868+niklasr22@users.noreply.github.com> Date: Mon, 10 Feb 2025 11:52:05 +0100 Subject: [PATCH] Update DockerSwarmOperator auto_remove to align with DockerOperator (#45745) * Update DockerSwarmOperator auto_remove to align with DockerOperator * add docker swarm auto remove test --- .../docker/operators/docker_swarm.py | 16 ++++---- .../docker/operators/test_docker_swarm.py | 38 ++++++++++++++++++- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py b/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py index e899b20f648d0..200c03b127694 100644 --- a/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py +++ b/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py @@ -59,9 +59,11 @@ class DockerSwarmOperator(DockerOperator): If image tag is omitted, "latest" will be used. :param api_version: Remote API version. Set to ``auto`` to automatically detect the server's version. - :param auto_remove: Auto-removal of the container on daemon side when the - container's process exits. - The default is False. + :param auto_remove: Enable removal of the service when the service has terminated. Possible values: + + - ``never``: (default) do not remove service + - ``success``: remove on success + - ``force``: always remove service :param command: Command to be run in the container. (templated) :param args: Arguments to the command. :param docker_url: URL of the host running the docker daemon. @@ -214,18 +216,16 @@ def _run_service(self) -> None: container_id = task["Status"]["ContainerStatus"]["ContainerID"] container = self.cli.inspect_container(container_id) self.containers.append(container) - else: - raise AirflowException(f"Service did not complete: {self.service!r}") if self.retrieve_output: return self._attempt_to_retrieve_results() - self.log.info("auto_removeauto_removeauto_removeauto_removeauto_remove : %s", str(self.auto_remove)) + self.log.info("auto_remove: %s", str(self.auto_remove)) if self.service and self._service_status() != "complete": - if self.auto_remove == "success": + if self.auto_remove == "force": self.cli.remove_service(self.service["ID"]) raise AirflowException(f"Service did not complete: {self.service!r}") - elif self.auto_remove == "success": + elif self.auto_remove in ["success", "force"]: if not self.service: raise RuntimeError("The 'service' should be initialized before!") self.cli.remove_service(self.service["ID"]) diff --git a/providers/docker/tests/provider_tests/docker/operators/test_docker_swarm.py b/providers/docker/tests/provider_tests/docker/operators/test_docker_swarm.py index 67976ccf4ac70..0887785739b1e 100644 --- a/providers/docker/tests/provider_tests/docker/operators/test_docker_swarm.py +++ b/providers/docker/tests/provider_tests/docker/operators/test_docker_swarm.py @@ -130,7 +130,8 @@ def _client_service_logs_effect(): client_mock.remove_service.assert_called_once_with("some_id") @mock.patch("airflow.providers.docker.operators.docker_swarm.types") - def test_auto_remove(self, types_mock, docker_api_client_patcher): + @pytest.mark.parametrize("auto_remove", ["success", "force"]) + def test_auto_remove(self, types_mock, docker_api_client_patcher, auto_remove): mock_obj = mock.Mock() client_mock = mock.Mock(spec=APIClient) @@ -148,12 +149,45 @@ def test_auto_remove(self, types_mock, docker_api_client_patcher): docker_api_client_patcher.return_value = client_mock operator = DockerSwarmOperator( - image="", auto_remove="success", task_id="unittest", enable_logging=False + image="", auto_remove=auto_remove, task_id="unittest", enable_logging=False ) operator.execute(None) client_mock.remove_service.assert_called_once_with("some_id") + @mock.patch("airflow.providers.docker.operators.docker_swarm.types") + @pytest.mark.parametrize( + "auto_remove,expected_remove_call", [("success", False), ("force", True), ("never", False)] + ) + def test_auto_remove_failed( + self, types_mock, docker_api_client_patcher, auto_remove, expected_remove_call + ): + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {"ID": "some_id"} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [ + {"Status": {"State": "failed", "ContainerStatus": {"ContainerID": "some_id"}}} + ] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + docker_api_client_patcher.return_value = client_mock + + operator = DockerSwarmOperator( + image="", auto_remove=auto_remove, task_id="unittest", enable_logging=False + ) + try: + operator.execute(None) + except AirflowException: + pass + + assert (client_mock.remove_service.call_count > 0) == expected_remove_call + @mock.patch("airflow.providers.docker.operators.docker_swarm.types") def test_no_auto_remove(self, types_mock, docker_api_client_patcher): mock_obj = mock.Mock()