Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ssl support dbtspark #169

Merged
merged 22 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class SparkCredentials(Credentials):
organization: str = '0'
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False

@classmethod
def __pre_deserialize__(cls, data):
Expand Down Expand Up @@ -348,11 +349,21 @@ def open(cls, connection):
cls.validate_creds(creds,
['host', 'port', 'user', 'schema'])

conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
if creds.use_ssl:
import puretransport
rahulgoyal2987 marked this conversation as resolved.
Show resolved Hide resolved
transport = puretransport.transport_factory(host=creds.host,
port=creds.port,
username=creds.user,
password='dummy',
use_ssl=creds.use_ssl,
kerberos_service_name=creds.kerberos_service_name)
conn = hive.connect(thrift_transport=transport)
else:
conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
if creds.cluster is not None:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ PyHive[hive]>=0.6.0,<0.7.0
pyodbc>=4.0.30
sqlparams>=3.0.0
thrift>=0.11.0,<0.12.0
pure-transport>=0.2.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put an upper bound on the version here, since it's version-0 software?

You'll need to add this to the extra setup requirements as well:
https://github.com/fishtown-analytics/dbt-spark/blob/dff1b613ddf87e4e72e8a47475bcfd1d55796a5c/setup.py#L41-L44

I'm inclined to bundle it in with the larger PyHive extra, unless there's a good reason to bundle it separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jtcohen6 I have added the logic to build transport object rather than using pure-transport

34 changes: 34 additions & 0 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def _get_target_thrift_kerberos(self, project):
'target': 'test'
})

def _get_target_use_ssl_thrift(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'use_ssl': True,
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
}
},
'target': 'test'
})

def _get_target_odbc_cluster(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
Expand Down Expand Up @@ -154,6 +170,24 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name):
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(thrift_transport):
client_factory = thrift_transport.sasl_client_factory()
self.assertIsNotNone(thrift_transport)
self.assertEqual(client_factory.host, 'myorg.sparkhost.com')

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
connection.handle # trigger lazy-load

self.assertEqual(connection.state, 'open')
self.assertIsNotNone(connection.handle)
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)
Expand Down