Skip to content

Commit

Permalink
Merge pull request #952 from kngwyu/typed-pybuffer
Browse files Browse the repository at this point in the history
Typed PyBuffer
  • Loading branch information
kngwyu authored Jun 5, 2020
2 parents be1b704 + 5939362 commit d674b5f
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 120 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Changed
- Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934)
- Call `Py_Finalize` at exit to flush buffers, etc. [#943](https://github.com/PyO3/pyo3/pull/943)
- Add type parameter to PyBuffer. #[951](https://github.com/PyO3/pyo3/pull/951)

### Removed
- Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930)
Expand Down
175 changes: 65 additions & 110 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@

//! `PyBuffer` implementation
use crate::err::{self, PyResult};
use crate::{exceptions, ffi, AsPyPointer, PyAny, Python};
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw;
use std::pin::Pin;
use std::{cell, mem, ptr, slice};

/// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`.
// use Pin<Box> because Python expects that the Py_buffer struct has a stable memory address
#[repr(transparent)]
pub struct PyBuffer(Pin<Box<ffi::Py_buffer>>);
pub struct PyBuffer<T>(Pin<Box<ffi::Py_buffer>>, PhantomData<T>);

// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists.
// Accessing the buffer contents is protected using the GIL.
unsafe impl Send for PyBuffer {}
unsafe impl Sync for PyBuffer {}
unsafe impl<T> Send for PyBuffer<T> {}
unsafe impl<T> Sync for PyBuffer<T> {}

