Skip to content

Commit

Permalink
chore!:drop JSON support on intermediate agg result (#1992)
Browse files Browse the repository at this point in the history
* chore!:drop JSON support on intermediate agg result

add support for other formats by removing skip_serialize and untagged
JSON support is broken anyway due it's lack on f64::INF etc. handling

* Update src/aggregation/intermediate_agg_result.rs

Co-authored-by: Paul Masurel <[email protected]>

* move from impl

---------

Co-authored-by: Paul Masurel <[email protected]>
  • Loading branch information
PSeitz and fulmicoton authored Apr 26, 2023
1 parent 80df1d9 commit c599bf3
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 97 deletions.
4 changes: 0 additions & 4 deletions src/aggregation/agg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::segment_agg_result::AggregationLimits;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
Expand Down Expand Up @@ -421,9 +420,6 @@ fn test_aggregation_level2(

let searcher = reader.searcher();
let res = searcher.search(&term_query, &collector).unwrap();
// Test de/serialization roundtrip on intermediate_agg_result
let res: IntermediateAggregationResults =
serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap();
res.into_final_result(agg_req.clone(), &Default::default())
.unwrap()
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/aggregation/bucket/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl SegmentRangeBucketEntry {
};

Ok(IntermediateRangeBucketEntry {
key: self.key,
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: sub_aggregation_res,
from: self.from,
Expand Down
16 changes: 6 additions & 10 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::f64_from_fastfield_u64;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateTermBucketEntry, IntermediateTermBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, Key};
use crate::error::DataCorruption;
use crate::TantivyError;

Expand All @@ -30,10 +30,6 @@ use crate::TantivyError;
/// Term aggregations work only on [fast fields](`crate::fastfield`) of type `u64`, `f64`, `i64` and
/// text.
///
/// ### Terminology
/// Shard parameters are supposed to be equivalent to elasticsearch shard parameter.
/// Since they are
///
/// ## Document count error
/// To improve performance, results from one segment are cut off at `segment_size`. On a index with
/// a single segment this is fine. When combining results of multiple segments, terms that
Expand Down Expand Up @@ -402,7 +398,7 @@ impl SegmentTermCollector {
cut_off_buckets(&mut entries, self.req.segment_size as usize)
};

let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());

let mut into_intermediate_bucket_entry =
Expand Down Expand Up @@ -453,7 +449,7 @@ impl SegmentTermCollector {

let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;

dict.insert(Key::Str(buffer.to_string()), intermediate_entry);
dict.insert(IntermediateKey::Str(buffer.to_string()), intermediate_entry);
}
if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
Expand All @@ -463,7 +459,7 @@ impl SegmentTermCollector {
break;
}

let key = Key::Str(
let key = IntermediateKey::Str(
std::str::from_utf8(key)
.map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?
.to_string(),
Expand All @@ -475,7 +471,7 @@ impl SegmentTermCollector {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = f64_from_fastfield_u64(val, &self.field_type);
dict.insert(Key::F64(val), intermediate_entry);
dict.insert(IntermediateKey::F64(val), intermediate_entry);
}
};

Expand Down
137 changes: 55 additions & 82 deletions src/aggregation/intermediate_agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
//! indices.
use std::cmp::Ordering;
use std::hash::Hash;

use columnar::ColumnType;
use itertools::Itertools;
use rustc_hash::FxHashMap;
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};

use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry};
Expand All @@ -29,11 +29,52 @@ use crate::TantivyError;

/// Contains the intermediate aggregation result, which is optimized to be merged with other
/// intermediate results.
///
/// Notice: This struct should not be de/serialized via JSON format.
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateAggregationResults {
pub(crate) aggs_res: VecWithNames<IntermediateAggregationResult>,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, PartialEq)]
/// The key to identify a bucket.
/// This might seem redundant with `Key`, but the point is to have a different
/// Serialize implementation.
pub enum IntermediateKey {
/// String key
Str(String),
/// `f64` key
F64(f64),
}
impl From<Key> for IntermediateKey {
fn from(value: Key) -> Self {
match value {
Key::Str(s) => Self::Str(s),
Key::F64(f) => Self::F64(f),
}
}
}
impl From<IntermediateKey> for Key {
fn from(value: IntermediateKey) -> Self {
match value {
IntermediateKey::Str(s) => Self::Str(s),
IntermediateKey::F64(f) => Self::F64(f),
}
}
}

impl Eq for IntermediateKey {}

impl std::hash::Hash for IntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
IntermediateKey::Str(text) => text.hash(state),
IntermediateKey::F64(val) => val.to_bits().hash(state),
}
}
}

