Skip to content

Commit

Permalink
add expected_return_code parameter to WinRMOperator (apache#46534)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkag authored and ambika-garg committed Feb 13, 2025
1 parent c420710 commit 451a37d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class WinRMOperator(BaseOperator):
If specified, it will execute the command as powershell script.
:param output_encoding: the encoding used to decode stout and stderr
:param timeout: timeout for executing the command.
:param expected_return_code: expected return code value(s) of command.
"""

template_fields: Sequence[str] = ("command",)
Expand All @@ -64,6 +65,7 @@ def __init__(
ps_path: str | None = None,
output_encoding: str = "utf-8",
timeout: int = 10,
expected_return_code: int | list[int] | range = 0,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -74,6 +76,7 @@ def __init__(
self.ps_path = ps_path
self.output_encoding = output_encoding
self.timeout = timeout
self.expected_return_code = expected_return_code

def execute(self, context: Context) -> list | str:
if self.ssh_conn_id and not self.winrm_hook:
Expand All @@ -96,7 +99,13 @@ def execute(self, context: Context) -> list | str:
return_output=self.do_xcom_push,
)

if return_code == 0:
success = False
if isinstance(self.expected_return_code, int):
success = return_code == self.expected_return_code
elif isinstance(self.expected_return_code, list) or isinstance(self.expected_return_code, range):
success = return_code in self.expected_return_code

if success:
# returning output if do_xcom_push is set
# TODO: Remove this after minimum Airflow version is 3.0
enable_pickling = conf.getboolean("core", "enable_xcom_pickling", fallback=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from base64 import b64encode
from unittest import mock

import pytest
Expand All @@ -38,3 +39,63 @@ def test_no_command(self, mock_hook):
exception_msg = "No command specified so nothing to execute here."
with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)

@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
def test_default_returning_0_command(self, mock_hook):
stdout = [b"O", b"K"]
command = "not_empty"
mock_hook.run.return_value = (0, stdout, [])
op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command)
execute_result = op.execute(None)
assert execute_result == b64encode(b"".join(stdout)).decode("utf-8")
mock_hook.run.assert_called_once_with(
command=command,
ps_path=None,
output_encoding="utf-8",
return_output=True,
)

@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
def test_default_returning_1_command(self, mock_hook):
stderr = [b"K", b"O"]
command = "not_empty"
mock_hook.run.return_value = (1, [], stderr)
op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command)
exception_msg = f"Error running cmd: {command}, return code: 1, error: KO"
with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)

@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
@pytest.mark.parametrize("expected_return_code", [1, [1, 2], range(1, 3)])
@pytest.mark.parametrize("real_return_code", [0, 1, 2])
def test_expected_return_code_command(self, mock_hook, expected_return_code, real_return_code):
stdout = [b"O", b"K"]
stderr = [b"K", b"O"]
command = "not_empty"
mock_hook.run.return_value = (real_return_code, stdout, stderr)
op = WinRMOperator(
task_id="test_task_id",
winrm_hook=mock_hook,
command=command,
expected_return_code=expected_return_code,
)

should_task_succeed = False
if isinstance(expected_return_code, int):
should_task_succeed = real_return_code == expected_return_code
elif isinstance(expected_return_code, list) or isinstance(expected_return_code, range):
should_task_succeed = real_return_code in expected_return_code

if should_task_succeed:
execute_result = op.execute(None)
assert execute_result == b64encode(b"".join(stdout)).decode("utf-8")
mock_hook.run.assert_called_once_with(
command=command,
ps_path=None,
output_encoding="utf-8",
return_output=True,
)
else:
exception_msg = f"Error running cmd: {command}, return code: {real_return_code}, error: KO"
with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)

0 comments on commit 451a37d

Please sign in to comment.