Skip to content

Commit

Permalink
min_should_match for pure shoulds
Browse files Browse the repository at this point in the history
  • Loading branch information
jtong11 committed Oct 9, 2020
1 parent 60f21c3 commit a76f3de
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 43 deletions.
39 changes: 27 additions & 12 deletions src/core/search/query/boolean_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct BooleanQuery<C: Codec> {
should_queries: Vec<Box<dyn Query<C>>>,
filter_queries: Vec<Box<dyn Query<C>>>,
must_not_queries: Vec<Box<dyn Query<C>>>,
minimum_should_match: i32,
min_should_match: i32,
}

pub const BOOLEAN: &str = "boolean";
Expand All @@ -42,8 +42,18 @@ impl<C: Codec> BooleanQuery<C> {
shoulds: Vec<Box<dyn Query<C>>>,
filters: Vec<Box<dyn Query<C>>>,
must_nots: Vec<Box<dyn Query<C>>>,
min_should_match: i32,
) -> Result<Box<dyn Query<C>>> {
let minimum_should_match = if musts.is_empty() { 1 } else { 0 };
let min_should_match = if min_should_match > 0 {
min_should_match
} else {
if musts.is_empty() {
1
} else {
0
}
};

let mut musts = musts;
let mut shoulds = shoulds;
let mut filters = filters;
Expand Down Expand Up @@ -72,7 +82,7 @@ impl<C: Codec> BooleanQuery<C> {
should_queries: shoulds,
filter_queries: filters,
must_not_queries: must_nots,
minimum_should_match,
min_should_match,
}))
}

Expand Down Expand Up @@ -110,6 +120,7 @@ impl<C: Codec> Query<C> for BooleanQuery<C> {
should_weights,
must_not_weights,
needs_scores,
self.min_should_match,
)))
}

Expand Down Expand Up @@ -145,7 +156,7 @@ impl<C: Codec> fmt::Display for BooleanQuery<C> {
write!(
f,
"BooleanQuery(must: [{}], should: [{}], filters: [{}], must_not: [{}], match: {})",
must_str, should_str, filters_str, must_not_str, self.minimum_should_match
must_str, should_str, filters_str, must_not_str, self.min_should_match
)
}
}
Expand All @@ -154,8 +165,7 @@ struct BooleanWeight<C: Codec> {
must_weights: Vec<Box<dyn Weight<C>>>,
should_weights: Vec<Box<dyn Weight<C>>>,
must_not_weights: Vec<Box<dyn Weight<C>>>,
#[allow(dead_code)]
minimum_should_match: i32,
min_should_match: i32,
needs_scores: bool,
}

Expand All @@ -165,13 +175,13 @@ impl<C: Codec> BooleanWeight<C> {
shoulds: Vec<Box<dyn Weight<C>>>,
must_nots: Vec<Box<dyn Weight<C>>>,
needs_scores: bool,
min_should_match: i32,
) -> BooleanWeight<C> {
let minimum_should_match = if musts.is_empty() { 1 } else { 0 };
BooleanWeight {
must_weights: musts,
should_weights: shoulds,
must_not_weights: must_nots,
minimum_should_match,
min_should_match,
needs_scores,
}
}
Expand Down Expand Up @@ -217,6 +227,7 @@ impl<C: Codec> Weight<C> for BooleanWeight<C> {
_ => Some(Box::new(DisjunctionSumScorer::new(
scorers,
self.needs_scores,
self.min_should_match,
))),
}
};
Expand All @@ -230,7 +241,11 @@ impl<C: Codec> Weight<C> for BooleanWeight<C> {
match scorers.len() {
0 => None,
1 => Some(scorers.remove(0)),
_ => Some(Box::new(DisjunctionSumScorer::new(scorers, false))),
_ => Some(Box::new(DisjunctionSumScorer::new(
scorers,
false,
self.min_should_match,
))),
}
};

