Skip to content

Commit

Permalink
CLI support for optional and variadic positional args (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab authored Jan 13, 2025
1 parent c989335 commit e2cbb2d
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 20 deletions.
10 changes: 7 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,10 @@ print(User().model_dump())

### Subcommands and Positional Arguments

Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.
Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. The
subcommand annotation can only be applied to required fields (i.e. fields that do not have a default value).
Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses
`dataclass`.

Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.
Expand Down Expand Up @@ -1284,6 +1285,9 @@ However, if your use case [aligns more with #2](#command-line-support), using Py
likely want required fields to be _strictly required at the CLI_. We can enable this behavior by using
`cli_enforce_required`.

!!! note
A required `CliPositionalArg` field is always strictly required (enforced) at the CLI.

```py
import os
import sys
Expand Down
64 changes: 51 additions & 13 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ def _load_env_vars(
if subcommand_dest not in selected_subcommands:
parsed_args[subcommand_dest] = self.cli_parse_none_str

parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')}
parsed_args = {
key: val
for key, val in parsed_args.items()
if not key.endswith(':subcommand') and val is not PydanticUndefined
}
if selected_subcommands:
last_selected_subcommand = max(selected_subcommands, key=len)
if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name):
Expand Down Expand Up @@ -1494,6 +1498,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
)

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_variadic_arg = []
positional_args, subcommand_args, optional_args = [], [], []
for field_name, field_info in _get_model_fields(model).items():
if _CliSubCommand in field_info.metadata:
Expand All @@ -1511,17 +1516,31 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
)
subcommand_args.append((field_name, field_info))
elif _CliPositionalArg in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
is_append_action = _annotation_contains_types(
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
)
if not is_append_action:
positional_args.append((field_name, field_info))
else:
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
positional_variadic_arg.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
optional_args.append((field_name, field_info))
return positional_args + subcommand_args + optional_args

if positional_variadic_arg:
if len(positional_variadic_arg) > 1:
field_names = ', '.join([name for name, info in positional_variadic_arg])
raise SettingsError(f'{model.__name__} has multiple variadic positonal arguments: {field_names}')
elif subcommand_args:
field_names = ', '.join([name for name, info in positional_variadic_arg + subcommand_args])
raise SettingsError(
f'{model.__name__} has variadic positonal arguments and subcommand arguments: {field_names}'
)

return positional_args + positional_variadic_arg + subcommand_args + optional_args

@property
def root_parser(self) -> T:
Expand Down Expand Up @@ -1727,11 +1746,9 @@ def _add_parser_args(
self._cli_dict_args[kwargs['dest']] = field_info.annotation

if _CliPositionalArg in field_info.metadata:
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
arg_names = [kwargs['dest']]
del kwargs['dest']
del kwargs['required']
flag_prefix = ''
arg_names, flag_prefix = self._convert_positional_arg(
kwargs, field_info, preferred_alias, model_default
)

self._convert_bool_flag(kwargs, field_info, model_default)

Expand Down Expand Up @@ -1787,6 +1804,27 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}'
)

def _convert_positional_arg(
self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str, model_default: Any
) -> tuple[list[str], str]:
flag_prefix = ''
arg_names = [kwargs['dest']]
kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
# conjunction with model_default instead of the derived kwargs['required'].
is_required = field_info.is_required() and model_default is PydanticUndefined
if kwargs.get('action') == 'append':
del kwargs['action']
kwargs['nargs'] = '+' if is_required else '*'
elif not is_required:
kwargs['nargs'] = '?'

del kwargs['dest']
del kwargs['required']
return arg_names, flag_prefix

def _get_arg_names(
self,
arg_prefix: str,
Expand Down
60 changes: 56 additions & 4 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,45 @@ class Cfg(BaseSettings):
assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}}


def test_cli_optional_positional_arg(env):
class Main(BaseSettings):
model_config = SettingsConfigDict(
cli_parse_args=True,
cli_enforce_required=True,
)

value: CliPositionalArg[int] = 123

assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 123}

env.set('VALUE', '456')
assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 456}

assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789}


def test_cli_variadic_positional_arg(env):
class MainRequired(BaseSettings):
model_config = SettingsConfigDict(cli_parse_args=True)

values: CliPositionalArg[List[int]]

class MainOptional(MainRequired):
values: CliPositionalArg[List[int]] = [1, 2, 3]

assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [1, 2, 3]}
with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'):
CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False)

env.set('VALUES', '[4,5,6]')
assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [4, 5, 6]}
with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'):
CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False)

assert CliApp.run(MainOptional, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]}
assert CliApp.run(MainRequired, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]}


def test_cli_enums(capsys, monkeypatch):
class Pet(IntEnum):
dog = 0
Expand Down Expand Up @@ -1416,13 +1455,26 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True):
PositionalArgNotOutermost()

with pytest.raises(
SettingsError, match='positional argument PositionalArgHasDefault.pos_arg has a default value'
SettingsError,
match='MultipleVariadicPositionialArgs has multiple variadic positonal arguments: strings, numbers',
):

class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True):
strings: CliPositionalArg[List[str]]
numbers: CliPositionalArg[List[int]]

MultipleVariadicPositionialArgs()

with pytest.raises(
SettingsError,
match='VariadicPositionialArgAndSubCommand has variadic positonal arguments and subcommand arguments: strings, sub_cmd',
):

class PositionalArgHasDefault(BaseSettings, cli_parse_args=True):
pos_arg: CliPositionalArg[str] = 'bad'
class VariadicPositionialArgAndSubCommand(BaseSettings, cli_parse_args=True):
strings: CliPositionalArg[List[str]]
sub_cmd: CliSubCommand[SubCmd]

PositionalArgHasDefault()
VariadicPositionialArgAndSubCommand()

with pytest.raises(
SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved <class 'str'>")
Expand Down

0 comments on commit e2cbb2d

Please sign in to comment.