Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend FuzzyTermQuery to support json field #2173

Merged
merged 5 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/query/automaton_weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;

use super::phrase_prefix_query::prefix_end;
use crate::core::SegmentReader;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
Expand All @@ -14,6 +15,10 @@ use crate::{DocId, Score, TantivyError};
pub struct AutomatonWeight<A> {
field: Field,
automaton: Arc<A>,
// For JSON fields, the term dictionary include terms from all paths.
// We apply additional filtering based on the given JSON path, when searching within the term
// dictionary. This prevents terms from unrelated paths from matching the search criteria.
json_path_bytes: Option<Box<[u8]>>,
}

impl<A> AutomatonWeight<A>
Expand All @@ -26,6 +31,20 @@ where
AutomatonWeight {
field,
automaton: automaton.into(),
json_path_bytes: None,
}
}

/// Create a new AutomationWeight for a json path
pub fn new_for_json_path<IntoArcA: Into<Arc<A>>>(
field: Field,
automaton: IntoArcA,
json_path_bytes: &[u8],
) -> AutomatonWeight<A> {
AutomatonWeight {
field,
automaton: automaton.into(),
json_path_bytes: Some(json_path_bytes.to_vec().into_boxed_slice()),
}
}

Expand All @@ -34,7 +53,15 @@ where
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
let automaton: &A = &self.automaton;
let term_stream_builder = term_dict.search(automaton);
let mut term_stream_builder = term_dict.search(automaton);

if let Some(json_path_bytes) = &self.json_path_bytes {
term_stream_builder = term_stream_builder.ge(json_path_bytes);
if let Some(end) = prefix_end(json_path_bytes) {
term_stream_builder = term_stream_builder.lt(&end);
}
}

term_stream_builder.into_stream()
}
}
Expand Down
126 changes: 117 additions & 9 deletions src/query/fuzzy_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use once_cell::sync::OnceCell;
use tantivy_fst::Automaton;

use crate::query::{AutomatonWeight, EnableScoring, Query, Weight};
use crate::schema::Term;
use crate::schema::{Term, Type};
use crate::TantivyError::InvalidArgument;

pub(crate) struct DfaWrapper(pub DFA);
Expand Down Expand Up @@ -132,18 +132,46 @@ impl FuzzyTermQuery {
});

let term_value = self.term.value();
let term_text = term_value.as_str().ok_or_else(|| {
InvalidArgument("The fuzzy term query requires a string term.".to_string())
})?;

let term_text = if term_value.typ() == Type::Json {
if let Some(json_path_type) = term_value.json_path_type() {
if json_path_type != Type::Str {
return Err(InvalidArgument(format!(
"The fuzzy term query requires a string path type for a json term. Found \
{:?}",
json_path_type
)));
}
}

std::str::from_utf8(self.term.serialized_value_bytes()).map_err(|_| {
InvalidArgument(
"Failed to convert json term value bytes to utf8 string.".to_string(),
)
})?
} else {
term_value.as_str().ok_or_else(|| {
InvalidArgument("The fuzzy term query requires a string term.".to_string())
})?
};
let automaton = if self.prefix {
automaton_builder.build_prefix_dfa(term_text)
} else {
automaton_builder.build_dfa(term_text)
};
Ok(AutomatonWeight::new(
self.term.field(),
DfaWrapper(automaton),
))

if let Some((json_path_bytes, _)) = term_value.as_json() {
Ok(AutomatonWeight::new_for_json_path(
self.term.field(),
DfaWrapper(automaton),
json_path_bytes,
))
} else {
Ok(AutomatonWeight::new(
self.term.field(),
DfaWrapper(automaton),
))
}
}
}

