Skip to content

Commit

Permalink
PHOENIX-6892 Add support for SqlAlchemy 2.0
Browse files Browse the repository at this point in the history
also
- SQLAlchemy is no longer an install dependency
- add all supported Python-SQLAlchemy combinations to tox
- replace deprecated failUnless test method
  • Loading branch information
stoty committed Mar 20, 2023
1 parent d086ca5 commit 5634468
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 103 deletions.
11 changes: 11 additions & 0 deletions python-phoenixdb/NEWS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
Changelog
=========

Unreleased
----------
- Update python-phoenixdb/RELEASING.rst (PHOENIX-6820)
- Add Python 3.11 to supported languages and update docker test image for phoenixdb (PHOENIX-6858)
- Document workaround for PhoenixDB 1.2+ not working with Python2 on some systems (PHOENIX-6863)
- Update install instructions in README.rst (PHOENIX-6812)
- Add support for SQLAlchemy 2.0 (PHOENIX-6892)
- SQLAlchemy is no longer an install dependency (PHOENIX-6892)
- Run tests with all supported Python + SqlAlchemy versions (1.3, 1.4, 2.0) (PHOENIX-6892)
- Replace deprecated failUnless methods in tests (PHOENIX-6892)

Version 1.2.1
-------------
- Defined authentication mechanism for SPNEGO explicitly (PHOENIX-6781)
Expand Down
29 changes: 21 additions & 8 deletions python-phoenixdb/phoenixdb/sqlalchemy_phoenix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys

import phoenixdb
import sqlalchemy

from sqlalchemy import types
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
Expand Down Expand Up @@ -44,6 +45,14 @@ def visit_primary_key_constraint(self, constraint):
)


if sqlalchemy.__version__.startswith('1.3'):
def _get_dbapi(connectable):
return connectable.connect().connection.connection
else:
def _get_dbapi(connectable):
return connectable.connection


class PhoenixExecutionContext(DefaultExecutionContext):

def should_autocommit_text(self, statement):
Expand Down Expand Up @@ -107,7 +116,11 @@ def __init__(self, tls=False, path='/', **opts):
# There is no way to pass these via the SqlAlchemy url object
self.tls = tls
self.path = path
super(PhoenixDialect, self).__init__(self, **opts)
super(PhoenixDialect, self).__init__(**opts)

@classmethod
def import_dbapi(cls):
return phoenixdb

