Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing column when both nearest and filter are applied #686

Merged
merged 3 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import lance
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from lance.vector import vec_to_table

Expand Down Expand Up @@ -70,11 +71,10 @@ def run(ds):
q = np.random.randn(768)
project = [None, ["price"], ["vector"], ["vector", "meta"]]
refine = [None, 1, 2]
# filters = [None, pc.field("price") > 50.0]
filters = [None, pc.field("price") > 50.0]
times = []

for columns in project:

expected_columns = []
if columns is None:
expected_columns.extend(ds.schema.names)
Expand All @@ -84,26 +84,29 @@ def run(ds):
if c not in expected_columns:
expected_columns.append(c)

for rf in refine:
# for filter_ in filters:
start = time.time()
rs = ds.to_table(
columns=columns,
# filter=filter_,
nearest={
"column": "vector",
"q": q,
"k": 10,
"nprobes": 1,
"refine_factor": rf,
},
)
end = time.time()
times.append(end - start)
assert rs.column_names == expected_columns
assert len(rs) == 10
scores = rs["score"].to_numpy()
assert (scores.max() - scores.min()) > 1e-6
for filter_ in filters:
for rf in refine:
start = time.time()
rs = ds.to_table(
columns=columns,
nearest={
"column": "vector",
"q": q,
"k": 10,
"nprobes": 1,
"refine_factor": rf,
},
filter=filter_,
)
end = time.time()
times.append(end - start)
assert rs.column_names == expected_columns
if filter_ is not None and "price" in (columns or []):
inmem = pa.dataset.dataset(rs)
assert len(inmem.to_table(filter=filter_)) == len(rs)
assert len(rs) == 10
scores = rs["score"].to_numpy()
assert (scores.max() - scores.min()) > 1e-6
return times


Expand Down
60 changes: 34 additions & 26 deletions rust/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use sqlparser::{dialect::GenericDialect, parser::Parser};

use super::Dataset;
use crate::datafusion::physical_expr::column_names_in_expr;
use crate::datatypes::Schema;
use crate::datatypes::{Field, Schema};
use crate::format::Index;
use crate::index::vector::{MetricType, Query};
use crate::io::exec::{GlobalTakeExec, KNNFlatExec, KNNIndexExec, LanceScanExec, LocalTakeExec};
Expand Down Expand Up @@ -212,28 +212,29 @@ impl Scanner {
self
}

/// The schema of the output, a.k.a, projection schema.
/// The Arrow schema of the output, including projections and vector / score
pub fn schema(&self) -> Result<SchemaRef> {
self.scanner_output_schema()
.map(|s| SchemaRef::new(ArrowSchema::from(s.as_ref())))
}

fn scanner_output_schema(&self) -> Result<Arc<Schema>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just be Scanner::schema?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the API aspect, Scanner::schema should just return the schema of output. So the contract can be built with other system components.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scanner::schema is the ArrowSchema and should not be used in Lance reader internals because that's why the field id's are messed up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll just make Scanner::schema call the other

if self.nearest.as_ref().is_some() {
let q = self.nearest.as_ref().unwrap();
let column: ArrowField = self
.dataset
.schema()
.field(q.column.as_str())
.ok_or_else(|| {
Error::Schema(format!("Vector column {} not found in schema", q.column))
})?
.into();
let score = ArrowField::new("score", Float32, false);
let score_schema = ArrowSchema::new(vec![column, score]);
let vector_search_columns = &Schema::try_from(&score_schema)?;
let merged = self.projections.merge(vector_search_columns);
Ok(SchemaRef::new(ArrowSchema::from(&merged)))
let merged = self.projections.merge(&self.vector_search_schema()?);
Ok(Arc::new(merged))
} else {
Ok(Arc::new(ArrowSchema::from(&self.projections)))
Ok(Arc::new(self.projections.clone()))
}
}

fn vector_search_schema(&self) -> Result<Schema> {
let q = self.nearest.as_ref().unwrap();
let vector_schema = self.dataset.schema().project(&[&q.column])?;
let score = ArrowField::new("score", Float32, false);
let score_schema = Schema::try_from(&ArrowSchema::new(vec![score]))?;
Ok(vector_schema.merge(&score_schema))
}

/// Create a stream of this Scanner.
///
/// TODO: implement as IntoStream/IntoIterator.
Expand Down Expand Up @@ -267,29 +268,35 @@ impl Scanner {
}
}

let knn_node = self.ann(q, &index);
let knn_node = self.ann(q, &index); // score, _rowid
let with_vector = self.dataset.schema().project(&[&q.column])?;
let knn_node_with_vector = self.take(knn_node, &with_vector, false);
let knn_node = if q.refine_factor.is_some() {
self.flat_knn(knn_node_with_vector, q)
} else {
knn_node_with_vector
};
}; // vector, score, _rowid

