Skip to content

Commit

Permalink
Add get_subcommand function. (pydantic#341)
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab authored Aug 27, 2024
1 parent cabcdee commit 47924f5
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 101 deletions.
122 changes: 35 additions & 87 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,9 @@ Subcommands and positional arguments are expressed using the `CliSubCommand` and
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.

Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.

!!! note
CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model
are grouped under a single subparser; it does not allow for multiple subparsers with each subparser having its own
Expand All @@ -759,114 +762,59 @@ subcommands must be a valid type derived from either a pydantic `BaseModel` or p
```py
import sys

from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic import BaseModel

from pydantic_settings import (
BaseSettings,
CliPositionalArg,
CliSubCommand,
SettingsError,
get_subcommand,
)


@dataclass
class FooPlugin:
"""git-plugins-foo - Extra deep foo plugin command"""

x_feature: bool = Field(default=False, description='Enable "X" feature')


@dataclass
class BarPlugin:
"""git-plugins-bar - Extra deep bar plugin command"""

y_feature: bool = Field(default=False, description='Enable "Y" feature')


@dataclass
class Plugins:
"""git-plugins - Fake plugins for GIT"""

foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin')

bar: CliSubCommand[BarPlugin] = Field(description='Bar is fake plugin')
class Init(BaseModel):
directory: CliPositionalArg[str]


class Clone(BaseModel):
"""git-clone - Clone a repository into a new directory"""

repository: CliPositionalArg[str] = Field(description='The repo ...')

directory: CliPositionalArg[str] = Field(description='The dir ...')

local: bool = Field(default=False, description='When the repo ...')


class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'):
"""git - The stupid content tracker"""

clone: CliSubCommand[Clone] = Field(description='Clone a repo ...')

plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugins')


try:
sys.argv = ['example.py', '--help']
Git()
except SystemExit as e:
print(e)
#> 0
"""
usage: git [-h] {clone,plugins} ...
repository: CliPositionalArg[str]
directory: CliPositionalArg[str]

git - The stupid content tracker

options:
-h, --help show this help message and exit
class Git(BaseSettings, cli_parse_args=True, cli_exit_on_error=False):
clone: CliSubCommand[Clone]
init: CliSubCommand[Init]

subcommands:
{clone,plugins}
clone Clone a repo ...
plugins Fake GIT plugins
"""

# Run without subcommands
sys.argv = ['example.py']
cmd = Git()
assert cmd.model_dump() == {'clone': None, 'init': None}

try:
sys.argv = ['example.py', 'clone', '--help']
Git()
except SystemExit as e:
print(e)
#> 0
"""
usage: git clone [-h] [--local bool] [--shared bool] REPOSITORY DIRECTORY
# Will raise an error since no subcommand was provided
get_subcommand(cmd).model_dump()
except SettingsError as err:
assert str(err) == 'Error: CLI subcommand is required {clone, init}'

git-clone - Clone a repository into a new directory
# Will not raise an error since subcommand is not required
assert get_subcommand(cmd, is_required=False) is None

positional arguments:
REPOSITORY The repo ...
DIRECTORY The dir ...
options:
-h, --help show this help message and exit
--local bool When the repo ... (default: False)
"""

# Run the clone subcommand
sys.argv = ['example.py', 'clone', 'repo', 'dest']
cmd = Git()
assert cmd.model_dump() == {
'clone': {'repository': 'repo', 'directory': 'dest'},
'init': None,
}

try:
sys.argv = ['example.py', 'plugins', 'bar', '--help']
Git()
except SystemExit as e:
print(e)
#> 0
"""
usage: git plugins bar [-h] [--my_feature bool]
git-plugins-bar - Extra deep bar plugin command
options:
-h, --help show this help message and exit
--y_feature bool Enable "Y" feature (default: False)
"""
# Returns the subcommand model instance (in this case, 'clone')
assert get_subcommand(cmd).model_dump() == {
'directory': 'dest',
'repository': 'repo',
}
```

### Customizing the CLI Experience
Expand Down
2 changes: 2 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SettingsError,
TomlConfigSettingsSource,
YamlConfigSettingsSource,
get_subcommand,
)
from .version import VERSION

Expand All @@ -38,6 +39,7 @@
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
'AzureKeyVaultSettingsSource',
'get_subcommand',
'__version__',
)

Expand Down
72 changes: 58 additions & 14 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,53 @@ def error(self, message: str) -> NoReturn:
CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag]


def get_subcommand(model: BaseModel, is_required: bool = True, cli_exit_on_error: bool | None = None) -> Any:
"""
Get the subcommand from a model.
Args:
model: The model to get the subcommand from.
is_required: Determines whether a model must have subcommand set and raises error if not
found. Defaults to `True`.
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
Returns:
The subcommand model if found, otherwise `None`.
Raises:
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
(the default).
SettingsError: When no subcommand is found and is_required=`True` and
cli_exit_on_error=`False`.
"""