Expand Down Expand Up @@ -348,13 +363,13 @@ impl<C: Codec> Weight<C> for BooleanWeight<C> {
"No matching clauses".to_string(),
subs,
))
} else if should_match_count < self.minimum_should_match {
} else if should_match_count < self.min_should_match {
Ok(Explanation::new(
false,
0.0f32,
format!(
"Failure to match minimum number of optional clauses: {}<{}",
should_match_count, self.minimum_should_match
should_match_count, self.min_should_match
),
subs,
))
Expand Down Expand Up @@ -394,7 +409,7 @@ impl<C: Codec> fmt::Display for BooleanWeight<C> {
f,
"BooleanWeight(must: [{}], should: [{}], must_not: [{}], min match: {}, needs score: \
{})",
must_str, should_str, must_not_str, self.minimum_should_match, self.needs_scores
must_str, should_str, must_not_str, self.min_should_match, self.needs_scores
)
}
}
28 changes: 14 additions & 14 deletions src/core/search/query/query_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct QueryStringQueryBuilder {
query_string: String,
fields: Vec<(String, f32)>,
#[allow(dead_code)]
minimum_should_match: i32,
min_should_match: i32,
#[allow(dead_code)]
boost: f32,
}
Expand All @@ -38,13 +38,13 @@ impl QueryStringQueryBuilder {
pub fn new(
query_string: String,
fields: Vec<(String, f32)>,
minimum_should_match: i32,
min_should_match: i32,
boost: f32,
) -> QueryStringQueryBuilder {
QueryStringQueryBuilder {
query_string,
fields,
minimum_should_match,
min_should_match,
boost,
}
}
Expand Down Expand Up @@ -171,7 +171,7 @@ impl QueryStringQueryBuilder {
shoulds.remove(0)
}
} else {
BooleanQuery::build(musts, shoulds, vec![], vec![])?
BooleanQuery::build(musts, shoulds, vec![], vec![], self.min_should_match)?
};
Ok(Some(query))
}
Expand All @@ -190,7 +190,7 @@ impl QueryStringQueryBuilder {
let res = if queries.len() == 1 {
queries.remove(0)
} else {
BooleanQuery::build(Vec::new(), queries, vec![], vec![])?
BooleanQuery::build(Vec::new(), queries, vec![], vec![], self.min_should_match)?
};
Ok(res)
}
Expand Down Expand Up @@ -259,7 +259,7 @@ mod tests {
let term = String::from("test");
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(term.clone(), vec![(field, 1.0)], 1, 1.0)
QueryStringQueryBuilder::new(term.clone(), vec![(field, 1.0)], 0, 1.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -271,7 +271,7 @@ mod tests {
let term = String::from("(test^0.2 | 测试^2)");
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(term.clone(), vec![(field, 1.0)], 1, 2.0)
QueryStringQueryBuilder::new(term.clone(), vec![(field, 1.0)], 0, 2.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -287,7 +287,7 @@ mod tests {
let term = String::from("test^0.2 \"测试\"^2");
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(term.clone(), vec![(field, 1.0)], 1, 2.0)
QueryStringQueryBuilder::new(term.clone(), vec![(field, 1.0)], 0, 2.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -302,7 +302,7 @@ mod tests {

let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(String::from("+test"), vec![(field, 1.0)], 1, 1.0)
QueryStringQueryBuilder::new(String::from("+test"), vec![(field, 1.0)], 0, 1.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -314,7 +314,7 @@ mod tests {
let query_string = String::from("test search");
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 1, 1.0)
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 0, 1.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -330,7 +330,7 @@ mod tests {
let query_string = String::from("test +search");
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 1, 1.0)
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 0, 1.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -346,7 +346,7 @@ mod tests {
let query_string = String::from("test +(search 搜索)");
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 1, 1.0)
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 0, 1.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand All @@ -364,7 +364,7 @@ mod tests {
let q: Box<dyn Query<TestCodec>> = QueryStringQueryBuilder::new(
query_string.clone(),
vec![("title".to_string(), 1.0), ("content".to_string(), 1.0)],
1,
0,
1.0,
)
.build()
Expand All @@ -387,7 +387,7 @@ mod tests {
);
let field = String::from("title");
let q: Box<dyn Query<TestCodec>> =
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 1, 1.0)
QueryStringQueryBuilder::new(query_string.clone(), vec![(field, 1.0)], 0, 1.0)
.build()
.unwrap();
let term_str: String = q.to_string();
Expand Down
45 changes: 31 additions & 14 deletions src/core/search/scorer/disjunction_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@ pub struct DisjunctionSumScorer<T: Scorer> {
sub_scorers: SubScorers<T>,
needs_scores: bool,
cost: usize,
min_should_match: i32,
}

impl<T: Scorer> DisjunctionSumScorer<T> {
pub fn new(children: Vec<T>, needs_scores: bool) -> DisjunctionSumScorer<T> {
pub fn new(
children: Vec<T>,
needs_scores: bool,
min_should_match: i32,
) -> DisjunctionSumScorer<T> {
assert!(children.len() > 1);

let cost = children.iter().map(|w| w.cost()).sum();

let sub_scorers = if children.len() < 10 {
let sub_scorers = if children.len() < 10 || min_should_match > 1 {
SubScorers::SQ(SimpleQueue::new(children))
} else {
SubScorers::DPQ(DisiPriorityQueue::new(children))
Expand All @@ -41,6 +46,7 @@ impl<T: Scorer> DisjunctionSumScorer<T> {
sub_scorers,
needs_scores,
cost,
min_should_match,
}
}
}
Expand Down Expand Up @@ -81,7 +87,8 @@ impl<T: Scorer> DocIterator for DisjunctionSumScorer<T> {
}

fn approximate_next(&mut self) -> Result<DocId> {
self.sub_scorers.approximate_next()
self.sub_scorers
.approximate_next(Some(self.min_should_match))
}

fn approximate_advance(&mut self, target: DocId) -> Result<DocId> {
Expand Down Expand Up @@ -162,7 +169,7 @@ impl<T: Scorer> DocIterator for DisjunctionMaxScorer<T> {
}

fn approximate_next(&mut self) -> Result<DocId> {
self.sub_scorers.approximate_next()
self.sub_scorers.approximate_next(None)
}

fn approximate_advance(&mut self, target: DocId) -> Result<DocId> {
Expand Down Expand Up @@ -278,23 +285,33 @@ impl<T: Scorer> SubScorers<T> {
}
}

fn approximate_next(&mut self) -> Result<DocId> {
fn approximate_next(&mut self, min_should_match: Option<i32>) -> Result<DocId> {
let min_should_match = min_should_match.unwrap_or(0);

match self {
SubScorers::SQ(sq) => {
let curr_doc = sq.curr_doc;
let mut min_doc = NO_MORE_DOCS;
for s in sq.scorers.iter_mut() {
if s.doc_id() == curr_doc {
s.approximate_next()?;
loop {
// curr_doc = current min_doc, (not -1)
let curr_doc = sq.curr_doc;
let mut min_doc = NO_MORE_DOCS;
let mut should_count = 0;
for s in sq.scorers.iter_mut() {
if s.doc_id() == curr_doc {
should_count += 1;
s.approximate_next()?;
}

min_doc = min_doc.min(s.doc_id());
}

min_doc = min_doc.min(s.doc_id());
sq.curr_doc = min_doc;
if should_count >= min_should_match || sq.curr_doc == NO_MORE_DOCS {
return Ok(sq.curr_doc);
}
}

sq.curr_doc = min_doc;
Ok(sq.curr_doc)
}
SubScorers::DPQ(dbq) => {
// reset with -1, @posting_reader.rs#1208
let doc = dbq.peek().doc();

loop {
Expand Down
4 changes: 2 additions & 2 deletions src/core/search/scorer/req_not_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod tests {

let conjunction_scorer: Box<dyn Scorer> = Box::new(ConjunctionScorer::new(vec![s1, s2]));
let disjunction_scorer: Box<dyn Scorer> =
Box::new(DisjunctionSumScorer::new(vec![s3, s4], true));
Box::new(DisjunctionSumScorer::new(vec![s3, s4], true, 0));
let mut scorer = ReqNotScorer::new(conjunction_scorer, disjunction_scorer);

assert_eq!(scorer.doc_id(), -1);
Expand All @@ -154,7 +154,7 @@ mod tests {

let conjunction_scorer: Box<dyn Scorer> = Box::new(ConjunctionScorer::new(vec![s1, s2]));
let disjunction_scorer: Box<dyn Scorer> =
Box::new(DisjunctionSumScorer::new(vec![s3, s4], true));
Box::new(DisjunctionSumScorer::new(vec![s3, s4], true, 0));
let mut scorer = ReqNotScorer::new(conjunction_scorer, disjunction_scorer);

// 2, 3, 5, 7, 9
Expand Down
2 changes: 1 addition & 1 deletion src/core/search/scorer/req_opt_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ mod tests {

let conjunction_scorer: Box<dyn Scorer> = Box::new(ConjunctionScorer::new(vec![s1, s2]));
let disjunction_scorer: Box<dyn Scorer> =
Box::new(DisjunctionSumScorer::new(vec![s3, s4], true));
Box::new(DisjunctionSumScorer::new(vec![s3, s4], true, 0));
let mut scorer = ReqOptScorer::new(conjunction_scorer, disjunction_scorer);

assert_eq!(scorer.doc_id(), -1);
Expand Down

0 comments on commit a76f3de

Please sign in to comment.