let knn_node = if let Some(filter_expression) = filter_expr {
let columns_in_filter = column_names_in_expr(filter_expression.as_ref());
let columns_refs = columns_in_filter
.iter()
.map(|c| c.as_str())
.collect::<Vec<_>>();
let filter_projection = Arc::new(self.dataset.schema().project(&columns_refs)?);
let filter_projection = self.dataset.schema().project(&columns_refs)?;

let take_node = Arc::new(GlobalTakeExec::new(
self.dataset.clone(),
filter_projection,
Arc::new(filter_projection),
knn_node,
false,
));
self.filter_node(filter_expression, take_node, false)?
self.filter_node(
filter_expression,
take_node,
false,
Some(Arc::new(self.vector_search_schema()?)),
)?
} else {
knn_node
};
Expand All @@ -313,7 +320,7 @@ impl Scanner {
)?,
);
let scan = self.scan(true, filter_schema);
self.filter_node(filter, scan, true)?
self.filter_node(filter, scan, true, None)?
} else {
self.scan(with_row_id, Arc::new(self.projections.clone()))
};
Expand Down Expand Up @@ -388,12 +395,15 @@ impl Scanner {
filter: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
drop_row_id: bool,
ann_schema: Option<Arc<Schema>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let filter_node = Arc::new(FilterExec::try_new(filter, input)?);
let output_schema = self.scanner_output_schema()?;
Ok(Arc::new(LocalTakeExec::new(
filter_node,
self.dataset.clone(),
Arc::new(self.projections.clone()),
output_schema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is leaking the abstraction of ann to take?

what is in output_schema nad ann_schema tho? can we just use one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_schema includes user supplied projections. ann_schema includes the vector/score columns.

ann_schema,
drop_row_id,
)))
}
Expand Down Expand Up @@ -518,7 +528,6 @@ mod test {
assert!(scan.filter.is_none());

scan.filter("i > 50").unwrap();
println!("Filter is: {:?}", scan.filter);
assert_eq!(scan.filter, Some("i > 50".to_string()));

let batches = scan
Expand All @@ -530,7 +539,6 @@ mod test {
.try_collect::<Vec<_>>()
.await
.unwrap();
println!("Batches: {:?}\n", batches);
let batch = concat_batches(&batches[0].schema(), &batches).unwrap();

let expected_batch = RecordBatch::try_new(
Expand Down
31 changes: 20 additions & 11 deletions rust/src/io/exec/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion::physical_plan::{
ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use futures::stream::{self, Stream, StreamExt, TryStreamExt};
use futures::{FutureExt, TryFutureExt};
use tokio::sync::mpsc::{self, Receiver};
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -236,12 +237,15 @@ impl LocalTake {
input: SendableRecordBatchStream,
dataset: Arc<Dataset>,
schema: Arc<Schema>,
ann_schema: Option<Arc<Schema>>, // TODO add input/output schema contract to exec nodes and remove this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File a ticket to remove this leaking abstraction?

drop_row_id: bool,
) -> Result<Self> {
let (tx, rx) = mpsc::channel(4);

let inner_schema = Schema::try_from(input.schema().as_ref())?;
let take_schema = schema.exclude(&inner_schema)?;
let mut take_schema = schema.exclude(&inner_schema)?;
if ann_schema.is_some() {
take_schema = take_schema.exclude(&ann_schema.unwrap())?;
}
let projection = schema.clone();

let _bg_thread = tokio::spawn(async move {
Expand All @@ -252,22 +256,22 @@ impl LocalTake {
.then(|(b, (dataset, take_schema, projection))| async move {
// TODO: need to cache the fragments.
let batch = b?;
if take_schema.fields.is_empty() {
return Ok(batch);
};
let projection_schema = ArrowSchema::from(projection.as_ref());
if batch.num_rows() == 0 {
return Ok(RecordBatch::new_empty(Arc::new(projection_schema)));
}

let row_id_arr = batch.column_by_name(ROW_ID).unwrap();
let row_ids: &UInt64Array = as_primitive_array(row_id_arr);
let remaining_columns =
dataset.take_rows(row_ids.values(), &take_schema).await?;

let batch = batch
.merge(&remaining_columns)?
.project_by_schema(&projection_schema)?;
let batch = if take_schema.fields.is_empty() {
batch.project_by_schema(&projection_schema)?
} else {
let remaining_columns =
dataset.take_rows(row_ids.values(), &take_schema).await?;
batch
.merge(&remaining_columns)?
.project_by_schema(&projection_schema)?
};

if !drop_row_id {
Ok(batch.try_with_column(
Expand Down Expand Up @@ -333,6 +337,7 @@ pub struct LocalTakeExec {
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
schema: Arc<Schema>,
ann_schema: Option<Arc<Schema>>,
drop_row_id: bool,
}

Expand All @@ -341,13 +346,15 @@ impl LocalTakeExec {
input: Arc<dyn ExecutionPlan>,
dataset: Arc<Dataset>,
schema: Arc<Schema>,
ann_schema: Option<Arc<Schema>>,
drop_row_id: bool,
) -> Self {
assert!(input.schema().column_with_name(ROW_ID).is_some());
Self {
dataset,
input,
schema,
ann_schema,
drop_row_id,
}
}
Expand Down Expand Up @@ -387,6 +394,7 @@ impl ExecutionPlan for LocalTakeExec {
input: children[0].clone(),
dataset: self.dataset.clone(),
schema: self.schema.clone(),
ann_schema: self.ann_schema.clone(),
drop_row_id: self.drop_row_id,
}))
}
Expand All @@ -401,6 +409,7 @@ impl ExecutionPlan for LocalTakeExec {
input_stream,
self.dataset.clone(),
self.schema.clone(),
self.ann_schema.clone(),
self.drop_row_id,
)?))
}
Expand Down