From 2ce3ecc85e5c3172f612af82c5257bf2a9e6d738 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 25 May 2023 11:34:52 +0100 Subject: [PATCH 1/4] Buffer Pages in ArrowWriter instead of RecordBatch (#3871) --- parquet/src/arrow/arrow_writer/byte_array.rs | 57 +- parquet/src/arrow/arrow_writer/mod.rs | 608 ++++++++++--------- parquet/src/column/page.rs | 69 +++ parquet/src/column/writer/encoder.rs | 2 +- parquet/src/column/writer/mod.rs | 22 +- parquet/src/file/writer.rs | 106 +--- parquet/src/util/memory.rs | 6 + 7 files changed, 454 insertions(+), 416 deletions(-) diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs b/parquet/src/arrow/arrow_writer/byte_array.rs index 77f9598b23fe..6dbc83dd05c4 100644 --- a/parquet/src/arrow/arrow_writer/byte_array.rs +++ b/parquet/src/arrow/arrow_writer/byte_array.rs @@ -15,25 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::arrow_writer::levels::LevelInfo; use crate::basic::Encoding; use crate::bloom_filter::Sbbf; -use crate::column::page::PageWriter; use crate::column::writer::encoder::{ ColumnValueEncoder, DataPageValues, DictionaryPage, }; -use crate::column::writer::GenericColumnWriter; use crate::data_type::{AsBytes, ByteArray, Int32Type}; use crate::encodings::encoding::{DeltaBitPackEncoder, Encoder}; use crate::encodings::rle::RleEncoder; use crate::errors::{ParquetError, Result}; -use crate::file::properties::{WriterProperties, WriterPropertiesPtr, WriterVersion}; -use crate::file::writer::OnCloseColumnChunk; +use crate::file::properties::{WriterProperties, WriterVersion}; use crate::schema::types::ColumnDescPtr; use crate::util::bit_util::num_required_bits; use crate::util::interner::{Interner, Storage}; use arrow_array::{ - Array, ArrayAccessor, ArrayRef, BinaryArray, DictionaryArray, LargeBinaryArray, + Array, ArrayAccessor, BinaryArray, DictionaryArray, LargeBinaryArray, LargeStringArray, StringArray, }; use arrow_schema::DataType; @@ -94,49 +90,6 @@ macro_rules! downcast_op { }; } -/// A writer for byte array types -pub(super) struct ByteArrayWriter<'a> { - writer: GenericColumnWriter<'a, ByteArrayEncoder>, - on_close: Option>, -} - -impl<'a> ByteArrayWriter<'a> { - /// Returns a new [`ByteArrayWriter`] - pub fn new( - descr: ColumnDescPtr, - props: WriterPropertiesPtr, - page_writer: Box, - on_close: OnCloseColumnChunk<'a>, - ) -> Result { - Ok(Self { - writer: GenericColumnWriter::new(descr, props, page_writer), - on_close: Some(on_close), - }) - } - - pub fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> { - self.writer.write_batch_internal( - array, - Some(levels.non_null_indices()), - levels.def_levels(), - levels.rep_levels(), - None, - None, - None, - )?; - Ok(()) - } - - pub fn close(self) -> Result<()> { - let r = self.writer.close()?; - - if let Some(on_close) = self.on_close { - on_close(r)?; - } - Ok(()) - } -} - /// A fallback encoder, i.e. non-dictionary, for [`ByteArray`] struct FallbackEncoder { encoder: FallbackEncoderImpl, @@ -427,7 +380,7 @@ impl DictEncoder { } } -struct ByteArrayEncoder { +pub struct ByteArrayEncoder { fallback: FallbackEncoder, dict_encoder: Option, min_value: Option, @@ -437,11 +390,11 @@ struct ByteArrayEncoder { impl ColumnValueEncoder for ByteArrayEncoder { type T = ByteArray; - type Values = ArrayRef; + type Values = dyn Array; fn min_max( &self, - values: &ArrayRef, + values: &dyn Array, value_indices: Option<&[usize]>, ) -> Option<(Self::T, Self::T)> { match value_indices { diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 08cfc7ea3ebf..c265058c11d5 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -17,16 +17,21 @@ //! Contains writer which writes arrow data into parquet data. -use std::collections::VecDeque; +use bytes::Bytes; use std::fmt::Debug; -use std::io::Write; -use std::sync::Arc; +use std::io::{Read, Write}; +use std::iter::Peekable; +use std::slice::Iter; +use std::sync::{Arc, Mutex}; +use std::vec::IntoIter; +use thrift::protocol::{TCompactOutputProtocol, TSerializable}; use arrow_array::cast::AsArray; -use arrow_array::types::{Decimal128Type, Int32Type, Int64Type, UInt32Type, UInt64Type}; -use arrow_array::{ - types, Array, ArrayRef, FixedSizeListArray, RecordBatch, RecordBatchWriter, +use arrow_array::types::{ + Decimal128Type, Float32Type, Float64Type, Int32Type, Int64Type, UInt32Type, + UInt64Type, }; +use arrow_array::{Array, FixedSizeListArray, RecordBatch, RecordBatchWriter}; use arrow_schema::{ArrowError, DataType as ArrowDataType, IntervalUnit, SchemaRef}; use super::schema::{ @@ -34,14 +39,19 @@ use super::schema::{ decimal_length_from_precision, }; -use crate::arrow::arrow_writer::byte_array::ByteArrayWriter; -use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; -use crate::data_type::{ByteArray, DataType, FixedLenByteArray}; +use crate::arrow::arrow_writer::byte_array::ByteArrayEncoder; +use crate::column::page::{CompressedPage, PageWriteSpec, PageWriter}; +use crate::column::writer::encoder::ColumnValueEncoder; +use crate::column::writer::{ + get_column_writer, ColumnCloseResult, ColumnWriter, GenericColumnWriter, +}; +use crate::data_type::{ByteArray, FixedLenByteArray}; use crate::errors::{ParquetError, Result}; -use crate::file::metadata::{KeyValue, RowGroupMetaDataPtr}; -use crate::file::properties::WriterProperties; +use crate::file::metadata::{ColumnChunkMetaData, KeyValue, RowGroupMetaDataPtr}; +use crate::file::properties::{WriterProperties, WriterPropertiesPtr}; +use crate::file::reader::{ChunkReader, Length}; use crate::file::writer::SerializedFileWriter; -use crate::file::writer::SerializedRowGroupWriter; +use crate::schema::types::{ColumnDescPtr, SchemaDescriptor}; use levels::{calculate_array_levels, LevelInfo}; mod byte_array; @@ -49,8 +59,8 @@ mod levels; /// Arrow writer /// -/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order -/// to produce row groups with `max_row_group_size` rows. Any remaining rows will be +/// Writes Arrow `RecordBatch`es to a Parquet writer. Multiple [`RecordBatch`] will be encoded +/// to the same row group, up to `max_row_group_size` rows. Any remaining rows will be /// flushed on close, leading the final row group in the output file to potentially /// contain fewer than `max_row_group_size` rows /// @@ -78,10 +88,10 @@ pub struct ArrowWriter { /// Underlying Parquet writer writer: SerializedFileWriter, - /// For each column, maintain an ordered queue of arrays to write - buffer: Vec>, + /// The in-progress row group if any + in_progress: Option, - /// The total number of rows currently buffered + /// The total number of rows in the in_progress row group, if any buffered_rows: usize, /// A copy of the Arrow schema. @@ -95,21 +105,10 @@ pub struct ArrowWriter { impl Debug for ArrowWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let buffered_batches = self.buffer.len(); - let mut buffered_memory = 0; - - for batch in self.buffer.iter() { - for arr in batch.iter() { - buffered_memory += arr.get_array_memory_size() - } - } - + let buffered_memory = self.in_progress_size(); f.debug_struct("ArrowWriter") .field("writer", &self.writer) - .field( - "buffer", - &format!("{buffered_batches} , {buffered_memory} bytes"), - ) + .field("buffer", &format_args!("{buffered_memory} bytes")) .field("buffered_rows", &self.buffered_rows) .field("arrow_schema", &self.arrow_schema) .field("max_row_group_size", &self.max_row_group_size) @@ -140,7 +139,7 @@ impl ArrowWriter { Ok(Self { writer: file_writer, - buffer: vec![Default::default(); arrow_schema.fields().len()], + in_progress: None, buffered_rows: 0, arrow_schema, max_row_group_size, @@ -152,43 +151,69 @@ impl ArrowWriter { self.writer.flushed_row_groups() } - /// Enqueues the provided `RecordBatch` to be written + /// Returns the length in bytes of the current in progress row group + pub fn in_progress_size(&self) -> usize { + match &self.in_progress { + Some(in_progress) => in_progress + .writers + .iter() + .map(|(x, _)| x.lock().unwrap().length) + .sum(), + None => 0, + } + } + + /// Encodes the provided [`RecordBatch`] /// - /// If following this there are more than `max_row_group_size` rows buffered, - /// this will flush out one or more row groups with `max_row_group_size` rows, - /// and drop any fully written `RecordBatch` + /// If this would cause the current row group to exceed [`WriterProperties::max_row_group_size`] + /// rows, the contents of `batch` will be distributed across multiple row groups such that all + /// but the final row group in the file contain [`WriterProperties::max_row_group_size`] rows pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { - // validate batch schema against writer's supplied schema - let batch_schema = batch.schema(); - if !(Arc::ptr_eq(&self.arrow_schema, &batch_schema) - || self.arrow_schema.contains(&batch_schema)) - { - return Err(ParquetError::ArrowError( - "Record batch schema does not match writer schema".to_string(), - )); + if batch.num_rows() == 0 { + return Ok(()); } - for (buffer, column) in self.buffer.iter_mut().zip(batch.columns()) { - buffer.push_back(column.clone()) + // If would exceed max_row_group_size, split batch + if self.buffered_rows + batch.num_rows() > self.max_row_group_size { + let to_write = self.max_row_group_size - self.buffered_rows; + let a = batch.slice(0, to_write); + let b = batch.slice(to_write, batch.num_rows() - to_write); + self.write(&a)?; + return self.write(&b); } self.buffered_rows += batch.num_rows(); - self.flush_completed()?; + let in_progress = match &mut self.in_progress { + Some(in_progress) => in_progress, + x => x.insert(ArrowRowGroupWriter::new( + self.writer.schema_descr(), + self.writer.properties(), + &self.arrow_schema, + )?), + }; - Ok(()) - } + in_progress.write(batch)?; - /// Flushes buffered data until there are less than `max_row_group_size` rows buffered - fn flush_completed(&mut self) -> Result<()> { - while self.buffered_rows >= self.max_row_group_size { - self.flush_rows(self.max_row_group_size)?; + if self.buffered_rows >= self.max_row_group_size { + self.flush()? } Ok(()) } /// Flushes all buffered rows into a new row group pub fn flush(&mut self) -> Result<()> { - self.flush_rows(self.buffered_rows) + let in_progress = match self.in_progress.take() { + Some(in_progress) => in_progress, + None => return Ok(()), + }; + + self.buffered_rows = 0; + let mut row_group_writer = self.writer.next_row_group()?; + for (chunk, close) in in_progress.close()? { + row_group_writer.append_column(&chunk, close)?; + } + row_group_writer.close()?; + Ok(()) } /// Additional [`KeyValue`] metadata to be written in addition to those from [`WriterProperties`] @@ -198,68 +223,6 @@ impl ArrowWriter { self.writer.append_key_value_metadata(kv_metadata) } - /// Flushes `num_rows` from the buffer into a new row group - fn flush_rows(&mut self, num_rows: usize) -> Result<()> { - if num_rows == 0 { - return Ok(()); - } - - assert!( - num_rows <= self.buffered_rows, - "cannot flush {} rows only have {}", - num_rows, - self.buffered_rows - ); - - assert!( - num_rows <= self.max_row_group_size, - "cannot flush {} rows would exceed max row group size of {}", - num_rows, - self.max_row_group_size - ); - - let mut row_group_writer = self.writer.next_row_group()?; - - for (col_buffer, field) in self.buffer.iter_mut().zip(self.arrow_schema.fields()) - { - // Collect the number of arrays to append - let mut remaining = num_rows; - let mut arrays = Vec::with_capacity(col_buffer.len()); - while remaining != 0 { - match col_buffer.pop_front() { - Some(next) if next.len() > remaining => { - col_buffer - .push_front(next.slice(remaining, next.len() - remaining)); - arrays.push(next.slice(0, remaining)); - remaining = 0; - } - Some(next) => { - remaining -= next.len(); - arrays.push(next); - } - _ => break, - } - } - - let mut levels = arrays - .iter() - .map(|array| { - let mut levels = calculate_array_levels(array, field)?; - // Reverse levels as we pop() them when writing arrays - levels.reverse(); - Ok(levels) - }) - .collect::>>()?; - - write_leaves(&mut row_group_writer, &arrays, &mut levels)?; - } - - row_group_writer.close()?; - self.buffered_rows -= num_rows; - - Ok(()) - } - /// Flushes any outstanding data and returns the underlying writer. pub fn into_inner(mut self) -> Result { self.flush()?; @@ -284,156 +247,271 @@ impl RecordBatchWriter for ArrowWriter { } } -fn write_leaves( - row_group_writer: &mut SerializedRowGroupWriter<'_, W>, - arrays: &[ArrayRef], - levels: &mut [Vec], -) -> Result<()> { - assert_eq!(arrays.len(), levels.len()); - assert!(!arrays.is_empty()); - - let data_type = arrays.first().unwrap().data_type().clone(); - assert!(arrays.iter().all(|a| a.data_type() == &data_type)); - - match &data_type { - ArrowDataType::Null - | ArrowDataType::Boolean - | ArrowDataType::Int8 - | ArrowDataType::Int16 - | ArrowDataType::Int32 - | ArrowDataType::Int64 - | ArrowDataType::UInt8 - | ArrowDataType::UInt16 - | ArrowDataType::UInt32 - | ArrowDataType::UInt64 - | ArrowDataType::Float32 - | ArrowDataType::Float64 - | ArrowDataType::Timestamp(_, _) - | ArrowDataType::Date32 - | ArrowDataType::Date64 - | ArrowDataType::Time32(_) - | ArrowDataType::Time64(_) - | ArrowDataType::Duration(_) - | ArrowDataType::Interval(_) - | ArrowDataType::Decimal128(_, _) - | ArrowDataType::Decimal256(_, _) - | ArrowDataType::FixedSizeBinary(_) => { - let mut col_writer = row_group_writer.next_column()?.unwrap(); - for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - write_leaf(col_writer.untyped(), array, levels.pop().expect("Levels exhausted"))?; +/// A list of [`Bytes`] comprising a single column chunk +#[derive(Default)] +struct ArrowColumnChunk { + length: usize, + data: Vec, +} + +impl Length for ArrowColumnChunk { + fn len(&self) -> u64 { + self.length as _ + } +} + +impl ChunkReader for ArrowColumnChunk { + type T = ChainReader; + + fn get_read(&self, start: u64) -> Result { + assert_eq!(start, 0); + Ok(ChainReader(self.data.clone().into_iter().peekable())) + } + + fn get_bytes(&self, _start: u64, _length: usize) -> Result { + unimplemented!() + } +} + +/// A [`Read`] for an iterator of [`Bytes`] +struct ChainReader(Peekable>); + +impl Read for ChainReader { + fn read(&mut self, out: &mut [u8]) -> std::io::Result { + let buffer = loop { + match self.0.peek_mut() { + Some(b) if b.is_empty() => { + self.0.next(); + continue; + } + Some(b) => break b, + None => return Ok(0), } - col_writer.close() + }; + + let len = buffer.len().min(out.len()); + let b = buffer.split_to(len); + out[..len].copy_from_slice(&b); + Ok(len) + } +} + +/// A shared [`ArrowColumnChunk`] +/// +/// This allows it to be owned by [`ArrowPageWriter`] whilst allowing access via +/// [`ArrowRowGroupWriter`] on flush, without requiring self-referential borrows +type SharedColumnChunk = Arc>; + +#[derive(Default)] +struct ArrowPageWriter { + buffer: SharedColumnChunk, +} + +impl PageWriter for ArrowPageWriter { + fn write_page(&mut self, page: CompressedPage) -> Result { + let mut buf = self.buffer.try_lock().unwrap(); + let page_header = page.to_thrift_header(); + let header = { + let mut header = Vec::with_capacity(1024); + let mut protocol = TCompactOutputProtocol::new(&mut header); + page_header.write_to_out_protocol(&mut protocol)?; + Bytes::from(header) + }; + + let data = page.compressed_page().buffer().clone(); + let compressed_size = data.len() + header.len(); + + let mut spec = PageWriteSpec::new(); + spec.page_type = page.page_type(); + spec.num_values = page.num_values(); + spec.uncompressed_size = page.uncompressed_size() + header.len(); + spec.offset = buf.length as u64; + spec.compressed_size = compressed_size; + spec.bytes_written = compressed_size as u64; + + buf.length += compressed_size; + buf.data.push(header); + buf.data.push(data.into()); + + Ok(spec) + } + + fn write_metadata(&mut self, _metadata: &ColumnChunkMetaData) -> Result<()> { + // Skip writing metadata as won't be copied anyway + Ok(()) + } + + fn close(&mut self) -> Result<()> { + Ok(()) + } +} + +/// Encodes a leaf column to [`ArrowPageWriter`] +enum ArrowColumnWriter { + ByteArray(GenericColumnWriter<'static, ByteArrayEncoder>), + Column(ColumnWriter<'static>), +} + +/// Encodes [`RecordBatch`] to a parquet row group +struct ArrowRowGroupWriter { + writers: Vec<(SharedColumnChunk, ArrowColumnWriter)>, + schema: SchemaRef, +} + +impl ArrowRowGroupWriter { + fn new( + parquet: &SchemaDescriptor, + props: &WriterPropertiesPtr, + arrow: &SchemaRef, + ) -> Result { + let mut writers = Vec::with_capacity(arrow.fields.len()); + let mut leaves = parquet.columns().iter(); + for field in &arrow.fields { + get_arrow_column_writer(field.data_type(), props, &mut leaves, &mut writers)?; } + Ok(Self { + writers, + schema: arrow.clone(), + }) + } + + fn write(&mut self, batch: &RecordBatch) -> Result<()> { + let mut writers = self.writers.iter_mut().map(|(_, x)| x); + for (array, field) in batch.columns().iter().zip(&self.schema.fields) { + let mut levels = calculate_array_levels(array, field)?.into_iter(); + write_leaves(&mut writers, &mut levels, array.as_ref())?; + } + Ok(()) + } + + fn close(self) -> Result> { + self.writers + .into_iter() + .map(|(chunk, writer)| { + let close_result = match writer { + ArrowColumnWriter::ByteArray(c) => c.close()?, + ArrowColumnWriter::Column(c) => c.close()?, + }; + + let chunk = Arc::try_unwrap(chunk).ok().unwrap().into_inner().unwrap(); + Ok((chunk, close_result)) + }) + .collect() + } +} + +/// Get an [`ArrowColumnWriter`] along with a reference to its [`SharedColumnChunk`] +fn get_arrow_column_writer( + data_type: &ArrowDataType, + props: &WriterPropertiesPtr, + leaves: &mut Iter<'_, ColumnDescPtr>, + out: &mut Vec<(SharedColumnChunk, ArrowColumnWriter)>, +) -> Result<()> { + let col = |desc: &ColumnDescPtr| { + let page_writer = Box::::default(); + let chunk = page_writer.buffer.clone(); + let writer = get_column_writer(desc.clone(), props.clone(), page_writer); + (chunk, ArrowColumnWriter::Column(writer)) + }; + + let bytes = |desc: &ColumnDescPtr| { + let page_writer = Box::::default(); + let chunk = page_writer.buffer.clone(); + let writer = GenericColumnWriter::new(desc.clone(), props.clone(), page_writer); + (chunk, ArrowColumnWriter::ByteArray(writer)) + }; + + match data_type { + _ if data_type.is_primitive() => out.push(col(leaves.next().unwrap())), + ArrowDataType::FixedSizeBinary(_) | ArrowDataType::Boolean | ArrowDataType::Null => out.push(col(leaves.next().unwrap())), ArrowDataType::LargeBinary | ArrowDataType::Binary | ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { - let mut col_writer = row_group_writer.next_column_with_factory(ByteArrayWriter::new)?.unwrap(); - for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - col_writer.write(array, levels.pop().expect("Levels exhausted"))?; - } - col_writer.close() + out.push(bytes(leaves.next().unwrap())) } - ArrowDataType::List(_) => { - let arrays: Vec<_> = arrays.iter().map(|array|{ - array.as_list::().values().clone() - }).collect(); - - write_leaves(row_group_writer, &arrays, levels)?; - Ok(()) - } - ArrowDataType::LargeList(_) => { - let arrays: Vec<_> = arrays.iter().map(|array|{ - array.as_list::().values().clone() - }).collect(); - write_leaves(row_group_writer, &arrays, levels)?; - Ok(()) + ArrowDataType::List(f) + | ArrowDataType::LargeList(f) + | ArrowDataType::FixedSizeList(f, _) => { + get_arrow_column_writer(f.data_type(), props, leaves, out)? } ArrowDataType::Struct(fields) => { - // Groups child arrays by field - let mut field_arrays = vec![Vec::with_capacity(arrays.len()); fields.len()]; - - for array in arrays { - let struct_array: &arrow_array::StructArray = array - .as_any() - .downcast_ref::() - .expect("Unable to get struct array"); - - assert_eq!(struct_array.columns().len(), fields.len()); - - for (child_array, field) in field_arrays.iter_mut().zip(struct_array.columns()) { - child_array.push(field.clone()) - } + for field in fields { + get_arrow_column_writer(field.data_type(), props, leaves, out)? } - - for field in field_arrays { - write_leaves(row_group_writer, &field, levels)?; - } - - Ok(()) } - ArrowDataType::Map(_, _) => { - let mut keys = Vec::with_capacity(arrays.len()); - let mut values = Vec::with_capacity(arrays.len()); - for array in arrays { - let map_array: &arrow_array::MapArray = array - .as_any() - .downcast_ref::() - .expect("Unable to get map array"); - keys.push(map_array.keys().clone()); - values.push(map_array.values().clone()); + ArrowDataType::Map(f, _) => match f.data_type() { + ArrowDataType::Struct(f) => { + get_arrow_column_writer(f[0].data_type(), props, leaves, out)?; + get_arrow_column_writer(f[1].data_type(), props, leaves, out)? } - - write_leaves(row_group_writer, &keys, levels)?; - write_leaves(row_group_writer, &values, levels)?; - Ok(()) + _ => unreachable!("invalid map type"), } ArrowDataType::Dictionary(_, value_type) => match value_type.as_ref() { ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Binary | ArrowDataType::LargeBinary => { - let mut col_writer = row_group_writer.next_column_with_factory(ByteArrayWriter::new)?.unwrap(); - for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - col_writer.write(array, levels.pop().expect("Levels exhausted"))?; - } - col_writer.close() + out.push(bytes(leaves.next().unwrap())) } _ => { - let mut col_writer = row_group_writer.next_column()?.unwrap(); - for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - write_leaf(col_writer.untyped(), array, levels.pop().expect("Levels exhausted"))?; - } - col_writer.close() + out.push(col(leaves.next().unwrap())) } } - ArrowDataType::Float16 => Err(ParquetError::ArrowError( - "Float16 arrays not supported".to_string(), - )), + _ => return Err(ParquetError::NYI( + format!( + "Attempting to write an Arrow type {data_type:?} to parquet that is not yet implemented" + ) + )) + } + Ok(()) +} + +/// Write the leaves of `array` in depth-first order to `writers` with `levels` +fn write_leaves<'a, W>( + writers: &mut W, + levels: &mut IntoIter, + array: &(dyn Array + 'static), +) -> Result<()> +where + W: Iterator, +{ + match array.data_type() { + ArrowDataType::List(_) => { + write_leaves(writers, levels, array.as_list::().values().as_ref())? + } + ArrowDataType::LargeList(_) => { + write_leaves(writers, levels, array.as_list::().values().as_ref())? + } ArrowDataType::FixedSizeList(_, _) => { - let arrays: Vec<_> = arrays.iter().map(|array|{ - array.as_any().downcast_ref::() - .expect("unable to get fixed-size list array") - .values() - .clone() - }).collect(); - write_leaves(row_group_writer, &arrays, levels)?; - Ok(()) - }, - ArrowDataType::Union(_, _) | ArrowDataType::RunEndEncoded(_, _) => { - Err(ParquetError::NYI( - format!( - "Attempting to write an Arrow type {data_type:?} to parquet that is not yet implemented" - ) - )) + let array = array.as_any().downcast_ref::().unwrap(); + write_leaves(writers, levels, array.values().as_ref())? + } + ArrowDataType::Struct(_) => { + for column in array.as_struct().columns() { + write_leaves(writers, levels, column.as_ref())? + } + } + ArrowDataType::Map(_, _) => { + let map = array.as_map(); + write_leaves(writers, levels, map.keys().as_ref())?; + write_leaves(writers, levels, map.values().as_ref())? + } + _ => { + let levels = levels.next().unwrap(); + match writers.next().unwrap() { + ArrowColumnWriter::Column(c) => write_leaf(c, array, levels)?, + ArrowColumnWriter::ByteArray(c) => write_primitive(c, array, levels)?, + }; } } + Ok(()) } fn write_leaf( writer: &mut ColumnWriter<'_>, - column: &ArrayRef, + column: &dyn Array, levels: LevelInfo, -) -> Result { +) -> Result { let indices = levels.non_null_indices(); - let written = match writer { + match writer { ColumnWriter::Int32ColumnWriter(ref mut typed) => { match column.data_type() { ArrowDataType::Date64 => { @@ -442,26 +520,26 @@ fn write_leaf( let array = arrow_cast::cast(&array, &ArrowDataType::Int32)?; let array = array.as_primitive::(); - write_primitive(typed, array.values(), levels)? + write_primitive(typed, array.values(), levels) } ArrowDataType::UInt32 => { let values = column.as_primitive::().values(); // follow C++ implementation and use overflow/reinterpret cast from u32 to i32 which will map // `(i32::MAX as u32)..u32::MAX` to `i32::MIN..0` let array = values.inner().typed_data::(); - write_primitive(typed, array, levels)? + write_primitive(typed, array, levels) } ArrowDataType::Decimal128(_, _) => { // use the int32 to represent the decimal with low precision let array = column .as_primitive::() - .unary::<_, types::Int32Type>(|v| v as i32); - write_primitive(typed, array.values(), levels)? + .unary::<_, Int32Type>(|v| v as i32); + write_primitive(typed, array.values(), levels) } _ => { let array = arrow_cast::cast(column, &ArrowDataType::Int32)?; let array = array.as_primitive::(); - write_primitive(typed, array.values(), levels)? + write_primitive(typed, array.values(), levels) } } } @@ -471,32 +549,32 @@ fn write_leaf( get_bool_array_slice(array, indices).as_slice(), levels.def_levels(), levels.rep_levels(), - )? + ) } ColumnWriter::Int64ColumnWriter(ref mut typed) => { match column.data_type() { ArrowDataType::Int64 => { let array = column.as_primitive::(); - write_primitive(typed, array.values(), levels)? + write_primitive(typed, array.values(), levels) } ArrowDataType::UInt64 => { let values = column.as_primitive::().values(); // follow C++ implementation and use overflow/reinterpret cast from u64 to i64 which will map // `(i64::MAX as u64)..u64::MAX` to `i64::MIN..0` let array = values.inner().typed_data::(); - write_primitive(typed, array, levels)? + write_primitive(typed, array, levels) } ArrowDataType::Decimal128(_, _) => { // use the int64 to represent the decimal with low precision let array = column .as_primitive::() - .unary::<_, types::Int64Type>(|v| v as i64); - write_primitive(typed, array.values(), levels)? + .unary::<_, Int64Type>(|v| v as i64); + write_primitive(typed, array.values(), levels) } _ => { let array = arrow_cast::cast(column, &ArrowDataType::Int64)?; let array = array.as_primitive::(); - write_primitive(typed, array.values(), levels)? + write_primitive(typed, array.values(), levels) } } } @@ -504,18 +582,12 @@ fn write_leaf( unreachable!("Currently unreachable because data type not supported") } ColumnWriter::FloatColumnWriter(ref mut typed) => { - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get Float32 array"); - write_primitive(typed, array.values(), levels)? + let array = column.as_primitive::(); + write_primitive(typed, array.values(), levels) } ColumnWriter::DoubleColumnWriter(ref mut typed) => { - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get Float64 array"); - write_primitive(typed, array.values(), levels)? + let array = column.as_primitive::(); + write_primitive(typed, array.values(), levels) } ColumnWriter::ByteArrayColumnWriter(_) => { unreachable!("should use ByteArrayWriter") @@ -553,10 +625,7 @@ fn write_leaf( get_fsb_array_slice(array, indices) } ArrowDataType::Decimal128(_, _) => { - let array = column - .as_any() - .downcast_ref::() - .unwrap(); + let array = column.as_primitive::(); get_decimal_array_slice(array, indices) } _ => { @@ -566,19 +635,14 @@ fn write_leaf( )); } }; - typed.write_batch( - bytes.as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? + typed.write_batch(bytes.as_slice(), levels.def_levels(), levels.rep_levels()) } - }; - Ok(written as i64) + } } -fn write_primitive( - writer: &mut ColumnWriterImpl<'_, T>, - values: &[T::T], +fn write_primitive( + writer: &mut GenericColumnWriter, + values: &E::Values, levels: LevelInfo, ) -> Result { writer.write_batch_internal( diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index bd3568d13cee..9140a21fefb5 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -162,6 +162,75 @@ impl CompressedPage { pub fn data(&self) -> &[u8] { self.compressed_page.buffer().data() } + + /// Returns the thrift page header + pub(crate) fn to_thrift_header(&self) -> PageHeader { + let uncompressed_size = self.uncompressed_size(); + let compressed_size = self.compressed_size(); + let num_values = self.num_values(); + let encoding = self.encoding(); + let page_type = self.page_type(); + + let mut page_header = PageHeader { + type_: page_type.into(), + uncompressed_page_size: uncompressed_size as i32, + compressed_page_size: compressed_size as i32, + // TODO: Add support for crc checksum + crc: None, + data_page_header: None, + index_page_header: None, + dictionary_page_header: None, + data_page_header_v2: None, + }; + + match self.compressed_page { + Page::DataPage { + def_level_encoding, + rep_level_encoding, + ref statistics, + .. + } => { + let data_page_header = crate::format::DataPageHeader { + num_values: num_values as i32, + encoding: encoding.into(), + definition_level_encoding: def_level_encoding.into(), + repetition_level_encoding: rep_level_encoding.into(), + statistics: crate::file::statistics::to_thrift(statistics.as_ref()), + }; + page_header.data_page_header = Some(data_page_header); + } + Page::DataPageV2 { + num_nulls, + num_rows, + def_levels_byte_len, + rep_levels_byte_len, + is_compressed, + ref statistics, + .. + } => { + let data_page_header_v2 = crate::format::DataPageHeaderV2 { + num_values: num_values as i32, + num_nulls: num_nulls as i32, + num_rows: num_rows as i32, + encoding: encoding.into(), + definition_levels_byte_length: def_levels_byte_len as i32, + repetition_levels_byte_length: rep_levels_byte_len as i32, + is_compressed: Some(is_compressed), + statistics: crate::file::statistics::to_thrift(statistics.as_ref()), + }; + page_header.data_page_header_v2 = Some(data_page_header_v2); + } + Page::DictionaryPage { is_sorted, .. } => { + let dictionary_page_header = crate::format::DictionaryPageHeader { + num_values: num_values as i32, + encoding: encoding.into(), + is_sorted: Some(is_sorted), + }; + page_header.dictionary_page_header = Some(dictionary_page_header); + } + } + page_header + } } /// Contains page write metrics. diff --git a/parquet/src/column/writer/encoder.rs b/parquet/src/column/writer/encoder.rs index c343f1d6c824..fb5889b785a8 100644 --- a/parquet/src/column/writer/encoder.rs +++ b/parquet/src/column/writer/encoder.rs @@ -36,7 +36,7 @@ pub trait ColumnValues { } #[cfg(feature = "arrow")] -impl ColumnValues for T { +impl ColumnValues for dyn arrow_array::Array { fn len(&self) -> usize { arrow_array::Array::len(self) } diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 137893092405..6b5117c8631d 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -55,6 +55,22 @@ pub enum ColumnWriter<'a> { FixedLenByteArrayColumnWriter(ColumnWriterImpl<'a, FixedLenByteArrayType>), } +impl<'a> ColumnWriter<'a> { + /// Close this [`ColumnWriter`] + pub fn close(self) -> Result { + match self { + Self::BoolColumnWriter(typed) => typed.close(), + Self::Int32ColumnWriter(typed) => typed.close(), + Self::Int64ColumnWriter(typed) => typed.close(), + Self::Int96ColumnWriter(typed) => typed.close(), + Self::FloatColumnWriter(typed) => typed.close(), + Self::DoubleColumnWriter(typed) => typed.close(), + Self::ByteArrayColumnWriter(typed) => typed.close(), + Self::FixedLenByteArrayColumnWriter(typed) => typed.close(), + } + } +} + pub enum Level { Page, Column, @@ -915,11 +931,11 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { fn update_metrics_for_page(&mut self, page_spec: PageWriteSpec) { self.column_metrics.total_uncompressed_size += page_spec.uncompressed_size as u64; self.column_metrics.total_compressed_size += page_spec.compressed_size as u64; - self.column_metrics.total_num_values += page_spec.num_values as u64; self.column_metrics.total_bytes_written += page_spec.bytes_written; match page_spec.page_type { PageType::DATA_PAGE | PageType::DATA_PAGE_V2 => { + self.column_metrics.total_num_values += page_spec.num_values as u64; if self.column_metrics.data_page_offset.is_none() { self.column_metrics.data_page_offset = Some(page_spec.offset); } @@ -1512,7 +1528,7 @@ mod tests { metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY] ); - assert_eq!(metadata.num_values(), 8); // dictionary + value indexes + assert_eq!(metadata.num_values(), 4); assert_eq!(metadata.compressed_size(), 20); assert_eq!(metadata.uncompressed_size(), 20); assert_eq!(metadata.data_page_offset(), 0); @@ -1639,7 +1655,7 @@ mod tests { metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY] ); - assert_eq!(metadata.num_values(), 8); // dictionary + value indexes + assert_eq!(metadata.num_values(), 4); assert_eq!(metadata.compressed_size(), 20); assert_eq!(metadata.uncompressed_size(), 20); assert_eq!(metadata.data_page_offset(), 0); diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index 4b1c4bad92e1..322d515400bc 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -26,21 +26,17 @@ use std::io::{BufWriter, IoSlice, Read}; use std::{io::Write, sync::Arc}; use thrift::protocol::{TCompactOutputProtocol, TSerializable}; -use crate::basic::PageType; use crate::column::writer::{ get_typed_column_writer_mut, ColumnCloseResult, ColumnWriterImpl, }; use crate::column::{ - page::{CompressedPage, Page, PageWriteSpec, PageWriter}, + page::{CompressedPage, PageWriteSpec, PageWriter}, writer::{get_column_writer, ColumnWriter}, }; use crate::data_type::DataType; use crate::errors::{ParquetError, Result}; use crate::file::reader::ChunkReader; -use crate::file::{ - metadata::*, properties::WriterPropertiesPtr, - statistics::to_thrift as statistics_to_thrift, PARQUET_MAGIC, -}; +use crate::file::{metadata::*, properties::WriterPropertiesPtr, PARQUET_MAGIC}; use crate::schema::types::{ self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr, }; @@ -371,6 +367,16 @@ impl SerializedFileWriter { self.kv_metadatas.push(kv_metadata); } + /// Returns a reference to schema descriptor. + pub fn schema_descr(&self) -> &SchemaDescriptor { + &self.descr + } + + /// Returns a reference to the writer properties + pub fn properties(&self) -> &WriterPropertiesPtr { + &self.props + } + /// Writes the file footer and returns the underlying writer. pub fn into_inner(mut self) -> Result { self.assert_previous_writer_closed()?; @@ -654,17 +660,7 @@ impl<'a> SerializedColumnWriter<'a> { /// Close this [`SerializedColumnWriter`] pub fn close(mut self) -> Result<()> { - let r = match self.inner { - ColumnWriter::BoolColumnWriter(typed) => typed.close()?, - ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, - ColumnWriter::FloatColumnWriter(typed) => typed.close()?, - ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, - ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, - ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, - }; - + let r = self.inner.close()?; if let Some(on_close) = self.on_close.take() { on_close(r)? } @@ -702,86 +698,20 @@ impl<'a, W: Write> SerializedPageWriter<'a, W> { impl<'a, W: Write> PageWriter for SerializedPageWriter<'a, W> { fn write_page(&mut self, page: CompressedPage) -> Result { - let uncompressed_size = page.uncompressed_size(); - let compressed_size = page.compressed_size(); - let num_values = page.num_values(); - let encoding = page.encoding(); let page_type = page.page_type(); - - let mut page_header = parquet::PageHeader { - type_: page_type.into(), - uncompressed_page_size: uncompressed_size as i32, - compressed_page_size: compressed_size as i32, - // TODO: Add support for crc checksum - crc: None, - data_page_header: None, - index_page_header: None, - dictionary_page_header: None, - data_page_header_v2: None, - }; - - match *page.compressed_page() { - Page::DataPage { - def_level_encoding, - rep_level_encoding, - ref statistics, - .. - } => { - let data_page_header = parquet::DataPageHeader { - num_values: num_values as i32, - encoding: encoding.into(), - definition_level_encoding: def_level_encoding.into(), - repetition_level_encoding: rep_level_encoding.into(), - statistics: statistics_to_thrift(statistics.as_ref()), - }; - page_header.data_page_header = Some(data_page_header); - } - Page::DataPageV2 { - num_nulls, - num_rows, - def_levels_byte_len, - rep_levels_byte_len, - is_compressed, - ref statistics, - .. - } => { - let data_page_header_v2 = parquet::DataPageHeaderV2 { - num_values: num_values as i32, - num_nulls: num_nulls as i32, - num_rows: num_rows as i32, - encoding: encoding.into(), - definition_levels_byte_length: def_levels_byte_len as i32, - repetition_levels_byte_length: rep_levels_byte_len as i32, - is_compressed: Some(is_compressed), - statistics: statistics_to_thrift(statistics.as_ref()), - }; - page_header.data_page_header_v2 = Some(data_page_header_v2); - } - Page::DictionaryPage { is_sorted, .. } => { - let dictionary_page_header = parquet::DictionaryPageHeader { - num_values: num_values as i32, - encoding: encoding.into(), - is_sorted: Some(is_sorted), - }; - page_header.dictionary_page_header = Some(dictionary_page_header); - } - } - let start_pos = self.sink.bytes_written() as u64; + let page_header = page.to_thrift_header(); let header_size = self.serialize_page_header(page_header)?; self.sink.write_all(page.data())?; let mut spec = PageWriteSpec::new(); spec.page_type = page_type; - spec.uncompressed_size = uncompressed_size + header_size; - spec.compressed_size = compressed_size + header_size; + spec.uncompressed_size = page.uncompressed_size() + header_size; + spec.compressed_size = page.compressed_size() + header_size; spec.offset = start_pos; spec.bytes_written = self.sink.bytes_written() as u64 - start_pos; - // Number of values is incremented for data pages only - if page_type == PageType::DATA_PAGE || page_type == PageType::DATA_PAGE_V2 { - spec.num_values = num_values; - } + spec.num_values = page.num_values(); Ok(spec) } @@ -808,7 +738,7 @@ mod tests { use std::fs::File; use crate::basic::{Compression, Encoding, LogicalType, Repetition, Type}; - use crate::column::page::PageReader; + use crate::column::page::{Page, PageReader}; use crate::column::reader::get_typed_column_reader; use crate::compression::{create_codec, Codec, CodecOptionsBuilder}; use crate::data_type::{BoolType, Int32Type}; diff --git a/parquet/src/util/memory.rs b/parquet/src/util/memory.rs index 909878a6d538..25d15dd4ff73 100644 --- a/parquet/src/util/memory.rs +++ b/parquet/src/util/memory.rs @@ -114,6 +114,12 @@ impl From for ByteBufferPtr { } } +impl From for Bytes { + fn from(value: ByteBufferPtr) -> Self { + value.data + } +} + #[cfg(test)] mod tests { use super::*; From a3c4fc34f487a1e538822fe6c3205dd3c5bab6f8 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 25 May 2023 20:20:09 +0100 Subject: [PATCH 2/4] Review feedback --- parquet/src/arrow/arrow_writer/mod.rs | 47 +++++++++++++++------------ 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index c265058c11d5..3449de43e559 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -91,9 +91,6 @@ pub struct ArrowWriter { /// The in-progress row group if any in_progress: Option, - /// The total number of rows in the in_progress row group, if any - buffered_rows: usize, - /// A copy of the Arrow schema. /// /// The schema is used to verify that each record batch written has the correct schema @@ -108,8 +105,8 @@ impl Debug for ArrowWriter { let buffered_memory = self.in_progress_size(); f.debug_struct("ArrowWriter") .field("writer", &self.writer) - .field("buffer", &format_args!("{buffered_memory} bytes")) - .field("buffered_rows", &self.buffered_rows) + .field("in_progress_size", &format_args!("{buffered_memory} bytes")) + .field("in_progress_rows", &self.in_progress_rows()) .field("arrow_schema", &self.arrow_schema) .field("max_row_group_size", &self.max_row_group_size) .finish() @@ -140,7 +137,6 @@ impl ArrowWriter { Ok(Self { writer: file_writer, in_progress: None, - buffered_rows: 0, arrow_schema, max_row_group_size, }) @@ -163,26 +159,24 @@ impl ArrowWriter { } } + /// Returns the number of rows buffered in the in progress row group + pub fn in_progress_rows(&self) -> usize { + self.in_progress + .as_ref() + .map(|x| x.buffered_rows) + .unwrap_or_default() + } + /// Encodes the provided [`RecordBatch`] /// /// If this would cause the current row group to exceed [`WriterProperties::max_row_group_size`] - /// rows, the contents of `batch` will be distributed across multiple row groups such that all - /// but the final row group in the file contain [`WriterProperties::max_row_group_size`] rows + /// rows, the contents of `batch` will be written to one or more row groups such that all but + /// the final row group in the file contain [`WriterProperties::max_row_group_size`] rows pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { if batch.num_rows() == 0 { return Ok(()); } - // If would exceed max_row_group_size, split batch - if self.buffered_rows + batch.num_rows() > self.max_row_group_size { - let to_write = self.max_row_group_size - self.buffered_rows; - let a = batch.slice(0, to_write); - let b = batch.slice(to_write, batch.num_rows() - to_write); - self.write(&a)?; - return self.write(&b); - } - - self.buffered_rows += batch.num_rows(); let in_progress = match &mut self.in_progress { Some(in_progress) => in_progress, x => x.insert(ArrowRowGroupWriter::new( @@ -192,9 +186,18 @@ impl ArrowWriter { )?), }; + // If would exceed max_row_group_size, split batch + if in_progress.buffered_rows + batch.num_rows() > self.max_row_group_size { + let to_write = self.max_row_group_size - in_progress.buffered_rows; + let a = batch.slice(0, to_write); + let b = batch.slice(to_write, batch.num_rows() - to_write); + self.write(&a)?; + return self.write(&b); + } + in_progress.write(batch)?; - if self.buffered_rows >= self.max_row_group_size { + if in_progress.buffered_rows >= self.max_row_group_size { self.flush()? } Ok(()) @@ -207,7 +210,6 @@ impl ArrowWriter { None => return Ok(()), }; - self.buffered_rows = 0; let mut row_group_writer = self.writer.next_row_group()?; for (chunk, close) in in_progress.close()? { row_group_writer.append_column(&chunk, close)?; @@ -264,7 +266,7 @@ impl ChunkReader for ArrowColumnChunk { type T = ChainReader; fn get_read(&self, start: u64) -> Result { - assert_eq!(start, 0); + assert_eq!(start, 0); // Assume append_column writes all data in one-shot Ok(ChainReader(self.data.clone().into_iter().peekable())) } @@ -356,6 +358,7 @@ enum ArrowColumnWriter { struct ArrowRowGroupWriter { writers: Vec<(SharedColumnChunk, ArrowColumnWriter)>, schema: SchemaRef, + buffered_rows: usize, } impl ArrowRowGroupWriter { @@ -372,10 +375,12 @@ impl ArrowRowGroupWriter { Ok(Self { writers, schema: arrow.clone(), + buffered_rows: 0, }) } fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.buffered_rows += batch.num_rows(); let mut writers = self.writers.iter_mut().map(|(_, x)| x); for (array, field) in batch.columns().iter().zip(&self.schema.fields) { let mut levels = calculate_array_levels(array, field)?.into_iter(); From 4683a301048d826344dc5d483891997928ffe446 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sat, 27 May 2023 15:52:17 +0100 Subject: [PATCH 3/4] Improved memory accounting --- parquet/src/arrow/arrow_writer/mod.rs | 50 +++++++++++++++++++++++++-- parquet/src/column/writer/mod.rs | 41 ++++++++++++++++------ 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 835e7bc232bb..bde21ae856d0 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -147,13 +147,13 @@ impl ArrowWriter { self.writer.flushed_row_groups() } - /// Returns the length in bytes of the current in progress row group + /// Returns the estimated length in bytes of the current in progress row group pub fn in_progress_size(&self) -> usize { match &self.in_progress { Some(in_progress) => in_progress .writers .iter() - .map(|(x, _)| x.lock().unwrap().length) + .map(|(_, x)| x.get_estimated_total_bytes() as usize) .sum(), None => 0, } @@ -354,6 +354,16 @@ enum ArrowColumnWriter { Column(ColumnWriter<'static>), } +impl ArrowColumnWriter { + /// Returns the estimated total bytes for this column writer + fn get_estimated_total_bytes(&self) -> u64 { + match self { + ArrowColumnWriter::ByteArray(c) => c.get_estimated_total_bytes(), + ArrowColumnWriter::Column(c) => c.get_estimated_total_bytes(), + } + } +} + /// Encodes [`RecordBatch`] to a parquet row group struct ArrowRowGroupWriter { writers: Vec<(SharedColumnChunk, ArrowColumnWriter)>, @@ -2531,4 +2541,40 @@ mod tests { assert_ne!(back.schema(), batch.schema()); assert_eq!(back.column(0).as_ref(), batch.column(0).as_ref()); } + + #[test] + fn in_progress_accounting() { + // define schema + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + + // build a record batch + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + + let mut writer = ArrowWriter::try_new(vec![], batch.schema(), None).unwrap(); + + // starts empty + assert_eq!(writer.in_progress_size(), 0); + assert_eq!(writer.in_progress_rows(), 0); + writer.write(&batch).unwrap(); + + // updated on write + let initial_size = writer.in_progress_size(); + assert!(initial_size > 0); + assert_eq!(writer.in_progress_rows(), 5); + + // updated on second write + writer.write(&batch).unwrap(); + assert!(writer.in_progress_size() > initial_size); + assert_eq!(writer.in_progress_rows(), 10); + + // cleared on flush + writer.flush().unwrap(); + assert_eq!(writer.in_progress_size(), 0); + assert_eq!(writer.in_progress_rows(), 0); + + writer.close().unwrap(); + } } diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 847c1fec9b56..6e8058d433e2 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -43,6 +43,21 @@ use crate::util::memory::ByteBufferPtr; pub(crate) mod encoder; +macro_rules! downcast_writer { + ($e:expr, $i:ident, $b:expr) => { + match $e { + Self::BoolColumnWriter($i) => $b, + Self::Int32ColumnWriter($i) => $b, + Self::Int64ColumnWriter($i) => $b, + Self::Int96ColumnWriter($i) => $b, + Self::FloatColumnWriter($i) => $b, + Self::DoubleColumnWriter($i) => $b, + Self::ByteArrayColumnWriter($i) => $b, + Self::FixedLenByteArrayColumnWriter($i) => $b, + } + }; +} + /// Column writer for a Parquet type. pub enum ColumnWriter<'a> { BoolColumnWriter(ColumnWriterImpl<'a, BoolType>), @@ -56,18 +71,14 @@ pub enum ColumnWriter<'a> { } impl<'a> ColumnWriter<'a> { + /// Returns the estimated total bytes for this column writer + pub(crate) fn get_estimated_total_bytes(&self) -> u64 { + downcast_writer!(self, typed, typed.get_estimated_total_bytes()) + } + /// Close this [`ColumnWriter`] pub fn close(self) -> Result { - match self { - Self::BoolColumnWriter(typed) => typed.close(), - Self::Int32ColumnWriter(typed) => typed.close(), - Self::Int64ColumnWriter(typed) => typed.close(), - Self::Int96ColumnWriter(typed) => typed.close(), - Self::FloatColumnWriter(typed) => typed.close(), - Self::DoubleColumnWriter(typed) => typed.close(), - Self::ByteArrayColumnWriter(typed) => typed.close(), - Self::FixedLenByteArrayColumnWriter(typed) => typed.close(), - } + downcast_writer!(self, typed, typed.close()) } } @@ -441,6 +452,16 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { self.column_metrics.total_bytes_written } + /// Returns the estimated total bytes for this column writer + /// + /// Unlike [`Self::get_total_bytes_written`] this includes an estimate + /// of any data that has not yet been flushed to a pge + pub(crate) fn get_estimated_total_bytes(&self) -> u64 { + self.column_metrics.total_bytes_written + + self.encoder.estimated_data_page_size() as u64 + + self.encoder.estimated_dict_page_size().unwrap_or_default() as u64 + } + /// Returns total number of rows written by this column writer so far. /// This value is also returned when column writer is closed. pub fn get_total_rows_written(&self) -> u64 { From 5fa48439ae6cf5d048fec5df6f00008328468f2c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sat, 27 May 2023 16:01:27 +0100 Subject: [PATCH 4/4] Clippy --- parquet/src/column/writer/mod.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 6e8058d433e2..5e623d281157 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -72,6 +72,7 @@ pub enum ColumnWriter<'a> { impl<'a> ColumnWriter<'a> { /// Returns the estimated total bytes for this column writer + #[cfg(feature = "arrow")] pub(crate) fn get_estimated_total_bytes(&self) -> u64 { downcast_writer!(self, typed, typed.get_estimated_total_bytes()) } @@ -448,6 +449,9 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { /// Returns total number of bytes written by this column writer so far. /// This value is also returned when column writer is closed. + /// + /// Note: this value does not include any buffered data that has not + /// yet been flushed to a page. pub fn get_total_bytes_written(&self) -> u64 { self.column_metrics.total_bytes_written } @@ -455,7 +459,8 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { /// Returns the estimated total bytes for this column writer /// /// Unlike [`Self::get_total_bytes_written`] this includes an estimate - /// of any data that has not yet been flushed to a pge + /// of any data that has not yet been flushed to a page + #[cfg(feature = "arrow")] pub(crate) fn get_estimated_total_bytes(&self) -> u64 { self.column_metrics.total_bytes_written + self.encoder.estimated_data_page_size() as u64