diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py index 1a7ba21bfbf48..cfb0015b8ee72 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py @@ -49,6 +49,7 @@ class WinRMOperator(BaseOperator): If specified, it will execute the command as powershell script. :param output_encoding: the encoding used to decode stout and stderr :param timeout: timeout for executing the command. + :param expected_return_code: expected return code value(s) of command. """ template_fields: Sequence[str] = ("command",) @@ -64,6 +65,7 @@ def __init__( ps_path: str | None = None, output_encoding: str = "utf-8", timeout: int = 10, + expected_return_code: int | list[int] | range = 0, **kwargs, ) -> None: super().__init__(**kwargs) @@ -74,6 +76,7 @@ def __init__( self.ps_path = ps_path self.output_encoding = output_encoding self.timeout = timeout + self.expected_return_code = expected_return_code def execute(self, context: Context) -> list | str: if self.ssh_conn_id and not self.winrm_hook: @@ -96,7 +99,13 @@ def execute(self, context: Context) -> list | str: return_output=self.do_xcom_push, ) - if return_code == 0: + success = False + if isinstance(self.expected_return_code, int): + success = return_code == self.expected_return_code + elif isinstance(self.expected_return_code, list) or isinstance(self.expected_return_code, range): + success = return_code in self.expected_return_code + + if success: # returning output if do_xcom_push is set # TODO: Remove this after minimum Airflow version is 3.0 enable_pickling = conf.getboolean("core", "enable_xcom_pickling", fallback=False) diff --git a/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/test_winrm.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/test_winrm.py index b05ae8f060255..8997395b0bf5c 100644 --- a/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/test_winrm.py +++ b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/test_winrm.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from base64 import b64encode from unittest import mock import pytest @@ -38,3 +39,63 @@ def test_no_command(self, mock_hook): exception_msg = "No command specified so nothing to execute here." with pytest.raises(AirflowException, match=exception_msg): op.execute(None) + + @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook") + def test_default_returning_0_command(self, mock_hook): + stdout = [b"O", b"K"] + command = "not_empty" + mock_hook.run.return_value = (0, stdout, []) + op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command) + execute_result = op.execute(None) + assert execute_result == b64encode(b"".join(stdout)).decode("utf-8") + mock_hook.run.assert_called_once_with( + command=command, + ps_path=None, + output_encoding="utf-8", + return_output=True, + ) + + @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook") + def test_default_returning_1_command(self, mock_hook): + stderr = [b"K", b"O"] + command = "not_empty" + mock_hook.run.return_value = (1, [], stderr) + op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command) + exception_msg = f"Error running cmd: {command}, return code: 1, error: KO" + with pytest.raises(AirflowException, match=exception_msg): + op.execute(None) + + @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook") + @pytest.mark.parametrize("expected_return_code", [1, [1, 2], range(1, 3)]) + @pytest.mark.parametrize("real_return_code", [0, 1, 2]) + def test_expected_return_code_command(self, mock_hook, expected_return_code, real_return_code): + stdout = [b"O", b"K"] + stderr = [b"K", b"O"] + command = "not_empty" + mock_hook.run.return_value = (real_return_code, stdout, stderr) + op = WinRMOperator( + task_id="test_task_id", + winrm_hook=mock_hook, + command=command, + expected_return_code=expected_return_code, + ) + + should_task_succeed = False + if isinstance(expected_return_code, int): + should_task_succeed = real_return_code == expected_return_code + elif isinstance(expected_return_code, list) or isinstance(expected_return_code, range): + should_task_succeed = real_return_code in expected_return_code + + if should_task_succeed: + execute_result = op.execute(None) + assert execute_result == b64encode(b"".join(stdout)).decode("utf-8") + mock_hook.run.assert_called_once_with( + command=command, + ps_path=None, + output_encoding="utf-8", + return_output=True, + ) + else: + exception_msg = f"Error running cmd: {command}, return code: {real_return_code}, error: KO" + with pytest.raises(AirflowException, match=exception_msg): + op.execute(None)