Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poc dunder without specialization #552

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/rustapi_module/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def get_py_version_cfgs():
"Cargo.toml",
rustc_flags=get_py_version_cfgs(),
),
RustExtension(
"rustapi_module.dunder",
"Cargo.toml",
rustc_flags=get_py_version_cfgs(),
),
],
install_requires=install_requires,
tests_require=tests_require,
Expand Down
25 changes: 25 additions & 0 deletions examples/rustapi_module/src/dunder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use pyo3::prelude::*;

#[pymodule]
fn dunder(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<Number>()?;
Ok(())
}

#[pyclass]
pub struct Number {
value: u32,
}

#[pymethods]
impl Number {
#[new]
fn new(obj: &PyRawObject, value: u32) {
obj.init(Number { value })
}

/// Very basic add function
fn __add__(&self, other: u32) -> PyResult<u32> {
Ok(self.value + other)
}
}
1 change: 1 addition & 0 deletions examples/rustapi_module/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod datetime;
pub mod dict_iter;
pub mod dunder;
pub mod othermod;
pub mod subclassing;
3 changes: 1 addition & 2 deletions examples/rustapi_module/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from hypothesis import strategies as st
from hypothesis.strategies import dates, datetimes


# Constants
def _get_utc():
timezone = getattr(pdt, "timezone", None)
Expand Down Expand Up @@ -310,4 +309,4 @@ def test_tz_class_introspection():
tzi = rdt.TzClass()

assert tzi.__class__ == rdt.TzClass
assert repr(tzi).startswith("<TzClass object at")
assert repr(tzi).startswith("<TzClass object at")
5 changes: 5 additions & 0 deletions examples/rustapi_module/tests/test_dunder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import rustapi_module.dunder


def test_add():
assert rustapi_module.dunder.Number(10) + 20 == 30
27 changes: 27 additions & 0 deletions pyo3-derive-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ fn impl_inventory(cls: &syn::Ident) -> TokenStream {
// it comes up in error messages
let name = cls.to_string() + "GeneratedPyo3Inventory";
let inventory_cls = syn::Ident::new(&name, Span::call_site());
let protocol_name = cls.to_string() + "GeneratedPyo3InventoryProtocol";
let protocol_inventory_cls = syn::Ident::new(&protocol_name, Span::call_site());

quote! {
#[doc(hidden)]
Expand All @@ -241,6 +243,31 @@ fn impl_inventory(cls: &syn::Ident) -> TokenStream {
}

pyo3::inventory::collect!(#inventory_cls);

// Dunder methods/Protocol support

#[doc(hidden)]
pub struct #protocol_inventory_cls {
methods: &'static [pyo3::methods::protocols::PyProcotolMethodWrapped],
}

impl pyo3::class::methods::protocols::PyProtocolInventory for #protocol_inventory_cls {
fn new(methods: &'static [pyo3::methods::protocols::PyProcotolMethodWrapped]) -> Self {
Self {
methods
}
}

fn get_methods(&self) -> &'static [pyo3::methods::protocols::PyProcotolMethodWrapped] {
self.methods
}
}

impl pyo3::class::methods::protocols::PyProtocolInventoryDispatch for #cls {
type ProtocolInventoryType = #protocol_inventory_cls;
}

pyo3::inventory::collect!(#protocol_inventory_cls);
}
}

Expand Down
57 changes: 56 additions & 1 deletion pyo3-derive-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,60 @@ pub fn build_py_methods(ast: &mut syn::ItemImpl) -> syn::Result<TokenStream> {
}
}

fn binary_func_protocol_wrap(ty: &syn::Type, name: &syn::Ident) -> TokenStream {
quote! {{
#[allow(unused_mut)]
unsafe extern "C" fn wrap(
lhs: *mut pyo3::ffi::PyObject,
rhs: *mut pyo3::ffi::PyObject,
) -> *mut pyo3::ffi::PyObject {
use pyo3::ObjectProtocol;
let _pool = pyo3::GILPool::new();
let py = pyo3::Python::assume_gil_acquired();
let lhs = py.from_borrowed_ptr::<pyo3::types::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<pyo3::types::PyAny>(rhs);

let result = match lhs.extract() {
Ok(lhs) => match rhs.extract() {
Ok(rhs) => #ty::#name(lhs, rhs).into(),
Err(e) => Err(e.into()),
},
Err(e) => Err(e.into()),
};
pyo3::callback::cb_convert(pyo3::callback::PyObjectCallbackConverter, py, result)
}
pyo3::class::methods::protocols::PyProcotolMethodWrapped::Add(wrap)
}}
}

