From 95751b6c17d1bc6cfc1965ad46e1c02d711e7929 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Wed, 23 Mar 2022 23:52:49 -0400 Subject: [PATCH 01/11] draft: initial sqlite window functions --- siuba/sql/dialects/sqlite.py | 106 ++++++++++++++++++++++-- siuba/tests/test_dply_series_methods.py | 2 +- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/siuba/sql/dialects/sqlite.py b/siuba/sql/dialects/sqlite.py index 479b2bc1..eecf212b 100644 --- a/siuba/sql/dialects/sqlite.py +++ b/siuba/sql/dialects/sqlite.py @@ -1,13 +1,19 @@ # sqlvariant, allow defining 3 namespaces to override defaults from ..translate import ( SqlColumn, SqlColumnAgg, extend_base, - SqlTranslator + SqlTranslator, + sql_not_impl, + annotate, + wrap_annotate ) from .base import base_nowin +#from .postgresql import PostgresqlColumn as SqlColumn, PostgresqlColumnAgg as SqlColumnAgg +from . import _dt_generics as _dt import sqlalchemy.sql.sqltypes as sa_types from sqlalchemy import sql +from sqlalchemy.sql import func as fn # Custom dispatching in call trees ============================================ @@ -16,18 +22,102 @@ class SqliteColumn(SqlColumn): pass class SqliteColumnAgg(SqlColumnAgg, SqliteColumn): pass -scalar = extend_base( +# Note this is taken from the postgres dialect, but it seems that there are 2 key points +# compared to postgresql, which always returns a float +# * sqlite date parts are returned as floats +# * sqlite time parts are returned as integers +def returns_float(func_names): + # TODO: MC-NOTE - shift all translations to directly register + # TODO: MC-NOTE - make an AliasAnnotated class or something, that signals + # it is using another method, but w/ an updated annotation. + from siuba.ops import ALL_OPS + + for name in func_names: + generic = ALL_OPS[name] + f_concrete = generic.dispatch(SqlColumn) + f_annotated = wrap_annotate(f_concrete, result_type="float") + generic.register(SqliteColumn, f_annotated) + +# detect first and last date (similar to the mysql dialect) --- +@annotate(return_type="float") +def sql_extract(name): + if name == "quarter": + # division in sqlite automatically rounds down + # so for jan, 1 + 2 = 3, and 3 / 1 is Q1 + return lambda _, col: (fn.strftime("%m", col) + 2) / 3 + return lambda _, col: fn.extract(name, col) + + +@_dt.sql_is_last_day_of.register +def _sql_is_last_day_of(codata: SqliteColumn, col, period): + valid_periods = {"month", "year"} + if period not in valid_periods: + raise ValueError(f"Period must be one of {valid_periods}") + + incr = f"+1 {period}" + + target_date = fn.date(col, f'start of {period}', incr, "-1 day") + return col == target_date + +@_dt.sql_is_first_day_of.register +def _sql_is_first_day_of(codata: SqliteColumn, col, period): + valid_periods = {"month", "year"} + if period not in valid_periods: + raise ValueError(f"Period must be one of {valid_periods}") + + target_date = fn.date(col, f'start of {period}') + return fn.date(col) == target_date + + +def sql_days_in_month(_, col): + date_last_day = fn.date(col, 'start of month', '+1 month', '-1 day') + return fn.strftime("%d", date_last_day).cast(sa_types.Integer()) + + +@annotate(result_type = "float") +def sql_round(_, col, n): + return sql.func.round(col, n) + +def sql_func_truediv(_, x, y): + return sql.cast(x, sa_types.Float()) / y + +extend_base( SqliteColumn, - ) + div = sql_func_truediv, + divide = sql_func_truediv, + rdiv = lambda _, x,y: sql_func_truediv(_, y, x), -aggregate = extend_base( - SqliteColumnAgg, - ) + __truediv__ = sql_func_truediv, + truediv = sql_func_truediv, + __rtruediv__ = lambda _, x, y: sql_func_truediv(_, y, x), + + round = sql_round, + __round__ = sql_round, + **{ + "dt.quarter": sql_extract("quarter"), + "dt.is_quarter_start": sql_not_impl("TODO"), + "dt.is_quarter_end": sql_not_impl("TODO"), + "dt.days_in_month": sql_days_in_month, + "dt.daysinmonth": sql_days_in_month, + + } +) -window = extend_base( +returns_float([ + "dt.dayofweek", + "dt.weekday", +]) + + +extend_base( SqliteColumn, - **base_nowin # TODO: should check sqlite version, since < 3.25 can't use windows + quantile = sql_not_impl("sqlite does not support ordered set aggregates"), + ) + +extend_base( + SqliteColumnAgg, + quantile = sql_not_impl("sqlite does not support ordered set aggregates"), ) diff --git a/siuba/tests/test_dply_series_methods.py b/siuba/tests/test_dply_series_methods.py index 31a65af7..302d3e88 100644 --- a/siuba/tests/test_dply_series_methods.py +++ b/siuba/tests/test_dply_series_methods.py @@ -231,7 +231,7 @@ def test_pandas_grouped_frame_fast_not_implemented(notimpl_entry): #@backend_pandas -@pytest.mark.skip_backend('sqlite') +#@pytest.mark.skip_backend('sqlite') def test_frame_mutate(skip_backend, backend, entry): do_test_missing_implementation(entry, backend) skip_no_mutate(entry, backend) From 5e3e0db073cd2ac4551fb728aa7505e604ced70f Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 24 Mar 2022 23:17:47 -0400 Subject: [PATCH 02/11] feat: complete sqlite window translations --- siuba/sql/dialects/sqlite.py | 55 +++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/siuba/sql/dialects/sqlite.py b/siuba/sql/dialects/sqlite.py index eecf212b..e33ec8f6 100644 --- a/siuba/sql/dialects/sqlite.py +++ b/siuba/sql/dialects/sqlite.py @@ -3,6 +3,8 @@ SqlColumn, SqlColumnAgg, extend_base, SqlTranslator, sql_not_impl, + win_cumul, + win_agg, annotate, wrap_annotate ) @@ -22,6 +24,11 @@ class SqliteColumn(SqlColumn): pass class SqliteColumnAgg(SqlColumnAgg, SqliteColumn): pass + +# Translations ================================================================ + +# fix some annotations -------------------------------------------------------- + # Note this is taken from the postgres dialect, but it seems that there are 2 key points # compared to postgresql, which always returns a float # * sqlite date parts are returned as floats @@ -38,7 +45,8 @@ def returns_float(func_names): f_annotated = wrap_annotate(f_concrete, result_type="float") generic.register(SqliteColumn, f_annotated) -# detect first and last date (similar to the mysql dialect) --- +# detect first and last date (similar to the mysql dialect) ------------------- + @annotate(return_type="float") def sql_extract(name): if name == "quarter": @@ -59,6 +67,7 @@ def _sql_is_last_day_of(codata: SqliteColumn, col, period): target_date = fn.date(col, f'start of {period}', incr, "-1 day") return col == target_date + @_dt.sql_is_first_day_of.register def _sql_is_first_day_of(codata: SqliteColumn, col, period): valid_periods = {"month", "year"} @@ -69,20 +78,53 @@ def _sql_is_first_day_of(codata: SqliteColumn, col, period): return fn.date(col) == target_date +# date part of period calculations -------------------------------------------- + def sql_days_in_month(_, col): date_last_day = fn.date(col, 'start of month', '+1 month', '-1 day') return fn.strftime("%d", date_last_day).cast(sa_types.Integer()) + +def sql_week_of_year(_, col): + # convert sqlite week to ISO week + # adapted from: https://stackoverflow.com/a/15511864 + iso_dow = (fn.strftime("%j", fn.date(col, "-3 days", "weekday 4")) - 1) + + return (iso_dow / 7) + 1 + + +# misc ------------------------------------------------------------------------ @annotate(result_type = "float") def sql_round(_, col, n): return sql.func.round(col, n) + def sql_func_truediv(_, x, y): return sql.cast(x, sa_types.Float()) / y + +def between(_, col, x, y): + res = col.between(x, y) + + # tell sqlalchemy the result is a boolean. this causes it to be correctly + # converted from an integer to bool when the results are collected. + # note that this is consistent with what col == col returns + res.type = sa_types.Boolean() + return res + +def sql_str_capitalize(_, col): + # capitalize first letter, then concatenate with lowercased rest + first_upper = fn.upper(fn.substr(col, 1, 1)) + rest_lower = fn.lower(fn.substr(col, 2)) + return first_upper.concat(rest_lower) + extend_base( SqliteColumn, + + between = between, + clip = sql_not_impl("sqlite does not have a least or greatest function."), + div = sql_func_truediv, divide = sql_func_truediv, rdiv = lambda _, x,y: sql_func_truediv(_, y, x), @@ -93,12 +135,20 @@ def sql_func_truediv(_, x, y): round = sql_round, __round__ = sql_round, + + **{ + "str.title": sql_not_impl("TODO"), + "str.capitalize": sql_str_capitalize, + }, + **{ "dt.quarter": sql_extract("quarter"), "dt.is_quarter_start": sql_not_impl("TODO"), "dt.is_quarter_end": sql_not_impl("TODO"), "dt.days_in_month": sql_days_in_month, "dt.daysinmonth": sql_days_in_month, + "dt.week": sql_week_of_year, + "dt.weekofyear": sql_week_of_year, } ) @@ -112,7 +162,10 @@ def sql_func_truediv(_, x, y): extend_base( SqliteColumn, # TODO: should check sqlite version, since < 3.25 can't use windows + cumsum = win_cumul("sum"), + quantile = sql_not_impl("sqlite does not support ordered set aggregates"), + sum = win_agg("sum"), ) extend_base( From 515dc604f58383c7e8cfd4a7d4b4133f48207527 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 27 Mar 2022 18:30:07 -0400 Subject: [PATCH 03/11] feat(sql): all sqlite tests passing --- siuba/sql/dply/vector.py | 15 ++++++--------- siuba/tests/test_verb_arrange.py | 2 -- siuba/tests/test_verb_filter.py | 6 ------ siuba/tests/test_verb_mutate.py | 5 +---- siuba/tests/test_verb_summarize.py | 1 - 5 files changed, 7 insertions(+), 22 deletions(-) diff --git a/siuba/sql/dply/vector.py b/siuba/sql/dply/vector.py index b95fa90e..de5408cf 100644 --- a/siuba/sql/dply/vector.py +++ b/siuba/sql/dply/vector.py @@ -63,6 +63,8 @@ def _sql_rank_over(rank_func, col, partition, nulls_last): return sql.case({col.isnot(None): over_clause}) def _sql_rank(func_name, partition = False, nulls_last = False): + # partition controls whether to make partition by NOT NULL + rank_func = getattr(sql.func, func_name) def f(_, col, na_option = None) -> RankOver: @@ -82,10 +84,10 @@ def f(_, col, na_option = None) -> RankOver: min_rank .register(SqlColumn, _sql_rank("rank", partition = True)) -dense_rank .register(SqliteColumn, win_absent("DENSE_RANK")) -percent_rank.register(SqliteColumn, win_absent("PERCENT_RANK")) -cume_dist .register(SqliteColumn, win_absent("CUME_DIST")) -min_rank .register(SqliteColumn, win_absent("MIN_RANK")) +dense_rank .register(SqliteColumn, _sql_rank("dense_rank", nulls_last=True)) +percent_rank.register(SqliteColumn, _sql_rank("percent_rank", nulls_last=True)) +cume_dist .register(SqliteColumn, _sql_rank("cume_dist", nulls_last=True)) +min_rank .register(SqliteColumn, _sql_rank("min_rank", nulls_last=True)) # partition everything, since MySQL puts NULLs first # see: https://stackoverflow.com/q/1498648/1144523 @@ -94,8 +96,6 @@ def f(_, col, na_option = None) -> RankOver: cume_dist .register(MysqlColumn, _sql_rank("cume_dist", partition = True)) min_rank .register(MysqlColumn, _sql_rank("rank", partition = True)) -# partition everything, since MySQL puts NULLs first -# see: https://stackoverflow.com/q/1498648/1144523 dense_rank .register(BigqueryColumn, _sql_rank("dense_rank", nulls_last = True)) percent_rank.register(BigqueryColumn, _sql_rank("percent_rank", nulls_last = True)) @@ -113,8 +113,6 @@ def _row_number_sql(codata: SqlColumn, col: ClauseElement) -> CumlOver: """ return CumlOver(sql.func.row_number()) -row_number.register(SqliteColumn, win_absent("ROW_NUMBER")) - # between --------------------------------------------------------------------- @between.register @@ -204,7 +202,6 @@ def _n_sql_agg(codata: SqlColumnAgg, x) -> ClauseElement: n.register(SqliteColumn, win_absent("N")) -row_number.register(SqliteColumn, win_absent("ROW_NUMBER")) # n_distinct ------------------------------------------------------------------ diff --git a/siuba/tests/test_verb_arrange.py b/siuba/tests/test_verb_arrange.py index 58c651e9..ca588fdc 100644 --- a/siuba/tests/test_verb_arrange.py +++ b/siuba/tests/test_verb_arrange.py @@ -50,7 +50,6 @@ def test_arrange_grouped_trivial(df): DATA.sort_values(['x']) ) -@backend_notimpl("sqlite") def test_arrange_grouped(backend, df): q = group_by(_.y) >> arrange(_.x) assert_equal_query( @@ -70,7 +69,6 @@ def test_arrange_grouped(backend, df): # SQL ------------------------------------------------------------------------- @backend_sql -@backend_notimpl("sqlite") def test_no_arrange_before_cuml_window_warning(backend): data = data_frame(x = range(1, 5), g = [1,1,2,2]) dfs = backend.load_df(data) diff --git a/siuba/tests/test_verb_filter.py b/siuba/tests/test_verb_filter.py index 0492781a..41984a7b 100644 --- a/siuba/tests/test_verb_filter.py +++ b/siuba/tests/test_verb_filter.py @@ -31,7 +31,6 @@ def test_filter_basic_two_args(backend): assert_equal_query(dfs, filter(_.x > 3, _.y < 2), df[lambda _: (_.x > 3) & (_.y < 2)]) -@backend_notimpl("sqlite") def test_filter_via_group_by(backend): df = data_frame( x = range(1, 11), @@ -48,7 +47,6 @@ def test_filter_via_group_by(backend): ) -@backend_notimpl("sqlite") def test_filter_via_group_by_agg(backend): dfs = backend.load_df(x = range(1,11), g = [1]*5 + [2]*5) @@ -59,7 +57,6 @@ def test_filter_via_group_by_agg(backend): ) -@backend_notimpl("sqlite") def test_filter_via_group_by_agg_two_args(backend): dfs = backend.load_df(x = range(1,11), g = [1]*5 + [2]*5) @@ -71,7 +68,6 @@ def test_filter_via_group_by_agg_two_args(backend): @backend_sql("TODO: pandas - implement arrange over group by") -@backend_notimpl("sqlite") def test_filter_via_group_by_arrange(backend): dfs = backend.load_df(x = [3,2,1] + [2,3,4], g = [1]*3 + [2]*3) @@ -82,7 +78,6 @@ def test_filter_via_group_by_arrange(backend): ) @backend_sql("TODO: pandas - implement arrange over group by") -@backend_notimpl("sqlite") def test_filter_via_group_by_desc_arrange(backend): dfs = backend.load_df(x = [3,2,1] + [2,3,4], g = [1]*3 + [2]*3) @@ -103,7 +98,6 @@ def test_filter_before_summarize(backend): check_dtype=False ) -@backend_notimpl("sqlite") def test_filter_before_summarize_grouped(backend): dfs = backend.load_df(x = [1,2,3], g = ["a", "a", "b"]) diff --git a/siuba/tests/test_verb_mutate.py b/siuba/tests/test_verb_mutate.py index 86fd4dec..38d00843 100644 --- a/siuba/tests/test_verb_mutate.py +++ b/siuba/tests/test_verb_mutate.py @@ -19,7 +19,7 @@ def dfs(backend): @pytest.mark.parametrize("query, output", [ (mutate(x = _.a + _.b), DATA.assign(x = [10, 10, 10])), - pytest.param( mutate(x = _.a + _.b) >> summarize(ttl = _.x.sum()), data_frame(ttl = 30.0), marks = pytest.mark.skip("TODO: failing sqlite?")), + (mutate(x = _.a + _.b) >> summarize(ttl = _.x.sum().astype(float)), data_frame(ttl = 30.0)), (mutate(x = _.a + 1, y = _.b - 1), DATA.assign(x = [2,3,4], y = [8,7,6])), (mutate(x = _.a + 1) >> mutate(y = _.b - 1), DATA.assign(x = [2,3,4], y = [8,7,6])), (mutate(x = _.a + 1, y = _.x + 1), DATA.assign(x = [2,3,4], y = [3,4,5])) @@ -75,7 +75,6 @@ def test_mutate_reassign_all_cols_keeps_rowsize(dfs): ) @backend_sql -@backend_notimpl("sqlite") def test_mutate_window_funcs(backend): data = data_frame(idx = range(0, 4), x = range(1, 5), g = [1,1,2,2]) dfs = backend.load_df(data) @@ -86,7 +85,6 @@ def test_mutate_window_funcs(backend): ) -@backend_notimpl("sqlite") def test_mutate_using_agg_expr(backend): data = data_frame(x = range(1, 5), g = [1,1,2,2]) dfs = backend.load_df(data) @@ -97,7 +95,6 @@ def test_mutate_using_agg_expr(backend): ) @backend_sql # TODO: pandas outputs a int column -@backend_notimpl("sqlite") def test_mutate_using_cuml_agg(backend): data = data_frame(idx = range(0, 4), x = range(1, 5), g = [1,1,2,2]) dfs = backend.load_df(data) diff --git a/siuba/tests/test_verb_summarize.py b/siuba/tests/test_verb_summarize.py index 978ee3ee..eaa26e4f 100644 --- a/siuba/tests/test_verb_summarize.py +++ b/siuba/tests/test_verb_summarize.py @@ -39,7 +39,6 @@ def test_ungrouped_summarize_literal(df): assert_equal_query(df, summarize(y = 1), data_frame(y = 1)) -@backend_notimpl("sqlite") def test_summarize_after_mutate_cuml_win(backend, df_float): assert_equal_query( df_float, From 43f091068524a38734479852b13f36e4d6ac9773 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 27 Mar 2022 18:43:55 -0400 Subject: [PATCH 04/11] tests: enable sqlite dply vector tests --- siuba/sql/dply/vector.py | 5 ++--- siuba/tests/test_dply_vector.py | 11 ----------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/siuba/sql/dply/vector.py b/siuba/sql/dply/vector.py index de5408cf..698d6d5d 100644 --- a/siuba/sql/dply/vector.py +++ b/siuba/sql/dply/vector.py @@ -86,8 +86,8 @@ def f(_, col, na_option = None) -> RankOver: dense_rank .register(SqliteColumn, _sql_rank("dense_rank", nulls_last=True)) percent_rank.register(SqliteColumn, _sql_rank("percent_rank", nulls_last=True)) -cume_dist .register(SqliteColumn, _sql_rank("cume_dist", nulls_last=True)) -min_rank .register(SqliteColumn, _sql_rank("min_rank", nulls_last=True)) +cume_dist .register(SqliteColumn, _sql_rank("cume_dist", partition = True)) +min_rank .register(SqliteColumn, _sql_rank("rank", nulls_last=True)) # partition everything, since MySQL puts NULLs first # see: https://stackoverflow.com/q/1498648/1144523 @@ -201,7 +201,6 @@ def _n_sql_agg(codata: SqlColumnAgg, x) -> ClauseElement: return sql.func.count() -n.register(SqliteColumn, win_absent("N")) # n_distinct ------------------------------------------------------------------ diff --git a/siuba/tests/test_dply_vector.py b/siuba/tests/test_dply_vector.py index d5038d68..cf22e17e 100644 --- a/siuba/tests/test_dply_vector.py +++ b/siuba/tests/test_dply_vector.py @@ -73,9 +73,6 @@ def simple_data(request): @pytest.mark.parametrize('func', OMNIBUS_VECTOR_FUNCS) def test_mutate_vector(backend, func, simple_data): - if backend.name == 'sqlite': - pytest.skip() - df = backend.load_cached_df(simple_data) assert_equal_query( @@ -96,9 +93,6 @@ def test_mutate_vector(backend, func, simple_data): @pytest.mark.parametrize('func', VECTOR_AGG_FUNCS) def test_agg_vector(backend, func, simple_data): - if backend.name == 'sqlite': - pytest.skip() - df = backend.load_cached_df(simple_data) res = data_frame(y = func(simple_data)) @@ -120,9 +114,6 @@ def test_agg_vector(backend, func, simple_data): @backend_sql @pytest.mark.parametrize('func', VECTOR_FILTER_FUNCS) def test_filter_vector(backend, func, simple_data): - if backend.name == 'sqlite': - pytest.skip() - df = backend.load_cached_df(simple_data) res = data_frame(y = func(simple_data)) @@ -147,8 +138,6 @@ def test_filter_vector(backend, func, simple_data): #@given(DATA_SPEC) #@settings(max_examples = 50, deadline = 1000) #def test_hypothesis_mutate_vector_funcs(backend, data): -# if backend.name == 'sqlite': -# pytest.skip() # # df = backend.load_df(data) # From 7d7d60e90bbea8f4c4c5821301990b93552a0321 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 27 Mar 2022 19:01:21 -0400 Subject: [PATCH 05/11] tests(sql): re-enable more sqlite tests --- siuba/tests/test_dply_series_methods.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/siuba/tests/test_dply_series_methods.py b/siuba/tests/test_dply_series_methods.py index 302d3e88..c3394a5f 100644 --- a/siuba/tests/test_dply_series_methods.py +++ b/siuba/tests/test_dply_series_methods.py @@ -231,7 +231,6 @@ def test_pandas_grouped_frame_fast_not_implemented(notimpl_entry): #@backend_pandas -#@pytest.mark.skip_backend('sqlite') def test_frame_mutate(skip_backend, backend, entry): do_test_missing_implementation(entry, backend) skip_no_mutate(entry, backend) @@ -296,7 +295,6 @@ def test_pandas_grouped_frame_fast_mutate(entry): assert_frame_equal(res_obj, dst.obj) -@pytest.mark.skip_backend('sqlite') def test_frame_summarize(skip_backend, backend, agg_entry): entry = agg_entry From da444106a5eea636a876b1390d3e7ee5c078c574 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 27 Mar 2022 19:23:39 -0400 Subject: [PATCH 06/11] ci: separate snowflake tests into own job --- .github/workflows/ci.yml | 31 ++++++++++++++++++++++++++----- pytest.ini | 2 +- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5d1f7e99..f9cc7442 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,8 +59,6 @@ jobs: env: SB_TEST_PGPORT: 5432 PYTEST_FLAGS: ${{ matrix.pytest_flags }} - SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }} - SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }} # optional step for running bigquery tests ---- - name: Set up Cloud SDK @@ -78,13 +76,13 @@ jobs: test-bigquery: name: "Test BigQuery" runs-on: ubuntu-latest - if: contains(github.ref, 'bigquery') || contains(github.ref, 'refs/tags') + if: contains(github.ref, 'bigquery') || ${{ !github.event.pull_request.draft }} steps: - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python uses: actions/setup-python@v2 with: - python-version: ${{ matrix.python-version }} + python-version: "3.8" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -104,6 +102,29 @@ jobs: env: SB_TEST_BQDATABASE: "ci_github" + test-snowflake: + name: "Test snowflake" + runs-on: ubuntu-latest + if: contains(github.ref, 'snowflake') || ${{ !github.event.pull_request.draft }} + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install -r requirements-test.txt + python -m pip install snowflake-sqlalchemy==1.3.3 + python -m pip install . + - name: Test with pytest + run: | + pytest siuba -m snowflake + env: + SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }} + SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }} build-docs: name: "Build Documentation" diff --git a/pytest.ini b/pytest.ini index 42a021da..99c40cf9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,6 @@ [pytest] # since bigquery takes a long time to execute, # the tests are disabled by default. -addopts = --doctest-modules -m 'not bigquery' +addopts = --doctest-modules -m 'not bigquery and not snowflake' markers = skip_backend From 5755ba90124dd942268aedee69e74bd01e8f4f41 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 28 Mar 2022 21:09:02 -0400 Subject: [PATCH 07/11] ci: run bigquery tests in parallel --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f9cc7442..acb23ed7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,7 +49,6 @@ jobs: python -m pip install --upgrade pip python -m pip install $REQUIREMENTS python -m pip install -r requirements-test.txt - python -m pip install snowflake-sqlalchemy==1.3.3 python -m pip install . env: REQUIREMENTS: ${{ matrix.requirements }} @@ -76,7 +75,7 @@ jobs: test-bigquery: name: "Test BigQuery" runs-on: ubuntu-latest - if: contains(github.ref, 'bigquery') || ${{ !github.event.pull_request.draft }} + if: ${{ contains('bigquery', github.ref) || !github.event.pull_request.draft }} steps: - uses: actions/checkout@v2 - name: Set up Python @@ -88,6 +87,7 @@ jobs: python -m pip install --upgrade pip python -m pip install -r requirements.txt python -m pip install -r requirements-test.txt + python -m pip install pytest-parallel python -m pip install git+https://github.com/machow/pybigquery.git pandas-gbq==0.15.0 python -m pip install . - name: Set up Cloud SDK @@ -98,14 +98,14 @@ jobs: export_default_credentials: true - name: Test with pytest run: | - pytest siuba -m bigquery + pytest siuba -m bigquery --workers auto --tests-per-worker 2 env: SB_TEST_BQDATABASE: "ci_github" test-snowflake: name: "Test snowflake" runs-on: ubuntu-latest - if: contains(github.ref, 'snowflake') || ${{ !github.event.pull_request.draft }} + if: ${{ contains('snowflake', github.ref) || !github.event.pull_request.draft }} steps: - uses: actions/checkout@v2 - name: Set up Python @@ -206,7 +206,7 @@ jobs: name: "Deploy to PyPI" runs-on: ubuntu-latest if: github.event_name == 'release' - needs: [checks, test-bigquery] + needs: [checks, test-bigquery, test-snowflake] steps: - uses: actions/checkout@v2 - name: "Set up Python 3.8" From be9ee456d8061d4388941beececffd1090a86027 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 28 Mar 2022 22:51:52 -0400 Subject: [PATCH 08/11] tests: use the multi method for inserting sql --- .github/workflows/ci.yml | 5 ++++- siuba/tests/helpers.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index acb23ed7..28c753a0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,7 +98,10 @@ jobs: export_default_credentials: true - name: Test with pytest run: | - pytest siuba -m bigquery --workers auto --tests-per-worker 2 + # tests are mostly waiting on http requests to bigquery api + # note that test backends can cache data, so more processes + # is not always faster + pytest siuba -m bigquery --workers 2 --tests-per-worker 20 env: SB_TEST_BQDATABASE: "ci_github" diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py index 26087f85..75a77e5f 100644 --- a/siuba/tests/helpers.py +++ b/siuba/tests/helpers.py @@ -243,7 +243,7 @@ def copy_to_sql(df, name, engine): df.to_gbq(qual_name, project_id, if_exists="replace") else: - df.to_sql(name, engine, index = False, if_exists = "replace") + df.to_sql(name, engine, index = False, if_exists = "replace", method="multi") # manually create table, so we can be explicit about boolean columns. # this is necessary because MySQL reflection reports them as TinyInts, From d71acf2ca973111d66dcc3e29d1d2e0a6200fc38 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 28 Mar 2022 23:01:03 -0400 Subject: [PATCH 09/11] ci: move away from custom, patched pybigquery --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 28c753a0..e51473f9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,7 +88,7 @@ jobs: python -m pip install -r requirements.txt python -m pip install -r requirements-test.txt python -m pip install pytest-parallel - python -m pip install git+https://github.com/machow/pybigquery.git pandas-gbq==0.15.0 + python -m pip install sqlalchemy-bigquery==1.3.0 pandas-gbq==0.15.0 python -m pip install . - name: Set up Cloud SDK uses: google-github-actions/setup-gcloud@v0 From 5c34b21d39637ea78fa7db24af17ded81e1279f9 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 28 Mar 2022 23:06:51 -0400 Subject: [PATCH 10/11] fix(sql): broken bigquery dt methods --- siuba/sql/dialects/bigquery.py | 38 ++++++++++++---------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/siuba/sql/dialects/bigquery.py b/siuba/sql/dialects/bigquery.py index b4cbd805..f68d1056 100644 --- a/siuba/sql/dialects/bigquery.py +++ b/siuba/sql/dialects/bigquery.py @@ -9,6 +9,7 @@ import sqlalchemy.sql.sqltypes as sa_types from sqlalchemy import sql from sqlalchemy.sql import func as fn +from . import _dt_generics as _dt # Custom dispatching in call trees ============================================ @@ -23,30 +24,24 @@ def sql_floordiv(_, x, y): # datetime ---- -def _date_trunc(_, col, name): +@_dt.date_trunc.register +def _date_trunc(_: BigqueryColumn, col, name): return fn.datetime_trunc(col, sql.text(name)) -def sql_extract(field): - return lambda _, col: fn.extract(field, col) - -def sql_is_first_of(name, reference): - def f(codata, col): - return _date_trunc(codata, col, name) == _date_trunc(codata, col, reference) - - return f -def sql_func_last_day_in_period(_, col, period): +@_dt.sql_func_last_day_in_period.register +def sql_func_last_day_in_period(_: BigqueryColumn, col, period): return fn.last_day(col, sql.text(period)) -def sql_func_days_in_month(_, col): - return fn.extract('DAY', sql_func_last_day_in_period(col, 'MONTH')) -def sql_is_last_day_of(period): - def f(codata, col): - last_day = sql_func_last_day_in_period(codata, col, period) - return _date_trunc(codata, col, "DAY") == last_day +@_dt.sql_is_last_day_of.register +def sql_is_last_day_of(codata: BigqueryColumn, col, period): + last_day = sql_func_last_day_in_period(codata, col, period) + return _date_trunc(codata, col, "DAY") == last_day - return f + +def sql_extract(field): + return lambda _, col: fn.extract(field, col) # string ---- @@ -115,13 +110,6 @@ def f(_, col): # bigquery has Sunday as 1, pandas wants Monday as 0 "dt.dayofweek": lambda _, col: fn.extract("DAYOFWEEK", col) - 2, "dt.dayofyear": sql_extract("DAYOFYEAR"), - "dt.daysinmonth": sql_func_days_in_month, - "dt.days_in_month": sql_func_days_in_month, - "dt.is_month_end": sql_is_last_day_of("MONTH"), - "dt.is_month_start": sql_is_first_of("DAY", "MONTH"), - "dt.is_quarter_start": sql_is_first_of("DAY", "QUARTER"), - "dt.is_year_end": sql_is_last_day_of("YEAR"), - "dt.is_year_start": sql_is_first_of("DAY", "YEAR"), "dt.month_name": lambda _, col: fn.format_date("%B", col), "dt.week": sql_extract("ISOWEEK"), "dt.weekday": lambda _, col: fn.extract("DAYOFWEEK", col) - 2, @@ -154,7 +142,7 @@ def f(_, col): all = sql_all(window = True), count = lambda _, col: AggOver(fn.count(col)), cumsum = win_cumul("sum"), - median = lambda _, col: RankOver(sql_median(col)), + median = lambda _, col: RankOver(fn.percentile_cont(col, .5)), nunique = lambda _, col: AggOver(fn.count(fn.distinct(col))), quantile = lambda _, col, q: RankOver(fn.percentile_cont(col, q)), std = win_agg("stddev"), From e8bdd46f82f3265a5be225aa22a6e059c0fe30ee Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 28 Mar 2022 23:10:56 -0400 Subject: [PATCH 11/11] ci: also run tests in parallel for snowflake --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e51473f9..cb256df6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -120,11 +120,12 @@ jobs: python -m pip install --upgrade pip python -m pip install -r requirements.txt python -m pip install -r requirements-test.txt + python -m pip install pytest-parallel python -m pip install snowflake-sqlalchemy==1.3.3 python -m pip install . - name: Test with pytest run: | - pytest siuba -m snowflake + pytest siuba -m snowflake --workers 2 --tests-per-worker 20 env: SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }} SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }}