diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index b5aab320..3dbc72ac 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -56,13 +56,16 @@ def execute(self, *args, **kwargs): RE_VERSION=r"(?P\d+)\.(?P\d+).(?P\d+)" SQLA_VERSION=tuple(map(int, re.match(RE_VERSION, sqlalchemy.__version__).groups())) +def is_sqla_12(): + return SQLA_VERSION[:-1] == (1, 2) + def is_sqla_13(): return SQLA_VERSION[:-1] == (1, 3) def _sql_select(columns, *args, **kwargs): from sqlalchemy import sql - if is_sqla_13(): + if is_sqla_12() or is_sqla_13(): # use old syntax, where columns are passed as a list return sql.select(columns, *args, **kwargs) @@ -72,7 +75,7 @@ def _sql_select(columns, *args, **kwargs): def _sql_column_collection(data, columns): from sqlalchemy.sql.base import ColumnCollection, ImmutableColumnCollection - if is_sqla_13(): + if is_sqla_12() or is_sqla_13(): return ImmutableColumnCollection(data, columns) return ColumnCollection(list(data.items())).as_immutable()