From 25069baef48c7b770be9fc24d572a84fa912fd6b Mon Sep 17 00:00:00 2001 From: kngwyu Date: Wed, 4 Mar 2020 23:44:40 +0900 Subject: [PATCH 1/4] Fix the interpretation of '*' --- CHANGELOG.md | 1 + pyo3-derive-backend/src/method.rs | 20 -------------------- pyo3-derive-backend/src/pymethod.rs | 25 ++++++++++++------------- src/derive_utils.rs | 2 +- tests/test_methods.rs | 18 ++++++++++++++++-- 5 files changed, 30 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66146d40e1d..a6605d347f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Usage of raw identifiers with `#[pyo3(set)]`. [#745](https://github.com/PyO3/pyo3/pull/745) * Usage of `PyObject` with `#[pyo3(get)]`. [#760](https://github.com/PyO3/pyo3/pull/760) * `#[pymethods]` used in conjunction with `#[cfg]`. #[769](https://github.com/PyO3/pyo3/pull/769) +* Interpretation of `*`. #[792](https://github.com/PyO3/pyo3/pull/792) ### Removed diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index 6d06c377945..e138924b868 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -206,17 +206,6 @@ impl<'a> FnSpec<'a> { false } - pub fn accept_args(&self) -> bool { - for s in self.attrs.iter() { - match *s { - Argument::VarArgs(_) => return true, - Argument::VarArgsSeparator => return true, - _ => (), - } - } - false - } - pub fn is_kwargs(&self, name: &syn::Ident) -> bool { for s in self.attrs.iter() { if let Argument::KeywordArgs(ref path) = s { @@ -226,15 +215,6 @@ impl<'a> FnSpec<'a> { false } - pub fn accept_kwargs(&self) -> bool { - for s in self.attrs.iter() { - if let Argument::KeywordArgs(_) = s { - return true; - } - } - false - } - pub fn default_value(&self, name: &syn::Ident) -> Option { for s in self.attrs.iter() { match *s { diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 7149d9c7a27..d590d2b2752 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -406,15 +406,6 @@ fn impl_call(_cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { quote! { _slf.#fname(#(#names),*) } } -/// Converts a bool to "true" or "false" -fn bool_to_ident(condition: bool) -> syn::Ident { - if condition { - syn::Ident::new("true", Span::call_site()) - } else { - syn::Ident::new("false", Span::call_site()) - } -} - fn impl_arg_params_(spec: &FnSpec<'_>, body: TokenStream, into_result: TokenStream) -> TokenStream { if spec.args.is_empty() { return quote! { @@ -431,8 +422,8 @@ fn impl_arg_params_(spec: &FnSpec<'_>, body: TokenStream, into_result: TokenStre continue; } let name = arg.name; - let kwonly = bool_to_ident(spec.is_kw_only(&arg.name)); - let opt = bool_to_ident(arg.optional.is_some() || spec.default_value(&arg.name).is_some()); + let kwonly = spec.is_kw_only(&arg.name); + let opt = arg.optional.is_some() || spec.default_value(&arg.name).is_some(); params.push(quote! { pyo3::derive_utils::ParamDescription { @@ -449,8 +440,16 @@ fn impl_arg_params_(spec: &FnSpec<'_>, body: TokenStream, into_result: TokenStre param_conversion.push(impl_arg_param(&arg, &spec, idx, &mut option_pos)); } - let accept_args = bool_to_ident(spec.accept_args()); - let accept_kwargs = bool_to_ident(spec.accept_kwargs()); + let (mut accept_args, mut accept_kwargs) = (false, false); + + for s in spec.attrs.iter() { + use crate::pyfunction::Argument; + match s { + Argument::VarArgs(_) => accept_args = true, + Argument::KeywordArgs(_) => accept_kwargs = true, + _ => continue, + } + } let num_normal_params = params.len(); // create array of arguments, and then parse quote! { diff --git a/src/derive_utils.rs b/src/derive_utils.rs index baa134f6e8a..019b49c5df9 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -61,7 +61,7 @@ pub fn parse_fn_args<'p>( if i < nargs { raise_error!("got multiple values for argument: {}", p.name) } - kwargs.as_ref().unwrap().del_item(p.name).unwrap(); + kwargs.as_ref().unwrap().del_item(p.name)?; Some(kwarg) } None => { diff --git a/tests/test_methods.rs b/tests/test_methods.rs index e1b2546b66b..35906d8183f 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -231,11 +231,15 @@ impl MethArgs { [a.to_object(py), args.into(), kwargs.to_object(py)].to_object(py) } + #[args("*", c = 10)] + fn get_pos_arg_kw_sep(&self, a: i32, b: i32, c: i32) -> PyResult { + Ok(a + b + c) + } + #[args(kwargs = "**")] fn get_pos_kw(&self, py: Python, a: i32, kwargs: Option<&PyDict>) -> PyObject { [a.to_object(py), kwargs.to_object(py)].to_object(py) } - // "args" can be anything that can be extracted from PyTuple #[args(args = "*")] fn args_as_vec(&self, args: Vec) -> i32 { @@ -264,7 +268,7 @@ fn meth_args() { py_run!(py, inst, "assert inst.get_default() == 10"); py_run!(py, inst, "assert inst.get_default(100) == 100"); py_run!(py, inst, "assert inst.get_kwarg() == 10"); - py_run!(py, inst, "assert inst.get_kwarg(100) == 10"); + py_expect_exception!(py, inst, "inst.get_kwarg(100)", TypeError); py_run!(py, inst, "assert inst.get_kwarg(test=100) == 100"); py_run!(py, inst, "assert inst.get_kwargs() == [(), None]"); py_run!(py, inst, "assert inst.get_kwargs(1,2,3) == [(1,2,3), None]"); @@ -295,6 +299,16 @@ fn meth_args() { py_expect_exception!(py, inst, "inst.get_pos_arg_kw(1, a=1)", TypeError); py_expect_exception!(py, inst, "inst.get_pos_arg_kw(b=2)", TypeError); + py_run!(py, inst, "assert inst.get_pos_arg_kw_sep(1, 2, c=3) == 6"); + py_run!(py, inst, "assert inst.get_pos_arg_kw_sep(1, 2) == 13"); + py_expect_exception!(py, inst, "assert inst.get_pos_arg_kw_sep(1)", TypeError); + py_expect_exception!( + py, + inst, + "assert inst.get_pos_arg_kw_sep(1, 2, 3)", + TypeError + ); + py_run!(py, inst, "assert inst.get_pos_kw(1, b=2) == [1, {'b': 2}]"); py_expect_exception!(py, inst, "inst.get_pos_kw(1,2)", TypeError); From cea707dd1c749f399dd47b4bbbcf771397502b7d Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 6 Mar 2020 14:01:27 +0900 Subject: [PATCH 2/4] Inhibit positional args after * --- CHANGELOG.md | 2 +- pyo3-derive-backend/src/pyfunction.rs | 122 ++++++++++---------------- tests/common.rs | 2 +- tests/test_compile_error.rs | 1 + tests/test_methods.rs | 28 ++++-- tests/ui/invalid_macro_args.rs | 18 ++++ tests/ui/invalid_macro_args.stderr | 17 ++++ 7 files changed, 106 insertions(+), 84 deletions(-) create mode 100644 tests/ui/invalid_macro_args.rs create mode 100644 tests/ui/invalid_macro_args.stderr diff --git a/CHANGELOG.md b/CHANGELOG.md index a6605d347f3..473b972680e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Usage of raw identifiers with `#[pyo3(set)]`. [#745](https://github.com/PyO3/pyo3/pull/745) * Usage of `PyObject` with `#[pyo3(get)]`. [#760](https://github.com/PyO3/pyo3/pull/760) * `#[pymethods]` used in conjunction with `#[cfg]`. #[769](https://github.com/PyO3/pyo3/pull/769) -* Interpretation of `*`. #[792](https://github.com/PyO3/pyo3/pull/792) +* Interpretation of `*` and some unreasonable behaviors of `#[args]`. #[792](https://github.com/PyO3/pyo3/pull/792) ### Removed diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index c7bcacd3214..aa82cceaa6c 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -65,18 +65,7 @@ impl PyFunctionAttr { syn::Lit::Str(ref lits) => { // "*" if lits.value() == "*" { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - if self.has_varargs { - return Err(syn::Error::new_spanned( - item, - "self.arguments already define * (var args)", - )); - } + self.vararg_is_ok(item)?; self.has_varargs = true; self.arguments.push(Argument::VarArgsSeparator); } else { @@ -94,97 +83,82 @@ impl PyFunctionAttr { } fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> { - // self.arguments in form somename - if self.has_kwargs { + if self.has_kw || self.has_kwargs { return Err(syn::Error::new_spanned( item, - "syntax error, keyword self.arguments is defined", + "Positional argument or varargs(*) is not allowed after keyword arguments", )); } - if self.has_kw { + if self.has_varargs { return Err(syn::Error::new_spanned( item, - "syntax error, argument is not allowed after keyword argument", + "Positional argument or varargs(*) is not allowed after *", )); } self.arguments.push(Argument::Arg(path.clone(), None)); Ok(()) } + fn vararg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { + if self.has_kwargs || self.has_varargs { + return Err(syn::Error::new_spanned( + item, + "* is not allowed after varargs(*) or kwargs(**)", + )); + } + Ok(()) + } + + fn kw_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { + if self.has_kwargs { + return Err(syn::Error::new_spanned( + item, + "Keyword argument or kwargs(**) is not allowed after kwargs(**)", + )); + } + Ok(()) + } + + fn add_nv_common( + &mut self, + item: &NestedMeta, + name: &syn::Path, + value: String, + ) -> syn::Result<()> { + self.kw_arg_is_ok(item)?; + if self.has_varargs { + // kw only + self.arguments.push(Argument::Kwarg(name.clone(), value)); + } else { + self.has_kw = true; + self.arguments + .push(Argument::Arg(name.clone(), Some(value))); + } + Ok(()) + } + fn add_name_value(&mut self, item: &NestedMeta, nv: &syn::MetaNameValue) -> syn::Result<()> { match nv.lit { syn::Lit::Str(ref litstr) => { if litstr.value() == "*" { // args="*" - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "* - syntax error, keyword self.arguments is defined", - )); - } - if self.has_varargs { - return Err(syn::Error::new_spanned(item, "*(var args) is defined")); - } + self.vararg_is_ok(item)?; self.has_varargs = true; self.arguments.push(Argument::VarArgs(nv.path.clone())); } else if litstr.value() == "**" { // kwargs="**" - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "self.arguments already define ** (kw args)", - )); - } + self.kw_arg_is_ok(item)?; self.has_kwargs = true; self.arguments.push(Argument::KeywordArgs(nv.path.clone())); - } else if self.has_varargs { - self.arguments - .push(Argument::Kwarg(nv.path.clone(), litstr.value())) } else { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - self.has_kw = true; - self.arguments - .push(Argument::Arg(nv.path.clone(), Some(litstr.value()))) + self.add_nv_common(item, &nv.path, litstr.value())?; } } syn::Lit::Int(ref litint) => { - if self.has_varargs { - self.arguments - .push(Argument::Kwarg(nv.path.clone(), format!("{}", litint))); - } else { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - self.has_kw = true; - self.arguments - .push(Argument::Arg(nv.path.clone(), Some(format!("{}", litint)))); - } + self.add_nv_common(item, &nv.path, format!("{}", litint))?; } syn::Lit::Bool(ref litb) => { - if self.has_varargs { - self.arguments - .push(Argument::Kwarg(nv.path.clone(), format!("{}", litb.value))); - } else { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - self.has_kw = true; - self.arguments.push(Argument::Arg( - nv.path.clone(), - Some(format!("{}", litb.value)), - )); - } + self.add_nv_common(item, &nv.path, format!("{}", litb.value))?; } _ => { return Err(syn::Error::new_spanned( diff --git a/tests/common.rs b/tests/common.rs index 546220a27b2..4cd3c084f9f 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -18,7 +18,7 @@ macro_rules! py_expect_exception { let res = $py.run($code, None, Some(d)); let err = res.unwrap_err(); if !err.matches($py, $py.get_type::()) { - panic!(format!("Expected {} but got {:?}", stringify!($err), err)) + panic!("Expected {} but got {:?}", stringify!($err), err) } }}; } diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index b316e011aaf..c0cf995ac23 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -1,6 +1,7 @@ #[test] fn test_compile_errors() { let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/invalid_macro_args.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/missing_clone.rs"); diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 35906d8183f..5aaf9f6e6c4 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -231,8 +231,13 @@ impl MethArgs { [a.to_object(py), args.into(), kwargs.to_object(py)].to_object(py) } - #[args("*", c = 10)] - fn get_pos_arg_kw_sep(&self, a: i32, b: i32, c: i32) -> PyResult { + #[args(a, b = 2, "*", c = 3)] + fn get_pos_arg_kw_sep1(&self, a: i32, b: i32, c: i32) -> PyResult { + Ok(a + b + c) + } + + #[args(a, "*", b = 2, c = 3)] + fn get_pos_arg_kw_sep2(&self, a: i32, b: i32, c: i32) -> PyResult { Ok(a + b + c) } @@ -299,15 +304,22 @@ fn meth_args() { py_expect_exception!(py, inst, "inst.get_pos_arg_kw(1, a=1)", TypeError); py_expect_exception!(py, inst, "inst.get_pos_arg_kw(b=2)", TypeError); - py_run!(py, inst, "assert inst.get_pos_arg_kw_sep(1, 2, c=3) == 6"); - py_run!(py, inst, "assert inst.get_pos_arg_kw_sep(1, 2) == 13"); - py_expect_exception!(py, inst, "assert inst.get_pos_arg_kw_sep(1)", TypeError); - py_expect_exception!( + py_run!(py, inst, "assert inst.get_pos_arg_kw_sep1(1) == 6"); + py_run!(py, inst, "assert inst.get_pos_arg_kw_sep1(1, 2) == 6"); + py_run!( + py, + inst, + "assert inst.get_pos_arg_kw_sep1(1, 2, c=13) == 16" + ); + py_expect_exception!(py, inst, "inst.get_pos_arg_kw_sep1(1, 2, 3)", TypeError); + + py_run!(py, inst, "assert inst.get_pos_arg_kw_sep2(1) == 6"); + py_run!( py, inst, - "assert inst.get_pos_arg_kw_sep(1, 2, 3)", - TypeError + "assert inst.get_pos_arg_kw_sep2(1, b=12, c=13) == 26" ); + py_expect_exception!(py, inst, "inst.get_pos_arg_kw_sep2(1, 2)", TypeError); py_run!(py, inst, "assert inst.get_pos_kw(1, b=2) == [1, {'b': 2}]"); py_expect_exception!(py, inst, "inst.get_pos_kw(1,2)", TypeError); diff --git a/tests/ui/invalid_macro_args.rs b/tests/ui/invalid_macro_args.rs new file mode 100644 index 00000000000..f99f5814f61 --- /dev/null +++ b/tests/ui/invalid_macro_args.rs @@ -0,0 +1,18 @@ +use pyo3::prelude::*; + +#[pyfunction(a = 5, b)] +fn pos_after_kw(py: Python, a: i32, b: i32) -> PyObject { + [a.to_object(py), vararg.into()].to_object(py) +} + +#[pyfunction(a, "*", b)] +fn pos_after_separator(py: Python, a: i32, b: i32) -> PyObject { + [a.to_object(py), vararg.into()].to_object(py) +} + +#[pyfunction(kwargs = "**", a = 5)] +fn kw_after_kwargs(py: Python, kwargs: &PyDict, a: i32) -> PyObject { + [a.to_object(py), vararg.into()].to_object(py) +} + +fn main() {} diff --git a/tests/ui/invalid_macro_args.stderr b/tests/ui/invalid_macro_args.stderr new file mode 100644 index 00000000000..1e9dab29b7d --- /dev/null +++ b/tests/ui/invalid_macro_args.stderr @@ -0,0 +1,17 @@ +error: Positional argument or varargs(*) is not allowed after keyword arguments + --> $DIR/invalid_macro_args.rs:3:21 + | +3 | #[pyfunction(a = 5, b)] + | ^ + +error: Positional argument or varargs(*) is not allowed after * + --> $DIR/invalid_macro_args.rs:8:22 + | +8 | #[pyfunction(a, "*", b)] + | ^ + +error: Keyword argument or kwargs(**) is not allowed after kwargs(**) + --> $DIR/invalid_macro_args.rs:13:29 + | +13 | #[pyfunction(kwargs = "**", a = 5)] + | ^^^^^ From 26fe29f0cbb3af3c706636cfbcea70cc9c80088f Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 6 Mar 2020 19:01:05 +0900 Subject: [PATCH 3/4] Some refactorings for pyfunction.rs --- pyo3-derive-backend/src/pyfunction.rs | 41 +++++++++++---------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index aa82cceaa6c..689cd45f516 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -52,34 +52,30 @@ impl PyFunctionAttr { NestedMeta::Lit(ref lit) => { self.add_literal(item, lit)?; } - _ => { - return Err(syn::Error::new_spanned(item, "Unknown argument")); + NestedMeta::Meta(syn::Meta::List(ref list)) => { + return Err(syn::Error::new_spanned( + list, + "List is not supported as argument", + )); } } - Ok(()) } fn add_literal(&mut self, item: &NestedMeta, lit: &syn::Lit) -> syn::Result<()> { match lit { - syn::Lit::Str(ref lits) => { + syn::Lit::Str(ref lits) if lits.value() == "*" => { // "*" - if lits.value() == "*" { - self.vararg_is_ok(item)?; - self.has_varargs = true; - self.arguments.push(Argument::VarArgsSeparator); - } else { - return Err(syn::Error::new_spanned(lits, "Unknown string literal")); - } + self.vararg_is_ok(item)?; + self.has_varargs = true; + self.arguments.push(Argument::VarArgsSeparator); + Ok(()) } - _ => { - return Err(syn::Error::new_spanned( - item, - format!("Only string literal is supported, got: {:?}", lit), - )); - } - }; - Ok(()) + _ => Err(syn::Error::new_spanned( + item, + format!("Only \"*\" is supported here, got: {:?}", lit), + )), + } } fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> { @@ -195,11 +191,8 @@ pub fn parse_name_attribute(attrs: &mut Vec) -> syn::Result { - _ => Err(syn::Error::new( - name_attrs[0].1, + [(_, span), ..] => Err(syn::Error::new( + *span, "#[name] can not be specified multiple times", )), } From b7c4fdb9bc3773ebc61c858b2b5fe1b2a2b945f5 Mon Sep 17 00:00:00 2001 From: Yuji Kanagawa Date: Sun, 8 Mar 2020 00:57:11 +0900 Subject: [PATCH 4/4] Update CHANGELOG.md Co-Authored-By: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 473b972680e..ecf442b51e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Usage of raw identifiers with `#[pyo3(set)]`. [#745](https://github.com/PyO3/pyo3/pull/745) * Usage of `PyObject` with `#[pyo3(get)]`. [#760](https://github.com/PyO3/pyo3/pull/760) * `#[pymethods]` used in conjunction with `#[cfg]`. #[769](https://github.com/PyO3/pyo3/pull/769) -* Interpretation of `*` and some unreasonable behaviors of `#[args]`. #[792](https://github.com/PyO3/pyo3/pull/792) +* `"*"` in a `#[pyfunction()]` argument list incorrectly accepting any number of positional arguments (use `args = "*"` when this behaviour is desired). #[792](https://github.com/PyO3/pyo3/pull/792) ### Removed