Skip to content

Commit

Permalink
feat: basic mysql translation
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed Feb 22, 2021
1 parent fe0f0a9 commit a9b97a9
Show file tree
Hide file tree
Showing 15 changed files with 298 additions and 105 deletions.
41 changes: 26 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,35 @@ jobs:
requirements: numpy~=1.19.1 pandas~=1.1.0 SQLAlchemy~=1.3.18 psycopg2~=2.8.5

# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
image: postgres
env:
POSTGRES_PASSWORD: ""
POSTGRES_HOST_AUTH_METHOD: "trust"
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
#services:
# # Label used to access the service container
# postgres:
# image: postgres
# env:
# POSTGRES_PASSWORD: ""
# POSTGRES_HOST_AUTH_METHOD: "trust"
# # Set health checks to wait until postgres has started
# options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
# ports:
# - 5432:5432
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: ""
# MYSQL_ALLOW_EMPTY_PASSWORD: 1
# MYSQL_DATABASE: "public"
# ports:
# - 3306:3306
# # by default, mysql rounds to 4 decimals, but tests require more precision
# command: --div-precision-increment=30
# options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=5


steps:
- uses: actions/checkout@v2
- name: Run docker-compose
run: |
docker-compose up --build -d
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
12 changes: 12 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@ version: '3.1'

services:

db_mysql:
image: mysql
restart: always
environment:
MYSQL_ROOT_PASSWORD: ""
MYSQL_ALLOW_EMPTY_PASSWORD: 1
MYSQL_DATABASE: "public"
ports:
- 3306:3306
# by default, mysql rounds to 4 decimals, but tests require more precision
command: --div-precision-increment=30

db:
image: postgres
restart: always
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ psycopg2==2.8.5
ptyprocess==0.6.0
py==1.8.1
Pygments==2.5.2
PyMySQL==1.0.2
pyparsing==2.4.6
pyrsistent==0.15.7
pytest==5.3.5
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
numpy==1.19.1
pandas==1.1.0
psycopg2==2.8.5
PyMySQL==1.0.2
python-dateutil==2.8.1
pytz==2020.1
PyYAML==5.3.1
Expand Down
2 changes: 1 addition & 1 deletion siuba/spec/series2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def rm_na_entries(mapping):

from siuba import *

sql_backend_names = ["postgresql", "redshift", "sqlite"]
sql_backend_names = ["postgresql", "redshift", "sqlite", "mysql"]

sql_methods = pd.concat(list(map(read_dialect, sql_backend_names)))

Expand Down
28 changes: 21 additions & 7 deletions siuba/sql/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# we (1) use full import paths, (2) define everything a new backend would need
# here.
from sqlalchemy import sql
from sqlalchemy import types as sa_types
from sqlalchemy.sql import func as fn

from siuba import ops
Expand All @@ -18,7 +19,8 @@
sql_colmeth,
sql_not_impl,
create_sql_translators,
annotate
annotate,
RankOver
)

# TODO: move anything using this into base.py
Expand Down Expand Up @@ -55,6 +57,17 @@
# cot = sql_scalar("cot"),
#

def sql_func_floordiv(x, y):
return sql.cast(x / y, sa_types.Integer())

def sql_func_rank(col):
# see https://stackoverflow.com/a/36823637/1144523
min_rank = RankOver(sql.func.rank(), order_by = col)
to_mean = (RankOver(sql.func.count(), partition_by = col) - 1) / 2.0

return min_rank + to_mean


def req_bool(f):
return annotate(f, input_type = "bool")

