Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sqlite window functions #406

Merged
merged 11 commits into from
Mar 29, 2022
43 changes: 34 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ jobs:
python -m pip install --upgrade pip
python -m pip install $REQUIREMENTS
python -m pip install -r requirements-test.txt
python -m pip install snowflake-sqlalchemy==1.3.3
python -m pip install .
env:
REQUIREMENTS: ${{ matrix.requirements }}
Expand All @@ -59,8 +58,6 @@ jobs:
env:
SB_TEST_PGPORT: 5432
PYTEST_FLAGS: ${{ matrix.pytest_flags }}
SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }}
SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }}

# optional step for running bigquery tests ----
- name: Set up Cloud SDK
Expand All @@ -78,19 +75,20 @@ jobs:
test-bigquery:
name: "Test BigQuery"
runs-on: ubuntu-latest
if: contains(github.ref, 'bigquery') || contains(github.ref, 'refs/tags')
if: ${{ contains('bigquery', github.ref) || !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install -r requirements-test.txt
python -m pip install git+https://github.com/machow/pybigquery.git pandas-gbq==0.15.0
python -m pip install pytest-parallel
python -m pip install sqlalchemy-bigquery==1.3.0 pandas-gbq==0.15.0
python -m pip install .
- name: Set up Cloud SDK
uses: google-github-actions/setup-gcloud@v0
Expand All @@ -100,10 +98,37 @@ jobs:
export_default_credentials: true
- name: Test with pytest
run: |
pytest siuba -m bigquery
# tests are mostly waiting on http requests to bigquery api
# note that test backends can cache data, so more processes
# is not always faster
pytest siuba -m bigquery --workers 2 --tests-per-worker 20
env:
SB_TEST_BQDATABASE: "ci_github"

test-snowflake:
name: "Test snowflake"
runs-on: ubuntu-latest
if: ${{ contains('snowflake', github.ref) || !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install -r requirements-test.txt
python -m pip install pytest-parallel
python -m pip install snowflake-sqlalchemy==1.3.3
python -m pip install .
- name: Test with pytest
run: |
pytest siuba -m snowflake --workers 2 --tests-per-worker 20
env:
SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }}
SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }}

build-docs:
name: "Build Documentation"
Expand Down Expand Up @@ -185,7 +210,7 @@ jobs:
name: "Deploy to PyPI"
runs-on: ubuntu-latest
if: github.event_name == 'release'
needs: [checks, test-bigquery]
needs: [checks, test-bigquery, test-snowflake]
steps:
- uses: actions/checkout@v2
- name: "Set up Python 3.8"
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pytest]
# since bigquery takes a long time to execute,
# the tests are disabled by default.
addopts = --doctest-modules -m 'not bigquery'
addopts = --doctest-modules -m 'not bigquery and not snowflake'
markers =
skip_backend
38 changes: 13 additions & 25 deletions siuba/sql/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sqlalchemy.sql.sqltypes as sa_types
from sqlalchemy import sql
from sqlalchemy.sql import func as fn
from . import _dt_generics as _dt

# Custom dispatching in call trees ============================================

Expand All @@ -23,30 +24,24 @@ def sql_floordiv(_, x, y):

# datetime ----

def _date_trunc(_, col, name):
@_dt.date_trunc.register
def _date_trunc(_: BigqueryColumn, col, name):
return fn.datetime_trunc(col, sql.text(name))

def sql_extract(field):
return lambda _, col: fn.extract(field, col)

def sql_is_first_of(name, reference):
def f(codata, col):
return _date_trunc(codata, col, name) == _date_trunc(codata, col, reference)

return f

def sql_func_last_day_in_period(_, col, period):
@_dt.sql_func_last_day_in_period.register
def sql_func_last_day_in_period(_: BigqueryColumn, col, period):
return fn.last_day(col, sql.text(period))

