From 0b0f4ac3caca72e67273f9e80221677d78ad5c0e Mon Sep 17 00:00:00 2001 From: Xiaodong Date: Thu, 30 Aug 2018 03:20:11 +0800 Subject: [PATCH] [AIRFLOW-2948] Arg check & better doc - SSHOperator & SFTPOperator (#3793) There may be different combinations of arguments, and some processings are being done 'silently', while users may not be fully aware of them. For example - User only needs to provide either `ssh_hook` or `ssh_conn_id`, while this is not clear in doc - if both provided, `ssh_conn_id` will be ignored. - if `remote_host` is provided, it will replace the `remote_host` which wasndefined in `ssh_hook` or predefined in the connection of `ssh_conn_id` These should be documented clearly to ensure it's transparent to the users. log.info() should also be used to remind users and provide clear logs. In addition, add instance check for ssh_hook to ensure it is of the correct type (SSHHook). Tests are updated for this PR. --- airflow/contrib/operators/sftp_operator.py | 22 ++++-- airflow/contrib/operators/ssh_operator.py | 22 ++++-- tests/contrib/operators/test_sftp_operator.py | 76 +++++++++++++++++-- tests/contrib/operators/test_ssh_operator.py | 60 +++++++++++++++ 4 files changed, 165 insertions(+), 15 deletions(-) diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index 3c736c8b95101..a3b5c1f24492b 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -33,11 +33,15 @@ class SFTPOperator(BaseOperator): This operator uses ssh_hook to open sftp trasport channel that serve as basis for file transfer. - :param ssh_hook: predefined ssh_hook to use for remote execution + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: :class:`SSHHook` - :param ssh_conn_id: connection id from airflow Connections + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ingored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param local_filepath: local file path to get or put. (templated) :type local_filepath: str @@ -77,13 +81,21 @@ def __init__(self, def execute(self, context): file_msg = None try: - if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if not self.ssh_hook: - raise AirflowException("can not operate without ssh_hook or ssh_conn_id") + raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + + "It will replace the remote_host which was defined " + + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host with self.ssh_hook.get_conn() as ssh_client: diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index c0e8953d2c344..2bf342935d60c 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -31,11 +31,15 @@ class SSHOperator(BaseOperator): """ SSHOperator to execute commands on given remote host using the ssh_hook. - :param ssh_hook: predefined ssh_hook to use for remote execution + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: :class:`SSHHook` - :param ssh_conn_id: connection id from airflow Connections + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ingored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param command: command to execute on remote host. (templated) :type command: str @@ -68,14 +72,22 @@ def __init__(self, def execute(self, context): try: - if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, - timeout=self.timeout) + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, + timeout=self.timeout) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + + "It will replace the remote_host which was defined " + + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host if not self.command: diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 01446a6fddd49..5770c1b940eb5 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -20,6 +20,7 @@ import os import unittest from base64 import b64encode +import six from airflow import configuration from airflow import models @@ -219,6 +220,71 @@ def test_json_file_transfer_get(self): self.assertEqual(content_received.strip(), test_remote_file_content.encode('utf-8').decode('utf-8')) + def test_arg_checking(self): + from airflow.exceptions import AirflowException + conn_id = "conn_id_for_testing" + os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + if six.PY2: + self.assertRaisesRegex = self.assertRaisesRegexp + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SFTPOperator( + task_id="test_sftp", + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SFTPOperator( + task_id="test_sftp", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=conn_id, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + + task_2 = SFTPOperator( + task_id="test_sftp", + ssh_conn_id=conn_id, # no ssh_hook provided + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + ssh_conn_id=conn_id, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + def delete_local_resource(self): if os.path.exists(self.test_local_filepath): os.remove(self.test_local_filepath) @@ -226,11 +292,11 @@ def delete_local_resource(self): def delete_remote_resource(self): # check the remote file content remove_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="rm {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag ) self.assertIsNotNone(remove_file_task) ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index 7ddd24b2ac2ca..1a2c788596671 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -19,6 +19,7 @@ import unittest from base64 import b64encode +import six from airflow import configuration from airflow import models @@ -148,6 +149,65 @@ def test_no_output_command(self): self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'') + def test_arg_checking(self): + import os + from airflow.exceptions import AirflowException + conn_id = "conn_id_for_testing" + TIMEOUT = 5 + os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + if six.PY2: + self.assertRaisesRegex = self.assertRaisesRegexp + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SSHOperator(task_id="test", command="echo -n airflow", + timeout=TIMEOUT, dag=self.dag) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SSHOperator( + task_id="test_1", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=conn_id, + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + + task_2 = SSHOperator( + task_id="test_2", + ssh_conn_id=conn_id, # no ssh_hook provided + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SSHOperator( + task_id="test_3", + ssh_hook=self.hook, + ssh_conn_id=conn_id, + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + if __name__ == '__main__': unittest.main()