-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing column when both nearest
and filter
are applied
#686
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ use sqlparser::{dialect::GenericDialect, parser::Parser}; | |
|
||
use super::Dataset; | ||
use crate::datafusion::physical_expr::column_names_in_expr; | ||
use crate::datatypes::Schema; | ||
use crate::datatypes::{Field, Schema}; | ||
use crate::format::Index; | ||
use crate::index::vector::{MetricType, Query}; | ||
use crate::io::exec::{GlobalTakeExec, KNNFlatExec, KNNIndexExec, LanceScanExec, LocalTakeExec}; | ||
|
@@ -216,24 +216,37 @@ impl Scanner { | |
pub fn schema(&self) -> Result<SchemaRef> { | ||
if self.nearest.as_ref().is_some() { | ||
let q = self.nearest.as_ref().unwrap(); | ||
let column: ArrowField = self | ||
.dataset | ||
.schema() | ||
.field(q.column.as_str()) | ||
.ok_or_else(|| { | ||
Error::Schema(format!("Vector column {} not found in schema", q.column)) | ||
})? | ||
.into(); | ||
let vector_schema = self.dataset.schema().project(&[&q.column])?; | ||
let score = ArrowField::new("score", Float32, false); | ||
let score_schema = ArrowSchema::new(vec![column, score]); | ||
let vector_search_columns = &Schema::try_from(&score_schema)?; | ||
let merged = self.projections.merge(vector_search_columns); | ||
let score_schema = ArrowSchema::new(vec![score]); | ||
|
||
let merged = self | ||
.projections | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the differnence of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. => vector_schema is just the vector and score. Neither is the superset of the other. |
||
.merge(&vector_schema) | ||
.merge(&Schema::try_from(&score_schema)?); | ||
Ok(SchemaRef::new(ArrowSchema::from(&merged))) | ||
} else { | ||
Ok(Arc::new(ArrowSchema::from(&self.projections))) | ||
} | ||
} | ||
|
||
fn scanner_output_schema(&self) -> Result<Arc<Schema>> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the API aspect, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scanner::schema is the ArrowSchema and should not be used in Lance reader internals because that's why the field id's are messed up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll just make Scanner::schema call the other |
||
if self.nearest.as_ref().is_some() { | ||
let merged = self.projections.merge(&self.vector_search_schema()?); | ||
Ok(Arc::new(merged)) | ||
} else { | ||
Ok(Arc::new(self.projections.clone())) | ||
} | ||
} | ||
|
||
fn vector_search_schema(&self) -> Result<Schema> { | ||
let q = self.nearest.as_ref().unwrap(); | ||
let vector_schema = self.dataset.schema().project(&[&q.column])?; | ||
let score = ArrowField::new("score", Float32, false); | ||
let score_schema = Schema::try_from(&ArrowSchema::new(vec![score]))?; | ||
Ok(vector_schema.merge(&score_schema)) | ||
} | ||
|
||
/// Create a stream of this Scanner. | ||
/// | ||
/// TODO: implement as IntoStream/IntoIterator. | ||
|
@@ -267,29 +280,35 @@ impl Scanner { | |
} | ||
} | ||
|
||
let knn_node = self.ann(q, &index); | ||
let knn_node = self.ann(q, &index); // score, _rowid | ||
let with_vector = self.dataset.schema().project(&[&q.column])?; | ||
let knn_node_with_vector = self.take(knn_node, &with_vector, false); | ||
let knn_node = if q.refine_factor.is_some() { | ||
self.flat_knn(knn_node_with_vector, q) | ||
} else { | ||
knn_node_with_vector | ||
}; | ||
}; // vector, score, _rowid | ||
|
||
let knn_node = if let Some(filter_expression) = filter_expr { | ||
let columns_in_filter = column_names_in_expr(filter_expression.as_ref()); | ||
let columns_refs = columns_in_filter | ||
.iter() | ||
.map(|c| c.as_str()) | ||
.collect::<Vec<_>>(); | ||
let filter_projection = Arc::new(self.dataset.schema().project(&columns_refs)?); | ||
let filter_projection = self.dataset.schema().project(&columns_refs)?; | ||
|
||
let take_node = Arc::new(GlobalTakeExec::new( | ||
self.dataset.clone(), | ||
filter_projection, | ||
Arc::new(filter_projection), | ||
knn_node, | ||
false, | ||
)); | ||
self.filter_node(filter_expression, take_node, false)? | ||
self.filter_node( | ||
filter_expression, | ||
take_node, | ||
false, | ||
Some(Arc::new(self.vector_search_schema()?)), | ||
)? | ||
} else { | ||
knn_node | ||
}; | ||
|
@@ -313,7 +332,7 @@ impl Scanner { | |
)?, | ||
); | ||
let scan = self.scan(true, filter_schema); | ||
self.filter_node(filter, scan, true)? | ||
self.filter_node(filter, scan, true, None)? | ||
} else { | ||
self.scan(with_row_id, Arc::new(self.projections.clone())) | ||
}; | ||
|
@@ -388,12 +407,15 @@ impl Scanner { | |
filter: Arc<dyn PhysicalExpr>, | ||
input: Arc<dyn ExecutionPlan>, | ||
drop_row_id: bool, | ||
ann_schema: Option<Arc<Schema>>, | ||
) -> Result<Arc<dyn ExecutionPlan>> { | ||
let filter_node = Arc::new(FilterExec::try_new(filter, input)?); | ||
let output_schema = self.scanner_output_schema()?; | ||
Ok(Arc::new(LocalTakeExec::new( | ||
filter_node, | ||
self.dataset.clone(), | ||
Arc::new(self.projections.clone()), | ||
output_schema, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is leaking the abstraction of ann to take? what is in output_schema nad There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. output_schema includes user supplied projections. ann_schema includes the vector/score columns. |
||
ann_schema, | ||
drop_row_id, | ||
))) | ||
} | ||
|
@@ -518,7 +540,6 @@ mod test { | |
assert!(scan.filter.is_none()); | ||
|
||
scan.filter("i > 50").unwrap(); | ||
println!("Filter is: {:?}", scan.filter); | ||
assert_eq!(scan.filter, Some("i > 50".to_string())); | ||
|
||
let batches = scan | ||
|
@@ -530,7 +551,6 @@ mod test { | |
.try_collect::<Vec<_>>() | ||
.await | ||
.unwrap(); | ||
println!("Batches: {:?}\n", batches); | ||
let batch = concat_batches(&batches[0].schema(), &batches).unwrap(); | ||
|
||
let expected_batch = RecordBatch::try_new( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ use datafusion::physical_plan::{ | |
ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, Statistics, | ||
}; | ||
use futures::stream::{self, Stream, StreamExt, TryStreamExt}; | ||
use futures::{FutureExt, TryFutureExt}; | ||
use tokio::sync::mpsc::{self, Receiver}; | ||
use tokio::task::JoinHandle; | ||
|
||
|
@@ -236,12 +237,15 @@ impl LocalTake { | |
input: SendableRecordBatchStream, | ||
dataset: Arc<Dataset>, | ||
schema: Arc<Schema>, | ||
ann_schema: Option<Arc<Schema>>, // TODO add input/output schema contract to exec nodes and remove this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. File a ticket to remove this leaking abstraction? |
||
drop_row_id: bool, | ||
) -> Result<Self> { | ||
let (tx, rx) = mpsc::channel(4); | ||
|
||
let inner_schema = Schema::try_from(input.schema().as_ref())?; | ||
let take_schema = schema.exclude(&inner_schema)?; | ||
let mut take_schema = schema.exclude(&inner_schema)?; | ||
if ann_schema.is_some() { | ||
take_schema = take_schema.exclude(&ann_schema.unwrap())?; | ||
} | ||
let projection = schema.clone(); | ||
|
||
let _bg_thread = tokio::spawn(async move { | ||
|
@@ -252,22 +256,22 @@ impl LocalTake { | |
.then(|(b, (dataset, take_schema, projection))| async move { | ||
// TODO: need to cache the fragments. | ||
let batch = b?; | ||
if take_schema.fields.is_empty() { | ||
return Ok(batch); | ||
}; | ||
let projection_schema = ArrowSchema::from(projection.as_ref()); | ||
if batch.num_rows() == 0 { | ||
return Ok(RecordBatch::new_empty(Arc::new(projection_schema))); | ||
} | ||
|
||
let row_id_arr = batch.column_by_name(ROW_ID).unwrap(); | ||
let row_ids: &UInt64Array = as_primitive_array(row_id_arr); | ||
let remaining_columns = | ||
dataset.take_rows(row_ids.values(), &take_schema).await?; | ||
|
||
let batch = batch | ||
.merge(&remaining_columns)? | ||
.project_by_schema(&projection_schema)?; | ||
let batch = if take_schema.fields.is_empty() { | ||
batch.project_by_schema(&projection_schema)? | ||
} else { | ||
let remaining_columns = | ||
dataset.take_rows(row_ids.values(), &take_schema).await?; | ||
batch | ||
.merge(&remaining_columns)? | ||
.project_by_schema(&projection_schema)? | ||
}; | ||
|
||
if !drop_row_id { | ||
Ok(batch.try_with_column( | ||
|
@@ -333,6 +337,7 @@ pub struct LocalTakeExec { | |
dataset: Arc<Dataset>, | ||
input: Arc<dyn ExecutionPlan>, | ||
schema: Arc<Schema>, | ||
ann_schema: Option<Arc<Schema>>, | ||
drop_row_id: bool, | ||
} | ||
|
||
|
@@ -341,13 +346,15 @@ impl LocalTakeExec { | |
input: Arc<dyn ExecutionPlan>, | ||
dataset: Arc<Dataset>, | ||
schema: Arc<Schema>, | ||
ann_schema: Option<Arc<Schema>>, | ||
drop_row_id: bool, | ||
) -> Self { | ||
assert!(input.schema().column_with_name(ROW_ID).is_some()); | ||
Self { | ||
dataset, | ||
input, | ||
schema, | ||
ann_schema, | ||
drop_row_id, | ||
} | ||
} | ||
|
@@ -387,6 +394,7 @@ impl ExecutionPlan for LocalTakeExec { | |
input: children[0].clone(), | ||
dataset: self.dataset.clone(), | ||
schema: self.schema.clone(), | ||
ann_schema: self.ann_schema.clone(), | ||
drop_row_id: self.drop_row_id, | ||
})) | ||
} | ||
|
@@ -401,6 +409,7 @@ impl ExecutionPlan for LocalTakeExec { | |
input_stream, | ||
self.dataset.clone(), | ||
self.schema.clone(), | ||
self.ann_schema.clone(), | ||
self.drop_row_id, | ||
)?)) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
self.vector_search_schema()
?