impl IntermediateAggregationResults {
/// Add a result
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) {
Expand Down Expand Up @@ -387,7 +428,7 @@ impl IntermediateBucketResult {
IntermediateBucketResult::Terms(term_res_left),
IntermediateBucketResult::Terms(term_res_right),
) => {
merge_key_maps(&mut term_res_left.entries, term_res_right.entries)?;
merge_maps(&mut term_res_left.entries, term_res_right.entries)?;
term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count;
term_res_left.doc_count_error_upper_bound +=
term_res_right.doc_count_error_upper_bound;
Expand All @@ -397,7 +438,7 @@ impl IntermediateBucketResult {
IntermediateBucketResult::Range(range_res_left),
IntermediateBucketResult::Range(range_res_right),
) => {
merge_serialized_key_maps(&mut range_res_left.buckets, range_res_right.buckets)?;
merge_maps(&mut range_res_left.buckets, range_res_right.buckets)?;
}
(
IntermediateBucketResult::Histogram {
Expand Down Expand Up @@ -451,39 +492,11 @@ pub struct IntermediateRangeBucketResult {
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
/// Term aggregation including error counts
pub struct IntermediateTermBucketResult {
#[serde(
serialize_with = "serialize_entries",
deserialize_with = "deserialize_entries"
)]
pub(crate) entries: FxHashMap<Key, IntermediateTermBucketEntry>,
pub(crate) entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry>,
pub(crate) sum_other_doc_count: u64,
pub(crate) doc_count_error_upper_bound: u64,
}

// Serialize into a Vec to circument the JSON limitation, where keys can't be numbers
fn serialize_entries<S>(
entries: &FxHashMap<Key, IntermediateTermBucketEntry>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(entries.len()))?;
for (k, v) in entries {
seq.serialize_element(&(k, v))?;
}
seq.end()
}

fn deserialize_entries<'de, D>(
deserializer: D,
) -> Result<FxHashMap<Key, IntermediateTermBucketEntry>, D::Error>
where D: Deserializer<'de> {
let vec_entries: Vec<(Key, IntermediateTermBucketEntry)> =
Deserialize::deserialize(deserializer)?;
Ok(vec_entries.into_iter().collect())
}

impl IntermediateTermBucketResult {
pub(crate) fn into_final_result(
self,
Expand All @@ -499,7 +512,7 @@ impl IntermediateTermBucketResult {
.map(|(key, entry)| {
Ok(BucketEntry {
key_as_string: None,
key,
key: key.into(),
doc_count: entry.doc_count,
sub_aggregation: entry
.sub_aggregation
Expand Down Expand Up @@ -577,25 +590,9 @@ trait MergeFruits {
fn merge_fruits(&mut self, other: Self) -> crate::Result<()>;
}

fn merge_serialized_key_maps<V: MergeFruits + Clone>(
entries_left: &mut FxHashMap<SerializedKey, V>,
mut entries_right: FxHashMap<SerializedKey, V>,
) -> crate::Result<()> {
for (name, entry_left) in entries_left.iter_mut() {
if let Some(entry_right) = entries_right.remove(name) {
entry_left.merge_fruits(entry_right)?;
}
}

for (key, res) in entries_right.into_iter() {
entries_left.entry(key).or_insert(res);
}
Ok(())
}

fn merge_key_maps<V: MergeFruits + Clone>(
entries_left: &mut FxHashMap<Key, V>,
mut entries_right: FxHashMap<Key, V>,
fn merge_maps<V: MergeFruits + Clone, T: Eq + PartialEq + Hash>(
entries_left: &mut FxHashMap<T, V>,
mut entries_right: FxHashMap<T, V>,
) -> crate::Result<()> {
for (name, entry_left) in entries_left.iter_mut() {
if let Some(entry_right) = entries_right.remove(name) {
Expand Down Expand Up @@ -652,17 +649,15 @@ impl From<SegmentHistogramBucketEntry> for IntermediateHistogramBucketEntry {
/// sub_aggregations.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateRangeBucketEntry {
/// The unique the bucket is identified.
pub key: Key,
/// The unique key the bucket is identified with.
pub key: IntermediateKey,
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
#[serde(skip_serializing_if = "Option::is_none")]
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
#[serde(skip_serializing_if = "Option::is_none")]
pub to: Option<f64>,
}

Expand All @@ -675,7 +670,7 @@ impl IntermediateRangeBucketEntry {
limits: &AggregationLimits,
) -> crate::Result<RangeBucketEntry> {
let mut range_bucket_entry = RangeBucketEntry {
key: self.key,
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
Expand Down Expand Up @@ -752,7 +747,7 @@ mod tests {
buckets.insert(
key.to_string(),
IntermediateRangeBucketEntry {
key: Key::Str(key.to_string()),
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation: Default::default(),
from: None,
Expand Down Expand Up @@ -783,7 +778,7 @@ mod tests {
buckets.insert(
key.to_string(),
IntermediateRangeBucketEntry {
key: Key::Str(key.to_string()),
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
from: None,
to: None,
Expand Down Expand Up @@ -866,26 +861,4 @@ mod tests {

assert_eq!(tree_left, orig);
}

#[test]
fn test_term_bucket_json_roundtrip() {
let term_buckets = IntermediateTermBucketResult {
entries: vec![(
Key::F64(5.0),
IntermediateTermBucketEntry {
doc_count: 10,
sub_aggregation: Default::default(),
},
)]
.into_iter()
.collect(),
sum_other_doc_count: 0,
doc_count_error_upper_bound: 0,
};

let term_buckets_round: IntermediateTermBucketResult =
serde_json::from_str(&serde_json::to_string(&term_buckets).unwrap()).unwrap();

assert_eq!(term_buckets, term_buckets_round);
}
}
6 changes: 6 additions & 0 deletions src/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
//! ## JSON Format
//! Aggregations request and result structures de/serialize into elasticsearch compatible JSON.
//!
//! Notice: Intermediate aggregation results should not be de/serialized via JSON format.
//! See compatibility tests here: https://github.com/PSeitz/test_serde_formats
//! TLDR: use ciborium.
//!
//! ```verbatim
//! let agg_req: Aggregations = serde_json::from_str(json_request_string).unwrap();
//! let collector = AggregationCollector::from_aggs(agg_req, None);
Expand Down Expand Up @@ -151,6 +155,8 @@ pub use error::AggregationError;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use self::intermediate_agg_result::IntermediateKey;

/// Represents an associative array `(key => values)` in a very efficient manner.
#[derive(Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct VecWithNames<T: Clone> {
Expand Down

0 comments on commit c599bf3

Please sign in to comment.