Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expected_return_code parameter to WinRMOperator #46534

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class WinRMOperator(BaseOperator):
If specified, it will execute the command as powershell script.
:param output_encoding: the encoding used to decode stout and stderr
:param timeout: timeout for executing the command.
:param expected_return_code: expected return code value(s) of command.
"""

template_fields: Sequence[str] = ("command",)
Expand All @@ -64,6 +65,7 @@ def __init__(
ps_path: str | None = None,
output_encoding: str = "utf-8",
timeout: int = 10,
expected_return_code: int | list[int] | range = 0,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -74,6 +76,7 @@ def __init__(
self.ps_path = ps_path
self.output_encoding = output_encoding
self.timeout = timeout
self.expected_return_code = expected_return_code

def execute(self, context: Context) -> list | str:
if self.ssh_conn_id and not self.winrm_hook:
Expand All @@ -96,7 +99,13 @@ def execute(self, context: Context) -> list | str:
return_output=self.do_xcom_push,
)

if return_code == 0:
success = False
if isinstance(self.expected_return_code, int):
success = return_code == self.expected_return_code
elif isinstance(self.expected_return_code, list) or isinstance(self.expected_return_code, range):
success = return_code in self.expected_return_code

if success:
# returning output if do_xcom_push is set
# TODO: Remove this after minimum Airflow version is 3.0
enable_pickling = conf.getboolean("core", "enable_xcom_pickling", fallback=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from base64 import b64encode
from unittest import mock

import pytest
Expand All @@ -38,3 +39,63 @@ def test_no_command(self, mock_hook):
exception_msg = "No command specified so nothing to execute here."
with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)

@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
def test_default_returning_0_command(self, mock_hook):
stdout = [b"O", b"K"]
command = "not_empty"
mock_hook.run.return_value = (0, stdout, [])
op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command)
execute_result = op.execute(None)
assert execute_result == b64encode(b"".join(stdout)).decode("utf-8")
mock_hook.run.assert_called_once_with(
command=command,
ps_path=None,
output_encoding="utf-8",
return_output=True,
)

@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
def test_default_returning_1_command(self, mock_hook):
stderr = [b"K", b"O"]
command = "not_empty"
mock_hook.run.return_value = (1, [], stderr)
op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command)
exception_msg = f"Error running cmd: {command}, return code: 1, error: KO"
with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)

@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
@pytest.mark.parametrize("expected_return_code", [1, [1, 2], range(1, 3)])
@pytest.mark.parametrize("real_return_code", [0, 1, 2])
def test_expected_return_code_command(self, mock_hook, expected_return_code, real_return_code):
stdout = [b"O", b"K"]
stderr = [b"K", b"O"]
command = "not_empty"
mock_hook.run.return_value = (real_return_code, stdout, stderr)
op = WinRMOperator(
task_id="test_task_id",
winrm_hook=mock_hook,
command=command,
expected_return_code=expected_return_code,
)

should_task_succeed = False
if isinstance(expected_return_code, int):
should_task_succeed = real_return_code == expected_return_code
elif isinstance(expected_return_code, list) or isinstance(expected_return_code, range):
should_task_succeed = real_return_code in expected_return_code

if should_task_succeed:
execute_result = op.execute(None)
assert execute_result == b64encode(b"".join(stdout)).decode("utf-8")
mock_hook.run.assert_called_once_with(
command=command,
ps_path=None,
output_encoding="utf-8",
return_output=True,
)
else:
exception_msg = f"Error running cmd: {command}, return code: {real_return_code}, error: KO"
with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)