Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark typed buffer APIs safe (#996) (#1027) #1866

Merged
merged 5 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions arrow/src/array/array_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl UnionArray {
}

// Check the type_ids
let type_id_slice: &[i8] = unsafe { type_ids.typed_data() };
let type_id_slice: &[i8] = type_ids.typed_data();
let invalid_type_ids = type_id_slice
.iter()
.filter(|i| *i < &0)
Expand All @@ -201,7 +201,7 @@ impl UnionArray {
// Check the value offsets if provided
if let Some(offset_buffer) = &value_offsets {
let max_len = type_ids.len() as i32;
let offsets_slice: &[i32] = unsafe { offset_buffer.typed_data() };
let offsets_slice: &[i32] = offset_buffer.typed_data();
let invalid_offsets = offsets_slice
.iter()
.filter(|i| *i < &0 || *i > &max_len)
Expand Down Expand Up @@ -255,9 +255,7 @@ impl UnionArray {
pub fn value_offset(&self, index: usize) -> i32 {
assert!(index - self.offset() < self.len());
if self.is_dense() {
// safety: reinterpreting is safe since the offset buffer contains `i32` values and is
// properly aligned.
unsafe { self.data().buffers()[1].typed_data::<i32>()[index] }
self.data().buffers()[1].typed_data::<i32>()[index]
} else {
index as i32
}
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub(crate) fn builder_to_mutable_buffer<T: ArrowNativeType>(
/// builder.append(45);
/// let buffer = builder.finish();
///
/// assert_eq!(unsafe { buffer.typed_data::<u8>() }, &[42, 43, 44, 45]);
/// assert_eq!(buffer.typed_data::<u8>(), &[42, 43, 44, 45]);
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -291,7 +291,7 @@ impl<T: ArrowNativeType> BufferBuilder<T> {
///
/// let buffer = builder.finish();
///
/// assert_eq!(unsafe { buffer.typed_data::<u8>() }, &[42, 44, 46]);
/// assert_eq!(buffer.typed_data::<u8>(), &[42, 44, 46]);
/// ```
#[inline]
pub fn finish(&mut self) -> Buffer {
Expand Down
5 changes: 2 additions & 3 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,7 @@ impl ArrayData {
)));
}

// SAFETY: Bounds checked above
Ok(unsafe { &(buffer.typed_data::<T>()[self.offset..self.offset + len]) })
Ok(&buffer.typed_data::<T>()[self.offset..self.offset + len])
}

/// Does a cheap sanity check that the `self.len` values in `buffer` are valid
Expand Down Expand Up @@ -1161,7 +1160,7 @@ impl ArrayData {

// Justification: buffer size was validated above
let indexes: &[T] =
unsafe { &(buffer.typed_data::<T>()[self.offset..self.offset + self.len]) };
&buffer.typed_data::<T>()[self.offset..self.offset + self.len];

indexes.iter().enumerate().try_for_each(|(i, &dict_index)| {
// Do not check the value is null (value can be arbitrary)
Expand Down
25 changes: 10 additions & 15 deletions arrow/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,14 @@ impl Buffer {

/// View buffer as typed slice.
///
/// # Safety
/// # Panics
///
/// `ArrowNativeType` is public so that it can be used as a trait bound for other public
/// components, such as the `ToByteSlice` trait. However, this means that it can be
/// implemented by user defined types, which it is not intended for.
pub unsafe fn typed_data<T: ArrowNativeType + num::Num>(&self) -> &[T] {
// JUSTIFICATION
// Benefit
// Many of the buffers represent specific types, and consumers of `Buffer` often need to re-interpret them.
// Soundness
// * The pointer is non-null by construction
// * alignment asserted below.
let (prefix, offsets, suffix) = self.as_slice().align_to::<T>();
/// This function panics if the underlying buffer is not aligned
/// correctly for type `T`.
pub fn typed_data<T: ArrowNativeType>(&self) -> &[T] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is truly "safe" -- is it really true that any bit pattern is a valid ArrowNativeType? I am thinking about floating point representations in particular -- I wonder if this API could potentially create invalid f32 / f64 which seems like it would thus still be unsafe 🤔

Copy link
Contributor Author

@tustvold tustvold Jun 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think https://doc.rust-lang.org/std/primitive.f32.html#method.from_bits is relevant here, the short answer is it is perfectly safe to transmute arbitrary bytes to floats, it may not be wise, but it is not UB.

In particular the standard library provides safe functions that transmute u32 -> f32, u64 -> f64, and so I think it is fair to say all bit sequences are valid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that it is safe because there are no actual undefined bit patterns in any of the native types (as opposed to bool or Option<...> for example). Certain bit patterns might get canonicalized when interpreted as floating point values, but I don't think that would be considered undefined behavior. There are more details about specific behavior in the docs for f64::from_bits (which is considered safe).

// SAFETY
// ArrowNativeType are trivially transmutable, and this method checks alignment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// ArrowNativeType are trivially transmutable, and this method checks alignment
// ArrowNativeType is sealed (can't be implemented outside the arrow crate,
// trivially transmutable, and this method checks alignment

let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::<T>() };
assert!(prefix.is_empty() && suffix.is_empty());
offsets
}
Expand Down Expand Up @@ -451,7 +446,7 @@ mod tests {
macro_rules! check_as_typed_data {
($input: expr, $native_t: ty) => {{
let buffer = Buffer::from_slice_ref($input);
let slice: &[$native_t] = unsafe { buffer.typed_data::<$native_t>() };
let slice: &[$native_t] = buffer.typed_data::<$native_t>();
assert_eq!($input, slice);
}};
}
Expand Down Expand Up @@ -573,12 +568,12 @@ mod tests {
)
};

let slice = unsafe { buffer.typed_data::<i32>() };
let slice = buffer.typed_data::<i32>();
assert_eq!(slice, &[1, 2, 3, 4, 5]);

let buffer = buffer.slice(std::mem::size_of::<i32>());

let slice = unsafe { buffer.typed_data::<i32>() };
let slice = buffer.typed_data::<i32>();
assert_eq!(slice, &[2, 3, 4, 5]);
}
}
13 changes: 5 additions & 8 deletions arrow/src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,14 @@ impl MutableBuffer {

/// View this buffer asa slice of a specific type.
///
/// # Safety
///
/// This function must only be used with buffers which are treated
/// as type `T` (e.g. extended with items of type `T`).
///
/// # Panics
///
/// This function panics if the underlying buffer is not aligned
/// correctly for type `T`.
pub unsafe fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
pub fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
// SAFETY
// ArrowNativeType are trivially transmutable, and this method checks alignment
let (prefix, offsets, suffix) = unsafe { self.as_slice_mut().align_to_mut::<T>() };
assert!(prefix.is_empty() && suffix.is_empty());
offsets
}
Expand All @@ -299,7 +296,7 @@ impl MutableBuffer {
/// assert_eq!(buffer.len(), 8) // u32 has 4 bytes
/// ```
#[inline]
pub fn extend_from_slice<T: ToByteSlice>(&mut self, items: &[T]) {
pub fn extend_from_slice<T: ArrowNativeType>(&mut self, items: &[T]) {
Copy link
Contributor Author

@tustvold tustvold Jun 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method was potentially unsound, as ToByteSlice is not sealed and so could theoretically be implemented for a type that is not trivially transmutable (which the implementation of this method implicitly assumes).

Edit: this is an API change

let len = items.len();
let additional = len * std::mem::size_of::<T>();
self.reserve(additional);
Expand Down
4 changes: 1 addition & 3 deletions arrow/src/buffer/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ where

let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits);

// Safety: buffer is always treated as type `u64` in the code
// below.
let result_chunks = unsafe { result.typed_data_mut::<u64>().iter_mut() };
let result_chunks = result.typed_data_mut::<u64>().iter_mut();

result_chunks
.zip(left_chunks.iter())
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ where
let list_data = array.data();
let str_values_buf = str_array.value_data();

let offsets = unsafe { list_data.buffers()[0].typed_data::<OffsetSizeFrom>() };
let offsets = list_data.buffers()[0].typed_data::<OffsetSizeFrom>();

let mut offset_builder = BufferBuilder::<OffsetSizeTo>::new(offsets.len());
offsets.iter().try_for_each::<_, Result<_>>(|offset| {
Expand Down
6 changes: 2 additions & 4 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,7 @@ fn sort_boolean(
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
// Safety: the buffer is always treated as `u32` in the code below
let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
let result_slice: &mut [u32] = result.typed_data_mut();

if options.nulls_first {
let size = nulls_len.min(len);
Expand Down Expand Up @@ -565,8 +564,7 @@ where
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
// Safety: the buffer is always treated as `u32` in the code below
let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
let result_slice: &mut [u32] = result.typed_data_mut();

if options.nulls_first {
let size = nulls_len.min(len);
Expand Down
3 changes: 1 addition & 2 deletions arrow/src/compute/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,7 @@ where
let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);

// Safety: the buffer is always treated as as a type of `OffsetSize` in the code below
let offsets = unsafe { offsets_buffer.typed_data_mut() };
let offsets = offsets_buffer.typed_data_mut();
let mut values = MutableBuffer::new(0);
let mut length_so_far = OffsetSize::zero();
offsets[0] = length_so_far;
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/array_reader/byte_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ impl<I: OffsetSizeTrait + ScalarValue> ArrayReader for ByteArrayReader<I> {
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down
6 changes: 3 additions & 3 deletions parquet/src/arrow/array_reader/byte_array_dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ where
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down Expand Up @@ -356,7 +356,7 @@ where
assert_eq!(dict.data_type(), &self.value_type);

let dict_buffers = dict.data().buffers();
let dict_offsets = unsafe { dict_buffers[0].typed_data::<V>() };
let dict_offsets = dict_buffers[0].typed_data::<V>();
let dict_values = dict_buffers[1].as_slice();

values.extend_from_dictionary(
Expand Down
8 changes: 4 additions & 4 deletions parquet/src/arrow/array_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ where
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down Expand Up @@ -447,13 +447,13 @@ where
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down
2 changes: 1 addition & 1 deletion parquet/src/arrow/arrow_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ macro_rules! def_get_binary_array_fn {
fn $name(array: &$ty) -> Vec<ByteArray> {
let mut byte_array = ByteArray::new();
let ptr = crate::util::memory::ByteBufferPtr::new(
unsafe { array.value_data().typed_data::<u8>() }.to_vec(),
array.value_data().as_slice().to_vec(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this was ever using typed_data...

);
byte_array.set_data(ptr);
array
Expand Down
2 changes: 1 addition & 1 deletion parquet/src/arrow/buffer/dictionary_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl<K: ScalarValue + ArrowNativeType + Ord, V: ScalarValue + OffsetSizeTrait>
Self::Dict { keys, values } => {
let mut spilled = OffsetBuffer::default();
let dict_buffers = values.data().buffers();
let dict_offsets = unsafe { dict_buffers[0].typed_data::<V>() };
let dict_offsets = dict_buffers[0].typed_data::<V>();
let dict_values = dict_buffers[1].as_slice();

if values.is_empty() {
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/record_reader/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::marker::PhantomData;

use crate::arrow::buffer::bit_util::iter_set_bits_rev;
use arrow::buffer::{Buffer, MutableBuffer};
use arrow::datatypes::ToByteSlice;
use arrow::datatypes::ArrowNativeType;

/// A buffer that supports writing new data to the end, and removing data from the front
///
Expand Down Expand Up @@ -172,7 +172,7 @@ impl<T: ScalarValue> ScalarBuffer<T> {
}
}

impl<T: ScalarValue + ToByteSlice> ScalarBuffer<T> {
impl<T: ScalarValue + ArrowNativeType> ScalarBuffer<T> {
pub fn push(&mut self, v: T) {
self.buffer.push(v);
self.len += 1;
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/record_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ mod tests {

// Verify result record data
let actual = record_reader.consume_record_data().unwrap();
let actual_values = unsafe { actual.typed_data::<i32>() };
let actual_values = actual.typed_data::<i32>();

let expected = &[0, 7, 0, 6, 3, 0, 8];
assert_eq!(actual_values.len(), expected.len());
Expand Down Expand Up @@ -687,7 +687,7 @@ mod tests {

// Verify result record data
let actual = record_reader.consume_record_data().unwrap();
let actual_values = unsafe { actual.typed_data::<i32>() };
let actual_values = actual.typed_data::<i32>();
let expected = &[4, 0, 0, 7, 6, 3, 2, 8, 9];
assert_eq!(actual_values.len(), expected.len());

Expand Down