Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ssl support dbtspark #169

Merged
merged 22 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## dbt next

### Features

- Allow user to specify `use_ssl` ([#169](https://github.com/fishtown-analytics/dbt-spark/pull/169))
- Allow setting table `OPTIONS` using `config` ([#171](https://github.com/fishtown-analytics/dbt-spark/pull/171))

### Fixes
Expand All @@ -16,6 +18,7 @@
- [@friendofasquid](https://github.com/friendofasquid) ([#159](https://github.com/fishtown-analytics/dbt-spark/pull/159))
- [@franloza](https://github.com/franloza) ([#160](https://github.com/fishtown-analytics/dbt-spark/pull/160))
- [@Fokko](https://github.com/Fokko) ([#165](https://github.com/fishtown-analytics/dbt-spark/pull/165))
- [@rahulgoyal2987](https://github.com/rahulgoyal2987) ([#169](https://github.com/fishtown-analytics/dbt-spark/pull/169))
- [@JCZuurmond](https://github.com/JCZuurmond) ([#171](https://github.com/fishtown-analytics/dbt-spark/pull/171))

## dbt-spark 0.19.1 (Release TBD)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ A dbt profile for Spark connections support the following configurations:
| user | The username to use to connect to the cluster | ❔ | ❔ | ❔ | `hadoop` |
| connect_timeout | The number of seconds to wait before retrying to connect to a Pending Spark cluster | ❌ | ❔ (`10`) | ❔ (`10`) | `60` |
| connect_retries | The number of times to try connecting to a Pending Spark cluster before giving up | ❌ | ❔ (`0`) | ❔ (`0`) | `5` |
| use_ssl | The value of `hive.server2.use.SSL` (`True` or `False`). Default ssl store (ssl.get_default_verify_paths()) is the valid location for SSL certificate | ❌ | ❔ (`False`) | ❌ | `True` |

**Databricks** connections differ based on the cloud provider:

Expand Down
75 changes: 70 additions & 5 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
from hologram.helpers import StrEnum
from dataclasses import dataclass
from typing import Optional
try:
from thrift.transport.TSSLSocket import TSSLSocket
import thrift
import ssl
import sasl
import thrift_sasl
except ImportError:
TSSLSocket = None
thrift = None
ssl = None
sasl = None
thrift_sasl = None

import base64
import time
Expand Down Expand Up @@ -59,6 +71,7 @@ class SparkCredentials(Credentials):
organization: str = '0'
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False

@classmethod
def __pre_deserialize__(cls, data):
Expand Down Expand Up @@ -348,11 +361,20 @@ def open(cls, connection):
cls.validate_creds(creds,
['host', 'port', 'user', 'schema'])

conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
if creds.use_ssl:
transport = build_ssl_transport(
host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name)
conn = hive.connect(thrift_transport=transport)
else:
conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
if creds.cluster is not None:
Expand Down Expand Up @@ -431,6 +453,49 @@ def open(cls, connection):
return connection


def build_ssl_transport(host, port, username, auth,
kerberos_service_name, password=None):
transport = None
if port is None:
port = 10000
if auth is None:
auth = 'NONE'
socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE)
if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL
# in hive-site.xml
transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI
# in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs
# to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client

transport = thrift_sasl.TSaslClientTransport(sasl_factory,
sasl_auth, socket)
return transport


def _is_retryable_error(exc: Exception) -> Optional[str]:
message = getattr(exc, 'message', None)
if message is None:
Expand Down
35 changes: 35 additions & 0 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def _get_target_thrift_kerberos(self, project):
'target': 'test'
})

def _get_target_use_ssl_thrift(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'use_ssl': True,
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
}
},
'target': 'test'
})

def _get_target_odbc_cluster(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
Expand Down Expand Up @@ -154,6 +170,25 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name):
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(thrift_transport):
self.assertIsNotNone(thrift_transport)
transport = thrift_transport._trans
self.assertEqual(transport.host, 'myorg.sparkhost.com')
self.assertEqual(transport.port, 10001)

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
connection.handle # trigger lazy-load

self.assertEqual(connection.state, 'open')
self.assertIsNotNone(connection.handle)
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)
Expand Down