Skip to content

Commit

Permalink
add experimental support for PEP 489 multi-phase module initialization
Browse files Browse the repository at this point in the history
This commit adds experimental but *almost* fully functional support
for multi-phase module initialization as specified in PEP 489.

With the exception of adding submodules to modules, all other
functionality should be retained. In other words, unless submodules
are used, multi-phase initialization is a complete drop-in
replacement.

Note that this commit serves as a demonstration & basis for further
improvements only; many tests have therefore not been adapted
correspondingly yet.

Nevertheless, it runs.

Signed-off-by: Max Carrara <[email protected]>
  • Loading branch information
Aequitosh committed May 6, 2024
1 parent e792269 commit 02eb845
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 72 deletions.
67 changes: 62 additions & 5 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,15 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {

#initialization

// FIXME: Support multi-phase initialization here
#[allow(unknown_lints, non_local_definitions)]
impl MakeDef {
const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
use #pyo3_path::impl_::pymodule as impl_;
const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
unsafe {
impl_::ModuleDef::new(
__PYO3_NAME,
#doc,
INITIALIZER
)
}
}
Expand Down Expand Up @@ -289,6 +288,9 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
let vis = &function.vis;
let doc = get_doc(&function.attrs, None);

// Each generated `module_objects_init` function is exported as a separate symbol.
let module_objects_init_symbol = format!("__module_objects_init__{}", ident.unraw());

let initialization = module_initialization(options, ident);

// Module function called with optional Python<'_> marker as first arg, followed by the module.
Expand Down Expand Up @@ -327,6 +329,32 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
#function
#[doc(hidden)]
#vis mod #ident {
/// Function used to add classes, functions, etc. to the module during
/// multi-phase initialization.
#[doc(hidden)]
#[export_name = #module_objects_init_symbol]
pub unsafe extern "C" fn __module_objects_init(module: *mut #pyo3_path::ffi::PyObject) -> ::std::ffi::c_int {
let module = unsafe {
let nonnull = ::std::ptr::NonNull::new(module).expect("'module' shouldn't be NULL");
#pyo3_path::Py::<#pyo3_path::types::PyModule>::from_non_null(nonnull)
};

let res = unsafe {
#pyo3_path::Python::with_gil_unchecked(|py| {
let bound = module.bind(py);
MakeDef::do_init_multiphase(bound)
})
};

// FIXME: Better error handling?
let _ = res.unwrap();

0
}

#[doc(hidden)]
pub const __PYO3_INIT: *mut ::std::ffi::c_void = __module_objects_init as *mut ::std::ffi::c_void;

#initialization
}

Expand All @@ -336,17 +364,22 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
// inside a function body)
#[allow(unknown_lints, non_local_definitions)]
impl #ident::MakeDef {
/// Helper function for `__module_objects_init`. Should probably be put
/// somewhere else.
#[doc(hidden)]
pub fn do_init_multiphase(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}

const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}

const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
unsafe {
#pyo3_path::impl_::pymodule::ModuleDef::new(
#ident::__PYO3_NAME,
#doc,
INITIALIZER
)
}
}
Expand All @@ -368,12 +401,36 @@ fn module_initialization(options: PyModuleOptions, ident: &syn::Ident) -> TokenS
#[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = MakeDef::make_def();

#[doc(hidden)]
pub static _PYO3_SLOTS: &[#pyo3_path::impl_::pymodule::ModuleDefSlot] = &[
#pyo3_path::impl_::pymodule::ModuleDefSlot(#pyo3_path::ffi::PyModuleDef_Slot {
slot: #pyo3_path::ffi::Py_mod_exec,
value: #pyo3_path::impl_::pymodule_state::module_state_init as *mut ::std::ffi::c_void,
}),
#pyo3_path::impl_::pymodule::ModuleDefSlot(#pyo3_path::ffi::PyModuleDef_Slot {
slot: #pyo3_path::ffi::Py_mod_exec,
value: __PYO3_INIT,
}),
#[cfg(Py_3_12)]
#pyo3_path::impl_::pymodule::ModuleDefSlot(#pyo3_path::ffi::PyModuleDef_Slot {
slot: #pyo3_path::ffi::Py_mod_multiple_interpreters,
value: #pyo3_path::ffi::Py_MOD_PER_INTERPRETER_GIL_SUPPORTED,
}),
#pyo3_path::impl_::pymodule::ModuleDefSlot(#pyo3_path::ffi::PyModuleDef_Slot {
slot: 0,
value: ::std::ptr::null_mut(),
}),
];

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
#pyo3_path::impl_::trampoline::module_init(|py| {
_PYO3_DEF.set_multiphase_items(_PYO3_SLOTS);
_PYO3_DEF.make_module(py)
})
}
}
}
Expand Down
149 changes: 86 additions & 63 deletions src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,39 @@ use crate::{
Bound, Py, PyClass, PyMethodDef, PyResult, PyTypeInfo, Python,
};

