Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine tyro.conf.Suppress behavior in unions #231

Merged
merged 24 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ def lowered(self) -> LoweredArgumentDefinition:
_rule_apply_argconf(self, lowered)
return lowered

def is_suppressed(self) -> bool:
"""Returns if the argument is suppressed. Suppressed arguments won't be
added to the parser."""
return _markers.Suppress in self.field.markers or (
_markers.SuppressFixed in self.field.markers and self.lowered.is_fixed()
)


@dataclasses.dataclass
class LoweredArgumentDefinition:
Expand Down Expand Up @@ -427,13 +434,6 @@ def _rule_generate_helptext(
) -> None:
"""Generate helptext from docstring, argument name, default values."""

# If the suppress marker is attached, hide the argument.
if _markers.Suppress in arg.field.markers or (
_markers.SuppressFixed in arg.field.markers and lowered.is_fixed()
):
lowered.help = argparse.SUPPRESS
return

help_parts = []

primary_help = arg.field.helptext
Expand Down Expand Up @@ -525,7 +525,6 @@ def _rule_generate_helptext(
# The percent symbol needs some extra handling in argparse.
# https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string
lowered.help = " ".join([p for p in help_parts if len(p) > 0]).replace("%", "%%")
return


def _rule_set_name_or_flag_and_dest(
Expand Down
17 changes: 8 additions & 9 deletions src/tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def callable_with_args(
consumed_keywords: Set[str] = set()

def get_value_from_arg(
prefixed_field_name: str, field_def: _fields.FieldDefinition
prefixed_field_name: str, arg: _arguments.ArgumentDefinition
) -> tuple[Any, bool]:
"""Helper for getting values from `value_from_arg` + doing some extra
asserts.
Expand All @@ -64,11 +64,11 @@ def get_value_from_arg(
# When would the value not be found? Only if we have
# `tyro.conf.ConslidateSubcommandArgs` for one of the contained
# subparsers.
assert (
assert arg.is_suppressed() or (
parser_definition.subparsers is not None
and parser_definition.consolidate_subcommand_args
), "Field value is unexpectedly missing. This is likely a bug in tyro."
return field_def.default, False
return arg.field.default, False
else:
return value_from_prefixed_field_name[prefixed_field_name], True

Expand Down Expand Up @@ -96,7 +96,7 @@ def get_value_from_arg(
name_maybe_prefixed = prefixed_field_name
consumed_keywords.add(name_maybe_prefixed)
if not arg.lowered.is_fixed():
value, value_found = get_value_from_arg(name_maybe_prefixed, field)
value, value_found = get_value_from_arg(name_maybe_prefixed, arg)

if value in _fields.MISSING_AND_MISSING_NONPROP:
value = arg.field.default
Expand Down Expand Up @@ -133,7 +133,9 @@ def get_value_from_arg(
else:
assert arg.field.default not in _fields.MISSING_AND_MISSING_NONPROP
value = arg.field.default
parsed_value = value_from_prefixed_field_name.get(prefixed_field_name)
parsed_value = value_from_prefixed_field_name.get(
prefixed_field_name, _singleton.MISSING_NONPROP
)
if parsed_value not in _fields.MISSING_AND_MISSING_NONPROP:
raise InstantiationError(
f"{'/'.join(arg.lowered.name_or_flags)} was passed in, but"
Expand Down Expand Up @@ -162,10 +164,7 @@ def get_value_from_arg(
subparser_dest = _strings.make_subparser_dest(name=prefixed_field_name)
consumed_keywords.add(subparser_dest)
if subparser_dest in value_from_prefixed_field_name:
subparser_name, subparser_name_found = get_value_from_arg(
subparser_dest, field
)
assert subparser_name_found
subparser_name = value_from_prefixed_field_name[subparser_dest]
else:
assert (
subparser_def.default_instance
Expand Down
43 changes: 27 additions & 16 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ParserSpecification:
"""Each parser contains a list of arguments and optionally some subparsers."""

f: Callable
markers: Set[_markers._Marker]
description: str
args: List[_arguments.ArgumentDefinition]
field_list: List[_fields.FieldDefinition]
Expand Down Expand Up @@ -172,6 +173,7 @@ def from_callable_or_type(

return ParserSpecification(
f=f,
markers=markers,
description=_strings.remove_single_line_breaks(
description
if description is not None
Expand Down Expand Up @@ -268,11 +270,12 @@ def group_name_from_arg(arg: _arguments.ArgumentDefinition) -> str:
# Add each argument group. Groups with only suppressed arguments won't
# be added.
for arg in self.args:
# Don't add suppressed arguments to the parser.
if arg.is_suppressed():
continue

group_name = group_name_from_arg(arg)
if (
arg.lowered.help is not argparse.SUPPRESS
and group_name not in group_from_group_name
):
if group_name not in group_from_group_name:
description = (
parent.helptext_from_intern_prefixed_field_name.get(
arg.intern_prefix
Expand All @@ -290,13 +293,8 @@ def group_name_from_arg(arg: _arguments.ArgumentDefinition) -> str:
arg.add_argument(positional_group)
continue

if group_name in group_from_group_name:
arg.add_argument(group_from_group_name[group_name])
else:
# Suppressed argument: still need to add them, but they won't show up in
# the helptext so it doesn't matter which group.
assert arg.lowered.help is argparse.SUPPRESS
arg.add_argument(group_from_group_name[""])
assert group_name in group_from_group_name
arg.add_argument(group_from_group_name[group_name])

for child in self.child_from_prefix.values():
child.apply_args(parser, parent=self)
Expand Down Expand Up @@ -475,16 +473,16 @@ def from_field(
subcommand_config_from_name: Dict[str, _confstruct._SubcommandConfig] = {}
subcommand_type_from_name: Dict[str, type] = {}
for option in options:
option_unwrapped, found_subcommand_configs = _resolver.unwrap_annotated(
option, _confstruct._SubcommandConfig
)
subcommand_name = _strings.subparser_name_from_type(
(
""
if _markers.OmitSubcommandPrefixes in field.markers
else extern_prefix
),
type(None) if option is none_proxy else cast(type, option),
)
option_unwrapped, found_subcommand_configs = _resolver.unwrap_annotated(
option, _confstruct._SubcommandConfig
type(None) if option_unwrapped is none_proxy else cast(type, option),
)
if len(found_subcommand_configs) != 0:
# Explicitly annotated default.
Expand Down Expand Up @@ -567,6 +565,9 @@ def from_field(
for a in annotations
if not isinstance(a, _confstruct._SubcommandConfig)
)
if _markers.Suppress in annotations:
continue

if len(annotations) == 0:
option = option_origin
else:
Expand Down Expand Up @@ -595,6 +596,10 @@ def from_field(
)
parser_from_name[subcommand_name] = subparser

# Default parser was suppressed!
if default_name not in parser_from_name:
default_name = None

# Required if a default is passed in, but the default value has missing
# parameters.
default_parser = None
Expand All @@ -603,13 +608,19 @@ def from_field(
else:
required = False
default_parser = parser_from_name[default_name]

# If there are any required arguments.
if any(map(lambda arg: arg.lowered.required, default_parser.args)):
required = True
if (
default_parser = None

# If there are any required subparsers.
elif (
default_parser.subparsers is not None
and default_parser.subparsers.required
):
required = True
default_parser = None

return SubparsersSpecification(
name=field.intern_name,
Expand Down
18 changes: 13 additions & 5 deletions src/tyro/constructors/_primitive_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,20 @@ def union_rule(
nargs: int | Literal["*"] = 1
first = True
for t in options:
option_spec = ConstructorRegistry.get_primitive_spec(
PrimitiveTypeInfo.make(
raw_annotation=t,
parent_markers=type_info.markers,
)
option_type_info = PrimitiveTypeInfo.make(
raw_annotation=t,
parent_markers=type_info.markers,
)

# If the argument is not suppressed, we can add the ability to
# suppress individual options.
if (
_markers.Suppress not in type_info.markers
and _markers.Suppress in option_type_info.markers
):
continue

option_spec = ConstructorRegistry.get_primitive_spec(option_type_info)
if isinstance(option_spec, UnsupportedTypeAnnotationError):
return option_spec
if option_spec.choices is None:
Expand Down
29 changes: 9 additions & 20 deletions src/tyro/extras/_subcommand_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,12 @@ def cli(
for orig_name in orig_subcommand_names:
subcommands[swap_delimeters(orig_name)] = subcommands.pop(orig_name)

if len(subcommands) == 1:
return tyro.cli(
next(iter(subcommands.values())),
prog=prog,
description=description,
args=args,
use_underscores=use_underscores,
console_outputs=console_outputs,
config=config,
)
else:
return tyro.extras.subcommand_cli_from_dict(
subcommands,
prog=prog,
description=description,
args=args,
use_underscores=use_underscores,
console_outputs=console_outputs,
config=config,
)
return tyro.extras.subcommand_cli_from_dict(
subcommands,
prog=prog,
description=description,
args=args,
use_underscores=use_underscores,
console_outputs=console_outputs,
config=config,
)
7 changes: 5 additions & 2 deletions src/tyro/extras/_subcommand_cli_from_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import Annotated

from tyro.conf._markers import Marker
from tyro.conf._markers import Marker, Suppress

from .._cli import cli
from ..conf import subcommand
Expand Down Expand Up @@ -101,7 +101,6 @@ def subcommand_cli_from_dict(
config: Sequence of config marker objects, from :mod:`tyro.conf`.
"""
# We need to form a union type, which requires at least two elements.
assert len(subcommands) >= 2, "At least two subcommands are required."
return cli(
Union.__getitem__( # type: ignore
tuple(
Expand All @@ -115,6 +114,10 @@ def subcommand_cli_from_dict(
]
for k, v in subcommands.items()
]
# Union types need at least two types. To support the case
# where we only pass one subcommand in, we'll pad with `None`
# but suppress it.
+ [Annotated[None, Suppress]]
)
),
prog=prog,
Expand Down
53 changes: 45 additions & 8 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,17 @@ class Parent:
Parent,
args="nested1.nested2.subcommand:command-a".split(" "),
) == Parent(Nested1(Nested2(A(7))))
assert tyro.cli(
Parent,
args=(
"nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split(
" "
)
),
) == Parent(Nested1(Nested2(A(3))))

# The `a` argument is suppresed.
with pytest.raises(SystemExit):
tyro.cli(
Parent,
args=(
"nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split(
" "
)
),
)

assert tyro.cli(
Parent,
Expand Down Expand Up @@ -1757,3 +1760,37 @@ class Config:
== Config()
)
assert tyro.cli(Config, args=["optimizer:adam"]) == Config()


def test_suppress_in_union() -> None:
@dataclasses.dataclass
class ConfigA:
x: int

@dataclasses.dataclass
class ConfigB:
y: Union[int, Annotated[str, tyro.conf.Suppress]]
z: Annotated[Union[str, int], tyro.conf.Suppress] = 3

def main(
x: Union[Annotated[ConfigA, tyro.conf.Suppress], ConfigB] = ConfigA(3),
) -> Any:
return x

assert tyro.cli(main, args="x:config-b --x.y 5".split(" ")) == ConfigB(5)

with pytest.raises(SystemExit):
# ConfigA is suppressed, so there'll be no default.
tyro.cli(main, args=[])
with pytest.raises(SystemExit):
# ConfigB needs an int, since str is suppressed.
tyro.cli(main, args="x:config-b --x.y five".split(" "))
with pytest.raises(SystemExit):
# The z argument is suppressed.
tyro.cli(main, args="x:config-b --x.y 5 --x.z 3".split(" "))
with pytest.raises(SystemExit):
# ConfigA is suppressed.
assert tyro.cli(main, args=["x:config-a"])
with pytest.raises(SystemExit):
# ConfigB has a required argument.
assert tyro.cli(main, args=["x:config-b"])
10 changes: 8 additions & 2 deletions tests/test_decorator_subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ def test_app_just_one_cli(capsys):
app_just_one.cli(args=["--help"])
captured = capsys.readouterr()
assert "usage: " in captured.out
assert "greet-person" not in captured.out
assert "addition" not in captured.out
assert "greet-person" in captured.out
assert "--name" not in captured.out

# Test: `python my_script.py greet-person --help`
with pytest.raises(SystemExit):
app_just_one.cli(args=["greet-person", "--help"], sort_subcommands=False)
captured = capsys.readouterr()
assert "usage: " in captured.out
assert "--name" in captured.out


Expand Down
16 changes: 6 additions & 10 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,16 +1181,12 @@ def commit(message: str, all: bool = False) -> Tuple[str, bool]:
"""Make a commit."""
return (message, all)

# If we only get one, we unfortunately can't form subcommands. This is
# because unions in Python require at least 2 types.
with pytest.raises(AssertionError):
tyro.extras.subcommand_cli_from_dict(
{
"commit": commit,
},
args="--message hello --all".split(" "),
)

assert tyro.extras.subcommand_cli_from_dict(
{
"commit": commit,
},
args="commit --message hello --all".split(" "),
) == ("hello", True)
assert (
tyro.extras.subcommand_cli_from_dict(
{
Expand Down
Loading
Loading