diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index ae649a44d062..ef97b9bcf4d9 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -15,53 +15,13 @@ """Utility to validate the `pyproject.toml` file.""" -import zipfile -from io import BytesIO from pathlib import Path -from typing import IO, Any, Optional, Union, get_args +from typing import Any, Optional, Union import tomli import typer -from flwr.common import object_ref -from flwr.common.typing import UserConfigValue - - -def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]: - """Extract the config from a FAB file or path. - - Parameters - ---------- - fab_file : Union[Path, bytes] - The Flower App Bundle file to validate and extract the metadata from. - It can either be a path to the file or the file itself as bytes. - - Returns - ------- - Dict[str, Any] - The `config` of the given Flower App Bundle. - """ - fab_file_archive: Union[Path, IO[bytes]] - if isinstance(fab_file, bytes): - fab_file_archive = BytesIO(fab_file) - elif isinstance(fab_file, Path): - fab_file_archive = fab_file - else: - raise ValueError("fab_file must be either a Path or bytes") - - with zipfile.ZipFile(fab_file_archive, "r") as zipf: - with zipf.open("pyproject.toml") as file: - toml_content = file.read().decode("utf-8") - - conf = load_from_string(toml_content) - if conf is None: - raise ValueError("Invalid TOML content in pyproject.toml") - - is_valid, errors, _ = validate(conf, check_module=False) - if not is_valid: - raise ValueError(errors) - - return conf +from flwr.common.config import get_fab_config, validate_config def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]: @@ -120,7 +80,7 @@ def load_and_validate( ] return (None, errors, []) - is_valid, errors, warnings = validate(config, check_module, path.parent) + is_valid, errors, warnings = validate_config(config, check_module, path.parent) if not is_valid: return (None, errors, warnings) @@ -133,102 +93,11 @@ def load(toml_path: Path) -> Optional[dict[str, Any]]: if not toml_path.is_file(): return None - with toml_path.open(encoding="utf-8") as toml_file: - return load_from_string(toml_file.read()) - - -def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None: - for key, value in config_dict.items(): - if isinstance(value, dict): - _validate_run_config(config_dict[key], errors) - elif not isinstance(value, get_args(UserConfigValue)): - raise ValueError( - f"The value for key {key} needs to be of type `int`, `float`, " - "`bool, `str`, or a `dict` of those.", - ) - - -# pylint: disable=too-many-branches -def validate_fields(config: dict[str, Any]) -> tuple[bool, list[str], list[str]]: - """Validate pyproject.toml fields.""" - errors = [] - warnings = [] - - if "project" not in config: - errors.append("Missing [project] section") - else: - if "name" not in config["project"]: - errors.append('Property "name" missing in [project]') - if "version" not in config["project"]: - errors.append('Property "version" missing in [project]') - if "description" not in config["project"]: - warnings.append('Recommended property "description" missing in [project]') - if "license" not in config["project"]: - warnings.append('Recommended property "license" missing in [project]') - if "authors" not in config["project"]: - warnings.append('Recommended property "authors" missing in [project]') - - if ( - "tool" not in config - or "flwr" not in config["tool"] - or "app" not in config["tool"]["flwr"] - ): - errors.append("Missing [tool.flwr.app] section") - else: - if "publisher" not in config["tool"]["flwr"]["app"]: - errors.append('Property "publisher" missing in [tool.flwr.app]') - if "config" in config["tool"]["flwr"]["app"]: - _validate_run_config(config["tool"]["flwr"]["app"]["config"], errors) - if "components" not in config["tool"]["flwr"]["app"]: - errors.append("Missing [tool.flwr.app.components] section") - else: - if "serverapp" not in config["tool"]["flwr"]["app"]["components"]: - errors.append( - 'Property "serverapp" missing in [tool.flwr.app.components]' - ) - if "clientapp" not in config["tool"]["flwr"]["app"]["components"]: - errors.append( - 'Property "clientapp" missing in [tool.flwr.app.components]' - ) - - return len(errors) == 0, errors, warnings - - -def validate( - config: dict[str, Any], - check_module: bool = True, - project_dir: Optional[Union[str, Path]] = None, -) -> tuple[bool, list[str], list[str]]: - """Validate pyproject.toml.""" - is_valid, errors, warnings = validate_fields(config) - - if not is_valid: - return False, errors, warnings - - # Validate serverapp - serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"] - is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir) - - if not is_valid and isinstance(reason, str): - return False, [reason], [] - - # Validate clientapp - clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] - is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir) - - if not is_valid and isinstance(reason, str): - return False, [reason], [] - - return True, [], [] - - -def load_from_string(toml_content: str) -> Optional[dict[str, Any]]: - """Load TOML content from a string and return as dict.""" - try: - data = tomli.loads(toml_content) - return data - except tomli.TOMLDecodeError: - return None + with toml_path.open("rb") as toml_file: + try: + return tomli.load(toml_file) + except tomli.TOMLDecodeError: + return None def process_loaded_project_config( diff --git a/src/py/flwr/cli/config_utils_test.py b/src/py/flwr/cli/config_utils_test.py index c14a792f09b5..3b8607d2f308 100644 --- a/src/py/flwr/cli/config_utils_test.py +++ b/src/py/flwr/cli/config_utils_test.py @@ -26,10 +26,8 @@ from .config_utils import ( load, process_loaded_project_config, - validate, validate_certificate_in_federation_config, validate_federation_in_project_config, - validate_fields, ) @@ -163,184 +161,6 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: os.chdir(origin) -def test_validate_pyproject_toml_fields_empty() -> None: - """Test that validate_pyproject_toml_fields fails correctly.""" - # Prepare - config: dict[str, Any] = {} - - # Execute - is_valid, errors, warnings = validate_fields(config) - - # Assert - assert not is_valid - assert len(errors) == 2 - assert len(warnings) == 0 - - -def test_validate_pyproject_toml_fields_no_flower() -> None: - """Test that validate_pyproject_toml_fields fails correctly.""" - # Prepare - config = { - "project": { - "name": "fedgpt", - "version": "1.0.0", - "description": "", - "license": "", - "authors": [], - } - } - - # Execute - is_valid, errors, warnings = validate_fields(config) - - # Assert - assert not is_valid - assert len(errors) == 1 - assert len(warnings) == 0 - - -def test_validate_pyproject_toml_fields_no_flower_components() -> None: - """Test that validate_pyproject_toml_fields fails correctly.""" - # Prepare - config = { - "project": { - "name": "fedgpt", - "version": "1.0.0", - "description": "", - "license": "", - "authors": [], - }, - "tool": {"flwr": {"app": {}}}, - } - - # Execute - is_valid, errors, warnings = validate_fields(config) - - # Assert - assert not is_valid - assert len(errors) == 2 - assert len(warnings) == 0 - - -def test_validate_pyproject_toml_fields_no_server_and_client_app() -> None: - """Test that validate_pyproject_toml_fields fails correctly.""" - # Prepare - config = { - "project": { - "name": "fedgpt", - "version": "1.0.0", - "description": "", - "license": "", - "authors": [], - }, - "tool": {"flwr": {"app": {"components": {}}}}, - } - - # Execute - is_valid, errors, warnings = validate_fields(config) - - # Assert - assert not is_valid - assert len(errors) == 3 - assert len(warnings) == 0 - - -def test_validate_pyproject_toml_fields() -> None: - """Test that validate_pyproject_toml_fields succeeds correctly.""" - # Prepare - config = { - "project": { - "name": "fedgpt", - "version": "1.0.0", - "description": "", - "license": "", - "authors": [], - }, - "tool": { - "flwr": { - "app": { - "publisher": "flwrlabs", - "components": {"serverapp": "", "clientapp": ""}, - }, - }, - }, - } - - # Execute - is_valid, errors, warnings = validate_fields(config) - - # Assert - assert is_valid - assert len(errors) == 0 - assert len(warnings) == 0 - - -def test_validate_pyproject_toml() -> None: - """Test that validate_pyproject_toml succeeds correctly.""" - # Prepare - config = { - "project": { - "name": "fedgpt", - "version": "1.0.0", - "description": "", - "license": "", - "authors": [], - }, - "tool": { - "flwr": { - "app": { - "publisher": "flwrlabs", - "components": { - "serverapp": "flwr.cli.run:run", - "clientapp": "flwr.cli.run:run", - }, - }, - }, - }, - } - - # Execute - is_valid, errors, warnings = validate(config) - - # Assert - assert is_valid - assert not errors - assert not warnings - - -def test_validate_pyproject_toml_fail() -> None: - """Test that validate_pyproject_toml fails correctly.""" - # Prepare - config = { - "project": { - "name": "fedgpt", - "version": "1.0.0", - "description": "", - "license": "", - "authors": [], - }, - "tool": { - "flwr": { - "app": { - "publisher": "flwrlabs", - "components": { - "serverapp": "flwr.cli.run:run", - "clientapp": "flwr.cli.run:runa", - }, - }, - }, - }, - } - - # Execute - is_valid, errors, warnings = validate(config) - - # Assert - assert not is_valid - assert len(errors) == 1 - assert len(warnings) == 0 - - def test_validate_project_config_fail() -> None: """Test that validate_project_config fails correctly.""" # Prepare diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index f65da35ca562..e13c7f5f9fc4 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -17,13 +17,13 @@ import os import re +import zipfile +from io import BytesIO from pathlib import Path -from typing import Any, Optional, Union, cast, get_args +from typing import IO, Any, Optional, Union, cast, get_args import tomli -from flwr.cli.config_utils import get_fab_config, validate_fields -from flwr.common import ConfigsRecord from flwr.common.constant import ( APP_DIR, FAB_CONFIG_FILE, @@ -33,6 +33,8 @@ ) from flwr.common.typing import Run, UserConfig, UserConfigValue +from . import ConfigsRecord, object_ref + def get_flwr_dir(provided_path: Optional[str] = None) -> Path: """Return the Flower home directory based on env variables.""" @@ -80,7 +82,7 @@ def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]: config = tomli.loads(toml_file.read()) # Validate pyproject.toml fields - is_valid, errors, _ = validate_fields(config) + is_valid, errors, _ = validate_fields_in_config(config) if not is_valid: error_msg = "\n".join([f" - {error}" for error in errors]) raise ValueError( @@ -241,3 +243,127 @@ def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord: c_record[k] = v return c_record + + +def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]: + """Extract the config from a FAB file or path. + + Parameters + ---------- + fab_file : Union[Path, bytes] + The Flower App Bundle file to validate and extract the metadata from. + It can either be a path to the file or the file itself as bytes. + + Returns + ------- + Dict[str, Any] + The `config` of the given Flower App Bundle. + """ + fab_file_archive: Union[Path, IO[bytes]] + if isinstance(fab_file, bytes): + fab_file_archive = BytesIO(fab_file) + elif isinstance(fab_file, Path): + fab_file_archive = fab_file + else: + raise ValueError("fab_file must be either a Path or bytes") + + with zipfile.ZipFile(fab_file_archive, "r") as zipf: + with zipf.open("pyproject.toml") as file: + toml_content = file.read().decode("utf-8") + try: + conf = tomli.loads(toml_content) + except tomli.TOMLDecodeError: + raise ValueError("Invalid TOML content in pyproject.toml") from None + + is_valid, errors, _ = validate_config(conf, check_module=False) + if not is_valid: + raise ValueError(errors) + + return conf + + +def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None: + for key, value in config_dict.items(): + if isinstance(value, dict): + _validate_run_config(config_dict[key], errors) + elif not isinstance(value, get_args(UserConfigValue)): + raise ValueError( + f"The value for key {key} needs to be of type `int`, `float`, " + "`bool, `str`, or a `dict` of those.", + ) + + +# pylint: disable=too-many-branches +def validate_fields_in_config( + config: dict[str, Any] +) -> tuple[bool, list[str], list[str]]: + """Validate pyproject.toml fields.""" + errors = [] + warnings = [] + + if "project" not in config: + errors.append("Missing [project] section") + else: + if "name" not in config["project"]: + errors.append('Property "name" missing in [project]') + if "version" not in config["project"]: + errors.append('Property "version" missing in [project]') + if "description" not in config["project"]: + warnings.append('Recommended property "description" missing in [project]') + if "license" not in config["project"]: + warnings.append('Recommended property "license" missing in [project]') + if "authors" not in config["project"]: + warnings.append('Recommended property "authors" missing in [project]') + + if ( + "tool" not in config + or "flwr" not in config["tool"] + or "app" not in config["tool"]["flwr"] + ): + errors.append("Missing [tool.flwr.app] section") + else: + if "publisher" not in config["tool"]["flwr"]["app"]: + errors.append('Property "publisher" missing in [tool.flwr.app]') + if "config" in config["tool"]["flwr"]["app"]: + _validate_run_config(config["tool"]["flwr"]["app"]["config"], errors) + if "components" not in config["tool"]["flwr"]["app"]: + errors.append("Missing [tool.flwr.app.components] section") + else: + if "serverapp" not in config["tool"]["flwr"]["app"]["components"]: + errors.append( + 'Property "serverapp" missing in [tool.flwr.app.components]' + ) + if "clientapp" not in config["tool"]["flwr"]["app"]["components"]: + errors.append( + 'Property "clientapp" missing in [tool.flwr.app.components]' + ) + + return len(errors) == 0, errors, warnings + + +def validate_config( + config: dict[str, Any], + check_module: bool = True, + project_dir: Optional[Union[str, Path]] = None, +) -> tuple[bool, list[str], list[str]]: + """Validate pyproject.toml.""" + is_valid, errors, warnings = validate_fields_in_config(config) + + if not is_valid: + return False, errors, warnings + + # Validate serverapp + serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"] + is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir) + + if not is_valid and isinstance(reason, str): + return False, [reason], [] + + # Validate clientapp + clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] + is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir) + + if not is_valid and isinstance(reason, str): + return False, [reason], [] + + return True, [], [] diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index d7440333e964..161ec412b1d3 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -19,6 +19,7 @@ import tempfile import textwrap from pathlib import Path +from typing import Any from unittest.mock import patch import pytest @@ -33,6 +34,8 @@ get_project_dir, parse_config_args, unflatten_dict, + validate_config, + validate_fields_in_config, ) # Mock constants @@ -312,3 +315,181 @@ def test_parse_config_args_passing_toml_and_key_value() -> None: config = ["my-other-config.toml", "lr=0.1", "epochs=99"] with pytest.raises(ValueError): parse_config_args(config) + + +def test_validate_pyproject_toml_fields_empty() -> None: + """Test that validate_pyproject_toml_fields fails correctly.""" + # Prepare + config: dict[str, Any] = {} + + # Execute + is_valid, errors, warnings = validate_fields_in_config(config) + + # Assert + assert not is_valid + assert len(errors) == 2 + assert len(warnings) == 0 + + +def test_validate_pyproject_toml_fields_no_flower() -> None: + """Test that validate_pyproject_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + } + } + + # Execute + is_valid, errors, warnings = validate_fields_in_config(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 + + +def test_validate_pyproject_toml_fields_no_flower_components() -> None: + """Test that validate_pyproject_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "tool": {"flwr": {"app": {}}}, + } + + # Execute + is_valid, errors, warnings = validate_fields_in_config(config) + + # Assert + assert not is_valid + assert len(errors) == 2 + assert len(warnings) == 0 + + +def test_validate_pyproject_toml_fields_no_server_and_client_app() -> None: + """Test that validate_pyproject_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "tool": {"flwr": {"app": {"components": {}}}}, + } + + # Execute + is_valid, errors, warnings = validate_fields_in_config(config) + + # Assert + assert not is_valid + assert len(errors) == 3 + assert len(warnings) == 0 + + +def test_validate_pyproject_toml_fields() -> None: + """Test that validate_pyproject_toml_fields succeeds correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "tool": { + "flwr": { + "app": { + "publisher": "flwrlabs", + "components": {"serverapp": "", "clientapp": ""}, + }, + }, + }, + } + + # Execute + is_valid, errors, warnings = validate_fields_in_config(config) + + # Assert + assert is_valid + assert len(errors) == 0 + assert len(warnings) == 0 + + +def test_validate_pyproject_toml() -> None: + """Test that validate_pyproject_toml succeeds correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "tool": { + "flwr": { + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:run", + }, + }, + }, + }, + } + + # Execute + is_valid, errors, warnings = validate_config(config) + + # Assert + assert is_valid + assert not errors + assert not warnings + + +def test_validate_pyproject_toml_fail() -> None: + """Test that validate_pyproject_toml fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "tool": { + "flwr": { + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:runa", + }, + }, + }, + }, + } + + # Execute + is_valid, errors, warnings = validate_config(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0