Expand All @@ -64,7 +77,7 @@ def req_bool(f):
__and__ = req_bool(sql_colmeth("__and__")),
__div__ = sql_colmeth("__div__"),
__eq__ = sql_colmeth("__eq__"),
__floordiv__ = sql_not_impl(),
__floordiv__ = sql_func_floordiv,
__ge__ = sql_colmeth("__ge__"),
__gt__ = sql_colmeth("__gt__"),
__invert__ = req_bool(sql_colmeth("__invert__")),
Expand All @@ -80,7 +93,7 @@ def req_bool(f):
__radd__ = sql_colmeth("__radd__"),
__rand__ = req_bool(sql_colmeth("__rand__")),
__rdiv__ = sql_colmeth("__rdiv__"),
__rfloordiv__ = sql_colmeth("__pow__"),
__rfloordiv__ = lambda x, y: sql_func_floordiv(y, x),
__rmod__ = sql_colmeth("__rmod__"),
__rmul__ = sql_colmeth("__rmul__"),
__ror__ = req_bool(sql_colmeth("__ror__")),
Expand Down Expand Up @@ -139,6 +152,7 @@ def req_bool(f):


**{
# TODO: check generality of trim functions, since MYSQL overrides
"str.capitalize" : sql_func_capitalize,
#"str.center" :,
#"str.contains" :,
Expand Down Expand Up @@ -261,13 +275,13 @@ def req_bool(f):
cummax = win_cumul("max"),
cummin = win_cumul("min"),
#cumprod =
cumsum = win_cumul("sum"),
cumsum = annotate(win_cumul("sum"), result_type = "float"),
diff = sql_func_diff,
#is_monotonic =
#is_monotonic_decreasing =
#is_monotonic_increasing =
#pct_change = TODO(?)
rank = win_over("rank"),
rank = sql_func_rank,

# computation (strict aggregates)
#all = #TODO(pg): all = sql_aggregate("BOOL_AND", "all")
Expand All @@ -290,7 +304,7 @@ def req_bool(f):
#sem =
#skew =
#std = # TODO(pg)
sum = win_agg("sum"),
sum = annotate(win_agg("sum"), result_type = "float"),
#var = # TODO(pg)


Expand Down Expand Up @@ -354,7 +368,7 @@ def req_bool(f):
#sem =
#skew =
#std = # TODO(pg)
sum = sql_agg("sum"),
sum = annotate(sql_agg("sum"), result_type = "float"),
#var = # TODO(pg)

# index ----
Expand Down
131 changes: 131 additions & 0 deletions siuba/sql/dialects/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# sqlvariant, allow defining 3 namespaces to override defaults
from ..translate import (
SqlColumn, SqlColumnAgg, SqlTranslations, win_agg,
create_sql_translators, sql_not_impl
)

from .base import base_scalar, base_agg, base_win

import sqlalchemy.sql.sqltypes as sa_types

from sqlalchemy import sql
from sqlalchemy.sql import func as fn

from sqlalchemy.dialects.mysql import DOUBLE

# Custom dispatching in call trees ============================================

class MysqlColumn(SqlColumn): pass
class MysqlColumnAgg(SqlColumnAgg, MysqlColumn): pass

def sql_str_strip(left = True, right = True):
def f(col):
# see https://stackoverflow.com/a/6858168/1144523
lstrip = "^[[:space:]]+" if left else ""
rstrip = "[[:space:]]+$" if right else ""

or_op = "|" if lstrip and rstrip else ""
regex = "(" + lstrip + or_op + rstrip + ")"

return fn.regexp_replace(col, regex, "")

return f

def sql_func_extract_dow_monday(col):
# MYSQL: sunday starts, equals 1 (an int)
# pandas: monday starts, equals 0 (also an int)

raw_dow = fn.dayofweek(col)

# monday is 2 in MYSQL, so use monday + 5 % 7
return (raw_dow + 5) % 7

def sql_is_date_offset(period, is_start = True):

# will check against one day in the past for is_start, v.v. otherwise
fn_add = fn.date_sub if is_start else fn.date_add

def f(col):
get_period = getattr(fn, period)
src_per = get_period(col)
incr_per = get_period(fn_add(col, sql.text("INTERVAL 1 DAY")))

return src_per != incr_per

return f

def sql_func_truediv(x, y):
return sql.cast(x, DOUBLE()) / y

def sql_func_floordiv(x, y):
return x.op("DIV")(y)

def sql_func_between(col, left, right, inclusive=True):
if not inclusive:
raise NotImplementedError("between must be inclusive")

# TODO: should figure out how sqlalchemy prefers to set types, rather
# than setting manually on this expression
expr = col.between(left, right)
expr.type = sa_types.Boolean()
return expr

scalar = SqlTranslations(
base_scalar,

# copied from postgres. MYSQL does true division over ints by default,
# but it does not produce double precision.
__div__ = sql_func_truediv,
div = sql_func_truediv,
divide = sql_func_truediv,
rdiv = lambda x,y: sql_func_truediv(y, x),
__rdiv__ = lambda x, y: sql_func_truediv(y, x),

__truediv__ = sql_func_truediv,
truediv = sql_func_truediv,
__rtruediv__ = lambda x, y: sql_func_truediv(y, x),

__floordiv__ = sql_func_floordiv,
__rfloordiv__ = lambda x, y: sql_func_floordiv(y, x),

between = sql_func_between,

**{
"str.lstrip": sql_str_strip(right = False),
"str.rstrip": sql_str_strip(left = False),
"str.strip": sql_str_strip(),
"str.title": sql_not_impl() # see https://stackoverflow.com/q/12364086/1144523
},
**{
"dt.dayofweek": sql_func_extract_dow_monday,
"dt.dayofyear": lambda col: fn.dayofyear(col),
"dt.days_in_month": lambda col: fn.dayofmonth(fn.last_day(col)),
"dt.daysinmonth": lambda col: fn.dayofmonth(fn.last_day(col)),
"dt.is_month_end": lambda col: col == fn.last_day(col),
"dt.is_month_start": lambda col: fn.dayofmonth(col) == 1,
"dt.is_quarter_start": sql_is_date_offset("QUARTER"),
"dt.is_year_start": sql_is_date_offset("YEAR"),
"dt.is_year_end": sql_is_date_offset("YEAR", is_start = False),
# see https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_week
"dt.week": lambda col: fn.week(col, 1),
"dt.weekday": sql_func_extract_dow_monday,
"dt.weekofyear": lambda col: fn.week(col, 1),
}
)

aggregate = SqlTranslations(
base_agg
)

window = SqlTranslations(
base_win,
sd = win_agg("stddev")
)

funcs = dict(scalar = scalar, aggregate = aggregate, window = window)

translator = create_sql_translators(
scalar, aggregate, window,
MysqlColumn, MysqlColumnAgg
)

13 changes: 0 additions & 13 deletions siuba/sql/dialects/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ def sql_func_truediv(x, y):
def sql_func_floordiv(x, y):
return sql.cast(x / y, sa_types.Integer())

def sql_func_rank(col):
# see https://stackoverflow.com/a/36823637/1144523
min_rank = RankOver(sql.func.rank(), order_by = col)
to_mean = (RankOver(sql.func.count(), partition_by = col) - 1) / 2.0

return min_rank + to_mean

scalar = SqlTranslations(
base_scalar,
Expand All @@ -85,8 +79,6 @@ def sql_func_rank(col):
truediv = sql_func_truediv,
__rtruediv__ = lambda x, y: sql_func_truediv(y, x),

__floordiv__ = sql_func_floordiv,
__rfloordiv__ = lambda x, y: sql_func_floordiv(y, x),

round = sql_round,
__round__ = sql_round,
Expand All @@ -113,9 +105,6 @@ def sql_func_rank(col):
# overrides ----

# note that postgres does sum(bigint) -> numeric
sum = annotate(win_agg("sum"), result_type = "float"),
cumsum = annotate(win_cumul("sum"), result_type = "float"),
rank = sql_func_rank,
size = win_agg("count"), #TODO double check
)

Expand All @@ -125,8 +114,6 @@ def sql_func_rank(col):
any = sql_agg("bool_or"),
std = sql_agg("stddev_samp"),
var = sql_agg("var_samp"),

sum = annotate(sql_agg("sum"), result_type = "float"),
)


Expand Down
Loading

0 comments on commit a9b97a9

Please sign in to comment.