Skip to content

Commit

Permalink
fix(sqlalchemy): v2.0 compat by using only ColumnCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed May 18, 2023
1 parent df22664 commit 9908df9
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 16 deletions.
7 changes: 4 additions & 3 deletions siuba/sql/across.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .backend import LazyTbl
from .utils import _sql_select, _sql_column_collection
from .translate import ColumnCollection

from sqlalchemy import sql

Expand Down Expand Up @@ -39,13 +40,13 @@ def _across_lazy_tbl(__data: LazyTbl, cols, fns, names: "str | None" = None) ->
#return __data.append_op(_sql_select(res_cols))


@across.register(sql.base.ImmutableColumnCollection)
@across.register(ColumnCollection)
def _across_sql_cols(
__data: sql.base.ImmutableColumnCollection,
__data: ColumnCollection,
cols,
fns,
names: "str | None" = None
) -> sql.base.ImmutableColumnCollection:
) -> ColumnCollection:

lazy_tbl = ctx_verb_data.get()
window = ctx_verb_window.get()
Expand Down
4 changes: 2 additions & 2 deletions siuba/sql/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import warnings

from .translate import CustomOverClause
from .translate import CustomOverClause, ColumnCollection
from .utils import (
get_dialect_translator,
_sql_column_collection,
Expand Down Expand Up @@ -106,7 +106,7 @@ def replace_call_windows(col_expr, group_by, order_by, window_cte = None):
raise TypeError(str(type(col_expr)))


@replace_call_windows.register(sql.base.ImmutableColumnCollection)
@replace_call_windows.register(ColumnCollection)
def _(col_expr, group_by, order_by, window_cte = None):
all_over_clauses = []
for col in col_expr:
Expand Down
1 change: 0 additions & 1 deletion siuba/sql/dply/vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings

from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy import sql

from ..translate import (
Expand Down
4 changes: 4 additions & 0 deletions siuba/sql/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class SqlColumn(SqlBase): pass
class SqlColumnAgg(SqlBase): pass


# Columns container ===========================================================

from sqlalchemy.sql import ColumnCollection

# Custom over clause handling ================================================

from sqlalchemy.sql.elements import Over
Expand Down
4 changes: 3 additions & 1 deletion siuba/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def _sql_select(columns, *args, **kwargs):


def _sql_column_collection(columns):
from sqlalchemy.sql.base import ColumnCollection
# This function largely handles the removal of ImmutableColumnCollection in
# sqlalchemy, in favor of ColumnCollection being immutable.
from sqlalchemy.sql.base import ColumnCollection, ImmutableColumnCollection

data = {col.key: col for col in columns}

Expand Down
5 changes: 2 additions & 3 deletions siuba/sql/verbs/arrange.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from sqlalchemy.sql.base import ImmutableColumnCollection

from siuba.dply.verbs import arrange, _call_strip_ascending
from siuba.dply.across import _set_data_context

from ..utils import lift_inner_cols
from ..backend import LazyTbl
from ..translate import ColumnCollection

# Helpers ---------------------------------------------------------------------

Expand Down Expand Up @@ -38,7 +37,7 @@ def _eval_arrange_args(__data, args, cols):
with _set_data_context(__data, window=True):
res = new_call(cols)

if isinstance(res, ImmutableColumnCollection):
if isinstance(res, ColumnCollection):
raise NotImplementedError(
f"`arrange()` expression {ii} of {len(args)} returned multiple columns, "
"which is currently unsupported."
Expand Down
3 changes: 2 additions & 1 deletion siuba/sql/verbs/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from siuba.siu import Call

from ..backend import LazyTbl
from ..translate import ColumnCollection


@case_when.register(sql.base.ImmutableColumnCollection)
@case_when.register(ColumnCollection)
def _case_when(__data, cases):
# TODO: will need listener to enter case statements, to handle when they use windows
if isinstance(cases, Call):
Expand Down
3 changes: 2 additions & 1 deletion siuba/sql/verbs/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from siuba.dply.verbs import filter

from ..backend import LazyTbl
from ..translate import ColumnCollection
from ..utils import _sql_select

from sqlalchemy import sql
Expand Down Expand Up @@ -33,7 +34,7 @@ def _filter(__data, *args):
window_cte = win_sel
)

if isinstance(col_expr, sql.base.ImmutableColumnCollection):
if isinstance(col_expr, ColumnCollection):
conds.extend(col_expr)
else:
conds.append(col_expr)
Expand Down
9 changes: 5 additions & 4 deletions siuba/sql/verbs/mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)

from ..backend import LazyTbl, SqlLabelReplacer
from ..translate import ColumnCollection
from ..utils import (
_sql_with_only_columns,
lift_inner_cols
Expand Down Expand Up @@ -42,7 +43,7 @@ def _select_mutate_result(src_sel, expr_result):
src_columns = set(lift_inner_cols(src_sel))
replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns)

if isinstance(expr_result, sql.base.ImmutableColumnCollection):
if isinstance(expr_result, ColumnCollection):
replaced_cols = list(map(replacer, expr_result))
orig_cols = expr_result
#elif isinstance(expr_result, None):
Expand Down Expand Up @@ -71,7 +72,7 @@ def _eval_expr_arg(__data, sel, func, verb_name, window=True):
cols_result = _eval_with_context(__data, window, inner_cols, func)

# TODO: remove or raise a more informative error
assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result)
assert isinstance(cols_result, ColumnCollection), type(cols_result)

return cols_result

Expand All @@ -82,7 +83,7 @@ def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True):
expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name)
new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols)

if isinstance(new_col, sql.base.ImmutableColumnCollection):
if isinstance(new_col, ColumnCollection):
raise TypeError(
f"{verb_name} named arguments must return a single column, but `{new_name}` "
"returned multiple columns."
Expand All @@ -101,7 +102,7 @@ def _mutate_cols(__data, args, kwargs, verb_name):
# replace any labels that require a subquery ----
sel = _select_mutate_result(sel, cols_result)

if isinstance(cols_result, sql.base.ImmutableColumnCollection):
if isinstance(cols_result, ColumnCollection):
result_names.update({k: True for k in cols_result.keys()})
else:
result_names[cols_result.name] = True
Expand Down

0 comments on commit 9908df9

Please sign in to comment.