diff --git a/siuba/sql/across.py b/siuba/sql/across.py index 59a2ba5e..042656d7 100644 --- a/siuba/sql/across.py +++ b/siuba/sql/across.py @@ -6,6 +6,7 @@ from .backend import LazyTbl from .utils import _sql_select, _sql_column_collection +from .translate import ColumnCollection from sqlalchemy import sql @@ -39,13 +40,13 @@ def _across_lazy_tbl(__data: LazyTbl, cols, fns, names: "str | None" = None) -> #return __data.append_op(_sql_select(res_cols)) -@across.register(sql.base.ImmutableColumnCollection) +@across.register(ColumnCollection) def _across_sql_cols( - __data: sql.base.ImmutableColumnCollection, + __data: ColumnCollection, cols, fns, names: "str | None" = None -) -> sql.base.ImmutableColumnCollection: +) -> ColumnCollection: lazy_tbl = ctx_verb_data.get() window = ctx_verb_window.get() diff --git a/siuba/sql/backend.py b/siuba/sql/backend.py index d2ed3cc6..d05022f8 100644 --- a/siuba/sql/backend.py +++ b/siuba/sql/backend.py @@ -9,7 +9,7 @@ import warnings -from .translate import CustomOverClause +from .translate import CustomOverClause, ColumnCollection from .utils import ( get_dialect_translator, _sql_column_collection, @@ -106,7 +106,7 @@ def replace_call_windows(col_expr, group_by, order_by, window_cte = None): raise TypeError(str(type(col_expr))) -@replace_call_windows.register(sql.base.ImmutableColumnCollection) +@replace_call_windows.register(ColumnCollection) def _(col_expr, group_by, order_by, window_cte = None): all_over_clauses = [] for col in col_expr: diff --git a/siuba/sql/dply/vector.py b/siuba/sql/dply/vector.py index d29c4fd2..542dd86c 100644 --- a/siuba/sql/dply/vector.py +++ b/siuba/sql/dply/vector.py @@ -1,7 +1,6 @@ import warnings from sqlalchemy.sql.elements import ClauseElement -from sqlalchemy.sql.base import ImmutableColumnCollection from sqlalchemy import sql from ..translate import ( diff --git a/siuba/sql/translate.py b/siuba/sql/translate.py index d54f2fb9..62348727 100644 --- a/siuba/sql/translate.py +++ b/siuba/sql/translate.py @@ -53,6 +53,10 @@ class SqlColumn(SqlBase): pass class SqlColumnAgg(SqlBase): pass +# Columns container =========================================================== + +from sqlalchemy.sql import ColumnCollection + # Custom over clause handling ================================================ from sqlalchemy.sql.elements import Over diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index aab903e4..b818370a 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -115,7 +115,9 @@ def _sql_select(columns, *args, **kwargs): def _sql_column_collection(columns): - from sqlalchemy.sql.base import ColumnCollection + # This function largely handles the removal of ImmutableColumnCollection in + # sqlalchemy, in favor of ColumnCollection being immutable. + from sqlalchemy.sql.base import ColumnCollection, ImmutableColumnCollection data = {col.key: col for col in columns} diff --git a/siuba/sql/verbs/arrange.py b/siuba/sql/verbs/arrange.py index 2fd7bd98..8b6ee81e 100644 --- a/siuba/sql/verbs/arrange.py +++ b/siuba/sql/verbs/arrange.py @@ -1,10 +1,9 @@ -from sqlalchemy.sql.base import ImmutableColumnCollection - from siuba.dply.verbs import arrange, _call_strip_ascending from siuba.dply.across import _set_data_context from ..utils import lift_inner_cols from ..backend import LazyTbl +from ..translate import ColumnCollection # Helpers --------------------------------------------------------------------- @@ -38,7 +37,7 @@ def _eval_arrange_args(__data, args, cols): with _set_data_context(__data, window=True): res = new_call(cols) - if isinstance(res, ImmutableColumnCollection): + if isinstance(res, ColumnCollection): raise NotImplementedError( f"`arrange()` expression {ii} of {len(args)} returned multiple columns, " "which is currently unsupported." diff --git a/siuba/sql/verbs/conditional.py b/siuba/sql/verbs/conditional.py index d9acd73a..aa551f97 100644 --- a/siuba/sql/verbs/conditional.py +++ b/siuba/sql/verbs/conditional.py @@ -6,9 +6,10 @@ from siuba.siu import Call from ..backend import LazyTbl +from ..translate import ColumnCollection -@case_when.register(sql.base.ImmutableColumnCollection) +@case_when.register(ColumnCollection) def _case_when(__data, cases): # TODO: will need listener to enter case statements, to handle when they use windows if isinstance(cases, Call): diff --git a/siuba/sql/verbs/filter.py b/siuba/sql/verbs/filter.py index 9275c651..ad705f49 100644 --- a/siuba/sql/verbs/filter.py +++ b/siuba/sql/verbs/filter.py @@ -1,6 +1,7 @@ from siuba.dply.verbs import filter from ..backend import LazyTbl +from ..translate import ColumnCollection from ..utils import _sql_select from sqlalchemy import sql @@ -33,7 +34,7 @@ def _filter(__data, *args): window_cte = win_sel ) - if isinstance(col_expr, sql.base.ImmutableColumnCollection): + if isinstance(col_expr, ColumnCollection): conds.extend(col_expr) else: conds.append(col_expr) diff --git a/siuba/sql/verbs/mutate.py b/siuba/sql/verbs/mutate.py index 5f4e8da4..b053674e 100644 --- a/siuba/sql/verbs/mutate.py +++ b/siuba/sql/verbs/mutate.py @@ -5,6 +5,7 @@ ) from ..backend import LazyTbl, SqlLabelReplacer +from ..translate import ColumnCollection from ..utils import ( _sql_with_only_columns, lift_inner_cols @@ -42,7 +43,7 @@ def _select_mutate_result(src_sel, expr_result): src_columns = set(lift_inner_cols(src_sel)) replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - if isinstance(expr_result, sql.base.ImmutableColumnCollection): + if isinstance(expr_result, ColumnCollection): replaced_cols = list(map(replacer, expr_result)) orig_cols = expr_result #elif isinstance(expr_result, None): @@ -71,7 +72,7 @@ def _eval_expr_arg(__data, sel, func, verb_name, window=True): cols_result = _eval_with_context(__data, window, inner_cols, func) # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + assert isinstance(cols_result, ColumnCollection), type(cols_result) return cols_result @@ -82,7 +83,7 @@ def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - if isinstance(new_col, sql.base.ImmutableColumnCollection): + if isinstance(new_col, ColumnCollection): raise TypeError( f"{verb_name} named arguments must return a single column, but `{new_name}` " "returned multiple columns." @@ -101,7 +102,7 @@ def _mutate_cols(__data, args, kwargs, verb_name): # replace any labels that require a subquery ---- sel = _select_mutate_result(sel, cols_result) - if isinstance(cols_result, sql.base.ImmutableColumnCollection): + if isinstance(cols_result, ColumnCollection): result_names.update({k: True for k in cols_result.keys()}) else: result_names[cols_result.name] = True