-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/ssl support dbtspark #169
Changes from 7 commits
cdf3168
3db68db
382e47e
776e27a
49588ff
2a278b3
8635b38
79a7b3c
7744680
735d83f
8942c55
8ef3c28
a5965fe
aab9f2e
ce0f1cf
6fdaee9
4a3ea46
5db0101
970930e
003cbdf
6aad441
39628d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,9 @@ | |
from hologram.helpers import StrEnum | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from thrift.transport.TSSLSocket import TSSLSocket | ||
import thrift | ||
import ssl | ||
|
||
import base64 | ||
import time | ||
|
@@ -59,6 +62,7 @@ class SparkCredentials(Credentials): | |
organization: str = '0' | ||
connect_retries: int = 0 | ||
connect_timeout: int = 10 | ||
use_ssl: bool = False | ||
|
||
@classmethod | ||
def __pre_deserialize__(cls, data): | ||
|
@@ -348,11 +352,20 @@ def open(cls, connection): | |
cls.validate_creds(creds, | ||
['host', 'port', 'user', 'schema']) | ||
|
||
conn = hive.connect(host=creds.host, | ||
port=creds.port, | ||
username=creds.user, | ||
auth=creds.auth, | ||
kerberos_service_name=creds.kerberos_service_name) # noqa | ||
if creds.use_ssl: | ||
transport = build_ssl_transport( | ||
host=creds.host, | ||
port=creds.port, | ||
username=creds.user, | ||
auth=creds.auth, | ||
kerberos_service_name=creds.kerberos_service_name) | ||
conn = hive.connect(thrift_transport=transport) | ||
else: | ||
conn = hive.connect(host=creds.host, | ||
port=creds.port, | ||
username=creds.user, | ||
auth=creds.auth, | ||
kerberos_service_name=creds.kerberos_service_name) # noqa | ||
handle = PyhiveConnectionWrapper(conn) | ||
elif creds.method == SparkConnectionMethod.ODBC: | ||
if creds.cluster is not None: | ||
|
@@ -431,6 +444,52 @@ def open(cls, connection): | |
return connection | ||
|
||
|
||
def build_ssl_transport(host, port, username, auth, | ||
kerberos_service_name, password=None): | ||
transport = None | ||
if port is None: | ||
port = 10000 | ||
if auth is None: | ||
auth = 'NONE' | ||
socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE) | ||
if auth == 'NOSASL': | ||
# NOSASL corresponds to hive.server2.authentication=NOSASL | ||
# in hive-site.xml | ||
transport = thrift.transport.TTransport.TBufferedTransport(socket) | ||
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): | ||
# Defer import so package dependency is optional | ||
import sasl | ||
import thrift_sasl | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a different way of handling these optional imports. See the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Could you move these up alongside the other imports (lines 29-32)? I'd prefer not to have imports nested so far down in the fil. |
||
|
||
if auth == 'KERBEROS': | ||
# KERBEROS mode in hive.server2.authentication is GSSAPI | ||
# in sasl library | ||
sasl_auth = 'GSSAPI' | ||
else: | ||
sasl_auth = 'PLAIN' | ||
if password is None: | ||
# Password doesn't matter in NONE mode, just needs | ||
# to be nonempty. | ||
password = 'x' | ||
|
||
def sasl_factory(): | ||
sasl_client = sasl.Client() | ||
sasl_client.setAttr('host', host) | ||
if sasl_auth == 'GSSAPI': | ||
sasl_client.setAttr('service', kerberos_service_name) | ||
elif sasl_auth == 'PLAIN': | ||
sasl_client.setAttr('username', username) | ||
sasl_client.setAttr('password', password) | ||
else: | ||
raise AssertionError | ||
sasl_client.init() | ||
return sasl_client | ||
|
||
transport = thrift_sasl.TSaslClientTransport(sasl_factory, | ||
sasl_auth, socket) | ||
return transport | ||
|
||
|
||
def _is_retryable_error(exc: Exception) -> Optional[str]: | ||
message = getattr(exc, 'message', None) | ||
if message is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's include these up in the
try
import (line 11) for requirements that may or may not be installed.ssl
is a net-new dependency, right? We'll need to add it insetup.py
, within thePyhive
extraThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ssl is part of python library https://docs.python.org/3/library/ssl.html so i think it is not required to add
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I saidJust kidding, I follow now, those are dependencies ofssl
here but I was thinking ofthrift_sasl
. That's a net-new dependency, right?PyHive[hive]
. Thanks for the clarification aroundssl
.