def sql_func_days_in_month(_, col):
return fn.extract('DAY', sql_func_last_day_in_period(col, 'MONTH'))

def sql_is_last_day_of(period):
def f(codata, col):
last_day = sql_func_last_day_in_period(codata, col, period)
return _date_trunc(codata, col, "DAY") == last_day
@_dt.sql_is_last_day_of.register
def sql_is_last_day_of(codata: BigqueryColumn, col, period):
last_day = sql_func_last_day_in_period(codata, col, period)
return _date_trunc(codata, col, "DAY") == last_day

return f

def sql_extract(field):
return lambda _, col: fn.extract(field, col)


# string ----
Expand Down Expand Up @@ -115,13 +110,6 @@ def f(_, col):
# bigquery has Sunday as 1, pandas wants Monday as 0
"dt.dayofweek": lambda _, col: fn.extract("DAYOFWEEK", col) - 2,
"dt.dayofyear": sql_extract("DAYOFYEAR"),
"dt.daysinmonth": sql_func_days_in_month,
"dt.days_in_month": sql_func_days_in_month,
"dt.is_month_end": sql_is_last_day_of("MONTH"),
"dt.is_month_start": sql_is_first_of("DAY", "MONTH"),
"dt.is_quarter_start": sql_is_first_of("DAY", "QUARTER"),
"dt.is_year_end": sql_is_last_day_of("YEAR"),
"dt.is_year_start": sql_is_first_of("DAY", "YEAR"),
"dt.month_name": lambda _, col: fn.format_date("%B", col),
"dt.week": sql_extract("ISOWEEK"),
"dt.weekday": lambda _, col: fn.extract("DAYOFWEEK", col) - 2,
Expand Down Expand Up @@ -154,7 +142,7 @@ def f(_, col):
all = sql_all(window = True),
count = lambda _, col: AggOver(fn.count(col)),
cumsum = win_cumul("sum"),
median = lambda _, col: RankOver(sql_median(col)),
median = lambda _, col: RankOver(fn.percentile_cont(col, .5)),
nunique = lambda _, col: AggOver(fn.count(fn.distinct(col))),
quantile = lambda _, col, q: RankOver(fn.percentile_cont(col, q)),
std = win_agg("stddev"),
Expand Down
159 changes: 151 additions & 8 deletions siuba/sql/dialects/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# sqlvariant, allow defining 3 namespaces to override defaults
from ..translate import (
SqlColumn, SqlColumnAgg, extend_base,
SqlTranslator
SqlTranslator,
sql_not_impl,
win_cumul,
win_agg,
annotate,
wrap_annotate
)

from .base import base_nowin
#from .postgresql import PostgresqlColumn as SqlColumn, PostgresqlColumnAgg as SqlColumnAgg
from . import _dt_generics as _dt

import sqlalchemy.sql.sqltypes as sa_types
from sqlalchemy import sql
from sqlalchemy.sql import func as fn

# Custom dispatching in call trees ============================================

Expand All @@ -16,18 +24,153 @@
class SqliteColumn(SqlColumn): pass
class SqliteColumnAgg(SqlColumnAgg, SqliteColumn): pass

