Skip to content

Commit

Permalink
Missing column when both nearest and filter are applied (#686)
Browse files Browse the repository at this point in the history
* failing unit test to repro #685

* fix

* address PR comments
  • Loading branch information
changhiskhan authored Mar 16, 2023
1 parent fa1847b commit 9a3364c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 59 deletions.
47 changes: 25 additions & 22 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import lance
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from lance.vector import vec_to_table

Expand Down Expand Up @@ -70,11 +71,10 @@ def run(ds):
q = np.random.randn(768)
project = [None, ["price"], ["vector"], ["vector", "meta"]]
refine = [None, 1, 2]
# filters = [None, pc.field("price") > 50.0]
filters = [None, pc.field("price") > 50.0]
times = []

for columns in project:

expected_columns = []
if columns is None:
expected_columns.extend(ds.schema.names)
Expand All @@ -84,26 +84,29 @@ def run(ds):
if c not in expected_columns:
expected_columns.append(c)

for rf in refine:
# for filter_ in filters:
start = time.time()
rs = ds.to_table(
columns=columns,
# filter=filter_,
nearest={
"column": "vector",
"q": q,
"k": 10,
"nprobes": 1,
"refine_factor": rf,
},
)
end = time.time()
times.append(end - start)
assert rs.column_names == expected_columns
assert len(rs) == 10
scores = rs["score"].to_numpy()
assert (scores.max() - scores.min()) > 1e-6
for filter_ in filters:
for rf in refine:
start = time.time()
rs = ds.to_table(
columns=columns,
nearest={
"column": "vector",
"q": q,
"k": 10,
"nprobes": 1,
"refine_factor": rf,
},
filter=filter_,
)
end = time.time()
times.append(end - start)
assert rs.column_names == expected_columns
if filter_ is not None and "price" in (columns or []):
inmem = pa.dataset.dataset(rs)
assert len(inmem.to_table(filter=filter_)) == len(rs)
assert len(rs) == 10
scores = rs["score"].to_numpy()
assert (scores.max() - scores.min()) > 1e-6
return times


Expand Down
60 changes: 34 additions & 26 deletions rust/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use sqlparser::{dialect::GenericDialect, parser::Parser};

use super::Dataset;
use crate::datafusion::physical_expr::column_names_in_expr;
use crate::datatypes::Schema;
use crate::datatypes::{Field, Schema};
use crate::format::Index;
use crate::index::vector::{MetricType, Query};
use crate::io::exec::{GlobalTakeExec, KNNFlatExec, KNNIndexExec, LanceScanExec, LocalTakeExec};
Expand Down Expand Up @@ -212,28 +212,29 @@ impl Scanner {
self
}

/// The schema of the output, a.k.a, projection schema.
/// The Arrow schema of the output, including projections and vector / score
pub fn schema(&self) -> Result<SchemaRef> {
self.scanner_output_schema()
.map(|s| SchemaRef::new(ArrowSchema::from(s.as_ref())))
}

fn scanner_output_schema(&self) -> Result<Arc<Schema>> {
if self.nearest.as_ref().is_some() {
let q = self.nearest.as_ref().unwrap();
let column: ArrowField = self
.dataset
.schema()
.field(q.column.as_str())
.ok_or_else(|| {
Error::Schema(format!("Vector column {} not found in schema", q.column))
})?
.into();
let score = ArrowField::new("score", Float32, false);
let score_schema = ArrowSchema::new(vec![column, score]);
let vector_search_columns = &Schema::try_from(&score_schema)?;
let merged = self.projections.merge(vector_search_columns);
Ok(SchemaRef::new(ArrowSchema::from(&merged)))
let merged = self.projections.merge(&self.vector_search_schema()?);
Ok(Arc::new(merged))
} else {
Ok(Arc::new(ArrowSchema::from(&self.projections)))
Ok(Arc::new(self.projections.clone()))
}
}

fn vector_search_schema(&self) -> Result<Schema> {
let q = self.nearest.as_ref().unwrap();
let vector_schema = self.dataset.schema().project(&[&q.column])?;
let score = ArrowField::new("score", Float32, false);
let score_schema = Schema::try_from(&ArrowSchema::new(vec![score]))?;
Ok(vector_schema.merge(&score_schema))
}

/// Create a stream of this Scanner.
///
/// TODO: implement as IntoStream/IntoIterator.
Expand Down Expand Up @@ -267,29 +268,35 @@ impl Scanner {
}
}

let knn_node = self.ann(q, &index);
let knn_node = self.ann(q, &index); // score, _rowid
let with_vector = self.dataset.schema().project(&[&q.column])?;
let knn_node_with_vector = self.take(knn_node, &with_vector, false);
let knn_node = if q.refine_factor.is_some() {
self.flat_knn(knn_node_with_vector, q)
} else {
knn_node_with_vector
};
}; // vector, score, _rowid

