diff --git a/pyhive/hive.py b/pyhive/hive.py index 3f71df33..110faf96 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -200,8 +200,6 @@ def __init__( self._transport = thrift.transport.TTransport.TBufferedTransport(socket) elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): # Defer import so package dependency is optional - import sasl - import thrift_sasl if auth == 'KERBEROS': # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library @@ -212,7 +210,9 @@ def __init__( # Password doesn't matter in NONE mode, just needs to be nonempty. password = 'x' - def sasl_factory(): + + def get_sasl_client(): + import sasl sasl_client = sasl.Client() sasl_client.setAttr('host', host) if sasl_auth == 'GSSAPI': @@ -224,6 +224,33 @@ def sasl_factory(): raise AssertionError sasl_client.init() return sasl_client + + + def get_pure_sasl_client(): + from pyhive.sasl_compat import PureSASLClient + sasl_kwargs = {} + if sasl_auth == 'GSSAPI': + sasl_kwargs['service'] = kerberos_service_name + elif sasl_auth == 'PLAIN': + sasl_kwargs['username'] = username + sasl_kwargs['password'] = password + else: + raise AssertionError + return PureSASLClient(host=host, **sasl_kwargs) + + + def sasl_factory(): + try: + sasl_client = get_sasl_client() + # The sasl library is available + except ImportError: + # Fallback to pure-sasl library + sasl_client = get_pure_sasl_client() + + return sasl_client + + + import thrift_sasl self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) else: # All HS2 config options: diff --git a/pyhive/sasl_compat.py b/pyhive/sasl_compat.py new file mode 100644 index 00000000..dc65abe9 --- /dev/null +++ b/pyhive/sasl_compat.py @@ -0,0 +1,56 @@ +# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py +# which uses Apache-2.0 license as of 21 May 2023. +# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl +# via PR https://github.com/cloudera/impyla/pull/179 +# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 +# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 +# Hence this code is required for the fallback to work. + + +from puresasl.client import SASLClient, SASLError +from contextlib import contextmanager + +@contextmanager +def error_catcher(self, Exc = Exception): + try: + self.error = None + yield + except Exc as e: + self.error = str(e) + + +class PureSASLClient(SASLClient): + def __init__(self, *args, **kwargs): + self.error = None + super(PureSASLClient, self).__init__(*args, **kwargs) + + def start(self, mechanism): + with error_catcher(self, SASLError): + if isinstance(mechanism, list): + self.choose_mechanism(mechanism) + else: + self.choose_mechanism([mechanism]) + return True, self.mechanism, self.process() + # else + return False, mechanism, None + + def encode(self, incoming): + with error_catcher(self): + return True, self.unwrap(incoming) + # else + return False, None + + def decode(self, outgoing): + with error_catcher(self): + return True, self.wrap(outgoing) + # else + return False, None + + def step(self, challenge=None): + with error_catcher(self): + return True, self.process(challenge) + # else + return False, None + + def getError(self): + return self.error diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index c70ed962..d7700594 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -8,6 +8,7 @@ from __future__ import unicode_literals import contextlib +import importlib.util import datetime import os import socket @@ -15,12 +16,9 @@ import time import unittest from decimal import Decimal - import mock -import sasl import thrift.transport.TSocket import thrift.transport.TTransport -import thrift_sasl from thrift.transport.TTransport import TTransportException from TCLIService import ttypes @@ -204,14 +202,35 @@ def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' - def sasl_factory(): + spec_sasl = importlib.util.find_spec('sasl') + spec_puresasl = importlib.util.find_spec('puresasl') + + def get_sasl_client(): + import sasl sasl_client = sasl.Client() sasl_client.setAttr('host', 'localhost') sasl_client.setAttr('username', 'test_username') sasl_client.setAttr('password', 'x') sasl_client.init() return sasl_client + + def get_pure_sasl_client(): + from pyhive.sasl_compat import PureSASLClient + sasl_client = PureSASLClient(host='localhost', username='test_username', password='x') + return sasl_client + + def sasl_factory(): + if spec_sasl: + sasl_client = get_sasl_client() + return sasl_client + elif spec_puresasl: + sasl_client = get_pure_sasl_client() + return sasl_client + else: + raise ValueError("No suitable SASL module available. Please install either sasl or pure-sasl.") + + import thrift_sasl transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) conn = hive.connect(thrift_transport=transport) with contextlib.closing(conn): diff --git a/pyhive/tests/test_sasl_compat.py b/pyhive/tests/test_sasl_compat.py new file mode 100644 index 00000000..f1c81a62 --- /dev/null +++ b/pyhive/tests/test_sasl_compat.py @@ -0,0 +1,307 @@ +import unittest +import base64 +import hashlib +import hmac +import kerberos +from mock import patch +import six +import struct +from puresasl import SASLProtocolException, QOP +from puresasl.client import SASLError +from pyhive.sasl_compat import PureSASLClient, error_catcher + + +class TestPureSASLClient(unittest.TestCase): + """Test cases for initialization of SASL client using PureSASLClient class""" + + def setUp(self): + self.sasl_kwargs = {} + self.sasl = PureSASLClient('localhost', **self.sasl_kwargs) + + def test_start_no_mechanism(self): + """Test starting SASL authentication with no mechanism.""" + success, mechanism, response = self.sasl.start(mechanism=None) + self.assertFalse(success) + self.assertIsNone(mechanism) + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_wrong_mechanism(self): + """Test starting SASL authentication with a single unsupported mechanism.""" + success, mechanism, response = self.sasl.start(mechanism='WRONG') + self.assertFalse(success) + self.assertEqual(mechanism, 'WRONG') + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_invalid_mechanisms(self): + """Test starting SASL authentication with a list of unsupported mechanisms.""" + self.sasl.start(['invalid1', 'invalid2']) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_valid_mechanisms(self): + """Test starting SASL authentication with a list of supported mechanisms.""" + self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5']) + # Validate right mechanism is chosen based on score. + self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5') + + def test_error_catcher_no_error(self): + """Test the error_catcher with no error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='ANONYMOUS') + + self.assertEqual(self.sasl.getError(), None) + self.assertEqual(result, True) + + def test_error_catcher_with_error(self): + """Test the error_catcher with an error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='WRONG') + + self.assertEqual(result, False) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + +"""Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims""" + +class _BaseMechanismTests(unittest.TestCase): + """Base test case for SASL mechanisms.""" + + mechanism = 'ANONYMOUS' + sasl_kwargs = {} + + def setUp(self): + self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + self.mechanism_class = self.sasl._chosen_mech + + def test_init_basic(self, *args): + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + mech = sasl._chosen_mech + self.assertIs(mech.sasl, sasl) + + def test_step_basic(self, *args): + success, response = self.sasl.step(six.b('string')) + self.assertTrue(success) + self.assertIsInstance(response, six.binary_type) + + def test_decode_encode(self, *args): + self.assertEqual(self.sasl.encode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + self.assertEqual(self.sasl.decode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + + +class AnonymousMechanismTest(_BaseMechanismTests): + """Test case for the Anonymous SASL mechanism.""" + + mechanism = 'ANONYMOUS' + + +class PlainTextMechanismTest(_BaseMechanismTests): + """Test case for the PlainText SASL mechanism.""" + + mechanism = 'PLAIN' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + for challenge in (None, '', b'asdf', u"\U0001F44D"): + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + + def test_step_with_authorization_id_or_identity(self): + challenge = u"\U0001F44D" + identity = 'user2' + + # Test that we can pass an identity + sasl_kwargs = self.sasl_kwargs.copy() + sasl_kwargs.update({'identity': identity}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + # Test that the sasl authorization_id has priority over identity + auth_id = 'user3' + sasl_kwargs.update({'authorization_id': auth_id}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class ExternalMechanismTest(_BaseMechanismTests): + """Test case for the External SASL mechanisms""" + + mechanism = 'EXTERNAL' + + def test_step(self): + self.assertEqual(self.sasl.step(), (True, b'')) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +@patch('puresasl.mechanisms.kerberos.authGSSClientStep') +@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response'))) +class GSSAPIMechanismTest(_BaseMechanismTests): + """Test case for the GSSAPI SASL mechanism.""" + + mechanism = 'GSSAPI' + service = 'GSSAPI' + sasl_kwargs = {'service': service} + + @patch('puresasl.mechanisms.kerberos.authGSSClientWrap') + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args): + # bypassing step setup by setting qop directly + self.mechanism_class.qop = QOP.AUTH + msg = b'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication + for qop in (QOP.AUTH_INT, QOP.AUTH_CONF): + self.mechanism_class.qop = qop + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1): + self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + if qop == QOP.AUTH_CONF: + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0): + self.assertEqual(self.sasl.encode(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.') + + def test_step_no_user(self, authGSSClientResponse, *args): + msg = six.b('whatever') + + # no user + self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + + username = 'username' + # with user; this has to be last because it sets mechanism.user + with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE): + with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + self.assertEqual(self.mechanism_class.user, six.b(username)) + + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_step_qop(self, *args): + self.mechanism_class._have_negotiated_details = True + self.mechanism_class.user = 'user' + msg = six.b('msg') + self.assertEqual(self.sasl.step(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Bad response from server') + + max_len = 100 + self.assertLess(max_len, self.sasl.max_buffer) + for i, qop in QOP.bit_map.items(): + qop_size = struct.pack('!i', i << 24 | max_len) + response = base64.b64encode(qop_size) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response): + with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap: + self.mechanism_class.complete = False + self.assertEqual(self.sasl.step(msg), (True, qop_size)) + self.assertTrue(self.mechanism_class.complete) + self.assertEqual(self.mechanism_class.qop, qop) + self.assertEqual(self.mechanism_class.max_buffer, max_len) + + args = authGSSClientWrap.call_args[0] + out_data = args[1] + out = base64.b64decode(out_data) + self.assertEqual(out[:4], qop_size) + self.assertEqual(out[4:], six.b(self.mechanism_class.user)) + + +class CramMD5MechanismTest(_BaseMechanismTests): + """Test case for the CRAM-MD5 SASL mechanism.""" + + mechanism = 'CRAM-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + success, response = self.sasl.step(None) + self.assertTrue(success) + self.assertIsNone(response) + challenge = six.b('msg') + hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5) + hash.update(challenge) + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertIn(six.b(self.username), response) + self.assertIn(six.b(hash.hexdigest()), response) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(self.sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class DigestMD5MechanismTest(_BaseMechanismTests): + """Test case for the DIGEST-MD5 SASL mechanism.""" + + mechanism = 'DIGEST-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + def test_step_basic(self, *args): + pass + + def test_step(self): + """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism.""" + testChallenge = ( + b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r' + b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc' + b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess' + ) + result, response = self.sasl.step(testChallenge) + self.assertTrue(result) + self.assertIsNotNone(response) + + def test_step_server_answer(self): + """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism.""" + sasl_kwargs = {'username': "chris", 'password': "secret"} + sasl = PureSASLClient('elwood.innosoft.com', + service="imap", + mechanism=self.mechanism, + mutual_auth=True, + **sasl_kwargs) + testChallenge = ( + b'utf-8,username="chris",realm="elwood.innosoft.com",' + b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' + b'digest-uri="imap/elwood.innosoft.com",' + b'response=d388dad90d4bbd760a152321f2143af7,qop=auth' + ) + sasl.step(testChallenge) + sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" + + serverResponse = ( + b'rspauth=ea40f60335c427b5527b84dbabcdfffd' + ) + sasl.step(serverResponse) + # assert that step choses the only supported QOP for for DIGEST-MD5 + self.assertEqual(self.sasl.qop, QOP.AUTH) diff --git a/setup.py b/setup.py index be593fc0..9b198b26 100755 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def run_tests(self): 'presto': ['requests>=1.0.0'], 'trino': ['requests>=1.0.0'], 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], + 'hive_pure': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 'sqlalchemy': ['sqlalchemy>=1.3.0'], 'kerberos': ['requests_kerberos>=0.12.0'], }, @@ -55,7 +56,9 @@ def run_tests(self): 'pytest-cov', 'requests>=1.0.0', 'requests_kerberos>=0.12.0', + 'kerberos>=1.3.0', 'sasl>=0.2.1', + 'pure-sasl>=0.6.2', 'sqlalchemy>=1.3.0', 'thrift>=0.10.0', ],