Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ssl support dbtspark #169

Merged
merged 22 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from hologram.helpers import StrEnum
from dataclasses import dataclass
from typing import Optional
from thrift.transport.TSSLSocket import TSSLSocket
import thrift
import ssl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's include these up in the try import (line 11) for requirements that may or may not be installed.

ssl is a net-new dependency, right? We'll need to add it in setup.py, within the Pyhive extra

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ssl is part of python library https://docs.python.org/3/library/ssl.html so i think it is not required to add

Copy link
Contributor

@jtcohen6 jtcohen6 Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I said ssl here but I was thinking of thrift_sasl. That's a net-new dependency, right? Just kidding, I follow now, those are dependencies of PyHive[hive]. Thanks for the clarification around ssl.


import base64
import time
Expand Down Expand Up @@ -59,6 +62,7 @@ class SparkCredentials(Credentials):
organization: str = '0'
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False

@classmethod
def __pre_deserialize__(cls, data):
Expand Down Expand Up @@ -348,11 +352,20 @@ def open(cls, connection):
cls.validate_creds(creds,
['host', 'port', 'user', 'schema'])

conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
if creds.use_ssl:
transport = build_ssl_transport(
host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name)
conn = hive.connect(thrift_transport=transport)
else:
conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
if creds.cluster is not None:
Expand Down Expand Up @@ -431,6 +444,52 @@ def open(cls, connection):
return connection


def build_ssl_transport(host, port, username, auth,
kerberos_service_name, password=None):
transport = None
if port is None:
port = 10000
if auth is None:
auth = 'NONE'
socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE)
if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL
# in hive-site.xml
transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import sasl
import thrift_sasl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a different way of handling these optional imports. See the try logic at the top of the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you move these up alongside the other imports (lines 29-32)? I'd prefer not to have imports nested so far down in the fil.


if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI
# in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs
# to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client

transport = thrift_sasl.TSaslClientTransport(sasl_factory,
sasl_auth, socket)
return transport


def _is_retryable_error(exc: Exception) -> Optional[str]:
message = getattr(exc, 'message', None)
if message is None:
Expand Down
35 changes: 35 additions & 0 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def _get_target_thrift_kerberos(self, project):
'target': 'test'
})

def _get_target_use_ssl_thrift(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'use_ssl': True,
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
}
},
'target': 'test'
})

def _get_target_odbc_cluster(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
Expand Down Expand Up @@ -154,6 +170,25 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name):
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(thrift_transport):
self.assertIsNotNone(thrift_transport)
transport = thrift_transport._trans
self.assertEqual(transport.host, 'myorg.sparkhost.com')
self.assertEqual(transport.port, 10001)

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
connection.handle # trigger lazy-load

self.assertEqual(connection.state, 'open')
self.assertIsNotNone(connection.handle)
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)
Expand Down