diff --git a/CHANGELOG.md b/CHANGELOG.md index 39086421190..53020074e40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed - Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934) - Call `Py_Finalize` at exit to flush buffers, etc. [#943](https://github.com/PyO3/pyo3/pull/943) +- Add type parameter to PyBuffer. #[951](https://github.com/PyO3/pyo3/pull/951) ### Removed - Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930) diff --git a/src/buffer.rs b/src/buffer.rs index 84737e0ccb1..fd004568b6b 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -18,8 +18,9 @@ //! `PyBuffer` implementation use crate::err::{self, PyResult}; -use crate::{exceptions, ffi, AsPyPointer, PyAny, Python}; +use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python}; use std::ffi::CStr; +use std::marker::PhantomData; use std::os::raw; use std::pin::Pin; use std::{cell, mem, ptr, slice}; @@ -27,12 +28,12 @@ use std::{cell, mem, ptr, slice}; /// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`. // use Pin because Python expects that the Py_buffer struct has a stable memory address #[repr(transparent)] -pub struct PyBuffer(Pin>); +pub struct PyBuffer(Pin>, PhantomData); // PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists. // Accessing the buffer contents is protected using the GIL. -unsafe impl Send for PyBuffer {} -unsafe impl Sync for PyBuffer {} +unsafe impl Send for PyBuffer {} +unsafe impl Sync for PyBuffer {} #[derive(Copy, Clone, Eq, PartialEq)] pub enum ElementType { @@ -146,29 +147,51 @@ fn is_matching_endian(c: u8) -> bool { } /// Trait implemented for possible element types of `PyBuffer`. -pub unsafe trait Element { +pub unsafe trait Element: Copy { /// Gets whether the element specified in the format string is potentially compatible. /// Alignment and size are checked separately from this function. fn is_compatible_format(format: &CStr) -> bool; } -fn validate(b: &ffi::Py_buffer) { +fn validate(b: &ffi::Py_buffer) -> PyResult<()> { // shape and stride information must be provided when we use PyBUF_FULL_RO - assert!(!b.shape.is_null()); - assert!(!b.strides.is_null()); + if b.shape.is_null() { + return Err(exceptions::BufferError::py_err("Shape is Null")); + } + if b.strides.is_null() { + return Err(exceptions::BufferError::py_err("PyBuffer: Strides is Null")); + } + Ok(()) +} + +impl<'source, T: Element> FromPyObject<'source> for PyBuffer { + fn extract(obj: &PyAny) -> PyResult> { + Self::get(obj) + } } -impl PyBuffer { +impl PyBuffer { /// Get the underlying buffer from the specified python object. - pub fn get(py: Python, obj: &PyAny) -> PyResult { + pub fn get(obj: &PyAny) -> PyResult> { unsafe { - let mut buf = Box::pin(mem::zeroed::()); + let mut buf = Box::pin(ffi::Py_buffer::new()); err::error_on_minusone( - py, + obj.py(), ffi::PyObject_GetBuffer(obj.as_ptr(), &mut *buf, ffi::PyBUF_FULL_RO), )?; - validate(&buf); - Ok(PyBuffer(buf)) + validate(&buf)?; + let buf = PyBuffer(buf, PhantomData); + // Type Check + if mem::size_of::() == buf.item_size() + && (buf.0.buf as usize) % mem::align_of::() == 0 + && T::is_compatible_format(buf.format()) + { + Ok(buf) + } else { + Err(exceptions::BufferError::py_err( + "Incompatible type as buffer", + )) + } } } @@ -307,12 +330,8 @@ impl PyBuffer { /// /// The returned slice uses type `Cell` because it's theoretically possible for any call into the Python runtime /// to modify the values in the slice. - pub fn as_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell]> { - if mem::size_of::() == self.item_size() - && (self.0.buf as usize) % mem::align_of::() == 0 - && self.is_c_contiguous() - && T::is_compatible_format(self.format()) - { + pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell]> { + if self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( self.0.buf as *mut ReadOnlyCell, @@ -334,13 +353,8 @@ impl PyBuffer { /// /// The returned slice uses type `Cell` because it's theoretically possible for any call into the Python runtime /// to modify the values in the slice. - pub fn as_mut_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell]> { - if !self.readonly() - && mem::size_of::() == self.item_size() - && (self.0.buf as usize) % mem::align_of::() == 0 - && self.is_c_contiguous() - && T::is_compatible_format(self.format()) - { + pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell]> { + if !self.readonly() && self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( self.0.buf as *mut cell::Cell, @@ -361,15 +375,8 @@ impl PyBuffer { /// /// The returned slice uses type `Cell` because it's theoretically possible for any call into the Python runtime /// to modify the values in the slice. - pub fn as_fortran_slice<'a, T: Element>( - &'a self, - _py: Python<'a>, - ) -> Option<&'a [ReadOnlyCell]> { - if mem::size_of::() == self.item_size() - && (self.0.buf as usize) % mem::align_of::() == 0 - && self.is_fortran_contiguous() - && T::is_compatible_format(self.format()) - { + pub fn as_fortran_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell]> { + if mem::size_of::() == self.item_size() && self.is_fortran_contiguous() { unsafe { Some(slice::from_raw_parts( self.0.buf as *mut ReadOnlyCell, @@ -391,16 +398,8 @@ impl PyBuffer { /// /// The returned slice uses type `Cell` because it's theoretically possible for any call into the Python runtime /// to modify the values in the slice. - pub fn as_fortran_mut_slice<'a, T: Element>( - &'a self, - _py: Python<'a>, - ) -> Option<&'a [cell::Cell]> { - if !self.readonly() - && mem::size_of::() == self.item_size() - && (self.0.buf as usize) % mem::align_of::() == 0 - && self.is_fortran_contiguous() - && T::is_compatible_format(self.format()) - { + pub fn as_fortran_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell]> { + if !self.readonly() && self.is_fortran_contiguous() { unsafe { Some(slice::from_raw_parts( self.0.buf as *mut cell::Cell, @@ -421,7 +420,7 @@ impl PyBuffer { /// To check whether the buffer format is compatible before calling this method, /// you can use `::is_compatible_format(buf.format())`. /// Alternatively, `match buffer::ElementType::from_format(buf.format())`. - pub fn copy_to_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> { + pub fn copy_to_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> { self.copy_to_slice_impl(py, target, b'C') } @@ -434,28 +433,16 @@ impl PyBuffer { /// To check whether the buffer format is compatible before calling this method, /// you can use `::is_compatible_format(buf.format())`. /// Alternatively, `match buffer::ElementType::from_format(buf.format())`. - pub fn copy_to_fortran_slice( - &self, - py: Python, - target: &mut [T], - ) -> PyResult<()> { + pub fn copy_to_fortran_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> { self.copy_to_slice_impl(py, target, b'F') } - fn copy_to_slice_impl( - &self, - py: Python, - target: &mut [T], - fort: u8, - ) -> PyResult<()> { + fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> { if mem::size_of_val(target) != self.len_bytes() { return Err(exceptions::BufferError::py_err( "Slice length does not match buffer length.", )); } - if !T::is_compatible_format(self.format()) || mem::size_of::() != self.item_size() { - return incompatible_format_error(); - } unsafe { err::error_on_minusone( py, @@ -473,7 +460,7 @@ impl PyBuffer { /// If the buffer is multi-dimensional, the elements are written in C-style order. /// /// Fails if the buffer format is not compatible with type `T`. - pub fn to_vec(&self, py: Python) -> PyResult> { + pub fn to_vec(&self, py: Python) -> PyResult> { self.to_vec_impl(py, b'C') } @@ -481,15 +468,11 @@ impl PyBuffer { /// If the buffer is multi-dimensional, the elements are written in Fortran-style order. /// /// Fails if the buffer format is not compatible with type `T`. - pub fn to_fortran_vec(&self, py: Python) -> PyResult> { + pub fn to_fortran_vec(&self, py: Python) -> PyResult> { self.to_vec_impl(py, b'F') } - fn to_vec_impl(&self, py: Python, fort: u8) -> PyResult> { - if !T::is_compatible_format(self.format()) || mem::size_of::() != self.item_size() { - incompatible_format_error()?; - unreachable!(); - } + fn to_vec_impl(&self, py: Python, fort: u8) -> PyResult> { let item_count = self.item_count(); let mut vec: Vec = Vec::with_capacity(item_count); unsafe { @@ -520,7 +503,7 @@ impl PyBuffer { /// To check whether the buffer format is compatible before calling this method, /// use `::is_compatible_format(buf.format())`. /// Alternatively, `match buffer::ElementType::from_format(buf.format())`. - pub fn copy_from_slice(&self, py: Python, source: &[T]) -> PyResult<()> { + pub fn copy_from_slice(&self, py: Python, source: &[T]) -> PyResult<()> { self.copy_from_slice_impl(py, source, b'C') } @@ -534,20 +517,11 @@ impl PyBuffer { /// To check whether the buffer format is compatible before calling this method, /// use `::is_compatible_format(buf.format())`. /// Alternatively, `match buffer::ElementType::from_format(buf.format())`. - pub fn copy_from_fortran_slice( - &self, - py: Python, - source: &[T], - ) -> PyResult<()> { + pub fn copy_from_fortran_slice(&self, py: Python, source: &[T]) -> PyResult<()> { self.copy_from_slice_impl(py, source, b'F') } - fn copy_from_slice_impl( - &self, - py: Python, - source: &[T], - fort: u8, - ) -> PyResult<()> { + fn copy_from_slice_impl(&self, py: Python, source: &[T], fort: u8) -> PyResult<()> { if self.readonly() { return buffer_readonly_error(); } @@ -556,9 +530,6 @@ impl PyBuffer { "Slice length does not match buffer length.", )); } - if !T::is_compatible_format(self.format()) || mem::size_of::() != self.item_size() { - return incompatible_format_error(); - } unsafe { err::error_on_minusone( py, @@ -589,19 +560,14 @@ impl PyBuffer { } } -fn incompatible_format_error() -> PyResult<()> { - Err(exceptions::BufferError::py_err( - "Slice type is incompatible with buffer format.", - )) -} - +#[inline(always)] fn buffer_readonly_error() -> PyResult<()> { Err(exceptions::BufferError::py_err( "Cannot write to read-only buffer.", )) } -impl Drop for PyBuffer { +impl Drop for PyBuffer { fn drop(&mut self) { let _gil_guard = Python::acquire_gil(); unsafe { ffi::PyBuffer_Release(&mut *self.0) } @@ -614,9 +580,9 @@ impl Drop for PyBuffer { /// The data cannot be modified through the reference, but other references may /// be modifying the data. #[repr(transparent)] -pub struct ReadOnlyCell(cell::UnsafeCell); +pub struct ReadOnlyCell(cell::UnsafeCell); -impl ReadOnlyCell { +impl ReadOnlyCell { #[inline] pub fn get(&self) -> T { unsafe { *self.0.get() } @@ -675,7 +641,7 @@ mod test { let gil = Python::acquire_gil(); let py = gil.python(); let bytes = py.eval("b'abcde'", None, None).unwrap(); - let buffer = PyBuffer::get(py, &bytes).unwrap(); + let buffer = PyBuffer::get(&bytes).unwrap(); assert_eq!(buffer.dimensions(), 1); assert_eq!(buffer.item_count(), 5); assert_eq!(buffer.format().to_str().unwrap(), "B"); @@ -684,26 +650,18 @@ mod test { assert!(buffer.is_c_contiguous()); assert!(buffer.is_fortran_contiguous()); - assert!(buffer.as_slice::(py).is_none()); - assert!(buffer.as_slice::(py).is_none()); - - let slice = buffer.as_slice::(py).unwrap(); + let slice = buffer.as_slice(py).unwrap(); assert_eq!(slice.len(), 5); assert_eq!(slice[0].get(), b'a'); assert_eq!(slice[2].get(), b'c'); - assert!(buffer.as_mut_slice::(py).is_none()); - assert!(buffer.copy_to_slice(py, &mut [0u8]).is_err()); let mut arr = [0; 5]; buffer.copy_to_slice(py, &mut arr).unwrap(); assert_eq!(arr, b"abcde" as &[u8]); assert!(buffer.copy_from_slice(py, &[0u8; 5]).is_err()); - - assert!(buffer.to_vec::(py).is_err()); - assert!(buffer.to_vec::(py).is_err()); - assert_eq!(buffer.to_vec::(py).unwrap(), b"abcde"); + assert_eq!(buffer.to_vec(py).unwrap(), b"abcde"); } #[allow(clippy::float_cmp)] // The test wants to ensure that no precision was lost on the Python round-trip @@ -716,21 +674,18 @@ mod test { .unwrap() .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) .unwrap(); - let buffer = PyBuffer::get(py, array).unwrap(); + let buffer = PyBuffer::get(array).unwrap(); assert_eq!(buffer.dimensions(), 1); assert_eq!(buffer.item_count(), 4); assert_eq!(buffer.format().to_str().unwrap(), "f"); assert_eq!(buffer.shape(), [4]); - assert!(buffer.as_slice::(py).is_none()); - assert!(buffer.as_slice::(py).is_none()); - - let slice = buffer.as_slice::(py).unwrap(); + let slice = buffer.as_slice(py).unwrap(); assert_eq!(slice.len(), 4); assert_eq!(slice[0].get(), 1.0); assert_eq!(slice[3].get(), 2.5); - let mut_slice = buffer.as_mut_slice::(py).unwrap(); + let mut_slice = buffer.as_mut_slice(py).unwrap(); assert_eq!(mut_slice.len(), 4); assert_eq!(mut_slice[0].get(), 1.0); mut_slice[3].set(2.75); @@ -741,6 +696,6 @@ mod test { .unwrap(); assert_eq!(slice[2].get(), 12.0); - assert_eq!(buffer.to_vec::(py).unwrap(), [10.0, 11.0, 12.0, 13.0]); + assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]); } } diff --git a/src/ffi/object.rs b/src/ffi/object.rs index 67fd1133b09..175105b8e1f 100644 --- a/src/ffi/object.rs +++ b/src/ffi/object.rs @@ -133,8 +133,8 @@ pub type objobjargproc = #[cfg(not(Py_LIMITED_API))] mod bufferinfo { use crate::ffi::pyport::Py_ssize_t; - use std::mem; use std::os::raw::{c_char, c_int, c_void}; + use std::ptr; #[repr(C)] #[derive(Copy, Clone)] @@ -152,10 +152,21 @@ mod bufferinfo { pub internal: *mut c_void, } - impl Default for Py_buffer { - #[inline] - fn default() -> Self { - unsafe { mem::zeroed() } + impl Py_buffer { + pub const fn new() -> Self { + Py_buffer { + buf: ptr::null_mut(), + obj: ptr::null_mut(), + len: 0, + itemsize: 0, + readonly: 0, + ndim: 0, + format: ptr::null_mut(), + shape: ptr::null_mut(), + strides: ptr::null_mut(), + suboffsets: ptr::null_mut(), + internal: ptr::null_mut(), + } } } diff --git a/src/types/sequence.rs b/src/types/sequence.rs index 4adaf5a8b53..98d667ea202 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -279,7 +279,7 @@ macro_rules! array_impls { fn extract(obj: &'source PyAny) -> PyResult { let mut array = [T::default(); $N]; // first try buffer protocol - if let Ok(buf) = buffer::PyBuffer::get(obj.py(), obj) { + if let Ok(buf) = buffer::PyBuffer::get(obj) { if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() { buf.release(obj.py()); return Ok(array); @@ -315,9 +315,9 @@ where { fn extract(obj: &'source PyAny) -> PyResult { // first try buffer protocol - if let Ok(buf) = buffer::PyBuffer::get(obj.py(), obj) { + if let Ok(buf) = buffer::PyBuffer::get(obj) { if buf.dimensions() == 1 { - if let Ok(v) = buf.to_vec::(obj.py()) { + if let Ok(v) = buf.to_vec(obj.py()) { buf.release(obj.py()); return Ok(v); } diff --git a/tests/test_buffer_protocol.rs b/tests/test_buffer_protocol.rs index c576e1fd00c..6f0e468381e 100644 --- a/tests/test_buffer_protocol.rs +++ b/tests/test_buffer_protocol.rs @@ -114,8 +114,8 @@ fn test_buffer_referenced() { } .into_py(py); - let buf = PyBuffer::get(py, instance.as_ref(py)).unwrap(); - assert_eq!(buf.to_vec::(py).unwrap(), input); + let buf = PyBuffer::::get(instance.as_ref(py)).unwrap(); + assert_eq!(buf.to_vec(py).unwrap(), input); drop(instance); buf }; diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index c3e56547ac3..825670f2bc2 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -1,3 +1,4 @@ +use pyo3::buffer::PyBuffer; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -20,3 +21,44 @@ fn test_optional_bool() { py_assert!(py, f, "f(False) == 'Some(false)'"); py_assert!(py, f, "f(None) == 'None'"); } + +#[pyfunction] +fn buffer_inplace_add(py: Python, x: PyBuffer, y: PyBuffer) { + let x = x.as_mut_slice(py).unwrap(); + let y = y.as_slice(py).unwrap(); + for (xi, yi) in x.iter().zip(y) { + let xi_plus_yi = xi.get() + yi.get(); + xi.set(xi_plus_yi); + } +} + +#[test] +fn test_buffer_add() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let f = wrap_pyfunction!(buffer_inplace_add)(py); + + py_expect_exception!( + py, + f, + r#" +import array +a = array.array("i", [0, 1, 2, 3]) +b = array.array("I", [0, 1, 2, 3]) +f(a, b) +"#, + BufferError + ); + + pyo3::py_run!( + py, + f, + r#" +import array +a = array.array("i", [0, 1, 2, 3]) +b = array.array("i", [2, 3, 4, 5]) +f(a, b) +assert a, array.array("i", [2, 4, 6, 8]) +"# + ); +}