#[derive(Copy, Clone, Eq, PartialEq)]
pub enum ElementType {
Expand Down Expand Up @@ -146,29 +147,51 @@ fn is_matching_endian(c: u8) -> bool {
}

/// Trait implemented for possible element types of `PyBuffer`.
pub unsafe trait Element {
pub unsafe trait Element: Copy {
/// Gets whether the element specified in the format string is potentially compatible.
/// Alignment and size are checked separately from this function.
fn is_compatible_format(format: &CStr) -> bool;
}

fn validate(b: &ffi::Py_buffer) {
fn validate(b: &ffi::Py_buffer) -> PyResult<()> {
// shape and stride information must be provided when we use PyBUF_FULL_RO
assert!(!b.shape.is_null());
assert!(!b.strides.is_null());
if b.shape.is_null() {
return Err(exceptions::BufferError::py_err("Shape is Null"));
}
if b.strides.is_null() {
return Err(exceptions::BufferError::py_err("PyBuffer: Strides is Null"));
}
Ok(())
}

impl<'source, T: Element> FromPyObject<'source> for PyBuffer<T> {
fn extract(obj: &PyAny) -> PyResult<PyBuffer<T>> {
Self::get(obj)
}
}

impl PyBuffer {
impl<T: Element> PyBuffer<T> {
/// Get the underlying buffer from the specified python object.
pub fn get(py: Python, obj: &PyAny) -> PyResult<PyBuffer> {
pub fn get(obj: &PyAny) -> PyResult<PyBuffer<T>> {
unsafe {
let mut buf = Box::pin(mem::zeroed::<ffi::Py_buffer>());
let mut buf = Box::pin(ffi::Py_buffer::new());
err::error_on_minusone(
py,
obj.py(),
ffi::PyObject_GetBuffer(obj.as_ptr(), &mut *buf, ffi::PyBUF_FULL_RO),
)?;
validate(&buf);
Ok(PyBuffer(buf))
validate(&buf)?;
let buf = PyBuffer(buf, PhantomData);
// Type Check
if mem::size_of::<T>() == buf.item_size()
&& (buf.0.buf as usize) % mem::align_of::<T>() == 0
&& T::is_compatible_format(buf.format())
{
Ok(buf)
} else {
Err(exceptions::BufferError::py_err(
"Incompatible type as buffer",
))
}
}
}

Expand Down Expand Up @@ -307,12 +330,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
if mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_c_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
if self.is_c_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut ReadOnlyCell<T>,
Expand All @@ -334,13 +353,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_mut_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly()
&& mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_c_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly() && self.is_c_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut cell::Cell<T>,
Expand All @@ -361,15 +375,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_fortran_slice<'a, T: Element>(
&'a self,
_py: Python<'a>,
) -> Option<&'a [ReadOnlyCell<T>]> {
if mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_fortran_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_fortran_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
if mem::size_of::<T>() == self.item_size() && self.is_fortran_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut ReadOnlyCell<T>,
Expand All @@ -391,16 +398,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_fortran_mut_slice<'a, T: Element>(
&'a self,
_py: Python<'a>,
) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly()
&& mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_fortran_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_fortran_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly() && self.is_fortran_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut cell::Cell<T>,
Expand All @@ -421,7 +420,7 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_to_slice<T: Element + Copy>(&self, py: Python, target: &mut [T]) -> PyResult<()> {
pub fn copy_to_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
self.copy_to_slice_impl(py, target, b'C')
}

Expand All @@ -434,28 +433,16 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_to_fortran_slice<T: Element + Copy>(
&self,
py: Python,
target: &mut [T],
) -> PyResult<()> {
pub fn copy_to_fortran_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
self.copy_to_slice_impl(py, target, b'F')
}

fn copy_to_slice_impl<T: Element + Copy>(
&self,
py: Python,
target: &mut [T],
fort: u8,
) -> PyResult<()> {
fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
if mem::size_of_val(target) != self.len_bytes() {
return Err(exceptions::BufferError::py_err(
"Slice length does not match buffer length.",
));
}
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
return incompatible_format_error();
}
unsafe {
err::error_on_minusone(
py,
Expand All @@ -473,23 +460,19 @@ impl PyBuffer {
/// If the buffer is multi-dimensional, the elements are written in C-style order.
///
/// Fails if the buffer format is not compatible with type `T`.
pub fn to_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
pub fn to_vec(&self, py: Python) -> PyResult<Vec<T>> {
self.to_vec_impl(py, b'C')
}

/// Copies the buffer elements to a newly allocated vector.
/// If the buffer is multi-dimensional, the elements are written in Fortran-style order.
///
/// Fails if the buffer format is not compatible with type `T`.
pub fn to_fortran_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
pub fn to_fortran_vec(&self, py: Python) -> PyResult<Vec<T>> {
self.to_vec_impl(py, b'F')
}

fn to_vec_impl<T: Element + Copy>(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
incompatible_format_error()?;
unreachable!();
}
fn to_vec_impl(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
let item_count = self.item_count();
let mut vec: Vec<T> = Vec::with_capacity(item_count);
unsafe {
Expand Down Expand Up @@ -520,7 +503,7 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_from_slice<T: Element + Copy>(&self, py: Python, source: &[T]) -> PyResult<()> {
pub fn copy_from_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
self.copy_from_slice_impl(py, source, b'C')
}

Expand All @@ -534,20 +517,11 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_from_fortran_slice<T: Element + Copy>(
&self,
py: Python,
source: &[T],
) -> PyResult<()> {
pub fn copy_from_fortran_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
self.copy_from_slice_impl(py, source, b'F')
}

fn copy_from_slice_impl<T: Element + Copy>(
&self,
py: Python,
source: &[T],
fort: u8,
) -> PyResult<()> {
fn copy_from_slice_impl(&self, py: Python, source: &[T], fort: u8) -> PyResult<()> {
if self.readonly() {
return buffer_readonly_error();
}
Expand All @@ -556,9 +530,6 @@ impl PyBuffer {
"Slice length does not match buffer length.",
));
}
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
return incompatible_format_error();
}
unsafe {
err::error_on_minusone(
py,
Expand Down Expand Up @@ -589,19 +560,14 @@ impl PyBuffer {
}
}

fn incompatible_format_error() -> PyResult<()> {
Err(exceptions::BufferError::py_err(
"Slice type is incompatible with buffer format.",
))
}

#[inline(always)]
fn buffer_readonly_error() -> PyResult<()> {
Err(exceptions::BufferError::py_err(
"Cannot write to read-only buffer.",
))
}

impl Drop for PyBuffer {
impl<T> Drop for PyBuffer<T> {
fn drop(&mut self) {
let _gil_guard = Python::acquire_gil();
unsafe { ffi::PyBuffer_Release(&mut *self.0) }
Expand All @@ -614,9 +580,9 @@ impl Drop for PyBuffer {
/// The data cannot be modified through the reference, but other references may
/// be modifying the data.
#[repr(transparent)]
pub struct ReadOnlyCell<T>(cell::UnsafeCell<T>);
pub struct ReadOnlyCell<T: Element>(cell::UnsafeCell<T>);

impl<T: Copy> ReadOnlyCell<T> {
impl<T: Element> ReadOnlyCell<T> {
#[inline]
pub fn get(&self) -> T {
unsafe { *self.0.get() }
Expand Down Expand Up @@ -675,7 +641,7 @@ mod test {
let gil = Python::acquire_gil();
let py = gil.python();
let bytes = py.eval("b'abcde'", None, None).unwrap();
let buffer = PyBuffer::get(py, &bytes).unwrap();
let buffer = PyBuffer::get(&bytes).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 5);
assert_eq!(buffer.format().to_str().unwrap(), "B");
Expand All @@ -684,26 +650,18 @@ mod test {
assert!(buffer.is_c_contiguous());
assert!(buffer.is_fortran_contiguous());

assert!(buffer.as_slice::<f64>(py).is_none());
assert!(buffer.as_slice::<i8>(py).is_none());

let slice = buffer.as_slice::<u8>(py).unwrap();
let slice = buffer.as_slice(py).unwrap();
assert_eq!(slice.len(), 5);
assert_eq!(slice[0].get(), b'a');
assert_eq!(slice[2].get(), b'c');

assert!(buffer.as_mut_slice::<u8>(py).is_none());

assert!(buffer.copy_to_slice(py, &mut [0u8]).is_err());
let mut arr = [0; 5];
buffer.copy_to_slice(py, &mut arr).unwrap();
assert_eq!(arr, b"abcde" as &[u8]);

assert!(buffer.copy_from_slice(py, &[0u8; 5]).is_err());

assert!(buffer.to_vec::<i8>(py).is_err());
assert!(buffer.to_vec::<u16>(py).is_err());
assert_eq!(buffer.to_vec::<u8>(py).unwrap(), b"abcde");
assert_eq!(buffer.to_vec(py).unwrap(), b"abcde");
}

#[allow(clippy::float_cmp)] // The test wants to ensure that no precision was lost on the Python round-trip
Expand All @@ -716,21 +674,18 @@ mod test {
.unwrap()
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
.unwrap();
let buffer = PyBuffer::get(py, array).unwrap();
let buffer = PyBuffer::get(array).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 4);
assert_eq!(buffer.format().to_str().unwrap(), "f");
assert_eq!(buffer.shape(), [4]);

assert!(buffer.as_slice::<f64>(py).is_none());
assert!(buffer.as_slice::<i32>(py).is_none());

let slice = buffer.as_slice::<f32>(py).unwrap();
let slice = buffer.as_slice(py).unwrap();
assert_eq!(slice.len(), 4);
assert_eq!(slice[0].get(), 1.0);
assert_eq!(slice[3].get(), 2.5);

let mut_slice = buffer.as_mut_slice::<f32>(py).unwrap();
let mut_slice = buffer.as_mut_slice(py).unwrap();
assert_eq!(mut_slice.len(), 4);
assert_eq!(mut_slice[0].get(), 1.0);
mut_slice[3].set(2.75);
Expand All @@ -741,6 +696,6 @@ mod test {
.unwrap();
assert_eq!(slice[2].get(), 12.0);

assert_eq!(buffer.to_vec::<f32>(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
}
}
Loading

0 comments on commit d674b5f

Please sign in to comment.