Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to set Train/Validation/Test subset on training data #42

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ansys/simai/core/api/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from typing import Any, Dict, Iterable, Optional

from ansys.simai.core.api.mixin import ApiClientMixin
from ansys.simai.core.data.types import SubsetEnum
from ansys.simai.core.errors import InvalidArguments


class TrainingDataClientMixin(ApiClientMixin):
Expand Down Expand Up @@ -65,3 +67,12 @@ def compute_training_data(self, training_data_id: str) -> None:

def get_training_data_subset(self, project_id: str, training_data_id: str) -> Dict[str, Any]:
return self._get(f"projects/{project_id}/data/{training_data_id}/subset")

def put_training_data_subset(
self, project_id: str, training_data_id: str, subset: SubsetEnum
) -> None:
if subset not in SubsetEnum.__members__.values():
raise InvalidArguments("Must be one of: Ignored, Training, Test, Validation.")
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved
return self._put(
f"projects/{project_id}/data/{training_data_id}/subset", json={"subset": subset}
)
23 changes: 20 additions & 3 deletions src/ansys/simai/core/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MonitorCallback,
NamedFile,
Path,
SubsetEnum,
get_id_from_identifiable,
unpack_named_file,
)
Expand Down Expand Up @@ -75,18 +76,34 @@ def parts(self) -> List["TrainingDataPart"]:
for training_data_part in self.fields["parts"]
]

def get_subset(self, project: Identifiable["Project"]) -> Optional[str]:
def get_subset(self, project: Identifiable["Project"]) -> SubsetEnum:
"""Get the subset that the training data belongs to, in relation to the given project.

Args:
project: ID or :class:`model <.projects.Project>` of the project to check
the :class:`~.projects.Project` object for, or its ID.

Returns:
Name of the subset that the training data belongs to in the given project.
SubsetEnum of the subset that the training data belongs to in the given project.
(e.g. <SubsetEnum.VALIDATION: 'Validation'>)
"""
project_id = get_id_from_identifiable(project, default=self._client._current_project)
return self._client._api.get_training_data_subset(project_id, self.id).get("subset")
subset_value = self._client._api.get_training_data_subset(project_id, self.id).get("subset")
return SubsetEnum(subset_value)
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved

def assign_subset(self, project: Identifiable["Project"], subset: SubsetEnum) -> None:
"""Assign the training data subset in relation to a given project.

Args:
project: ID or :class:`model <.projects.Project>`
subset: SubsetEnum attribute (e.g. SubsetEnum.TRAINING) or string value (e.g. "Training").
Available options: (Training, Validation, Test, Ignored)

Returns:
None
"""
project_id = get_id_from_identifiable(project, default=self._client._current_project)
self._client._api.put_training_data_subset(project_id, self.id, subset)

@property
def extracted_metadata(self) -> Optional[Dict]:
Expand Down
8 changes: 8 additions & 0 deletions src/ansys/simai/core/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import pathlib
from contextlib import contextmanager
from enum import Enum
from numbers import Number
from typing import Any, BinaryIO, Callable, Dict, Generator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -262,3 +263,10 @@ def get_object_from_identifiable(
return get_object_from_identifiable(default, directory)
else:
raise InvalidArguments(f"Argument {identifiable} is neither a data model nor an ID string.")


class SubsetEnum(str, Enum):
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved
IGNORED = "Ignored"
TRAINING = "Training"
VALIDATION = "Validation"
TEST = "Test"
42 changes: 39 additions & 3 deletions tests/test_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import pytest
import responses

from ansys.simai.core.data.types import SubsetEnum
from ansys.simai.core.errors import InvalidArguments

if TYPE_CHECKING:
from ansys.simai.core.data.training_data import TrainingData

Expand Down Expand Up @@ -84,9 +87,8 @@ def test_training_data_remove_from_project(simai_client, training_data_factory,
@pytest.mark.parametrize(
"td_factory_args",
[
({"id": "777", "name": "ICBM", "subset": "Training"}),
({"id": "888", "name": "Duke Nukem", "subset": ""}),
({"id": "999", "name": "Roman"}),
({"id": "777", "name": "ICBM", "subset": SubsetEnum.TRAINING}),
({"id": "888", "name": "Duke Nukem", "subset": "Validation"}),
],
)
@responses.activate
Expand All @@ -103,3 +105,37 @@ def test_get_subset(training_data_factory, project_factory, td_factory_args):
json={"subset": td_subset},
)
assert td.get_subset(project=project) == td_subset


@responses.activate
def test_get_subset_fails_enum_check(training_data_factory, project_factory):
project = project_factory(id="bon5ai", name="coolest_proj")
td: TrainingData = training_data_factory(project=project, id="415")
td_subset = "Trainidation"
responses.add(
responses.GET,
f"https://test.test/projects/{project.id}/data/{td.id}/subset",
status=200,
json={"subset": td_subset},
)
with pytest.raises(ValueError) as e:
td.get_subset(project=project)
assert str(e.value) == f"'{td_subset}' is not a valid SubsetEnum"


@responses.activate
def test_assign_subset(training_data_factory, project_factory):
project = project_factory(id="n07e45y", name="bananarama")
td: TrainingData = training_data_factory(project=project, subset=SubsetEnum.TRAINING)

responses.add(
responses.PUT,
f"https://test.test/projects/{project.id}/data/{td.id}/subset",
status=200,
json={"subset": SubsetEnum.VALIDATION},
)
td.assign_subset(project=project, subset=SubsetEnum.VALIDATION)
td.assign_subset(project=project, subset="Validation")
with pytest.raises(InvalidArguments) as ve:
td.assign_subset(project=project, subset="Travalignorinestidation")
assert str(ve.value) == "Must be one of: Ignored, Training, Test, Validation."
Loading