Skip to content

Commit

Permalink
Fix row format decode loses timezone (#3063) (#3064)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored Nov 9, 2022
1 parent af5f1e4 commit 5a3ecc2
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 31 deletions.
36 changes: 21 additions & 15 deletions arrow/src/row/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ pub fn encode_dictionary<K: ArrowDictionaryKeyType>(
}

macro_rules! decode_primitive_helper {
($t:ty, $values: ident) => {
decode_primitive::<$t>(&$values)
($t:ty, $values: ident, $data_type:ident) => {
decode_primitive::<$t>(&$values, $data_type.clone())
};
}

Expand Down Expand Up @@ -170,11 +170,11 @@ pub unsafe fn decode_dictionary<K: ArrowDictionaryKeyType>(
}

let child = downcast_primitive! {
&value_type => (decode_primitive_helper, values),
value_type => (decode_primitive_helper, values, value_type),
DataType::Null => NullArray::new(values.len()).into_data(),
DataType::Boolean => decode_bool(&values),
DataType::Decimal128(p, s) => decode_decimal::<Decimal128Type>(&values, *p, *s),
DataType::Decimal256(p, s) => decode_decimal::<Decimal256Type>(&values, *p, *s),
DataType::Decimal128(_, _) => decode_primitive_helper!(Decimal128Type, values, value_type),
DataType::Decimal256(_, _) => decode_primitive_helper!(Decimal256Type, values, value_type),
DataType::Utf8 => decode_string::<i32>(&values),
DataType::LargeUtf8 => decode_string::<i64>(&values),
DataType::Binary => decode_binary::<i32>(&values),
Expand Down Expand Up @@ -247,7 +247,11 @@ fn decode_bool(values: &[&[u8]]) -> ArrayData {
}

/// Decodes a fixed length type array from dictionary values
fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
///
/// # Safety
///
/// `data_type` must be appropriate native type for `T`
unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
values: &[&[u8]],
data_type: DataType,
) -> ArrayData {
Expand All @@ -267,17 +271,19 @@ fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
}

/// Decodes a `PrimitiveArray` from dictionary values
fn decode_primitive<T: ArrowPrimitiveType>(values: &[&[u8]]) -> ArrayData
fn decode_primitive<T: ArrowPrimitiveType>(
values: &[&[u8]],
data_type: DataType,
) -> ArrayData
where
T::Native: FixedLengthEncoding,
{
decode_fixed::<T::Native>(values, T::DATA_TYPE)
}
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(&data_type),
);

/// Decodes a `DecimalArray` from dictionary values
fn decode_decimal<T: DecimalType>(values: &[&[u8]], precision: u8, scale: u8) -> ArrayData
where
T::Native: FixedLengthEncoding,
{
decode_fixed::<T::Native>(values, T::TYPE_CONSTRUCTOR(precision, scale))
// SAFETY:
// Validated data type above
unsafe { decode_fixed::<T::Native>(values, data_type) }
}
17 changes: 14 additions & 3 deletions arrow/src/row/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,11 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray {
}

/// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T`
fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
///
/// # Safety
///
/// `data_type` must be appropriate native type for `T`
unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
rows: &mut [&[u8]],
data_type: DataType,
options: SortOptions,
Expand Down Expand Up @@ -319,16 +323,23 @@ fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
.null_bit_buffer(Some(nulls.into()));

// SAFETY: Buffers correct length
unsafe { builder.build_unchecked() }
builder.build_unchecked()
}

/// Decodes a `PrimitiveArray` from rows
pub fn decode_primitive<T: ArrowPrimitiveType>(
rows: &mut [&[u8]],
data_type: DataType,
options: SortOptions,
) -> PrimitiveArray<T>
where
T::Native: FixedLengthEncoding + ToByteSlice,
{
decode_fixed::<T::Native>(rows, T::DATA_TYPE, options).into()
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(&data_type),
);
// SAFETY:
// Validated data type above
unsafe { decode_fixed::<T::Native>(rows, data_type, options).into() }
}
61 changes: 48 additions & 13 deletions arrow/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ fn encode_column(
}

macro_rules! decode_primitive_helper {
($t:ty, $rows: ident, $options:ident) => {
Arc::new(decode_primitive::<$t>($rows, $options))
($t:ty, $rows: ident, $data_type:ident, $options:ident) => {
Arc::new(decode_primitive::<$t>($rows, $data_type, $options))
};
}

Expand All @@ -645,24 +645,17 @@ unsafe fn decode_column(
interner: Option<&OrderPreservingInterner>,
) -> Result<ArrayRef> {
let options = field.options;
let data_type = field.data_type.clone();
let array: ArrayRef = downcast_primitive! {
&field.data_type => (decode_primitive_helper, rows, options),
data_type => (decode_primitive_helper, rows, data_type, options),
DataType::Null => Arc::new(NullArray::new(rows.len())),
DataType::Boolean => Arc::new(decode_bool(rows, options)),
DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)),
DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options)),
DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options)),
DataType::Decimal128(p, s) => Arc::new(
decode_primitive::<Decimal128Type>(rows, options)
.with_precision_and_scale(*p, *s)
.unwrap(),
),
DataType::Decimal256(p, s) => Arc::new(
decode_primitive::<Decimal256Type>(rows, options)
.with_precision_and_scale(*p, *s)
.unwrap(),
),
DataType::Decimal128(_, _) => decode_primitive_helper!(Decimal128Type, rows, data_type, options),
DataType::Decimal256(_, _) => decode_primitive_helper!(Decimal256Type, rows, data_type, options),
DataType::Dictionary(k, v) => match k.as_ref() {
DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
interner.unwrap(),
Expand Down Expand Up @@ -900,6 +893,48 @@ mod tests {
assert_eq!(&cols[0], &col);
}

#[test]
fn test_timezone() {
let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5])
.with_timezone("+01:00".to_string());
let d = a.data_type().clone();

let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]);
let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap();
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(back[0].data_type(), &d);

// Test dictionary
let mut a =
PrimitiveDictionaryBuilder::<Int32Type, TimestampNanosecondType>::new();
a.append(34).unwrap();
a.append_null();
a.append(345).unwrap();

// Construct dictionary with a timezone
let dict = a.finish();
let values = TimestampNanosecondArray::from(dict.values().data().clone());
let dict_with_tz = dict.with_values(&values.with_timezone("+02:00".to_string()));
let d = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+02:00".to_string()),
)),
);

assert_eq!(dict_with_tz.data_type(), &d);
let mut converter = RowConverter::new(vec![SortField::new(d.clone())]);
let rows = converter
.convert_columns(&[Arc::new(dict_with_tz) as _])
.unwrap();
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(back[0].data_type(), &d);
}

#[test]
fn test_null_encoding() {
let col = Arc::new(NullArray::new(10));
Expand Down

0 comments on commit 5a3ecc2

Please sign in to comment.