@classmethod
def dbapi(cls):
Expand All @@ -131,13 +144,13 @@ def create_connect_args(self, url):
def has_table(self, connection, table_name, schema=None, **kw):
if schema is None:
schema = ''
return bool(connection.connect().connection.meta().get_tables(
return bool(_get_dbapi(connection).meta().get_tables(
tableNamePattern=table_name,
schemaPattern=schema,
typeList=('TABLE', 'SYSTEM_TABLE')))

def get_schema_names(self, connection, **kw):
schemas = connection.connect().connection.meta().get_schemas()
schemas = _get_dbapi(connection).meta().get_schemas()
schema_names = [schema['TABLE_SCHEM'] for schema in schemas]
# Phoenix won't return the default schema if there aren't any tables in it
if '' not in schema_names:
Expand All @@ -148,27 +161,27 @@ def get_table_names(self, connection, schema=None, order_by=None, **kw):
'''order_by is ignored'''
if schema is None:
schema = ''
tables = connection.connect().connection.meta().get_tables(
tables = _get_dbapi(connection).meta().get_tables(
schemaPattern=schema, typeList=('TABLE', 'SYSTEM TABLE'))
return [table['TABLE_NAME'] for table in tables]

def get_view_names(self, connection, schema=None, **kw):
if schema is None:
schema = ''
views = connection.connect().connection.meta().get_tables(schemaPattern=schema, typeList=('VIEW',))
views = _get_dbapi(connection).meta().get_tables(schemaPattern=schema, typeList=('VIEW',))
return [view['TABLE_NAME'] for view in views]

def get_columns(self, connection, table_name, schema=None, **kw):
if schema is None:
schema = ''
raw = connection.connect().connection.meta().get_columns(
raw = _get_dbapi(connection).meta().get_columns(
schemaPattern=schema, tableNamePattern=table_name)
return [self._map_column(row) for row in raw]

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
if schema is None:
schema = ''
raw = connection.connect().connection.meta().get_primary_keys(
raw = _get_dbapi(connection).meta().get_primary_keys(
schema=schema, table=table_name)
cooked = {
'constrained_columns': []
Expand All @@ -182,7 +195,7 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
def get_indexes(self, connection, table_name, schema=None, **kw):
if schema is None:
schema = ''
raw = connection.connect().connection.meta().get_index_info(schema=schema, table=table_name)
raw = _get_dbapi(connection).meta().get_index_info(schema=schema, table=table_name)
# We know that Phoenix returns the rows ordered by INDEX_NAME and ORDINAL_POSITION
cooked = []
current = None
Expand Down
111 changes: 44 additions & 67 deletions python-phoenixdb/phoenixdb/tests/dbapi20.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@
import time
import sys

if sys.version[0] >= '3': #python 3.x
_BaseException = Exception
def _failUnless(self, expr, msg=None):
self.assertTrue(expr, msg)
else: #python 2.x
from exceptions import StandardError as _BaseException
def _failUnless(self, expr, msg=None):
self.failUnless(expr, msg) ## deprecated since Python 2.6

def str2bytes(sval):
if sys.version_info < (3,0) and isinstance(sval, str):
sval = sval.decode("latin1")
Expand Down Expand Up @@ -147,7 +138,7 @@ def test_threadsafety(self):
# Must exist
threadsafety = self.driver.threadsafety
# Must be a valid value
_failUnless(self, threadsafety in (0,1,2,3))
self.assertTrue(threadsafety in (0,1,2,3))
except AttributeError:
self.fail("Driver doesn't define threadsafety")

Expand All @@ -156,7 +147,7 @@ def test_paramstyle(self):
# Must exist
paramstyle = self.driver.paramstyle
# Must be a valid value
_failUnless(self, paramstyle in (
self.assertTrue(paramstyle in (
'qmark','numeric','named','format','pyformat'
))
except AttributeError:
Expand All @@ -169,30 +160,16 @@ def test_Exceptions(self):
self.assertTrue(issubclass(self.driver.Warning,Exception))
self.assertTrue(issubclass(self.driver.Error,Exception))
else:
self.failUnless(issubclass(self.driver.Warning,StandardError))
self.failUnless(issubclass(self.driver.Error,StandardError))
self.assertTrue(issubclass(self.driver.Warning,StandardError))
self.assertTrue(issubclass(self.driver.Error,StandardError))

_failUnless(self,
issubclass(self.driver.InterfaceError,self.driver.Error)
)
_failUnless(self,
issubclass(self.driver.DatabaseError,self.driver.Error)
)
_failUnless(self,
issubclass(self.driver.OperationalError,self.driver.Error)
)
_failUnless(self,
issubclass(self.driver.IntegrityError,self.driver.Error)
)
_failUnless(self,
issubclass(self.driver.InternalError,self.driver.Error)
)
_failUnless(self,
issubclass(self.driver.ProgrammingError,self.driver.Error)
)
_failUnless(self,
issubclass(self.driver.NotSupportedError,self.driver.Error)
)
self.assertTrue(issubclass(self.driver.InterfaceError,self.driver.Error))
self.assertTrue(issubclass(self.driver.DatabaseError,self.driver.Error))
self.assertTrue(issubclass(self.driver.OperationalError,self.driver.Error))
self.assertTrue(issubclass(self.driver.IntegrityError,self.driver.Error))
self.assertTrue(issubclass(self.driver.InternalError,self.driver.Error))
self.assertTrue(issubclass(self.driver.ProgrammingError,self.driver.Error))
self.assertTrue(issubclass(self.driver.NotSupportedError,self.driver.Error))

def test_ExceptionsAsConnectionAttributes(self):
# OPTIONAL EXTENSION
Expand All @@ -203,15 +180,15 @@ def test_ExceptionsAsConnectionAttributes(self):
# by default.
con = self._connect()
drv = self.driver
_failUnless(self,con.Warning is drv.Warning)
_failUnless(self,con.Error is drv.Error)
_failUnless(self,con.InterfaceError is drv.InterfaceError)
_failUnless(self,con.DatabaseError is drv.DatabaseError)
_failUnless(self,con.OperationalError is drv.OperationalError)
_failUnless(self,con.IntegrityError is drv.IntegrityError)
_failUnless(self,con.InternalError is drv.InternalError)
_failUnless(self,con.ProgrammingError is drv.ProgrammingError)
_failUnless(self,con.NotSupportedError is drv.NotSupportedError)
self.assertTrue(con.Warning is drv.Warning)
self.assertTrue(con.Error is drv.Error)
self.assertTrue(con.InterfaceError is drv.InterfaceError)
self.assertTrue(con.DatabaseError is drv.DatabaseError)
self.assertTrue(con.OperationalError is drv.OperationalError)
self.assertTrue(con.IntegrityError is drv.IntegrityError)
self.assertTrue(con.InternalError is drv.InternalError)
self.assertTrue(con.ProgrammingError is drv.ProgrammingError)
self.assertTrue(con.NotSupportedError is drv.NotSupportedError)


def test_commit(self):
Expand Down Expand Up @@ -296,24 +273,24 @@ def test_rowcount(self):
try:
cur = con.cursor()
self.executeDDL1(cur)
_failUnless(self,cur.rowcount in (-1,0), # Bug #543885
self.assertTrue(cur.rowcount in (-1,0), # Bug #543885
'cursor.rowcount should be -1 or 0 after executing no-result '
'statements'
)
cur.execute("%s into %sbooze values ('Victoria Bitter')" % (
self.insert, self.table_prefix
))
_failUnless(self,cur.rowcount in (-1,1),
self.assertTrue(cur.rowcount in (-1,1),
'cursor.rowcount should == number or rows inserted, or '
'set to -1 after executing an insert statement'
)
cur.execute("select name from %sbooze" % self.table_prefix)
_failUnless(self,cur.rowcount in (-1,1),
self.assertTrue(cur.rowcount in (-1,1),
'cursor.rowcount should == number of rows returned, or '
'set to -1 after executing a select statement'
)
self.executeDDL2(cur)
_failUnless(self,cur.rowcount in (-1,0), # Bug #543885
self.assertTrue(cur.rowcount in (-1,0), # Bug #543885
'cursor.rowcount should be -1 or 0 after executing no-result '
'statements'
)
Expand Down Expand Up @@ -375,7 +352,7 @@ def _paraminsert(self,cur):
cur.execute("%s into %sbarflys values ('Victoria Bitter', 'thi%%s :may ca%%(u)se? troub:1e')" % (
self.insert, self.table_prefix
))
_failUnless(self,cur.rowcount in (-1,1))
self.assertTrue(cur.rowcount in (-1,1))

if self.driver.paramstyle == 'qmark':
cur.execute(
Expand Down Expand Up @@ -404,7 +381,7 @@ def _paraminsert(self,cur):
)
else:
self.fail('Invalid paramstyle')
_failUnless(self,cur.rowcount in (-1,1))
self.assertTrue(cur.rowcount in (-1,1))

cur.execute('select name, drink from %sbarflys' % self.table_prefix)
res = cur.fetchall()
Expand Down Expand Up @@ -464,7 +441,7 @@ def test_executemany(self):
)
else:
self.fail('Unknown paramstyle')
_failUnless(self,cur.rowcount in (-1,2),
self.assertTrue(cur.rowcount in (-1,2),
'insert using cursor.executemany set cursor.rowcount to '
'incorrect value %r' % cur.rowcount
)
Expand Down Expand Up @@ -499,7 +476,7 @@ def test_fetchone(self):
'cursor.fetchone should return None if a query retrieves '
'no rows'
)
_failUnless(self,cur.rowcount in (-1,0))
self.assertTrue(cur.rowcount in (-1,0))

# cursor.fetchone should raise an Error if called after
# executing a query that cannnot return rows
Expand All @@ -519,7 +496,7 @@ def test_fetchone(self):
self.assertEqual(cur.fetchone(),None,
'cursor.fetchone should return None if no more rows available'
)
_failUnless(self,cur.rowcount in (-1,1))
self.assertTrue(cur.rowcount in (-1,1))
finally:
con.close()

Expand Down Expand Up @@ -575,7 +552,7 @@ def test_fetchmany(self):
'cursor.fetchmany should return an empty sequence after '
'results are exhausted'
)
_failUnless(self,cur.rowcount in (-1,6))
self.assertTrue(cur.rowcount in (-1,6))

# Same as above, using cursor.arraysize
cur.arraysize=4
Expand All @@ -588,12 +565,12 @@ def test_fetchmany(self):
self.assertEqual(len(r),2)
r = cur.fetchmany() # Should be an empty sequence
self.assertEqual(len(r),0)
_failUnless(self,cur.rowcount in (-1,6))
self.assertTrue(cur.rowcount in (-1,6))

cur.arraysize=6
cur.execute('select name from %sbooze' % self.table_prefix)
rows = cur.fetchmany() # Should get all rows
_failUnless(self,cur.rowcount in (-1,6))
self.assertTrue(cur.rowcount in (-1,6))
self.assertEqual(len(rows),6)
self.assertEqual(len(rows),6)
rows = [r[0] for r in rows]
Expand All @@ -610,7 +587,7 @@ def test_fetchmany(self):
'cursor.fetchmany should return an empty sequence if '
'called after the whole result set has been fetched'
)
_failUnless(self,cur.rowcount in (-1,6))
self.assertTrue(cur.rowcount in (-1,6))

self.executeDDL2(cur)
cur.execute('select name from %sbarflys' % self.table_prefix)
Expand All @@ -619,7 +596,7 @@ def test_fetchmany(self):
'cursor.fetchmany should return an empty sequence if '
'query retrieved no rows'
)
_failUnless(self,cur.rowcount in (-1,0))
self.assertTrue(cur.rowcount in (-1,0))

finally:
con.close()
Expand All @@ -643,7 +620,7 @@ def test_fetchall(self):

cur.execute('select name from %sbooze' % self.table_prefix)
rows = cur.fetchall()
_failUnless(self,cur.rowcount in (-1,len(self.samples)))
self.assertTrue(cur.rowcount in (-1,len(self.samples)))
self.assertEqual(len(rows),len(self.samples),
'cursor.fetchall did not retrieve all rows'
)
Expand All @@ -659,12 +636,12 @@ def test_fetchall(self):
'cursor.fetchall should return an empty list if called '
'after the whole result set has been fetched'
)
_failUnless(self,cur.rowcount in (-1,len(self.samples)))
self.assertTrue(cur.rowcount in (-1,len(self.samples)))

self.executeDDL2(cur)
cur.execute('select name from %sbarflys' % self.table_prefix)
rows = cur.fetchall()
_failUnless(self,cur.rowcount in (-1,0))
self.assertTrue(cur.rowcount in (-1,0))
self.assertEqual(len(rows),0,
'cursor.fetchall should return an empty list if '
'a select query returns no rows'
Expand All @@ -686,7 +663,7 @@ def test_mixedfetch(self):
rows23 = cur.fetchmany(2)
rows4 = cur.fetchone()
rows56 = cur.fetchall()
_failUnless(self,cur.rowcount in (-1,6))
self.assertTrue(cur.rowcount in (-1,6))
self.assertEqual(len(rows23),2,
'fetchmany returned incorrect number of rows'
)
Expand Down Expand Up @@ -763,7 +740,7 @@ def test_arraysize(self):
con = self._connect()
try:
cur = con.cursor()
_failUnless(self,hasattr(cur,'arraysize'),
self.assertTrue(hasattr(cur,'arraysize'),
'cursor.arraysize must be defined'
)
finally:
Expand Down Expand Up @@ -832,26 +809,26 @@ def test_Binary(self):
b = self.driver.Binary(str2bytes(''))

def test_STRING(self):
_failUnless(self, hasattr(self.driver,'STRING'),
self.assertTrue(hasattr(self.driver,'STRING'),
'module.STRING must be defined'
)

def test_BINARY(self):
_failUnless(self, hasattr(self.driver,'BINARY'),
self.assertTrue(hasattr(self.driver,'BINARY'),
'module.BINARY must be defined.'
)

def test_NUMBER(self):
_failUnless(self, hasattr(self.driver,'NUMBER'),
self.assertTrue(hasattr(self.driver,'NUMBER'),
'module.NUMBER must be defined.'
)

def test_DATETIME(self):
_failUnless(self, hasattr(self.driver,'DATETIME'),
self.assertTrue(hasattr(self.driver,'DATETIME'),
'module.DATETIME must be defined.'
)

def test_ROWID(self):
_failUnless(self, hasattr(self.driver,'ROWID'),
self.assertTrue(hasattr(self.driver,'ROWID'),
'module.ROWID must be defined.'
)
Loading

0 comments on commit 5634468

Please sign in to comment.