From 3815feac640c2c06d02b0ce48bd8e3ec03d559ea Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 23 Dec 2022 16:55:37 -0800 Subject: [PATCH] [DUCKDB] Add a Derivative macro (#393) --- integration/duckdb/README.md | 7 ++++++- .../duckdb/src/lance/duckdb/lance-extension.cc | 12 ++++++++++++ integration/duckdb/tests/test_query.py | 5 +++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/integration/duckdb/README.md b/integration/duckdb/README.md index aff888025a..801106ed06 100644 --- a/integration/duckdb/README.md +++ b/integration/duckdb/README.md @@ -20,7 +20,7 @@ Machine Learning functions | `predict(model, blob)` | Run model inference over image | | `ml_models()` | Show all ML models | -Currently the Lance duckdb extension is compiled against pytorch 1.13 +Currently, the Lance duckdb extension is compiled against pytorch 1.13 ```sql CALL create_pytorch_model('resnet', './resnet.pth', 'cpu') @@ -36,6 +36,11 @@ Vector functions | `l2_distance(list, list)` | Calculate L2 distance between two vectors | | `in_rectangle(list, list[list])` | Whether the point is in a bounding box | +Misc functions + +| Function | Description | +|--------------|--------------------------------------| +| `dydx(y, x)` | Calculate derivative $\frac{dy}{dx}$ | ## Development diff --git a/integration/duckdb/src/lance/duckdb/lance-extension.cc b/integration/duckdb/src/lance/duckdb/lance-extension.cc index c3b4f0dadf..1c6f61c8f9 100644 --- a/integration/duckdb/src/lance/duckdb/lance-extension.cc +++ b/integration/duckdb/src/lance/duckdb/lance-extension.cc @@ -17,6 +17,7 @@ #include "lance-extension.h" #include +#include #include #include "lance/duckdb/lance_reader.h" @@ -28,6 +29,12 @@ namespace duckdb { +static DefaultMacro macros[] = {{DEFAULT_SCHEMA, + "dydx", + {"y", "x", nullptr}, + "y - lag(y, 1) OVER (ORDER BY x) / (x - lag(x, 1, 0) OVER (ORDER BY x))"}, + {nullptr, nullptr, {nullptr}, nullptr}}; + void LanceExtension::Load(::duckdb::DuckDB &db) { duckdb::Connection con(db); con.BeginTransaction(); @@ -43,6 +50,11 @@ void LanceExtension::Load(::duckdb::DuckDB &db) { catalog.CreateFunction(context, func.get()); } + for (idx_t index = 0; macros[index].name != nullptr; index++) { + auto info = DefaultFunctionGenerator::CreateInternalMacroInfo(macros[index]); + catalog.CreateFunction(*con.context, info.get()); + } + #if defined(WITH_PYTORCH) for (auto &func : lance::duckdb::ml::GetMLFunctions()) { catalog.CreateFunction(context, func.get()); diff --git a/integration/duckdb/tests/test_query.py b/integration/duckdb/tests/test_query.py index 99f477387f..1feb153b04 100644 --- a/integration/duckdb/tests/test_query.py +++ b/integration/duckdb/tests/test_query.py @@ -48,3 +48,8 @@ def test_list_argmax(db: duckdb.DuckDBPyConnection): for dtype in ["INT", "BIGINT", "FLOAT", "DOUBLE"]: df = db.query(f"""SELECT list_argmax([1, 2, 3, 2, 1]::{dtype}[]) as idx""").to_df() assert_series_equal(df.idx, pd.Series([2], name='idx', dtype='int32')) + +def test_derivative(db: duckdb.DuckDBPyConnection): + tbl = pa.Table.from_pylist([{"x": i * 0.2, "y": i * 1} for i in range(5)]) + df = db.query("SELECT dydx(y, x) as d FROM tbl").to_df() + assert_series_equal(df.d, pd.Series([None, 5, 5, 5, 5]))