Skip to content

Commit

Permalink
add experimental support for PEP 489 multi-phase module initialization
Browse files Browse the repository at this point in the history
This commit adds experimental but fully functional support for
multi-phase module initialization as specified in PEP 489.

Note that this commit serves as a demonstration & basis for further
improvements only; many tests have therefore not been adapted
correspondingly yet.

Signed-off-by: Max R. Carrara <[email protected]>
  • Loading branch information
Aequitosh committed Jul 23, 2024
1 parent 8749371 commit bb977a1
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 74 deletions.
60 changes: 55 additions & 5 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,7 @@ pub fn pymodule_module_impl(
unsafe {
impl_::ModuleDef::new(
__PYO3_NAME,
#doc,
INITIALIZER
#doc
)
}
}};
Expand Down Expand Up @@ -385,6 +384,9 @@ pub fn pymodule_function_impl(

let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }, false);

// Each generated `module_objects_init` function is exported as a separate symbol.
let module_objects_init_symbol = format!("__module_objects_init__{}", ident.unraw());

// Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new();
if function.sig.inputs.len() == 2 {
Expand All @@ -396,6 +398,32 @@ pub fn pymodule_function_impl(
Ok(quote! {
#[doc(hidden)]
#vis mod #ident {
/// Function used to add classes, functions, etc. to the module during
/// multi-phase initialization.
#[doc(hidden)]
#[export_name = #module_objects_init_symbol]
pub unsafe extern "C" fn __module_objects_init(module: *mut #pyo3_path::ffi::PyObject) -> ::std::ffi::c_int {
let module = unsafe {
let nonnull = ::std::ptr::NonNull::new(module).expect("'module' shouldn't be NULL");
#pyo3_path::Py::<#pyo3_path::types::PyModule>::from_non_null(nonnull)
};

let res = unsafe {
#pyo3_path::Python::with_gil_unchecked(|py| {
let bound = module.bind(py);
MakeDef::do_init_multiphase(bound)
})
};

// FIXME: Better error handling
let _ = res.unwrap();

0
}

#[doc(hidden)]
pub const __PYO3_INIT: *mut ::std::ffi::c_void = __module_objects_init as *mut ::std::ffi::c_void;

#initialization
}

Expand All @@ -405,17 +433,22 @@ pub fn pymodule_function_impl(
// inside a function body)
#[allow(unknown_lints, non_local_definitions)]
impl #ident::MakeDef {
/// Helper function for `__module_objects_init`. Should probably be put
/// somewhere else.
#[doc(hidden)]
pub fn do_init_multiphase(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}

const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}

const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
unsafe {
#pyo3_path::impl_::pymodule::ModuleDef::new(
#ident::__PYO3_NAME,
#doc,
INITIALIZER
)
}
}
Expand All @@ -442,14 +475,31 @@ fn module_initialization(
#[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
};

if !is_submodule {
result.extend(quote! {
#[doc(hidden)]
pub static _PYO3_SLOTS: &[#pyo3_path::impl_::pymodule_state::ModuleDefSlot] = &[
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::start(),
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::new(
#pyo3_path::ffi::Py_mod_exec,
__PYO3_INIT,
),
#[cfg(Py_3_12)]
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::per_interpreter_gil(),
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::end(),
];

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
#pyo3_path::impl_::trampoline::module_init(|py| {
let slots = #pyo3_path::impl_::pymodule_state::ModuleDefSlots::new_from_static(_PYO3_SLOTS);
_PYO3_DEF.set_multiphase_items(slots);
_PYO3_DEF.make_module(py);
})
}
});
}
Expand Down
167 changes: 102 additions & 65 deletions src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,33 @@ use crate::{
Bound, Py, PyClass, PyMethodDef, PyResult, PyTypeInfo, Python,
};

use crate::impl_::pymodule_state as state;

// TODO: replace other usages (if this passes review :^) )
pub use state::ModuleDefSlot;

/// `Sync` wrapper of `ffi::PyModuleDef`.
pub struct ModuleDef {
// wrapped in UnsafeCell so that Rust compiler treats this as interior mutability
ffi_def: UnsafeCell<ffi::PyModuleDef>,
initializer: ModuleInitializer,
/// Interpreter ID where module was initialized (not applicable on PyPy).
#[cfg(all(
not(any(PyPy, GraalPy)),
Py_3_9,
not(all(windows, Py_LIMITED_API, not(Py_3_10)))
))]
interpreter: AtomicI64,
// TODO: `module` could probably go..?
/// Initialized module object, cached to avoid reinitialization.
#[allow(unused)]
module: GILOnceCell<Py<PyModule>>,
}

