Skip to content

Commit

Permalink
Add python tests for duckdb extension, and bind function with multipl…
Browse files Browse the repository at this point in the history
…e input types (#243)
  • Loading branch information
eddyxu authored Oct 18, 2022
1 parent 1e7386e commit ab3c694
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 48 deletions.
44 changes: 30 additions & 14 deletions integration/duckdb/src/lance/duckdb/list_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,42 @@ void ListArgMax(::duckdb::DataChunk &args,
}
}

::duckdb::ScalarFunction ListArgMaxOp(const ::duckdb::LogicalType &type) {
switch (type.InternalType()) {
case ::duckdb::PhysicalType::INT32:
return ::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER)},
::duckdb::LogicalType::INTEGER,
ListArgMax<int>);
default:
return ::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT)},
::duckdb::LogicalType::INTEGER,
ListArgMax<float>);
}
}

std::unique_ptr<::duckdb::FunctionData> ListArgMaxBind(
::duckdb::ClientContext &context,
::duckdb::ScalarFunction &function,
std::vector<std::unique_ptr<::duckdb::Expression>> &arguments) {
auto input_type = arguments[0]->return_type;
auto name = std::move(function.name);
function = ListArgMaxOp(input_type);
function.name = std::move(name);
if (function.bind) {
return function.bind(context, function, arguments);
}
return nullptr;
}

std::vector<std::unique_ptr<::duckdb::CreateFunctionInfo>> GetListFunctions() {
std::vector<std::unique_ptr<::duckdb::CreateFunctionInfo>> functions;

::duckdb::ScalarFunctionSet list_argmax("list_argmax");
list_argmax.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT)},
::duckdb::LogicalType::INTEGER,
ListArgMax<int64_t>));
list_argmax.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER)},
::duckdb::LogicalType::INTEGER,
ListArgMax<int>));
list_argmax.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT)},
::duckdb::LogicalType::INTEGER,
ListArgMax<float>));
list_argmax.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE)},
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::ANY)},
::duckdb::LogicalType::INTEGER,
ListArgMax<double>));
nullptr,
ListArgMaxBind));
functions.emplace_back(std::make_unique<::duckdb::CreateScalarFunctionInfo>(list_argmax));

return functions;
Expand Down
84 changes: 50 additions & 34 deletions integration/duckdb/src/lance/duckdb/vector_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cstdint>
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
#include <iostream>
#include <memory>

namespace lance::duckdb {
Expand All @@ -40,6 +41,49 @@ void L2Distance(::duckdb::DataChunk &args,
}
}

::duckdb::ScalarFunction L2DistanceOp(const ::duckdb::LogicalType &type) {
if (type.InternalType() != ::duckdb::PhysicalType::LIST) {
throw ::duckdb::BinderException("l2_distance expects list type, got: ", type.ToString());
}
auto child_type = ::duckdb::ListType::GetChildType(type);
switch (child_type.InternalType()) {
case ::duckdb::PhysicalType::INT32:
return ::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER)},
::duckdb::LogicalType::INTEGER,
L2Distance<int>);
case ::duckdb::PhysicalType::INT64:
return ::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT)},
::duckdb::LogicalType::BIGINT,
L2Distance<int64_t>);
case ::duckdb::PhysicalType::FLOAT:
return ::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT)},
::duckdb::LogicalType::FLOAT,
L2Distance<float>);
default:
return ::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE)},
::duckdb::LogicalType::DOUBLE,
L2Distance<double>);
}
}

std::unique_ptr<::duckdb::FunctionData> L2DistanceBind(
::duckdb::ClientContext &context,
::duckdb::ScalarFunction &function,
std::vector<std::unique_ptr<::duckdb::Expression>> &arguments) {
auto input_type = arguments[0]->return_type;
auto name = std::move(function.name);
function = L2DistanceOp(input_type);
function.name = std::move(name);
if (function.bind) {
return function.bind(context, function, arguments);
}
return nullptr;
}

void IsInRectangle(::duckdb::DataChunk &args,
::duckdb::ExpressionState &state,
::duckdb::Vector &result) {
Expand All @@ -64,43 +108,15 @@ std::vector<std::unique_ptr<::duckdb::CreateFunctionInfo>> GetVectorFunctions()

::duckdb::ScalarFunctionSet l2_distance("l2_distance");
l2_distance.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER)},
::duckdb::LogicalType::INTEGER,
L2Distance<int>));
l2_distance.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT)},
::duckdb::LogicalType::BIGINT,
L2Distance<int64_t>));
l2_distance.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT)},
::duckdb::LogicalType::FLOAT,
L2Distance<float>));
l2_distance.AddFunction(
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE)},
::duckdb::LogicalType::DOUBLE,
L2Distance<double>));
::duckdb::ScalarFunction({::duckdb::LogicalType::LIST(::duckdb::LogicalType::ANY),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::ANY)},
::duckdb::LogicalType::ANY,
nullptr,
L2DistanceBind));
functions.emplace_back(std::make_unique<::duckdb::CreateScalarFunctionInfo>(l2_distance));

