Skip to content

Commit

Permalink
[DUCKDB] Add a Derivative macro (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Dec 24, 2022
1 parent 6051b7b commit 3815fea
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
7 changes: 6 additions & 1 deletion integration/duckdb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Machine Learning functions
| `predict(model, blob)` | Run model inference over image |
| `ml_models()` | Show all ML models |

Currently the Lance duckdb extension is compiled against pytorch 1.13
Currently, the Lance duckdb extension is compiled against pytorch 1.13

```sql
CALL create_pytorch_model('resnet', './resnet.pth', 'cpu')
Expand All @@ -36,6 +36,11 @@ Vector functions
| `l2_distance(list, list)` | Calculate L2 distance between two vectors |
| `in_rectangle(list, list[list])` | Whether the point is in a bounding box |

Misc functions

| Function | Description |
|--------------|--------------------------------------|
| `dydx(y, x)` | Calculate derivative $\frac{dy}{dx}$ |

## Development

Expand Down
12 changes: 12 additions & 0 deletions integration/duckdb/src/lance/duckdb/lance-extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "lance-extension.h"

#include <duckdb.hpp>
#include <duckdb/catalog/default/default_functions.hpp>
#include <duckdb/parser/parsed_data/create_table_function_info.hpp>

#include "lance/duckdb/lance_reader.h"
Expand All @@ -28,6 +29,12 @@

namespace duckdb {

static DefaultMacro macros[] = {{DEFAULT_SCHEMA,
"dydx",
{"y", "x", nullptr},
"y - lag(y, 1) OVER (ORDER BY x) / (x - lag(x, 1, 0) OVER (ORDER BY x))"},
{nullptr, nullptr, {nullptr}, nullptr}};

void LanceExtension::Load(::duckdb::DuckDB &db) {
duckdb::Connection con(db);
con.BeginTransaction();
Expand All @@ -43,6 +50,11 @@ void LanceExtension::Load(::duckdb::DuckDB &db) {
catalog.CreateFunction(context, func.get());
}

for (idx_t index = 0; macros[index].name != nullptr; index++) {
auto info = DefaultFunctionGenerator::CreateInternalMacroInfo(macros[index]);
catalog.CreateFunction(*con.context, info.get());
}

#if defined(WITH_PYTORCH)
for (auto &func : lance::duckdb::ml::GetMLFunctions()) {
catalog.CreateFunction(context, func.get());
Expand Down
5 changes: 5 additions & 0 deletions integration/duckdb/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ def test_list_argmax(db: duckdb.DuckDBPyConnection):
for dtype in ["INT", "BIGINT", "FLOAT", "DOUBLE"]:
df = db.query(f"""SELECT list_argmax([1, 2, 3, 2, 1]::{dtype}[]) as idx""").to_df()
assert_series_equal(df.idx, pd.Series([2], name='idx', dtype='int32'))

def test_derivative(db: duckdb.DuckDBPyConnection):
tbl = pa.Table.from_pylist([{"x": i * 0.2, "y": i * 1} for i in range(5)])
df = db.query("SELECT dydx(y, x) as d FROM tbl").to_df()
assert_series_equal(df.d, pd.Series([None, 5, 5, 5, 5]))

0 comments on commit 3815fea

Please sign in to comment.