/// Wrapper to enable initializer to be used in const fns.
pub struct ModuleInitializer(pub for<'py> fn(&Bound<'py, PyModule>) -> PyResult<()>);

unsafe impl Sync for ModuleDef {}

impl ModuleDef {
/// Make new module definition with given module name.
pub const unsafe fn new(
name: &'static CStr,
doc: &'static CStr,
initializer: ModuleInitializer,
) -> Self {
pub const unsafe fn new(name: &'static CStr, doc: &'static CStr) -> Self {
const INIT: ffi::PyModuleDef = ffi::PyModuleDef {
m_base: ffi::PyModuleDef_HEAD_INIT,
m_name: std::ptr::null(),
Expand All @@ -74,7 +73,6 @@ impl ModuleDef {

ModuleDef {
ffi_def,
initializer,
// -1 is never expected to be a valid interpreter ID
#[cfg(all(
not(any(PyPy, GraalPy)),
Expand All @@ -85,8 +83,9 @@ impl ModuleDef {
module: GILOnceCell::new(),
}
}

/// Builds a module using user given initializer. Used for [`#[pymodule]`][crate::pymodule].
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<Py<PyModule>> {
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<*mut ffi::PyModuleDef> {
#[cfg(all(PyPy, not(Py_3_8)))]
{
use crate::types::any::PyAnyMethods;
Expand Down Expand Up @@ -140,18 +139,31 @@ impl ModuleDef {
}
}
}
self.module
.get_or_try_init(py, || {
let module = unsafe {
Py::<PyModule>::from_owned_ptr_or_err(
py,
ffi::PyModule_Create(self.ffi_def.get()),
)?
};
self.initializer.0(module.bind(py))?;
Ok(module)
})
.map(|py_module| py_module.clone_ref(py))

if (unsafe { *self.ffi_def.get() }).m_slots.is_null() {
return Err(PyImportError::new_err(
"'m_slots' of module definition is NULL",
));
}

let module_def_ptr = unsafe { ffi::PyModuleDef_Init(self.ffi_def.get()) };

if module_def_ptr.is_null() {
return Err(PyImportError::new_err("PyModuleDef_Init returned NULL"));
}

Ok(module_def_ptr.cast())
}

pub fn set_multiphase_items(&'static self, slots: state::ModuleDefSlots) {
let ffi_def = self.ffi_def.get();
unsafe {
(*ffi_def).m_size = std::mem::size_of::<state::ModuleState>() as ffi::Py_ssize_t;
(*ffi_def).m_slots = slots.into_inner();
(*ffi_def).m_traverse = Some(state::module_state_traverse);
(*ffi_def).m_clear = Some(state::module_state_clear);
(*ffi_def).m_free = Some(state::module_state_free);
};
}
}

Expand Down Expand Up @@ -204,7 +216,44 @@ impl PyAddToModule for PyMethodDef {
/// For adding a module to a module.
impl PyAddToModule for ModuleDef {
fn add_to_module(&'static self, module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_submodule(self.make_module(module.py())?.bind(module.py()))
let parent_ptr = module.as_ptr();
let parent_name = std::ffi::CString::new(module.name()?.to_string())?;

let add_to_parent = |child_ptr: *mut ffi::PyObject| -> std::ffi::c_int {
// TODO: reference to child_ptr is stolen - check if this is fine here?
let ret =
unsafe { ffi::PyModule_AddObject(parent_ptr, parent_name.as_ptr(), child_ptr) };

// TODO: .. as well as this error handling here - is this fine
// inside Py_mod_exec slots?
if ret < 0 {
unsafe { ffi::Py_DECREF(parent_ptr) };
return -1;
}

0
};

// SAFETY: We only use this closure inside the ModuleDef's slots and
// then immediately initialize the module - this closure /
// "function pointer" isn't used anywhere else afterwards and can't
// outlive the current thread.
let add_to_parent = unsafe { state::alloc_closure(add_to_parent) };

let slots = [
state::ModuleDefSlot::start(),
state::ModuleDefSlot::new(ffi::Py_mod_exec, add_to_parent),
#[cfg(Py_3_12)]
state::ModuleDefSlot::per_interpreter_gil(),
state::ModuleDefSlot::end(),
];

let slots = state::alloc_slots(slots);
self.set_multiphase_items(slots);

let _module_def_ptr = self.make_module(module.py())?;

Ok(())
}
}

Expand All @@ -218,50 +267,31 @@ mod tests {

use crate::{
ffi,
impl_::pymodule_state as state,
types::{any::PyAnyMethods, module::PyModuleMethods, PyModule},
Bound, PyResult, Python,
};

use super::{ModuleDef, ModuleInitializer};
use super::ModuleDef;

#[test]
fn module_init() {
static MODULE_DEF: ModuleDef = unsafe {
ModuleDef::new(
ffi::c_str!("test_module"),
ffi::c_str!("some doc"),
ModuleInitializer(|m| {
m.add("SOME_CONSTANT", 42)?;
Ok(())
}),
)
};
static MODULE_DEF: ModuleDef =
unsafe { ModuleDef::new(ffi::c_str!("test_module"), ffi::c_str!("some doc")) };

let slots = [
state::ModuleDefSlot::start(),
#[cfg(Py_3_12)]
state::ModuleDefSlot::per_interpreter_gil(),
state::ModuleDefSlot::end(),
];

MODULE_DEF.set_multiphase_items(state::alloc_slots(slots));

Python::with_gil(|py| {
let module = MODULE_DEF.make_module(py).unwrap().into_bound(py);
assert_eq!(
module
.getattr("__name__")
.unwrap()
.extract::<Cow<'_, str>>()
.unwrap(),
"test_module",
);
assert_eq!(
module
.getattr("__doc__")
.unwrap()
.extract::<Cow<'_, str>>()
.unwrap(),
"some doc",
);
assert_eq!(
module
.getattr("SOME_CONSTANT")
.unwrap()
.extract::<u8>()
.unwrap(),
42,
);
let module_def = MODULE_DEF.make_module(py).unwrap();
// FIXME: get PyModule from PyModuleDef ..?
unimplemented!("Test currently not implemented");
})
}

Expand All @@ -272,6 +302,13 @@ mod tests {
static NAME: &CStr = ffi::c_str!("test_module");
static DOC: &CStr = ffi::c_str!("some doc");

let slots = [
state::ModuleDefSlot::start(),
#[cfg(Py_3_12)]
state::ModuleDefSlot::per_interpreter_gil(),
state::ModuleDefSlot::end(),
];

static INIT_CALLED: AtomicBool = AtomicBool::new(false);

#[allow(clippy::unnecessary_wraps)]
Expand All @@ -281,12 +318,12 @@ mod tests {
}

unsafe {
let module_def: ModuleDef = ModuleDef::new(NAME, DOC, ModuleInitializer(init));
assert_eq!((*module_def.ffi_def.get()).m_name, NAME.as_ptr() as _);
assert_eq!((*module_def.ffi_def.get()).m_doc, DOC.as_ptr() as _);
static MODULE_DEF: ModuleDef = unsafe { ModuleDef::new(NAME, DOC) };
MODULE_DEF.set_multiphase_items(state::alloc_slots(slots));
assert_eq!((*MODULE_DEF.ffi_def.get()).m_name, NAME.as_ptr() as _);
assert_eq!((*MODULE_DEF.ffi_def.get()).m_doc, DOC.as_ptr() as _);

Python::with_gil(|py| {
module_def.initializer.0(&py.import_bound("builtins").unwrap()).unwrap();
Python::with_gil(|_py| {
assert!(INIT_CALLED.load(Ordering::SeqCst));
})
}
Expand Down
6 changes: 3 additions & 3 deletions src/impl_/trampoline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ use std::{
use crate::gil::GILGuard;
use crate::{
callback::PyCallbackOutput, ffi, ffi_ptr_ext::FfiPtrExt, impl_::panic::PanicTrap,
methods::IPowModulo, panic::PanicException, types::PyModule, Py, PyResult, Python,
methods::IPowModulo, panic::PanicException, PyResult, Python,
};

#[inline]
pub unsafe fn module_init(
f: for<'py> unsafe fn(Python<'py>) -> PyResult<Py<PyModule>>,
f: for<'py> unsafe fn(Python<'py>) -> PyResult<*mut ffi::PyModuleDef>,
) -> *mut ffi::PyObject {
trampoline(|py| f(py).map(|module| module.into_ptr()))
trampoline(|py| f(py).map(|module_def| module_def.cast()))
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,7 @@ impl<T> Py<T> {
///
/// # Safety
/// `ptr` must point to a Python object of type T.
unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
pub unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
Self(ptr, PhantomData)
}
}
Expand Down

0 comments on commit bb977a1

Please sign in to comment.