Skip to content

Commit

Permalink
Initial support for SQLAlchemy
Browse files Browse the repository at this point in the history
This is cleaned up code developed at tobii.com (in collaboration with
knowit.se), which we use to connect from Superset (http://airbnb.io/superset/,
apache/superset#2531) to an Athena DB.

Usage:
athena://<user>:<password>@athena.us-east-1.amazonaws.com/?region_name=<region>&s3_staging_dir=s3%3A//<staging_bucket_of_choice>
  • Loading branch information
David Wallin committed Apr 1, 2017
1 parent 51b8d63 commit be7893f
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 1 deletion.
231 changes: 231 additions & 0 deletions pyathenajdbc/sqlalchemy_athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Integration between SQLAlchemy and Athena.
Some code based on
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
which is released under the MIT license.
"""

from __future__ import absolute_import
from __future__ import unicode_literals
import re
#from distutils.version import StrictVersion
#from pyhive import presto
#from pyhive.common import UniversalSet
from sqlalchemy import exc
from sqlalchemy import types
#from sqlalchemy import util
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.sql.compiler import IdentifierPreparer
from pyathenajdbc import error
from pyathenajdbc.converter import JDBCTypeConverter

# try:
# from sqlalchemy.sql.compiler import SQLCompiler
# except ImportError:
# from sqlalchemy.sql.compiler import DefaultCompiler as SQLCompiler


class UniversalSet(object):
def __contains__(self, item):
return True


class AthenaIdentifierPreparer(IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = UniversalSet()

_type_map = {
'NULL': types.NullType,
'BOOLEAN': types.Boolean,
'TINYINT': types.Integer,
'SMALLINT': types.Integer,
'BIGINT': types.BigInteger,
'INTEGER': types.Integer,
'REAL': types.Float,
'DOUBLE': types.Float,
'FLOAT': types.Float,
'CHAR': types.String,
'NCHAR': types.String,
'VARCHAR': types.String,
'NVARCHAR': types.String,
'LONGVARCHAR': types.String,
'LONGNVARCHAR': types.String,
'DATE': types.DATE,
'TIMESTAMP': types.TIMESTAMP,
'TIMESTAMP_WITH_TIMEZONE': types.TIMESTAMP,
'ARRAY': types.ARRAY,
'DECIMAL': types.DECIMAL,
'NUMERIC': types.Numeric,
'BINARY': types.Binary,
'VARBINARY': types.Binary,
'LONGVARBINARY': types.Binary,
# TODO Converter impl
# 'TIME': ???,
# 'BIT': ???,
# 'CLOB': ???,
'BLOB': types.BLOB,
# 'NCLOB': ???,
# 'STRUCT': ???,
'JAVA_OBJECT': types.BLOB,
# 'REF_CURSOR': ???,
# 'REF': ???,
# 'DISTINCT': ???,
# 'DATALINK': ???,
# 'SQLXML': ???,
# 'OTHER': ???,
# 'ROWID': ???,
}


# class AthenaCompiler(SQLCompiler):
# def visit_char_length_func(self, fn, **kw):
# return 'length{}'.format(self.function_argspec(fn, **kw))


class AthenaDialect(DefaultDialect):
name = 'athena'
driver = 'athena'
preparer = AthenaIdentifierPreparer
# statement_compiler = AthenaCompiler
supports_alter = False
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_unicode_statements = True
supports_unicode_binds = True
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True

jdbctypeconverter = None
jdbc_type_map = None


@classmethod
def dbapi(cls):
import pyathenajdbc
import pyathenajdbc.error

pyathenajdbc.Error = pyathenajdbc.error.Error
pyathenajdbc.Warning = pyathenajdbc.error.Warning
pyathenajdbc.InterfaceError = pyathenajdbc.error.InterfaceError
pyathenajdbc.DatabaseError = pyathenajdbc.error.DatabaseError
pyathenajdbc.InternalError = pyathenajdbc.error.InternalError
pyathenajdbc.OperationalError = pyathenajdbc.error.OperationalError
pyathenajdbc.ProgrammingError = pyathenajdbc.error.ProgrammingError
pyathenajdbc.IntegrityError = pyathenajdbc.error.IntegrityError
pyathenajdbc.DataError = pyathenajdbc.error.DataError
pyathenajdbc.NotSupportedError = pyathenajdbc.error.NotSupportedError

return pyathenajdbc

def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')

# TODO:
# - schema_name='default'
# - profile_name=None
# - credential_file=None
kwargs = {
'host': url.host,
'access_key': url.username,
'secret_key': url.password,
'region_name': url.query['region_name'],
's3_staging_dir': url.query['s3_staging_dir']
}
kwargs.update(url.query)
if len(db_parts) == 1:
kwargs['catalog'] = db_parts[0]
elif len(db_parts) == 2:
kwargs['catalog'] = db_parts[0]
kwargs['schema'] = db_parts[1]
else:
raise ValueError("Unexpected database format {}".format(url.database))
return ([], kwargs)

def get_schema_names(self, connection, **kw):
return [schema for (schema,) in connection.execute('SHOW SCHEMAS')]

def _get_table_columns(self, connection, table_name, schema):
name = table_name
if schema is not None:
name = '%s.%s' % (schema, name)
try:
return connection.execute('SHOW COLUMNS IN {}'.format(name))
except (error.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Athena is that this
# error is raised when fetching the cursor's description rather than the initial execute
# call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
# presto.DatabaseError here.
# Does the table exist?
msg = (
e.args[0].get('message') if e.args and isinstance(e.args[0], dict)
else e.args[0] if e.args and isinstance(e.args[0], str)
else None
)
regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
if msg and re.search(regex, msg):
raise exc.NoSuchTableError(table_name)
else:
raise

def has_table(self, connection, table_name, schema=None):
try:
self._get_table_columns(connection, table_name, schema)
return True
except exc.NoSuchTableError:
return False

def get_columns(self, connection, table_name, schema=None, **kwargs):

if self.jdbctypeconverter is None:
self.jdbctypeconverter = JDBCTypeConverter()
self.jdbc_type_map = {v: k for (k, v) in
self.jdbctypeconverter.jdbc_type_mappings.items()}

# pylint: disable=unused-argument
name = table_name
if schema is not None:
name = '%s.%s' % (schema, name)
query = 'SELECT * FROM %s LIMIT 0' % name
cursor = connection.execute(query)
schema = cursor.cursor.description
# We need to fetch the empty results otherwise these queries remain in
# flight
cursor.fetchall()
column_info = []
for col in schema:
column_info.append({
'name': col[0],
'type': _type_map[self.jdbc_type_map[col[1]]],
'nullable': True,
'autoincrement': False})
return column_info

def get_foreign_keys(self, connection, table_name, schema=None, **kw):
return []

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
return []

def get_indexes(self, connection, table_name, schema=None, **kw):
return []

def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' IN {}'.format(schema)
return [tbl for (tbl,) in connection.execute(query).fetchall()]

def do_rollback(self, dbapi_connection):
# No transactions for Athena
pass

def _check_unicode_returns(self, connection, additional_tests=None):
# requests gives back Unicode strings
return True

def _check_unicode_description(self, connection):
# requests gives back Unicode strings
return True
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def run(self):
'botocore>=1.0.0'
],
extras_require={
'Pandas': ['pandas>=0.19.0']
'Pandas': ['pandas>=0.19.0'],
'SQLAlchemy': ['sqlalchemy>=1.0.0'],
},
tests_require=[
'futures',
Expand All @@ -138,4 +139,14 @@ def run(self):
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
],
entry_points={
# New versions
'sqlalchemy.dialects': [
'athena = pyathenajdbc.sqlalchemy_athena:AthenaDialect',
],
# Version 0.5
'sqlalchemy.databases': [
'athena = pyathenajdbc.sqlalchemy_athena:AthenaDialect',
],
}
)

0 comments on commit be7893f

Please sign in to comment.