Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move responsibility to run a command from WinRMOperator to WinRMHook #43646

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
42d1378
refactor: Moved responsibility to run a command away from WinRmOperat…
davidblain-infrabel Nov 4, 2024
701b9d4
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 4, 2024
e08922b
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 4, 2024
c94e621
refactor: Reformatted exception message in WinRMOperator
davidblain-infrabel Nov 5, 2024
ddd7bdc
refactor: command parameter of run method in WinRMHook must be specified
davidblain-infrabel Nov 5, 2024
132d39a
refactor: Changed return type of run method in WinRMHook
davidblain-infrabel Nov 5, 2024
006e9c1
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
48e62c4
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
ed48f08
refactor: WinRMHook cannot be closable as it doesn't have the winrm_c…
davidblain-infrabel Nov 5, 2024
c47a7a5
refactor: Reorganized imports in WinRMHook
davidblain-infrabel Nov 5, 2024
6faf18e
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
150bb3e
refactor: Added unit tests for new run method in WinRMHook
davidblain-infrabel Nov 5, 2024
aae74bf
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
f8868ec
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
5f0dcfe
refactor: Reorganized imports in TestWinRMHook
davidblain-infrabel Nov 5, 2024
b7019a3
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
11cef77
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 5, 2024
cc407f5
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 6, 2024
d6f3c0a
Merge branch 'main' into refactor/closable-winrm-hook-with-run-method
dabla Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

from __future__ import annotations

from base64 import b64encode
from contextlib import suppress

from winrm.exceptions import WinRMOperationTimeoutError
from winrm.protocol import Protocol

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -218,3 +222,71 @@ def get_conn(self):
raise AirflowException(error_msg)

return self.client

def run(
self,
command: str,
ps_path: str | None = None,
output_encoding: str = "utf-8",
return_output: bool = True,
) -> tuple[int, list[bytes], list[bytes]]:
"""
Run a command.

:param command: command to execute on remote host.
:param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` for v6+.
If specified, it will execute the command as powershell script.
:param output_encoding: the encoding used to decode stout and stderr.
:param return_output: Whether to accumulate and return the stdout or not.
:return: returns a tuple containing return_code, stdout and stderr in order.
"""
winrm_client = self.get_conn()

try:
if ps_path is not None:
self.log.info("Running command as powershell script: '%s'...", command)
encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii")
command_id = self.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, f"{ps_path} -encodedcommand {encoded_ps}"
)
else:
self.log.info("Running command: '%s'...", command)
command_id = self.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, command
)

# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
stdout_buffer = []
stderr_buffer = []
command_done = False
while not command_done:
# this is an expected error when waiting for a long-running process, just silently retry
with suppress(WinRMOperationTimeoutError):
(
stdout,
stderr,
return_code,
command_done,
) = self.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined]
winrm_client, command_id
)

# Only buffer stdout if we need to so that we minimize memory usage.
if return_output:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)

for line in stdout.decode(output_encoding).splitlines():
self.log.info(line)
for line in stderr.decode(output_encoding).splitlines():
self.log.warning(line)

self.winrm_protocol.cleanup_command( # type: ignore[attr-defined]
winrm_client, command_id
)

return return_code, stdout_buffer, stderr_buffer
except Exception as e:
raise AirflowException(f"WinRM operator error: {e}")
finally:
self.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined]
73 changes: 12 additions & 61 deletions providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from base64 import b64encode
from typing import TYPE_CHECKING, Sequence

from winrm.exceptions import WinRMOperationTimeoutError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -90,68 +88,21 @@ def execute(self, context: Context) -> list | str:
if not self.command:
raise AirflowException("No command specified so nothing to execute here.")

winrm_client = self.winrm_hook.get_conn()

try:
if self.ps_path is not None:
self.log.info("Running command as powershell script: '%s'...", self.command)
encoded_ps = b64encode(self.command.encode("utf_16_le")).decode("ascii")
command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, f"{self.ps_path} -encodedcommand {encoded_ps}"
)
else:
self.log.info("Running command: '%s'...", self.command)
command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, self.command
)

# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
stdout_buffer = []
stderr_buffer = []
command_done = False
while not command_done:
try:
(
stdout,
stderr,
return_code,
command_done,
) = self.winrm_hook.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined]
winrm_client, command_id
)

# Only buffer stdout if we need to so that we minimize memory usage.
if self.do_xcom_push:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)

for line in stdout.decode(self.output_encoding).splitlines():
self.log.info(line)
for line in stderr.decode(self.output_encoding).splitlines():
self.log.warning(line)
except WinRMOperationTimeoutError:
# this is an expected error when waiting for a
# long-running process, just silently retry
pass

self.winrm_hook.winrm_protocol.cleanup_command( # type: ignore[attr-defined]
winrm_client, command_id
)
self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined]

except Exception as e:
raise AirflowException(f"WinRM operator error: {e}")
return_code, stdout_buffer, stderr_buffer = self.winrm_hook.run(
command=self.command,
ps_path=self.ps_path,
output_encoding=self.output_encoding,
return_output=self.do_xcom_push,
)

if return_code == 0:
# returning output if do_xcom_push is set
enable_pickling = conf.getboolean("core", "enable_xcom_pickling")

if enable_pickling:
return stdout_buffer
else:
return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)
else:
stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
error_msg = (
f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}"
)
raise AirflowException(error_msg)
return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)

stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
error_msg = f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}"
raise AirflowException(error_msg)
84 changes: 83 additions & 1 deletion providers/tests/microsoft/winrm/hooks/test_winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -119,3 +119,85 @@ def test_get_conn_no_endpoint(self, mock_protocol):
winrm_hook.get_conn()

assert f"http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman" == winrm_hook.endpoint

@patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", autospec=True)
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
login="username",
password="password",
host="remote_host",
extra="""{
"endpoint": "endpoint",
"remote_port": 123,
"transport": "plaintext",
"service": "service",
"keytab": "keytab",
"ca_trust_path": "ca_trust_path",
"cert_pem": "cert_pem",
"cert_key_pem": "cert_key_pem",
"server_cert_validation": "validate",
"kerberos_delegation": "true",
"read_timeout_sec": 124,
"operation_timeout_sec": 123,
"kerberos_hostname_override": "kerberos_hostname_override",
"message_encryption": "auto",
"credssp_disable_tlsv1_2": "true",
"send_cbt": "false"
}""",
),
)
def test_run_with_stdout(self, mock_get_connection, mock_protocol):
winrm_hook = WinRMHook(ssh_conn_id="conn_id")

mock_protocol.return_value.run_command = MagicMock(return_value="command_id")
mock_protocol.return_value._raw_get_command_output = MagicMock(
return_value=(b"stdout", b"stderr", 0, True)
)

return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir")

assert return_code == 0
assert stdout_buffer == [b"stdout"]
assert stderr_buffer == [b"stderr"]

@patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", autospec=True)
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
login="username",
password="password",
host="remote_host",
extra="""{
"endpoint": "endpoint",
"remote_port": 123,
"transport": "plaintext",
"service": "service",
"keytab": "keytab",
"ca_trust_path": "ca_trust_path",
"cert_pem": "cert_pem",
"cert_key_pem": "cert_key_pem",
"server_cert_validation": "validate",
"kerberos_delegation": "true",
"read_timeout_sec": 124,
"operation_timeout_sec": 123,
"kerberos_hostname_override": "kerberos_hostname_override",
"message_encryption": "auto",
"credssp_disable_tlsv1_2": "true",
"send_cbt": "false"
}""",
),
)
def test_run_without_stdout(self, mock_get_connection, mock_protocol):
winrm_hook = WinRMHook(ssh_conn_id="conn_id")

mock_protocol.return_value.run_command = MagicMock(return_value="command_id")
mock_protocol.return_value._raw_get_command_output = MagicMock(
return_value=(b"stdout", b"stderr", 0, True)
)

return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir", return_output=False)

assert return_code == 0
assert not stdout_buffer
assert stderr_buffer == [b"stderr"]