Skip to content

Commit

Permalink
feat: check project is trainable before build
Browse files Browse the repository at this point in the history
Also fix an issue with translating a model configuration payload into ModelConfiguration object
  • Loading branch information
tmpbeing committed Dec 2, 2024
1 parent bd2b409 commit 180f159
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 27 deletions.
29 changes: 20 additions & 9 deletions src/ansys/simai/core/data/model_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# SOFTWARE.

from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, List, Literal, Optional

from ansys.simai.core.errors import InvalidArguments, ProcessingError
from ansys.simai.core.utils.misc import dict_get
Expand Down Expand Up @@ -223,7 +223,7 @@ class ModelConfiguration:
| *2_days*: < 2 days, default value.
| *7_days*: < 1 week
continuous: indicates if continuous learning is enabled. Default is False.
build_on_top: indicates if build_on_top learning is enabled. Default is False.
input: the inputs of the model.
output: the outputs of the model.
global_coefficients: the Global Coefficients of the model.
Expand Down Expand Up @@ -275,20 +275,20 @@ class ModelConfiguration:
new_conf = ModelConfiguration(
project=aero_dyn_project,
build_preset="debug",
continuous=False,
build_on_top=False,
input=model_input,
output=model_output,
global_coefficients=global_coefficients,
domain_of_analysis=doa,
pp_input=pp_input,
)
# Launch a mode build with the new configuration
# Launch a model build with the new configuration
new_model = simai.models.build(new_conf)
"""

project: "Optional[Project]" = None
continuous: bool = False
build_on_top: bool = False
input: ModelInput = field(default_factory=lambda: ModelInput())
output: ModelOutput = field(default_factory=lambda: ModelOutput())
domain_of_analysis: DomainOfAnalysis = field(default_factory=lambda: DomainOfAnalysis())
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(
project: "Project",
boundary_conditions: Optional[dict[str, Any]] = None,
build_preset: Optional[str] = "debug",
continuous: bool = False,
build_on_top: bool = False,
fields: Optional[dict[str, Any]] = None,
global_coefficients: Optional[list[GlobalCoefficientDefinition]] = None,
simulation_volume: Optional[dict[str, Any]] = None,
Expand All @@ -358,7 +358,7 @@ def __init__(
if boundary_conditions is not None and self.input.boundary_conditions is None:
self.input.boundary_conditions = list(boundary_conditions.keys())
self.build_preset = build_preset
self.continuous = continuous
self.build_on_top = build_on_top
if fields is not None:
if fields.get("surface_input"):
self.input.surface = [fd.get("name") for fd in fields["surface_input"]]
Expand Down Expand Up @@ -464,13 +464,24 @@ def _to_payload(self):
return {
"boundary_conditions": bcs,
"build_preset": SupportedBuildPresets[self.build_preset],
"continuous": self.continuous,
"continuous": self.build_on_top,
"fields": flds,
"global_coefficients": gcs,
"simulation_volume": simulation_volume,
}

def compute_global_coefficient(self):
@classmethod
def _from_payload(cls, **kwargs) -> "ModelConfiguration":
# Retrieve SDK version of build preset from API version of build preset
build_preset = next(
(k for k, v in SupportedBuildPresets.items() if v == kwargs.get("build_preset")), None
)
kwargs["build_preset"] = build_preset
if build_on_top := kwargs.pop("continuous", None):
kwargs["build_on_top"] = build_on_top

Check warning on line 481 in src/ansys/simai/core/data/model_configuration.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/model_configuration.py#L481

Added line #L481 was not covered by tests
return ModelConfiguration(**kwargs)

def compute_global_coefficient(self) -> List[float]:
"""Computes the results of the formula for all global coefficients with respect to the project's sample."""

if self.project is None:
Expand Down
9 changes: 8 additions & 1 deletion src/ansys/simai/core/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ansys.simai.core.data.base import ComputableDataModel, Directory
from ansys.simai.core.data.model_configuration import ModelConfiguration
from ansys.simai.core.errors import InvalidArguments


