diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index d6b51b422e95e..8489932e5de1e 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -63,13 +63,13 @@ class SSHHook(BaseHook): :type keepalive_interval: int """ - # key type name to paramiko PKey class - _default_pkey_mappings = { - 'dsa': paramiko.DSSKey, - 'ecdsa': paramiko.ECDSAKey, - 'ed25519': paramiko.Ed25519Key, - 'rsa': paramiko.RSAKey, - } + # List of classes to try loading private keys as, ordered (roughly) by most common to least common + _pkey_loaders = ( + paramiko.RSAKey, + paramiko.ECDSAKey, + paramiko.Ed25519Key, + paramiko.DSSKey, + ) _host_key_mappings = { 'rsa': paramiko.RSAKey, @@ -357,15 +357,17 @@ def _pkey_from_private_key(self, private_key: str, passphrase: Optional[str] = N Creates appropriate paramiko key for given private key :param private_key: string containing private key - :return: `paramiko.PKey` appropriate for given key + :return: ``paramiko.PKey`` appropriate for given key :raises AirflowException: if key cannot be read """ - allowed_pkey_types = self._default_pkey_mappings.values() - for pkey_type in allowed_pkey_types: + for pkey_class in self._pkey_loaders: try: - key = pkey_type.from_private_key(StringIO(private_key), password=passphrase) + key = pkey_class.from_private_key(StringIO(private_key), password=passphrase) + # Test it acutally works. If Paramiko loads an openssh generated key, sometimes it will + # happily load it as the wrong type, only to fail when actually used. + key.sign_ssh_data(b'') return key - except paramiko.ssh_exception.SSHException: + except (paramiko.ssh_exception.SSHException, ValueError): continue raise AirflowException( 'Private key provided cannot be read by paramiko.' diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index 1f6d805b81565..70ba4eb0ce642 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -18,6 +18,7 @@ import json import random import string +import textwrap import unittest from io import StringIO from typing import Optional @@ -25,6 +26,7 @@ import paramiko +from airflow import settings from airflow.models import Connection from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.utils import db @@ -473,6 +475,47 @@ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self, assert ssh_client.return_value.connect.called is True assert ssh_client.return_value.get_host_keys.return_value.add.called is False + def test_openssh_private_key(self): + # Paramiko behaves differently with OpenSSH generated keys to paramiko + # generated keys, so we need a test one. + # This has been gernerated specifically to put here, it is not otherwise in use + TEST_OPENSSH_PRIVATE_KEY = "-----BEGIN OPENSSH " + textwrap.dedent( + """\ + PRIVATE KEY----- + b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn + NhAAAAAwEAAQAAAIEAuPKIGPWtIpMDrXwMAvNKQlhQ1gXV/tKyufElw/n6hrr6lvtfGhwX + DihHMsAF+8+KKWQjWgh0fttbIF3+3C56Ns8hgvgMQJT2nyWd7egwqn+LQa08uCEBEka3MO + arKzj39P66EZ/KQDD29VErlVOd97dPhaR8pOZvzcHxtLbU6rMAAAIA3uBiZd7gYmUAAAAH + c3NoLXJzYQAAAIEAuPKIGPWtIpMDrXwMAvNKQlhQ1gXV/tKyufElw/n6hrr6lvtfGhwXDi + hHMsAF+8+KKWQjWgh0fttbIF3+3C56Ns8hgvgMQJT2nyWd7egwqn+LQa08uCEBEka3MOar + Kzj39P66EZ/KQDD29VErlVOd97dPhaR8pOZvzcHxtLbU6rMAAAADAQABAAAAgA2QC5b4/T + dZ3J0uSZs1yC5RV6w6RVUokl68Zm6WuF6E+7dyu6iogrBRF9eK6WVr9M/QPh9uG0zqPSaE + fhobdm7KeycXmtDtrJnXE2ZSk4oU29++TvYZBrAqAli9aHlSArwiLnOIMzY/kIHoSJLJmd + jwXykdQ7QAd93KPEnkaMzBAAAAQGTyp6/wWqtqpMmYJ5prCGNtpVOGthW5upeiuQUytE/K + 5pyPoq6dUCUxQpkprtkuNAv/ff9nW6yy1v2DWohKfaEAAABBAO3y+erRXmiMreMOAd1S84 + RK2E/LUHOvClQqf6GnVavmIgkxIYEgjcFiWv4xIkTc1/FN6aX5aT4MB3srvuM7sxEAAABB + AMb6QAkvxo4hT/xKY0E0nG7zCUMXeBV35MEXQK0/InFC7aZ0tjzFsQJzLe/7q7ljIf+9/O + rCqNhxgOrv7XrRuYMAAAAKYXNoQHNpbm9wZQE= + -----END OPENSSH PRIVATE KEY----- + """ + ) + + session = settings.Session() + try: + conn = Connection( + conn_id='openssh_pkey', + host='localhost', + conn_type='ssh', + extra={"private_key": TEST_OPENSSH_PRIVATE_KEY}, + ) + session.add(conn) + session.flush() + hook = SSHHook(ssh_conn_id=conn.conn_id) + assert isinstance(hook.pkey, paramiko.RSAKey) + finally: + session.delete(conn) + session.commit() + if __name__ == '__main__': unittest.main()