diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index d5e1631671..9bb176e0d7 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -165,7 +165,7 @@ impl Wand { let score = self.score(doc, &scorer); if self.candidates.len() < limit { self.candidates.push(Reverse(OrderedDoc::new(doc, score))); - } else if score > self.threshold { + } else if score > self.candidates.peek().unwrap().0.score.0 { self.candidates.pop(); self.candidates.push(Reverse(OrderedDoc::new(doc, score))); self.threshold = self.candidates.peek().unwrap().0.score.0 * factor; diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 84ba4bf528..cbcf878d78 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1680,6 +1680,7 @@ mod tests { use arrow::array::{as_struct_array, AsArray}; use arrow::compute::concat_batches; + use arrow::datatypes::UInt64Type; use arrow_array::{ builder::StringDictionaryBuilder, cast::as_string_array, @@ -4614,6 +4615,76 @@ mod tests { assert_eq!(results.num_rows(), 1); } + #[tokio::test] + async fn test_fts_rank() { + let tempdir = tempfile::tempdir().unwrap(); + + let params = InvertedIndexParams::default(); + let text_col = + GenericStringArray::::from(vec!["score", "find score", "try to find score"]); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "text", + text_col.data_type().to_owned(), + false, + )]) + .into(), + vec![Arc::new(text_col) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, tempdir.path().to_str().unwrap(), None) + .await + .unwrap(); + dataset + .create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + let results = dataset + .scan() + .with_row_id() + .full_text_search(FullTextSearchQuery::new("score".to_owned())) + .unwrap() + .limit(Some(3), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 3); + let row_ids = results[ROW_ID].as_primitive::().values(); + assert_eq!(row_ids, &[0, 1, 2]); + + let results = dataset + .scan() + .with_row_id() + .full_text_search(FullTextSearchQuery::new("score".to_owned())) + .unwrap() + .limit(Some(2), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 2); + let row_ids = results[ROW_ID].as_primitive::().values(); + assert_eq!(row_ids, &[0, 1]); + + let results = dataset + .scan() + .with_row_id() + .full_text_search(FullTextSearchQuery::new("score".to_owned())) + .unwrap() + .limit(Some(1), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 1); + let row_ids = results[ROW_ID].as_primitive::().values(); + assert_eq!(row_ids, &[0]); + } + #[tokio::test] async fn concurrent_create() { async fn write(uri: &str) -> Result<()> {