Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Move necessary utils from flwr.cli.config_utils to flwr.common.config #4838

Merged
merged 8 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 8 additions & 139 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,13 @@
"""Utility to validate the `pyproject.toml` file."""


import zipfile
from io import BytesIO
from pathlib import Path
from typing import IO, Any, Optional, Union, get_args
from typing import Any, Optional, Union

import tomli
import typer

from flwr.common import object_ref
from flwr.common.typing import UserConfigValue


def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]:
"""Extract the config from a FAB file or path.

Parameters
----------
fab_file : Union[Path, bytes]
The Flower App Bundle file to validate and extract the metadata from.
It can either be a path to the file or the file itself as bytes.

Returns
-------
Dict[str, Any]
The `config` of the given Flower App Bundle.
"""
fab_file_archive: Union[Path, IO[bytes]]
if isinstance(fab_file, bytes):
fab_file_archive = BytesIO(fab_file)
elif isinstance(fab_file, Path):
fab_file_archive = fab_file
else:
raise ValueError("fab_file must be either a Path or bytes")

with zipfile.ZipFile(fab_file_archive, "r") as zipf:
with zipf.open("pyproject.toml") as file:
toml_content = file.read().decode("utf-8")

conf = load_from_string(toml_content)
if conf is None:
raise ValueError("Invalid TOML content in pyproject.toml")

is_valid, errors, _ = validate(conf, check_module=False)
if not is_valid:
raise ValueError(errors)

return conf
from flwr.common.config import get_fab_config, validate_config


def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]:
Expand Down Expand Up @@ -120,7 +80,7 @@ def load_and_validate(
]
return (None, errors, [])

is_valid, errors, warnings = validate(config, check_module, path.parent)
is_valid, errors, warnings = validate_config(config, check_module, path.parent)

if not is_valid:
return (None, errors, warnings)
Expand All @@ -133,102 +93,11 @@ def load(toml_path: Path) -> Optional[dict[str, Any]]:
if not toml_path.is_file():
return None

with toml_path.open(encoding="utf-8") as toml_file:
return load_from_string(toml_file.read())


def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None:
for key, value in config_dict.items():
if isinstance(value, dict):
_validate_run_config(config_dict[key], errors)
elif not isinstance(value, get_args(UserConfigValue)):
raise ValueError(
f"The value for key {key} needs to be of type `int`, `float`, "
"`bool, `str`, or a `dict` of those.",
)


# pylint: disable=too-many-branches
def validate_fields(config: dict[str, Any]) -> tuple[bool, list[str], list[str]]:
"""Validate pyproject.toml fields."""
errors = []
warnings = []

if "project" not in config:
errors.append("Missing [project] section")
else:
if "name" not in config["project"]:
errors.append('Property "name" missing in [project]')
if "version" not in config["project"]:
errors.append('Property "version" missing in [project]')
if "description" not in config["project"]:
warnings.append('Recommended property "description" missing in [project]')
if "license" not in config["project"]:
warnings.append('Recommended property "license" missing in [project]')
if "authors" not in config["project"]:
warnings.append('Recommended property "authors" missing in [project]')

if (
"tool" not in config
or "flwr" not in config["tool"]
or "app" not in config["tool"]["flwr"]
):
errors.append("Missing [tool.flwr.app] section")
else:
if "publisher" not in config["tool"]["flwr"]["app"]:
errors.append('Property "publisher" missing in [tool.flwr.app]')
if "config" in config["tool"]["flwr"]["app"]:
_validate_run_config(config["tool"]["flwr"]["app"]["config"], errors)
if "components" not in config["tool"]["flwr"]["app"]:
errors.append("Missing [tool.flwr.app.components] section")
else:
if "serverapp" not in config["tool"]["flwr"]["app"]["components"]:
errors.append(
'Property "serverapp" missing in [tool.flwr.app.components]'
)
if "clientapp" not in config["tool"]["flwr"]["app"]["components"]:
errors.append(
'Property "clientapp" missing in [tool.flwr.app.components]'
)