pub fn impl_methods(ty: &syn::Type, impls: &mut Vec<syn::ImplItem>) -> syn::Result<TokenStream> {
// get method names in impl block
let mut methods = Vec::new();
let mut protocol_methods = Vec::new();
for iimpl in impls.iter_mut() {
if let syn::ImplItem::Method(ref mut meth) = iimpl {
let name = meth.sig.ident.clone();

if name.to_string().starts_with("__") && name.to_string().ends_with("__") {
#[allow(clippy::single_match)]
{
match name.to_string().as_str() {
"__add__" => {
protocol_methods.push(binary_func_protocol_wrap(&ty, &name));
}
_ => {
// This currently breaks the tests
/*
return Err(syn::Error::new_spanned(
meth.sig.ident.clone(),
"Unknown dunder method",
))
*/
}
}
}
}

methods.push(pymethod::gen_py_method(
ty,
&name,
Expand All @@ -36,11 +84,18 @@ pub fn impl_methods(ty: &syn::Type, impls: &mut Vec<syn::ImplItem>) -> syn::Resu
}

Ok(quote! {
pyo3::inventory::submit! {
pyo3::inventory::submit! {
#![crate = pyo3] {
type TyInventory = <#ty as pyo3::class::methods::PyMethodsInventoryDispatch>::InventoryType;
<TyInventory as pyo3::class::methods::PyMethodsInventory>::new(&[#(#methods),*])
}
}

pyo3::inventory::submit! {
#![crate = pyo3] {
type ProtocolInventory = <#ty as pyo3::class::methods::protocols::PyProtocolInventoryDispatch>::ProtocolInventoryType;
<ProtocolInventory as pyo3::class::methods::protocols::PyProtocolInventory>::new(&[#(#protocol_methods),*])
}
}
})
}
74 changes: 73 additions & 1 deletion src/class/methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl PySetterDef {
}

#[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code
/// This trait is implemented for all pyclass so to implement the [PyMethodsProtocol]
/// This trait is implemented for all pyclass to implement the [PyMethodsProtocol]
/// through inventory
pub trait PyMethodsInventoryDispatch {
/// This allows us to get the inventory type when only the pyclass is in scope
Expand Down Expand Up @@ -153,3 +153,75 @@ where
.collect()
}
}

/// Utils to define and collect dunder methods, powered by inventory
pub mod protocols {
use crate::ffi;

#[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code
/// The c wrapper around a dunder method defined in an impl block
pub enum PyProcotolMethodWrapped {
Add(ffi::binaryfunc),
}

#[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code
/// All defined dunder methods collected into a single struct
#[derive(Default)]
pub struct PyProcolTypes {
pub(crate) add: Option<ffi::binaryfunc>,
}

impl PyProcolTypes {
/// Returns whether any dunder method has been defined
pub fn any_defined(&self) -> bool {
self.add.is_some()
}
}

#[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code
/// This trait is implemented for all pyclass to implement the [PyProtocolInventory]
/// through inventory
pub trait PyProtocolInventoryDispatch {
/// This allows us to get the inventory type when only the pyclass is in scope
type ProtocolInventoryType: PyProtocolInventory;
}

#[doc(hidden)]
/// Allows arbitrary pymethod blocks to submit dunder methods, which are eventually collected
/// into [PyProcolTypes]
pub trait PyProtocolInventory: inventory::Collect {
fn new(methods: &'static [PyProcotolMethodWrapped]) -> Self;
fn get_methods(&self) -> &'static [PyProcotolMethodWrapped];
}

/// Defines which protocols this class implements
pub trait PyProtocol {
/// Returns all methods that are defined for a class
fn py_protocols() -> PyProcolTypes;
}

impl<T> PyProtocol for T
where
T: PyProtocolInventoryDispatch,
{
/// Collects all defined dunder methods into a single [PyProcolTypes] instance
fn py_protocols() -> PyProcolTypes {
let mut py_protocol_types = PyProcolTypes::default();
let flattened = inventory::iter::<T::ProtocolInventoryType>
.into_iter()
.flat_map(PyProtocolInventory::get_methods);
for method in flattened {
match method {
PyProcotolMethodWrapped::Add(add) => {
if py_protocol_types.add.is_some() {
panic!("You can't define `__add__` more than once");
}
py_protocol_types.add = Some(*add);
}
}
}

py_protocol_types
}
}
}
22 changes: 15 additions & 7 deletions src/class/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use crate::callback::PyObjectCallbackConverter;
use crate::class::basic::PyObjectProtocolImpl;
use crate::class::methods::protocols::PyProtocol;
use crate::class::methods::PyMethodDef;
use crate::err::PyResult;
use crate::ffi;
Expand Down Expand Up @@ -621,14 +622,21 @@ pub trait PyNumberIndexProtocol<'p>: PyNumberProtocol<'p> {
}

#[doc(hidden)]
pub trait PyNumberProtocolImpl: PyObjectProtocolImpl {
pub trait PyNumberProtocolImpl: PyObjectProtocolImpl + PyProtocol {
fn methods() -> Vec<PyMethodDef> {
Vec::new()
}
fn tp_as_number() -> Option<ffi::PyNumberMethods> {
if let Some(nb_bool) = <Self as PyObjectProtocolImpl>::nb_bool_fn() {
let meth = ffi::PyNumberMethods {
nb_bool: Some(nb_bool),
nb_add: Self::py_protocols().add,
..ffi::PyNumberMethods_INIT
};
Some(meth)
} else if Self::py_protocols().any_defined() {
let meth = ffi::PyNumberMethods {
nb_add: Self::py_protocols().add,
..ffi::PyNumberMethods_INIT
};
Some(meth)
Expand All @@ -638,11 +646,11 @@ pub trait PyNumberProtocolImpl: PyObjectProtocolImpl {
}
}

impl<'p, T> PyNumberProtocolImpl for T {}
impl<'p, T> PyNumberProtocolImpl for T where T: PyProtocol {}

impl<'p, T> PyNumberProtocolImpl for T
where
T: PyNumberProtocol<'p>,
T: PyNumberProtocol<'p> + PyProtocol,
{
fn tp_as_number() -> Option<ffi::PyNumberMethods> {
Some(ffi::PyNumberMethods {
Expand Down Expand Up @@ -742,17 +750,17 @@ where
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the selection of the PyNumberProtocolImpl impl depend upon specialisation? This branch doesn't compile for me on nightly 2019-11-09.

trait PyNumberAddProtocolImpl {
trait PyNumberAddProtocolImpl: PyProtocol {
fn nb_add() -> Option<ffi::binaryfunc> {
None
Self::py_protocols().add
}
}

impl<'p, T> PyNumberAddProtocolImpl for T where T: PyNumberProtocol<'p> {}
impl<'p, T> PyNumberAddProtocolImpl for T where T: PyNumberProtocol<'p> + PyProtocol {}

impl<T> PyNumberAddProtocolImpl for T
where
T: for<'p> PyNumberAddProtocol<'p>,
T: for<'p> PyNumberAddProtocol<'p> + PyProtocol,
{
fn nb_add() -> Option<ffi::binaryfunc> {
py_binary_num_func!(
Expand Down
2 changes: 1 addition & 1 deletion src/ffi3/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ mod typeobject {
impl Default for PyNumberMethods {
#[inline]
fn default() -> Self {
unsafe { mem::zeroed() }
PyNumberMethods_INIT
}
}
macro_rules! as_expr {
Expand Down
7 changes: 4 additions & 3 deletions src/type_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

//! Python type object information

use crate::class::methods::protocols::PyProtocol;
use crate::class::methods::PyMethodDefType;
use crate::err::{PyErr, PyResult};
use crate::instance::{Py, PyNativeType};
Expand Down Expand Up @@ -249,7 +250,7 @@ pub unsafe trait PyTypeObject {

unsafe impl<T> PyTypeObject for T
where
T: PyTypeInfo + PyMethodsProtocol + PyObjectAlloc,
T: PyTypeInfo + PyMethodsProtocol + PyObjectAlloc + PyProtocol,
{
fn init_type() -> NonNull<ffi::PyTypeObject> {
let type_object = unsafe { <Self as PyTypeInfo>::type_object() };
Expand Down Expand Up @@ -297,7 +298,7 @@ impl<T> PyTypeCreate for T where T: PyObjectAlloc + PyTypeObject + Sized {}
#[cfg(not(Py_LIMITED_API))]
pub fn initialize_type<T>(py: Python, module_name: Option<&str>) -> PyResult<*mut ffi::PyTypeObject>
where
T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol,
T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol + PyProtocol,
{
let type_object: &mut ffi::PyTypeObject = unsafe { T::type_object() };
let base_type_object: &mut ffi::PyTypeObject =
Expand Down Expand Up @@ -438,7 +439,7 @@ fn py_class_flags<T: PyTypeInfo>(type_object: &mut ffi::PyTypeObject) {
}
}

fn py_class_method_defs<T: PyMethodsProtocol>() -> (
fn py_class_method_defs<T: PyMethodsProtocol + PyProtocol>() -> (
Option<ffi::newfunc>,
Option<ffi::initproc>,
Option<ffi::PyCFunctionWithKeywords>,
Expand Down