From fc97c56defb6057915a4649135aa7b07b57c81d5 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 15 Jan 2025 16:41:33 -0500 Subject: [PATCH] feat: add support for multi-file uploads to resource import workflow --- src/dioptra/client/workflows.py | 47 ++++++++++++++----- .../restapi/v1/workflows/controller.py | 1 + src/dioptra/restapi/v1/workflows/schema.py | 26 +++++++--- src/dioptra/restapi/v1/workflows/service.py | 32 +++++++++++-- tests/unit/restapi/v1/conftest.py | 22 +++++++-- .../v1/test_workflow_resource_import.py | 15 +++++- 6 files changed, 113 insertions(+), 30 deletions(-) diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 86a869145..f31f3ce46 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -113,11 +113,22 @@ def import_resources( """Signature for using import_resource from archive file""" ... # pragma: nocover + @overload + def import_resources( + self, + files: list[DioptraFile], + config_path: str | None = "dioptra.toml", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from archive file""" + ... # pragma: nocover + def import_resources( self, group_id, git_url=None, archive_file=None, + files=None, config_path="dioptra.toml", resolve_name_conflicts_strategy="fail", ): @@ -126,36 +137,46 @@ def import_resources( Args: group_id: The group to import resources into - source_type: The source to import from (either "upload" or "git") + source_type: The source to import from git_url: The url to the git repository if source_type is "git" - archive_file: The contents of the upload if source_type is "upload" + archive_file: The contents of the upload if source_type is "upload_archive" + files: The contents of the upload if source_type is "upload_files" config_path: The path to the toml configuration file in the import source. resolve_name_conflicts_strategy: The strategy for resolving name conflicts. Either "fail" or "overwrite" Raises: - IllegalArgumentError: If only one of archive_file + IllegalArgumentError: If more than one import source is provided or if no + import source is provided. """ - if archive_file is None and git_url is None: + import_source_args = [git_url, archive_file, files] + num_provided_import_source_args = sum( + arg is not None for arg in import_source_args + ) + + if num_provided_import_source_args == 0: raise IllegalArgumentError( - "One of 'archive_file' and 'git_url' must be provided" + "One of (git_url, archive_file, or files) must be provided" ) - - if archive_file is not None and git_url is not None: + elif num_provided_import_source_args > 1: raise IllegalArgumentError( - "Only one of 'archive_file' and 'git_url' can be provided" + "Only one of (git_url, archive_file and files) can be provided" ) - data: dict[str, Any] = {"group": group_id} - files: dict[str, DioptraFile | list[DioptraFile]] = {} + data: dict[str, Any] = {"group": str(group_id)} + files_: dict[str, DioptraFile | list[DioptraFile]] = {} if git_url is not None: data["sourceType"] = "git" data["gitUrl"] = git_url if archive_file is not None: - data["sourceType"] = "upload" - files["archiveFile"] = archive_file + data["sourceType"] = "upload_archive" + files_["archiveFile"] = archive_file + + if files is not None: + data["sourceType"] = "upload_files" + files_["files"] = files if config_path is not None: data["configPath"] = config_path @@ -164,5 +185,5 @@ def import_resources( data["resolveNameConflictsStrategy"] = resolve_name_conflicts_strategy return self._session.post( - self.url, RESOURCE_IMPORT, data=data, files=files or None + self.url, RESOURCE_IMPORT, data=data, files=files_ or None ) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index fd5cd9b6f..e469b6e23 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -124,6 +124,7 @@ def post(self): source_type=parsed_form["source_type"], git_url=parsed_form.get("git_url", None), archive_file=request.files.get("archiveFile", None), + files=request.files.getlist("files", None), config_path=parsed_form["config_path"], resolve_name_conflicts_strategy=parsed_form[ "resolve_name_conflicts_strategy" diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 6103e6123..4b7fcda5e 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -19,7 +19,7 @@ from marshmallow import Schema, ValidationError, fields, validates_schema -from dioptra.restapi.custom_schema_fields import FileUpload +from dioptra.restapi.custom_schema_fields import FileUpload, MultiFileUpload class FileTypes(Enum): @@ -47,7 +47,8 @@ class JobFilesDownloadQueryParametersSchema(Schema): class ResourceImportSourceTypes(Enum): GIT = "git" - UPLOAD = "upload" + UPLOAD_ARCHIVE = "upload_archive" + UPLOAD_FILES = "upload_files" class ResourceImportResolveNameConflictsStrategy(Enum): @@ -69,14 +70,17 @@ class ResourceImportSchema(Schema): sourceType = fields.Enum( ResourceImportSourceTypes, attribute="source_type", - metadata=dict(description="The source of the resources to import."), + metadata=dict( + description="The source of the resources to import" + "('upload_archive', 'upload_files', or 'git'." + ), by_value=True, required=True, ) gitUrl = fields.String( attribute="git_url", metadata=dict( - description="The URL of the git repository containing resources to import. " + description="The URL of the git repository containing resources to import." "A git branch can optionally be specified by appending #BRANCH_NAME. " "Used when sourceType is 'git'." ), @@ -87,8 +91,18 @@ class ResourceImportSchema(Schema): metadata=dict( type="file", format="binary", - description="The archive file containing resources to import (.tar.gz). " - "Used when sourceType is 'upload'.", + description="The archive file containing resources to import (.tar.gz)." + "Used when sourceType is 'upload_archive'.", + ), + required=False, + ) + files = MultiFileUpload( + attribute="files", + metadata=dict( + type="file", + format="binary", + description="The files containing the resources to import." + "Used when sourceType is 'upload_files'.", ), required=False, ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index ab9012260..8ea0182b7 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -162,6 +162,7 @@ def import_resources( source_type: str, git_url: str | None, archive_file: FileStorage | None, + files: list[FileStorage] | None, config_path: str, resolve_name_conflicts_strategy: str, **kwargs, @@ -172,7 +173,8 @@ def import_resources( group_id: The group to import resources into source_type: The source to import from (either "upload" or "git") git_url: The url to the git repository if source_type is "git" - archive_file: The contents of the upload if source_type is "upload" + archive_file: The contents of the upload if source_type is "upload_archive" + files: The contents of the upload if source_type is "upload_files" config_path: The path to the toml configuration file in the import source. resolve_name_conflicts_strategy: The strategy for resolving name conflicts. Either "fail" or "overwrite" @@ -191,7 +193,7 @@ def import_resources( with TemporaryDirectory() as tmp_dir, set_cwd(tmp_dir): working_dir = Path(tmp_dir) - if source_type == ResourceImportSourceTypes.UPLOAD: + if source_type == ResourceImportSourceTypes.UPLOAD_ARCHIVE: bytes = archive_file.stream.read() try: with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: @@ -199,11 +201,20 @@ def import_resources( except Exception as e: raise ImportFailedError("Failed to read uploaded tarfile") from e hash = str(sha256(bytes).hexdigest()) + elif source_type == ResourceImportSourceTypes.UPLOAD_FILES: + hashes = b"" + for file in files: + Path(file.filename).parent.mkdir(parents=True, exist_ok=True) + bytes = file.stream.read() + with open(working_dir / file.filename, "wb") as f: + f.write(bytes) + hashes = hashes + sha256(bytes).digest() + hash = str(sha256(hashes).hexdigest()) else: try: hash = clone_git_repository(cast(str, git_url), working_dir) except Exception as e: - raise GitError("Failed to clone repository: {git_url}") from e + raise GitError(f"Failed to clone repository: {git_url}") from e try: config = toml.load(working_dir / config_path) @@ -345,7 +356,12 @@ def _register_plugins( tasks = self._build_tasks(plugin.get("tasks", []), param_types) for plugin_file_path in Path(plugin["path"]).rglob("*.py"): filename = str(plugin_file_path.relative_to(plugin["path"])) - contents = plugin_file_path.read_text() + try: + contents = plugin_file_path.read_text() + except FileNotFoundError as e: + raise ImportFailedError( + f"Failed to read plugin file from {plugin_file_path}" + ) from e self._plugin_id_file_service.create( filename, @@ -394,7 +410,13 @@ def _register_entrypoints( entrypoint_id=existing.resource_id ) - contents = Path(entrypoint["path"]).read_text() + try: + contents = Path(entrypoint["path"]).read_text() + except FileNotFoundError as e: + raise ImportFailedError( + f"Failed to read plugin file from {entrypoint['path']}" + ) from e + params = [ { "name": param["name"], diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 2e7f24f94..f1b09d941 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -15,6 +15,7 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Fixtures representing resources needed for test suites""" +import os import tarfile import textwrap from collections.abc import Iterator @@ -32,7 +33,11 @@ from injector import Injector from pytest import MonkeyPatch -from dioptra.client import DioptraFile, select_one_or_more_files +from dioptra.client import ( + DioptraFile, + select_files_in_directory, + select_one_or_more_files, +) from ..lib import actions, mock_rq @@ -730,13 +735,13 @@ def registered_mlflowrun_incomplete( @pytest.fixture def resources_tar_file() -> DioptraFile: - root_dir = Path(__file__).absolute().parent / "resource_import_files" + os.chdir(Path(__file__).absolute().parent / "resource_import_files") f = NamedTemporaryFile(suffix=".tar.gz") with tarfile.open(fileobj=f, mode="w:gz") as tar: - tar.add(root_dir / "dioptra.toml", arcname="dioptra.toml") - tar.add(root_dir / "hello_world", arcname="plugins/hello_world", recursive=True) - tar.add(root_dir / "hello-world.yaml", arcname="examples/hello-world.yaml") + tar.add("dioptra.toml") + tar.add("plugins", recursive=True) + tar.add("examples/hello-world.yaml") f.seek(0) yield select_one_or_more_files([f.name])[0] @@ -744,6 +749,13 @@ def resources_tar_file() -> DioptraFile: f.close() +@pytest.fixture +def resources_files() -> DioptraFile: + os.chdir(Path(__file__).absolute().parent / "resource_import_files") + + return select_files_in_directory(".", recursive=True) + + @pytest.fixture def resources_import_config() -> dict[str, Any]: root_dir = Path(__file__).absolute().parent / "resource_import_files" diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py index 82437c4fd..f997c3d2c 100644 --- a/tests/unit/restapi/v1/test_workflow_resource_import.py +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -90,7 +90,7 @@ def assert_resource_import_overwrite_works( # -- Tests ----------------------------------------------------------------------------- -def test_resource_import( +def test_resource_import_from_archive_file( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, auth_account: dict[str, Any], @@ -103,6 +103,19 @@ def test_resource_import( assert_imported_resources_match_expected(dioptra_client, resources_import_config) +def test_resource_import_from_files( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_files: list[DioptraFile], + resources_import_config: dict[str, Any], +): + group_id = auth_account["groups"][0]["id"] + dioptra_client.workflows.import_resources(group_id, files=resources_files) + + assert_imported_resources_match_expected(dioptra_client, resources_import_config) + + def test_resource_import_fails_from_name_clash( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy,