Skip to content

Commit

Permalink
Row decode cleanups (#3180)
Browse files Browse the repository at this point in the history
* Row decode cleanups

* Clippy
  • Loading branch information
tustvold authored Nov 25, 2022
1 parent 187bf61 commit d74c48e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 97 deletions.
65 changes: 24 additions & 41 deletions arrow/src/row/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ use crate::array::PrimitiveArray;
use crate::compute::SortOptions;
use crate::datatypes::ArrowPrimitiveType;
use crate::row::{null_sentinel, Rows};
use arrow_array::builder::BufferBuilder;
use arrow_array::BooleanArray;
use arrow_buffer::{bit_util, i256, MutableBuffer, ToByteSlice};
use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::DataType;
use half::f16;
Expand Down Expand Up @@ -266,61 +267,43 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray {
unsafe { BooleanArray::from(builder.build_unchecked()) }
}

fn decode_nulls(rows: &[&[u8]]) -> (usize, Buffer) {
let mut null_count = 0;
let buffer = MutableBuffer::collect_bool(rows.len(), |idx| {
let valid = rows[idx][0] == 1;
null_count += !valid as usize;
valid
})
.into();
(null_count, buffer)
}

/// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T`
///
/// # Safety
///
/// `data_type` must be appropriate native type for `T`
unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
unsafe fn decode_fixed<T: FixedLengthEncoding + ArrowNativeType>(
rows: &mut [&[u8]],
data_type: DataType,
options: SortOptions,
) -> ArrayData {
let len = rows.len();

let mut null_count = 0;
let mut nulls = MutableBuffer::new(bit_util::ceil(len, 64) * 8);
let mut values = MutableBuffer::new(std::mem::size_of::<T>() * len);
let mut values = BufferBuilder::<T>::new(len);
let (null_count, nulls) = decode_nulls(rows);

let chunks = len / 64;
let remainder = len % 64;
for chunk in 0..chunks {
let mut null_packed = 0;

for bit_idx in 0..64 {
let i = split_off(&mut rows[bit_idx + chunk * 64], T::ENCODED_LEN);
let null = i[0] == 1;
null_count += !null as usize;
null_packed |= (null as u64) << bit_idx;

let value = T::Encoded::from_slice(&i[1..], options.descending);
values.push(T::decode(value));
}

nulls.push(null_packed);
}

if remainder != 0 {
let mut null_packed = 0;

for bit_idx in 0..remainder {
let i = split_off(&mut rows[bit_idx + chunks * 64], T::ENCODED_LEN);
let null = i[0] == 1;
null_count += !null as usize;
null_packed |= (null as u64) << bit_idx;

let value = T::Encoded::from_slice(&i[1..], options.descending);
values.push(T::decode(value));
}

nulls.push(null_packed);
for row in rows {
let i = split_off(row, T::ENCODED_LEN);
let value = T::Encoded::from_slice(&i[1..], options.descending);
values.append(T::decode(value));
}

let builder = ArrayDataBuilder::new(data_type)
.len(rows.len())
.len(len)
.null_count(null_count)
.add_buffer(values.into())
.null_bit_buffer(Some(nulls.into()));
.add_buffer(values.finish())
.null_bit_buffer(Some(nulls));

// SAFETY: Buffers correct length
builder.build_unchecked()
Expand All @@ -333,7 +316,7 @@ pub fn decode_primitive<T: ArrowPrimitiveType>(
options: SortOptions,
) -> PrimitiveArray<T>
where
T::Native: FixedLengthEncoding + ToByteSlice,
T::Native: FixedLengthEncoding,
{
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
Expand Down
71 changes: 15 additions & 56 deletions arrow/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,11 +908,22 @@ fn encode_column(
}

macro_rules! decode_primitive_helper {
($t:ty, $rows: ident, $data_type:ident, $options:ident) => {
($t:ty, $rows:ident, $data_type:ident, $options:ident) => {
Arc::new(decode_primitive::<$t>($rows, $data_type, $options))
};
}

macro_rules! decode_dictionary_helper {
($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => {
Arc::new(decode_dictionary::<$t>(
$interner.unwrap(),
$v.as_ref(),
$options,
$rows,
)?)
};
}

/// Decodes a the provided `field` from `rows`
///
/// # Safety
Expand All @@ -934,61 +945,9 @@ unsafe fn decode_column(
DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
DataType::Dictionary(k, v) => match k.as_ref() {
DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::Int16 => Arc::new(decode_dictionary::<Int16Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::Int32 => Arc::new(decode_dictionary::<Int32Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::Int64 => Arc::new(decode_dictionary::<Int64Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt8 => Arc::new(decode_dictionary::<UInt8Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt16 => Arc::new(decode_dictionary::<UInt16Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt32 => Arc::new(decode_dictionary::<UInt32Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt64 => Arc::new(decode_dictionary::<UInt64Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"{} is not a valid dictionary key type",
field.data_type
)));
}
DataType::Dictionary(k, v) => downcast_integer! {
k.as_ref() => (decode_dictionary_helper, interner, v, options, rows),
_ => unreachable!()
},
_ => {
return Err(ArrowError::NotYetImplemented(format!(
Expand Down

0 comments on commit d74c48e

Please sign in to comment.