diff --git a/bofire/runners/run.py b/bofire/runners/run.py index 89729bf2d..9df64e6fe 100644 --- a/bofire/runners/run.py +++ b/bofire/runners/run.py @@ -23,7 +23,7 @@ def _single_run( strategy_factory: StrategyFactory, n_iterations: int, metric: Callable[[Domain, pd.DataFrame], float], - n_candidates_per_proposals: int, + n_candidates_per_proposals: Optional[int], safe_interval: int, initial_sampler: Optional[ Union[Callable[[Domain], pd.DataFrame], pd.DataFrame] @@ -70,11 +70,13 @@ def autosafe_results(benchmark): # pd.concat() changes datatype of str to np.int32 if column contains whole numbers. # column needs to be converted back to str to be added to the benchmark domain. strategy.tell(XY) + assert isinstance(strategy.experiments, pd.DataFrame) metric_values[i] = metric(strategy.domain, strategy.experiments) pbar.set_description(f"Run {run_idx}") pbar.set_postfix({"Current Best:": f"{metric_values[i]:0.3f}"}) if (i + 1) % safe_interval == 0: autosafe_results(benchmark=benchmark) + assert isinstance(strategy.experiments, pd.DataFrame) return strategy.experiments, pd.Series(metric_values) @@ -84,7 +86,7 @@ def run( n_iterations: int, metric: Callable[[Domain, pd.DataFrame], float], initial_sampler: Optional[Callable[[Domain], pd.DataFrame]] = None, - n_candidates_per_proposal: int = 1, + n_candidates_per_proposal: Optional[int] = None, n_runs: int = 5, n_procs: int = 5, safe_interval: int = 1000, diff --git a/bofire/strategies/predictives/botorch.py b/bofire/strategies/predictives/botorch.py index 71c48c67e..dadab03ad 100644 --- a/bofire/strategies/predictives/botorch.py +++ b/bofire/strategies/predictives/botorch.py @@ -403,7 +403,7 @@ def _optimize_acqf_continuous( ) return candidates, acqf_vals - def _ask(self, candidate_count: int) -> pd.DataFrame: # type: ignore + def _ask(self, candidate_count: Optional[int] = None) -> pd.DataFrame: """[summary] Args: @@ -413,6 +413,7 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: # type: ignore pd.DataFrame: [description] """ + candidate_count = candidate_count or 1 assert candidate_count > 0, "candidate_count has to be larger than zero." if self.experiments is None: raise ValueError("No experiments have been provided yet.")