diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index 5c96b31b..6e1235ee 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -197,6 +197,13 @@ def lowered(self) -> LoweredArgumentDefinition: _rule_apply_argconf(self, lowered) return lowered + def is_suppressed(self) -> bool: + """Returns if the argument is suppressed. Suppressed arguments won't be + added to the parser.""" + return _markers.Suppress in self.field.markers or ( + _markers.SuppressFixed in self.field.markers and self.lowered.is_fixed() + ) + @dataclasses.dataclass class LoweredArgumentDefinition: @@ -427,13 +434,6 @@ def _rule_generate_helptext( ) -> None: """Generate helptext from docstring, argument name, default values.""" - # If the suppress marker is attached, hide the argument. - if _markers.Suppress in arg.field.markers or ( - _markers.SuppressFixed in arg.field.markers and lowered.is_fixed() - ): - lowered.help = argparse.SUPPRESS - return - help_parts = [] primary_help = arg.field.helptext @@ -525,7 +525,6 @@ def _rule_generate_helptext( # The percent symbol needs some extra handling in argparse. # https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string lowered.help = " ".join([p for p in help_parts if len(p) > 0]).replace("%", "%%") - return def _rule_set_name_or_flag_and_dest( diff --git a/src/tyro/_calling.py b/src/tyro/_calling.py index 4e74374c..a4d9bea4 100644 --- a/src/tyro/_calling.py +++ b/src/tyro/_calling.py @@ -48,7 +48,7 @@ def callable_with_args( consumed_keywords: Set[str] = set() def get_value_from_arg( - prefixed_field_name: str, field_def: _fields.FieldDefinition + prefixed_field_name: str, arg: _arguments.ArgumentDefinition ) -> tuple[Any, bool]: """Helper for getting values from `value_from_arg` + doing some extra asserts. @@ -64,11 +64,11 @@ def get_value_from_arg( # When would the value not be found? Only if we have # `tyro.conf.ConslidateSubcommandArgs` for one of the contained # subparsers. - assert ( + assert arg.is_suppressed() or ( parser_definition.subparsers is not None and parser_definition.consolidate_subcommand_args ), "Field value is unexpectedly missing. This is likely a bug in tyro." - return field_def.default, False + return arg.field.default, False else: return value_from_prefixed_field_name[prefixed_field_name], True @@ -96,7 +96,7 @@ def get_value_from_arg( name_maybe_prefixed = prefixed_field_name consumed_keywords.add(name_maybe_prefixed) if not arg.lowered.is_fixed(): - value, value_found = get_value_from_arg(name_maybe_prefixed, field) + value, value_found = get_value_from_arg(name_maybe_prefixed, arg) if value in _fields.MISSING_AND_MISSING_NONPROP: value = arg.field.default @@ -133,7 +133,9 @@ def get_value_from_arg( else: assert arg.field.default not in _fields.MISSING_AND_MISSING_NONPROP value = arg.field.default - parsed_value = value_from_prefixed_field_name.get(prefixed_field_name) + parsed_value = value_from_prefixed_field_name.get( + prefixed_field_name, _singleton.MISSING_NONPROP + ) if parsed_value not in _fields.MISSING_AND_MISSING_NONPROP: raise InstantiationError( f"{'/'.join(arg.lowered.name_or_flags)} was passed in, but" @@ -162,10 +164,7 @@ def get_value_from_arg( subparser_dest = _strings.make_subparser_dest(name=prefixed_field_name) consumed_keywords.add(subparser_dest) if subparser_dest in value_from_prefixed_field_name: - subparser_name, subparser_name_found = get_value_from_arg( - subparser_dest, field - ) - assert subparser_name_found + subparser_name = value_from_prefixed_field_name[subparser_dest] else: assert ( subparser_def.default_instance diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 06ec18a5..e55e82bd 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -38,6 +38,7 @@ class ParserSpecification: """Each parser contains a list of arguments and optionally some subparsers.""" f: Callable + markers: Set[_markers._Marker] description: str args: List[_arguments.ArgumentDefinition] field_list: List[_fields.FieldDefinition] @@ -172,6 +173,7 @@ def from_callable_or_type( return ParserSpecification( f=f, + markers=markers, description=_strings.remove_single_line_breaks( description if description is not None @@ -268,11 +270,12 @@ def group_name_from_arg(arg: _arguments.ArgumentDefinition) -> str: # Add each argument group. Groups with only suppressed arguments won't # be added. for arg in self.args: + # Don't add suppressed arguments to the parser. + if arg.is_suppressed(): + continue + group_name = group_name_from_arg(arg) - if ( - arg.lowered.help is not argparse.SUPPRESS - and group_name not in group_from_group_name - ): + if group_name not in group_from_group_name: description = ( parent.helptext_from_intern_prefixed_field_name.get( arg.intern_prefix @@ -290,13 +293,8 @@ def group_name_from_arg(arg: _arguments.ArgumentDefinition) -> str: arg.add_argument(positional_group) continue - if group_name in group_from_group_name: - arg.add_argument(group_from_group_name[group_name]) - else: - # Suppressed argument: still need to add them, but they won't show up in - # the helptext so it doesn't matter which group. - assert arg.lowered.help is argparse.SUPPRESS - arg.add_argument(group_from_group_name[""]) + assert group_name in group_from_group_name + arg.add_argument(group_from_group_name[group_name]) for child in self.child_from_prefix.values(): child.apply_args(parser, parent=self) @@ -475,16 +473,16 @@ def from_field( subcommand_config_from_name: Dict[str, _confstruct._SubcommandConfig] = {} subcommand_type_from_name: Dict[str, type] = {} for option in options: + option_unwrapped, found_subcommand_configs = _resolver.unwrap_annotated( + option, _confstruct._SubcommandConfig + ) subcommand_name = _strings.subparser_name_from_type( ( "" if _markers.OmitSubcommandPrefixes in field.markers else extern_prefix ), - type(None) if option is none_proxy else cast(type, option), - ) - option_unwrapped, found_subcommand_configs = _resolver.unwrap_annotated( - option, _confstruct._SubcommandConfig + type(None) if option_unwrapped is none_proxy else cast(type, option), ) if len(found_subcommand_configs) != 0: # Explicitly annotated default. @@ -567,6 +565,9 @@ def from_field( for a in annotations if not isinstance(a, _confstruct._SubcommandConfig) ) + if _markers.Suppress in annotations: + continue + if len(annotations) == 0: option = option_origin else: @@ -595,6 +596,10 @@ def from_field( ) parser_from_name[subcommand_name] = subparser + # Default parser was suppressed! + if default_name not in parser_from_name: + default_name = None + # Required if a default is passed in, but the default value has missing # parameters. default_parser = None @@ -603,13 +608,19 @@ def from_field( else: required = False default_parser = parser_from_name[default_name] + + # If there are any required arguments. if any(map(lambda arg: arg.lowered.required, default_parser.args)): required = True - if ( + default_parser = None + + # If there are any required subparsers. + elif ( default_parser.subparsers is not None and default_parser.subparsers.required ): required = True + default_parser = None return SubparsersSpecification( name=field.intern_name, diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py index a22eec2b..dd7c7300 100644 --- a/src/tyro/constructors/_primitive_spec.py +++ b/src/tyro/constructors/_primitive_spec.py @@ -619,12 +619,20 @@ def union_rule( nargs: int | Literal["*"] = 1 first = True for t in options: - option_spec = ConstructorRegistry.get_primitive_spec( - PrimitiveTypeInfo.make( - raw_annotation=t, - parent_markers=type_info.markers, - ) + option_type_info = PrimitiveTypeInfo.make( + raw_annotation=t, + parent_markers=type_info.markers, ) + + # If the argument is not suppressed, we can add the ability to + # suppress individual options. + if ( + _markers.Suppress not in type_info.markers + and _markers.Suppress in option_type_info.markers + ): + continue + + option_spec = ConstructorRegistry.get_primitive_spec(option_type_info) if isinstance(option_spec, UnsupportedTypeAnnotationError): return option_spec if option_spec.choices is None: diff --git a/src/tyro/extras/_subcommand_app.py b/src/tyro/extras/_subcommand_app.py index 560dc27b..ce9b8e41 100644 --- a/src/tyro/extras/_subcommand_app.py +++ b/src/tyro/extras/_subcommand_app.py @@ -138,23 +138,12 @@ def cli( for orig_name in orig_subcommand_names: subcommands[swap_delimeters(orig_name)] = subcommands.pop(orig_name) - if len(subcommands) == 1: - return tyro.cli( - next(iter(subcommands.values())), - prog=prog, - description=description, - args=args, - use_underscores=use_underscores, - console_outputs=console_outputs, - config=config, - ) - else: - return tyro.extras.subcommand_cli_from_dict( - subcommands, - prog=prog, - description=description, - args=args, - use_underscores=use_underscores, - console_outputs=console_outputs, - config=config, - ) + return tyro.extras.subcommand_cli_from_dict( + subcommands, + prog=prog, + description=description, + args=args, + use_underscores=use_underscores, + console_outputs=console_outputs, + config=config, + ) diff --git a/src/tyro/extras/_subcommand_cli_from_dict.py b/src/tyro/extras/_subcommand_cli_from_dict.py index b7933afc..f674103c 100644 --- a/src/tyro/extras/_subcommand_cli_from_dict.py +++ b/src/tyro/extras/_subcommand_cli_from_dict.py @@ -2,7 +2,7 @@ from typing_extensions import Annotated -from tyro.conf._markers import Marker +from tyro.conf._markers import Marker, Suppress from .._cli import cli from ..conf import subcommand @@ -101,7 +101,6 @@ def subcommand_cli_from_dict( config: Sequence of config marker objects, from :mod:`tyro.conf`. """ # We need to form a union type, which requires at least two elements. - assert len(subcommands) >= 2, "At least two subcommands are required." return cli( Union.__getitem__( # type: ignore tuple( @@ -115,6 +114,10 @@ def subcommand_cli_from_dict( ] for k, v in subcommands.items() ] + # Union types need at least two types. To support the case + # where we only pass one subcommand in, we'll pad with `None` + # but suppress it. + + [Annotated[None, Suppress]] ) ), prog=prog, diff --git a/tests/test_conf.py b/tests/test_conf.py index 9e34915d..8c350a51 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -260,14 +260,17 @@ class Parent: Parent, args="nested1.nested2.subcommand:command-a".split(" "), ) == Parent(Nested1(Nested2(A(7)))) - assert tyro.cli( - Parent, - args=( - "nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split( - " " - ) - ), - ) == Parent(Nested1(Nested2(A(3)))) + + # The `a` argument is suppresed. + with pytest.raises(SystemExit): + tyro.cli( + Parent, + args=( + "nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split( + " " + ) + ), + ) assert tyro.cli( Parent, @@ -1757,3 +1760,37 @@ class Config: == Config() ) assert tyro.cli(Config, args=["optimizer:adam"]) == Config() + + +def test_suppress_in_union() -> None: + @dataclasses.dataclass + class ConfigA: + x: int + + @dataclasses.dataclass + class ConfigB: + y: Union[int, Annotated[str, tyro.conf.Suppress]] + z: Annotated[Union[str, int], tyro.conf.Suppress] = 3 + + def main( + x: Union[Annotated[ConfigA, tyro.conf.Suppress], ConfigB] = ConfigA(3), + ) -> Any: + return x + + assert tyro.cli(main, args="x:config-b --x.y 5".split(" ")) == ConfigB(5) + + with pytest.raises(SystemExit): + # ConfigA is suppressed, so there'll be no default. + tyro.cli(main, args=[]) + with pytest.raises(SystemExit): + # ConfigB needs an int, since str is suppressed. + tyro.cli(main, args="x:config-b --x.y five".split(" ")) + with pytest.raises(SystemExit): + # The z argument is suppressed. + tyro.cli(main, args="x:config-b --x.y 5 --x.z 3".split(" ")) + with pytest.raises(SystemExit): + # ConfigA is suppressed. + assert tyro.cli(main, args=["x:config-a"]) + with pytest.raises(SystemExit): + # ConfigB has a required argument. + assert tyro.cli(main, args=["x:config-b"]) diff --git a/tests/test_decorator_subcommands.py b/tests/test_decorator_subcommands.py index d9e8f9de..4002dbef 100644 --- a/tests/test_decorator_subcommands.py +++ b/tests/test_decorator_subcommands.py @@ -28,8 +28,14 @@ def test_app_just_one_cli(capsys): app_just_one.cli(args=["--help"]) captured = capsys.readouterr() assert "usage: " in captured.out - assert "greet-person" not in captured.out - assert "addition" not in captured.out + assert "greet-person" in captured.out + assert "--name" not in captured.out + + # Test: `python my_script.py greet-person --help` + with pytest.raises(SystemExit): + app_just_one.cli(args=["greet-person", "--help"], sort_subcommands=False) + captured = capsys.readouterr() + assert "usage: " in captured.out assert "--name" in captured.out diff --git a/tests/test_nested.py b/tests/test_nested.py index e036c109..3b077f12 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -1181,16 +1181,12 @@ def commit(message: str, all: bool = False) -> Tuple[str, bool]: """Make a commit.""" return (message, all) - # If we only get one, we unfortunately can't form subcommands. This is - # because unions in Python require at least 2 types. - with pytest.raises(AssertionError): - tyro.extras.subcommand_cli_from_dict( - { - "commit": commit, - }, - args="--message hello --all".split(" "), - ) - + assert tyro.extras.subcommand_cli_from_dict( + { + "commit": commit, + }, + args="commit --message hello --all".split(" "), + ) == ("hello", True) assert ( tyro.extras.subcommand_cli_from_dict( { diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 28b8c2de..08cfa1b9 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -270,14 +270,17 @@ class Parent: Parent, args="nested1.nested2.subcommand:command-a".split(" "), ) == Parent(Nested1(Nested2(A(7)))) - assert tyro.cli( - Parent, - args=( - "nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split( - " " - ) - ), - ) == Parent(Nested1(Nested2(A(3)))) + + # The `a` argument is suppresed. + with pytest.raises(SystemExit): + tyro.cli( + Parent, + args=( + "nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split( + " " + ) + ), + ) assert tyro.cli( Parent, @@ -1763,3 +1766,37 @@ class Config: == Config() ) assert tyro.cli(Config, args=["optimizer:adam"]) == Config() + + +def test_suppress_in_union() -> None: + @dataclasses.dataclass + class ConfigA: + x: int + + @dataclasses.dataclass + class ConfigB: + y: int | Annotated[str, tyro.conf.Suppress] + z: Annotated[str | int, tyro.conf.Suppress] = 3 + + def main( + x: Annotated[ConfigA, tyro.conf.Suppress] | ConfigB = ConfigA(3), + ) -> Any: + return x + + assert tyro.cli(main, args="x:config-b --x.y 5".split(" ")) == ConfigB(5) + + with pytest.raises(SystemExit): + # ConfigA is suppressed, so there'll be no default. + tyro.cli(main, args=[]) + with pytest.raises(SystemExit): + # ConfigB needs an int, since str is suppressed. + tyro.cli(main, args="x:config-b --x.y five".split(" ")) + with pytest.raises(SystemExit): + # The z argument is suppressed. + tyro.cli(main, args="x:config-b --x.y 5 --x.z 3".split(" ")) + with pytest.raises(SystemExit): + # ConfigA is suppressed. + assert tyro.cli(main, args=["x:config-a"]) + with pytest.raises(SystemExit): + # ConfigB has a required argument. + assert tyro.cli(main, args=["x:config-b"]) diff --git a/tests/test_py311_generated/test_decorator_subcommands_generated.py b/tests/test_py311_generated/test_decorator_subcommands_generated.py index d9e8f9de..4002dbef 100644 --- a/tests/test_py311_generated/test_decorator_subcommands_generated.py +++ b/tests/test_py311_generated/test_decorator_subcommands_generated.py @@ -28,8 +28,14 @@ def test_app_just_one_cli(capsys): app_just_one.cli(args=["--help"]) captured = capsys.readouterr() assert "usage: " in captured.out - assert "greet-person" not in captured.out - assert "addition" not in captured.out + assert "greet-person" in captured.out + assert "--name" not in captured.out + + # Test: `python my_script.py greet-person --help` + with pytest.raises(SystemExit): + app_just_one.cli(args=["greet-person", "--help"], sort_subcommands=False) + captured = capsys.readouterr() + assert "usage: " in captured.out assert "--name" in captured.out diff --git a/tests/test_py311_generated/test_nested_generated.py b/tests/test_py311_generated/test_nested_generated.py index b8bbf94a..068538f9 100644 --- a/tests/test_py311_generated/test_nested_generated.py +++ b/tests/test_py311_generated/test_nested_generated.py @@ -1191,16 +1191,12 @@ def commit(message: str, all: bool = False) -> Tuple[str, bool]: """Make a commit.""" return (message, all) - # If we only get one, we unfortunately can't form subcommands. This is - # because unions in Python require at least 2 types. - with pytest.raises(AssertionError): - tyro.extras.subcommand_cli_from_dict( - { - "commit": commit, - }, - args="--message hello --all".split(" "), - ) - + assert tyro.extras.subcommand_cli_from_dict( + { + "commit": commit, + }, + args="commit --message hello --all".split(" "), + ) == ("hello", True) assert ( tyro.extras.subcommand_cli_from_dict( {