use crate::impl_::pymodule_state as state;

/// `Sync` wrapper of `ffi::PyModuleDef_Slot`.
#[allow(unused)]
pub struct ModuleDefSlot(pub ffi::PyModuleDef_Slot);

unsafe impl Sync for ModuleDefSlot {}

/// `Sync` wrapper of `ffi::PyModuleDef`.
pub struct ModuleDef {
// wrapped in UnsafeCell so that Rust compiler treats this as interior mutability
ffi_def: UnsafeCell<ffi::PyModuleDef>,
initializer: ModuleInitializer,
/// Interpreter ID where module was initialized (not applicable on PyPy).
#[cfg(all(
not(any(PyPy, GraalPy)),
Py_3_9,
not(all(windows, Py_LIMITED_API, not(Py_3_10)))
))]
interpreter: AtomicI64,
// TODO: Figure out how to cache module with multi-phase init
/// Initialized module object, cached to avoid reinitialization.
#[allow(unused)]
module: GILOnceCell<Py<PyModule>>,
}

/// Wrapper to enable initializer to be used in const fns.
pub struct ModuleInitializer(pub for<'py> fn(&Bound<'py, PyModule>) -> PyResult<()>);

unsafe impl Sync for ModuleDef {}

impl ModuleDef {
/// Make new module definition with given module name.
///
/// # Safety
/// `name` and `doc` must be null-terminated strings.
pub const unsafe fn new(
name: &'static str,
doc: &'static str,
initializer: ModuleInitializer,
) -> Self {
pub const unsafe fn new(name: &'static str, doc: &'static str) -> Self {
const INIT: ffi::PyModuleDef = ffi::PyModuleDef {
m_base: ffi::PyModuleDef_HEAD_INIT,
m_name: std::ptr::null(),
Expand All @@ -69,7 +71,6 @@ impl ModuleDef {

ModuleDef {
ffi_def,
initializer,
// -1 is never expected to be a valid interpreter ID
#[cfg(all(
not(any(PyPy, GraalPy)),
Expand All @@ -80,8 +81,9 @@ impl ModuleDef {
module: GILOnceCell::new(),
}
}

/// Builds a module using user given initializer. Used for [`#[pymodule]`][crate::pymodule].
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<Py<PyModule>> {
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<*mut ffi::PyModuleDef> {
#[cfg(all(PyPy, not(Py_3_8)))]
{
use crate::types::any::PyAnyMethods;
Expand Down Expand Up @@ -135,18 +137,32 @@ impl ModuleDef {
}
}
}
self.module
.get_or_try_init(py, || {
let module = unsafe {
Py::<PyModule>::from_owned_ptr_or_err(
py,
ffi::PyModule_Create(self.ffi_def.get()),
)?
};
self.initializer.0(module.bind(py))?;
Ok(module)
})
.map(|py_module| py_module.clone_ref(py))

if (unsafe { *self.ffi_def.get() }).m_slots.is_null() {
return Err(PyImportError::new_err(
"'m_slots' of module definition is NULL",
));
}

let module_def_ptr = unsafe { ffi::PyModuleDef_Init(self.ffi_def.get()) };

if module_def_ptr.is_null() {
return Err(PyImportError::new_err("PyModuleDef_Init returned NULL"));
}

Ok(module_def_ptr.cast())
}

pub fn set_multiphase_items(&'static self, slots: &'static [ModuleDefSlot]) {
let slots = slots as *const [ModuleDefSlot] as *mut ffi::PyModuleDef_Slot;
let ffi_def = self.ffi_def.get();
unsafe {
(*ffi_def).m_size = std::mem::size_of::<state::ModuleState>() as ffi::Py_ssize_t;
(*ffi_def).m_slots = slots;
(*ffi_def).m_traverse = Some(state::module_state_traverse);
(*ffi_def).m_clear = Some(state::module_state_clear);
(*ffi_def).m_free = Some(state::module_state_free);
};
}
}

Expand Down Expand Up @@ -198,8 +214,9 @@ impl PyAddToModule for PyMethodDef {

/// For adding a module to a module.
impl PyAddToModule for ModuleDef {
fn add_to_module(&'static self, module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_submodule(self.make_module(module.py())?.bind(module.py()))
fn add_to_module(&'static self, _module: &Bound<'_, PyModule>) -> PyResult<()> {
// FIXME: Support multi-phase initialization
unimplemented!("Adding submodules to a module is not supported at the moment.")
}
}

Expand All @@ -211,50 +228,40 @@ mod tests {
};

use crate::{
ffi,
impl_::pymodule_state::module_state_init,
types::{any::PyAnyMethods, module::PyModuleMethods, PyModule},
Bound, PyResult, Python,
};

use super::{ModuleDef, ModuleInitializer};
use super::{ModuleDef, ModuleDefSlot};

#[test]
fn module_init() {
static MODULE_DEF: ModuleDef = unsafe {
ModuleDef::new(
"test_module\0",
"some doc\0",
ModuleInitializer(|m| {
m.add("SOME_CONSTANT", 42)?;
Ok(())
}),
)
};
static MODULE_DEF: ModuleDef = unsafe { ModuleDef::new("test_module\0", "some doc\0") };

static SLOTS: &[ModuleDefSlot] = &[
ModuleDefSlot(ffi::PyModuleDef_Slot {
slot: ffi::Py_mod_exec,
value: module_state_init as *mut std::ffi::c_void,
}),
#[cfg(Py_3_12)]
ModuleDefSlot(ffi::PyModuleDef_Slot {
slot: ffi::Py_mod_multiple_interpreters,
value: ffi::Py_MOD_PER_INTERPRETER_GIL_SUPPORTED,
}),
ModuleDefSlot(ffi::PyModuleDef_Slot {
slot: 0,
value: std::ptr::null_mut(),
}),
];

MODULE_DEF.set_multiphase_items(SLOTS);

Python::with_gil(|py| {
let module = MODULE_DEF.make_module(py).unwrap().into_bound(py);
assert_eq!(
module
.getattr("__name__")
.unwrap()
.extract::<Cow<'_, str>>()
.unwrap(),
"test_module",
);
assert_eq!(
module
.getattr("__doc__")
.unwrap()
.extract::<Cow<'_, str>>()
.unwrap(),
"some doc",
);
assert_eq!(
module
.getattr("SOME_CONSTANT")
.unwrap()
.extract::<u8>()
.unwrap(),
42,
);
let module_def = MODULE_DEF.make_module(py).unwrap();
// FIXME: Use PyState_FindModule to retrieve module here?
unimplemented!("Test currently not implemented");
})
}

Expand All @@ -265,6 +272,22 @@ mod tests {
static NAME: &str = "test_module\0";
static DOC: &str = "some doc\0";

static SLOTS: &[ModuleDefSlot] = &[
ModuleDefSlot(ffi::PyModuleDef_Slot {
slot: ffi::Py_mod_exec,
value: module_state_init as *mut std::ffi::c_void,
}),
#[cfg(Py_3_12)]
ModuleDefSlot(ffi::PyModuleDef_Slot {
slot: ffi::Py_mod_multiple_interpreters,
value: ffi::Py_MOD_PER_INTERPRETER_GIL_SUPPORTED,
}),
ModuleDefSlot(ffi::PyModuleDef_Slot {
slot: 0,
value: std::ptr::null_mut(),
}),
];

static INIT_CALLED: AtomicBool = AtomicBool::new(false);

#[allow(clippy::unnecessary_wraps)]
Expand All @@ -274,12 +297,12 @@ mod tests {
}

unsafe {
let module_def: ModuleDef = ModuleDef::new(NAME, DOC, ModuleInitializer(init));
let module_def: ModuleDef = ModuleDef::new(NAME, DOC);
module_def.set_multiphase_items(SLOTS);
assert_eq!((*module_def.ffi_def.get()).m_name, NAME.as_ptr() as _);
assert_eq!((*module_def.ffi_def.get()).m_doc, DOC.as_ptr() as _);

Python::with_gil(|py| {
module_def.initializer.0(&py.import_bound("builtins").unwrap()).unwrap();
assert!(INIT_CALLED.load(Ordering::SeqCst));
})
}
Expand Down
6 changes: 3 additions & 3 deletions src/impl_/trampoline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ use std::{
use crate::gil::GILPool;
use crate::{
callback::PyCallbackOutput, ffi, ffi_ptr_ext::FfiPtrExt, impl_::panic::PanicTrap,
methods::IPowModulo, panic::PanicException, types::PyModule, Py, PyResult, Python,
methods::IPowModulo, panic::PanicException, PyResult, Python,
};

#[inline]
pub unsafe fn module_init(
f: for<'py> unsafe fn(Python<'py>) -> PyResult<Py<PyModule>>,
f: for<'py> unsafe fn(Python<'py>) -> PyResult<*mut ffi::PyModuleDef>,
) -> *mut ffi::PyObject {
trampoline(|py| f(py).map(|module| module.into_ptr()))
trampoline(|py| f(py).map(|module_def| module_def.cast()))
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1679,7 +1679,7 @@ impl<T> Py<T> {
///
/// # Safety
/// `ptr` must point to a Python object of type T.
unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
pub unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
Self(ptr, PhantomData)
}
}
Expand Down

0 comments on commit 02eb845

Please sign in to comment.