Skip to content

Commit

Permalink
Merge pull request #472 from machow/fix-sqlalchemy-compat
Browse files Browse the repository at this point in the history
fix(sql): sqlalchemy 2.0 compat imports
  • Loading branch information
machow authored Sep 18, 2023
2 parents df88cc0 + c797c04 commit f0b5335
Show file tree
Hide file tree
Showing 24 changed files with 158 additions and 88 deletions.
20 changes: 9 additions & 11 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.7, 3.8]
requirements: ['-r requirements.txt']
python-version: ["3.8", "3.9", "3.10", "3.11"]
# unset doctests for earlier versions of python.
pytest_flags: ["-o addopts='' -m 'not bigquery and not snowflake'"]
pytest_flags: ["-v -o addopts='' -m 'not bigquery and not snowflake'"]
requirements: [""]
include:
# historical requirements
- name: "2020-mid dependencies"
python-version: 3.8
requirements: numpy~=1.19.1 pandas~=1.2.0 SQLAlchemy~=1.3.18 psycopg2~=2.8.5 PyMySQL==1.0.2
pytest_flags: "-o addopts='' -m 'not bigquery and not snowflake'"
- name: "2021-mid dependencies"
python-version: 3.8
requirements: numpy~=1.19.1 pandas~=1.2.0 SQLAlchemy~=1.4.13 psycopg2~=2.8.5 PyMySQL==1.0.2
Expand All @@ -33,7 +29,7 @@ jobs:
python-version: 3.8
requirements: numpy~=1.22.0 pandas~=1.3.5 SQLAlchemy~=1.4.29 psycopg2-binary~=2.9.3 PyMySQL==1.0.2
- name: "2022-early dependencies"
python-version: 3.10.1
python-version: "3.10"
requirements: numpy~=1.22.0 pandas~=1.3.5 SQLAlchemy~=1.4.29 psycopg2-binary~=2.9.3 PyMySQL==1.0.2
latest: true

Expand All @@ -49,12 +45,14 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install $REQUIREMENTS
python -m pip install -r requirements-test.txt
if [ -n "$REQUIREMENTS" ]; then
python -m pip install $REQUIREMENTS '.[test,docs]'
else
python -m pip install '.[test,docs]'
fi
# step to test duckdb
# TODO: move these requirements into the test matrix
pip install duckdb_engine
python -m pip install .
env:
REQUIREMENTS: ${{ matrix.requirements }}
- name: Test with pytest
Expand Down
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ test:
pytest siuba/

test-travis:
py.test --nbval-lax $(filter-out %postgres.ipynb, $(NOTEBOOK_TESTS))
pytest $(PYTEST_FLAGS) siuba/

