Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support FTS on multiple fields #3025

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1691,7 +1691,10 @@ mod tests {
ArrayRef, DictionaryArray, Float32Array, Int32Array, Int64Array, Int8Array,
Int8DictionaryArray, RecordBatchIterator, StringArray, UInt16Array, UInt32Array,
};
use arrow_array::{FixedSizeListArray, Int16Array, Int16DictionaryArray, StructArray};
use arrow_array::{
Array, FixedSizeListArray, GenericStringArray, Int16Array, Int16DictionaryArray,
StructArray,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{
DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema,
Expand Down Expand Up @@ -4324,6 +4327,72 @@ mod tests {
);
}

#[tokio::test]
async fn test_fts_on_multiple_columns() {
let tempdir = tempfile::tempdir().unwrap();

let params = InvertedIndexParams::default();
let title_col =
GenericStringArray::<i32>::from(vec!["title hello", "title lance", "title common"]);
let content_col = GenericStringArray::<i32>::from(vec![
"content world",
"content database",
"content common",
]);
let batch = RecordBatch::try_new(
arrow_schema::Schema::new(vec![
arrow_schema::Field::new("title", title_col.data_type().to_owned(), false),
arrow_schema::Field::new("content", title_col.data_type().to_owned(), false),
])
.into(),
vec![
Arc::new(title_col) as ArrayRef,
Arc::new(content_col) as ArrayRef,
],
)
.unwrap();
let schema = batch.schema();
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let mut dataset = Dataset::write(batches, tempdir.path().to_str().unwrap(), None)
.await
.unwrap();
dataset
.create_index(&["title"], IndexType::Inverted, None, &params, true)
.await
.unwrap();
dataset
.create_index(&["content"], IndexType::Inverted, None, &params, true)
.await
.unwrap();

let results = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("title".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 3);

let results = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("content".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 3);

let results = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("common".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 2);
}

#[tokio::test]
async fn concurrent_create() {
async fn write(uri: &str) -> Result<()> {
Expand Down
44 changes: 19 additions & 25 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,35 +1325,29 @@ impl Scanner {
));
}

// Now the full text search supports only one column
if columns.len() != 1 {
return Err(Error::invalid_input(
format!(
"Full text search supports only one column right now, but got {} columns",
columns.len()
),
location!(),
));
// load indices
let mut indices = HashMap::with_capacity(columns.len());
for column in &columns {
let index = self
.dataset
.load_scalar_index_for_column(column)
.await?
.ok_or(Error::invalid_input(
format!("Column {} has no inverted index", column),
location!(),
))?;
let index_uuids = self
.dataset
.load_indices_by_name(&index.name)
.await?
.into_iter()
.collect();
indices.insert(column.clone(), index_uuids);
}
let column = &columns[0];
let index = self
.dataset
.load_scalar_index_for_column(column)
.await?
.ok_or(Error::invalid_input(
format!("Column {} has no inverted index", column),
location!(),
))?;
let index_uuids = self
.dataset
.load_indices_by_name(&index.name)
.await?
.into_iter()
.collect();

let query = query.clone().columns(Some(columns)).limit(self.limit);
let prefilter_source = self.prefilter_source(filter_plan).await?;
let fts_plan = FtsExec::new(self.dataset.clone(), index_uuids, query, prefilter_source);
let fts_plan = FtsExec::new(self.dataset.clone(), indices, query, prefilter_source);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(SCORE_COL, fts_plan.schema().as_ref())?,
options: SortOptions {
Expand Down
10 changes: 7 additions & 3 deletions rust/lance/src/io/exec/fts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::collections::HashMap;
use std::sync::Arc;

use arrow_array::{Float32Array, RecordBatch, UInt64Array};
Expand All @@ -15,6 +16,7 @@ use datafusion::physical_plan::{
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use futures::stream::{self};
use futures::StreamExt;
use itertools::Itertools;
use lance_core::ROW_ID_FIELD;
use lance_index::prefilter::{FilterLoader, PreFilter};
use lance_index::scalar::inverted::{InvertedIndex, SCORE_FIELD};
Expand All @@ -40,7 +42,8 @@ lazy_static::lazy_static! {
#[derive(Debug)]
pub struct FtsExec {
dataset: Arc<Dataset>,
indices: Vec<Index>,
// column -> indices
indices: HashMap<String, Vec<Index>>,
query: FullTextSearchQuery,
/// Prefiltering input
prefilter_source: PreFilterSource,
Expand All @@ -60,7 +63,7 @@ impl DisplayAs for FtsExec {
impl FtsExec {
pub fn new(
dataset: Arc<Dataset>,
indices: Vec<Index>,
indices: HashMap<String, Vec<Index>>,
query: FullTextSearchQuery,
prefilter_source: PreFilterSource,
) -> Self {
Expand Down Expand Up @@ -117,7 +120,8 @@ impl ExecutionPlan for FtsExec {
let ds = self.dataset.clone();
let prefilter_source = self.prefilter_source.clone();

let stream = stream::iter(self.indices.clone())
let indices = self.indices.values().flatten().cloned().collect_vec();
let stream = stream::iter(indices)
.map(move |index_meta| {
let uuid = index_meta.uuid.to_string();
let query = query.clone();
Expand Down
Loading