Skip to content

Commit

Permalink
Merge pull request #792 from kngwyu/separator
Browse files Browse the repository at this point in the history
Fix the interpretation of vararg separator
  • Loading branch information
kngwyu authored Mar 7, 2020
2 parents 74b22eb + b7c4fdb commit 1a8ebc2
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 134 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
* Usage of raw identifiers with `#[pyo3(set)]`. [#745](https://github.com/PyO3/pyo3/pull/745)
* Usage of `PyObject` with `#[pyo3(get)]`. [#760](https://github.com/PyO3/pyo3/pull/760)
* `#[pymethods]` used in conjunction with `#[cfg]`. #[769](https://github.com/PyO3/pyo3/pull/769)
* `"*"` in a `#[pyfunction()]` argument list incorrectly accepting any number of positional arguments (use `args = "*"` when this behaviour is desired). #[792](https://github.com/PyO3/pyo3/pull/792)

### Removed

Expand Down
20 changes: 0 additions & 20 deletions pyo3-derive-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,6 @@ impl<'a> FnSpec<'a> {
false
}

pub fn accept_args(&self) -> bool {
for s in self.attrs.iter() {
match *s {
Argument::VarArgs(_) => return true,
Argument::VarArgsSeparator => return true,
_ => (),
}
}
false
}

pub fn is_kwargs(&self, name: &syn::Ident) -> bool {
for s in self.attrs.iter() {
if let Argument::KeywordArgs(ref path) = s {
Expand All @@ -226,15 +215,6 @@ impl<'a> FnSpec<'a> {
false
}

pub fn accept_kwargs(&self) -> bool {
for s in self.attrs.iter() {
if let Argument::KeywordArgs(_) = s {
return true;
}
}
false
}

pub fn default_value(&self, name: &syn::Ident) -> Option<TokenStream> {
for s in self.attrs.iter() {
match *s {
Expand Down
161 changes: 64 additions & 97 deletions pyo3-derive-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,139 +52,109 @@ impl PyFunctionAttr {
NestedMeta::Lit(ref lit) => {
self.add_literal(item, lit)?;
}
_ => {
return Err(syn::Error::new_spanned(item, "Unknown argument"));
NestedMeta::Meta(syn::Meta::List(ref list)) => {
return Err(syn::Error::new_spanned(
list,
"List is not supported as argument",
));
}
}

Ok(())
}

fn add_literal(&mut self, item: &NestedMeta, lit: &syn::Lit) -> syn::Result<()> {
match lit {
syn::Lit::Str(ref lits) => {
syn::Lit::Str(ref lits) if lits.value() == "*" => {
// "*"
if lits.value() == "*" {
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"syntax error, keyword self.arguments is defined",
));
}
if self.has_varargs {
return Err(syn::Error::new_spanned(
item,
"self.arguments already define * (var args)",
));
}
self.has_varargs = true;
self.arguments.push(Argument::VarArgsSeparator);
} else {
return Err(syn::Error::new_spanned(lits, "Unknown string literal"));
}
}
_ => {
return Err(syn::Error::new_spanned(
item,
format!("Only string literal is supported, got: {:?}", lit),
));
self.vararg_is_ok(item)?;
self.has_varargs = true;
self.arguments.push(Argument::VarArgsSeparator);
Ok(())
}
};
Ok(())
_ => Err(syn::Error::new_spanned(
item,
format!("Only \"*\" is supported here, got: {:?}", lit),
)),
}
}

fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> {
// self.arguments in form somename
if self.has_kwargs {
if self.has_kw || self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"syntax error, keyword self.arguments is defined",
"Positional argument or varargs(*) is not allowed after keyword arguments",
));
}
if self.has_kw {
if self.has_varargs {
return Err(syn::Error::new_spanned(
item,
"syntax error, argument is not allowed after keyword argument",
"Positional argument or varargs(*) is not allowed after *",
));
}
self.arguments.push(Argument::Arg(path.clone(), None));
Ok(())
}

fn vararg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> {
if self.has_kwargs || self.has_varargs {
return Err(syn::Error::new_spanned(
item,
"* is not allowed after varargs(*) or kwargs(**)",
));
}
Ok(())
}

fn kw_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> {
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"Keyword argument or kwargs(**) is not allowed after kwargs(**)",
));
}
Ok(())
}

fn add_nv_common(
&mut self,
item: &NestedMeta,
name: &syn::Path,
value: String,
) -> syn::Result<()> {
self.kw_arg_is_ok(item)?;
if self.has_varargs {
// kw only
self.arguments.push(Argument::Kwarg(name.clone(), value));
} else {
self.has_kw = true;
self.arguments
.push(Argument::Arg(name.clone(), Some(value)));
}
Ok(())
}

fn add_name_value(&mut self, item: &NestedMeta, nv: &syn::MetaNameValue) -> syn::Result<()> {
match nv.lit {
syn::Lit::Str(ref litstr) => {
if litstr.value() == "*" {
// args="*"
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"* - syntax error, keyword self.arguments is defined",
));
}
if self.has_varargs {
return Err(syn::Error::new_spanned(item, "*(var args) is defined"));
}
self.vararg_is_ok(item)?;
self.has_varargs = true;
self.arguments.push(Argument::VarArgs(nv.path.clone()));
} else if litstr.value() == "**" {
// kwargs="**"
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"self.arguments already define ** (kw args)",
));
}
self.kw_arg_is_ok(item)?;
self.has_kwargs = true;
self.arguments.push(Argument::KeywordArgs(nv.path.clone()));
} else if self.has_varargs {
self.arguments
.push(Argument::Kwarg(nv.path.clone(), litstr.value()))
} else {
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"syntax error, keyword self.arguments is defined",
));
}
self.has_kw = true;
self.arguments
.push(Argument::Arg(nv.path.clone(), Some(litstr.value())))
self.add_nv_common(item, &nv.path, litstr.value())?;
}
}
syn::Lit::Int(ref litint) => {
if self.has_varargs {
self.arguments
.push(Argument::Kwarg(nv.path.clone(), format!("{}", litint)));
} else {
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"syntax error, keyword self.arguments is defined",
));
}
self.has_kw = true;
self.arguments
.push(Argument::Arg(nv.path.clone(), Some(format!("{}", litint))));
}
self.add_nv_common(item, &nv.path, format!("{}", litint))?;
}
syn::Lit::Bool(ref litb) => {
if self.has_varargs {
self.arguments
.push(Argument::Kwarg(nv.path.clone(), format!("{}", litb.value)));
} else {
if self.has_kwargs {
return Err(syn::Error::new_spanned(
item,
"syntax error, keyword self.arguments is defined",
));
}
self.has_kw = true;
self.arguments.push(Argument::Arg(
nv.path.clone(),
Some(format!("{}", litb.value)),
));
}
self.add_nv_common(item, &nv.path, format!("{}", litb.value))?;
}
_ => {
return Err(syn::Error::new_spanned(
Expand Down Expand Up @@ -221,11 +191,8 @@ pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Opti
*span,
"Expected string literal for #[name] argument",
)),
// TODO: The below pattern is unstable, so instead we match the wildcard.
// slice_patterns due to be stable soon: https://github.com/rust-lang/rust/issues/62254
// [(_, span), _, ..] => {
_ => Err(syn::Error::new(
name_attrs[0].1,
[(_, span), ..] => Err(syn::Error::new(
*span,
"#[name] can not be specified multiple times",
)),
}
Expand Down
25 changes: 12 additions & 13 deletions pyo3-derive-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,6 @@ fn impl_call(_cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
quote! { _slf.#fname(#(#names),*) }
}

/// Converts a bool to "true" or "false"
fn bool_to_ident(condition: bool) -> syn::Ident {
if condition {
syn::Ident::new("true", Span::call_site())
} else {
syn::Ident::new("false", Span::call_site())
}
}

fn impl_arg_params_(spec: &FnSpec<'_>, body: TokenStream, into_result: TokenStream) -> TokenStream {
if spec.args.is_empty() {
return quote! {
Expand All @@ -431,8 +422,8 @@ fn impl_arg_params_(spec: &FnSpec<'_>, body: TokenStream, into_result: TokenStre
continue;
}
let name = arg.name;
let kwonly = bool_to_ident(spec.is_kw_only(&arg.name));
let opt = bool_to_ident(arg.optional.is_some() || spec.default_value(&arg.name).is_some());
let kwonly = spec.is_kw_only(&arg.name);
let opt = arg.optional.is_some() || spec.default_value(&arg.name).is_some();

params.push(quote! {
pyo3::derive_utils::ParamDescription {
Expand All @@ -449,8 +440,16 @@ fn impl_arg_params_(spec: &FnSpec<'_>, body: TokenStream, into_result: TokenStre
param_conversion.push(impl_arg_param(&arg, &spec, idx, &mut option_pos));
}

let accept_args = bool_to_ident(spec.accept_args());
let accept_kwargs = bool_to_ident(spec.accept_kwargs());
let (mut accept_args, mut accept_kwargs) = (false, false);

for s in spec.attrs.iter() {
use crate::pyfunction::Argument;
match s {
Argument::VarArgs(_) => accept_args = true,
Argument::KeywordArgs(_) => accept_kwargs = true,
_ => continue,
}
}
let num_normal_params = params.len();
// create array of arguments, and then parse
quote! {
Expand Down
2 changes: 1 addition & 1 deletion src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub fn parse_fn_args<'p>(
if i < nargs {
raise_error!("got multiple values for argument: {}", p.name)
}
kwargs.as_ref().unwrap().del_item(p.name).unwrap();
kwargs.as_ref().unwrap().del_item(p.name)?;
Some(kwarg)
}
None => {
Expand Down
2 changes: 1 addition & 1 deletion tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ macro_rules! py_expect_exception {
let res = $py.run($code, None, Some(d));
let err = res.unwrap_err();
if !err.matches($py, $py.get_type::<pyo3::exceptions::$err>()) {
panic!(format!("Expected {} but got {:?}", stringify!($err), err))
panic!("Expected {} but got {:?}", stringify!($err), err)
}
}};
}
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[test]
fn test_compile_errors() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/invalid_macro_args.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/missing_clone.rs");
Expand Down
30 changes: 28 additions & 2 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,20 @@ impl MethArgs {
[a.to_object(py), args.into(), kwargs.to_object(py)].to_object(py)
}

#[args(a, b = 2, "*", c = 3)]
fn get_pos_arg_kw_sep1(&self, a: i32, b: i32, c: i32) -> PyResult<i32> {
Ok(a + b + c)
}

#[args(a, "*", b = 2, c = 3)]
fn get_pos_arg_kw_sep2(&self, a: i32, b: i32, c: i32) -> PyResult<i32> {
Ok(a + b + c)
}

#[args(kwargs = "**")]
fn get_pos_kw(&self, py: Python, a: i32, kwargs: Option<&PyDict>) -> PyObject {
[a.to_object(py), kwargs.to_object(py)].to_object(py)
}

// "args" can be anything that can be extracted from PyTuple
#[args(args = "*")]
fn args_as_vec(&self, args: Vec<i32>) -> i32 {
Expand Down Expand Up @@ -264,7 +273,7 @@ fn meth_args() {
py_run!(py, inst, "assert inst.get_default() == 10");
py_run!(py, inst, "assert inst.get_default(100) == 100");
py_run!(py, inst, "assert inst.get_kwarg() == 10");
py_run!(py, inst, "assert inst.get_kwarg(100) == 10");
py_expect_exception!(py, inst, "inst.get_kwarg(100)", TypeError);
py_run!(py, inst, "assert inst.get_kwarg(test=100) == 100");
py_run!(py, inst, "assert inst.get_kwargs() == [(), None]");
py_run!(py, inst, "assert inst.get_kwargs(1,2,3) == [(1,2,3), None]");
Expand Down Expand Up @@ -295,6 +304,23 @@ fn meth_args() {
py_expect_exception!(py, inst, "inst.get_pos_arg_kw(1, a=1)", TypeError);
py_expect_exception!(py, inst, "inst.get_pos_arg_kw(b=2)", TypeError);

py_run!(py, inst, "assert inst.get_pos_arg_kw_sep1(1) == 6");
py_run!(py, inst, "assert inst.get_pos_arg_kw_sep1(1, 2) == 6");
py_run!(
py,
inst,
"assert inst.get_pos_arg_kw_sep1(1, 2, c=13) == 16"
);
py_expect_exception!(py, inst, "inst.get_pos_arg_kw_sep1(1, 2, 3)", TypeError);

py_run!(py, inst, "assert inst.get_pos_arg_kw_sep2(1) == 6");
py_run!(
py,
inst,
"assert inst.get_pos_arg_kw_sep2(1, b=12, c=13) == 26"
);
py_expect_exception!(py, inst, "inst.get_pos_arg_kw_sep2(1, 2)", TypeError);

py_run!(py, inst, "assert inst.get_pos_kw(1, b=2) == [1, {'b': 2}]");
py_expect_exception!(py, inst, "inst.get_pos_kw(1,2)", TypeError);

Expand Down
Loading

0 comments on commit 1a8ebc2

Please sign in to comment.