scalar = extend_base(

# Translations ================================================================

# fix some annotations --------------------------------------------------------

# Note this is taken from the postgres dialect, but it seems that there are 2 key points
# compared to postgresql, which always returns a float
# * sqlite date parts are returned as floats
# * sqlite time parts are returned as integers
def returns_float(func_names):
# TODO: MC-NOTE - shift all translations to directly register
# TODO: MC-NOTE - make an AliasAnnotated class or something, that signals
# it is using another method, but w/ an updated annotation.
from siuba.ops import ALL_OPS

for name in func_names:
generic = ALL_OPS[name]
f_concrete = generic.dispatch(SqlColumn)
f_annotated = wrap_annotate(f_concrete, result_type="float")
generic.register(SqliteColumn, f_annotated)

# detect first and last date (similar to the mysql dialect) -------------------

@annotate(return_type="float")
def sql_extract(name):
if name == "quarter":
# division in sqlite automatically rounds down
# so for jan, 1 + 2 = 3, and 3 / 1 is Q1
return lambda _, col: (fn.strftime("%m", col) + 2) / 3
return lambda _, col: fn.extract(name, col)


@_dt.sql_is_last_day_of.register
def _sql_is_last_day_of(codata: SqliteColumn, col, period):
valid_periods = {"month", "year"}
if period not in valid_periods:
raise ValueError(f"Period must be one of {valid_periods}")

incr = f"+1 {period}"

target_date = fn.date(col, f'start of {period}', incr, "-1 day")
return col == target_date


@_dt.sql_is_first_day_of.register
def _sql_is_first_day_of(codata: SqliteColumn, col, period):
valid_periods = {"month", "year"}
if period not in valid_periods:
raise ValueError(f"Period must be one of {valid_periods}")

target_date = fn.date(col, f'start of {period}')
return fn.date(col) == target_date


# date part of period calculations --------------------------------------------

def sql_days_in_month(_, col):
date_last_day = fn.date(col, 'start of month', '+1 month', '-1 day')
return fn.strftime("%d", date_last_day).cast(sa_types.Integer())


def sql_week_of_year(_, col):
# convert sqlite week to ISO week
# adapted from: https://stackoverflow.com/a/15511864
iso_dow = (fn.strftime("%j", fn.date(col, "-3 days", "weekday 4")) - 1)

return (iso_dow / 7) + 1


# misc ------------------------------------------------------------------------

@annotate(result_type = "float")
def sql_round(_, col, n):
return sql.func.round(col, n)


def sql_func_truediv(_, x, y):
return sql.cast(x, sa_types.Float()) / y


def between(_, col, x, y):
res = col.between(x, y)

# tell sqlalchemy the result is a boolean. this causes it to be correctly
# converted from an integer to bool when the results are collected.
# note that this is consistent with what col == col returns
res.type = sa_types.Boolean()
return res

def sql_str_capitalize(_, col):
# capitalize first letter, then concatenate with lowercased rest
first_upper = fn.upper(fn.substr(col, 1, 1))
rest_lower = fn.lower(fn.substr(col, 2))
return first_upper.concat(rest_lower)

extend_base(
SqliteColumn,
)

aggregate = extend_base(
SqliteColumnAgg,
)
between = between,
clip = sql_not_impl("sqlite does not have a least or greatest function."),

div = sql_func_truediv,
divide = sql_func_truediv,
rdiv = lambda _, x,y: sql_func_truediv(_, y, x),

window = extend_base(
__truediv__ = sql_func_truediv,
truediv = sql_func_truediv,
__rtruediv__ = lambda _, x, y: sql_func_truediv(_, y, x),

round = sql_round,
__round__ = sql_round,

**{
"str.title": sql_not_impl("TODO"),
"str.capitalize": sql_str_capitalize,
},

**{
"dt.quarter": sql_extract("quarter"),
"dt.is_quarter_start": sql_not_impl("TODO"),
"dt.is_quarter_end": sql_not_impl("TODO"),
"dt.days_in_month": sql_days_in_month,
"dt.daysinmonth": sql_days_in_month,
"dt.week": sql_week_of_year,
"dt.weekofyear": sql_week_of_year,

}
)

returns_float([
"dt.dayofweek",
"dt.weekday",
])


extend_base(
SqliteColumn,
**base_nowin
# TODO: should check sqlite version, since < 3.25 can't use windows
cumsum = win_cumul("sum"),

quantile = sql_not_impl("sqlite does not support ordered set aggregates"),
sum = win_agg("sum"),
)

extend_base(
SqliteColumnAgg,
quantile = sql_not_impl("sqlite does not support ordered set aggregates"),
)


Expand Down
Loading