Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge case when combining PEP 695 aliases with typing.Annotated[] #178

Merged
merged 3 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def make(
markers: Tuple[_markers.Marker, ...] = (),
):
# Try to extract argconf overrides from type.
_, argconfs = _resolver.unwrap_annotated(
_, argconfs = _resolver.unwrap_annotated_and_aliases(
type_or_callable, _confstruct._ArgConfiguration
)
argconf = _confstruct._ArgConfiguration(
Expand All @@ -135,7 +135,7 @@ def make(
if argconf.help is not None:
helptext = argconf.help

type_or_callable, inferred_markers = _resolver.unwrap_annotated(
type_or_callable, inferred_markers = _resolver.unwrap_annotated_and_aliases(
type_or_callable, _markers._Marker
)
markers = inferred_markers + markers
Expand Down Expand Up @@ -285,7 +285,7 @@ def field_list_from_callable(

# Try to generate field list.
# We recursively apply markers.
_, parent_markers = _resolver.unwrap_annotated(f, _markers._Marker)
_, parent_markers = _resolver.unwrap_annotated_and_aliases(f, _markers._Marker)
with FieldDefinition.marker_context(parent_markers):
field_list = _try_field_list_from_callable(f, default_instance)

Expand Down Expand Up @@ -386,7 +386,7 @@ def _try_field_list_from_callable(
# Check for default instances in subcommand configs. This is needed for
# is_nested_type() when arguments are not valid without a default, and this
# default is specified in the subcommand config.
f, found_subcommand_configs = _resolver.unwrap_annotated(
f, found_subcommand_configs = _resolver.unwrap_annotated_and_aliases(
f, conf._confstruct._SubcommandConfiguration
)
if len(found_subcommand_configs) > 0:
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/_instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def instantiator(strings: List[str]) -> None:
if maybe_newtype_name is not None:
metavar = maybe_newtype_name.upper()

typ = _resolver.unwrap_annotated(typ)
typ = _resolver.unwrap_annotated_and_aliases(typ)

# Address container types. If a matching container is found, this will recursively
# call instantiator_from_type().
Expand Down
18 changes: 11 additions & 7 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def from_callable_or_type(
"""Create a parser definition from a callable or type."""

# Consolidate subcommand types.
markers = _resolver.unwrap_annotated(f, _markers._Marker)[1]
markers = _resolver.unwrap_annotated_and_aliases(f, _markers._Marker)[1]
consolidate_subcommand_args = _markers.ConsolidateSubcommandArgs in markers

# Resolve the type of `f`, generate a field list.
Expand Down Expand Up @@ -387,7 +387,7 @@ def from_field(
extern_prefix: str,
) -> Optional[SubparsersSpecification]:
# Union of classes should create subparsers.
typ = _resolver.unwrap_annotated(field.type_or_callable)
typ = _resolver.unwrap_annotated_and_aliases(field.type_or_callable)
if get_origin(typ) is not Union:
return None

Expand All @@ -407,7 +407,7 @@ def from_field(

# If specified, swap types using tyro.conf.subcommand(constructor=...).
for i, option in enumerate(options):
_, found_subcommand_configs = _resolver.unwrap_annotated(
_, found_subcommand_configs = _resolver.unwrap_annotated_and_aliases(
option, _confstruct._SubcommandConfiguration
)
if (
Expand All @@ -417,7 +417,7 @@ def from_field(
options[i] = Annotated.__class_getitem__( # type: ignore
(
found_subcommand_configs[0].constructor_factory(),
*_resolver.unwrap_annotated(option, "all")[1],
*_resolver.unwrap_annotated_and_aliases(option, "all")[1],
)
)

Expand All @@ -443,8 +443,10 @@ def from_field(
else extern_prefix,
type(None) if option is none_proxy else cast(type, option),
)
option_unwrapped, found_subcommand_configs = _resolver.unwrap_annotated(
option, _confstruct._SubcommandConfiguration
option_unwrapped, found_subcommand_configs = (
_resolver.unwrap_annotated_and_aliases(
option, _confstruct._SubcommandConfiguration
)
)
if len(found_subcommand_configs) != 0:
# Explicitly annotated default.
Expand Down Expand Up @@ -508,7 +510,9 @@ def from_field(

# Strip the subcommand config from the option type.
# Relevant: https://github.com/brentyi/tyro/pull/117
option_origin, annotations = _resolver.unwrap_annotated(option, "all")
option_origin, annotations = _resolver.unwrap_annotated_and_aliases(
option, "all"
)
annotations = tuple(
a
for a in annotations
Expand Down
20 changes: 12 additions & 8 deletions src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def unwrap_origin_strip_extras(typ: TypeOrCallable) -> TypeOrCallable:
"""Returns the origin, ignoring typing.Annotated, of typ if it exists. Otherwise,
returns typ."""
# TODO: Annotated[] handling should be revisited...
typ = unwrap_annotated(typ)
typ = unwrap_annotated_and_aliases(typ)
origin = get_origin(typ)

if origin is not None:
Expand All @@ -72,7 +72,7 @@ def resolve_generic_types(
# ^We need this `if` statement for an obscure edge case: when `cls` is a
# function with `__tyro_markers__` set, we don't want/need to return
# Annotated[func, markers].
cls, annotations = unwrap_annotated(cls, "all")
cls, annotations = unwrap_annotated_and_aliases(cls, "all")

# We'll ignore NewType when getting the origin + args for generics.
origin_cls = get_origin(unwrap_newtype_and_aliases(cls)[0])
Expand Down Expand Up @@ -218,7 +218,7 @@ def unwrap_newtype_and_narrow_subtypes(
# it doesn't really make sense to parse this case.
return typ

superclass = unwrap_annotated(typ)
superclass = unwrap_annotated_and_aliases(typ)

# For Python 3.10.
if get_origin(superclass) is Union:
Expand All @@ -243,7 +243,7 @@ def swap_type_using_confstruct(typ: TypeOrCallable) -> TypeOrCallable:
`tyro.conf.arg` and `tyro.conf.subcommand`. Runtime annotations are
kept, but the type is swapped."""
# Need to swap types.
_, annotations = unwrap_annotated(typ, search_type="all")
_, annotations = unwrap_annotated_and_aliases(typ, search_type="all")
for anno in reversed(annotations):
if (
isinstance(
Expand Down Expand Up @@ -305,27 +305,27 @@ def narrow_collection_types(


@overload
def unwrap_annotated(
def unwrap_annotated_and_aliases(
typ: TypeOrCallable,
search_type: TypeForm[MetadataType],
) -> Tuple[TypeOrCallable, Tuple[MetadataType, ...]]: ...


@overload
def unwrap_annotated(
def unwrap_annotated_and_aliases(
typ: TypeOrCallable,
search_type: Literal["all"],
) -> Tuple[TypeOrCallable, Tuple[Any, ...]]: ...


@overload
def unwrap_annotated(
def unwrap_annotated_and_aliases(
typ: TypeOrCallable,
search_type: None = None,
) -> TypeOrCallable: ...


def unwrap_annotated(
def unwrap_annotated_and_aliases(
typ: TypeOrCallable,
search_type: Union[TypeForm[MetadataType], Literal["all"], object, None] = None,
) -> Union[Tuple[TypeOrCallable, Tuple[MetadataType, ...]], TypeOrCallable]:
Expand All @@ -337,6 +337,10 @@ def unwrap_annotated(
- Annotated[int, "1"], int => (int, ())
"""

# Unwrap aliases defined using Python 3.12's `type` syntax.
if isinstance(typ, TypeAliasType):
return unwrap_annotated_and_aliases(typ.__value__, search_type)

# `Final` and `ReadOnly` types are ignored in tyro.
while get_origin(typ) in STRIP_WRAPPER_TYPES:
typ = get_args(typ)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _subparser_name_from_type(cls: Type) -> Tuple[str, bool]:
from .conf import _confstruct # Prevent circular imports

cls, type_from_typevar = _resolver.resolve_generic_types(cls)
cls, found_subcommand_configs = _resolver.unwrap_annotated(
cls, found_subcommand_configs = _resolver.unwrap_annotated_and_aliases(
cls, _confstruct._SubcommandConfiguration
)

Expand Down
4 changes: 2 additions & 2 deletions src/tyro/_subcommand_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def _get_type_options(typ: _typing.TypeForm) -> Tuple[_typing.TypeForm, ...]:

# Check against supertypes.
for self_type in self_types:
self_type = _resolver.unwrap_annotated(self_type)
self_type = _resolver.unwrap_annotated_and_aliases(self_type)
self_type, _ = _resolver.unwrap_newtype_and_aliases(self_type)
ok = False
for super_type in super_types:
super_type = _resolver.unwrap_annotated(super_type)
super_type = _resolver.unwrap_annotated_and_aliases(super_type)
self_type, _ = _resolver.unwrap_newtype_and_aliases(self_type)
if issubclass(self_type, super_type):
ok = True
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/extras/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _get_contained_special_types_from_type(
else _parent_contained_dataclasses
)

cls = _resolver.unwrap_annotated(cls)
cls = _resolver.unwrap_annotated_and_aliases(cls)
cls, type_from_typevar = _resolver.resolve_generic_types(cls)

contained_special_types = {cls}
Expand Down
12 changes: 12 additions & 0 deletions tests/test_new_style_annotations_min_py312.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# PEP 695 isn't yet supported in mypy. (April 4, 2024)
from dataclasses import dataclass
from typing import Annotated

import tyro

Expand Down Expand Up @@ -52,3 +53,14 @@ class Container[T]:
assert tyro.cli(Container[Y], args="--a.a 1 --a.b 2".split(" ")) == Container(
Inner(1, 2)
)


type AnnotatedBasic = Annotated[int, tyro.conf.arg(name="basic")]


def test_annotated_alias():
@dataclass(frozen=True)
class Container:
a: AnnotatedBasic

assert tyro.cli(Container, args="--basic 1".split(" ")) == Container(1)
19 changes: 19 additions & 0 deletions tests/test_py311_generated/test_errors_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,25 @@ class Class:
assert error == ""


def test_suppress_console_outputs_fromdict() -> None:
def foo(track: bool) -> None:
print(track)

def bar(track: bool) -> None:
print(track)

target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stderr(target):
tyro.extras.subcommand_cli_from_dict(
{"foo": foo, "bar": bar},
args="foo --reward.trac".split(" "),
console_outputs=False,
)

error = target.getvalue()
assert error == ""


def test_similar_arguments_subcommands() -> None:
@dataclasses.dataclass
class RewardConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# PEP 695 isn't yet supported in mypy. (April 4, 2024)
from dataclasses import dataclass
from typing import Annotated

import tyro

Expand Down Expand Up @@ -52,3 +53,14 @@ class Container[T]:
assert tyro.cli(Container[Y], args="--a.a 1 --a.b 2".split(" ")) == Container(
Inner(1, 2)
)


type AnnotatedBasic = Annotated[int, tyro.conf.arg(name="basic")]


def test_annotated_alias():
@dataclass(frozen=True)
class Container:
a: AnnotatedBasic

assert tyro.cli(Container, args="--basic 1".split(" ")) == Container(1)
Loading