Skip to content

Commit

Permalink
feat: support FTS on multiple fields (#3025)
Browse files Browse the repository at this point in the history
  • Loading branch information
BubbleCal authored Oct 22, 2024
1 parent f572d63 commit 7413344
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 29 deletions.
71 changes: 70 additions & 1 deletion rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,10 @@ mod tests {
ArrayRef, DictionaryArray, Float32Array, Int32Array, Int64Array, Int8Array,
Int8DictionaryArray, RecordBatchIterator, StringArray, UInt16Array, UInt32Array,
};
use arrow_array::{FixedSizeListArray, Int16Array, Int16DictionaryArray, StructArray};
use arrow_array::{
Array, FixedSizeListArray, GenericStringArray, Int16Array, Int16DictionaryArray,
StructArray,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{
DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema,
Expand Down Expand Up @@ -4325,6 +4328,72 @@ mod tests {
);
}

#[tokio::test]
async fn test_fts_on_multiple_columns() {
let tempdir = tempfile::tempdir().unwrap();

let params = InvertedIndexParams::default();
let title_col =
GenericStringArray::<i32>::from(vec!["title hello", "title lance", "title common"]);
let content_col = GenericStringArray::<i32>::from(vec![
"content world",
"content database",
"content common",
]);
let batch = RecordBatch::try_new(
arrow_schema::Schema::new(vec![
arrow_schema::Field::new("title", title_col.data_type().to_owned(), false),
arrow_schema::Field::new("content", title_col.data_type().to_owned(), false),
])
.into(),
vec![
Arc::new(title_col) as ArrayRef,
Arc::new(content_col) as ArrayRef,
],
)
.unwrap();
let schema = batch.schema();
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let mut dataset = Dataset::write(batches, tempdir.path().to_str().unwrap(), None)
.await
.unwrap();
dataset
.create_index(&["title"], IndexType::Inverted, None, &params, true)
.await
.unwrap();
dataset
.create_index(&["content"], IndexType::Inverted, None, &params, true)
.await
.unwrap();

let results = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("title".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 3);

let results = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("content".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 3);

let results = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("common".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 2);
}

#[tokio::test]
async fn concurrent_create() {
async fn write(uri: &str) -> Result<()> {
Expand Down
44 changes: 19 additions & 25 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,35 +1325,29 @@ impl Scanner {
));
}

// Now the full text search supports only one column
if columns.len() != 1 {
return Err(Error::invalid_input(
format!(
"Full text search supports only one column right now, but got {} columns",
columns.len()
),
location!(),
));
// load indices
let mut indices = HashMap::with_capacity(columns.len());
for column in &columns {
let index = self
.dataset
.load_scalar_index_for_column(column)
.await?
.ok_or(Error::invalid_input(
format!("Column {} has no inverted index", column),
location!(),
))?;
let index_uuids = self
.dataset
.load_indices_by_name(&index.name)
.await?
.into_iter()
.collect();
indices.insert(column.clone(), index_uuids);
}
let column = &columns[0];
let index = self
.dataset
.load_scalar_index_for_column(column)
.await?
.ok_or(Error::invalid_input(
format!("Column {} has no inverted index", column),
location!(),
))?;
let index_uuids = self
.dataset
.load_indices_by_name(&index.name)
.await?
.into_iter()
.collect();

let query = query.clone().columns(Some(columns)).limit(self.limit);
let prefilter_source = self.prefilter_source(filter_plan).await?;
let fts_plan = FtsExec::new(self.dataset.clone(), index_uuids, query, prefilter_source);
let fts_plan = FtsExec::new(self.dataset.clone(), indices, query, prefilter_source);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(SCORE_COL, fts_plan.schema().as_ref())?,
options: SortOptions {
Expand Down
10 changes: 7 additions & 3 deletions rust/lance/src/io/exec/fts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::collections::HashMap;
use std::sync::Arc;

use arrow_array::{Float32Array, RecordBatch, UInt64Array};
Expand All @@ -15,6 +16,7 @@ use datafusion::physical_plan::{
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use futures::stream::{self};
use futures::StreamExt;
use itertools::Itertools;
use lance_core::ROW_ID_FIELD;
use lance_index::prefilter::{FilterLoader, PreFilter};
use lance_index::scalar::inverted::{InvertedIndex, SCORE_FIELD};
Expand All @@ -40,7 +42,8 @@ lazy_static::lazy_static! {
#[derive(Debug)]
pub struct FtsExec {
dataset: Arc<Dataset>,
indices: Vec<Index>,
// column -> indices
indices: HashMap<String, Vec<Index>>,
query: FullTextSearchQuery,
/// Prefiltering input
prefilter_source: PreFilterSource,
Expand All @@ -60,7 +63,7 @@ impl DisplayAs for FtsExec {
impl FtsExec {
pub fn new(
dataset: Arc<Dataset>,
indices: Vec<Index>,
indices: HashMap<String, Vec<Index>>,
query: FullTextSearchQuery,
prefilter_source: PreFilterSource,
) -> Self {
Expand Down Expand Up @@ -117,7 +120,8 @@ impl ExecutionPlan for FtsExec {
let ds = self.dataset.clone();
let prefilter_source = self.prefilter_source.clone();

let stream = stream::iter(self.indices.clone())
let indices = self.indices.values().flatten().cloned().collect_vec();
let stream = stream::iter(indices)
.map(move |index_meta| {
let uuid = index_meta.uuid.to_string();
let query = query.clone();
Expand Down

0 comments on commit 7413344

Please sign in to comment.