Skip to content

Commit

Permalink
multi-col agg
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Aug 3, 2024
1 parent f4e519f commit 24283fb
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 68 deletions.
66 changes: 54 additions & 12 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray;
use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_common::utils::proxy::RawTableAllocExt;
use std::fmt::Debug;
use std::sync::Arc;

Expand Down Expand Up @@ -207,6 +207,7 @@ where
values,
make_payload_fn,
observe_payload_fn,
None,
)
}
OutputType::Utf8View => {
Expand All @@ -215,6 +216,43 @@ where
values,
make_payload_fn,
observe_payload_fn,
None,
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
};
}

/// Similar to [`Self::insert_if_new`] but allows the caller to provide the
/// hash values for the values in `values` instead of computing them
pub fn insert_if_new_with_hash<MP, OP>(
&mut self,
values: &ArrayRef,
make_payload_fn: MP,
observe_payload_fn: OP,
provided_hash: &Vec<u64>,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
{
// Sanity check array type
match self.output_type {
OutputType::BinaryView => {
assert!(matches!(values.data_type(), DataType::BinaryView));
self.insert_if_new_inner::<MP, OP, BinaryViewType>(
values,
make_payload_fn,
observe_payload_fn,
Some(provided_hash),
)
}
OutputType::Utf8View => {
assert!(matches!(values.data_type(), DataType::Utf8View));
self.insert_if_new_inner::<MP, OP, StringViewType>(
values,
make_payload_fn,
observe_payload_fn,
Some(provided_hash),
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
Expand All @@ -234,19 +272,26 @@ where
values: &ArrayRef,
mut make_payload_fn: MP,
mut observe_payload_fn: OP,
provided_hash: Option<&Vec<u64>>,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
B: ByteViewType,
{
// step 1: compute hashes
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
// hash is supported for all types and create_hashes only
// returns errors for unsupported types
.unwrap();
let batch_hashes = match provided_hash {
Some(h) => h,
None => {
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
// hash is supported for all types and create_hashes only
// returns errors for unsupported types
.unwrap();
batch_hashes
}
};

// step 2: insert each value into the set, if not already present
let values = values.as_byte_view::<B>();
Expand Down Expand Up @@ -353,9 +398,7 @@ where
/// Return the total size, in bytes, of memory used to store the data in
/// this set, not including `self`
pub fn size(&self) -> usize {
self.map_size
+ self.builder.allocated_size()
+ self.hashes_buffer.allocated_size()
self.map_size + self.builder.allocated_size()
}
}

Expand All @@ -369,7 +412,6 @@ where
.field("map_size", &self.map_size)
.field("view_builder", &self.builder)
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
}
}
Expand Down
Loading

0 comments on commit 24283fb

Please sign in to comment.