Skip to content

Commit

Permalink
[AIRFLOW-2948] Arg check & better doc - SSHOperator & SFTPOperator (#…
Browse files Browse the repository at this point in the history
…3793)

There may be different combinations of arguments, and
some processings are being done 'silently', while users
may not be fully aware of them.

For example
- User only needs to provide either `ssh_hook`
  or `ssh_conn_id`, while this is not clear in doc
- if both provided, `ssh_conn_id` will be ignored.
- if `remote_host` is provided, it will replace
  the `remote_host` which wasndefined in `ssh_hook`
  or predefined in the connection of `ssh_conn_id`

These should be documented clearly to ensure it's
transparent to the users. log.info() should also be
used to remind users and provide clear logs.

In addition, add instance check for ssh_hook to ensure
it is of the correct type (SSHHook).

Tests are updated for this PR.
  • Loading branch information
XD-DENG authored and kaxil committed Jan 9, 2019
1 parent 304f09a commit 319cf17
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 15 deletions.
22 changes: 17 additions & 5 deletions airflow/contrib/operators/sftp_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class SFTPOperator(BaseOperator):
This operator uses ssh_hook to open sftp trasport channel that serve as basis
for file transfer.
:param ssh_hook: predefined ssh_hook to use for remote execution
:param ssh_hook: predefined ssh_hook to use for remote execution.
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
:type ssh_hook: :class:`SSHHook`
:param ssh_conn_id: connection id from airflow Connections
:param ssh_conn_id: connection id from airflow Connections.
`ssh_conn_id` will be ingored if `ssh_hook` is provided.
:type ssh_conn_id: str
:param remote_host: remote host to connect (templated)
Nullable. If provided, it will replace the `remote_host` which was
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
:type remote_host: str
:param local_filepath: local file path to get or put. (templated)
:type local_filepath: str
Expand Down Expand Up @@ -77,13 +81,21 @@ def __init__(self,
def execute(self, context):
file_msg = None
try:
if self.ssh_conn_id and not self.ssh_hook:
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
else:
self.log.info("ssh_hook is not provided or invalid. " +
"Trying ssh_conn_id to create SSHHook.")
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)

if not self.ssh_hook:
raise AirflowException("can not operate without ssh_hook or ssh_conn_id")
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

if self.remote_host is not None:
self.log.info("remote_host is provided explicitly. " +
"It will replace the remote_host which was defined " +
"in ssh_hook or predefined in connection of ssh_conn_id.")
self.ssh_hook.remote_host = self.remote_host

with self.ssh_hook.get_conn() as ssh_client:
Expand Down
22 changes: 17 additions & 5 deletions airflow/contrib/operators/ssh_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ class SSHOperator(BaseOperator):
"""
SSHOperator to execute commands on given remote host using the ssh_hook.
:param ssh_hook: predefined ssh_hook to use for remote execution
:param ssh_hook: predefined ssh_hook to use for remote execution.
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
:type ssh_hook: :class:`SSHHook`
:param ssh_conn_id: connection id from airflow Connections
:param ssh_conn_id: connection id from airflow Connections.
`ssh_conn_id` will be ingored if `ssh_hook` is provided.
:type ssh_conn_id: str
:param remote_host: remote host to connect (templated)
Nullable. If provided, it will replace the `remote_host` which was
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
:type remote_host: str
:param command: command to execute on remote host. (templated)
:type command: str
Expand Down Expand Up @@ -68,14 +72,22 @@ def __init__(self,

def execute(self, context):
try:
if self.ssh_conn_id and not self.ssh_hook:
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
timeout=self.timeout)
if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
else:
self.log.info("ssh_hook is not provided or invalid. " +
"Trying ssh_conn_id to create SSHHook.")
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
timeout=self.timeout)

if not self.ssh_hook:
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

if self.remote_host is not None:
self.log.info("remote_host is provided explicitly. " +
"It will replace the remote_host which was defined " +
"in ssh_hook or predefined in connection of ssh_conn_id.")
self.ssh_hook.remote_host = self.remote_host

if not self.command:
Expand Down
76 changes: 71 additions & 5 deletions tests/contrib/operators/test_sftp_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import unittest
from base64 import b64encode
import six

from airflow import configuration
from airflow import models
Expand Down Expand Up @@ -219,18 +220,83 @@ def test_json_file_transfer_get(self):
self.assertEqual(content_received.strip(),
test_remote_file_content.encode('utf-8').decode('utf-8'))

def test_arg_checking(self):
from airflow.exceptions import AirflowException
conn_id = "conn_id_for_testing"
os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost"

# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
if six.PY2:
self.assertRaisesRegex = self.assertRaisesRegexp
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp",
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
task_0.execute(None)

# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
task_1 = SFTPOperator(
task_id="test_sftp",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=conn_id,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_1.execute(None)
except Exception:
pass
self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id)

task_2 = SFTPOperator(
task_id="test_sftp",
ssh_conn_id=conn_id, # no ssh_hook provided
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_2.execute(None)
except Exception:
pass
self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id)

# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
ssh_conn_id=conn_id,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_3.execute(None)
except Exception:
pass
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)

def delete_local_resource(self):
if os.path.exists(self.test_local_filepath):
os.remove(self.test_local_filepath)

def delete_remote_resource(self):
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="rm {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
task_id="test_check_file",
ssh_hook=self.hook,
command="rm {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(remove_file_task)
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
Expand Down
60 changes: 60 additions & 0 deletions tests/contrib/operators/test_ssh_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import unittest
from base64 import b64encode
import six

from airflow import configuration
from airflow import models
Expand Down Expand Up @@ -148,6 +149,65 @@ def test_no_output_command(self):
self.assertIsNotNone(ti.duration)
self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'')

def test_arg_checking(self):
import os
from airflow.exceptions import AirflowException
conn_id = "conn_id_for_testing"
TIMEOUT = 5
os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost"

# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
if six.PY2:
self.assertRaisesRegex = self.assertRaisesRegexp
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SSHOperator(task_id="test", command="echo -n airflow",
timeout=TIMEOUT, dag=self.dag)
task_0.execute(None)

# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
task_1 = SSHOperator(
task_id="test_1",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=conn_id,
command="echo -n airflow",
timeout=TIMEOUT,
dag=self.dag
)
try:
task_1.execute(None)
except Exception:
pass
self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id)

task_2 = SSHOperator(
task_id="test_2",
ssh_conn_id=conn_id, # no ssh_hook provided
command="echo -n airflow",
timeout=TIMEOUT,
dag=self.dag
)
try:
task_2.execute(None)
except Exception:
pass
self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id)

# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SSHOperator(
task_id="test_3",
ssh_hook=self.hook,
ssh_conn_id=conn_id,
command="echo -n airflow",
timeout=TIMEOUT,
dag=self.dag
)
try:
task_3.execute(None)
except Exception:
pass
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)


if __name__ == '__main__':
unittest.main()

0 comments on commit 319cf17

Please sign in to comment.