diff --git a/src/ansys/simai/core/api/training_data.py b/src/ansys/simai/core/api/training_data.py index 520ac7ea..afca5f4a 100644 --- a/src/ansys/simai/core/api/training_data.py +++ b/src/ansys/simai/core/api/training_data.py @@ -65,3 +65,8 @@ def compute_training_data(self, training_data_id: str) -> None: def get_training_data_subset(self, project_id: str, training_data_id: str) -> Dict[str, Any]: return self._get(f"projects/{project_id}/data/{training_data_id}/subset") + + def put_training_data_subset(self, project_id: str, training_data_id: str, subset: str) -> None: + return self._put( + f"projects/{project_id}/data/{training_data_id}/subset", json={"subset": subset} + ) diff --git a/src/ansys/simai/core/data/training_data.py b/src/ansys/simai/core/data/training_data.py index d9e76578..a9cb8291 100644 --- a/src/ansys/simai/core/data/training_data.py +++ b/src/ansys/simai/core/data/training_data.py @@ -29,6 +29,7 @@ MonitorCallback, NamedFile, Path, + SubsetEnum, get_id_from_identifiable, unpack_named_file, ) @@ -75,7 +76,7 @@ def parts(self) -> List["TrainingDataPart"]: for training_data_part in self.fields["parts"] ] - def get_subset(self, project: Identifiable["Project"]) -> Optional[str]: + def get_subset(self, project: Identifiable["Project"]) -> Optional[SubsetEnum]: """Get the subset that the training data belongs to, in relation to the given project. Args: @@ -83,10 +84,28 @@ def get_subset(self, project: Identifiable["Project"]) -> Optional[str]: the :class:`~.projects.Project` object for, or its ID. Returns: - Name of the subset that the training data belongs to in the given project. + SubsetEnum of the subset that the training data belongs to in the given project. + (e.g. ) """ project_id = get_id_from_identifiable(project, default=self._client._current_project) - return self._client._api.get_training_data_subset(project_id, self.id).get("subset") + subset_value = self._client._api.get_training_data_subset(project_id, self.id).get("subset") + return SubsetEnum(subset_value) if subset_value else None + + def assign_subset(self, project: Identifiable["Project"], subset: SubsetEnum) -> None: + """Assign the training data subset in relation to a given project. + + Args: + project: ID or :class:`model <.projects.Project>` + subset: SubsetEnum attribute (e.g. SubsetEnum.TRAINING) or string value (e.g. "Training"). + Available options: (Training, Validation, Test, Ignored) + + Returns: + None + """ + if subset not in SubsetEnum.__members__.values(): + raise InvalidArguments("Must be one of: Ignored, Training, Test, Validation.") + project_id = get_id_from_identifiable(project, default=self._client._current_project) + self._client._api.put_training_data_subset(project_id, self.id, subset) @property def extracted_metadata(self) -> Optional[Dict]: diff --git a/src/ansys/simai/core/data/types.py b/src/ansys/simai/core/data/types.py index 088ec4b8..1feb7830 100644 --- a/src/ansys/simai/core/data/types.py +++ b/src/ansys/simai/core/data/types.py @@ -24,6 +24,7 @@ import os import pathlib from contextlib import contextmanager +from enum import Enum from numbers import Number from typing import Any, BinaryIO, Callable, Dict, Generator, List, Optional, Tuple, Union @@ -262,3 +263,10 @@ def get_object_from_identifiable( return get_object_from_identifiable(default, directory) else: raise InvalidArguments(f"Argument {identifiable} is neither a data model nor an ID string.") + + +class SubsetEnum(str, Enum): + IGNORED = "Ignored" + TRAINING = "Training" + VALIDATION = "Validation" + TEST = "Test" diff --git a/tests/test_training_data.py b/tests/test_training_data.py index 5c723664..3c41b918 100644 --- a/tests/test_training_data.py +++ b/tests/test_training_data.py @@ -26,6 +26,9 @@ import pytest import responses +from ansys.simai.core.data.types import SubsetEnum +from ansys.simai.core.errors import InvalidArguments + if TYPE_CHECKING: from ansys.simai.core.data.training_data import TrainingData @@ -84,9 +87,11 @@ def test_training_data_remove_from_project(simai_client, training_data_factory, @pytest.mark.parametrize( "td_factory_args", [ - ({"id": "777", "name": "ICBM", "subset": "Training"}), - ({"id": "888", "name": "Duke Nukem", "subset": ""}), + ({"id": "777", "name": "ICBM", "subset": SubsetEnum.TRAINING}), + ({"id": "888", "name": "Duke Nukem", "subset": "Validation"}), ({"id": "999", "name": "Roman"}), + ({"id": "81", "name": "Diablo", "subset": None}), + ({"id": "9191", "name": "Deckard", "subset": "Ignored"}), ], ) @responses.activate @@ -103,3 +108,37 @@ def test_get_subset(training_data_factory, project_factory, td_factory_args): json={"subset": td_subset}, ) assert td.get_subset(project=project) == td_subset + + +@responses.activate +def test_get_subset_fails_enum_check(training_data_factory, project_factory): + project = project_factory(id="bon5ai", name="coolest_proj") + td: TrainingData = training_data_factory(project=project, id="415") + td_subset = "Trainidation" + responses.add( + responses.GET, + f"https://test.test/projects/{project.id}/data/{td.id}/subset", + status=200, + json={"subset": td_subset}, + ) + with pytest.raises(ValueError) as e: + td.get_subset(project=project) + assert str(e.value) == f"'{td_subset}' is not a valid SubsetEnum" + + +@responses.activate +def test_assign_subset(training_data_factory, project_factory): + project = project_factory(id="n07e45y", name="bananarama") + td: TrainingData = training_data_factory(project=project, subset=SubsetEnum.TRAINING) + + responses.add( + responses.PUT, + f"https://test.test/projects/{project.id}/data/{td.id}/subset", + status=200, + json={"subset": SubsetEnum.VALIDATION}, + ) + td.assign_subset(project=project, subset=SubsetEnum.VALIDATION) + td.assign_subset(project=project, subset="Validation") + with pytest.raises(InvalidArguments) as ve: + td.assign_subset(project=project, subset="Travalignorinestidation") + assert str(ve.value) == "Must be one of: Ignored, Training, Test, Validation."