Skip to content

Commit

Permalink
make it more flexible (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt authored Nov 28, 2024
1 parent ffc1757 commit fd9315a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
43 changes: 29 additions & 14 deletions bofire/data_models/strategies/stepwise/stepwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,37 @@ def validate_domain_compatibility(domain1: Domain, domain2: Domain):
Raises:
ValueError: If one of the the conditions mentioned above is not met.
"""
features1 = domain1.inputs + domain1.outputs
features2 = domain2.inputs + domain2.outputs
if len(features1) != len(features2):
raise ValueError("Domains have different number of features.")
if features1.get_keys() != features2.get_keys():
raise ValueError("Domains have different feature keys.")
for feature1, feature2 in zip(features1.get(), features2.get()):
if feature1.__class__ != feature2.__class__:
raise ValueError(f"Features with key {feature1.key} have different types.")
if isinstance(feature1, (CategoricalInput, CategoricalOutput)) and isinstance(
feature2, (CategoricalInput, CategoricalOutput)
):
if feature1.categories != feature2.categories:

def validate(equals: List[str], features1, features2):
for key in equals:
feature1 = features1.get_by_key(key)
feature2 = features2.get_by_key(key)
if feature1.__class__ != feature2.__class__:
raise ValueError(
f"Features with key {feature1.key} have different categories."
f"Features with key {feature1.key} have different types."
)
if isinstance(
feature1, (CategoricalInput, CategoricalOutput)
) and isinstance(feature2, (CategoricalInput, CategoricalOutput)):
if feature1.categories != feature2.categories:
raise ValueError(
f"Features with key {feature1.key} have different categories."
)

validate(
[key for key in domain1.inputs.get_keys() if key in domain2.inputs.get_keys()],
domain1.inputs,
domain2.inputs,
)
validate(
[
key
for key in domain1.outputs.get_keys()
if key in domain2.outputs.get_keys()
],
domain1.outputs,
domain2.outputs,
)


class StepwiseStrategy(Strategy):
Expand Down
2 changes: 0 additions & 2 deletions bofire/transforms/manipulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ def _apply_pd_transforms(self, df: pd.DataFrame, transforms: list) -> pd.DataFra
if len(transforms) == 0:
return df
transformed_df = df.copy()
print(transformed_df)
for tr in transforms:
transformed_df.eval(tr, inplace=True)
print(transformed_df)

return transformed_df

Expand Down
15 changes: 3 additions & 12 deletions tests/bofire/strategies/stepwise/test_stepwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import bofire.strategies.api as strategies
from bofire.benchmarks.single import Himmelblau
from bofire.data_models.acquisition_functions.api import qNEI
from bofire.data_models.features.api import CategoricalInput, ContinuousInput
from bofire.data_models.features.api import CategoricalInput
from bofire.data_models.strategies.api import (
AlwaysTrueCondition,
NumberOfExperimentsCondition,
Expand All @@ -22,15 +22,6 @@

def test_validate_domain_compatibility():
bench = Himmelblau()
domain2 = deepcopy(bench.domain)
domain2.inputs.features.append(ContinuousInput(key="a", bounds=(0, 1)))
with pytest.raises(ValueError, match="Domains have different number of features."):
validate_domain_compatibility(bench.domain, domain2)

domain2 = deepcopy(bench.domain)
domain2.inputs.features[0].key = "mama"
with pytest.raises(ValueError, match="Domains have different feature keys."):
validate_domain_compatibility(bench.domain, domain2)

domain2 = deepcopy(bench.domain)
domain2.inputs = bench.domain.inputs.get_by_keys(["x_1"])
Expand All @@ -51,10 +42,10 @@ def test_validate_domain_compatibility():
def test_StepwiseStrategy_invalid_domains():
benchmark = Himmelblau()
domain2 = deepcopy(benchmark.domain)
domain2.inputs[0].key = "mama"
domain2.inputs.features[0] = CategoricalInput(key="x_1", categories=["a", "b"])
with pytest.raises(
ValueError,
match="Domains have different feature keys.",
match="Features with key x_1 have different types.",
):
StepwiseStrategy(
domain=benchmark.domain,
Expand Down

0 comments on commit fd9315a

Please sign in to comment.