class Model(ComputableDataModel):
Expand All @@ -39,7 +40,7 @@ def project_id(self) -> str:
@property
def configuration(self) -> ModelConfiguration:
"""Build configuration of a model."""
return ModelConfiguration(
return ModelConfiguration._from_payload(
project=self._client.projects.get(self.fields["project_id"]),
**self.fields["configuration"],
)
Expand Down Expand Up @@ -99,6 +100,12 @@ def build(
new_model = simai.models.build(build_conf)
"""
if not configuration.project:
raise InvalidArguments("The model configuration does not have a project set")

Check warning on line 104 in src/ansys/simai/core/data/models.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/models.py#L104

Added line #L104 was not covered by tests

is_trainable = configuration.project.is_trainable()
if not is_trainable:
raise InvalidArguments(f"Cannot train model because: {is_trainable.reason}")

Check warning on line 108 in src/ansys/simai/core/data/models.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/models.py#L108

Added line #L108 was not covered by tests

return self._model_from(
self._client._api.launch_build(
Expand Down
19 changes: 11 additions & 8 deletions src/ansys/simai/core/data/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ansys.simai.core.data.model_configuration import ModelConfiguration
from ansys.simai.core.data.types import Identifiable, get_id_from_identifiable
from ansys.simai.core.errors import InvalidArguments, ProcessingError
from ansys.simai.core.utils.numerical import cast_values_to_float

if TYPE_CHECKING:
from ansys.simai.core.data.global_coefficients_requests import (
Expand Down Expand Up @@ -124,17 +125,19 @@ def sample(self, new_sample: Identifiable["TrainingData"]):
@property
def last_model_configuration(self) -> ModelConfiguration:
"""Last :class:`configuration <ansys.simai.core.data.model_configuration.ModelConfiguration>` used for model training in this project."""
return ModelConfiguration(project=self, **self.fields.get("last_model_configuration"))

def delete(self) -> None:
"""Delete the project."""
self._client._api.delete_project(self.id)
return ModelConfiguration._from_payload(
project=self, **self.fields.get("last_model_configuration")
)

def is_trainable(self) -> bool:
def is_trainable(self) -> IsTrainableInfo:
"""Check if the project meets the prerequisites to be trained."""
tt = self._client._api.is_project_trainable(self.id)
return IsTrainableInfo(**tt)

def delete(self) -> None:
"""Delete the project."""
self._client._api.delete_project(self.id)

Check warning on line 139 in src/ansys/simai/core/data/projects.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/projects.py#L139

Added line #L139 was not covered by tests

def get_variables(self) -> Optional[Dict[str, List[str]]]:
"""Get the available variables for the model's input/output."""
if not self.sample:
Expand Down Expand Up @@ -177,7 +180,7 @@ def verify_gc_formula(

def compute_gc_formula(
self, gc_formula: str, bc: list[str] = None, surface_variables: list[str] = None
):
) -> float:
"""Compute the result of a global coefficient formula according to the project sample."""

if not self.sample:
Expand All @@ -203,7 +206,7 @@ def compute_gc_formula(
gc_compute.run()
gc_compute.wait()

return gc_compute.result if gc_compute.is_ready else None
return cast_values_to_float(gc_compute.result["value"])

Check warning on line 209 in src/ansys/simai/core/data/projects.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/projects.py#L209

Added line #L209 was not covered by tests

def cancel_build(self):
"""Cancels a build if there is one pending."""
Expand Down
42 changes: 33 additions & 9 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,18 @@ def test_build_with_last_config(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)

project.verify_gc_formula = Mock()

in_model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
in_model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

launched_model: Model = simai_client.models.build(in_model_conf)

Expand Down Expand Up @@ -261,6 +267,12 @@ def test_build_with_new_config(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)

Expand All @@ -280,7 +292,7 @@ def test_build_with_new_config(simai_client):
new_conf = ModelConfiguration(
project=project,
build_preset="debug",
continuous=False,
build_on_top=False,
input=model_input,
output=model_output,
global_coefficients=global_coefficients,
Expand Down Expand Up @@ -317,7 +329,7 @@ def test_set_doa(simai_client):

project.verify_gc_formula = Mock()

model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

new_height = {"position": "relative_to_center", "value": 0.5, "length": 15.2}

Expand All @@ -344,7 +356,7 @@ def test_get_doa(simai_client):

project.verify_gc_formula = Mock()

model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

doa_length_raw = MODEL_CONF_RAW.get("simulation_volume").get("X")

Expand Down Expand Up @@ -436,7 +448,7 @@ def test_exception_compute_global_coefficient(simai_client):

project.verify_gc_formula = Mock()

model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

model_conf.project = None

Expand All @@ -449,7 +461,7 @@ def test_exception_setting_global_coefficient():
THEN an error is raise."""

with pytest.raises(ProcessingError):
ModelConfiguration(project=None, **MODEL_CONF_RAW)
ModelConfiguration._from_payload(project=None, **MODEL_CONF_RAW)


def test_sse_event_handler(simai_client, model_factory):
Expand Down Expand Up @@ -508,7 +520,7 @@ def test_throw_error_when_volume_is_missing_from_sample(simai_client):
model_output = ModelOutput(surface=[], volume=["Velocity_0"])
global_coefficients = []

model_conf = ModelConfiguration(
model_conf = ModelConfiguration._from_payload(
project=project,
build_preset="debug",
continuous=False,
Expand Down Expand Up @@ -539,6 +551,12 @@ def test_post_process_input(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)
project.verify_gc_formula = Mock()
Expand All @@ -559,7 +577,7 @@ def test_post_process_input(simai_client):
model_request = deepcopy(MODEL_RAW)
model_request["configuration"] = model_conf_dict

config_with_pp_input = ModelConfiguration(
config_with_pp_input = ModelConfiguration._from_payload(
project=project,
**model_conf_dict,
pp_input=pp_input,
Expand Down Expand Up @@ -606,6 +624,12 @@ def test_failed_build_with_resolution(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)

Expand All @@ -619,7 +643,7 @@ def test_failed_build_with_resolution(simai_client):
height=hght,
)

new_conf = ModelConfiguration(
new_conf = ModelConfiguration._from_payload(
project=project,
build_preset="debug",
domain_of_analysis=doa,
Expand Down

0 comments on commit 180f159

Please sign in to comment.