Skip to content

Commit

Permalink
fix strict candidate enforcement (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt authored Jan 14, 2025
1 parent 4b6d7bf commit e560833
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bofire/data_models/strategies/factorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class FactorialStrategy(Strategy):
This strategy is deprecated, please use FractionalFactorialStrategy instead.
"""

type: Literal["FactorialStrategy"] = "FactorialStrategy"
type: Literal["FactorialStrategy"] = "FactorialStrategy" # type: ignore

@classmethod
def is_constraint_implemented(cls, my_type: Type[Constraint]) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion bofire/strategies/fractional_factorial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -58,10 +59,11 @@ def _get_categorical_design(self) -> pd.DataFrame:

def _ask(self, candidate_count: Optional[int] = None) -> pd.DataFrame:
if candidate_count is not None:
raise ValueError(
warnings.warn(
"FractionalFactorialStrategy will ignore the specified value of candidate_count. "
"The strategy automatically determines how many candidates to "
"propose.",
UserWarning,
)
design = None
if len(self.domain.inputs.get(ContinuousInput)) > 0:
Expand Down
6 changes: 4 additions & 2 deletions bofire/strategies/shortest_path.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional, Tuple

import cvxpy as cp
Expand Down Expand Up @@ -127,10 +128,11 @@ def _ask(self, candidate_count: Optional[int] = None) -> pd.DataFrame:
"""
if candidate_count is not None:
raise ValueError(
"ShortestPath will ignore the specified value of candidate_count. "
warnings.warn(
"ShortestPathStrategy will ignore the specified value of candidate_count. "
"The strategy automatically determines how many candidates to "
"propose.",
UserWarning,
)
start = self.start
steps = []
Expand Down
6 changes: 4 additions & 2 deletions bofire/strategies/strategy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional

Expand Down Expand Up @@ -129,8 +130,9 @@ def ask(

if candidate_count is not None:
if len(candidates) != candidate_count:
raise ValueError(
f"expected {candidate_count} candidates, got {len(candidates)}",
warnings.warn(
f"Expected {candidate_count} candidates, got {len(candidates)}",
UserWarning,
)

if add_pending:
Expand Down
7 changes: 4 additions & 3 deletions tests/bofire/strategies/test_fractional_factorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ def test_FractionalFactorialStrategy_ask_invalid():
),
)
strategy = strategies.map(strategy_data)
with pytest.raises(
ValueError,
with pytest.warns(
UserWarning,
match="FractionalFactorialStrategy will ignore the specified value of candidate_count. "
"The strategy automatically determines how many candidates to "
"propose.",
):
strategy.ask(5)
candidates = strategy.ask(7)
assert len(candidates) == 5
4 changes: 3 additions & 1 deletion tests/bofire/strategies/test_shortest_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_step():
def test_ask():
data_model = specs.valid(data_models.ShortestPathStrategy).obj()
strategy = strategies.map(data_model=data_model)
with pytest.raises(ValueError, match="ShortestPath will ignore the specified "):
with pytest.warns(
UserWarning, match="ShortestPathStrategy will ignore the specified "
):
strategy.ask(candidate_count=4)
steps = strategy.ask()
assert np.allclose(
Expand Down
2 changes: 1 addition & 1 deletion tests/bofire/strategies/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def test_ask(self: Strategy, candidate_count: int):
return candidates

with mock.patch.object(dummy.DummyStrategy, "_ask", new=test_ask):
with pytest.raises(ValueError):
with pytest.warns(UserWarning, match="Expected"):
strategy.ask(candidate_count=4)


Expand Down

0 comments on commit e560833

Please sign in to comment.