Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: check project is trainable before build #113

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/ansys/simai/core/data/model_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# SOFTWARE.

from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, List, Literal, Optional

from ansys.simai.core.errors import InvalidArguments, ProcessingError
from ansys.simai.core.utils.misc import dict_get
Expand Down Expand Up @@ -223,7 +223,7 @@
| *2_days*: < 2 days, default value.

| *7_days*: < 1 week
continuous: indicates if continuous learning is enabled. Default is False.
build_on_top: indicates if build_on_top learning is enabled. Default is False.
input: the inputs of the model.
output: the outputs of the model.
global_coefficients: the Global Coefficients of the model.
Expand Down Expand Up @@ -275,20 +275,20 @@
new_conf = ModelConfiguration(
project=aero_dyn_project,
build_preset="debug",
continuous=False,
build_on_top=False,
input=model_input,
output=model_output,
global_coefficients=global_coefficients,
domain_of_analysis=doa,
pp_input=pp_input,
)

# Launch a mode build with the new configuration
# Launch a model build with the new configuration
new_model = simai.models.build(new_conf)
"""

project: "Optional[Project]" = None
continuous: bool = False
build_on_top: bool = False
input: ModelInput = field(default_factory=lambda: ModelInput())
output: ModelOutput = field(default_factory=lambda: ModelOutput())
domain_of_analysis: DomainOfAnalysis = field(default_factory=lambda: DomainOfAnalysis())
Expand Down Expand Up @@ -335,7 +335,7 @@
project: "Project",
boundary_conditions: Optional[dict[str, Any]] = None,
build_preset: Optional[str] = "debug",
continuous: bool = False,
build_on_top: bool = False,
fields: Optional[dict[str, Any]] = None,
global_coefficients: Optional[list[GlobalCoefficientDefinition]] = None,
simulation_volume: Optional[dict[str, Any]] = None,
Expand All @@ -358,7 +358,7 @@
if boundary_conditions is not None and self.input.boundary_conditions is None:
self.input.boundary_conditions = list(boundary_conditions.keys())
self.build_preset = build_preset
self.continuous = continuous
self.build_on_top = build_on_top
if fields is not None:
if fields.get("surface_input"):
self.input.surface = [fd.get("name") for fd in fields["surface_input"]]
Expand Down Expand Up @@ -464,13 +464,24 @@
return {
"boundary_conditions": bcs,
"build_preset": SupportedBuildPresets[self.build_preset],
"continuous": self.continuous,
"continuous": self.build_on_top,
"fields": flds,
"global_coefficients": gcs,
"simulation_volume": simulation_volume,
}

def compute_global_coefficient(self):
@classmethod
def _from_payload(cls, **kwargs) -> "ModelConfiguration":
# Retrieve SDK version of build preset from API version of build preset
build_preset = next(
(k for k, v in SupportedBuildPresets.items() if v == kwargs.get("build_preset")), None
)
kwargs["build_preset"] = build_preset
if build_on_top := kwargs.pop("continuous", None):
kwargs["build_on_top"] = build_on_top

Check warning on line 481 in src/ansys/simai/core/data/model_configuration.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/model_configuration.py#L481

Added line #L481 was not covered by tests
return ModelConfiguration(**kwargs)

def compute_global_coefficient(self) -> List[float]:
"""Computes the results of the formula for all global coefficients with respect to the project's sample."""

if self.project is None:
Expand Down
9 changes: 8 additions & 1 deletion src/ansys/simai/core/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ansys.simai.core.data.base import ComputableDataModel, Directory
from ansys.simai.core.data.model_configuration import ModelConfiguration
from ansys.simai.core.errors import InvalidArguments


class Model(ComputableDataModel):
Expand All @@ -39,7 +40,7 @@
@property
def configuration(self) -> ModelConfiguration:
"""Build configuration of a model."""
return ModelConfiguration(
return ModelConfiguration._from_payload(
project=self._client.projects.get(self.fields["project_id"]),
**self.fields["configuration"],
)
Expand Down Expand Up @@ -99,6 +100,12 @@
new_model = simai.models.build(build_conf)

"""
if not configuration.project:
raise InvalidArguments("The model configuration does not have a project set")

Check warning on line 104 in src/ansys/simai/core/data/models.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/models.py#L104

Added line #L104 was not covered by tests

is_trainable = configuration.project.is_trainable()
if not is_trainable:
raise InvalidArguments(f"Cannot train model because: {is_trainable.reason}")

Check warning on line 108 in src/ansys/simai/core/data/models.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/models.py#L108

Added line #L108 was not covered by tests

return self._model_from(
self._client._api.launch_build(
Expand Down
19 changes: 11 additions & 8 deletions src/ansys/simai/core/data/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ansys.simai.core.data.model_configuration import ModelConfiguration
from ansys.simai.core.data.types import Identifiable, get_id_from_identifiable
from ansys.simai.core.errors import InvalidArguments, ProcessingError
from ansys.simai.core.utils.numerical import cast_values_to_float

if TYPE_CHECKING:
from ansys.simai.core.data.global_coefficients_requests import (
Expand Down Expand Up @@ -124,17 +125,19 @@
@property
def last_model_configuration(self) -> ModelConfiguration:
"""Last :class:`configuration <ansys.simai.core.data.model_configuration.ModelConfiguration>` used for model training in this project."""
return ModelConfiguration(project=self, **self.fields.get("last_model_configuration"))

def delete(self) -> None:
"""Delete the project."""
self._client._api.delete_project(self.id)
return ModelConfiguration._from_payload(
project=self, **self.fields.get("last_model_configuration")
)

def is_trainable(self) -> bool:
def is_trainable(self) -> IsTrainableInfo:
"""Check if the project meets the prerequisites to be trained."""
tt = self._client._api.is_project_trainable(self.id)
return IsTrainableInfo(**tt)

def delete(self) -> None:
"""Delete the project."""
self._client._api.delete_project(self.id)

Check warning on line 139 in src/ansys/simai/core/data/projects.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/projects.py#L139

Added line #L139 was not covered by tests

def get_variables(self) -> Optional[Dict[str, List[str]]]:
"""Get the available variables for the model's input/output."""
if not self.sample:
Expand Down Expand Up @@ -177,7 +180,7 @@

def compute_gc_formula(
self, gc_formula: str, bc: list[str] = None, surface_variables: list[str] = None
):
) -> float:
"""Compute the result of a global coefficient formula according to the project sample."""