examples/%.ipynb:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.19.1
pandas==1.2.0
pandas==1.2.5
psycopg2==2.8.5
PyMySQL==1.0.2
python-dateutil==2.8.1
Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
url='https://github.com/machow/siuba',
keywords=['package', ],
install_requires=[
"pandas>=0.24.0",
"pandas>=0.24.0,<2.1.0",
"numpy>=1.12.0",
"SQLAlchemy>=1.2.19",
"PyYAML>=3.0.0"
Expand All @@ -35,6 +35,13 @@
"test": [
"pytest",
"hypothesis",
"IPython",
"pymysql",
"psycopg2-binary",
"duckdb_engine",
# duckdb 0.8.0 has a bug which always errors for pandas v2+
# it's been fixed, but we need to pin until duckdb v0.9.0
"duckdb",
],
"docs": [
"plotnine",
Expand Down
13 changes: 11 additions & 2 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2573,7 +2573,7 @@ def tbl(src, *args, **kwargs):
>>> from siuba import count, show_query, collect
>>> engine = create_engine("sqlite:///:memory:")
>>> cars.to_sql("cars", engine, index=False)
>>> _rows = cars.to_sql("cars", engine, index=False)
>>> tbl_sql_cars = tbl(engine, "cars")
>>> tbl_sql_cars >> count()
Expand Down Expand Up @@ -2607,7 +2607,6 @@ def tbl(src, *args, **kwargs):

return src


@tbl.register
def _tbl_sqla(src: SqlaEngine, table_name, columns=None):
from siuba.sql import LazyTbl
Expand All @@ -2623,6 +2622,16 @@ def _tbl_sqla(src: SqlaEngine, table_name, columns=None):

@tbl.register(object)
def _tbl(__data, *args, **kwargs):
# sqlalchemy v2 does not have MockConnection inherit from anything
# even though it is a mock :/.
try:
from sqlalchemy.engine.mock import MockConnection

if isinstance(__data, MockConnection):
return tbl.dispatch(SqlaEngine)(__data, *args, **kwargs)
except ImportError:
pass

raise NotImplementedError(
f"Unsupported type {type(__data)}. "
"Note that tbl currently can be used at the start of a pipe, but not as "
Expand Down
2 changes: 1 addition & 1 deletion siuba/experimental/pivot/sql_pivot_wide.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _pivot_wider_spec(
when_clause = sql.and_(sel_cols[k] == row[k] for k in name_vars)
when_then = (when_clause, sel_cols[row[".value"]])

col = values_fn(dispatch_cls(), _sql_case(when_then))
col = values_fn(dispatch_cls(), _sql_case([when_then]))

wide_name_cols.append(col)

Expand Down
2 changes: 1 addition & 1 deletion siuba/ops/support/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def replace_next(node):

wide_backends = (
pd.concat([sql_methods, pandas_methods])
.pivot("full_name", "backend", "metadata")
.pivot(index="full_name", columns="backend", values="metadata")
)

full_methods = methods.merge(wide_backends, how = "left", on = "full_name")
Expand Down
7 changes: 4 additions & 3 deletions siuba/sql/across.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .backend import LazyTbl
from .utils import _sql_select, _sql_column_collection
from .translate import ColumnCollection

from sqlalchemy import sql

Expand Down Expand Up @@ -39,13 +40,13 @@ def _across_lazy_tbl(__data: LazyTbl, cols, fns, names: "str | None" = None) ->
#return __data.append_op(_sql_select(res_cols))


@across.register(sql.base.ImmutableColumnCollection)
@across.register(ColumnCollection)
def _across_sql_cols(
__data: sql.base.ImmutableColumnCollection,
__data: ColumnCollection,
cols,
fns,
names: "str | None" = None
) -> sql.base.ImmutableColumnCollection:
) -> ColumnCollection:

lazy_tbl = ctx_verb_data.get()
window = ctx_verb_window.get()
Expand Down
6 changes: 3 additions & 3 deletions siuba/sql/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import warnings

from .translate import CustomOverClause
from .translate import CustomOverClause, ColumnCollection
from .utils import (
get_dialect_translator,
_sql_column_collection,
Expand Down Expand Up @@ -106,7 +106,7 @@ def replace_call_windows(col_expr, group_by, order_by, window_cte = None):
raise TypeError(str(type(col_expr)))


@replace_call_windows.register(sql.base.ImmutableColumnCollection)
@replace_call_windows.register(ColumnCollection)
def _(col_expr, group_by, order_by, window_cte = None):
all_over_clauses = []
for col in col_expr:
Expand Down Expand Up @@ -315,7 +315,7 @@ def _create_table(tbl, columns = None, source = None):

return sqlalchemy.Table(
table_name,
sqlalchemy.MetaData(bind = source),
sqlalchemy.MetaData(),
*columns,
schema = schema,
autoload_with = source if not columns else None
Expand Down
5 changes: 5 additions & 0 deletions siuba/sql/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
FunctionLookupBound
)

from siuba.sql.utils import SQLA_VERSION


# =============================================================================
# Column data classes
Expand All @@ -65,6 +67,9 @@ def sql_func_diff(_, col, periods = 1):
raise ValueError("periods argument to sql diff cannot be 0")

def sql_func_floordiv(_, x, y):
if SQLA_VERSION >= (2, 0, 0):
return sql.cast(x // y, sa_types.Integer())

return sql.cast(x / y, sa_types.Integer())

def sql_func_rank(_, col):
Expand Down
2 changes: 2 additions & 0 deletions siuba/sql/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def f_quantile(codata, col, q, *args):

extend_base(
DuckdbColumn,
__floordiv__ = lambda _, x, y: x.op("//")(y),
__rfloordiv__ = lambda _, y, x: x.op("//")(y),
rank = sql_func_rank,
#quantile = sql_quantile(is_analytic=True),
)
Expand Down
5 changes: 5 additions & 0 deletions siuba/sql/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
annotate,
wrap_annotate
)
from ..utils import SQLA_VERSION

from .base import base_nowin
#from .postgresql import PostgresqlColumn as SqlColumn, PostgresqlColumnAgg as SqlColumnAgg
Expand Down Expand Up @@ -90,6 +91,10 @@ def sql_week_of_year(_, col):
# adapted from: https://stackoverflow.com/a/15511864
iso_dow = (fn.strftime("%j", fn.date(col, "-3 days", "weekday 4")) - 1)

if SQLA_VERSION >= (2, 0, 0):
# in v2, regular division will cause sqlalchemy to coerce the 7 to a float
return (iso_dow // 7) + 1

return (iso_dow / 7) + 1


Expand Down
1 change: 0 additions & 1 deletion siuba/sql/dply/vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings

from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy import sql

from ..translate import (
Expand Down
4 changes: 4 additions & 0 deletions siuba/sql/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class SqlColumn(SqlBase): pass
class SqlColumnAgg(SqlBase): pass


# Columns container ===========================================================

from sqlalchemy.sql import ColumnCollection

# Custom over clause handling ================================================

from sqlalchemy.sql.elements import Over
Expand Down
29 changes: 23 additions & 6 deletions siuba/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def mock_sqlalchemy_engine(dialect):

class _FixedSqlDatabase(_pd_sql.SQLDatabase):
def execute(self, *args, **kwargs):
return self.connectable.execute(*args, **kwargs)
if hasattr(self, "connectable"):
return self.connectable.execute(*args, **kwargs)

return self.con.execute(*args, **kwargs)


# Detect duckdb for temporary workarounds -------------------------------------
Expand All @@ -104,6 +107,9 @@ def is_sqla_12():
def is_sqla_13():
return SQLA_VERSION[:-1] == (1, 3)

def is_sqla_14():
return SQLA_VERSION[:-1] == (1, 4)


def _sql_select(columns, *args, **kwargs):
from sqlalchemy import sql
Expand All @@ -115,14 +121,25 @@ def _sql_select(columns, *args, **kwargs):


def _sql_column_collection(columns):
from sqlalchemy.sql.base import ColumnCollection, ImmutableColumnCollection
# This function largely handles the removal of ImmutableColumnCollection in
# sqlalchemy, in favor of ColumnCollection being immutable.

data = {col.key: col for col in columns}

if is_sqla_12() or is_sqla_13():
from sqlalchemy.sql.base import ImmutableColumnCollection

return ImmutableColumnCollection(data, columns)

return ColumnCollection(list(data.items())).as_immutable()
elif is_sqla_14():
from sqlalchemy.sql.base import ColumnCollection

return ColumnCollection(list(data.items())).as_immutable()

else:
from sqlalchemy.sql.base import ColumnCollection

return ColumnCollection(list(data.items()))


def _sql_add_columns(select, columns):
Expand All @@ -147,12 +164,12 @@ def _sql_with_only_columns(select, columns):
return out


def _sql_case(*args, **kwargs):
def _sql_case(whens, **kwargs):
from sqlalchemy import sql
if is_sqla_12() or is_sqla_13():
return sql.case(args, **kwargs)
return sql.case(whens, **kwargs)

return sql.case(*args, **kwargs)
return sql.case(*whens, **kwargs)


# Simplify option in show_query -----------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions siuba/sql/verbs/arrange.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from sqlalchemy.sql.base import ImmutableColumnCollection

from siuba.dply.verbs import arrange, _call_strip_ascending
from siuba.dply.across import _set_data_context

from ..utils import lift_inner_cols
from ..backend import LazyTbl
from ..translate import ColumnCollection

# Helpers ---------------------------------------------------------------------

Expand Down Expand Up @@ -38,7 +37,7 @@ def _eval_arrange_args(__data, args, cols):
with _set_data_context(__data, window=True):
res = new_call(cols)

if isinstance(res, ImmutableColumnCollection):
if isinstance(res, ColumnCollection):
raise NotImplementedError(
f"`arrange()` expression {ii} of {len(args)} returned multiple columns, "
"which is currently unsupported."
Expand Down
8 changes: 5 additions & 3 deletions siuba/sql/verbs/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from siuba.dply.verbs import case_when, if_else
from siuba.siu import Call

from ..utils import _sql_case
from ..backend import LazyTbl
from ..translate import ColumnCollection


@case_when.register(sql.base.ImmutableColumnCollection)
@case_when.register(ColumnCollection)
def _case_when(__data, cases):
# TODO: will need listener to enter case statements, to handle when they use windows
if isinstance(cases, Call):
Expand All @@ -35,7 +37,7 @@ def _case_when(__data, cases):
else:
whens.append((expr, val))

return sql.case(whens, else_ = else_val)
return _sql_case(whens, else_ = else_val)


@case_when.register(LazyTbl)
Expand All @@ -51,4 +53,4 @@ def _case_when(__data, cases):
@if_else.register(sql.elements.ColumnElement)
def _if_else(cond, true_vals, false_vals):
whens = [(cond, true_vals)]
return sql.case(whens, else_ = false_vals)
return _sql_case(whens, else_ = false_vals)
3 changes: 2 additions & 1 deletion siuba/sql/verbs/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from siuba.dply.verbs import filter

from ..backend import LazyTbl
from ..translate import ColumnCollection
from ..utils import _sql_select

from sqlalchemy import sql
Expand Down Expand Up @@ -33,7 +34,7 @@ def _filter(__data, *args):
window_cte = win_sel
)

if isinstance(col_expr, sql.base.ImmutableColumnCollection):
if isinstance(col_expr, ColumnCollection):
conds.extend(col_expr)
else:
conds.append(col_expr)
Expand Down
Loading

0 comments on commit f0b5335

Please sign in to comment.