-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pandas data manipulation transform (#460)
* add __len__ for dataframes * add new transform * add missing files * fix 3.9 compatibility issue * fix typing
- Loading branch information
Showing
16 changed files
with
269 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
from typing import Union | ||
|
||
from bofire.data_models.transforms.drop_data import DropDataTransform | ||
from bofire.data_models.transforms.manipulate_data import ManipulateDataTransform | ||
|
||
|
||
AnyTransform = DropDataTransform | ||
AnyTransform = Union[DropDataTransform, ManipulateDataTransform] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
from typing import List, Literal, Optional | ||
|
||
from pydantic import BaseModel | ||
from bofire.data_models.transforms.transform import Transform | ||
|
||
|
||
class DropDataTransform(BaseModel): | ||
class DropDataTransform(Transform): | ||
type: Literal["DropDataTransform"] = "DropDataTransform" | ||
to_be_removed_experiments: Optional[List[int]] = None | ||
to_be_removed_candidates: Optional[List[int]] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import List, Literal, Optional | ||
|
||
from pydantic import Field, model_validator | ||
|
||
from bofire.data_models.transforms.transform import Transform | ||
|
||
|
||
class ManipulateDataTransform(Transform): | ||
"""Transform that can be used to manipulate experiments/candidates by applying pandas based transformations | ||
as described here: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.eval.html#pandas.DataFrame.eval | ||
Attributes: | ||
experiment_transformations: List of strings representing the transformations to be applied to the experiments | ||
candidate_transformations: List of strings representing the transformations to be applied to the candidates | ||
candidate_untransformations: List of strings representing the transformations to be applied to untransform the | ||
generated candidates | ||
""" | ||
|
||
type: Literal["ManipulateDataTransform"] = "ManipulateDataTransform" | ||
experiment_transforms: Optional[List[str]] = Field(None, min_length=1) | ||
candidate_transforms: Optional[List[str]] = Field(None, min_length=1) | ||
candidate_untransforms: Optional[List[str]] = Field(None, min_length=1) | ||
|
||
@model_validator(mode="after") | ||
def validate_transformations(self): | ||
if not any( | ||
[ | ||
self.experiment_transforms, | ||
self.candidate_transforms, | ||
self.candidate_untransforms, | ||
] | ||
): | ||
raise ValueError( | ||
"At least one of experiment_transforms, candidate_transforms, or candidate_untransforms must be provided." | ||
) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from typing import Any | ||
|
||
from bofire.data_models.base import BaseModel | ||
|
||
|
||
class Transform(BaseModel): | ||
type: Any |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pandas as pd | ||
|
||
from bofire.data_models.transforms.api import ManipulateDataTransform as DataModel | ||
from bofire.transforms.transform import Transform | ||
|
||
|
||
class ManipulateDataTransform(Transform): | ||
def __init__(self, data_model: DataModel): | ||
self.experiment_transforms = data_model.experiment_transforms or [] | ||
self.candidate_transforms = data_model.candidate_transforms or [] | ||
self.candidate_untransforms = data_model.candidate_untransforms or [] | ||
|
||
def _apply_pd_transforms(self, df: pd.DataFrame, transforms: list) -> pd.DataFrame: | ||
if len(transforms) == 0: | ||
return df | ||
transformed_df = df.copy() | ||
print(transformed_df) | ||
for tr in transforms: | ||
transformed_df.eval(tr, inplace=True) | ||
print(transformed_df) | ||
|
||
return transformed_df | ||
|
||
def transform_experiments(self, experiments: pd.DataFrame) -> pd.DataFrame: | ||
return self._apply_pd_transforms(experiments, self.experiment_transforms) | ||
|
||
def transform_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: | ||
return self._apply_pd_transforms(candidates, self.candidate_transforms) | ||
|
||
def untransform_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: | ||
return self._apply_pd_transforms(candidates, self.candidate_untransforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
75 changes: 75 additions & 0 deletions
75
tests/bofire/strategies/stepwise/test_manipulate_data_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from copy import deepcopy | ||
|
||
from pandas.testing import assert_frame_equal, assert_series_equal | ||
|
||
import bofire.strategies.api as strategies | ||
import bofire.transforms.api as transforms | ||
from bofire.benchmarks.api import Himmelblau | ||
from bofire.data_models.strategies.predictives.sobo import SoboStrategy | ||
from bofire.data_models.strategies.random import RandomStrategy | ||
from bofire.data_models.strategies.stepwise.conditions import ( | ||
AlwaysTrueCondition, | ||
NumberOfExperimentsCondition, | ||
) | ||
from bofire.data_models.strategies.stepwise.stepwise import Step, StepwiseStrategy | ||
from bofire.data_models.transforms.api import ManipulateDataTransform | ||
|
||
|
||
def test_dropdata_transform(): | ||
bench = Himmelblau() | ||
candidates = bench.domain.inputs.sample(10) | ||
experiments = bench.f(bench.domain.inputs.sample(10), return_complete=True) | ||
|
||
transform_data = ManipulateDataTransform( | ||
experiment_transforms=["x_1 = x_1 + 100", "x_2 = x_2 / 2.0"], | ||
candidate_transforms=["x_1 = x_1 -20", "x_2 = x_2 / 2.0"], | ||
candidate_untransforms=["x_1 = x_1 + 20", "x_2 = x_2 * 2.0"], | ||
) | ||
|
||
transform = transforms.map(transform_data) | ||
|
||
transformed_experiments = transform.transform_experiments(experiments) | ||
transformed_candidates = transform.transform_candidates(candidates) | ||
untransformed_candidates = transform.untransform_candidates(transformed_candidates) | ||
|
||
assert_series_equal(experiments.x_1 + 100, transformed_experiments.x_1) | ||
assert_series_equal(experiments.x_2 / 2.0, transformed_experiments.x_2) | ||
|
||
try: | ||
assert_frame_equal(candidates, transformed_candidates) | ||
except AssertionError: | ||
pass | ||
|
||
assert_frame_equal(candidates, untransformed_candidates) | ||
|
||
|
||
def test_stepwise(): | ||
bench = Himmelblau() | ||
candidates = bench.domain.inputs.sample(10) | ||
|
||
transform_data = ManipulateDataTransform( | ||
candidate_untransforms=["x_1 = x_1 + 200", "x_2 = x_2 - 200"], | ||
) | ||
|
||
domain = deepcopy(bench.domain) | ||
domain.inputs.get_by_key("x_1").bounds = (-6, 300) | ||
domain.inputs.get_by_key("x_2").bounds = (-300, 6) | ||
strategy_data = StepwiseStrategy( | ||
domain=domain, | ||
steps=[ | ||
Step( | ||
condition=NumberOfExperimentsCondition(n_experiments=5), | ||
strategy_data=RandomStrategy(domain=bench.domain), | ||
transform=transform_data, | ||
), | ||
Step( | ||
condition=AlwaysTrueCondition(), | ||
strategy_data=SoboStrategy(domain=bench.domain), | ||
), | ||
], | ||
) | ||
|
||
strategy = strategies.map(strategy_data) | ||
candidates = strategy.ask(candidate_count=1) | ||
assert all(candidates.x_1 >= 150) | ||
assert all(candidates.x_2 <= -150) |
Oops, something went wrong.