Expand All @@ -157,9 +185,89 @@ impl Query for FuzzyTermQuery {
mod test {
use super::FuzzyTermQuery;
use crate::collector::{Count, TopDocs};
use crate::schema::{Schema, TEXT};
use crate::indexer::NoMergePolicy;
use crate::query::QueryParser;
use crate::schema::{Schema, STORED, TEXT};
use crate::{assert_nearly_equals, Index, Term};

#[test]
pub fn test_fuzzy_json_path() -> crate::Result<()> {
// # Defining the schema
let mut schema_builder = Schema::builder();
let attributes = schema_builder.add_json_field("attributes", TEXT | STORED);
let schema = schema_builder.build();

// # Indexing documents
let index = Index::create_in_ram(schema.clone());

let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
let doc = schema.parse_document(
r#"{
"attributes": {
"a": "japan"
}
}"#,
)?;
index_writer.add_document(doc)?;
let doc = schema.parse_document(
r#"{
"attributes": {
"aa": "japan"
}
}"#,
)?;
index_writer.add_document(doc)?;
index_writer.commit()?;

let reader = index.reader()?;
let searcher = reader.searcher();

// # Fuzzy search
let query_parser = QueryParser::for_index(&index, vec![attributes]);

let get_json_path_term = |query: &str| -> crate::Result<Term> {
let query = query_parser.parse_query(query)?;
let mut terms = Vec::new();
query.query_terms(&mut |term, _| {
terms.push(term.clone());
});

Ok(terms[0].clone())
};

// shall not match the first document due to json path mismatch
{
let term = get_json_path_term("attributes.aa:japan")?;
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
}

// shall match the first document because Levenshtein distance is 1 (substitute 'o' with
// 'a')
{
let term = get_json_path_term("attributes.a:japon")?;

let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
}

// shall not match because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
{
let term = get_json_path_term("attributes.a:jap")?;

let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}

Ok(())
}

#[test]
pub fn test_fuzzy_term() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
Expand Down
2 changes: 1 addition & 1 deletion src/query/phrase_prefix_query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use phrase_prefix_query::PhrasePrefixQuery;
pub use phrase_prefix_scorer::PhrasePrefixScorer;
pub use phrase_prefix_weight::PhrasePrefixWeight;

fn prefix_end(prefix_start: &[u8]) -> Option<Vec<u8>> {
pub(crate) fn prefix_end(prefix_start: &[u8]) -> Option<Vec<u8>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the u8::MAX logic here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to the phrasePrefixQuery implementation, which also filters the term based on the term value prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it attempts to find the next larger prefix. Typically, this involves incrementing the last u8 value by 1. However, there is an edge case to consider when the last u8 value is u8::MAX.

let mut res = prefix_start.to_owned();
while !res.is_empty() {
let end = res.len() - 1;
Expand Down
24 changes: 18 additions & 6 deletions src/schema/term.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,20 +397,29 @@ where B: AsRef<[u8]>
Some(Ipv6Addr::from_u128(ip_u128))
}

/// Returns the json path (without non-human friendly separators),
/// Returns the json path type.
///
/// Returns `None` if the value is not JSON.
pub fn json_path_type(&self) -> Option<Type> {
let json_value_bytes = self.as_json_value_bytes()?;

Some(json_value_bytes.typ())
}

/// Returns the json path bytes (including the JSON_END_OF_PATH byte),
/// and the encoded ValueBytes after the json path.
///
/// Returns `None` if the value is not JSON.
pub(crate) fn as_json(&self) -> Option<(&str, ValueBytes<&[u8]>)> {
pub(crate) fn as_json(&self) -> Option<(&[u8], ValueBytes<&[u8]>)> {
if self.typ() != Type::Json {
return None;
}
let bytes = self.value_bytes();

let pos = bytes.iter().cloned().position(|b| b == JSON_END_OF_PATH)?;
let (json_path_bytes, term) = bytes.split_at(pos);
let json_path = str::from_utf8(json_path_bytes).ok()?;
Some((json_path, ValueBytes::wrap(&term[1..])))
// split at pos + 1, so that json_path_bytes includes the JSON_END_OF_PATH byte.
let (json_path_bytes, term) = bytes.split_at(pos + 1);
Some((json_path_bytes, ValueBytes::wrap(&term)))
}

/// Returns the encoded ValueBytes after the json path.
Expand Down Expand Up @@ -469,7 +478,10 @@ where B: AsRef<[u8]>
write_opt(f, self.as_bytes())?;
}
Type::Json => {
if let Some((path, sub_value_bytes)) = self.as_json() {
if let Some((path_bytes, sub_value_bytes)) = self.as_json() {
// Remove the JSON_END_OF_PATH byte & convert to utf8.
let path = str::from_utf8(&path_bytes[..path_bytes.len() - 1])
.map_err(|_| std::fmt::Error)?;
let path_pretty = path.replace(JSON_PATH_SEGMENT_SEP_STR, ".");
write!(f, "path={path_pretty}, ")?;
sub_value_bytes.debug_value_bytes(f)?;
Expand Down