Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI support for optional and variadic positional args #519

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,10 @@ print(User().model_dump())

### Subcommands and Positional Arguments

Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.
Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. The
subcommand annotation can only be applied to required fields (i.e. fields that do not have a default value).
Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses
`dataclass`.

Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.
Expand Down Expand Up @@ -1284,6 +1285,9 @@ However, if your use case [aligns more with #2](#command-line-support), using Py
likely want required fields to be _strictly required at the CLI_. We can enable this behavior by using
`cli_enforce_required`.

!!! note
A required `CliPositionalArg` field is always strictly required (enforced) at the CLI.

```py
import os
import sys
Expand Down
64 changes: 51 additions & 13 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ def _load_env_vars(
if subcommand_dest not in selected_subcommands:
parsed_args[subcommand_dest] = self.cli_parse_none_str

parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')}
parsed_args = {
key: val
for key, val in parsed_args.items()
if not key.endswith(':subcommand') and val is not PydanticUndefined
}
if selected_subcommands:
last_selected_subcommand = max(selected_subcommands, key=len)
if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name):
Expand Down Expand Up @@ -1494,6 +1498,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
)

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_variadic_arg = []
positional_args, subcommand_args, optional_args = [], [], []
for field_name, field_info in _get_model_fields(model).items():
if _CliSubCommand in field_info.metadata:
Expand All @@ -1511,17 +1516,31 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
)
subcommand_args.append((field_name, field_info))
elif _CliPositionalArg in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
is_append_action = _annotation_contains_types(
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
)
if not is_append_action:
positional_args.append((field_name, field_info))
else:
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
positional_variadic_arg.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
optional_args.append((field_name, field_info))
return positional_args + subcommand_args + optional_args

if positional_variadic_arg:
if len(positional_variadic_arg) > 1:
field_names = ', '.join([name for name, info in positional_variadic_arg])
raise SettingsError(f'{model.__name__} has multiple variadic positonal arguments: {field_names}')
elif subcommand_args:
field_names = ', '.join([name for name, info in positional_variadic_arg + subcommand_args])
raise SettingsError(
f'{model.__name__} has variadic positonal arguments and subcommand arguments: {field_names}'
)

return positional_args + positional_variadic_arg + subcommand_args + optional_args

@property
def root_parser(self) -> T:
Expand Down Expand Up @@ -1727,11 +1746,9 @@ def _add_parser_args(
self._cli_dict_args[kwargs['dest']] = field_info.annotation

if _CliPositionalArg in field_info.metadata:
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
arg_names = [kwargs['dest']]
del kwargs['dest']
del kwargs['required']
flag_prefix = ''
arg_names, flag_prefix = self._convert_positional_arg(
kwargs, field_info, preferred_alias, model_default
)

self._convert_bool_flag(kwargs, field_info, model_default)

Expand Down Expand Up @@ -1787,6 +1804,27 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}'
)

def _convert_positional_arg(
self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str, model_default: Any
) -> tuple[list[str], str]:
flag_prefix = ''
arg_names = [kwargs['dest']]
kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
# conjunction with model_default instead of the derived kwargs['required'].
is_required = field_info.is_required() and model_default is PydanticUndefined
if kwargs.get('action') == 'append':
del kwargs['action']
kwargs['nargs'] = '+' if is_required else '*'
elif not is_required:
kwargs['nargs'] = '?'

del kwargs['dest']
del kwargs['required']
return arg_names, flag_prefix

def _get_arg_names(
self,
arg_prefix: str,
Expand Down
60 changes: 56 additions & 4 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,45 @@ class Cfg(BaseSettings):
assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}}


def test_cli_optional_positional_arg(env):
class Main(BaseSettings):
model_config = SettingsConfigDict(
cli_parse_args=True,
cli_enforce_required=True,
)

value: CliPositionalArg[int] = 123

assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 123}

env.set('VALUE', '456')
assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 456}

assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789}


def test_cli_variadic_positional_arg(env):
class MainRequired(BaseSettings):
model_config = SettingsConfigDict(cli_parse_args=True)

values: CliPositionalArg[List[int]]

class MainOptional(MainRequired):
values: CliPositionalArg[List[int]] = [1, 2, 3]

assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [1, 2, 3]}
with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'):
CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False)

env.set('VALUES', '[4,5,6]')
assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [4, 5, 6]}
with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'):
CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False)

assert CliApp.run(MainOptional, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]}
assert CliApp.run(MainRequired, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]}


def test_cli_enums(capsys, monkeypatch):
class Pet(IntEnum):
dog = 0
Expand Down Expand Up @@ -1416,13 +1455,26 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True):
PositionalArgNotOutermost()

with pytest.raises(
SettingsError, match='positional argument PositionalArgHasDefault.pos_arg has a default value'
SettingsError,
match='MultipleVariadicPositionialArgs has multiple variadic positonal arguments: strings, numbers',
):

class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True):
strings: CliPositionalArg[List[str]]
numbers: CliPositionalArg[List[int]]

MultipleVariadicPositionialArgs()

with pytest.raises(
SettingsError,
match='VariadicPositionialArgAndSubCommand has variadic positonal arguments and subcommand arguments: strings, sub_cmd',
):

class PositionalArgHasDefault(BaseSettings, cli_parse_args=True):
pos_arg: CliPositionalArg[str] = 'bad'
class VariadicPositionialArgAndSubCommand(BaseSettings, cli_parse_args=True):
strings: CliPositionalArg[List[str]]
sub_cmd: CliSubCommand[SubCmd]

PositionalArgHasDefault()
VariadicPositionialArgAndSubCommand()

with pytest.raises(
SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved <class 'str'>")
Expand Down
Loading