From 7413344dedcefd6bf2251ef5d5de85771f615aeb Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 22 Oct 2024 10:57:20 +0800 Subject: [PATCH] feat: support FTS on multiple fields (#3025) --- rust/lance/src/dataset.rs | 71 ++++++++++++++++++++++++++++++- rust/lance/src/dataset/scanner.rs | 44 +++++++++---------- rust/lance/src/io/exec/fts.rs | 10 +++-- 3 files changed, 96 insertions(+), 29 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 13f1c141bb..799576e66a 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1692,7 +1692,10 @@ mod tests { ArrayRef, DictionaryArray, Float32Array, Int32Array, Int64Array, Int8Array, Int8DictionaryArray, RecordBatchIterator, StringArray, UInt16Array, UInt32Array, }; - use arrow_array::{FixedSizeListArray, Int16Array, Int16DictionaryArray, StructArray}; + use arrow_array::{ + Array, FixedSizeListArray, GenericStringArray, Int16Array, Int16DictionaryArray, + StructArray, + }; use arrow_ord::sort::sort_to_indices; use arrow_schema::{ DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, @@ -4325,6 +4328,72 @@ mod tests { ); } + #[tokio::test] + async fn test_fts_on_multiple_columns() { + let tempdir = tempfile::tempdir().unwrap(); + + let params = InvertedIndexParams::default(); + let title_col = + GenericStringArray::::from(vec!["title hello", "title lance", "title common"]); + let content_col = GenericStringArray::::from(vec![ + "content world", + "content database", + "content common", + ]); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("title", title_col.data_type().to_owned(), false), + arrow_schema::Field::new("content", title_col.data_type().to_owned(), false), + ]) + .into(), + vec![ + Arc::new(title_col) as ArrayRef, + Arc::new(content_col) as ArrayRef, + ], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, tempdir.path().to_str().unwrap(), None) + .await + .unwrap(); + dataset + .create_index(&["title"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + dataset + .create_index(&["content"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + let results = dataset + .scan() + .full_text_search(FullTextSearchQuery::new("title".to_owned())) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 3); + + let results = dataset + .scan() + .full_text_search(FullTextSearchQuery::new("content".to_owned())) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 3); + + let results = dataset + .scan() + .full_text_search(FullTextSearchQuery::new("common".to_owned())) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 2); + } + #[tokio::test] async fn concurrent_create() { async fn write(uri: &str) -> Result<()> { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index fa834a889c..ce49568087 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1325,35 +1325,29 @@ impl Scanner { )); } - // Now the full text search supports only one column - if columns.len() != 1 { - return Err(Error::invalid_input( - format!( - "Full text search supports only one column right now, but got {} columns", - columns.len() - ), - location!(), - )); + // load indices + let mut indices = HashMap::with_capacity(columns.len()); + for column in &columns { + let index = self + .dataset + .load_scalar_index_for_column(column) + .await? + .ok_or(Error::invalid_input( + format!("Column {} has no inverted index", column), + location!(), + ))?; + let index_uuids = self + .dataset + .load_indices_by_name(&index.name) + .await? + .into_iter() + .collect(); + indices.insert(column.clone(), index_uuids); } - let column = &columns[0]; - let index = self - .dataset - .load_scalar_index_for_column(column) - .await? - .ok_or(Error::invalid_input( - format!("Column {} has no inverted index", column), - location!(), - ))?; - let index_uuids = self - .dataset - .load_indices_by_name(&index.name) - .await? - .into_iter() - .collect(); let query = query.clone().columns(Some(columns)).limit(self.limit); let prefilter_source = self.prefilter_source(filter_plan).await?; - let fts_plan = FtsExec::new(self.dataset.clone(), index_uuids, query, prefilter_source); + let fts_plan = FtsExec::new(self.dataset.clone(), indices, query, prefilter_source); let sort_expr = PhysicalSortExpr { expr: expressions::col(SCORE_COL, fts_plan.schema().as_ref())?, options: SortOptions { diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index 54a8e96f97..e9a08a979f 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::sync::Arc; use arrow_array::{Float32Array, RecordBatch, UInt64Array}; @@ -15,6 +16,7 @@ use datafusion::physical_plan::{ use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use futures::stream::{self}; use futures::StreamExt; +use itertools::Itertools; use lance_core::ROW_ID_FIELD; use lance_index::prefilter::{FilterLoader, PreFilter}; use lance_index::scalar::inverted::{InvertedIndex, SCORE_FIELD}; @@ -40,7 +42,8 @@ lazy_static::lazy_static! { #[derive(Debug)] pub struct FtsExec { dataset: Arc, - indices: Vec, + // column -> indices + indices: HashMap>, query: FullTextSearchQuery, /// Prefiltering input prefilter_source: PreFilterSource, @@ -60,7 +63,7 @@ impl DisplayAs for FtsExec { impl FtsExec { pub fn new( dataset: Arc, - indices: Vec, + indices: HashMap>, query: FullTextSearchQuery, prefilter_source: PreFilterSource, ) -> Self { @@ -117,7 +120,8 @@ impl ExecutionPlan for FtsExec { let ds = self.dataset.clone(); let prefilter_source = self.prefilter_source.clone(); - let stream = stream::iter(self.indices.clone()) + let indices = self.indices.values().flatten().cloned().collect_vec(); + let stream = stream::iter(indices) .map(move |index_meta| { let uuid = index_meta.uuid.to_string(); let query = query.clone();