return len(errors) == 0, errors, warnings


def validate(
config: dict[str, Any],
check_module: bool = True,
project_dir: Optional[Union[str, Path]] = None,
) -> tuple[bool, list[str], list[str]]:
"""Validate pyproject.toml."""
is_valid, errors, warnings = validate_fields(config)

if not is_valid:
return False, errors, warnings

# Validate serverapp
serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"]
is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir)

if not is_valid and isinstance(reason, str):
return False, [reason], []

# Validate clientapp
clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir)

if not is_valid and isinstance(reason, str):
return False, [reason], []

return True, [], []


def load_from_string(toml_content: str) -> Optional[dict[str, Any]]:
"""Load TOML content from a string and return as dict."""
try:
data = tomli.loads(toml_content)
return data
except tomli.TOMLDecodeError:
return None
with toml_path.open("rb") as toml_file:
try:
return tomli.load(toml_file)
except tomli.TOMLDecodeError:
return None


def process_loaded_project_config(
Expand Down
180 changes: 0 additions & 180 deletions src/py/flwr/cli/config_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
from .config_utils import (
load,
process_loaded_project_config,
validate,
validate_certificate_in_federation_config,
validate_federation_in_project_config,
validate_fields,
)


Expand Down Expand Up @@ -163,184 +161,6 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None:
os.chdir(origin)


def test_validate_pyproject_toml_fields_empty() -> None:
"""Test that validate_pyproject_toml_fields fails correctly."""
# Prepare
config: dict[str, Any] = {}

# Execute
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
assert len(errors) == 2
assert len(warnings) == 0


def test_validate_pyproject_toml_fields_no_flower() -> None:
"""Test that validate_pyproject_toml_fields fails correctly."""
# Prepare
config = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
}
}

# Execute
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
assert len(errors) == 1
assert len(warnings) == 0


def test_validate_pyproject_toml_fields_no_flower_components() -> None:
"""Test that validate_pyproject_toml_fields fails correctly."""
# Prepare
config = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
},
"tool": {"flwr": {"app": {}}},
}

# Execute
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
assert len(errors) == 2
assert len(warnings) == 0


def test_validate_pyproject_toml_fields_no_server_and_client_app() -> None:
"""Test that validate_pyproject_toml_fields fails correctly."""
# Prepare
config = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
},
"tool": {"flwr": {"app": {"components": {}}}},
}

# Execute
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
assert len(errors) == 3
assert len(warnings) == 0


def test_validate_pyproject_toml_fields() -> None:
"""Test that validate_pyproject_toml_fields succeeds correctly."""
# Prepare
config = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
},
"tool": {
"flwr": {
"app": {
"publisher": "flwrlabs",
"components": {"serverapp": "", "clientapp": ""},
},
},
},
}

# Execute
is_valid, errors, warnings = validate_fields(config)

# Assert
assert is_valid
assert len(errors) == 0
assert len(warnings) == 0


def test_validate_pyproject_toml() -> None:
"""Test that validate_pyproject_toml succeeds correctly."""
# Prepare
config = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
},
"tool": {
"flwr": {
"app": {
"publisher": "flwrlabs",
"components": {
"serverapp": "flwr.cli.run:run",
"clientapp": "flwr.cli.run:run",
},
},
},
},
}

# Execute
is_valid, errors, warnings = validate(config)

# Assert
assert is_valid
assert not errors
assert not warnings


def test_validate_pyproject_toml_fail() -> None:
"""Test that validate_pyproject_toml fails correctly."""
# Prepare
config = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
},
"tool": {
"flwr": {
"app": {
"publisher": "flwrlabs",
"components": {
"serverapp": "flwr.cli.run:run",
"clientapp": "flwr.cli.run:runa",
},
},
},
},
}

# Execute
is_valid, errors, warnings = validate(config)

# Assert
assert not is_valid
assert len(errors) == 1
assert len(warnings) == 0


def test_validate_project_config_fail() -> None:
"""Test that validate_project_config fails correctly."""
# Prepare
Expand Down
Loading