model_cls = type(model)
if cli_exit_on_error is None and is_model_class(model_cls):
model_default = model.model_config.get('cli_exit_on_error')
if isinstance(model_default, bool):
cli_exit_on_error = model_default
if cli_exit_on_error is None:
cli_exit_on_error = True

subcommands: list[str] = []
for field_name, field_info in _get_model_fields(model_cls).items():
if _CliSubCommand in field_info.metadata:
if getattr(model, field_name) is not None:
return getattr(model, field_name)
subcommands.append(field_name)

if is_required:
error_message = (
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
if subcommands
else 'Error: CLI subcommand is required but no subcommands were found.'
)
raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)

return None


class EnvNoneType(str):
pass

Expand Down Expand Up @@ -763,11 +810,7 @@ class Cfg(BaseSettings):
if type_has_key:
return type_has_key
elif is_model_class(annotation) or is_pydantic_dataclass(annotation):
fields = (
annotation.__pydantic_fields__
if is_pydantic_dataclass(annotation) and hasattr(annotation, '__pydantic_fields__')
else cast(BaseModel, annotation).model_fields
)
fields = _get_model_fields(annotation)
# `case_sensitive is None` is here to be compatible with the old behavior.
# Has to be removed in V3.
if (case_sensitive is None or case_sensitive) and fields.get(key):
Expand Down Expand Up @@ -1376,12 +1419,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_args, subcommand_args, optional_args = [], [], []
fields = (
model.__pydantic_fields__
if hasattr(model, '__pydantic_fields__') and is_pydantic_dataclass(model)
else model.model_fields
)
for field_name, field_info in fields.items():
for field_name, field_info in _get_model_fields(model).items():
if _CliSubCommand in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
Expand Down Expand Up @@ -1496,9 +1534,7 @@ def _add_parser_args(
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
if _CliSubCommand in field_info.metadata:
if subparsers is None:
subparsers = self._add_subparsers(
parser, title='subcommands', dest=f'{arg_prefix}:subcommand', required=self.cli_enforce_required
)
subparsers = self._add_subparsers(parser, title='subcommands', dest=f'{arg_prefix}:subcommand')
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}']
else:
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}')
Expand Down Expand Up @@ -2095,5 +2131,13 @@ def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any
return None


def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]:
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
return model_cls.__pydantic_fields__
if is_model_class(model_cls):
return model_cls.model_fields
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')


def _is_function(obj: Any) -> bool:
return inspect.isfunction(obj) or inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.ismethod(obj)
44 changes: 44 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AliasChoices,
AliasPath,
BaseModel,
ConfigDict,
DirectoryPath,
Discriminator,
Field,
Expand Down Expand Up @@ -59,6 +60,7 @@
CliSubCommand,
DefaultSettingsSource,
SettingsError,
get_subcommand,
)

try:
Expand Down Expand Up @@ -3095,6 +3097,12 @@ class FooPlugin:
class BarPlugin:
my_feature: bool = False

bar = BarPlugin()
with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(bar)
with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(bar, cli_exit_on_error=False)

@pydantic_dataclasses.dataclass
class Plugins:
foo: CliSubCommand[FooPlugin]
Expand All @@ -3116,26 +3124,62 @@ class Git(BaseSettings):
init: CliSubCommand[Init]
plugins: CliSubCommand[Plugins]

git = Git(_cli_parse_args=[])
assert git.model_dump() == {
'clone': None,
'init': None,
'plugins': None,
}
assert get_subcommand(git, is_required=False) is None
with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'):
get_subcommand(git)
with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'):
get_subcommand(git, cli_exit_on_error=False)

git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path'])
assert git.model_dump() == {
'clone': None,
'init': {'directory': 'dir/path', 'quiet': True, 'bare': False},
'plugins': None,
}
assert get_subcommand(git) == git.init
assert get_subcommand(git, is_required=False) == git.init

git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true'])
assert git.model_dump() == {
'clone': {'repository': 'repo', 'directory': '.', 'local': False, 'shared': True},
'init': None,
'plugins': None,
}
assert get_subcommand(git) == git.clone
assert get_subcommand(git, is_required=False) == git.clone

git = Git(_cli_parse_args=['plugins', 'bar'])
assert git.model_dump() == {
'clone': None,
'init': None,
'plugins': {'foo': None, 'bar': {'my_feature': False}},
}
assert get_subcommand(git) == git.plugins
assert get_subcommand(git, is_required=False) == git.plugins
assert get_subcommand(get_subcommand(git)) == git.plugins.bar
assert get_subcommand(get_subcommand(git), is_required=False) == git.plugins.bar

class NotModel: ...

with pytest.raises(
SettingsError, match='Error: NotModel is not subclass of BaseModel or pydantic.dataclasses.dataclass'
):
get_subcommand(NotModel())

class NotSettingsConfigDict(BaseModel):
model_config = ConfigDict(cli_exit_on_error='not a bool')

with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(NotSettingsConfigDict())

with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False)


def test_cli_union_similar_sub_models():
Expand Down

0 comments on commit 47924f5

Please sign in to comment.