From dd7e745392904a3fc236e5508c7d2374e1010732 Mon Sep 17 00:00:00 2001 From: Tushar Date: Fri, 26 Jan 2024 21:16:41 +0530 Subject: [PATCH] feat(aggregators/metric): Add a top_hits aggregator (#2198) * feat(aggregators/metric): Implement a top_hits aggregator * fix: Expose get_fields * fix: Serializer for top_hits request Also removes extraneous the extraneous third-party serialization helper. * chore: Avert panick on parsing invalid top_hits query * refactor: Allow multiple field names from aggregations * perf: Replace binary heap with TopNComputer * fix: Avoid comparator inversion by ComparableDoc * fix: Rank missing field values lower than present values * refactor: Make KeyOrder a struct * feat: Rough attempt at docvalue_fields * feat: Complete stab at docvalue_fields - Rename "SearchResult*" => "Retrieval*" - Revert Vec => HashMap for aggregation accessors. - Split accessors for core aggregation and field retrieval. - Resolve globbed field names in docvalue_fields retrieval. - Handle strings/bytes and other column types with DynamicColumn * test(unit): Add tests for top_hits aggregator * fix: docfield_value field globbing * test(unit): Include dynamic fields * fix: Value -> OwnedValue * fix: Use OwnedValue's native Null variant * chore: Improve readability of test asserts * chore: Remove DocAddress from top_hits result * docs: Update aggregator doc * revert: accidental doc test * chore: enable time macros only for tests * chore: Apply suggestions from review * chore: Apply suggestions from review * fix: Retrieve all values for fields * test(unit): Update for multi-value retrieval * chore: Assert term existence * feat: Include all columns for a column name Since a (name, type) constitutes a unique column. * fix: Resolve json fields Introduces a translation step to bridge the difference between ColumnarReaders null `\0` separated json field keys to the common `.` separated used by SegmentReader. Although, this should probably be the default behavior for ColumnarReader's public API perhaps. * chore: Address review on mutability * chore: s/segment_id/segment_ordinal instances of SegmentOrdinal * chore: Revert erroneous grammar change --- Cargo.toml | 1 + src/aggregation/agg_req.rs | 39 +- src/aggregation/agg_req_with_accessor.rs | 181 +++- src/aggregation/agg_result.rs | 9 +- .../bucket/histogram/date_histogram.rs | 1 + src/aggregation/bucket/term_agg.rs | 2 +- src/aggregation/bucket/term_missing_agg.rs | 5 +- src/aggregation/collector.rs | 23 +- src/aggregation/intermediate_agg_result.rs | 14 +- src/aggregation/metric/mod.rs | 23 + src/aggregation/metric/percentiles.rs | 2 - src/aggregation/metric/top_hits.rs | 837 ++++++++++++++++++ src/aggregation/segment_agg_result.rs | 6 + src/collector/mod.rs | 1 + src/collector/top_collector.rs | 47 +- src/collector/top_score_collector.rs | 89 +- src/lib.rs | 2 +- 17 files changed, 1134 insertions(+), 148 deletions(-) create mode 100644 src/aggregation/metric/top_hits.rs diff --git a/Cargo.toml b/Cargo.toml index 85febab574..3345d18856 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ futures = "0.3.21" paste = "1.0.11" more-asserts = "0.3.1" rand_distr = "0.4.3" +time = { version = "0.3.10", features = ["serde-well-known", "macros"] } [target.'cfg(not(windows))'.dev-dependencies] criterion = { version = "0.5", default-features = false } diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 8e35991080..eee4090dec 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -35,7 +35,7 @@ use super::bucket::{ }; use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, - PercentilesAggregationReq, StatsAggregation, SumAggregation, + PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation, }; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user @@ -93,7 +93,12 @@ impl Aggregation { } fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - fast_field_names.insert(self.agg.get_fast_field_name().to_string()); + fast_field_names.extend( + self.agg + .get_fast_field_names() + .iter() + .map(|s| s.to_string()), + ); fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); } } @@ -147,23 +152,27 @@ pub enum AggregationVariants { /// Computes the sum of the extracted values. #[serde(rename = "percentiles")] Percentiles(PercentilesAggregationReq), + /// Finds the top k values matching some order + #[serde(rename = "top_hits")] + TopHits(TopHitsAggregation), } impl AggregationVariants { - /// Returns the name of the field used by the aggregation. - pub fn get_fast_field_name(&self) -> &str { + /// Returns the name of the fields used by the aggregation. + pub fn get_fast_field_names(&self) -> Vec<&str> { match self { - AggregationVariants::Terms(terms) => terms.field.as_str(), - AggregationVariants::Range(range) => range.field.as_str(), - AggregationVariants::Histogram(histogram) => histogram.field.as_str(), - AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(), - AggregationVariants::Average(avg) => avg.field_name(), - AggregationVariants::Count(count) => count.field_name(), - AggregationVariants::Max(max) => max.field_name(), - AggregationVariants::Min(min) => min.field_name(), - AggregationVariants::Stats(stats) => stats.field_name(), - AggregationVariants::Sum(sum) => sum.field_name(), - AggregationVariants::Percentiles(per) => per.field_name(), + AggregationVariants::Terms(terms) => vec![terms.field.as_str()], + AggregationVariants::Range(range) => vec![range.field.as_str()], + AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()], + AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()], + AggregationVariants::Average(avg) => vec![avg.field_name()], + AggregationVariants::Count(count) => vec![count.field_name()], + AggregationVariants::Max(max) => vec![max.field_name()], + AggregationVariants::Min(min) => vec![min.field_name()], + AggregationVariants::Stats(stats) => vec![stats.field_name()], + AggregationVariants::Sum(sum) => vec![sum.field_name()], + AggregationVariants::Percentiles(per) => vec![per.field_name()], + AggregationVariants::TopHits(top_hits) => top_hits.field_names(), } } diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index e6f960d057..2e4e3b9388 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,6 +1,9 @@ //! This will enhance the request tree with access to the fastfield and metadata. -use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; +use std::collections::HashMap; +use std::io; + +use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn}; use super::agg_limits::ResourceLimitGuard; use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; @@ -14,7 +17,7 @@ use super::metric::{ use super::segment_agg_result::AggregationLimits; use super::VecWithNames; use crate::aggregation::{f64_to_fastfield_u64, Key}; -use crate::SegmentReader; +use crate::{SegmentOrdinal, SegmentReader}; #[derive(Default)] pub(crate) struct AggregationsWithAccessor { @@ -32,6 +35,7 @@ impl AggregationsWithAccessor { } pub struct AggregationWithAccessor { + pub(crate) segment_ordinal: SegmentOrdinal, /// In general there can be buckets without fast field access, e.g. buckets that are created /// based on search terms. That is not that case currently, but eventually this needs to be /// Option or moved. @@ -44,10 +48,16 @@ pub struct AggregationWithAccessor { pub(crate) limits: ResourceLimitGuard, pub(crate) column_block_accessor: ColumnBlockAccessor, /// Used for missing term aggregation, which checks all columns for existence. + /// And also for `top_hits` aggregation, which may sort on multiple fields. /// By convention the missing aggregation is chosen, when this property is set /// (instead bein set in `agg`). /// If this needs to used by other aggregations, we need to refactor this. - pub(crate) accessors: Vec>, + // NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type` + // (making them obsolete) But will it have a performance impact? + pub(crate) accessors: Vec<(Column, ColumnType)>, + /// Map field names to all associated column accessors. + /// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`. + pub(crate) value_accessors: HashMap>, pub(crate) agg: Aggregation, } @@ -57,19 +67,55 @@ impl AggregationWithAccessor { agg: &Aggregation, sub_aggregation: &Aggregations, reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, limits: AggregationLimits, ) -> crate::Result> { - let add_agg_with_accessor = |accessor: Column, + let mut agg = agg.clone(); + + let add_agg_with_accessor = |agg: &Aggregation, + accessor: Column, column_type: ColumnType, aggs: &mut Vec| -> crate::Result<()> { let res = AggregationWithAccessor { + segment_ordinal, accessor, - accessors: Vec::new(), + accessors: Default::default(), + value_accessors: Default::default(), field_type: column_type, sub_aggregation: get_aggs_with_segment_accessor_and_validate( sub_aggregation, reader, + segment_ordinal, + &limits, + )?, + agg: agg.clone(), + limits: limits.new_guard(), + missing_value_for_accessor: None, + str_dict_column: None, + column_block_accessor: Default::default(), + }; + aggs.push(res); + Ok(()) + }; + + let add_agg_with_accessors = |agg: &Aggregation, + accessors: Vec<(Column, ColumnType)>, + aggs: &mut Vec, + value_accessors: HashMap>| + -> crate::Result<()> { + let (accessor, field_type) = accessors.first().expect("at least one accessor"); + let res = AggregationWithAccessor { + segment_ordinal, + // TODO: We should do away with the `accessor` field altogether + accessor: accessor.clone(), + value_accessors, + field_type: *field_type, + accessors, + sub_aggregation: get_aggs_with_segment_accessor_and_validate( + sub_aggregation, + reader, + segment_ordinal, &limits, )?, agg: agg.clone(), @@ -84,32 +130,36 @@ impl AggregationWithAccessor { let mut res: Vec = Vec::new(); use AggregationVariants::*; - match &agg.agg { + + match agg.agg { Range(RangeAggregation { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } Histogram(HistogramAggregation { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } DateHistogram(DateHistogramAggregationReq { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = // Only DateTime is supported for DateHistogram get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } Terms(TermsAggregation { - field: field_name, - missing, + field: ref field_name, + ref missing, .. }) => { let str_dict_column = reader.fast_fields().str(field_name)?; @@ -162,24 +212,11 @@ impl AggregationWithAccessor { let column_and_types = get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?; - let accessors: Vec = - column_and_types.iter().map(|(a, _)| a.clone()).collect(); - let agg_wit_acc = AggregationWithAccessor { - missing_value_for_accessor: None, - accessor: accessors[0].clone(), - accessors, - field_type: ColumnType::U64, - sub_aggregation: get_aggs_with_segment_accessor_and_validate( - sub_aggregation, - reader, - &limits, - )?, - agg: agg.clone(), - str_dict_column: str_dict_column.clone(), - limits: limits.new_guard(), - column_block_accessor: Default::default(), - }; - res.push(agg_wit_acc); + let accessors = column_and_types + .iter() + .map(|c_t| (c_t.0.clone(), c_t.1)) + .collect(); + add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?; } for (accessor, column_type) in column_and_types { @@ -189,21 +226,25 @@ impl AggregationWithAccessor { missing.clone() }; - let missing_value_for_accessor = - if let Some(missing) = missing_value_term_agg.as_ref() { - get_missing_val(column_type, missing, agg.agg.get_fast_field_name())? - } else { - None - }; + let missing_value_for_accessor = if let Some(missing) = + missing_value_term_agg.as_ref() + { + get_missing_val(column_type, missing, agg.agg.get_fast_field_names()[0])? + } else { + None + }; let agg = AggregationWithAccessor { + segment_ordinal, missing_value_for_accessor, accessor, - accessors: Vec::new(), + accessors: Default::default(), + value_accessors: Default::default(), field_type: column_type, sub_aggregation: get_aggs_with_segment_accessor_and_validate( sub_aggregation, reader, + segment_ordinal, &limits, )?, agg: agg.clone(), @@ -215,34 +256,63 @@ impl AggregationWithAccessor { } } Average(AverageAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Count(CountAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Max(MaxAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Min(MinAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Stats(StatsAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Sum(SumAggregation { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } - Percentiles(percentiles) => { + Percentiles(ref percentiles) => { let (accessor, column_type) = get_ff_reader( reader, percentiles.field_name(), Some(get_numeric_or_date_column_types()), )?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; + } + TopHits(ref mut top_hits) => { + top_hits.validate_and_resolve(reader.fast_fields().columnar())?; + let accessors: Vec<(Column, ColumnType)> = top_hits + .field_names() + .iter() + .map(|field| { + get_ff_reader(reader, field, Some(get_numeric_or_date_column_types())) + }) + .collect::>()?; + + let value_accessors = top_hits + .value_field_names() + .iter() + .map(|field_name| { + Ok(( + field_name.to_string(), + get_dynamic_columns(reader, field_name)?, + )) + }) + .collect::>()?; + + add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?; } }; @@ -284,6 +354,7 @@ fn get_numeric_or_date_column_types() -> &'static [ColumnType] { pub(crate) fn get_aggs_with_segment_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, limits: &AggregationLimits, ) -> crate::Result { let mut aggss = Vec::new(); @@ -292,6 +363,7 @@ pub(crate) fn get_aggs_with_segment_accessor_and_validate( agg, agg.sub_aggregation(), reader, + segment_ordinal, limits.clone(), )?; for agg in aggs { @@ -321,6 +393,19 @@ fn get_ff_reader( Ok(ff_field_with_type) } +fn get_dynamic_columns( + reader: &SegmentReader, + field_name: &str, +) -> crate::Result> { + let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?; + let cols = ff_fields + .iter() + .map(|h| h.open()) + .collect::>()?; + assert!(!ff_fields.is_empty(), "field {} not found", field_name); + Ok(cols) +} + /// Get all fast field reader or empty as default. /// /// Is guaranteed to return at least one column. diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index ff9e7716fb..b7ef247068 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::GetDocCount; -use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats}; +use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult}; use super::{AggregationError, Key}; use crate::TantivyError; @@ -90,8 +90,10 @@ pub enum MetricResult { Stats(Stats), /// Sum metric result. Sum(SingleMetricResult), - /// Sum metric result. + /// Percentiles metric result. Percentiles(PercentilesMetricResult), + /// Top hits metric result + TopHits(TopHitsMetricResult), } impl MetricResult { @@ -106,6 +108,9 @@ impl MetricResult { MetricResult::Percentiles(_) => Err(TantivyError::AggregationError( AggregationError::InvalidRequest("percentiles can't be used to order".to_string()), )), + MetricResult::TopHits(_) => Err(TantivyError::AggregationError( + AggregationError::InvalidRequest("top_hits can't be used to order".to_string()), + )), } } } diff --git a/src/aggregation/bucket/histogram/date_histogram.rs b/src/aggregation/bucket/histogram/date_histogram.rs index d0502af73f..b4919a4e56 100644 --- a/src/aggregation/bucket/histogram/date_histogram.rs +++ b/src/aggregation/bucket/histogram/date_histogram.rs @@ -307,6 +307,7 @@ pub mod tests { ) -> crate::Result { let mut schema_builder = Schema::builder(); schema_builder.add_date_field("date", FAST); + schema_builder.add_json_field("mixed", FAST); schema_builder.add_text_field("text", FAST | STRING); schema_builder.add_text_field("text2", FAST | STRING); let schema = schema_builder.build(); diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 1b29e361ef..b0e40f88db 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -110,7 +110,7 @@ pub struct TermsAggregation { #[serde(alias = "shard_size")] pub split_size: Option, - /// The get more accurate results, we fetch more than `size` from each segment. + /// To get more accurate results, we fetch more than `size` from each segment. /// /// Increasing this value is will increase the cost for more accuracy. /// diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 1d43a2e65d..a863d5eb2c 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -90,7 +90,10 @@ impl SegmentAggregationCollector for TermMissingAgg { agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - let has_value = agg.accessors.iter().any(|acc| acc.index.has_value(doc)); + let has_value = agg + .accessors + .iter() + .any(|(acc, _)| acc.index.has_value(doc)); if !has_value { self.missing_count += 1; if let Some(sub_agg) = self.sub_agg.as_mut() { diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index b3a0ed9176..d0e9ec5b83 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -8,7 +8,7 @@ use super::segment_agg_result::{ }; use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; -use crate::{DocId, SegmentReader, TantivyError}; +use crate::{DocId, SegmentOrdinal, SegmentReader, TantivyError}; /// The default max bucket count, before the aggregation fails. pub const DEFAULT_BUCKET_LIMIT: u32 = 65000; @@ -64,10 +64,15 @@ impl Collector for DistributedAggregationCollector { fn for_segment( &self, - _segment_local_id: crate::SegmentOrdinal, + segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + segment_local_id, + &self.limits, + ) } fn requires_scoring(&self) -> bool { @@ -89,10 +94,15 @@ impl Collector for AggregationCollector { fn for_segment( &self, - _segment_local_id: crate::SegmentOrdinal, + segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + segment_local_id, + &self.limits, + ) } fn requires_scoring(&self) -> bool { @@ -135,10 +145,11 @@ impl AggregationSegmentCollector { pub fn from_agg_req_and_reader( agg: &Aggregations, reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, limits: &AggregationLimits, ) -> crate::Result { let mut aggs_with_accessor = - get_aggs_with_segment_accessor_and_validate(agg, reader, limits)?; + get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?; let result = BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?); Ok(AggregationSegmentCollector { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 4bb056d5c2..9e07d68aa9 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -19,7 +19,7 @@ use super::bucket::{ }; use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats, - IntermediateSum, PercentilesCollector, + IntermediateSum, PercentilesCollector, TopHitsCollector, }; use super::segment_agg_result::AggregationLimits; use super::{format_date, AggregationError, Key, SerializedKey}; @@ -205,6 +205,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult Percentiles(_) => IntermediateAggregationResult::Metric( IntermediateMetricResult::Percentiles(PercentilesCollector::default()), ), + TopHits(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::TopHits( + TopHitsCollector::default(), + )), } } @@ -265,6 +268,8 @@ pub enum IntermediateMetricResult { Stats(IntermediateStats), /// Intermediate sum result. Sum(IntermediateSum), + /// Intermediate top_hits result + TopHits(TopHitsCollector), } impl IntermediateMetricResult { @@ -292,9 +297,13 @@ impl IntermediateMetricResult { percentiles .into_final_result(req.agg.as_percentile().expect("unexpected metric type")), ), + IntermediateMetricResult::TopHits(top_hits) => { + MetricResult::TopHits(top_hits.finalize()) + } } } + // TODO: this is our top-of-the-chain fruit merge mech fn merge_fruits(&mut self, other: IntermediateMetricResult) -> crate::Result<()> { match (self, other) { ( @@ -330,6 +339,9 @@ impl IntermediateMetricResult { ) => { left.merge_fruits(right)?; } + (IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => { + left.merge_fruits(right)?; + } _ => { panic!("incompatible fruit types in tree or missing merge_fruits handler"); } diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 998793d2d6..a1b9cc4118 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -23,6 +23,7 @@ mod min; mod percentiles; mod stats; mod sum; +mod top_hits; pub use average::*; pub use count::*; pub use max::*; @@ -32,6 +33,7 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; pub use stats::*; pub use sum::*; +pub use top_hits::*; /// Single-metric aggregations use this common result structure. /// @@ -81,6 +83,27 @@ pub struct PercentilesMetricResult { pub values: PercentileValues, } +/// The top_hits metric results entry +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TopHitsVecEntry { + /// The sort values of the document, depending on the sort criteria in the request. + pub sort: Vec>, + + /// Search results, for queries that include field retrieval requests + /// (`docvalue_fields`). + #[serde(flatten)] + pub search_results: FieldRetrivalResult, +} + +/// The top_hits metric aggregation results a list of top hits by sort criteria. +/// +/// The main reason for wrapping it in `hits` is to match elasticsearch output structure. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TopHitsMetricResult { + /// The result of the top_hits metric. + pub hits: Vec, +} + #[cfg(test)] mod tests { use crate::aggregation::agg_req::Aggregations; diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index 7b66b12732..886aa488c3 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -133,7 +133,6 @@ pub(crate) struct SegmentPercentilesCollector { field_type: ColumnType, pub(crate) percentiles: PercentilesCollector, pub(crate) accessor_idx: usize, - val_cache: Vec, missing: Option, } @@ -243,7 +242,6 @@ impl SegmentPercentilesCollector { field_type, percentiles: PercentilesCollector::new(), accessor_idx, - val_cache: Default::default(), missing, }) } diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs new file mode 100644 index 0000000000..fe3e7ba7f7 --- /dev/null +++ b/src/aggregation/metric/top_hits.rs @@ -0,0 +1,837 @@ +use std::collections::HashMap; +use std::fmt::Formatter; + +use columnar::{ColumnarReader, DynamicColumn}; +use regex::Regex; +use serde::ser::SerializeMap; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use super::{TopHitsMetricResult, TopHitsVecEntry}; +use crate::aggregation::bucket::Order; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateMetricResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::collector::TopNComputer; +use crate::schema::term::JSON_PATH_SEGMENT_SEP_STR; +use crate::schema::OwnedValue; +use crate::{DocAddress, DocId, SegmentOrdinal}; + +/// # Top Hits +/// +/// The top hits aggregation is a useful tool to answer questions like: +/// - "What are the most recent posts by each author?" +/// - "What are the most popular items in each category?" +/// +/// It does so by keeping track of the most relevant document being aggregated, +/// in terms of a sort criterion that can consist of multiple fields and their +/// sort-orders (ascending or descending). +/// +/// `top_hits` should not be used as a top-level aggregation. It is intended to be +/// used as a sub-aggregation, inside a `terms` aggregation or a `filters` aggregation, +/// for example. +/// +/// Note that this aggregator does not return the actual document addresses, but +/// rather a list of the values of the fields that were requested to be retrieved. +/// These values can be specified in the `docvalue_fields` parameter, which can include +/// a list of fast fields to be retrieved. At the moment, only fast fields are supported +/// but it is possible that we support the `fields` parameter to retrieve any stored +/// field in the future. +/// +/// The following example demonstrates a request for the top_hits aggregation: +/// ```JSON +/// { +/// "aggs": { +/// "top_authors": { +/// "terms": { +/// "field": "author", +/// "size": 5 +/// } +/// }, +/// "aggs": { +/// "top_hits": { +/// "size": 2, +/// "from": 0 +/// "sort": [ +/// { "date": "desc" } +/// ] +/// "docvalue_fields": ["date", "title", "iden"] +/// } +/// } +/// } +/// ``` +/// +/// This request will return an object containing the top two documents, sorted +/// by the `date` field in descending order. You can also sort by multiple fields, which +/// helps to resolve ties. The aggregation object for each bucket will look like: +/// ```JSON +/// { +/// "hits": [ +/// { +/// "score": [], +/// "docvalue_fields": { +/// "date": "", +/// "title": "", +/// "iden": "<iden>" +/// } +/// }, +/// { +/// "score": [<time_u64>] +/// "docvalue_fields": { +/// "date": "<date_RFC3339>", +/// "title": "<title>", +/// "iden": "<iden>" +/// } +/// } +/// ] +/// } +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct TopHitsAggregation { + sort: Vec<KeyOrder>, + size: usize, + from: Option<usize>, + + #[serde(flatten)] + retrieval: RetrievalFields, +} + +const fn default_doc_value_fields() -> Vec<String> { + Vec::new() +} + +/// Search query spec for each matched document +/// TODO: move this to a common module +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct RetrievalFields { + /// The fast fields to return for each hit. + /// This is the only variant supported for now. + /// TODO: support the {field, format} variant for custom formatting. + #[serde(rename = "docvalue_fields")] + #[serde(default = "default_doc_value_fields")] + pub doc_value_fields: Vec<String>, +} + +/// Search query result for each matched document +/// TODO: move this to a common module +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct FieldRetrivalResult { + /// The fast fields returned for each hit. + #[serde(rename = "docvalue_fields")] + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub doc_value_fields: HashMap<String, OwnedValue>, +} + +impl RetrievalFields { + fn get_field_names(&self) -> Vec<&str> { + self.doc_value_fields.iter().map(|s| s.as_str()).collect() + } + + fn resolve_field_names(&mut self, reader: &ColumnarReader) -> crate::Result<()> { + // Tranform a glob (`pattern*`, for example) into a regex::Regex (`^pattern.*$`) + let globbed_string_to_regex = |glob: &str| { + // Replace `*` glob with `.*` regex + let sanitized = format!("^{}$", regex::escape(glob).replace(r"\*", ".*")); + Regex::new(&sanitized.replace('*', ".*")).map_err(|e| { + crate::TantivyError::SchemaError(format!( + "Invalid regex '{}' in docvalue_fields: {}", + glob, e + )) + }) + }; + self.doc_value_fields = self + .doc_value_fields + .iter() + .map(|field| { + if !field.contains('*') + && reader + .iter_columns()? + .any(|(name, _)| name.as_str() == field) + { + return Ok(vec![field.to_owned()]); + } + + let pattern = globbed_string_to_regex(&field)?; + let fields = reader + .iter_columns()? + .map(|(name, _)| { + // normalize path from internal fast field repr + name.replace(JSON_PATH_SEGMENT_SEP_STR, ".") + }) + .filter(|name| pattern.is_match(name)) + .collect::<Vec<_>>(); + assert!( + !fields.is_empty(), + "No fields matched the glob '{}' in docvalue_fields", + field + ); + Ok(fields) + }) + .collect::<crate::Result<Vec<_>>>()? + .into_iter() + .flatten() + .collect(); + + Ok(()) + } + + fn get_document_field_data( + &self, + accessors: &HashMap<String, Vec<DynamicColumn>>, + doc_id: DocId, + ) -> FieldRetrivalResult { + let dvf = self + .doc_value_fields + .iter() + .map(|field| { + let accessors = accessors + .get(field) + .unwrap_or_else(|| panic!("field '{}' not found in accessors", field)); + + let values: Vec<OwnedValue> = accessors + .iter() + .flat_map(|accessor| match accessor { + DynamicColumn::U64(accessor) => accessor + .values_for_doc(doc_id) + .map(OwnedValue::U64) + .collect::<Vec<_>>(), + DynamicColumn::I64(accessor) => accessor + .values_for_doc(doc_id) + .map(OwnedValue::I64) + .collect::<Vec<_>>(), + DynamicColumn::F64(accessor) => accessor + .values_for_doc(doc_id) + .map(OwnedValue::F64) + .collect::<Vec<_>>(), + DynamicColumn::Bytes(accessor) => accessor + .term_ords(doc_id) + .map(|term_ord| { + let mut buffer = vec![]; + assert!( + accessor + .ord_to_bytes(term_ord, &mut buffer) + .expect("could not read term dictionary"), + "term corresponding to term_ord does not exist" + ); + OwnedValue::Bytes(buffer) + }) + .collect::<Vec<_>>(), + DynamicColumn::Str(accessor) => accessor + .term_ords(doc_id) + .map(|term_ord| { + let mut buffer = vec![]; + assert!( + accessor + .ord_to_bytes(term_ord, &mut buffer) + .expect("could not read term dictionary"), + "term corresponding to term_ord does not exist" + ); + OwnedValue::Str(String::from_utf8(buffer).unwrap()) + }) + .collect::<Vec<_>>(), + DynamicColumn::Bool(accessor) => accessor + .values_for_doc(doc_id) + .map(OwnedValue::Bool) + .collect::<Vec<_>>(), + DynamicColumn::IpAddr(accessor) => accessor + .values_for_doc(doc_id) + .map(OwnedValue::IpAddr) + .collect::<Vec<_>>(), + DynamicColumn::DateTime(accessor) => accessor + .values_for_doc(doc_id) + .map(OwnedValue::Date) + .collect::<Vec<_>>(), + }) + .collect(); + + (field.to_owned(), OwnedValue::Array(values)) + }) + .collect(); + FieldRetrivalResult { + doc_value_fields: dvf, + } + } +} + +#[derive(Debug, Clone, PartialEq, Default)] +struct KeyOrder { + field: String, + order: Order, +} + +impl Serialize for KeyOrder { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + let KeyOrder { field, order } = self; + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry(field, order)?; + map.end() + } +} + +impl<'de> Deserialize<'de> for KeyOrder { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where D: Deserializer<'de> { + let mut k_o = <HashMap<String, Order>>::deserialize(deserializer)?.into_iter(); + let (k, v) = k_o.next().ok_or(serde::de::Error::custom( + "Expected exactly one key-value pair in KeyOrder, found none", + ))?; + if k_o.next().is_some() { + return Err(serde::de::Error::custom( + "Expected exactly one key-value pair in KeyOrder, found more", + )); + } + Ok(Self { field: k, order: v }) + } +} + +impl TopHitsAggregation { + /// Validate and resolve field retrieval parameters + pub fn validate_and_resolve(&mut self, reader: &ColumnarReader) -> crate::Result<()> { + self.retrieval.resolve_field_names(reader) + } + + /// Return fields accessed by the aggregator, in order. + pub fn field_names(&self) -> Vec<&str> { + self.sort + .iter() + .map(|KeyOrder { field, .. }| field.as_str()) + .collect() + } + + /// Return fields accessed by the aggregator's value retrieval. + pub fn value_field_names(&self) -> Vec<&str> { + self.retrieval.get_field_names() + } +} + +/// Holds a single comparable doc feature, and the order in which it should be sorted. +#[derive(Clone, Serialize, Deserialize, Debug)] +struct ComparableDocFeature { + /// Stores any u64-mappable feature. + value: Option<u64>, + /// Sort order for the doc feature + order: Order, +} + +impl Ord for ComparableDocFeature { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let invert = |cmp: std::cmp::Ordering| match self.order { + Order::Asc => cmp, + Order::Desc => cmp.reverse(), + }; + + match (self.value, other.value) { + (Some(self_value), Some(other_value)) => invert(self_value.cmp(&other_value)), + (Some(_), None) => std::cmp::Ordering::Greater, + (None, Some(_)) => std::cmp::Ordering::Less, + (None, None) => std::cmp::Ordering::Equal, + } + } +} + +impl PartialOrd for ComparableDocFeature { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl PartialEq for ComparableDocFeature { + fn eq(&self, other: &Self) -> bool { + self.value.cmp(&other.value) == std::cmp::Ordering::Equal + } +} + +impl Eq for ComparableDocFeature {} + +#[derive(Clone, Serialize, Deserialize, Debug)] +struct ComparableDocFeatures(Vec<ComparableDocFeature>, FieldRetrivalResult); + +impl Ord for ComparableDocFeatures { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + for (self_feature, other_feature) in self.0.iter().zip(other.0.iter()) { + let cmp = self_feature.cmp(other_feature); + if cmp != std::cmp::Ordering::Equal { + return cmp; + } + } + std::cmp::Ordering::Equal + } +} + +impl PartialOrd for ComparableDocFeatures { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl PartialEq for ComparableDocFeatures { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == std::cmp::Ordering::Equal + } +} + +impl Eq for ComparableDocFeatures {} + +/// The TopHitsCollector used for collecting over segments and merging results. +#[derive(Clone, Serialize, Deserialize)] +pub struct TopHitsCollector { + req: TopHitsAggregation, + top_n: TopNComputer<ComparableDocFeatures, DocAddress, false>, +} + +impl Default for TopHitsCollector { + fn default() -> Self { + Self { + req: TopHitsAggregation::default(), + top_n: TopNComputer::new(1), + } + } +} + +impl std::fmt::Debug for TopHitsCollector { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TopHitsCollector") + .field("req", &self.req) + .field("top_n_threshold", &self.top_n.threshold) + .finish() + } +} + +impl std::cmp::PartialEq for TopHitsCollector { + fn eq(&self, _other: &Self) -> bool { + false + } +} + +impl TopHitsCollector { + fn collect(&mut self, features: ComparableDocFeatures, doc: DocAddress) { + self.top_n.push(features, doc); + } + + pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> { + for doc in other_fruit.top_n.into_vec() { + self.collect(doc.feature, doc.doc); + } + Ok(()) + } + + /// Finalize by converting self into the final result form + pub fn finalize(self) -> TopHitsMetricResult { + let mut hits: Vec<TopHitsVecEntry> = self + .top_n + .into_sorted_vec() + .into_iter() + .map(|doc| TopHitsVecEntry { + sort: doc.feature.0.iter().map(|f| f.value).collect(), + search_results: doc.feature.1, + }) + .collect(); + + // Remove the first `from` elements + // Truncating from end would be more efficient, but we need to truncate from the front + // because `into_sorted_vec` gives us a descending order because of the inverted + // `Ord` semantics of the heap elements. + hits.drain(..self.req.from.unwrap_or(0)); + TopHitsMetricResult { hits } + } +} + +#[derive(Clone)] +pub(crate) struct SegmentTopHitsCollector { + segment_ordinal: SegmentOrdinal, + accessor_idx: usize, + inner_collector: TopHitsCollector, +} + +impl SegmentTopHitsCollector { + pub fn from_req( + req: &TopHitsAggregation, + accessor_idx: usize, + segment_ordinal: SegmentOrdinal, + ) -> Self { + Self { + inner_collector: TopHitsCollector { + req: req.clone(), + top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + }, + segment_ordinal, + accessor_idx, + } + } +} + +impl std::fmt::Debug for SegmentTopHitsCollector { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentTopHitsCollector") + .field("segment_id", &self.segment_ordinal) + .field("accessor_idx", &self.accessor_idx) + .field("inner_collector", &self.inner_collector) + .finish() + } +} + +impl SegmentAggregationCollector for SegmentTopHitsCollector { + fn add_intermediate_aggregation_result( + self: Box<Self>, + agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let intermediate_result = IntermediateMetricResult::TopHits(self.inner_collector); + results.push( + name, + IntermediateAggregationResult::Metric(intermediate_result), + ) + } + + fn collect( + &mut self, + doc_id: crate::DocId, + agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + ) -> crate::Result<()> { + let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors; + let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors; + let features: Vec<ComparableDocFeature> = self + .inner_collector + .req + .sort + .iter() + .enumerate() + .map(|(idx, KeyOrder { order, .. })| { + let order = *order; + let value = accessors + .get(idx) + .expect("could not find field in accessors") + .0 + .values_for_doc(doc_id) + .next(); + ComparableDocFeature { value, order } + }) + .collect(); + + let retrieval_result = self + .inner_collector + .req + .retrieval + .get_document_field_data(value_accessors, doc_id); + + self.inner_collector.collect( + ComparableDocFeatures(features, retrieval_result), + DocAddress { + segment_ord: self.segment_ordinal, + doc_id, + }, + ); + Ok(()) + } + + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + ) -> crate::Result<()> { + // TODO: Consider getting fields with the column block accessor and refactor this. + // --- + // Would the additional complexity of getting fields with the column_block_accessor + // make sense here? Probably yes, but I want to get a first-pass review first + // before proceeding. + for doc in docs { + self.collect(*doc, agg_with_accessor)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use common::DateTime; + use pretty_assertions::assert_eq; + use serde_json::Value; + use time::macros::datetime; + + use super::{ComparableDocFeature, ComparableDocFeatures, Order}; + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::bucket::tests::get_test_index_from_docs; + use crate::aggregation::tests::get_test_index_from_values; + use crate::aggregation::AggregationCollector; + use crate::collector::ComparableDoc; + use crate::query::AllQuery; + use crate::schema::OwnedValue as SchemaValue; + + fn invert_order(cmp_feature: ComparableDocFeature) -> ComparableDocFeature { + let ComparableDocFeature { value, order } = cmp_feature; + let order = match order { + Order::Asc => Order::Desc, + Order::Desc => Order::Asc, + }; + ComparableDocFeature { value, order } + } + + fn collector_with_capacity(capacity: usize) -> super::TopHitsCollector { + super::TopHitsCollector { + top_n: super::TopNComputer::new(capacity), + ..Default::default() + } + } + + fn invert_order_features(cmp_features: ComparableDocFeatures) -> ComparableDocFeatures { + let ComparableDocFeatures(cmp_features, search_results) = cmp_features; + let cmp_features = cmp_features + .into_iter() + .map(invert_order) + .collect::<Vec<_>>(); + ComparableDocFeatures(cmp_features, search_results) + } + + #[test] + fn test_comparable_doc_feature() -> crate::Result<()> { + let small = ComparableDocFeature { + value: Some(1), + order: Order::Asc, + }; + let big = ComparableDocFeature { + value: Some(2), + order: Order::Asc, + }; + let none = ComparableDocFeature { + value: None, + order: Order::Asc, + }; + + assert!(small < big); + assert!(none < small); + assert!(none < big); + + let small = invert_order(small); + let big = invert_order(big); + let none = invert_order(none); + + assert!(small > big); + assert!(none < small); + assert!(none < big); + + Ok(()) + } + + #[test] + fn test_comparable_doc_features() -> crate::Result<()> { + let features_1 = ComparableDocFeatures( + vec![ComparableDocFeature { + value: Some(1), + order: Order::Asc, + }], + Default::default(), + ); + + let features_2 = ComparableDocFeatures( + vec![ComparableDocFeature { + value: Some(2), + order: Order::Asc, + }], + Default::default(), + ); + + assert!(features_1 < features_2); + + assert!(invert_order_features(features_1.clone()) > invert_order_features(features_2)); + + Ok(()) + } + + #[test] + fn test_aggregation_top_hits_empty_index() -> crate::Result<()> { + let values = vec![]; + + let index = get_test_index_from_values(false, &values)?; + + let d: Aggregations = serde_json::from_value(json!({ + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [ + { "date": "desc" } + ], + "from": 0, + } + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(d, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str( + &serde_json::to_string(&agg_res).expect("JSON serialization failed"), + ) + .expect("JSON parsing failed"); + + assert_eq!( + res, + json!({ + "top_hits_req": { + "hits": [] + } + }) + ); + + Ok(()) + } + + #[test] + fn test_top_hits_collector_single_feature() -> crate::Result<()> { + let docs = vec![ + ComparableDoc::<_, _, false> { + doc: crate::DocAddress { + segment_ord: 0, + doc_id: 0, + }, + feature: ComparableDocFeatures( + vec![ComparableDocFeature { + value: Some(1), + order: Order::Asc, + }], + Default::default(), + ), + }, + ComparableDoc { + doc: crate::DocAddress { + segment_ord: 0, + doc_id: 2, + }, + feature: ComparableDocFeatures( + vec![ComparableDocFeature { + value: Some(3), + order: Order::Asc, + }], + Default::default(), + ), + }, + ComparableDoc { + doc: crate::DocAddress { + segment_ord: 0, + doc_id: 1, + }, + feature: ComparableDocFeatures( + vec![ComparableDocFeature { + value: Some(5), + order: Order::Asc, + }], + Default::default(), + ), + }, + ]; + + let mut collector = collector_with_capacity(3); + for doc in docs.clone() { + collector.collect(doc.feature, doc.doc); + } + + let res = collector.finalize(); + + assert_eq!( + res, + super::TopHitsMetricResult { + hits: vec![ + super::TopHitsVecEntry { + sort: vec![docs[0].feature.0[0].value], + search_results: Default::default(), + }, + super::TopHitsVecEntry { + sort: vec![docs[1].feature.0[0].value], + search_results: Default::default(), + }, + super::TopHitsVecEntry { + sort: vec![docs[2].feature.0[0].value], + search_results: Default::default(), + }, + ] + } + ); + + Ok(()) + } + + fn test_aggregation_top_hits(merge_segments: bool) -> crate::Result<()> { + let docs = vec![ + vec![ + r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb", "text2": "bbb", "mixed": { "dyn_arr": [1, "2"] } }"#, + r#"{ "date": "2017-06-15T00:00:00Z", "text": "ccc", "text2": "ddd", "mixed": { "dyn_arr": [3, "4"] } }"#, + ], + vec![ + r#"{ "text": "aaa", "text2": "bbb", "date": "2018-01-02T00:00:00Z", "mixed": { "dyn_arr": ["9", 8] } }"#, + r#"{ "text": "aaa", "text2": "bbb", "date": "2016-01-02T00:00:00Z", "mixed": { "dyn_arr": ["7", 6] } }"#, + ], + ]; + + let index = get_test_index_from_docs(merge_segments, &docs)?; + + let d: Aggregations = serde_json::from_value(json!({ + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [ + { "date": "desc" } + ], + "from": 1, + "docvalue_fields": [ + "date", + "tex*", + "mixed.*", + ], + } + } + }))?; + + let collector = AggregationCollector::from_aggs(d, Default::default()); + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg_res = + serde_json::to_value(searcher.search(&AllQuery, &collector).unwrap()).unwrap(); + + let date_2017 = datetime!(2017-06-15 00:00:00 UTC); + let date_2016 = datetime!(2016-01-02 00:00:00 UTC); + + assert_eq!( + agg_res["top_hits_req"], + json!({ + "hits": [ + { + "sort": [common::i64_to_u64(date_2017.unix_timestamp_nanos() as i64)], + "docvalue_fields": { + "date": [ SchemaValue::Date(DateTime::from_utc(date_2017)) ], + "text": [ "ccc" ], + "text2": [ "ddd" ], + "mixed.dyn_arr": [ 3, "4" ], + } + }, + { + "sort": [common::i64_to_u64(date_2016.unix_timestamp_nanos() as i64)], + "docvalue_fields": { + "date": [ SchemaValue::Date(DateTime::from_utc(date_2016)) ], + "text": [ "aaa" ], + "text2": [ "bbb" ], + "mixed.dyn_arr": [ 6, "7" ], + } + } + ] + }), + ); + + Ok(()) + } + + #[test] + fn test_aggregation_top_hits_single_segment() -> crate::Result<()> { + test_aggregation_top_hits(true) + } + + #[test] + fn test_aggregation_top_hits_multi_segment() -> crate::Result<()> { + test_aggregation_top_hits(false) + } +} diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index e575796477..570dc3f034 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -16,6 +16,7 @@ use super::metric::{ SumAggregation, }; use crate::aggregation::bucket::TermMissingAgg; +use crate::aggregation::metric::SegmentTopHitsCollector; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( @@ -160,6 +161,11 @@ pub(crate) fn build_single_agg_segment_collector( accessor_idx, )?, )), + TopHits(top_hits_req) => Ok(Box::new(SegmentTopHitsCollector::from_req( + top_hits_req, + accessor_idx, + req.segment_ordinal, + ))), } } diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 4d9b43d653..de6c69f280 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -97,6 +97,7 @@ pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit}; mod top_collector; mod top_score_collector; +pub use self::top_collector::ComparableDoc; pub use self::top_score_collector::{TopDocs, TopNComputer}; mod custom_score_top_collector; diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index ddb78c7b12..5a07e4218b 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,47 +1,58 @@ use std::cmp::Ordering; use std::marker::PhantomData; +use serde::{Deserialize, Serialize}; + use super::top_score_collector::TopNComputer; use crate::{DocAddress, DocId, SegmentOrdinal, SegmentReader}; /// Contains a feature (field, score, etc.) of a document along with the document address. /// -/// It has a custom implementation of `PartialOrd` that reverses the order. This is because the -/// default Rust heap is a max heap, whereas a min heap is needed. -/// -/// Additionally, it guarantees stable sorting: in case of a tie on the feature, the document +/// It guarantees stable sorting: in case of a tie on the feature, the document /// address is used. /// +/// The REVERSE_ORDER generic parameter controls whether the by-feature order +/// should be reversed, which is useful for achieving for example largest-first +/// semantics without having to wrap the feature in a `Reverse`. +/// /// WARNING: equality is not what you would expect here. /// Two elements are equal if their feature is equal, and regardless of whether `doc` /// is equal. This should be perfectly fine for this usage, but let's make sure this /// struct is never public. -pub(crate) struct ComparableDoc<T, D> { +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> { + /// The feature of the document. In practice, this is + /// is any type that implements `PartialOrd`. pub feature: T, + /// The document address. In practice, this is any + /// type that implements `PartialOrd`, and is guaranteed + /// to be unique for each document. pub doc: D, } -impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> { +impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug + for ComparableDoc<T, D, R> +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ComparableDoc") + f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str()) .field("feature", &self.feature) .field("doc", &self.doc) .finish() } } -impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> { +impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) } } -impl<T: PartialOrd, D: PartialOrd> Ord for ComparableDoc<T, D> { +impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> { #[inline] fn cmp(&self, other: &Self) -> Ordering { - // Reversed to make BinaryHeap work as a min-heap - let by_feature = other + let by_feature = self .feature - .partial_cmp(&self.feature) + .partial_cmp(&other.feature) + .map(|ord| if R { ord.reverse() } else { ord }) .unwrap_or(Ordering::Equal); let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal); @@ -53,13 +64,13 @@ impl<T: PartialOrd, D: PartialOrd> Ord for ComparableDoc<T, D> { } } -impl<T: PartialOrd, D: PartialOrd> PartialEq for ComparableDoc<T, D> { +impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> { fn eq(&self, other: &Self) -> bool { self.cmp(other) == Ordering::Equal } } -impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {} +impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {} pub(crate) struct TopCollector<T> { pub limit: usize, @@ -99,10 +110,10 @@ where T: PartialOrd + Clone if self.limit == 0 { return Ok(Vec::new()); } - let mut top_collector = TopNComputer::new(self.limit + self.offset); + let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset); for child_fruit in children { for (feature, doc) in child_fruit { - top_collector.push(ComparableDoc { feature, doc }); + top_collector.push(feature, doc); } } @@ -143,6 +154,8 @@ where T: PartialOrd + Clone /// The theoretical complexity for collecting the top `K` out of `n` documents /// is `O(n + K)`. pub(crate) struct TopSegmentCollector<T> { + /// We reverse the order of the feature in order to + /// have top-semantics instead of bottom semantics. topn_computer: TopNComputer<T, DocId>, segment_ord: u32, } @@ -180,7 +193,7 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> { /// will compare the lowest scoring item with the given one and keep whichever is greater. #[inline] pub fn collect(&mut self, doc: DocId, feature: T) { - self.topn_computer.push(ComparableDoc { feature, doc }); + self.topn_computer.push(feature, doc); } } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 8c6c49d9e4..b6312a3bda 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::sync::Arc; use columnar::ColumnValues; +use serde::{Deserialize, Serialize}; use super::Collector; use crate::collector::custom_score_top_collector::CustomScoreTopCollector; @@ -663,7 +664,7 @@ impl Collector for TopDocs { reader: &SegmentReader, ) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> { let heap_len = self.0.limit + self.0.offset; - let mut top_n = TopNComputer::new(heap_len); + let mut top_n: TopNComputer<_, _> = TopNComputer::new(heap_len); if let Some(alive_bitset) = reader.alive_bitset() { let mut threshold = Score::MIN; @@ -672,21 +673,13 @@ impl Collector for TopDocs { if alive_bitset.is_deleted(doc) { return threshold; } - let doc = ComparableDoc { - feature: score, - doc, - }; - top_n.push(doc); + top_n.push(score, doc); threshold = top_n.threshold.unwrap_or(Score::MIN); threshold })?; } else { weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { - let doc = ComparableDoc { - feature: score, - doc, - }; - top_n.push(doc); + top_n.push(score, doc); top_n.threshold.unwrap_or(Score::MIN) })?; } @@ -726,13 +719,15 @@ impl SegmentCollector for TopScoreSegmentCollector { /// Fast TopN Computation /// /// For TopN == 0, it will be relative expensive. -pub struct TopNComputer<Score, DocId> { - buffer: Vec<ComparableDoc<Score, DocId>>, +#[derive(Clone, Serialize, Deserialize)] +pub struct TopNComputer<Score, DocId, const REVERSE_ORDER: bool = true> { + /// The buffer reverses sort order to get top-semantics instead of bottom-semantics + buffer: Vec<ComparableDoc<Score, DocId, REVERSE_ORDER>>, top_n: usize, pub(crate) threshold: Option<Score>, } -impl<Score, DocId> TopNComputer<Score, DocId> +impl<Score, DocId, const R: bool> TopNComputer<Score, DocId, R> where Score: PartialOrd + Clone, DocId: Ord + Clone, @@ -748,10 +743,12 @@ where } } + /// Push a new document to the top n. + /// If the document is below the current threshold, it will be ignored. #[inline] - pub(crate) fn push(&mut self, doc: ComparableDoc<Score, DocId>) { + pub fn push(&mut self, feature: Score, doc: DocId) { if let Some(last_median) = self.threshold.clone() { - if doc.feature < last_median { + if feature < last_median { return; } } @@ -766,7 +763,7 @@ where let uninit = self.buffer.spare_capacity_mut(); // This cannot panic, because we truncate_median will at least remove one element, since // the min capacity is 2. - uninit[0].write(doc); + uninit[0].write(ComparableDoc { doc, feature }); // This is safe because it would panic in the line above unsafe { self.buffer.set_len(self.buffer.len() + 1); @@ -785,13 +782,24 @@ where median_score } - pub(crate) fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, DocId>> { + /// Returns the top n elements in sorted order. + pub fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, DocId, R>> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } self.buffer.sort_unstable(); self.buffer } + + /// Returns the top n elements in stored order. + /// Useful if you do not need the elements in sorted order, + /// for example when merging the results of multiple segments. + pub fn into_vec(mut self) -> Vec<ComparableDoc<Score, DocId, R>> { + if self.buffer.len() > self.top_n { + self.truncate_top_n(); + } + self.buffer + } } #[cfg(test)] @@ -830,44 +838,20 @@ mod tests { fn test_empty_topn_computer() { let mut computer: TopNComputer<u32, u32> = TopNComputer::new(0); - computer.push(ComparableDoc { - feature: 1u32, - doc: 1u32, - }); - computer.push(ComparableDoc { - feature: 1u32, - doc: 2u32, - }); - computer.push(ComparableDoc { - feature: 1u32, - doc: 3u32, - }); + computer.push(1u32, 1u32); + computer.push(1u32, 2u32); + computer.push(1u32, 3u32); assert!(computer.into_sorted_vec().is_empty()); } #[test] fn test_topn_computer() { let mut computer: TopNComputer<u32, u32> = TopNComputer::new(2); - computer.push(ComparableDoc { - feature: 1u32, - doc: 1u32, - }); - computer.push(ComparableDoc { - feature: 2u32, - doc: 2u32, - }); - computer.push(ComparableDoc { - feature: 3u32, - doc: 3u32, - }); - computer.push(ComparableDoc { - feature: 2u32, - doc: 4u32, - }); - computer.push(ComparableDoc { - feature: 1u32, - doc: 5u32, - }); + computer.push(1u32, 1u32); + computer.push(2u32, 2u32); + computer.push(3u32, 3u32); + computer.push(2u32, 4u32); + computer.push(1u32, 5u32); assert_eq!( computer.into_sorted_vec(), &[ @@ -889,10 +873,7 @@ mod tests { let mut computer: TopNComputer<u32, u32> = TopNComputer::new(top_n); for _ in 0..1 + top_n * 2 { - computer.push(ComparableDoc { - feature: 1u32, - doc: 1u32, - }); + computer.push(1u32, 1u32); } let _vals = computer.into_sorted_vec(); } diff --git a/src/lib.rs b/src/lib.rs index e14d02a6c9..f9aa19c5a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -338,7 +338,7 @@ impl DocAddress { /// /// The id used for the segment is actually an ordinal /// in the list of `Segment`s held by a `Searcher`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct DocAddress { /// The segment ordinal id that identifies the segment /// hosting the document in the `Searcher` it is called from.