let knn_node = if let Some(filter_expression) = filter_expr {
let columns_in_filter = column_names_in_expr(filter_expression.as_ref());
let columns_refs = columns_in_filter
.iter()
.map(|c| c.as_str())
.collect::<Vec<_>>();
let filter_projection = Arc::new(self.dataset.schema().project(&columns_refs)?);
let filter_projection = self.dataset.schema().project(&columns_refs)?;

let take_node = Arc::new(GlobalTakeExec::new(
self.dataset.clone(),
filter_projection,
Arc::new(filter_projection),
knn_node,
false,
));
self.filter_node(filter_expression, take_node, false)?
self.filter_node(
filter_expression,
take_node,
false,
Some(Arc::new(self.vector_search_schema()?)),
)?
} else {
knn_node
};
Expand All @@ -313,7 +320,7 @@ impl Scanner {
)?,
);
let scan = self.scan(true, filter_schema);
self.filter_node(filter, scan, true)?
self.filter_node(filter, scan, true, None)?
} else {
self.scan(with_row_id, Arc::new(self.projections.clone()))
};
Expand Down Expand Up @@ -388,12 +395,15 @@ impl Scanner {
filter: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
drop_row_id: bool,
ann_schema: Option<Arc<Schema>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let filter_node = Arc::new(FilterExec::try_new(filter, input)?);
let output_schema = self.scanner_output_schema()?;
Ok(Arc::new(LocalTakeExec::new(
filter_node,
self.dataset.clone(),
Arc::new(self.projections.clone()),
output_schema,
ann_schema,
drop_row_id,
)))
}
Expand Down Expand Up @@ -518,7 +528,6 @@ mod test {
assert!(scan.filter.is_none());

scan.filter("i > 50").unwrap();
println!("Filter is: {:?}", scan.filter);
assert_eq!(scan.filter, Some("i > 50".to_string()));

let batches = scan
Expand All @@ -530,7 +539,6 @@ mod test {
.try_collect::<Vec<_>>()
.await
.unwrap();
println!("Batches: {:?}\n", batches);
let batch = concat_batches(&batches[0].schema(), &batches).unwrap();

let expected_batch = RecordBatch::try_new(
Expand Down
31 changes: 20 additions & 11 deletions rust/src/io/exec/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion::physical_plan::{
ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use futures::stream::{self, Stream, StreamExt, TryStreamExt};
use futures::{FutureExt, TryFutureExt};
use tokio::sync::mpsc::{self, Receiver};
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -236,12 +237,15 @@ impl LocalTake {
input: SendableRecordBatchStream,
dataset: Arc<Dataset>,
schema: Arc<Schema>,
ann_schema: Option<Arc<Schema>>, // TODO add input/output schema contract to exec nodes and remove this
drop_row_id: bool,
) -> Result<Self> {
let (tx, rx) = mpsc::channel(4);

let inner_schema = Schema::try_from(input.schema().as_ref())?;
let take_schema = schema.exclude(&inner_schema)?;
let mut take_schema = schema.exclude(&inner_schema)?;
if ann_schema.is_some() {
take_schema = take_schema.exclude(&ann_schema.unwrap())?;
}
let projection = schema.clone();

let _bg_thread = tokio::spawn(async move {
Expand All @@ -252,22 +256,22 @@ impl LocalTake {
.then(|(b, (dataset, take_schema, projection))| async move {
// TODO: need to cache the fragments.
let batch = b?;
if take_schema.fields.is_empty() {
return Ok(batch);
};
let projection_schema = ArrowSchema::from(projection.as_ref());
if batch.num_rows() == 0 {
return Ok(RecordBatch::new_empty(Arc::new(projection_schema)));
}

let row_id_arr = batch.column_by_name(ROW_ID).unwrap();
let row_ids: &UInt64Array = as_primitive_array(row_id_arr);
let remaining_columns =
dataset.take_rows(row_ids.values(), &take_schema).await?;

let batch = batch
.merge(&remaining_columns)?
.project_by_schema(&projection_schema)?;
let batch = if take_schema.fields.is_empty() {
batch.project_by_schema(&projection_schema)?
} else {
let remaining_columns =
dataset.take_rows(row_ids.values(), &take_schema).await?;
batch
.merge(&remaining_columns)?
.project_by_schema(&projection_schema)?
};

if !drop_row_id {
Ok(batch.try_with_column(
Expand Down Expand Up @@ -333,6 +337,7 @@ pub struct LocalTakeExec {
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
schema: Arc<Schema>,
ann_schema: Option<Arc<Schema>>,
drop_row_id: bool,
}

Expand All @@ -341,13 +346,15 @@ impl LocalTakeExec {
input: Arc<dyn ExecutionPlan>,
dataset: Arc<Dataset>,
schema: Arc<Schema>,
ann_schema: Option<Arc<Schema>>,
drop_row_id: bool,
) -> Self {
assert!(input.schema().column_with_name(ROW_ID).is_some());
Self {
dataset,
input,
schema,
ann_schema,
drop_row_id,
}
}
Expand Down Expand Up @@ -387,6 +394,7 @@ impl ExecutionPlan for LocalTakeExec {
input: children[0].clone(),
dataset: self.dataset.clone(),
schema: self.schema.clone(),
ann_schema: self.ann_schema.clone(),
drop_row_id: self.drop_row_id,
}))
}
Expand All @@ -401,6 +409,7 @@ impl ExecutionPlan for LocalTakeExec {
input_stream,
self.dataset.clone(),
self.schema.clone(),
self.ann_schema.clone(),
self.drop_row_id,
)?))
}
Expand Down

0 comments on commit 9a3364c

Please sign in to comment.