diff --git a/newsfragments/4453.changed.md b/newsfragments/4453.changed.md new file mode 100644 index 00000000000..58a1e0ffcbd --- /dev/null +++ b/newsfragments/4453.changed.md @@ -0,0 +1 @@ +Make subclassing a class that doesn't allow that a compile-time error instead of runtime \ No newline at end of file diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 9e3dbceaa91..9d5535f6e17 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -2244,7 +2244,20 @@ impl<'a> PyClassImplsBuilder<'a> { quote! { #pyo3_path::PyAny } }; + let pyclass_base_type_impl = attr.options.subclass.map(|subclass| { + quote_spanned! { subclass.span() => + impl #pyo3_path::impl_::pyclass::PyClassBaseType for #cls { + type LayoutAsBase = #pyo3_path::impl_::pycell::PyClassObject; + type BaseNativeType = ::BaseNativeType; + type Initializer = #pyo3_path::pyclass_init::PyClassInitializer; + type PyClassMutability = ::PyClassMutability; + } + } + }); + Ok(quote! { + #pyclass_base_type_impl + impl #pyo3_path::impl_::pyclass::PyClassImpl for #cls { const IS_BASETYPE: bool = #is_basetype; const IS_SUBCLASS: bool = #is_subclass; diff --git a/src/exceptions.rs b/src/exceptions.rs index fbd2eb077b5..292f6e43237 100644 --- a/src/exceptions.rs +++ b/src/exceptions.rs @@ -275,6 +275,7 @@ macro_rules! impl_native_exception ( $crate::impl_exception_boilerplate!($name); $crate::pyobject_native_type!($name, $layout, |_py| unsafe { $crate::ffi::$exc_name as *mut $crate::ffi::PyTypeObject } $(, #checkfunction=$checkfunction)?); + $crate::pyobject_subclassable_native_type!($name, $layout); ); ($name:ident, $exc_name:ident, $doc:expr) => ( impl_native_exception!($name, $exc_name, $doc, $crate::ffi::PyBaseExceptionObject); diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index 1680eca34ac..2256c0db7c9 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -1113,7 +1113,18 @@ impl PyClassThreadChecker for ThreadCheckerImpl { #[cfg_attr( all(diagnostic_namespace, feature = "abi3"), diagnostic::on_unimplemented( - note = "with the `abi3` feature enabled, PyO3 does not support subclassing native types" + message = "pyclass `{Self}` cannot be subclassed", + label = "required for `#[pyclass(extends={Self})]`", + note = "if you own `{Self}`, add `subclass` to the `#[pyclass]` macro: `#[pyclass(subclass)]`", + note = "with the `abi3` feature enabled, PyO3 does not support subclassing native types", + ) +)] +#[cfg_attr( + all(diagnostic_namespace, not(feature = "abi3")), + diagnostic::on_unimplemented( + message = "pyclass `{Self}` cannot be subclassed", + label = "required for `#[pyclass(extends={Self})]`", + note = "if you own `{Self}`, add `subclass` to the `#[pyclass]` macro: `#[pyclass(subclass)]`", ) )] pub trait PyClassBaseType: Sized { @@ -1123,16 +1134,6 @@ pub trait PyClassBaseType: Sized { type PyClassMutability: PyClassMutability; } -/// All mutable PyClasses can be used as a base type. -/// -/// In the future this will be extended to immutable PyClasses too. -impl PyClassBaseType for T { - type LayoutAsBase = crate::impl_::pycell::PyClassObject; - type BaseNativeType = T::BaseNativeType; - type Initializer = crate::pyclass_init::PyClassInitializer; - type PyClassMutability = T::PyClassMutability; -} - /// Implementation of tp_dealloc for pyclasses without gc pub(crate) unsafe extern "C" fn tp_dealloc(obj: *mut ffi::PyObject) { crate::impl_::trampoline::dealloc(obj, PyClassObject::::tp_dealloc) diff --git a/src/pyclass_init.rs b/src/pyclass_init.rs index 01983c79b13..8e331e78229 100644 --- a/src/pyclass_init.rs +++ b/src/pyclass_init.rs @@ -285,7 +285,7 @@ where impl From<(S, B)> for PyClassInitializer where S: PyClass, - B: PyClass, + B: PyClass + PyClassBaseType>, B::BaseType: PyClassBaseType>, { fn from(sub_and_base: (S, B)) -> PyClassInitializer { diff --git a/src/types/any.rs b/src/types/any.rs index e7c7c578e3e..422c04bee9d 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -44,6 +44,7 @@ pyobject_native_type_info!( ); pyobject_native_type_sized!(PyAny, ffi::PyObject); +pyobject_subclassable_native_type!(PyAny, ffi::PyObject); /// This trait represents the Python APIs which are usable on all Python objects. /// diff --git a/src/types/complex.rs b/src/types/complex.rs index 131bcc09347..58651569b47 100644 --- a/src/types/complex.rs +++ b/src/types/complex.rs @@ -19,6 +19,8 @@ use std::os::raw::c_double; #[repr(transparent)] pub struct PyComplex(PyAny); +pyobject_subclassable_native_type!(PyComplex, ffi::PyComplexObject); + pyobject_native_type!( PyComplex, ffi::PyComplexObject, diff --git a/src/types/datetime.rs b/src/types/datetime.rs index a70cb6c885e..6f9ba17dce7 100644 --- a/src/types/datetime.rs +++ b/src/types/datetime.rs @@ -192,6 +192,7 @@ pyobject_native_type!( #module=Some("datetime"), #checkfunction=PyDate_Check ); +pyobject_subclassable_native_type!(PyDate, crate::ffi::PyDateTime_Date); impl PyDate { /// Creates a new `datetime.date`. @@ -248,6 +249,7 @@ pyobject_native_type!( #module=Some("datetime"), #checkfunction=PyDateTime_Check ); +pyobject_subclassable_native_type!(PyDateTime, crate::ffi::PyDateTime_DateTime); impl PyDateTime { /// Creates a new `datetime.datetime` object. @@ -424,6 +426,7 @@ pyobject_native_type!( #module=Some("datetime"), #checkfunction=PyTime_Check ); +pyobject_subclassable_native_type!(PyTime, crate::ffi::PyDateTime_Time); impl PyTime { /// Creates a new `datetime.time` object. @@ -550,6 +553,7 @@ pyobject_native_type!( #module=Some("datetime"), #checkfunction=PyTZInfo_Check ); +pyobject_subclassable_native_type!(PyTzInfo, crate::ffi::PyObject); /// Equivalent to `datetime.timezone.utc` pub fn timezone_utc_bound(py: Python<'_>) -> Bound<'_, PyTzInfo> { @@ -594,6 +598,7 @@ pyobject_native_type!( #module=Some("datetime"), #checkfunction=PyDelta_Check ); +pyobject_subclassable_native_type!(PyDelta, crate::ffi::PyDateTime_Delta); impl PyDelta { /// Creates a new `timedelta`. diff --git a/src/types/dict.rs b/src/types/dict.rs index 50a9e139355..3b05f02df8e 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -18,6 +18,9 @@ use crate::{ffi, Python, ToPyObject}; #[repr(transparent)] pub struct PyDict(PyAny); +#[cfg(not(feature = "abi3"))] +pyobject_subclassable_native_type!(PyDict, crate::ffi::PyDictObject); + pyobject_native_type!( PyDict, ffi::PyDictObject, diff --git a/src/types/float.rs b/src/types/float.rs index 5e637af3b62..2fe2952761a 100644 --- a/src/types/float.rs +++ b/src/types/float.rs @@ -23,6 +23,9 @@ use std::os::raw::c_double; #[repr(transparent)] pub struct PyFloat(PyAny); +#[cfg(not(feature = "abi3"))] +pyobject_subclassable_native_type!(PyFloat, crate::ffi::PyFloatObject); + pyobject_native_type!( PyFloat, ffi::PyFloatObject, diff --git a/src/types/frozenset.rs b/src/types/frozenset.rs index 6a0cdca89d5..c76d087d635 100644 --- a/src/types/frozenset.rs +++ b/src/types/frozenset.rs @@ -61,6 +61,8 @@ impl<'py> PyFrozenSetBuilder<'py> { #[repr(transparent)] pub struct PyFrozenSet(PyAny); +#[cfg(not(feature = "abi3"))] +pyobject_subclassable_native_type!(PyFrozenSet, crate::ffi::PySetObject); #[cfg(not(any(PyPy, GraalPy)))] pyobject_native_type!( PyFrozenSet, diff --git a/src/types/mod.rs b/src/types/mod.rs index 9a54eee9661..81f2b7acef0 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -194,10 +194,8 @@ macro_rules! pyobject_native_type_core { #[doc(hidden)] #[macro_export] -macro_rules! pyobject_native_type_sized { +macro_rules! pyobject_subclassable_native_type { ($name:ty, $layout:path $(;$generics:ident)*) => { - unsafe impl $crate::type_object::PyLayout<$name> for $layout {} - impl $crate::type_object::PySizedLayout<$name> for $layout {} impl<$($generics,)*> $crate::impl_::pyclass::PyClassBaseType for $name { type LayoutAsBase = $crate::impl_::pycell::PyClassObjectBase<$layout>; type BaseNativeType = $name; @@ -207,6 +205,15 @@ macro_rules! pyobject_native_type_sized { } } +#[doc(hidden)] +#[macro_export] +macro_rules! pyobject_native_type_sized { + ($name:ty, $layout:path $(;$generics:ident)*) => { + unsafe impl $crate::type_object::PyLayout<$name> for $layout {} + impl $crate::type_object::PySizedLayout<$name> for $layout {} + }; +} + /// Declares all of the boilerplate for Python types which can be inherited from (because the exact /// Python layout is known). #[doc(hidden)] diff --git a/src/types/set.rs b/src/types/set.rs index fd65d4bcaa4..5be019a31e3 100644 --- a/src/types/set.rs +++ b/src/types/set.rs @@ -19,6 +19,9 @@ use std::ptr; #[repr(transparent)] pub struct PySet(PyAny); +#[cfg(not(feature = "abi3"))] +pyobject_subclassable_native_type!(PySet, crate::ffi::PySetObject); + #[cfg(not(any(PyPy, GraalPy)))] pyobject_native_type!( PySet, diff --git a/src/types/weakref/reference.rs b/src/types/weakref/reference.rs index 59f6f5bf3be..ed9528a1eb3 100644 --- a/src/types/weakref/reference.rs +++ b/src/types/weakref/reference.rs @@ -15,6 +15,9 @@ use super::PyWeakrefMethods; #[repr(transparent)] pub struct PyWeakrefReference(PyAny); +#[cfg(not(any(PyPy, GraalPy)))] +pyobject_subclassable_native_type!(PyWeakrefReference, crate::ffi::PyWeakReference); + #[cfg(not(any(PyPy, GraalPy, Py_LIMITED_API)))] pyobject_native_type!( PyWeakrefReference, diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index b1fcdc09fb7..bbca8ff6cb9 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -62,4 +62,5 @@ fn test_compile_errors() { #[cfg(all(Py_LIMITED_API, not(Py_3_9)))] t.compile_fail("tests/ui/abi3_dict.rs"); t.compile_fail("tests/ui/duplicate_pymodule_submodule.rs"); + t.compile_fail("tests/ui/invalid_base_class.rs"); } diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 9d3250d79ef..a911702ce20 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -4,8 +4,6 @@ use pyo3::create_exception; use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; -#[cfg(not(Py_LIMITED_API))] -use pyo3::types::PyBool; #[path = "../src/tests/common.rs"] mod common; @@ -186,31 +184,6 @@ fn test_declarative_module() { }) } -#[cfg(not(Py_LIMITED_API))] -#[pyclass(extends = PyBool)] -struct ExtendsBool; - -#[cfg(not(Py_LIMITED_API))] -#[pymodule] -mod class_initialization_module { - #[pymodule_export] - use super::ExtendsBool; -} - -#[test] -#[cfg(not(Py_LIMITED_API))] -fn test_class_initialization_fails() { - Python::with_gil(|py| { - let err = class_initialization_module::_PYO3_DEF - .make_module(py) - .unwrap_err(); - assert_eq!( - err.to_string(), - "RuntimeError: An error occurred while initializing class ExtendsBool" - ); - }) -} - #[pymodule] mod r#type { #[pymodule_export] diff --git a/tests/test_inheritance.rs b/tests/test_inheritance.rs index a43ab57b6c1..d3980152120 100644 --- a/tests/test_inheritance.rs +++ b/tests/test_inheritance.rs @@ -345,26 +345,3 @@ fn test_subclass_ref_counts() { ); }) } - -#[test] -#[cfg(not(Py_LIMITED_API))] -fn module_add_class_inherit_bool_fails() { - use pyo3::types::PyBool; - - #[pyclass(extends = PyBool)] - struct ExtendsBool; - - Python::with_gil(|py| { - let m = PyModule::new(py, "test_module").unwrap(); - - let err = m.add_class::().unwrap_err(); - assert_eq!( - err.to_string(), - "RuntimeError: An error occurred while initializing class ExtendsBool" - ); - assert_eq!( - err.cause(py).unwrap().to_string(), - "TypeError: type 'bool' is not an acceptable base type" - ); - }) -} diff --git a/tests/test_macros.rs b/tests/test_macros.rs index 40fd4847679..4b5feddf3f9 100644 --- a/tests/test_macros.rs +++ b/tests/test_macros.rs @@ -12,7 +12,7 @@ macro_rules! make_struct_using_macro { // Ensure that one doesn't need to fall back on the escape type: tt // in order to macro create pyclass. ($class_name:ident, $py_name:literal) => { - #[pyclass(name=$py_name)] + #[pyclass(name=$py_name, subclass)] struct $class_name {} }; } diff --git a/tests/ui/invalid_base_class.rs b/tests/ui/invalid_base_class.rs new file mode 100644 index 00000000000..7433dcb2b96 --- /dev/null +++ b/tests/ui/invalid_base_class.rs @@ -0,0 +1,7 @@ +use pyo3::prelude::*; +use pyo3::types::PyBool; + +#[pyclass(extends=PyBool)] +struct ExtendsBool; + +fn main() {}