diff --git a/newsfragments/4213.added.md b/newsfragments/4213.added.md new file mode 100644 index 00000000000..6f553dc93ab --- /dev/null +++ b/newsfragments/4213.added.md @@ -0,0 +1 @@ +Properly fills the `module=` attribute of declarative modules child `#[pymodule]` and `#[pyclass]`. \ No newline at end of file diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 756037263e3..71d776bf350 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -1,10 +1,13 @@ //! Code generation for the function that initializes a python module and adds classes and function. -use crate::utils::Ctx; use crate::{ - attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute}, + attributes::{ + self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute, + }, get_doc, + pyclass::PyClassPyO3Option, pyfunction::{impl_wrap_pyfunction, PyFunctionOptions}, + utils::Ctx, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,15 +15,17 @@ use syn::{ ext::IdentExt, parse::{Parse, ParseStream}, parse_quote, parse_quote_spanned, + punctuated::Punctuated, spanned::Spanned, token::Comma, - Item, Path, Result, + Item, Meta, Path, Result, }; #[derive(Default)] pub struct PyModuleOptions { krate: Option<CrateAttribute>, name: Option<syn::Ident>, + module: Option<ModuleAttribute>, } impl PyModuleOptions { @@ -31,6 +36,7 @@ impl PyModuleOptions { match option { PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?, PyModulePyO3Option::Crate(path) => options.set_crate(path)?, + PyModulePyO3Option::Module(module) => options.set_module(module)?, } } @@ -56,6 +62,16 @@ impl PyModuleOptions { self.krate = Some(path); Ok(()) } + + fn set_module(&mut self, name: ModuleAttribute) -> Result<()> { + ensure_spanned!( + self.module.is_none(), + name.span() => "`module` may only be specified once" + ); + + self.module = Some(name); + Ok(()) + } } pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { @@ -77,6 +93,12 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { let ctx = &Ctx::new(&options.krate); let Ctx { pyo3_path } = ctx; let doc = get_doc(attrs, None); + let name = options.name.unwrap_or_else(|| ident.unraw()); + let full_name = if let Some(module) = &options.module { + format!("{}.{}", module.value.value(), name) + } else { + name.to_string() + }; let mut module_items = Vec::new(); let mut module_items_cfg_attrs = Vec::new(); @@ -156,6 +178,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { if has_attribute(&item_struct.attrs, "pyclass") { module_items.push(item_struct.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs)); + if !has_pyo3_module_declared::<PyClassPyO3Option>( + &item_struct.attrs, + "pyclass", + |option| matches!(option, PyClassPyO3Option::Module(_)), + )? { + set_module_attribute(&mut item_struct.attrs, &full_name); + } } } Item::Enum(item_enum) => { @@ -166,6 +195,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { if has_attribute(&item_enum.attrs, "pyclass") { module_items.push(item_enum.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs)); + if !has_pyo3_module_declared::<PyClassPyO3Option>( + &item_enum.attrs, + "pyclass", + |option| matches!(option, PyClassPyO3Option::Module(_)), + )? { + set_module_attribute(&mut item_enum.attrs, &full_name); + } } } Item::Mod(item_mod) => { @@ -176,6 +212,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { if has_attribute(&item_mod.attrs, "pymodule") { module_items.push(item_mod.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs)); + if !has_pyo3_module_declared::<PyModulePyO3Option>( + &item_mod.attrs, + "pymodule", + |option| matches!(option, PyModulePyO3Option::Module(_)), + )? { + set_module_attribute(&mut item_mod.attrs, &full_name); + } } } Item::ForeignMod(item) => { @@ -242,7 +285,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { } } - let initialization = module_initialization(options, ident); + let initialization = module_initialization(&name, ctx); Ok(quote!( #vis mod #ident { #(#items)* @@ -286,10 +329,11 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream> let stmts = std::mem::take(&mut function.block.stmts); let Ctx { pyo3_path } = ctx; let ident = &function.sig.ident; + let name = options.name.unwrap_or_else(|| ident.unraw()); let vis = &function.vis; let doc = get_doc(&function.attrs, None); - let initialization = module_initialization(options, ident); + let initialization = module_initialization(&name, ctx); // Module function called with optional Python<'_> marker as first arg, followed by the module. let mut module_args = Vec::new(); @@ -354,9 +398,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream> }) } -fn module_initialization(options: PyModuleOptions, ident: &syn::Ident) -> TokenStream { - let name = options.name.unwrap_or_else(|| ident.unraw()); - let ctx = &Ctx::new(&options.krate); +fn module_initialization(name: &syn::Ident, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path } = ctx; let pyinit_symbol = format!("PyInit_{}", name); @@ -491,9 +533,33 @@ fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { attrs.iter().any(|attr| attr.path().is_ident(ident)) } +fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) { + attrs.push(parse_quote!(#[pyo3(module = #module_name)])); +} + +fn has_pyo3_module_declared<T: Parse>( + attrs: &[syn::Attribute], + root_attribute_name: &str, + is_module_option: impl Fn(&T) -> bool + Copy, +) -> Result<bool> { + for attr in attrs { + if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name)) + && matches!(attr.meta, Meta::List(_)) + { + for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? { + if is_module_option(option) { + return Ok(true); + } + } + } + } + Ok(false) +} + enum PyModulePyO3Option { Crate(CrateAttribute), Name(NameAttribute), + Module(ModuleAttribute), } impl Parse for PyModulePyO3Option { @@ -503,6 +569,8 @@ impl Parse for PyModulePyO3Option { input.parse().map(PyModulePyO3Option::Name) } else if lookahead.peek(syn::Token![crate]) { input.parse().map(PyModulePyO3Option::Crate) + } else if lookahead.peek(attributes::kw::module) { + input.parse().map(PyModulePyO3Option::Module) } else { Err(lookahead.error()) } diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 47c52c84518..2c24fc9a0a6 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -74,7 +74,7 @@ pub struct PyClassPyO3Options { pub weakref: Option<kw::weakref>, } -enum PyClassPyO3Option { +pub enum PyClassPyO3Option { Crate(CrateAttribute), Dict(kw::dict), Extends(ExtendsAttribute), diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 2e46f4a64d1..d3ca9faa1ec 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -3,6 +3,7 @@ use pyo3::create_exception; use pyo3::exceptions::PyException; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; #[cfg(not(Py_LIMITED_API))] use pyo3::types::PyBool; @@ -78,7 +79,7 @@ mod declarative_module { x * 3 } - #[pyclass] + #[pyclass(name = "Struct")] struct Struct; #[pymethods] @@ -89,11 +90,29 @@ mod declarative_module { } } - #[pyclass] + #[pyclass(module = "foo")] + struct StructInCustomModule; + + #[pyclass(name = "Enum")] enum Enum { A, B, } + + #[pyclass(module = "foo")] + enum EnumInCustomModule { + A, + B, + } + } + + #[pymodule] + #[pyo3(module = "custom_root")] + mod inner_custom_root { + use super::*; + + #[pyclass] + struct Struct; } #[pymodule_init] @@ -120,10 +139,17 @@ mod declarative_module2 { use super::double; } +fn declarative_module(py: Python<'_>) -> &Bound<'_, PyModule> { + static MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new(); + MODULE + .get_or_init(py, || pyo3::wrap_pymodule!(declarative_module)(py)) + .bind(py) +} + #[test] fn test_declarative_module() { Python::with_gil(|py| { - let m = pyo3::wrap_pymodule!(declarative_module)(py).into_bound(py); + let m = declarative_module(py); py_assert!( py, m, @@ -187,3 +213,27 @@ fn test_raw_ident_module() { py_assert!(py, m, "m.double(2) == 4"); }) } + +#[test] +fn test_module_names() { + Python::with_gil(|py| { + let m = declarative_module(py); + py_assert!( + py, + m, + "m.inner.Struct.__module__ == 'declarative_module.inner'" + ); + py_assert!(py, m, "m.inner.StructInCustomModule.__module__ == 'foo'"); + py_assert!( + py, + m, + "m.inner.Enum.__module__ == 'declarative_module.inner'" + ); + py_assert!(py, m, "m.inner.EnumInCustomModule.__module__ == 'foo'"); + py_assert!( + py, + m, + "m.inner_custom_root.Struct.__module__ == 'custom_root.inner_custom_root'" + ); + }) +}