From f10185157e6a748aeecdcf77c4d06c16e1db5a0a Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 10 Jan 2025 08:00:00 -0700 Subject: [PATCH 1/5] Add CLI support for optional positional args. --- pydantic_settings/sources.py | 38 ++++++++++++++++++++++++------------ tests/test_source_cli.py | 25 +++++++++++++++--------- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 5e64164..cc907b0 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -1333,7 +1333,11 @@ def _load_env_vars( if subcommand_dest not in selected_subcommands: parsed_args[subcommand_dest] = self.cli_parse_none_str - parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')} + parsed_args = { + key: val + for key, val in parsed_args.items() + if not key.endswith(':subcommand') and val is not PydanticUndefined + } if selected_subcommands: last_selected_subcommand = max(selected_subcommands, key=len) if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name): @@ -1511,12 +1515,9 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] ) subcommand_args.append((field_name, field_info)) elif _CliPositionalArg in field_info.metadata: - if not field_info.is_required(): - raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value') - else: - alias_names, *_ = _get_alias_names(field_name, field_info) - if len(alias_names) > 1: - raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') + alias_names, *_ = _get_alias_names(field_name, field_info) + if len(alias_names) > 1: + raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') positional_args.append((field_name, field_info)) else: self._verify_cli_flag_annotations(model, field_name, field_info) @@ -1727,11 +1728,7 @@ def _add_parser_args( self._cli_dict_args[kwargs['dest']] = field_info.annotation if _CliPositionalArg in field_info.metadata: - kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) - arg_names = [kwargs['dest']] - del kwargs['dest'] - del kwargs['required'] - flag_prefix = '' + arg_names, flag_prefix = self._convert_positional_arg(kwargs, field_info, preferred_alias) self._convert_bool_flag(kwargs, field_info, model_default) @@ -1787,6 +1784,23 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}' ) + def _convert_positional_arg( + self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str + ) -> tuple[list[str], str]: + flag_prefix = '' + arg_names = [kwargs['dest']] + kwargs['default'] = PydanticUndefined + kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) + + # Note: For positional args, we must strictly look at field_info.is_required instead of our derived + # kwargs['required']. + if not field_info.is_required(): + kwargs['nargs'] = '?' + + del kwargs['dest'] + del kwargs['required'] + return arg_names, flag_prefix + def _get_arg_names( self, arg_prefix: str, diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 35bfcda..d944c9e 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -1297,6 +1297,22 @@ class Cfg(BaseSettings): assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}} +def test_cli_optional_positional_arg(env): + class Main(BaseSettings): + model_config = SettingsConfigDict( + cli_parse_args=True, + cli_enforce_required=True, + ) + + value: CliPositionalArg[int] = 123 + + assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 123} + + env.set('VALUE', '456') + assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 456} + + assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789} + def test_cli_enums(capsys, monkeypatch): class Pet(IntEnum): dog = 0 @@ -1415,15 +1431,6 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True): PositionalArgNotOutermost() - with pytest.raises( - SettingsError, match='positional argument PositionalArgHasDefault.pos_arg has a default value' - ): - - class PositionalArgHasDefault(BaseSettings, cli_parse_args=True): - pos_arg: CliPositionalArg[str] = 'bad' - - PositionalArgHasDefault() - with pytest.raises( SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved ") ): From 7b3b48d83a49e68cd21412c83baa30b5621ba9a3 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 10 Jan 2025 08:03:09 -0700 Subject: [PATCH 2/5] Docs. --- docs/index.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/index.md b/docs/index.md index e4e0d26..18969f4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -842,9 +842,10 @@ print(User().model_dump()) ### Subcommands and Positional Arguments -Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These -annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore, -subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. +Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. The +subcommand annotation can only be applied to required fields (i.e. fields that do not have a default value). +Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses +`dataclass`. Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found. From eca638bf7b5ac9736c4fc7ee8e8e84d4fbf2c19b Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 10 Jan 2025 08:05:19 -0700 Subject: [PATCH 3/5] Lint. --- tests/test_source_cli.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index d944c9e..c7a3526 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -1313,6 +1313,7 @@ class Main(BaseSettings): assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789} + def test_cli_enums(capsys, monkeypatch): class Pet(IntEnum): dog = 0 From cb48bb55a2b670cf335392d689334754d2ca35dd Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Sat, 11 Jan 2025 08:17:08 -0700 Subject: [PATCH 4/5] Add CLI support for variadic positional args. --- docs/index.md | 3 +++ pydantic_settings/sources.py | 38 +++++++++++++++++++++++++------ tests/test_source_cli.py | 44 ++++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/docs/index.md b/docs/index.md index 18969f4..fd43bab 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1285,6 +1285,9 @@ However, if your use case [aligns more with #2](#command-line-support), using Py likely want required fields to be _strictly required at the CLI_. We can enable this behavior by using `cli_enforce_required`. +!!! note + A required `CliPositionalArg` field is always strictly required (enforced) at the CLI. + ```py import os import sys diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index cc907b0..584c7cd 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -1498,6 +1498,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, ) def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]: + positional_variadic_arg = [] positional_args, subcommand_args, optional_args = [], [], [] for field_name, field_info in _get_model_fields(model).items(): if _CliSubCommand in field_info.metadata: @@ -1518,11 +1519,28 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') - positional_args.append((field_name, field_info)) + is_append_action = _annotation_contains_types( + field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True + ) + if not is_append_action: + positional_args.append((field_name, field_info)) + else: + positional_variadic_arg.append((field_name, field_info)) else: self._verify_cli_flag_annotations(model, field_name, field_info) optional_args.append((field_name, field_info)) - return positional_args + subcommand_args + optional_args + + if positional_variadic_arg: + if len(positional_variadic_arg) > 1: + field_names = ', '.join([name for name, info in positional_variadic_arg]) + raise SettingsError(f'{model.__name__} has multiple variadic positonal arguments: {field_names}') + elif subcommand_args: + field_names = ', '.join([name for name, info in positional_variadic_arg + subcommand_args]) + raise SettingsError( + f'{model.__name__} has variadic positonal arguments and subcommand arguments: {field_names}' + ) + + return positional_args + positional_variadic_arg + subcommand_args + optional_args @property def root_parser(self) -> T: @@ -1728,7 +1746,9 @@ def _add_parser_args( self._cli_dict_args[kwargs['dest']] = field_info.annotation if _CliPositionalArg in field_info.metadata: - arg_names, flag_prefix = self._convert_positional_arg(kwargs, field_info, preferred_alias) + arg_names, flag_prefix = self._convert_positional_arg( + kwargs, field_info, preferred_alias, model_default + ) self._convert_bool_flag(kwargs, field_info, model_default) @@ -1785,16 +1805,20 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode ) def _convert_positional_arg( - self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str + self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str, model_default: Any ) -> tuple[list[str], str]: flag_prefix = '' arg_names = [kwargs['dest']] kwargs['default'] = PydanticUndefined kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) - # Note: For positional args, we must strictly look at field_info.is_required instead of our derived - # kwargs['required']. - if not field_info.is_required(): + # Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in + # conjunction with model_default instead of the derived kwargs['required']. + is_required = field_info.is_required() and model_default is PydanticUndefined + if kwargs.get('action') == 'append': + del kwargs['action'] + kwargs['nargs'] = '+' if is_required else '*' + elif not is_required: kwargs['nargs'] = '?' del kwargs['dest'] diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index c7a3526..5cf0319 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -1314,6 +1314,28 @@ class Main(BaseSettings): assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789} +def test_cli_variadic_positional_arg(env): + class MainRequired(BaseSettings): + model_config = SettingsConfigDict(cli_parse_args=True) + + values: CliPositionalArg[list[int]] + + class MainOptional(MainRequired): + values: CliPositionalArg[list[int]] = [1, 2, 3] + + assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [1, 2, 3]} + with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'): + CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False) + + env.set('VALUES', '[4,5,6]') + assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [4, 5, 6]} + with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'): + CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False) + + assert CliApp.run(MainOptional, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]} + assert CliApp.run(MainRequired, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]} + + def test_cli_enums(capsys, monkeypatch): class Pet(IntEnum): dog = 0 @@ -1432,6 +1454,28 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True): PositionalArgNotOutermost() + with pytest.raises( + SettingsError, + match='MultipleVariadicPositionialArgs has multiple variadic positonal arguments: strings, numbers', + ): + + class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True): + strings: CliPositionalArg[list[str]] + numbers: CliPositionalArg[list[int]] + + MultipleVariadicPositionialArgs() + + with pytest.raises( + SettingsError, + match='VariadicPositionialArgAndSubCommand has variadic positonal arguments and subcommand arguments: strings, sub_cmd', + ): + + class VariadicPositionialArgAndSubCommand(BaseSettings, cli_parse_args=True): + strings: CliPositionalArg[list[str]] + sub_cmd: CliSubCommand[SubCmd] + + VariadicPositionialArgAndSubCommand() + with pytest.raises( SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved ") ): From 93443fb545463982d327a6cd9581060bb9b27845 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Sat, 11 Jan 2025 08:21:00 -0700 Subject: [PATCH 5/5] Py 3.8 type hints. --- tests/test_source_cli.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 5cf0319..3c59016 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -1318,10 +1318,10 @@ def test_cli_variadic_positional_arg(env): class MainRequired(BaseSettings): model_config = SettingsConfigDict(cli_parse_args=True) - values: CliPositionalArg[list[int]] + values: CliPositionalArg[List[int]] class MainOptional(MainRequired): - values: CliPositionalArg[list[int]] = [1, 2, 3] + values: CliPositionalArg[List[int]] = [1, 2, 3] assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [1, 2, 3]} with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'): @@ -1460,8 +1460,8 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True): ): class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True): - strings: CliPositionalArg[list[str]] - numbers: CliPositionalArg[list[int]] + strings: CliPositionalArg[List[str]] + numbers: CliPositionalArg[List[int]] MultipleVariadicPositionialArgs() @@ -1471,7 +1471,7 @@ class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True): ): class VariadicPositionialArgAndSubCommand(BaseSettings, cli_parse_args=True): - strings: CliPositionalArg[list[str]] + strings: CliPositionalArg[List[str]] sub_cmd: CliSubCommand[SubCmd] VariadicPositionialArgAndSubCommand()