::duckdb::ScalarFunctionSet in_rectangle("in_rectangle");
in_rectangle.AddFunction(::duckdb::ScalarFunction(
{::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::LIST(::duckdb::LogicalType::INTEGER))},
::duckdb::LogicalType::BOOLEAN,
IsInRectangle));
in_rectangle.AddFunction(::duckdb::ScalarFunction(
{::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::LIST(::duckdb::LogicalType::BIGINT))},
::duckdb::LogicalType::BOOLEAN,
IsInRectangle));
in_rectangle.AddFunction(::duckdb::ScalarFunction(
{::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::LIST(::duckdb::LogicalType::FLOAT))},
::duckdb::LogicalType::BOOLEAN,
IsInRectangle));
/// All upcast to double
in_rectangle.AddFunction(::duckdb::ScalarFunction(
{::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE),
::duckdb::LogicalType::LIST(::duckdb::LogicalType::LIST(::duckdb::LogicalType::DOUBLE))},
Expand Down
17 changes: 17 additions & 0 deletions integration/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python3

from pathlib import Path

import duckdb
import pytest


@pytest.fixture
def db() -> duckdb.DuckDBPyConnection:
"""Initialize duckdb with lance extension"""
db = duckdb.connect(config={"allow_unsigned_extensions": True})

cur_path = Path(__file__).parent
db.install_extension(str(cur_path.parent / "manylinux-build" / "lance.duckdb_extension"), force_install=True)
db.load_extension("lance")
return db
60 changes: 60 additions & 0 deletions integration/duckdb/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

from pathlib import Path

import pandas as pd
import torch
import pyarrow as pa

from PIL import Image
from duckdb import DuckDBPyConnection
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights


def test_resnet(db: DuckDBPyConnection, tmp_path: Path):
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
m = torch.jit.script(resnet)
model_path = tmp_path / "resnet.pth"
torch.jit.save(m, str(model_path))

db.execute(f"CALL create_pytorch_model('resnet', '{str(model_path)}');")

expected_models = pd.DataFrame([{
"name": "resnet",
"uri": str(model_path),
"type": "torchscript",
}])
pd.testing.assert_frame_equal(db.query("SELECT * FROM ml_models()").to_df(), expected_models)

cat_path = Path(__file__).parent / "testdata" / "cat.jpg"
cat = cat_path.read_bytes()
tbl = pa.Table.from_pylist([{"img": cat}])

df = db.query("SELECT predict('resnet', img) as prob FROM tbl").to_df()
actual_prob = torch.tensor(df["prob"].iloc[0])
actual_class = torch.argmax(actual_prob)

resnet.eval()

preprocess = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image: Image = Image.open(cat_path)
image_tensor = preprocess(image)
batch: torch.Tensor = image_tensor.unsqueeze(0)
with torch.no_grad():
output = resnet(batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
assert torch.equal(actual_class, torch.argmax(probabilities))
assert torch.allclose(actual_prob, probabilities)

argmax = (db.query("SELECT list_argmax(predict('resnet', img)) as pred FROM tbl")
.to_df().pred)
assert (argmax.values == torch.argmax(probabilities).numpy()).all()

db.execute("CALL drop_model('resnet')")
assert db.query("SELECT * FROM ml_models()").to_df().size == 0
50 changes: 50 additions & 0 deletions integration/duckdb/tests/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
#

import duckdb
import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.testing import assert_series_equal


def test_l2_distance(db: duckdb.DuckDBPyConnection):
"""GH-7"""
embeddings = np.random.randn(10, 10)
tbl = pa.Table.from_arrays([embeddings.tolist()], names=["embedding"])

df = db.query("""
SELECT l2_distance(embedding,
[0.14132072948046223,-0.8578304618530145,1.03418279173152,-0.01988450184766287,0.20275013403601405,
1.1907599349042708, 1.592254025308326, -0.5606235353210591, -1.353943627981242, 0.10803636704591536]) AS score
FROM tbl
""").to_df()

v1 = np.array(
[0.14132072948046223, -0.8578304618530145, 1.03418279173152, -0.01988450184766287, 0.20275013403601405,
1.1907599349042708, 1.592254025308326, -0.5606235353210591, -1.353943627981242, 0.10803636704591536])
expected = pd.Series(((embeddings - v1) ** 2).sum(axis=1))
assert np.allclose(df.score.to_numpy(), expected)

df = db.query("""SELECT l2_distance([1, 2], [1, 2]) as score""").to_df()
assert_series_equal(df.score, pd.Series([0], name="score", dtype='int32'))


def test_in_rectangle(db: duckdb.DuckDBPyConnection):
tbl = pa.Table.from_pylist([{"box": [[1, 2], [3, 4]]}, {"box": [[10, 20], [30, 45]]}])
df = db.query("""SELECT in_rectangle([15, 35], box) AS contain FROM tbl""").to_df()
assert_series_equal(df.contain, pd.Series([False, True]), check_names=False)

tbl = pa.Table.from_pylist([{"point": [1, 2]}, {"point": [10, 20]}])
df = db.query("""SELECT in_rectangle(point, [[5, 10], [30, 40]]) AS contain FROM tbl""").to_df()
assert_series_equal(df.contain, pd.Series([False, True]), check_names=False)

df = db.query(
"""SELECT in_rectangle([15.0, 35.5], [[5.5, 10.1], [30.3, 40.4]]) as contain""").to_df()
assert_series_equal(df.contain, pd.Series([True]), check_names=False)


def test_list_argmax(db: duckdb.DuckDBPyConnection):
for dtype in ["INT", "BIGINT", "FLOAT", "DOUBLE"]:
df = db.query(f"""SELECT list_argmax([1, 2, 3, 2, 1]::{dtype}[]) as idx""").to_df()
assert_series_equal(df.idx, pd.Series([2], name='idx', dtype='int32'))
Binary file added integration/duckdb/tests/testdata/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit ab3c694

Please sign in to comment.