Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add table join to formhandler. #705

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions gramex/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def filter(
args: dict = {},
meta: dict = {},
engine: str = None,
join: str = None,
table: str = None,
ext: str = None,
id: List[str] = None,
Expand Down Expand Up @@ -271,6 +272,7 @@ def filter(
argstype=argstype,
id=id,
table=table,
join=join,
columns=columns,
ext=ext,
query=query,
Expand Down Expand Up @@ -313,7 +315,7 @@ def filter(
data = gramex.cache.query(table, engine, [table])
return _filter_frame(transform(data), meta, controls, args, argstype)
else:
return _filter_db(engine, table, meta, controls, args, argstype)
return _filter_db(engine, table, meta, controls, args, argstype, join=join)
else:
raise ValueError('No table: or query: specified')
else:
Expand Down Expand Up @@ -1686,6 +1688,7 @@ def _filter_db(
argstype: Dict[str, dict] = {},
source: str = 'select',
id: List[str] = None,
join: dict = None,
):
'''
Parameters:
Expand All @@ -1698,16 +1701,76 @@ def _filter_db(
argstype: optional dict that specifies `args` type and behavior.
id: list of keys specific to data using which values can be updated
'''

def get_joins(table, join):
if not join:
return table.columns, sa.select([table])

cols = {}
labels = []
label_texts = []
for c in table.columns:
cols[c.name] = c
labels.append(c.label(c.name))
label_texts.append(f"{table.name}.{c.name}")

# Identify all tables and columns required
tables_map = {}
for t in join.keys():
tables_map[t] = tbl = get_table(engine, t)
for c in tbl.columns:
lbl = f'{t}_{c.name}'
cols[lbl] = c
labels.append(c.label(lbl))
label_texts.append(f'{t}.{c.name}')

query = sa.select()
# Establish an explicit left side by setting the main table as the base
query = query.select_from(table)

for t, extras in join.items():
join_attr = [tables_map[t]]
if 'column' in extras:
conditions = []
for k, v in extras['column'].items():
invalidColumns = []
if k not in label_texts:
invalidColumns.append(k)
if v not in label_texts:
invalidColumns.append(v)
if len(invalidColumns) > 0:
app_log.warning(f'invalid column(s): {", ". join(invalidColumns)}')
continue

conditions.append(f'{k}={v}')
labels = [
l
for l in labels
if l.name not in [k.replace('.', '_'), v.replace('.', '_')]
]

condition = sa.text(' AND '.join(conditions))
join_attr.append(condition)

query = query.join(
*join_attr,
isouter='type' in extras and extras['type'].lower() in ['left', 'outer'],
)

query = query.with_only_columns(labels)
return cols, query

table = get_table(engine, table)
cols = table.columns
colslist = cols.keys()

if source == 'delete':
query = sa.delete(table)
elif source == 'update':
query = sa.update(table)
else:
query = sa.select([table])
cols, query = get_joins(table, join)
colslist = list(cols.keys())

cols_for_update = {}
cols_having = []
for key, vals in args.items():
Expand Down
25 changes: 25 additions & 0 deletions pytest/Docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
version: '3.9'

services:

mysql:
image: mysql:8.0
container_name: mysql
restart: always
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: 'yes'
ports:
- 3306:3306
expose:
- 3306

postgres:
image: postgres:13.2
container_name: postgres
restart: always
environment:
POSTGRES_HOST_AUTH_METHOD: trust
ports:
- 5432:5432
expose:
- 5432
6 changes: 6 additions & 0 deletions pytest/formhandler-basic/test-case.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kwargs:
url: ""
table: "sales"
expected: "SELECT * FROM sales"
formatting:
sale_date: to_datetime
28 changes: 28 additions & 0 deletions pytest/formhandler-join-controls/test-case.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
kwargs:
url: ""
table: sales
join:
products:
type: inner
column:
products.id: sales.product_id
customers:
type: left
column:
sales.customer_id: customers.id
args:
_c:
- "id|count"
_by:
- "customer_id"
customer_id>:
- '3'
expected: >
SELECT
customers.id AS customer_id,
count(customers.id) as 'id|count'
FROM sales
JOIN products ON products.id = sales.product_id
LEFT OUTER JOIN customers ON sales.customer_id = customers.id
WHERE customers.id > 3
GROUP BY customers.id
32 changes: 32 additions & 0 deletions pytest/formhandler-join/test-case.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
kwargs:
url: ""
table: sales
join:
products:
type: inner
column:
products.id: sales.product_id
customers:
type: left
column:
sales.customer_id: customers.id
expected: >
SELECT
sales.id AS sales_id,
sales.customer_id AS sales_customer_id,
sales.product_id AS sales_product_id,
sales.sale_date AS sales_sale_date,
sales.amount AS sales_amount,
sales.city AS sales_city,
products.id AS sales_id,
products.name AS sales_name,
products.price AS sales_price,
products.manufacturer AS sales_manufacturer,
customers.id AS sales_id,
customers.name AS sales_name,
customers.city AS sales_city
FROM sales
JOIN products ON products.id==sales.product_id
LEFT OUTER JOIN customers ON sales.customer_id==customers.id
formatting:
sales_sale_date: to_datetime
86 changes: 86 additions & 0 deletions pytest/test_formhandler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import pytest
import gramex.data
import gramex.cache
from itertools import product
from contextlib import contextmanager
import pandas as pd
import dbutils
from pandas.testing import assert_frame_equal as afe
from glob import glob


folder = os.path.dirname(os.path.abspath(__file__))
sales_join_file = os.path.join(folder, "..", "tests", "sales_join.xlsx")
sales_join_data: pd.DataFrame = gramex.cache.open(sales_join_file, sheet_name="sales")
customers_data: pd.DataFrame = gramex.cache.open(sales_join_file, sheet_name="customers")
products_data: pd.DataFrame = gramex.cache.open(sales_join_file, sheet_name="products")


@contextmanager
def sqlite():
yield dbutils.sqlite_create_db(
"test_formhandler_join.db",
sales=sales_join_data,
customers=customers_data,
products=products_data,
)
dbutils.sqlite_drop_db("test_formhandler_join.db")

@contextmanager
def mysql():
server = os.environ.get('MYSQL_SERVER', 'localhost')
yield dbutils.mysql_create_db(
server,
"test_formhandler_join",
sales=sales_join_data,
customers=customers_data,
products=products_data,
)
dbutils.mysql_drop_db(server, "test_formhandler_join")


@contextmanager
def postgres():
server = os.environ.get('POSTGRES_SERVER', 'localhost')
yield dbutils.postgres_create_db(
server,
"test_formhandler_join",
sales=sales_join_data,
customers=customers_data,
products=products_data,
)
dbutils.postgres_drop_db(server, "test_formhandler_join")

# @contextmanager
# def dataframe():
# yield {'url': sales_join_data.copy()}


db_setups = [
# dataframe,
sqlite,
mysql,
postgres,
]


@pytest.mark.parametrize(
"result,db_setup",
product(glob(os.path.join(folder, "formhandler-*", "*.yaml")), db_setups),
)
def test_formhandler_join(result, db_setup):
resJson = gramex.cache.open(result)
args = []
if "args" in resJson:
args = resJson["args"]
with db_setup() as url:
resJson["kwargs"]["url"] = url
actual = gramex.data.filter(args=args, meta={}, **resJson["kwargs"])
expected = pd.read_sql(resJson["expected"], url)
if not expected.empty and "formatting" in resJson:
for k, v in resJson["formatting"].items():
fun = getattr(pd, v)
expected[k] = expected[k].apply(fun)

afe(expected, actual)
12 changes: 12 additions & 0 deletions tests/gramex.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,18 @@ url:
formats:
json:
date_format: iso
formhandler/join:
pattern: /formhandler/join
handler: FormHandler
kwargs:
url: sqlite:///formhandler.db
table: sales
join:
cities:
type: left
column:
sales.city: cities.city
sales.nonexistent: cities.nonexistent
formhandler/dir:
pattern: /formhandler/dir
handler: FormHandler
Expand Down
Binary file modified tests/sales.xlsx
Binary file not shown.
Binary file added tests/sales_join.xlsx
Binary file not shown.
55 changes: 53 additions & 2 deletions tests/test_formhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def copy_file(source, target):


class TestFormHandler(TestGramex):
sales = gramex.cache.open(os.path.join(folder, 'sales.xlsx'), 'xlsx')
sales = gramex.cache.open(os.path.join(folder, 'sales.xlsx'), sheet_name='sales')
cities = gramex.cache.open(os.path.join(folder, 'sales.xlsx'), sheet_name='cities')

@classmethod
def setUpClass(cls):
dbutils.sqlite_create_db('formhandler.db', sales=cls.sales)
dbutils.sqlite_create_db('formhandler.db', sales=cls.sales, cities=cls.cities)

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -834,6 +835,56 @@ def test_date_comparison(self):
expected.index = actual.index
afe(actual, expected, check_like=True)

def test_join(self):
def check(expected, *args, **params):
url = '/formhandler/join'
if args:
url += f'?{"&".join(args)}'
params = {}

r = self.get(url, params=params)
actual = pd.DataFrame(r.json())
afe(actual, expected.reset_index(drop=True), check_like=True)

expected = self.sales.merge(self.cities, how='left')
expected = expected.rename(columns={'demand': 'cities_demand', 'drive': 'cities_drive'})
check(expected)
check(expected[expected['city'] == 'Singapore'], city='Singapore')
check(expected[expected['sales'] != 500], **{"sales%33": '500'})
check(expected[expected['sales'] > 500], **{"sales>": '500'})
check(expected[expected['sales'] >= 500], **{"sales>~": '500'})
check(expected[expected['sales'] < 500], **{"sales<": '500'})
check(expected[expected['sales'] <= 500], **{"sales<~": '500'})
check(
expected[expected['cities_demand'] > 400].sort_values(by='product'),
**{"cities_demand>": '400', "_sort": 'product'},
)
check(
expected[expected['cities_demand'] > 400].sort_values(by='product', ascending=False),
**{"cities_demand>": '400', "_sort": '-product'},
)
check(
# FIXME: we should not have to rename the columns, the column name must always be same
expected[['sales', 'growth', 'cities_drive']].rename(
columns={'cities_drive': 'drive'}
),
"_c=sales",
"_c=growth",
"_c=cities_drive",
)
# check(
# # FIXME: Test Failing
# expected.drop(['sales', 'growth', 'cities_drive'], axis=1),
# "_c=-sales",
# "_c=-growth",
# "_c=-cities_drive",
# )
check(expected.dropna(subset=['sales']), "sales")
check(
expected[expected['sales'].isna()].applymap(lambda x: None if pd.isnull(x) else x),
"sales!",
)

def test_edit_id_type(self):
target = copy_file('sales.xlsx', 'sales-edits.xlsx')
tempfiles[target] = target
Expand Down