if not self.sample:
Expand All @@ -203,7 +206,7 @@
gc_compute.run()
gc_compute.wait()

return gc_compute.result if gc_compute.is_ready else None
return cast_values_to_float(gc_compute.result["value"])

Check warning on line 209 in src/ansys/simai/core/data/projects.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/data/projects.py#L209

Added line #L209 was not covered by tests

def cancel_build(self):
"""Cancels a build if there is one pending."""
Expand Down
42 changes: 33 additions & 9 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,18 @@ def test_build_with_last_config(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)

project.verify_gc_formula = Mock()

in_model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
in_model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

launched_model: Model = simai_client.models.build(in_model_conf)

Expand Down Expand Up @@ -261,6 +267,12 @@ def test_build_with_new_config(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)

Expand All @@ -280,7 +292,7 @@ def test_build_with_new_config(simai_client):
new_conf = ModelConfiguration(
project=project,
build_preset="debug",
continuous=False,
build_on_top=False,
input=model_input,
output=model_output,
global_coefficients=global_coefficients,
Expand Down Expand Up @@ -317,7 +329,7 @@ def test_set_doa(simai_client):

project.verify_gc_formula = Mock()

model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

new_height = {"position": "relative_to_center", "value": 0.5, "length": 15.2}

Expand All @@ -344,7 +356,7 @@ def test_get_doa(simai_client):

project.verify_gc_formula = Mock()

model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

doa_length_raw = MODEL_CONF_RAW.get("simulation_volume").get("X")

Expand Down Expand Up @@ -436,7 +448,7 @@ def test_exception_compute_global_coefficient(simai_client):

project.verify_gc_formula = Mock()

model_conf = ModelConfiguration(project=project, **MODEL_CONF_RAW)
model_conf = ModelConfiguration._from_payload(project=project, **MODEL_CONF_RAW)

model_conf.project = None

Expand All @@ -449,7 +461,7 @@ def test_exception_setting_global_coefficient():
THEN an error is raise."""

with pytest.raises(ProcessingError):
ModelConfiguration(project=None, **MODEL_CONF_RAW)
ModelConfiguration._from_payload(project=None, **MODEL_CONF_RAW)


def test_sse_event_handler(simai_client, model_factory):
Expand Down Expand Up @@ -508,7 +520,7 @@ def test_throw_error_when_volume_is_missing_from_sample(simai_client):
model_output = ModelOutput(surface=[], volume=["Velocity_0"])
global_coefficients = []

model_conf = ModelConfiguration(
model_conf = ModelConfiguration._from_payload(
project=project,
build_preset="debug",
continuous=False,
Expand Down Expand Up @@ -539,6 +551,12 @@ def test_post_process_input(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)
project.verify_gc_formula = Mock()
Expand All @@ -559,7 +577,7 @@ def test_post_process_input(simai_client):
model_request = deepcopy(MODEL_RAW)
model_request["configuration"] = model_conf_dict

config_with_pp_input = ModelConfiguration(
config_with_pp_input = ModelConfiguration._from_payload(
project=project,
**model_conf_dict,
pp_input=pp_input,
Expand Down Expand Up @@ -606,6 +624,12 @@ def test_failed_build_with_resolution(simai_client):
json=raw_project,
status=200,
)
responses.add(
responses.GET,
f"https://test.test/projects/{MODEL_RAW['project_id']}/trainable",
json={"is_trainable": True},
status=200,
)

project: Project = simai_client._project_directory._model_from(raw_project)

Expand All @@ -619,7 +643,7 @@ def test_failed_build_with_resolution(simai_client):
height=hght,
)

new_conf = ModelConfiguration(
new_conf = ModelConfiguration._from_payload(
project=project,
build_preset="debug",
domain_of_analysis=doa,
Expand Down
Loading