Skip to content

Commit

Permalink
Add CLI support for variadic positional args.
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab committed Jan 11, 2025
1 parent eca638b commit cb48bb5
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,9 @@ However, if your use case [aligns more with #2](#command-line-support), using Py
likely want required fields to be _strictly required at the CLI_. We can enable this behavior by using
`cli_enforce_required`.

!!! note
A required `CliPositionalArg` field is always strictly required (enforced) at the CLI.

```py
import os
import sys
Expand Down
38 changes: 31 additions & 7 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
)

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_variadic_arg = []
positional_args, subcommand_args, optional_args = [], [], []
for field_name, field_info in _get_model_fields(model).items():
if _CliSubCommand in field_info.metadata:
Expand All @@ -1518,11 +1519,28 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
is_append_action = _annotation_contains_types(
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
)
if not is_append_action:
positional_args.append((field_name, field_info))
else:
positional_variadic_arg.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
optional_args.append((field_name, field_info))
return positional_args + subcommand_args + optional_args

if positional_variadic_arg:
if len(positional_variadic_arg) > 1:
field_names = ', '.join([name for name, info in positional_variadic_arg])
raise SettingsError(f'{model.__name__} has multiple variadic positonal arguments: {field_names}')
elif subcommand_args:
field_names = ', '.join([name for name, info in positional_variadic_arg + subcommand_args])
raise SettingsError(
f'{model.__name__} has variadic positonal arguments and subcommand arguments: {field_names}'
)

return positional_args + positional_variadic_arg + subcommand_args + optional_args

@property
def root_parser(self) -> T:
Expand Down Expand Up @@ -1728,7 +1746,9 @@ def _add_parser_args(
self._cli_dict_args[kwargs['dest']] = field_info.annotation

if _CliPositionalArg in field_info.metadata:
arg_names, flag_prefix = self._convert_positional_arg(kwargs, field_info, preferred_alias)
arg_names, flag_prefix = self._convert_positional_arg(
kwargs, field_info, preferred_alias, model_default
)

self._convert_bool_flag(kwargs, field_info, model_default)

Expand Down Expand Up @@ -1785,16 +1805,20 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
)

def _convert_positional_arg(
self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str
self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str, model_default: Any
) -> tuple[list[str], str]:
flag_prefix = ''
arg_names = [kwargs['dest']]
kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

# Note: For positional args, we must strictly look at field_info.is_required instead of our derived
# kwargs['required'].
if not field_info.is_required():
# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
# conjunction with model_default instead of the derived kwargs['required'].
is_required = field_info.is_required() and model_default is PydanticUndefined
if kwargs.get('action') == 'append':
del kwargs['action']
kwargs['nargs'] = '+' if is_required else '*'
elif not is_required:
kwargs['nargs'] = '?'

del kwargs['dest']
Expand Down
44 changes: 44 additions & 0 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,28 @@ class Main(BaseSettings):
assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789}


def test_cli_variadic_positional_arg(env):
class MainRequired(BaseSettings):
model_config = SettingsConfigDict(cli_parse_args=True)

values: CliPositionalArg[list[int]]

class MainOptional(MainRequired):
values: CliPositionalArg[list[int]] = [1, 2, 3]

assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [1, 2, 3]}
with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'):
CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False)

env.set('VALUES', '[4,5,6]')
assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [4, 5, 6]}
with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'):
CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False)

assert CliApp.run(MainOptional, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]}
assert CliApp.run(MainRequired, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]}


def test_cli_enums(capsys, monkeypatch):
class Pet(IntEnum):
dog = 0
Expand Down Expand Up @@ -1432,6 +1454,28 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True):

PositionalArgNotOutermost()

with pytest.raises(
SettingsError,
match='MultipleVariadicPositionialArgs has multiple variadic positonal arguments: strings, numbers',
):

class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True):
strings: CliPositionalArg[list[str]]
numbers: CliPositionalArg[list[int]]

MultipleVariadicPositionialArgs()

with pytest.raises(
SettingsError,
match='VariadicPositionialArgAndSubCommand has variadic positonal arguments and subcommand arguments: strings, sub_cmd',
):

class VariadicPositionialArgAndSubCommand(BaseSettings, cli_parse_args=True):
strings: CliPositionalArg[list[str]]
sub_cmd: CliSubCommand[SubCmd]

VariadicPositionialArgAndSubCommand()

with pytest.raises(
SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved <class 'str'>")
):
Expand Down

0 comments on commit cb48bb5

Please sign in to comment.