Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make add_dataclass_options public, separate field extraction into public helper function #59

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 96 additions & 84 deletions argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@

"""
import argparse
from collections import namedtuple
from typing import (
TypeVar,
Optional,
Expand Down Expand Up @@ -308,7 +309,7 @@ def format_usage(self):
def parse_args(options_class: Type[OptionsType], args: ArgsType = None) -> OptionsType:
"""Parse arguments and return as the dataclass type."""
parser = argparse.ArgumentParser()
_add_dataclass_options(options_class, parser)
add_dataclass_options(options_class, parser)
kwargs = _get_kwargs(parser.parse_args(args))
return options_class(**kwargs)

Expand All @@ -320,102 +321,113 @@ def parse_known_args(
and list of remaining arguments.
"""
parser = argparse.ArgumentParser()
_add_dataclass_options(options_class, parser)
add_dataclass_options(options_class, parser)
namespace, others = parser.parse_known_args(args=args)
kwargs = _get_kwargs(namespace)
return options_class(**kwargs), others


def _add_dataclass_options(
options_class: Type[OptionsType], parser: argparse.ArgumentParser
) -> None:
if not is_dataclass(options_class):
raise TypeError("cls must be a dataclass")

for field in fields(options_class):
args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.metadata.get("type", field.type),
"help": field.metadata.get("help", None),
}

if field.metadata.get("args") and not positional:
# We want to ensure that we store the argument based on the
# name of the field and not whatever flag name was provided
kwargs["dest"] = field.name
def extract_argparse_kwargs(field: Field[Any]) -> Tuple[List[str], Dict[str, Any]]:
"""Extract kwargs of ArgumentParser.add_argument from a dataclass field.

Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument.
"""
args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.metadata.get("type", field.type),
"help": field.metadata.get("help", None),
}

if field.metadata.get("args") and not positional:
# We want to ensure that we store the argument based on the
# name of the field and not whatever flag name was provided
kwargs["dest"] = field.name

if field.metadata.get("choices") is not None:
kwargs["choices"] = field.metadata["choices"]

# Support Literal types as an alternative means of specifying choices.
if get_origin(field.type) is Literal:
# Prohibit a potential collision with the choices field
if field.metadata.get("choices") is not None:
kwargs["choices"] = field.metadata["choices"]
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should not be combined with choices in the metadata. "
"Remove the redundant choices field from the metadata."
)

# Get the types of the arguments of the Literal
types = [type(arg) for arg in get_args(field.type)]

# Make sure just a single type has been used
if len(set(types)) > 1:
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should contain choices of a single type. "
f"Instead, {len(set(types))} types where found: "
+ ", ".join([type_.__name__ for type_ in set(types)])
+ "."
)

# Support Literal types as an alternative means of specifying choices.
if get_origin(field.type) is Literal:
# Prohibit a potential collision with the choices field
if field.metadata.get("choices") is not None:
# Overwrite the type kwarg
kwargs["type"] = types[0]
# Use the literal arguments as choices
kwargs["choices"] = get_args(field.type)

if field.metadata.get("metavar") is not None:
kwargs["metavar"] = field.metadata["metavar"]

if field.metadata.get("nargs") is not None:
kwargs["nargs"] = field.metadata["nargs"]
if field.metadata.get("type") is None:
# When nargs is specified, field.type should be a list,
# or something equivalent, like typing.List.
# Using it would most likely result in an error, so if the user
# did not specify the type of the elements within the list, we
# try to infer it:
try:
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
except IndexError:
# get_args returned an empty tuple, type cannot be inferred
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should not be combined with choices in the metadata. "
"Remove the redundant choices field from the metadata."
"Try using a parameterized type hint, or "
"specifying the type explicitly using metadata['type']"
)

# Get the types of the arguments of the Literal
types = [type(arg) for arg in get_args(field.type)]

# Make sure just a single type has been used
if len(set(types)) > 1:
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should contain choices of a single type. "
f"Instead, {len(set(types))} types where found: "
+ ", ".join([type_.__name__ for type_ in set(types)])
+ "."
if field.default == field.default_factory == MISSING and not positional:
kwargs["required"] = True
else:
kwargs["default"] = MISSING

if field.type is bool:
_handle_bool_type(field, args, kwargs)
elif get_origin(field.type) is Union:
if field.metadata.get("type") is None:
# Optional[X] is equivalent to Union[X, None].
f_args = get_args(field.type)
if len(f_args) == 2 and NoneType in f_args:
arg = next(a for a in f_args if a is not NoneType)
kwargs["type"] = arg
else:
raise TypeError(
"For Union types other than 'Optional', a custom 'type' must be specified using "
"'metadata'."
)

# Overwrite the type kwarg
kwargs["type"] = types[0]
# Use the literal arguments as choices
kwargs["choices"] = get_args(field.type)

if field.metadata.get("metavar") is not None:
kwargs["metavar"] = field.metadata["metavar"]

if field.metadata.get("nargs") is not None:
kwargs["nargs"] = field.metadata["nargs"]
if field.metadata.get("type") is None:
# When nargs is specified, field.type should be a list,
# or something equivalent, like typing.List.
# Using it would most likely result in an error, so if the user
# did not specify the type of the elements within the list, we
# try to infer it:
try:
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
except IndexError:
# get_args returned an empty tuple, type cannot be inferred
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Try using a parameterized type hint, or "
"specifying the type explicitly using metadata['type']"
)

if field.default == field.default_factory == MISSING and not positional:
kwargs["required"] = True
else:
kwargs["default"] = MISSING

if field.type is bool:
_handle_bool_type(field, args, kwargs)
elif get_origin(field.type) is Union:
if field.metadata.get("type") is None:
# Optional[X] is equivalent to Union[X, None].
f_args = get_args(field.type)
if len(f_args) == 2 and NoneType in f_args:
arg = next(a for a in f_args if a is not NoneType)
kwargs["type"] = arg
else:
raise TypeError(
"For Union types other than 'Optional', a custom 'type' must be specified using "
"'metadata'."
)
return args, kwargs


def add_dataclass_options(
options_class: Type[OptionsType], parser: argparse.ArgumentParser
) -> None:
"""Adds options given as dataclass fields to the parser."""
if not is_dataclass(options_class):
raise TypeError("cls must be a dataclass")

for field in fields(options_class):
args, kwargs = extract_argparse_kwargs(field)

if "group" in field.metadata:
_handle_argument_group(parser, field, args, kwargs)
Expand Down Expand Up @@ -493,7 +505,7 @@ class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
def __init__(self, options_class: Type[OptionsType], *args, **kwargs):
super().__init__(*args, **kwargs)
self._options_type: Type[OptionsType] = options_class
_add_dataclass_options(options_class, self)
add_dataclass_options(options_class, self)

def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType:
"""Parse arguments and return as the dataclass type."""
Expand Down