From 11cc40abcbf51ed4e792d3693b6665a3a7112469 Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 22 Sep 2023 09:09:36 -0400 Subject: [PATCH 01/31] Begin refactor --- bofire/data_models/domain/domain.py | 6 +- bofire/data_models/domain/features.py | 20 +- bofire/data_models/features/categorical.py | 20 +- .../strategies/predictives/qparego.py | 4 +- .../strategies/predictives/sobo.py | 4 +- bofire/surrogates/api.py | 3 +- bofire/surrogates/mapper.py | 3 +- bofire/surrogates/mlp_classifier.py | 213 ++++++++++++++++++ tests/bofire/data_models/test_features.py | 4 +- 9 files changed, 245 insertions(+), 32 deletions(-) create mode 100644 bofire/surrogates/mlp_classifier.py diff --git a/bofire/data_models/domain/domain.py b/bofire/data_models/domain/domain.py index 2e86ac629..b324689a5 100644 --- a/bofire/data_models/domain/domain.py +++ b/bofire/data_models/domain/domain.py @@ -604,7 +604,7 @@ def validate_candidates( itertools.chain.from_iterable( [ [f"{key}_pred", f"{key}_sd", f"{key}_des"] - for key in self.outputs.get_keys_by_objective(Objective) + for key in self.outputs.get_keys_by_objective([Objective, List]) ] + [ [f"{key}_pred", f"{key}_sd"] @@ -626,8 +626,8 @@ def validate_candidates( # validate no additional cols exist if_count = len(self.get_features(Input)) - of_count = len(self.outputs.get_by_objective(includes=Objective)) - of_count_w = len(self.outputs.get_by_objective(excludes=Objective, includes=None)) # type: ignore + of_count = len(self.outputs.get_by_objective(includes=[Objective, List])) + of_count_w = len(self.outputs.get_by_objective(excludes=[Objective, List], includes=None)) # type: ignore # input features, prediction, standard deviation and reward for each output feature, 3 additional usefull infos: reward, aquisition function, strategy if len(candidates.columns) != if_count + 3 * of_count + 2 * of_count_w: raise ValueError("additional columns found") diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index dc53fa0e4..e04ad33ee 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -20,6 +20,7 @@ CategoricalDescriptorInput, CategoricalInput, CategoricalMolecularInput, + CategoricalOutput, ContinuousInput, ContinuousOutput, DiscreteInput, @@ -103,8 +104,8 @@ def get_by_keys(self, keys: Sequence[str]) -> Features: def get( self, - includes: Union[Type, List[Type]] = AnyFeature, - excludes: Union[Type, List[Type]] = None, + includes: Union[Type, List[Type], Tuple[Type]] = AnyFeature, + excludes: Union[Type, List[Type], Tuple[Type]] = None, exact: bool = False, ) -> Features: """get features of the domain @@ -131,8 +132,8 @@ def get( def get_keys( self, - includes: Union[Type, List[Type]] = AnyFeature, - excludes: Union[Type, List[Type]] = None, + includes: Union[Type, List[Type], Tuple[Type]] = AnyFeature, + excludes: Union[Type, List[Type], Tuple[Type]] = None, exact: bool = False, ) -> List[str]: """Method to get feature keys of the domain @@ -262,8 +263,8 @@ def validate_experiments( def get_categorical_combinations( self, - include: Union[Type, List[Type]] = Input, - exclude: Union[Type, List[Type]] = None, + include: Union[Type, List[Type], Tuple[Type]] = Input, + exclude: Union[Type, List[Type], Tuple[Type]] = None, ): """get a list of tuples pairing the feature keys with a list of valid categories @@ -539,11 +540,13 @@ def get_by_objective( self, includes: Union[ List[Type[AbstractObjective]], + Tuple[Type[AbstractObjective]], Type[AbstractObjective], Type[Objective], ] = Objective, excludes: Union[ List[Type[AbstractObjective]], + Tuple[Type[AbstractObjective]], Type[AbstractObjective], None, ] = None, @@ -566,7 +569,7 @@ def get_by_objective( return Outputs( features=sorted( filter_by_attribute( - self.get(ContinuousOutput).features, + self.get([ContinuousOutput, CategoricalOutput]).features, lambda of: of.objective, includes, excludes, @@ -579,11 +582,12 @@ def get_keys_by_objective( self, includes: Union[ List[Type[AbstractObjective]], + Tuple[Type[AbstractObjective]], Type[AbstractObjective], Type[Objective], ] = Objective, excludes: Union[ - List[Type[AbstractObjective]], Type[AbstractObjective], None + List[Type[AbstractObjective]], Tuple[Type[AbstractObjective]], Type[AbstractObjective], None ] = None, exact: bool = False, ) -> List[str]: diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index e0b29e2a7..2ff523bfb 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -20,8 +20,8 @@ class CategoricalInput(Input): """Base class for all categorical input features. Attributes: - categories (List[str]): Names of the categories. - allowed (List[bool]): List of bools indicating if a category is allowed within the optimization. + categories (Tuple[str]): Names of the categories. + allowed (Tuple[bool]): List of bools indicating if a category is allowed within the optimization. """ type: Literal["CategoricalInput"] = "CategoricalInput" @@ -35,18 +35,17 @@ def validate_categories_unique(cls, categories): """validates that categories have unique names Args: - categories (List[str]): List of category names + categories (Union[List[str], Tuple[str]]): List or tuple of category names Raises: ValueError: when categories have non-unique names Returns: - List[str]: List of the categories + Tuple[str]: Tuple of the categories """ - categories = list(categories) if len(categories) != len(set(categories)): raise ValueError("categories must be unique") - return categories + return tuple(categories) @root_validator(pre=False, skip_on_failure=True) def init_allowed(cls, values): @@ -65,7 +64,7 @@ def init_allowed(cls, values): if "categories" not in values or values["categories"] is None: return values if "allowed" not in values or values["allowed"] is None: - values["allowed"] = [True for _ in range(len(values["categories"]))] + values["allowed"] = tuple([True for _ in range(len(values["categories"]))]) if len(values["allowed"]) != len(values["categories"]): raise ValueError("allowed must have same length as categories") if sum(values["allowed"]) == 0: @@ -368,18 +367,17 @@ def validate_categories_unique(cls, categories): """validates that categories have unique names Args: - categories (List[str]): List of category names + categories (Union[List[str], Tuple[str]]): List or tuple of category names Raises: ValueError: when categories have non-unique names Returns: - List[str]: List of the categories + Tuple[str]: Tuple of the categories """ - categories = list(categories) if len(categories) != len(set(categories)): raise ValueError("categories must be unique") - return categories + return tuple(categories) @validator("objective") def validate_objective(cls, objective, values): diff --git a/bofire/data_models/strategies/predictives/qparego.py b/bofire/data_models/strategies/predictives/qparego.py index 257235809..9776bc001 100644 --- a/bofire/data_models/strategies/predictives/qparego.py +++ b/bofire/data_models/strategies/predictives/qparego.py @@ -26,9 +26,7 @@ class QparegoStrategy(MultiobjectiveStrategy): @classmethod def is_feature_implemented(cls, my_type: Type[Feature]) -> bool: - if my_type not in [CategoricalOutput]: - return True - return False + return True @classmethod def is_objective_implemented(cls, my_type: Type[Objective]) -> bool: diff --git a/bofire/data_models/strategies/predictives/sobo.py b/bofire/data_models/strategies/predictives/sobo.py index 052a6bf91..00e3f84a9 100644 --- a/bofire/data_models/strategies/predictives/sobo.py +++ b/bofire/data_models/strategies/predictives/sobo.py @@ -21,9 +21,7 @@ def is_feature_implemented(cls, my_type: Type[Feature]) -> bool: Returns: bool: True if the feature type is valid for the strategy chosen, False otherwise """ - if my_type not in [CategoricalOutput]: - return True - return False + return True @classmethod def is_objective_implemented(cls, my_type: Type[Objective]) -> bool: diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index 5bae145f7..65130eef8 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -2,7 +2,8 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.mapper import map from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -from bofire.surrogates.mlp import MLPEnsemble +from bofire.surrogates.mlp_classifier import MLPEnsemble +# from bofire.surrogates.mlp import MLPEnsemble from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 5be8d4f3e..7079c6be9 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -4,7 +4,8 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -from bofire.surrogates.mlp import MLPEnsemble +# from bofire.surrogates.mlp import MLPEnsemble +from bofire.surrogates.mlp_classifier import MLPEnsemble from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py new file mode 100644 index 000000000..289440099 --- /dev/null +++ b/bofire/surrogates/mlp_classifier.py @@ -0,0 +1,213 @@ +from typing import Literal, Optional, Sequence + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from botorch.models.ensemble import EnsembleModel +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from bofire.data_models.enum import OutputFilteringEnum +from bofire.data_models.surrogates.api import MLPEnsemble as DataModel +from bofire.surrogates.botorch import BotorchSurrogate +from bofire.surrogates.single_task_gp import get_scaler +from bofire.surrogates.trainable import TrainableSurrogate +from bofire.utils.torch_tools import tkwargs + + +class ClassificationDataSet(Dataset): + """ + Prepare the dataset for classification + """ + + def __init__(self, X: Tensor, y: Tensor): + self.X = X.to(**tkwargs) + self.y = y.to(**tkwargs) + + def __len__(self): + return len(self.X) + + def __getitem__(self, i: int): + return self.X[i], self.y[i] + + +class MLPClassifier(nn.Module): + def __init__( + self, + input_size: int, + output_size: int = 1, + hidden_layer_sizes: Sequence = (100,), + dropout: float = 0.0, + activation: Literal["relu", "logistic", "tanh"] = "relu", + ): + super().__init__() + if activation == "relu": + f_activation = nn.ReLU + elif activation == "logistic": + f_activation = nn.Sigmoid + elif activation == "tanh": + f_activation = nn.Tanh + else: + raise ValueError(f"Activation {activation} not known.") + layers = [ + nn.Linear(input_size, hidden_layer_sizes[0]).to(**tkwargs), + f_activation(), + ] + if dropout > 0.0: + layers.append(nn.Dropout(dropout)) + if len(hidden_layer_sizes) > 1: + for i in range(len(hidden_layer_sizes) - 1): + layers += [ + nn.Linear(hidden_layer_sizes[i], hidden_layer_sizes[i + 1]).to( + **tkwargs + ), + f_activation(), + ] + if dropout > 0.0: + layers.append(nn.Dropout(dropout)) + layers.append(nn.Linear(hidden_layer_sizes[-1], output_size).to(**tkwargs)) + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class _MLPEnsemble(EnsembleModel): + def __init__(self, mlps: Sequence[MLPClassifier]): + super().__init__() + if len(mlps) == 0: + raise ValueError("List of mlps is empty.") + num_in_features = mlps[0].layers[0].in_features + num_out_features = mlps[0].layers[-1].out_features + for mlp in mlps: + assert mlp.layers[0].in_features == num_in_features + assert mlp.layers[-1].out_features == num_out_features + self.mlps = mlps + # put all models in eval mode + for mlp in self.mlps: + mlp.eval() + + def forward(self, X: Tensor): + r"""Compute the model output at X. + + Args: + X: A `batch_shape x n x d`-dim input tensor `X`. + + Returns: + A `batch_shape x s x n x m`-dimensional output tensor where + `s` is the size of the ensemble. + """ + return torch.stack([mlp(X) for mlp in self.mlps], dim=-3) + + @property + def num_outputs(self) -> int: + r"""The number of outputs of the model.""" + return self.mlps[0].layers[-1].out_features # type: ignore + + +def fit_mlp( + mlp: MLPClassifier, + dataset: ClassificationDataSet, + batch_size: int = 10, + n_epoches: int = 200, + lr: float = 1e-3, + shuffle: bool = True, + weight_decay: float = 0.0, +): + """Fit a MLP for classification to a dataset. + + Args: + mlp (MLP): The MLP that should be fitted. + dataset (ClassificationDataSet): The data that should be fitted + batch_size (int, optional): Batch size. Defaults to 10. + n_epoches (int, optional): Number of training epoches. Defaults to 200. + lr (float, optional): Initial learning rate. Defaults to 1e-4. + shuffle (bool, optional): Whereas the batches should be shuffled. Defaults to True. + weight_decay (float, optional): Weight decay (L2 regularization). Defaults to 0.0 (no regularization). + """ + mlp.train() + train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) + loss_function = nn.BCEWithLogitsLoss() + optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) + for _ in range(n_epoches): + current_loss = 0.0 + for data in train_loader: + # Get and prepare inputs + inputs, targets = data + if len(targets.shape) == 1: + targets = targets.reshape((targets.shape[0], 1)) + + # Zero the gradients + optimizer.zero_grad() + + # Perform forward pass + outputs = mlp(inputs) + + # Compute loss + loss = loss_function(outputs, targets) + + # Perform backward pass + loss.backward() + + # Perform optimization + optimizer.step() + + # Print statistics + current_loss += loss.item() + + +class MLPEnsemble(BotorchSurrogate, TrainableSurrogate): + def __init__(self, data_model: DataModel, **kwargs): + self.n_estimators = data_model.n_estimators + self.hidden_layer_sizes = data_model.hidden_layer_sizes + self.activation = data_model.activation + self.dropout = data_model.dropout + self.batch_size = data_model.batch_size + self.n_epochs = data_model.n_epochs + self.lr = data_model.lr + self.weight_decay = data_model.weight_decay + self.subsample_fraction = data_model.subsample_fraction + self.shuffle = data_model.shuffle + self.scaler = data_model.scaler + super().__init__(data_model, **kwargs) + + _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL + model: Optional[_MLPEnsemble] = None + + def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): + scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) + transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) + + mlps = [] + subsample_size = round(self.subsample_fraction * X.shape[0]) + for _ in range(self.n_estimators): + # resample X and Y + sample_idx = np.random.choice(X.shape[0], replace=True, size=subsample_size) + tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) + ty = torch.from_numpy(Y.values[sample_idx].astype(int)).to(**tkwargs) + + dataset = ClassificationDataSet( + X=scaler.transform(tX) if scaler is not None else tX, + y=ty, + ) + mlp = MLPClassifier( + input_size=transformed_X.shape[1], + output_size=1, + hidden_layer_sizes=self.hidden_layer_sizes, + activation=self.activation, # type: ignore + dropout=self.dropout, + ) + fit_mlp( + mlp=mlp, + dataset=dataset, + batch_size=self.batch_size, + n_epoches=self.n_epochs, + lr=self.lr, + shuffle=self.shuffle, + weight_decay=self.weight_decay, + ) + mlps.append(mlp) + self.model = _MLPEnsemble(mlps=mlps) + if scaler is not None: + self.model.input_transform = scaler diff --git a/tests/bofire/data_models/test_features.py b/tests/bofire/data_models/test_features.py index 52e0d8c15..c41227d6a 100644 --- a/tests/bofire/data_models/test_features.py +++ b/tests/bofire/data_models/test_features.py @@ -1193,9 +1193,9 @@ def test_continuous_descriptor_input_feature_as_dataframe(descriptors, values): @pytest.mark.parametrize( "categories, descriptors, values", [ - (["c1", "c2"], ["d1", "d2", "d3"], [[1, 2, 3], [4, 5, 6]]), + (("c1", "c2"), ["d1", "d2", "d3"], [[1, 2, 3], [4, 5, 6]]), ( - ["c1", "c2", "c3", "c4"], + ("c1", "c2", "c3", "c4"), ["d1", "d2", "d3"], [ [1, 2, 3], From 9b688e0b7b36558ac8f918764d8e36f8477f9cf0 Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 22 Sep 2023 12:50:39 -0400 Subject: [PATCH 02/31] Update bug fixes --- bofire/data_models/features/categorical.py | 2 +- bofire/surrogates/mlp_classifier.py | 13 +++++++++---- tests/bofire/data_models/specs/features.py | 10 +++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 2ff523bfb..b9eb9d28b 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -64,7 +64,7 @@ def init_allowed(cls, values): if "categories" not in values or values["categories"] is None: return values if "allowed" not in values or values["allowed"] is None: - values["allowed"] = tuple([True for _ in range(len(values["categories"]))]) + values["allowed"] = [True for _ in range(len(values["categories"]))] if len(values["allowed"]) != len(values["categories"]): raise ValueError("allowed must have same length as categories") if sum(values["allowed"]) == 0: diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py index 289440099..f3305fed6 100644 --- a/bofire/surrogates/mlp_classifier.py +++ b/bofire/surrogates/mlp_classifier.py @@ -73,7 +73,7 @@ def forward(self, x): return self.layers(x) -class _MLPEnsemble(EnsembleModel): +class _MLPClassifierEnsemble(EnsembleModel): def __init__(self, mlps: Sequence[MLPClassifier]): super().__init__() if len(mlps) == 0: @@ -87,6 +87,7 @@ def __init__(self, mlps: Sequence[MLPClassifier]): # put all models in eval mode for mlp in self.mlps: mlp.eval() + self.activation = nn.Sigmoid() def forward(self, X: Tensor): r"""Compute the model output at X. @@ -173,19 +174,23 @@ def __init__(self, data_model: DataModel, **kwargs): super().__init__(data_model, **kwargs) _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL - model: Optional[_MLPEnsemble] = None + model: Optional[_MLPClassifierEnsemble] = None def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) + # Convert Y to classification tensor + Y = pd.DataFrame.from_dict({col: np.unique(Y[col].values, return_inverse=True)[1] for col in Y.columns}) + # Y = Y.apply(lambda x: pd.factorize(x, sort=True)[0]) + print(f"X: {X}, Y={Y}") mlps = [] subsample_size = round(self.subsample_fraction * X.shape[0]) for _ in range(self.n_estimators): # resample X and Y sample_idx = np.random.choice(X.shape[0], replace=True, size=subsample_size) tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) - ty = torch.from_numpy(Y.values[sample_idx].astype(int)).to(**tkwargs) + ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) dataset = ClassificationDataSet( X=scaler.transform(tX) if scaler is not None else tX, @@ -208,6 +213,6 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): weight_decay=self.weight_decay, ) mlps.append(mlp) - self.model = _MLPEnsemble(mlps=mlps) + self.model = _MLPClassifierEnsemble(mlps=mlps) if scaler is not None: self.model.input_transform = scaler diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index 06b50c65f..9fdbd1457 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -53,7 +53,7 @@ features.CategoricalInput, lambda: { "key": str(uuid.uuid4()), - "categories": ["c1", "c2", "c3"], + "categories": ("c1", "c2", "c3"), "allowed": [True, True, False], }, ) @@ -61,7 +61,7 @@ features.CategoricalDescriptorInput, lambda: { "key": str(uuid.uuid4()), - "categories": ["c1", "c2", "c3"], + "categories": ("c1", "c2", "c3"), "allowed": [True, True, False], "descriptors": ["d1", "d2"], "values": [ @@ -84,7 +84,7 @@ features.CategoricalOutput, lambda: { "key": str(uuid.uuid4()), - "categories": ["a", "b", "c"], + "categories": ("a", "b", "c"), "objective": [0.0, 1.0, 0.0], }, ) @@ -99,12 +99,12 @@ features.CategoricalMolecularInput, lambda: { "key": str(uuid.uuid4()), - "categories": [ + "categories": ( "CC(=O)Oc1ccccc1C(=O)O", "c1ccccc1", "[CH3][CH2][OH]", "N[C@](C)(F)C(=O)O", - ], + ), "allowed": [True, True, True, True], }, ) From 51ea59d6aa7c1e62cdcc1125e0b63ba92cdcdbcf Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 22 Sep 2023 12:56:35 -0400 Subject: [PATCH 03/31] Update based on main --- for_gabe.ipynb | 718 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 718 insertions(+) create mode 100644 for_gabe.ipynb diff --git a/for_gabe.ipynb b/for_gabe.ipynb new file mode 100644 index 000000000..65bc37d7d --- /dev/null +++ b/for_gabe.ipynb @@ -0,0 +1,718 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DTLZ2 Benchmark\n", + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from bofire.benchmarks.multi import DTLZ2, C2DTLZ2\n", + "from bofire.utils.multiobjective import compute_hypervolume\n", + "from bofire.data_models.strategies.api import QehviStrategy, QparegoStrategy, RandomStrategy, PolytopeSampler, SoboStrategy\n", + "import bofire.strategies.api as strategies\n", + "from bofire.data_models.api import Domain, Outputs, Inputs\n", + "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", + "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective, MaximizeSigmoidObjective\n", + "from functools import partial\n", + "import pandas as pd\n", + "import os\n", + "from bofire.plot.api import plot_objective_plotly\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Manual setup of the optimization domain\n", + "\n", + "The following cell shows how to manually setup the optimization problem in BoFire for didactic purposes. In the following the implemented benchmark module is then used." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(0, 1)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=(0.5, 0.0))])\n", + "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", + "output_features = Outputs(features=[\n", + " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", + " CategoricalOutput(key=f\"f_{1}\", categories=[\"infeasible\", \"feasible\"], objective=[0, 1])\n", + " # ContinuousOutput(key=f\"f_{1}\", objective=MinimizeSigmoidObjective(w=1., steepness=50, tp=0.25)),\n", + " ]\n", + ")\n", + "# no constraints are present so we can create the domain\n", + "domain1 = Domain(inputs=input_features, outputs=output_features)\n", + "\n", + "# plot_objective_plotly(domain.outputs.get_by_key(\"f_0\"), lower=0, upper=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\G15361\\AppData\\Local\\Temp\\ipykernel_25932\\2798685259.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " sample_df[\"f_1\"][sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0] = \"feasible\"\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5f_0f_1valid_f_0
00.7695080.2400690.2770840.7364280.3443700.5-0.962660infeasible1
10.7882410.7779820.8927120.5424820.1092890.0-0.999523infeasible1
20.2996950.5440680.5464060.1696190.0000940.00.010914feasible1
30.2174950.7530680.8367950.3379950.6830570.0-0.951358feasible1
40.6261940.7799500.0779640.1816300.7288840.0-0.733750infeasible1
50.2006260.1691750.0859770.9318610.8926760.0-0.651469feasible1
60.1227090.5277600.8181040.0762930.5353400.0-0.487663feasible1
70.3098270.6174720.0808580.2583980.6283640.5-0.733951feasible1
80.4574090.8001570.5687100.7657640.4162110.0-0.991123infeasible1
90.8340060.3923990.1489890.4334850.1046700.5-0.746477infeasible1
100.0586600.7034930.7316530.1146170.8673400.0-0.786405feasible1
110.1747640.1418890.3930350.7490790.7701340.0-0.611618feasible1
120.2896990.4635330.3198790.0191260.5232050.0-0.044631feasible1
130.4465210.5433910.6287820.2744900.1722180.5-0.838544feasible1
140.3833320.5341240.0839690.9282720.3625200.0-0.660452feasible1
150.0884550.6719320.2993650.2624840.7232440.5-0.827524feasible1
160.3071720.0200330.6800260.3989330.5890410.0-0.411783feasible1
170.8964690.5321340.4554960.6935940.0654710.5-0.999999infeasible1
180.0014790.5014900.0074870.1928820.4548200.00.401028feasible1
190.6005530.0122210.5581150.0463140.6168230.5-0.691260feasible1
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", + "0 0.769508 0.240069 0.277084 0.736428 0.344370 0.5 -0.962660 \n", + "1 0.788241 0.777982 0.892712 0.542482 0.109289 0.0 -0.999523 \n", + "2 0.299695 0.544068 0.546406 0.169619 0.000094 0.0 0.010914 \n", + "3 0.217495 0.753068 0.836795 0.337995 0.683057 0.0 -0.951358 \n", + "4 0.626194 0.779950 0.077964 0.181630 0.728884 0.0 -0.733750 \n", + "5 0.200626 0.169175 0.085977 0.931861 0.892676 0.0 -0.651469 \n", + "6 0.122709 0.527760 0.818104 0.076293 0.535340 0.0 -0.487663 \n", + "7 0.309827 0.617472 0.080858 0.258398 0.628364 0.5 -0.733951 \n", + "8 0.457409 0.800157 0.568710 0.765764 0.416211 0.0 -0.991123 \n", + "9 0.834006 0.392399 0.148989 0.433485 0.104670 0.5 -0.746477 \n", + "10 0.058660 0.703493 0.731653 0.114617 0.867340 0.0 -0.786405 \n", + "11 0.174764 0.141889 0.393035 0.749079 0.770134 0.0 -0.611618 \n", + "12 0.289699 0.463533 0.319879 0.019126 0.523205 0.0 -0.044631 \n", + "13 0.446521 0.543391 0.628782 0.274490 0.172218 0.5 -0.838544 \n", + "14 0.383332 0.534124 0.083969 0.928272 0.362520 0.0 -0.660452 \n", + "15 0.088455 0.671932 0.299365 0.262484 0.723244 0.5 -0.827524 \n", + "16 0.307172 0.020033 0.680026 0.398933 0.589041 0.0 -0.411783 \n", + "17 0.896469 0.532134 0.455496 0.693594 0.065471 0.5 -0.999999 \n", + "18 0.001479 0.501490 0.007487 0.192882 0.454820 0.0 0.401028 \n", + "19 0.600553 0.012221 0.558115 0.046314 0.616823 0.5 -0.691260 \n", + "\n", + " f_1 valid_f_0 \n", + "0 infeasible 1 \n", + "1 infeasible 1 \n", + "2 feasible 1 \n", + "3 feasible 1 \n", + "4 infeasible 1 \n", + "5 feasible 1 \n", + "6 feasible 1 \n", + "7 feasible 1 \n", + "8 infeasible 1 \n", + "9 infeasible 1 \n", + "10 feasible 1 \n", + "11 feasible 1 \n", + "12 feasible 1 \n", + "13 feasible 1 \n", + "14 feasible 1 \n", + "15 feasible 1 \n", + "16 feasible 1 \n", + "17 infeasible 1 \n", + "18 feasible 1 \n", + "19 feasible 1 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "sample_df = domain1.inputs.sample(20).astype(float) # Sample x's\n", + "\n", + "# Write a function which outputs one continuous variable and another discrete based on some logic\n", + "sample_df[\"f_0\"] = np.cos(sample_df.values.sum(1))\n", + "sample_df[\"f_1\"] = \"infeasible\"\n", + "sample_df[\"f_1\"][sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0] = \"feasible\"\n", + "sample_df[\n", + " [\n", + " \"valid_%s\" % feat\n", + " for feat in domain1.outputs.get_keys_by_objective( # type: ignore\n", + " includes=[MinimizeObjective]\n", + " )\n", + " ]\n", + " ] = 1\n", + "\n", + "sample_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup of the Strategy and ask for Candidates\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X: x_0 x_1 x_2 x_3 x_4 x_5\n", + "0 0.769508 0.240069 0.277084 0.736428 0.344370 0.5\n", + "1 0.788241 0.777982 0.892712 0.542482 0.109289 0.0\n", + "2 0.299695 0.544068 0.546406 0.169619 0.000094 0.0\n", + "3 0.217495 0.753068 0.836795 0.337995 0.683057 0.0\n", + "4 0.626194 0.779950 0.077964 0.181630 0.728884 0.0\n", + "5 0.200626 0.169175 0.085977 0.931861 0.892676 0.0\n", + "6 0.122709 0.527760 0.818104 0.076293 0.535340 0.0\n", + "7 0.309827 0.617472 0.080858 0.258398 0.628364 0.5\n", + "8 0.457409 0.800157 0.568710 0.765764 0.416211 0.0\n", + "9 0.834006 0.392399 0.148989 0.433485 0.104670 0.5\n", + "10 0.058660 0.703493 0.731653 0.114617 0.867340 0.0\n", + "11 0.174764 0.141889 0.393035 0.749079 0.770134 0.0\n", + "12 0.289699 0.463533 0.319879 0.019126 0.523205 0.0\n", + "13 0.446521 0.543391 0.628782 0.274490 0.172218 0.5\n", + "14 0.383332 0.534124 0.083969 0.928272 0.362520 0.0\n", + "15 0.088455 0.671932 0.299365 0.262484 0.723244 0.5\n", + "16 0.307172 0.020033 0.680026 0.398933 0.589041 0.0\n", + "17 0.896469 0.532134 0.455496 0.693594 0.065471 0.5\n", + "18 0.001479 0.501490 0.007487 0.192882 0.454820 0.0\n", + "19 0.600553 0.012221 0.558115 0.046314 0.616823 0.5, Y= f_1\n", + "0 1\n", + "1 1\n", + "2 0\n", + "3 0\n", + "4 1\n", + "5 0\n", + "6 0\n", + "7 0\n", + "8 1\n", + "9 1\n", + "10 0\n", + "11 0\n", + "12 0\n", + "13 0\n", + "14 0\n", + "15 0\n", + "16 0\n", + "17 1\n", + "18 0\n", + "19 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_1_pred \\\n", + "0 0.326234 1.000000 1.000000 1.000000 1.000000 0.5 -2.118929 -14.356292 \n", + "1 1.000000 1.000000 1.000000 1.000000 1.000000 0.5 -1.965952 -1.856444 \n", + "2 0.599940 1.000000 1.000000 0.391617 1.000000 0.5 -1.997947 -13.875225 \n", + "3 1.000000 1.000000 1.000000 0.000000 1.000000 0.5 -1.829928 -8.801380 \n", + "4 1.000000 1.000000 0.000000 1.000000 1.000000 0.0 -1.794281 0.282144 \n", + "5 0.000000 1.000000 1.000000 1.000000 1.000000 0.5 -2.068119 -26.166718 \n", + "6 1.000000 1.000000 1.000000 1.000000 1.000000 0.0 -1.924939 0.120979 \n", + "7 0.846747 0.983599 0.551493 0.331172 1.000000 0.0 -1.546147 -0.139778 \n", + "8 0.835615 0.250119 0.238424 0.642778 0.343853 0.5 -0.929888 0.251818 \n", + "9 0.716120 0.923542 1.000000 0.301291 0.822325 0.5 -1.791427 -11.423274 \n", + "\n", + " f_0_sd f_1_sd f_0_des f_1_des \n", + "0 0.196473 18.588478 2.118929 NaN \n", + "1 0.239787 3.687648 1.965952 NaN \n", + "2 0.123578 18.389180 1.997947 NaN \n", + "3 0.232758 11.176665 1.829928 NaN \n", + "4 0.225904 1.452171 1.794281 NaN \n", + "5 0.241316 31.600352 2.068119 NaN \n", + "6 0.239853 1.585463 1.924939 NaN \n", + "7 0.114476 1.314039 1.546147 NaN \n", + "8 0.016918 1.475402 0.929888 NaN \n", + "9 0.119792 14.862474 1.791427 NaN \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from bofire.data_models.acquisition_functions.api import qNEI, qUCB, qSR, qEI\n", + "from bofire.data_models.strategies.api import QparegoStrategy, MultiplicativeSoboStrategy, SoboStrategy\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPEnsemble\n", + "from bofire.data_models.domain.api import Outputs\n", + "\n", + "strategy_data = SoboStrategy(domain=domain1, \n", + " acquisition_function=qEI(), \n", + " surrogate_specs=BotorchSurrogates(surrogates=\n", + " [\n", + " MLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[1]]), lr=1.0, n_epochs=100)\n", + " ]\n", + " )\n", + " )\n", + "\n", + "strategy = strategies.map(strategy_data)\n", + "\n", + "# experiments = DTLZ2(dim=6).f(domain1.inputs.sample(20).astype(float), return_complete=True)\n", + "\n", + "strategy.tell(sample_df)\n", + "candidates = strategy.ask(10)\n", + "\n", + "print(candidates)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Outputs(type='Outputs', features=[ContinuousOutput(type='ContinuousOutput', key='f_0', unit=None, objective=MinimizeObjective(type='MinimizeObjective', w=1.0, bounds=(0, 1))), CategoricalOutput(type='CategoricalOutput', key='f_1', categories=('infeasible', 'feasible'), objective=[0.0, 1.0])])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from bofire.data_models.objectives.api import Objective\n", + "strategy.domain.outputs.get_by_objective(includes=[Objective, list])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "type='Outputs' features=[ContinuousOutput(type='ContinuousOutput', key='f_0', unit=None, objective=MinimizeObjective(type='MinimizeObjective', w=1.0, bounds=(0, 1))), CategoricalOutput(type='CategoricalOutput', key='f_1', categories=('infeasible', 'feasible'), objective=[0.0, 1.0])]\n", + "(GenericMCObjective(), None, 0.001)\n", + "type='qEI'\n" + ] + } + ], + "source": [ + "print(domain1.outputs)\n", + "\n", + "print(strategy._get_objective_and_constraints())\n", + "\n", + "print(strategy.acquisition_function)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test gpytorch things" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Original ===\n", + "Should be X^T 1 for 1 the all ones vector\n", + "tensor([0.5000, 1.5000], grad_fn=)\n", + "tensor([4., 6.])\n", + "\n", + "=== Likelihood ===\n", + "tensor([0.6915, 0.9332], grad_fn=)\n", + "tensor(1.6247, grad_fn=)\n", + "tensor([0.7406, 1.2222])\n" + ] + } + ], + "source": [ + "import torch\n", + "from gpytorch.likelihoods import BernoulliLikelihood\n", + "\n", + "# Set up dummy tensors for easy verification\n", + "x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n", + "y = torch.tensor([0.5, 0.0], requires_grad=True)\n", + "out = torch.matmul(x, y)\n", + "print(f\"=== Original ===\")\n", + "print(f\"Should be X^T 1 for 1 the all ones vector\")\n", + "print(out)\n", + "out = out.sum()\n", + "out.backward()\n", + "print(y.grad)\n", + "\n", + "print(f\"\\n=== Likelihood ===\")\n", + "like = BernoulliLikelihood()\n", + "x1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n", + "y1 = torch.tensor([0.5, 0.0], requires_grad=True)\n", + "out1 = like(torch.matmul(x1, y1))\n", + "print(out1.probs)\n", + "out1 = out1.probs.sum()\n", + "print(out1)\n", + "out1.backward()\n", + "print(y1.grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Add Classification Models for Surrogates\n", + "\n", + "Updating the surrogates to allow for classification of output values (i.e. 'feasible' or 'infeasible').\n", + "\n", + "### Housekeeping changes\n", + "\n", + "1. Update the categorical input/outputs ('bofire/data_models/features/categorical.py') to always return a tuple instead of a list for `categories` and attribute (to prevent mutation)\n", + " - Associated test are changed in 'tests/bofire/data_models/specs/features.py'\n", + "2. \n", + "\n", + "### Classification Models\n", + "\n", + "Initially, we are only interested in checking whether or not certain points are feasible or infeasible, hence this is a binary classification problem. \n", + "\n", + "\n", + "### Questions\n", + "\n", + "1. Should we force `allowed` to be a tuple for the categorical input/outputs? If so, we need to refactor indexing for Pandas DFs..." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bofire", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "6f21737eef49beedf03d74399b47fe38d73eff760737ca33d38b9fe616638e91" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From e1414c5da7ffef92b1cdad86d6073cd1bcaafb39 Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 22 Sep 2023 17:05:56 -0400 Subject: [PATCH 04/31] Start fixing constraint output checks --- bofire/data_models/domain/features.py | 28 +- bofire/data_models/features/categorical.py | 6 +- bofire/strategies/predictives/predictive.py | 7 + bofire/surrogates/mlp_classifier.py | 7 +- for_gabe.ipynb | 718 ------------------ ...own_Binary_Constraint_Classification.ipynb | 638 ++++++++++++++++ 6 files changed, 670 insertions(+), 734 deletions(-) delete mode 100644 for_gabe.ipynb create mode 100644 tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index ed40c83ef..2c5d4fbbc 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -547,13 +547,11 @@ def get_by_objective( self, includes: Union[ List[Type[AbstractObjective]], - Tuple[Type[AbstractObjective]], Type[AbstractObjective], Type[Objective], ] = Objective, excludes: Union[ List[Type[AbstractObjective]], - Tuple[Type[AbstractObjective]], Type[AbstractObjective], None, ] = None, @@ -677,22 +675,16 @@ def validate_experiments(self, experiments: pd.DataFrame) -> pd.DataFrame: def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: # for each continuous output feature with an attached objective object # ToDo: adjust it for the CategoricalOutput - cols = list( + continuous_cols = list( itertools.chain.from_iterable( [ [f"{key}_pred", f"{key}_sd", f"{key}_des"] - for key in self.get_keys_by_objective(Objective) - ] - + [ - [f"{key}_pred", f"{key}_sd"] - for key in self.get_keys_by_objective( - excludes=Objective, includes=None # type: ignore - ) + for key in self.get_keys_by_objective(includes=Objective) ] ) ) # check that pred, sd, and des cols are specified and numerical - for col in cols: + for col in continuous_cols: if col not in candidates: raise ValueError(f"missing column {col}") try: @@ -703,6 +695,20 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: raise ValueError(f"Not all values of column `{col}` are numerical.") if candidates[col].isnull().to_numpy().any(): raise ValueError(f"Nan values are present in {col}.") + # Check for categorical output + categorical_objectives = self.get_by_objective(excludes=Objective, includes=None) + if len(categorical_objectives) == 0: + return candidates + categorical_cols = [ + f"{key}_pred" + for key in [categorical_output.key for categorical_output in categorical_objectives.features] + ] + categorical_values = [categorical_output.categories for categorical_output in categorical_objectives.features] + for ind, col in enumerate(categorical_cols): + if col not in candidates: + raise ValueError(f"missing column {col}") + if len(candidates[col]) - candidates[col].isin(categorical_values[ind]).sum() > 0: + raise ValueError(f"values present are not in {categorical_values[ind]}") return candidates def preprocess_experiments_one_valid_output( diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index c421136c5..a50020a84 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -402,6 +402,10 @@ def validate_experimental(self, values: pd.Series) -> pd.Series: def to_dict(self) -> Dict: """Returns the catergories and corresponding objective values as dictionary""" return dict(zip(self.categories, self.objective)) + + def to_dict_numeric(self) -> Dict: + """Returns the catergories and corresponding objective values as dictionary""" + return dict(zip(self.objective, self.categories)) def __call__(self, values: pd.Series) -> pd.Series: - return values.map(self.to_dict()).astype(float) + return values.round().map(self.to_dict_numeric()) diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index 3a27fa183..6a55b0f44 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -98,6 +98,13 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: raise ValueError("Model not yet fitted.") # TODO: validate also here the experiments but only for the input_columns # transformed = self.transformer.transform(experiments) + + ############################ + # TODO: Here, we need to separate by domain.outputs into continuous and categorical outputs. For continuous outputs, we leave as is, for categorical, we perform the desired mapping + # We then need to modify the input to the ._fit method for the surrogates to be categorically appropriate based on domain.outputs + # Finally, we need to modify the acquisition function to handle hard constraints (see how we can modify based on CBO's implementation and and remedy with BoTorch only using differentiable constraints in the `constraints` argument) + # Then, write tests/specs + ############################ transformed = self.domain.inputs.transform( experiments=experiments, specs=self.input_preprocessing_specs ) diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py index f3305fed6..58a888e22 100644 --- a/bofire/surrogates/mlp_classifier.py +++ b/bofire/surrogates/mlp_classifier.py @@ -70,7 +70,7 @@ def __init__( self.layers = nn.Sequential(*layers) def forward(self, x): - return self.layers(x) + return nn.functional.sigmoid(self.layers(x)) class _MLPClassifierEnsemble(EnsembleModel): @@ -87,7 +87,6 @@ def __init__(self, mlps: Sequence[MLPClassifier]): # put all models in eval mode for mlp in self.mlps: mlp.eval() - self.activation = nn.Sigmoid() def forward(self, X: Tensor): r"""Compute the model output at X. @@ -129,7 +128,7 @@ def fit_mlp( """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) - loss_function = nn.BCEWithLogitsLoss() + loss_function = nn.BCELoss() optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) for _ in range(n_epoches): current_loss = 0.0 @@ -183,7 +182,7 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): # Convert Y to classification tensor Y = pd.DataFrame.from_dict({col: np.unique(Y[col].values, return_inverse=True)[1] for col in Y.columns}) # Y = Y.apply(lambda x: pd.factorize(x, sort=True)[0]) - print(f"X: {X}, Y={Y}") + # print(f"X: {X}, Y={Y}") mlps = [] subsample_size = round(self.subsample_fraction * X.shape[0]) for _ in range(self.n_estimators): diff --git a/for_gabe.ipynb b/for_gabe.ipynb deleted file mode 100644 index 65bc37d7d..000000000 --- a/for_gabe.ipynb +++ /dev/null @@ -1,718 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# DTLZ2 Benchmark\n", - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "from bofire.benchmarks.multi import DTLZ2, C2DTLZ2\n", - "from bofire.utils.multiobjective import compute_hypervolume\n", - "from bofire.data_models.strategies.api import QehviStrategy, QparegoStrategy, RandomStrategy, PolytopeSampler, SoboStrategy\n", - "import bofire.strategies.api as strategies\n", - "from bofire.data_models.api import Domain, Outputs, Inputs\n", - "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", - "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective, MaximizeSigmoidObjective\n", - "from functools import partial\n", - "import pandas as pd\n", - "import os\n", - "from bofire.plot.api import plot_objective_plotly\n", - "\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Manual setup of the optimization domain\n", - "\n", - "The following cell shows how to manually setup the optimization problem in BoFire for didactic purposes. In the following the implemented benchmark module is then used." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(0, 1)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=(0.5, 0.0))])\n", - "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", - "output_features = Outputs(features=[\n", - " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", - " CategoricalOutput(key=f\"f_{1}\", categories=[\"infeasible\", \"feasible\"], objective=[0, 1])\n", - " # ContinuousOutput(key=f\"f_{1}\", objective=MinimizeSigmoidObjective(w=1., steepness=50, tp=0.25)),\n", - " ]\n", - ")\n", - "# no constraints are present so we can create the domain\n", - "domain1 = Domain(inputs=input_features, outputs=output_features)\n", - "\n", - "# plot_objective_plotly(domain.outputs.get_by_key(\"f_0\"), lower=0, upper=2)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\G15361\\AppData\\Local\\Temp\\ipykernel_25932\\2798685259.py:7: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " sample_df[\"f_1\"][sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0] = \"feasible\"\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
x_0x_1x_2x_3x_4x_5f_0f_1valid_f_0
00.7695080.2400690.2770840.7364280.3443700.5-0.962660infeasible1
10.7882410.7779820.8927120.5424820.1092890.0-0.999523infeasible1
20.2996950.5440680.5464060.1696190.0000940.00.010914feasible1
30.2174950.7530680.8367950.3379950.6830570.0-0.951358feasible1
40.6261940.7799500.0779640.1816300.7288840.0-0.733750infeasible1
50.2006260.1691750.0859770.9318610.8926760.0-0.651469feasible1
60.1227090.5277600.8181040.0762930.5353400.0-0.487663feasible1
70.3098270.6174720.0808580.2583980.6283640.5-0.733951feasible1
80.4574090.8001570.5687100.7657640.4162110.0-0.991123infeasible1
90.8340060.3923990.1489890.4334850.1046700.5-0.746477infeasible1
100.0586600.7034930.7316530.1146170.8673400.0-0.786405feasible1
110.1747640.1418890.3930350.7490790.7701340.0-0.611618feasible1
120.2896990.4635330.3198790.0191260.5232050.0-0.044631feasible1
130.4465210.5433910.6287820.2744900.1722180.5-0.838544feasible1
140.3833320.5341240.0839690.9282720.3625200.0-0.660452feasible1
150.0884550.6719320.2993650.2624840.7232440.5-0.827524feasible1
160.3071720.0200330.6800260.3989330.5890410.0-0.411783feasible1
170.8964690.5321340.4554960.6935940.0654710.5-0.999999infeasible1
180.0014790.5014900.0074870.1928820.4548200.00.401028feasible1
190.6005530.0122210.5581150.0463140.6168230.5-0.691260feasible1
\n", - "
" - ], - "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 0.769508 0.240069 0.277084 0.736428 0.344370 0.5 -0.962660 \n", - "1 0.788241 0.777982 0.892712 0.542482 0.109289 0.0 -0.999523 \n", - "2 0.299695 0.544068 0.546406 0.169619 0.000094 0.0 0.010914 \n", - "3 0.217495 0.753068 0.836795 0.337995 0.683057 0.0 -0.951358 \n", - "4 0.626194 0.779950 0.077964 0.181630 0.728884 0.0 -0.733750 \n", - "5 0.200626 0.169175 0.085977 0.931861 0.892676 0.0 -0.651469 \n", - "6 0.122709 0.527760 0.818104 0.076293 0.535340 0.0 -0.487663 \n", - "7 0.309827 0.617472 0.080858 0.258398 0.628364 0.5 -0.733951 \n", - "8 0.457409 0.800157 0.568710 0.765764 0.416211 0.0 -0.991123 \n", - "9 0.834006 0.392399 0.148989 0.433485 0.104670 0.5 -0.746477 \n", - "10 0.058660 0.703493 0.731653 0.114617 0.867340 0.0 -0.786405 \n", - "11 0.174764 0.141889 0.393035 0.749079 0.770134 0.0 -0.611618 \n", - "12 0.289699 0.463533 0.319879 0.019126 0.523205 0.0 -0.044631 \n", - "13 0.446521 0.543391 0.628782 0.274490 0.172218 0.5 -0.838544 \n", - "14 0.383332 0.534124 0.083969 0.928272 0.362520 0.0 -0.660452 \n", - "15 0.088455 0.671932 0.299365 0.262484 0.723244 0.5 -0.827524 \n", - "16 0.307172 0.020033 0.680026 0.398933 0.589041 0.0 -0.411783 \n", - "17 0.896469 0.532134 0.455496 0.693594 0.065471 0.5 -0.999999 \n", - "18 0.001479 0.501490 0.007487 0.192882 0.454820 0.0 0.401028 \n", - "19 0.600553 0.012221 0.558115 0.046314 0.616823 0.5 -0.691260 \n", - "\n", - " f_1 valid_f_0 \n", - "0 infeasible 1 \n", - "1 infeasible 1 \n", - "2 feasible 1 \n", - "3 feasible 1 \n", - "4 infeasible 1 \n", - "5 feasible 1 \n", - "6 feasible 1 \n", - "7 feasible 1 \n", - "8 infeasible 1 \n", - "9 infeasible 1 \n", - "10 feasible 1 \n", - "11 feasible 1 \n", - "12 feasible 1 \n", - "13 feasible 1 \n", - "14 feasible 1 \n", - "15 feasible 1 \n", - "16 feasible 1 \n", - "17 infeasible 1 \n", - "18 feasible 1 \n", - "19 feasible 1 " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "sample_df = domain1.inputs.sample(20).astype(float) # Sample x's\n", - "\n", - "# Write a function which outputs one continuous variable and another discrete based on some logic\n", - "sample_df[\"f_0\"] = np.cos(sample_df.values.sum(1))\n", - "sample_df[\"f_1\"] = \"infeasible\"\n", - "sample_df[\"f_1\"][sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0] = \"feasible\"\n", - "sample_df[\n", - " [\n", - " \"valid_%s\" % feat\n", - " for feat in domain1.outputs.get_keys_by_objective( # type: ignore\n", - " includes=[MinimizeObjective]\n", - " )\n", - " ]\n", - " ] = 1\n", - "\n", - "sample_df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup of the Strategy and ask for Candidates\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "X: x_0 x_1 x_2 x_3 x_4 x_5\n", - "0 0.769508 0.240069 0.277084 0.736428 0.344370 0.5\n", - "1 0.788241 0.777982 0.892712 0.542482 0.109289 0.0\n", - "2 0.299695 0.544068 0.546406 0.169619 0.000094 0.0\n", - "3 0.217495 0.753068 0.836795 0.337995 0.683057 0.0\n", - "4 0.626194 0.779950 0.077964 0.181630 0.728884 0.0\n", - "5 0.200626 0.169175 0.085977 0.931861 0.892676 0.0\n", - "6 0.122709 0.527760 0.818104 0.076293 0.535340 0.0\n", - "7 0.309827 0.617472 0.080858 0.258398 0.628364 0.5\n", - "8 0.457409 0.800157 0.568710 0.765764 0.416211 0.0\n", - "9 0.834006 0.392399 0.148989 0.433485 0.104670 0.5\n", - "10 0.058660 0.703493 0.731653 0.114617 0.867340 0.0\n", - "11 0.174764 0.141889 0.393035 0.749079 0.770134 0.0\n", - "12 0.289699 0.463533 0.319879 0.019126 0.523205 0.0\n", - "13 0.446521 0.543391 0.628782 0.274490 0.172218 0.5\n", - "14 0.383332 0.534124 0.083969 0.928272 0.362520 0.0\n", - "15 0.088455 0.671932 0.299365 0.262484 0.723244 0.5\n", - "16 0.307172 0.020033 0.680026 0.398933 0.589041 0.0\n", - "17 0.896469 0.532134 0.455496 0.693594 0.065471 0.5\n", - "18 0.001479 0.501490 0.007487 0.192882 0.454820 0.0\n", - "19 0.600553 0.012221 0.558115 0.046314 0.616823 0.5, Y= f_1\n", - "0 1\n", - "1 1\n", - "2 0\n", - "3 0\n", - "4 1\n", - "5 0\n", - "6 0\n", - "7 0\n", - "8 1\n", - "9 1\n", - "10 0\n", - "11 0\n", - "12 0\n", - "13 0\n", - "14 0\n", - "15 0\n", - "16 0\n", - "17 1\n", - "18 0\n", - "19 0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_1_pred \\\n", - "0 0.326234 1.000000 1.000000 1.000000 1.000000 0.5 -2.118929 -14.356292 \n", - "1 1.000000 1.000000 1.000000 1.000000 1.000000 0.5 -1.965952 -1.856444 \n", - "2 0.599940 1.000000 1.000000 0.391617 1.000000 0.5 -1.997947 -13.875225 \n", - "3 1.000000 1.000000 1.000000 0.000000 1.000000 0.5 -1.829928 -8.801380 \n", - "4 1.000000 1.000000 0.000000 1.000000 1.000000 0.0 -1.794281 0.282144 \n", - "5 0.000000 1.000000 1.000000 1.000000 1.000000 0.5 -2.068119 -26.166718 \n", - "6 1.000000 1.000000 1.000000 1.000000 1.000000 0.0 -1.924939 0.120979 \n", - "7 0.846747 0.983599 0.551493 0.331172 1.000000 0.0 -1.546147 -0.139778 \n", - "8 0.835615 0.250119 0.238424 0.642778 0.343853 0.5 -0.929888 0.251818 \n", - "9 0.716120 0.923542 1.000000 0.301291 0.822325 0.5 -1.791427 -11.423274 \n", - "\n", - " f_0_sd f_1_sd f_0_des f_1_des \n", - "0 0.196473 18.588478 2.118929 NaN \n", - "1 0.239787 3.687648 1.965952 NaN \n", - "2 0.123578 18.389180 1.997947 NaN \n", - "3 0.232758 11.176665 1.829928 NaN \n", - "4 0.225904 1.452171 1.794281 NaN \n", - "5 0.241316 31.600352 2.068119 NaN \n", - "6 0.239853 1.585463 1.924939 NaN \n", - "7 0.114476 1.314039 1.546147 NaN \n", - "8 0.016918 1.475402 0.929888 NaN \n", - "9 0.119792 14.862474 1.791427 NaN \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from bofire.data_models.acquisition_functions.api import qNEI, qUCB, qSR, qEI\n", - "from bofire.data_models.strategies.api import QparegoStrategy, MultiplicativeSoboStrategy, SoboStrategy\n", - "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPEnsemble\n", - "from bofire.data_models.domain.api import Outputs\n", - "\n", - "strategy_data = SoboStrategy(domain=domain1, \n", - " acquisition_function=qEI(), \n", - " surrogate_specs=BotorchSurrogates(surrogates=\n", - " [\n", - " MLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[1]]), lr=1.0, n_epochs=100)\n", - " ]\n", - " )\n", - " )\n", - "\n", - "strategy = strategies.map(strategy_data)\n", - "\n", - "# experiments = DTLZ2(dim=6).f(domain1.inputs.sample(20).astype(float), return_complete=True)\n", - "\n", - "strategy.tell(sample_df)\n", - "candidates = strategy.ask(10)\n", - "\n", - "print(candidates)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Outputs(type='Outputs', features=[ContinuousOutput(type='ContinuousOutput', key='f_0', unit=None, objective=MinimizeObjective(type='MinimizeObjective', w=1.0, bounds=(0, 1))), CategoricalOutput(type='CategoricalOutput', key='f_1', categories=('infeasible', 'feasible'), objective=[0.0, 1.0])])" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from bofire.data_models.objectives.api import Objective\n", - "strategy.domain.outputs.get_by_objective(includes=[Objective, list])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "type='Outputs' features=[ContinuousOutput(type='ContinuousOutput', key='f_0', unit=None, objective=MinimizeObjective(type='MinimizeObjective', w=1.0, bounds=(0, 1))), CategoricalOutput(type='CategoricalOutput', key='f_1', categories=('infeasible', 'feasible'), objective=[0.0, 1.0])]\n", - "(GenericMCObjective(), None, 0.001)\n", - "type='qEI'\n" - ] - } - ], - "source": [ - "print(domain1.outputs)\n", - "\n", - "print(strategy._get_objective_and_constraints())\n", - "\n", - "print(strategy.acquisition_function)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Test gpytorch things" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "=== Original ===\n", - "Should be X^T 1 for 1 the all ones vector\n", - "tensor([0.5000, 1.5000], grad_fn=)\n", - "tensor([4., 6.])\n", - "\n", - "=== Likelihood ===\n", - "tensor([0.6915, 0.9332], grad_fn=)\n", - "tensor(1.6247, grad_fn=)\n", - "tensor([0.7406, 1.2222])\n" - ] - } - ], - "source": [ - "import torch\n", - "from gpytorch.likelihoods import BernoulliLikelihood\n", - "\n", - "# Set up dummy tensors for easy verification\n", - "x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n", - "y = torch.tensor([0.5, 0.0], requires_grad=True)\n", - "out = torch.matmul(x, y)\n", - "print(f\"=== Original ===\")\n", - "print(f\"Should be X^T 1 for 1 the all ones vector\")\n", - "print(out)\n", - "out = out.sum()\n", - "out.backward()\n", - "print(y.grad)\n", - "\n", - "print(f\"\\n=== Likelihood ===\")\n", - "like = BernoulliLikelihood()\n", - "x1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n", - "y1 = torch.tensor([0.5, 0.0], requires_grad=True)\n", - "out1 = like(torch.matmul(x1, y1))\n", - "print(out1.probs)\n", - "out1 = out1.probs.sum()\n", - "print(out1)\n", - "out1.backward()\n", - "print(y1.grad)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Add Classification Models for Surrogates\n", - "\n", - "Updating the surrogates to allow for classification of output values (i.e. 'feasible' or 'infeasible').\n", - "\n", - "### Housekeeping changes\n", - "\n", - "1. Update the categorical input/outputs ('bofire/data_models/features/categorical.py') to always return a tuple instead of a list for `categories` and attribute (to prevent mutation)\n", - " - Associated test are changed in 'tests/bofire/data_models/specs/features.py'\n", - "2. \n", - "\n", - "### Classification Models\n", - "\n", - "Initially, we are only interested in checking whether or not certain points are feasible or infeasible, hence this is a binary classification problem. \n", - "\n", - "\n", - "### Questions\n", - "\n", - "1. Should we force `allowed` to be a tuple for the categorical input/outputs? If so, we need to refactor indexing for Pandas DFs..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bofire", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.0" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "6f21737eef49beedf03d74399b47fe38d73eff760737ca33d38b9fe616638e91" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb new file mode 100644 index 000000000..6ef78882d --- /dev/null +++ b/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb @@ -0,0 +1,638 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Classification Surrogate Tests\n", + "\n", + "We are interested in testing whether or not a surrogate model can correctly identify unknown constraints based on binary feasibility/infeasibility. This involves new models which produce `CategoricalOutput`s rather than continuous outputs. Mathematically, instead of multiplying the objective by $\\sigma(x)\\in(0,1)$, we multiply by $I(x)$ which is 1 if $x\\in X$ otherwise it is 0. Since currently BoTorch does not offer support for discrete feasibility constraints (see: [here](https://github.com/pytorch/botorch/blob/main/botorch/utils/objective.py#L122)), we will instead always multiply our objective directly by the feasibility value\n", + "\n", + "In our toy example, the feasible points satisfy $x_1+x_2<= 1.0$." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "# Import packages\n", + "import bofire.strategies.api as strategies\n", + "from bofire.data_models.api import Domain, Outputs, Inputs\n", + "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", + "from bofire.data_models.objectives.api import MinimizeObjective\n", + "import numpy as np" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Manual setup of the optimization domain\n", + "\n", + "The following cell shows how to manually setup the optimization problem in BoFire for didactic purposes. We design a feasible set and output constraints for example." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5f_0f_1
00.9346010.9248620.2232390.6581410.3949670.0-0.999983infeasible
10.7057690.1402500.9892530.1564190.3472860.0-0.694829feasible
20.5285490.9678690.6534190.4012100.8224780.0-0.973224infeasible
30.5395490.0059630.6732140.9118840.6723870.0-0.943222feasible
40.4040300.0466330.6285720.7636450.9522510.5-0.988236feasible
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0 f_1\n", + "0 0.934601 0.924862 0.223239 0.658141 0.394967 0.0 -0.999983 infeasible\n", + "1 0.705769 0.140250 0.989253 0.156419 0.347286 0.0 -0.694829 feasible\n", + "2 0.528549 0.967869 0.653419 0.401210 0.822478 0.0 -0.973224 infeasible\n", + "3 0.539549 0.005963 0.673214 0.911884 0.672387 0.0 -0.943222 feasible\n", + "4 0.404030 0.046633 0.628572 0.763645 0.952251 0.5 -0.988236 feasible" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Set-up the inputs and outputs, use categorical domain just as an example\n", + "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(0, 1)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=(0.5, 0.0))])\n", + "\n", + "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", + "output_features = Outputs(features=[\n", + " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", + " CategoricalOutput(key=f\"f_{1}\", categories=[\"infeasible\", \"feasible\"], objective=[0, 1]) # This function will be associated with learning the feasibility/infeasibility\n", + " ]\n", + ")\n", + "\n", + "# Create domain\n", + "domain1 = Domain(inputs=input_features, outputs=output_features)\n", + "\n", + "# Sample random points\n", + "sample_df = domain1.inputs.sample(20).astype(float) # Sample x's\n", + "\n", + "# Write a function which outputs one continuous variable and another discrete based on some logic\n", + "# Here, feasible points are points whose first two components sum to less then 1.0 - in real experiments, these would not be known\n", + "sample_df[\"f_0\"] = np.cos(sample_df.values.sum(1))\n", + "sample_df[\"f_1\"] = \"infeasible\"\n", + "sample_df.loc[sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0, \"f_1\"] = \"feasible\"\n", + "\n", + "sample_df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5f_0f_1
01.01.00.01.00.00.0-1.0infeasible
11.00.01.00.00.00.0-1.0feasible
21.01.01.00.01.00.0-1.0infeasible
31.00.01.01.01.00.0-1.0feasible
40.00.01.01.01.00.0-1.0feasible
50.00.01.01.00.00.0-1.0feasible
60.01.01.00.01.00.0-1.0feasible
70.01.00.00.00.00.0-0.0infeasible
81.01.01.01.00.00.0-1.0infeasible
91.01.01.00.01.00.0-1.0infeasible
101.00.01.01.01.00.0-1.0infeasible
111.01.01.01.01.00.0-1.0infeasible
120.01.00.00.01.00.0-1.0infeasible
131.00.01.01.00.00.0-1.0infeasible
140.00.01.00.01.00.0-1.0feasible
151.01.00.01.00.00.0-1.0infeasible
161.00.01.01.00.00.0-1.0feasible
170.01.00.01.01.00.0-1.0infeasible
181.01.01.01.01.00.0-1.0infeasible
191.01.00.00.00.00.0-1.0infeasible
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0 f_1\n", + "0 1.0 1.0 0.0 1.0 0.0 0.0 -1.0 infeasible\n", + "1 1.0 0.0 1.0 0.0 0.0 0.0 -1.0 feasible\n", + "2 1.0 1.0 1.0 0.0 1.0 0.0 -1.0 infeasible\n", + "3 1.0 0.0 1.0 1.0 1.0 0.0 -1.0 feasible\n", + "4 0.0 0.0 1.0 1.0 1.0 0.0 -1.0 feasible\n", + "5 0.0 0.0 1.0 1.0 0.0 0.0 -1.0 feasible\n", + "6 0.0 1.0 1.0 0.0 1.0 0.0 -1.0 feasible\n", + "7 0.0 1.0 0.0 0.0 0.0 0.0 -0.0 infeasible\n", + "8 1.0 1.0 1.0 1.0 0.0 0.0 -1.0 infeasible\n", + "9 1.0 1.0 1.0 0.0 1.0 0.0 -1.0 infeasible\n", + "10 1.0 0.0 1.0 1.0 1.0 0.0 -1.0 infeasible\n", + "11 1.0 1.0 1.0 1.0 1.0 0.0 -1.0 infeasible\n", + "12 0.0 1.0 0.0 0.0 1.0 0.0 -1.0 infeasible\n", + "13 1.0 0.0 1.0 1.0 0.0 0.0 -1.0 infeasible\n", + "14 0.0 0.0 1.0 0.0 1.0 0.0 -1.0 feasible\n", + "15 1.0 1.0 0.0 1.0 0.0 0.0 -1.0 infeasible\n", + "16 1.0 0.0 1.0 1.0 0.0 0.0 -1.0 feasible\n", + "17 0.0 1.0 0.0 1.0 1.0 0.0 -1.0 infeasible\n", + "18 1.0 1.0 1.0 1.0 1.0 0.0 -1.0 infeasible\n", + "19 1.0 1.0 0.0 0.0 0.0 0.0 -1.0 infeasible" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_df.round()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup of the Strategy and ask for Candidates\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CategoricalMethodEnum.EXHAUSTIVE\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, + { + "ename": "ValueError", + "evalue": "values present are not in ('infeasible', 'feasible')", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mc:\\Users\\G15361\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\tutorials\\basic_examples\\Unknown_Binary_Constraint_Classification.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m 15\u001b[0m strategy \u001b[39m=\u001b[39m strategies\u001b[39m.\u001b[39mmap(strategy_data)\n\u001b[0;32m 17\u001b[0m strategy\u001b[39m.\u001b[39mtell(sample_df)\n\u001b[1;32m---> 18\u001b[0m candidates \u001b[39m=\u001b[39m strategy\u001b[39m.\u001b[39;49mask(\u001b[39m1\u001b[39;49m)\n\u001b[0;32m 20\u001b[0m candidates\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\predictive.py:56\u001b[0m, in \u001b[0;36mPredictiveStrategy.ask\u001b[1;34m(self, candidate_count, add_pending, raise_validation_error)\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Function to generate new candidates.\u001b[39;00m\n\u001b[0;32m 41\u001b[0m \n\u001b[0;32m 42\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[39m pd.DataFrame: DataFrame with candidates (proposed experiments)\u001b[39;00m\n\u001b[0;32m 50\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 51\u001b[0m candidates \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39mask(\n\u001b[0;32m 52\u001b[0m candidate_count\u001b[39m=\u001b[39mcandidate_count,\n\u001b[0;32m 53\u001b[0m add_pending\u001b[39m=\u001b[39madd_pending,\n\u001b[0;32m 54\u001b[0m raise_validation_error\u001b[39m=\u001b[39mraise_validation_error,\n\u001b[0;32m 55\u001b[0m )\n\u001b[1;32m---> 56\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdomain\u001b[39m.\u001b[39;49mvalidate_candidates(\n\u001b[0;32m 57\u001b[0m candidates\u001b[39m=\u001b[39;49mcandidates, raise_validation_error\u001b[39m=\u001b[39;49mraise_validation_error\n\u001b[0;32m 58\u001b[0m )\n\u001b[0;32m 59\u001b[0m \u001b[39mreturn\u001b[39;00m candidates\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\data_models\\domain\\domain.py:581\u001b[0m, in \u001b[0;36mDomain.validate_candidates\u001b[1;34m(self, candidates, only_inputs, tol, raise_validation_error)\u001b[0m\n\u001b[0;32m 579\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m only_inputs:\n\u001b[0;32m 580\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutputs, Outputs)\n\u001b[1;32m--> 581\u001b[0m candidates \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49moutputs\u001b[39m.\u001b[39;49mvalidate_candidates(candidates\u001b[39m=\u001b[39;49mcandidates)\n\u001b[0;32m 582\u001b[0m \u001b[39mreturn\u001b[39;00m candidates\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\data_models\\domain\\features.py:711\u001b[0m, in \u001b[0;36mOutputs.validate_candidates\u001b[1;34m(self, candidates)\u001b[0m\n\u001b[0;32m 709\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mmissing column \u001b[39m\u001b[39m{\u001b[39;00mcol\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 710\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(candidates[col]) \u001b[39m-\u001b[39m candidates[col]\u001b[39m.\u001b[39misin(categorical_values[ind])\u001b[39m.\u001b[39msum() \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m--> 711\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mvalues present are not in \u001b[39m\u001b[39m{\u001b[39;00mcategorical_values[ind]\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 712\u001b[0m \u001b[39mreturn\u001b[39;00m candidates\n", + "\u001b[1;31mValueError\u001b[0m: values present are not in ('infeasible', 'feasible')" + ] + } + ], + "source": [ + "from bofire.data_models.acquisition_functions.api import qNEI, qUCB, qSR, qEI\n", + "from bofire.data_models.strategies.api import QparegoStrategy, MultiplicativeSoboStrategy, SoboStrategy\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPEnsemble\n", + "from bofire.data_models.domain.api import Outputs\n", + "\n", + "strategy_data = SoboStrategy(domain=domain1, \n", + " acquisition_function=qEI(), \n", + " surrogate_specs=BotorchSurrogates(surrogates=\n", + " [\n", + " MLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[1]]), lr=1.0, n_epochs=100, hidden_layer_sizes=(20,))\n", + " ]\n", + " )\n", + " )\n", + "\n", + "strategy = strategies.map(strategy_data)\n", + "\n", + "strategy.tell(sample_df)\n", + "candidates = strategy.ask(1)\n", + "\n", + "candidates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([[-19.4762]], dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import pandas as pd\n", + "# strategy.surrogates.surrogates[0].model.posterior(torch.tensor(domain1.inputs.sample(20).astype(float).values))\n", + "t = torch.tensor(domain1.inputs.transform(domain1.inputs.sample(1), strategy.surrogates.surrogates[0].input_preprocessing_specs).values)\n", + "strategy.surrogates.surrogates[0].model.posterior(t).mean\n", + "# domain1.outputs[1](pd.Series([1.0]))\n", + "# domain1.inputs.transform(domain1.inputs.sample(20), domain1.inputs.Config)\n", + "# strategy.surrogates.surrogates[0].input_preprocessing_specs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Add Classification Models for Surrogates\n", + "\n", + "Updating the surrogates to allow for classification of output values (i.e. 'feasible' or 'infeasible').\n", + "\n", + "### Housekeeping changes\n", + "\n", + "1. Update the categorical input/outputs ('bofire/data_models/features/categorical.py') to always return a tuple instead of a list for `categories` and attribute (to prevent mutation)\n", + " - Associated test are changed in 'tests/bofire/data_models/specs/features.py'\n", + "2. \n", + "\n", + "### Classification Models\n", + "\n", + "Initially, we are only interested in checking whether or not certain points are feasible or infeasible, hence this is a binary classification problem. \n", + "\n", + "\n", + "### Questions\n", + "\n", + "1. Should we force `allowed` to be a tuple for the categorical input/outputs? If so, we need to refactor indexing for Pandas DFs..." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bofire", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "6f21737eef49beedf03d74399b47fe38d73eff760737ca33d38b9fe616638e91" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From bb14774a1c12a4f0c925db0a68689d2e3a732249 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 28 Sep 2023 16:29:39 -0400 Subject: [PATCH 05/31] Sync categories and objectives --- bofire/data_models/domain/features.py | 6 +- bofire/data_models/features/categorical.py | 17 +- bofire/data_models/surrogates/api.py | 3 + .../surrogates/botorch_surrogates.py | 2 + .../data_models/surrogates/mlp_classifier.py | 22 + bofire/strategies/predictives/predictive.py | 12 +- bofire/surrogates/api.py | 4 +- bofire/surrogates/mapper.py | 5 +- bofire/surrogates/mlp_classifier.py | 21 +- ...own_Binary_Constraint_Classification.ipynb | 501 +++++------------- 10 files changed, 183 insertions(+), 410 deletions(-) create mode 100644 bofire/data_models/surrogates/mlp_classifier.py diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index 2c5d4fbbc..c3436f7e0 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -678,8 +678,8 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: continuous_cols = list( itertools.chain.from_iterable( [ - [f"{key}_pred", f"{key}_sd", f"{key}_des"] - for key in self.get_keys_by_objective(includes=Objective) + [f"{obj.key}_pred", f"{obj.key}_sd", f"{obj.key}_des"] + for obj in self.get_by_objective(includes=Objective) if not isinstance(obj.type, CategoricalOutput) ] ) ) @@ -696,7 +696,7 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: if candidates[col].isnull().to_numpy().any(): raise ValueError(f"Nan values are present in {col}.") # Check for categorical output - categorical_objectives = self.get_by_objective(excludes=Objective, includes=None) + categorical_objectives = [obj for obj in self.get_by_objective(excludes=Objective, includes=None) if isinstance(obj.type, CategoricalOutput)] if len(categorical_objectives) == 0: return candidates categorical_cols = [ diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index a50020a84..c3c7c0ccd 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -403,9 +403,18 @@ def to_dict(self) -> Dict: """Returns the catergories and corresponding objective values as dictionary""" return dict(zip(self.categories, self.objective)) - def to_dict_numeric(self) -> Dict: - """Returns the catergories and corresponding objective values as dictionary""" - return dict(zip(self.objective, self.categories)) + def to_dict_label(self) -> Dict: + """Returns the catergories and label location of categories""" + return dict(zip(self.categories, [i for i in range(len(self.categories))])) + + def from_dict_label(self) -> Dict: + """Returns the label location and the categories""" + d = self.to_dict_label() + return dict(zip(d.values(), d.keys())) + + def map_to_categories(self, values: pd.Series) -> pd.Series: + """Maps the input array to the categories""" + return values.round().astype(int).map(self.from_dict_label()) def __call__(self, values: pd.Series) -> pd.Series: - return values.round().map(self.to_dict_numeric()) + return values.map(self.to_dict()) diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index 3ba5749e7..88fa46dda 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -15,6 +15,7 @@ MixedSingleTaskGPSurrogate, ) from bofire.data_models.surrogates.mlp import MLPEnsemble + from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import ( SingleTaskGPHyperconfig, @@ -32,6 +33,7 @@ RandomForestSurrogate, SingleTaskGPSurrogate, MixedSingleTaskGPSurrogate, + MLPClassifierEnsemble, MLPEnsemble, SaasSingleTaskGPSurrogate, XGBoostSurrogate, @@ -43,6 +45,7 @@ RandomForestSurrogate, SingleTaskGPSurrogate, MixedSingleTaskGPSurrogate, + MLPClassifierEnsemble, MLPEnsemble, SaasSingleTaskGPSurrogate, XGBoostSurrogate, diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index b78f6c778..5fa18ac15 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -12,6 +12,7 @@ from bofire.data_models.surrogates.mixed_single_task_gp import ( MixedSingleTaskGPSurrogate, ) +from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.data_models.surrogates.mlp import MLPEnsemble from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate @@ -22,6 +23,7 @@ RandomForestSurrogate, SingleTaskGPSurrogate, MixedSingleTaskGPSurrogate, + MLPClassifierEnsemble, MLPEnsemble, SaasSingleTaskGPSurrogate, TanimotoGPSurrogate, diff --git a/bofire/data_models/surrogates/mlp_classifier.py b/bofire/data_models/surrogates/mlp_classifier.py new file mode 100644 index 000000000..6e6c59400 --- /dev/null +++ b/bofire/data_models/surrogates/mlp_classifier.py @@ -0,0 +1,22 @@ +from typing import Annotated, Literal, Sequence + +from pydantic import Field + +from bofire.data_models.surrogates.botorch import BotorchSurrogate +from bofire.data_models.surrogates.scaler import ScalerEnum +from bofire.data_models.surrogates.trainable import TrainableSurrogate + + +class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): + type: Literal["MLPClassifierEnsemble"] = "MLPClassifierEnsemble" + n_estimators: Annotated[int, Field(ge=1)] = 5 + hidden_layer_sizes: Sequence = (100,) + activation: Literal["relu", "logistic", "tanh"] = "relu" + dropout: Annotated[float, Field(ge=0.0)] = 0.0 + batch_size: Annotated[int, Field(ge=1)] = 10 + n_epochs: Annotated[int, Field(ge=1)] = 200 + lr: Annotated[float, Field(gt=0.0)] = 1e-4 + weight_decay: Annotated[float, Field(ge=0.0)] = 0.0 + subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 + shuffle: bool = True + scaler: ScalerEnum = ScalerEnum.NORMALIZE diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index 6a55b0f44..a54de4e85 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -5,7 +5,7 @@ import pandas as pd from pydantic import PositiveInt -from bofire.data_models.features.api import TInputTransformSpecs +from bofire.data_models.features.api import CategoricalOutput, TInputTransformSpecs from bofire.data_models.strategies.api import Strategy as DataModel from bofire.strategies.data_models.candidate import Candidate from bofire.strategies.data_models.values import InputValue, OutputValue @@ -98,13 +98,6 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: raise ValueError("Model not yet fitted.") # TODO: validate also here the experiments but only for the input_columns # transformed = self.transformer.transform(experiments) - - ############################ - # TODO: Here, we need to separate by domain.outputs into continuous and categorical outputs. For continuous outputs, we leave as is, for categorical, we perform the desired mapping - # We then need to modify the input to the ._fit method for the surrogates to be categorically appropriate based on domain.outputs - # Finally, we need to modify the acquisition function to handle hard constraints (see how we can modify based on CBO's implementation and and remedy with BoTorch only using differentiable constraints in the `constraints` argument) - # Then, write tests/specs - ############################ transformed = self.domain.inputs.transform( experiments=experiments, specs=self.input_preprocessing_specs ) @@ -120,6 +113,9 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: data=preds, columns=["%s_pred" % feat.key for feat in self.domain.outputs.get()], ) + categorical_df = pd.DataFrame.from_dict({f"{feat.key}_pred": feat.map_to_categories(predictions[f"{feat.key}_pred"]) for feat in self.domain.outputs.get() if isinstance(feat, CategoricalOutput)}) + if not categorical_df.empty: + predictions.update(categorical_df) desis = self.domain.outputs(predictions, predictions=True) predictions = pd.concat((predictions, desis), axis=1) predictions.index = experiments.index diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index 65130eef8..b64b3f5aa 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -2,8 +2,8 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.mapper import map from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -from bofire.surrogates.mlp_classifier import MLPEnsemble -# from bofire.surrogates.mlp import MLPEnsemble +from bofire.surrogates.mlp_classifier import MLPClassifierEnsemble +from bofire.surrogates.mlp import MLPEnsemble from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 7079c6be9..e17a4f3e9 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -4,8 +4,8 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -# from bofire.surrogates.mlp import MLPEnsemble -from bofire.surrogates.mlp_classifier import MLPEnsemble +from bofire.surrogates.mlp import MLPEnsemble +from bofire.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate @@ -16,6 +16,7 @@ data_models.RandomForestSurrogate: RandomForestSurrogate, data_models.SingleTaskGPSurrogate: SingleTaskGPSurrogate, data_models.MixedSingleTaskGPSurrogate: MixedSingleTaskGPSurrogate, + data_models.MLPClassifierEnsemble: MLPClassifierEnsemble, data_models.MLPEnsemble: MLPEnsemble, data_models.SaasSingleTaskGPSurrogate: SaasSingleTaskGPSurrogate, data_models.XGBoostSurrogate: XGBoostSurrogate, diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py index 58a888e22..3c44a3050 100644 --- a/bofire/surrogates/mlp_classifier.py +++ b/bofire/surrogates/mlp_classifier.py @@ -70,7 +70,7 @@ def __init__( self.layers = nn.Sequential(*layers) def forward(self, x): - return nn.functional.sigmoid(self.layers(x)) + return nn.functional.log_softmax(self.layers(x), dim=1) class _MLPClassifierEnsemble(EnsembleModel): @@ -98,7 +98,7 @@ def forward(self, X: Tensor): A `batch_shape x s x n x m`-dimensional output tensor where `s` is the size of the ensemble. """ - return torch.stack([mlp(X) for mlp in self.mlps], dim=-3) + return torch.stack([torch.argmax(mlp(X), dim=-1).unsqueeze(-1).float() for mlp in self.mlps], dim=-3) @property def num_outputs(self) -> int: @@ -128,15 +128,13 @@ def fit_mlp( """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) - loss_function = nn.BCELoss() + loss_function = nn.NLLLoss(reduction='mean') optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) for _ in range(n_epoches): current_loss = 0.0 for data in train_loader: # Get and prepare inputs inputs, targets = data - if len(targets.shape) == 1: - targets = targets.reshape((targets.shape[0], 1)) # Zero the gradients optimizer.zero_grad() @@ -145,7 +143,7 @@ def fit_mlp( outputs = mlp(inputs) # Compute loss - loss = loss_function(outputs, targets) + loss = loss_function(outputs, targets.flatten().long()) # Perform backward pass loss.backward() @@ -157,7 +155,7 @@ def fit_mlp( current_loss += loss.item() -class MLPEnsemble(BotorchSurrogate, TrainableSurrogate): +class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): def __init__(self, data_model: DataModel, **kwargs): self.n_estimators = data_model.n_estimators self.hidden_layer_sizes = data_model.hidden_layer_sizes @@ -178,11 +176,12 @@ def __init__(self, data_model: DataModel, **kwargs): def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) + # Map dictionary to objective values - gives what is feasible and to labels - to perform opt + label_mapping = self.outputs[0].to_dict_label() # Convert Y to classification tensor - Y = pd.DataFrame.from_dict({col: np.unique(Y[col].values, return_inverse=True)[1] for col in Y.columns}) - # Y = Y.apply(lambda x: pd.factorize(x, sort=True)[0]) - # print(f"X: {X}, Y={Y}") + Y = pd.DataFrame.from_dict({col: Y[col].map(label_mapping) for col in Y.columns}) + mlps = [] subsample_size = round(self.subsample_fraction * X.shape[0]) for _ in range(self.n_estimators): @@ -197,7 +196,7 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): ) mlp = MLPClassifier( input_size=transformed_X.shape[1], - output_size=1, + output_size=len(label_mapping), # Set outputs based on number of categories hidden_layer_sizes=self.hidden_layer_sizes, activation=self.activation, # type: ignore dropout=self.dropout, diff --git a/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb index 6ef78882d..13518e12c 100644 --- a/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb @@ -30,7 +30,7 @@ "import bofire.strategies.api as strategies\n", "from bofire.data_models.api import Domain, Outputs, Inputs\n", "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", - "from bofire.data_models.objectives.api import MinimizeObjective\n", + "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective\n", "import numpy as np" ] }, @@ -83,57 +83,57 @@ " \n", " \n", " 0\n", - " 0.934601\n", - " 0.924862\n", - " 0.223239\n", - " 0.658141\n", - " 0.394967\n", + " 0.502015\n", + " 0.556355\n", + " 0.184407\n", + " 0.135966\n", + " 0.191290\n", " 0.0\n", - " -0.999983\n", + " 0.000763\n", " infeasible\n", " \n", " \n", " 1\n", - " 0.705769\n", - " 0.140250\n", - " 0.989253\n", - " 0.156419\n", - " 0.347286\n", - " 0.0\n", - " -0.694829\n", + " 0.400085\n", + " 0.351817\n", + " 0.848917\n", + " 0.924594\n", + " 0.657947\n", + " 0.5\n", + " -0.856798\n", " feasible\n", " \n", " \n", " 2\n", - " 0.528549\n", - " 0.967869\n", - " 0.653419\n", - " 0.401210\n", - " 0.822478\n", + " 0.694559\n", + " 0.481801\n", + " 0.239983\n", + " 0.528414\n", + " 0.642616\n", " 0.0\n", - " -0.973224\n", + " -0.850312\n", " infeasible\n", " \n", " \n", " 3\n", - " 0.539549\n", - " 0.005963\n", - " 0.673214\n", - " 0.911884\n", - " 0.672387\n", - " 0.0\n", - " -0.943222\n", - " feasible\n", + " 0.678207\n", + " 0.840033\n", + " 0.298988\n", + " 0.925851\n", + " 0.740847\n", + " 0.5\n", + " -0.665724\n", + " infeasible\n", " \n", " \n", " 4\n", - " 0.404030\n", - " 0.046633\n", - " 0.628572\n", - " 0.763645\n", - " 0.952251\n", + " 0.006785\n", + " 0.787913\n", + " 0.778813\n", + " 0.861638\n", + " 0.290225\n", " 0.5\n", - " -0.988236\n", + " -0.996492\n", " feasible\n", " \n", " \n", @@ -142,11 +142,11 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 f_1\n", - "0 0.934601 0.924862 0.223239 0.658141 0.394967 0.0 -0.999983 infeasible\n", - "1 0.705769 0.140250 0.989253 0.156419 0.347286 0.0 -0.694829 feasible\n", - "2 0.528549 0.967869 0.653419 0.401210 0.822478 0.0 -0.973224 infeasible\n", - "3 0.539549 0.005963 0.673214 0.911884 0.672387 0.0 -0.943222 feasible\n", - "4 0.404030 0.046633 0.628572 0.763645 0.952251 0.5 -0.988236 feasible" + "0 0.502015 0.556355 0.184407 0.135966 0.191290 0.0 0.000763 infeasible\n", + "1 0.400085 0.351817 0.848917 0.924594 0.657947 0.5 -0.856798 feasible\n", + "2 0.694559 0.481801 0.239983 0.528414 0.642616 0.0 -0.850312 infeasible\n", + "3 0.678207 0.840033 0.298988 0.925851 0.740847 0.5 -0.665724 infeasible\n", + "4 0.006785 0.787913 0.778813 0.861638 0.290225 0.5 -0.996492 feasible" ] }, "execution_count": 2, @@ -161,6 +161,7 @@ "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", "output_features = Outputs(features=[\n", " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", + " # ContinuousOutput(key=f\"f_{2}\", objective=MinimizeSigmoidObjective(w=1., tp=0.0, steepness=0.5)),\n", " CategoricalOutput(key=f\"f_{1}\", categories=[\"infeasible\", \"feasible\"], objective=[0, 1]) # This function will be associated with learning the feasibility/infeasibility\n", " ]\n", ")\n", @@ -176,15 +177,62 @@ "sample_df[\"f_0\"] = np.cos(sample_df.values.sum(1))\n", "sample_df[\"f_1\"] = \"infeasible\"\n", "sample_df.loc[sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0, \"f_1\"] = \"feasible\"\n", + "# sample_df[\"f_2\"] = np.random.uniform(size=(len(sample_df),))\n", "\n", "sample_df.head(5)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup of the Strategy and ask for Candidates\n", + "\n" + ] + }, { "cell_type": "code", "execution_count": 3, "metadata": {}, + "outputs": [], + "source": [ + "from bofire.data_models.acquisition_functions.api import qNEI, qUCB, qSR, qEI\n", + "from bofire.data_models.strategies.api import QparegoStrategy, MultiplicativeSoboStrategy, SoboStrategy\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate\n", + "from bofire.data_models.domain.api import Outputs\n", + "\n", + "strategy_data = SoboStrategy(domain=domain1, \n", + " acquisition_function=qEI(), \n", + " surrogate_specs=BotorchSurrogates(surrogates=\n", + " [\n", + " MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[-1]]), lr=1.0, n_epochs=100, hidden_layer_sizes=(20,)),\n", + " # MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[1]]))\n", + " ]\n", + " )\n", + " )\n", + "\n", + "strategy = strategies.map(strategy_data)\n", + "\n", + "strategy.tell(sample_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, { "data": { "text/html": [ @@ -212,377 +260,70 @@ " x_3\n", " x_4\n", " x_5\n", - " f_0\n", - " f_1\n", + " f_0_pred\n", + " f_1_pred\n", + " f_0_sd\n", + " f_1_sd\n", + " f_0_des\n", + " f_1_des\n", " \n", " \n", " \n", " \n", " 0\n", + " 0.723341\n", " 1.0\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 1\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " feasible\n", - " \n", - " \n", - " 2\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 3\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " feasible\n", - " \n", - " \n", - " 4\n", - " 0.0\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " feasible\n", - " \n", - " \n", - " 5\n", - " 0.0\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " feasible\n", - " \n", - " \n", - " 6\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " feasible\n", - " \n", - " \n", - " 7\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " 0.0\n", - " 0.0\n", - " -0.0\n", - " infeasible\n", - " \n", - " \n", - " 8\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 9\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 10\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 11\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 12\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 13\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 14\n", - " 0.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", + " 0.703635\n", + " 0.426167\n", + " 0.328014\n", + " 0.5\n", + " -1.267539\n", " feasible\n", - " \n", - " \n", - " 15\n", - " 1.0\n", + " 0.042117\n", + " 0.547723\n", + " 1.267539\n", " 1.0\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", " \n", " \n", - " 16\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", + " 1\n", + " 0.000000\n", " 1.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", + " 0.312203\n", + " 0.450019\n", + " 0.388926\n", + " 0.5\n", + " -1.250402\n", " feasible\n", - " \n", - " \n", - " 17\n", - " 0.0\n", - " 1.0\n", - " 0.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 18\n", + " 0.042963\n", + " 0.547723\n", + " 1.250402\n", " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", - " \n", - " \n", - " 19\n", - " 1.0\n", - " 1.0\n", - " 0.0\n", - " 0.0\n", - " 0.0\n", - " 0.0\n", - " -1.0\n", - " infeasible\n", " \n", " \n", "\n", "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0 f_1\n", - "0 1.0 1.0 0.0 1.0 0.0 0.0 -1.0 infeasible\n", - "1 1.0 0.0 1.0 0.0 0.0 0.0 -1.0 feasible\n", - "2 1.0 1.0 1.0 0.0 1.0 0.0 -1.0 infeasible\n", - "3 1.0 0.0 1.0 1.0 1.0 0.0 -1.0 feasible\n", - "4 0.0 0.0 1.0 1.0 1.0 0.0 -1.0 feasible\n", - "5 0.0 0.0 1.0 1.0 0.0 0.0 -1.0 feasible\n", - "6 0.0 1.0 1.0 0.0 1.0 0.0 -1.0 feasible\n", - "7 0.0 1.0 0.0 0.0 0.0 0.0 -0.0 infeasible\n", - "8 1.0 1.0 1.0 1.0 0.0 0.0 -1.0 infeasible\n", - "9 1.0 1.0 1.0 0.0 1.0 0.0 -1.0 infeasible\n", - "10 1.0 0.0 1.0 1.0 1.0 0.0 -1.0 infeasible\n", - "11 1.0 1.0 1.0 1.0 1.0 0.0 -1.0 infeasible\n", - "12 0.0 1.0 0.0 0.0 1.0 0.0 -1.0 infeasible\n", - "13 1.0 0.0 1.0 1.0 0.0 0.0 -1.0 infeasible\n", - "14 0.0 0.0 1.0 0.0 1.0 0.0 -1.0 feasible\n", - "15 1.0 1.0 0.0 1.0 0.0 0.0 -1.0 infeasible\n", - "16 1.0 0.0 1.0 1.0 0.0 0.0 -1.0 feasible\n", - "17 0.0 1.0 0.0 1.0 1.0 0.0 -1.0 infeasible\n", - "18 1.0 1.0 1.0 1.0 1.0 0.0 -1.0 infeasible\n", - "19 1.0 1.0 0.0 0.0 0.0 0.0 -1.0 infeasible" + " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_1_pred \\\n", + "0 0.723341 1.0 0.703635 0.426167 0.328014 0.5 -1.267539 feasible \n", + "1 0.000000 1.0 0.312203 0.450019 0.388926 0.5 -1.250402 feasible \n", + "\n", + " f_0_sd f_1_sd f_0_des f_1_des \n", + "0 0.042117 0.547723 1.267539 1.0 \n", + "1 0.042963 0.547723 1.250402 1.0 " ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "sample_df.round()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup of the Strategy and ask for Candidates\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CategoricalMethodEnum.EXHAUSTIVE\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - }, - { - "ename": "ValueError", - "evalue": "values present are not in ('infeasible', 'feasible')", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mc:\\Users\\G15361\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\tutorials\\basic_examples\\Unknown_Binary_Constraint_Classification.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m 15\u001b[0m strategy \u001b[39m=\u001b[39m strategies\u001b[39m.\u001b[39mmap(strategy_data)\n\u001b[0;32m 17\u001b[0m strategy\u001b[39m.\u001b[39mtell(sample_df)\n\u001b[1;32m---> 18\u001b[0m candidates \u001b[39m=\u001b[39m strategy\u001b[39m.\u001b[39;49mask(\u001b[39m1\u001b[39;49m)\n\u001b[0;32m 20\u001b[0m candidates\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\predictive.py:56\u001b[0m, in \u001b[0;36mPredictiveStrategy.ask\u001b[1;34m(self, candidate_count, add_pending, raise_validation_error)\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Function to generate new candidates.\u001b[39;00m\n\u001b[0;32m 41\u001b[0m \n\u001b[0;32m 42\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[39m pd.DataFrame: DataFrame with candidates (proposed experiments)\u001b[39;00m\n\u001b[0;32m 50\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 51\u001b[0m candidates \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39mask(\n\u001b[0;32m 52\u001b[0m candidate_count\u001b[39m=\u001b[39mcandidate_count,\n\u001b[0;32m 53\u001b[0m add_pending\u001b[39m=\u001b[39madd_pending,\n\u001b[0;32m 54\u001b[0m raise_validation_error\u001b[39m=\u001b[39mraise_validation_error,\n\u001b[0;32m 55\u001b[0m )\n\u001b[1;32m---> 56\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdomain\u001b[39m.\u001b[39;49mvalidate_candidates(\n\u001b[0;32m 57\u001b[0m candidates\u001b[39m=\u001b[39;49mcandidates, raise_validation_error\u001b[39m=\u001b[39;49mraise_validation_error\n\u001b[0;32m 58\u001b[0m )\n\u001b[0;32m 59\u001b[0m \u001b[39mreturn\u001b[39;00m candidates\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\data_models\\domain\\domain.py:581\u001b[0m, in \u001b[0;36mDomain.validate_candidates\u001b[1;34m(self, candidates, only_inputs, tol, raise_validation_error)\u001b[0m\n\u001b[0;32m 579\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m only_inputs:\n\u001b[0;32m 580\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutputs, Outputs)\n\u001b[1;32m--> 581\u001b[0m candidates \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49moutputs\u001b[39m.\u001b[39;49mvalidate_candidates(candidates\u001b[39m=\u001b[39;49mcandidates)\n\u001b[0;32m 582\u001b[0m \u001b[39mreturn\u001b[39;00m candidates\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\data_models\\domain\\features.py:711\u001b[0m, in \u001b[0;36mOutputs.validate_candidates\u001b[1;34m(self, candidates)\u001b[0m\n\u001b[0;32m 709\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mmissing column \u001b[39m\u001b[39m{\u001b[39;00mcol\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 710\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(candidates[col]) \u001b[39m-\u001b[39m candidates[col]\u001b[39m.\u001b[39misin(categorical_values[ind])\u001b[39m.\u001b[39msum() \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m--> 711\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mvalues present are not in \u001b[39m\u001b[39m{\u001b[39;00mcategorical_values[ind]\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 712\u001b[0m \u001b[39mreturn\u001b[39;00m candidates\n", - "\u001b[1;31mValueError\u001b[0m: values present are not in ('infeasible', 'feasible')" - ] - } - ], - "source": [ - "from bofire.data_models.acquisition_functions.api import qNEI, qUCB, qSR, qEI\n", - "from bofire.data_models.strategies.api import QparegoStrategy, MultiplicativeSoboStrategy, SoboStrategy\n", - "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPEnsemble\n", - "from bofire.data_models.domain.api import Outputs\n", - "\n", - "strategy_data = SoboStrategy(domain=domain1, \n", - " acquisition_function=qEI(), \n", - " surrogate_specs=BotorchSurrogates(surrogates=\n", - " [\n", - " MLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[1]]), lr=1.0, n_epochs=100, hidden_layer_sizes=(20,))\n", - " ]\n", - " )\n", - " )\n", - "\n", - "strategy = strategies.map(strategy_data)\n", - "\n", - "strategy.tell(sample_df)\n", - "candidates = strategy.ask(1)\n", + "candidates = strategy.ask(2)\n", "\n", "candidates" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[-19.4762]], dtype=torch.float64, grad_fn=)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "import pandas as pd\n", - "# strategy.surrogates.surrogates[0].model.posterior(torch.tensor(domain1.inputs.sample(20).astype(float).values))\n", - "t = torch.tensor(domain1.inputs.transform(domain1.inputs.sample(1), strategy.surrogates.surrogates[0].input_preprocessing_specs).values)\n", - "strategy.surrogates.surrogates[0].model.posterior(t).mean\n", - "# domain1.outputs[1](pd.Series([1.0]))\n", - "# domain1.inputs.transform(domain1.inputs.sample(20), domain1.inputs.Config)\n", - "# strategy.surrogates.surrogates[0].input_preprocessing_specs" - ] - }, { "cell_type": "markdown", "metadata": {}, From acdbbf2b9e64fd651f5bbc16a80c01f72fea4af8 Mon Sep 17 00:00:00 2001 From: gmancino Date: Tue, 3 Oct 2023 12:18:39 -0400 Subject: [PATCH 06/31] Add categorical objective --- bofire/data_models/domain/features.py | 42 +- bofire/data_models/features/categorical.py | 56 +- bofire/data_models/objectives/api.py | 2 + bofire/data_models/objectives/categorical.py | 40 ++ .../strategies/predictives/qparego.py | 2 +- .../strategies/predictives/sobo.py | 2 +- .../surrogates/botorch_surrogates.py | 2 +- .../doe/utils_categorical_discrete.py | 8 +- bofire/strategies/predictives/predictive.py | 41 +- bofire/surrogates/api.py | 2 +- bofire/surrogates/mlp_classifier.py | 16 +- bofire/utils/torch_tools.py | 58 +- tests/bofire/data_models/specs/features.py | 4 +- tests/bofire/data_models/test_features.py | 24 +- tests/bofire/surrogates/test_diagnostics.py | 2 +- tests/bofire/utils/test_torch_tools.py | 2 +- ...own_Binary_Constraint_Classification.ipynb | 379 ----------- .../Unknown_Constraint_Classification.ipynb | 632 ++++++++++++++++++ 18 files changed, 854 insertions(+), 460 deletions(-) create mode 100644 bofire/data_models/objectives/categorical.py delete mode 100644 tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb create mode 100644 tutorials/basic_examples/Unknown_Constraint_Classification.ipynb diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index c3436f7e0..fd65b0fb3 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -30,7 +30,11 @@ TInputTransformSpecs, ) from bofire.data_models.molfeatures.api import MolFeatures -from bofire.data_models.objectives.api import AbstractObjective, Objective +from bofire.data_models.objectives.api import ( + AbstractObjective, + CategoricalObjective, + Objective, +) FeatureSequence = Union[List[AnyFeature], Tuple[AnyFeature]] @@ -592,7 +596,10 @@ def get_keys_by_objective( Type[Objective], ] = Objective, excludes: Union[ - List[Type[AbstractObjective]], Tuple[Type[AbstractObjective]], Type[AbstractObjective], None + List[Type[AbstractObjective]], + Tuple[Type[AbstractObjective]], + Type[AbstractObjective], + None, ] = None, exact: bool = False, ) -> List[str]: @@ -627,6 +634,14 @@ def __call__( feat(experiments[f"{feat.key}_pred" if predictions else feat.key]) # type: ignore for feat in self.features if feat.objective is not None + and not isinstance(feat, CategoricalOutput) + ] + + [ + feat.compute_objective(experiments.filter(regex=f"{feat.key}_pred_")) # type: ignore + if predictions + else experiments[feat.key] + for feat in self.features + if isinstance(feat, CategoricalOutput) ], axis=1, ) @@ -674,12 +689,13 @@ def validate_experiments(self, experiments: pd.DataFrame) -> pd.DataFrame: def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: # for each continuous output feature with an attached objective object - # ToDo: adjust it for the CategoricalOutput continuous_cols = list( itertools.chain.from_iterable( [ [f"{obj.key}_pred", f"{obj.key}_sd", f"{obj.key}_des"] - for obj in self.get_by_objective(includes=Objective) if not isinstance(obj.type, CategoricalOutput) + for obj in self.get_by_objective( + includes=Objective, excludes=CategoricalObjective + ) ] ) ) @@ -696,19 +712,17 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: if candidates[col].isnull().to_numpy().any(): raise ValueError(f"Nan values are present in {col}.") # Check for categorical output - categorical_objectives = [obj for obj in self.get_by_objective(excludes=Objective, includes=None) if isinstance(obj.type, CategoricalOutput)] - if len(categorical_objectives) == 0: - return candidates categorical_cols = [ - f"{key}_pred" - for key in [categorical_output.key for categorical_output in categorical_objectives.features] + (f"{obj.key}_pred", obj.categories) + for obj in self.get_by_objective(includes=CategoricalObjective) ] - categorical_values = [categorical_output.categories for categorical_output in categorical_objectives.features] - for ind, col in enumerate(categorical_cols): - if col not in candidates: + if len(categorical_cols) == 0: + return candidates + for col in categorical_cols: + if col[0] not in candidates: raise ValueError(f"missing column {col}") - if len(candidates[col]) - candidates[col].isin(categorical_values[ind]).sum() > 0: - raise ValueError(f"values present are not in {categorical_values[ind]}") + if len(candidates[col[0]]) - candidates[col[0]].isin(col[1]).sum() > 0: + raise ValueError(f"values present are not in {col[1]}") return candidates def preprocess_experiments_one_valid_output( diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index c3c7c0ccd..4f24d7e9b 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd from pydantic import Field, root_validator, validator -from typing_extensions import Annotated from bofire.data_models.enum import CategoricalEncodingEnum from bofire.data_models.features.feature import ( @@ -14,6 +13,7 @@ TCategoryVals, TTransform, ) +from bofire.data_models.objectives.categorical import CategoricalObjective class CategoricalInput(Input): @@ -359,9 +359,9 @@ class CategoricalOutput(Output): order_id: ClassVar[int] = 8 categories: TCategoryVals - objective: Annotated[ - List[Annotated[float, Field(type=float, ge=0, le=1)]], Field(min_items=2) - ] + objective: CategoricalObjective = Field( + default_factory=lambda: CategoricalObjective(w=1.0) + ) @validator("categories") def validate_categories_unique(cls, categories): @@ -382,13 +382,20 @@ def validate_categories_unique(cls, categories): @validator("objective") def validate_objective(cls, objective, values): - if len(objective) != len(values["categories"]): + weights = objective.weights + if len(weights) != len(values["categories"]): raise ValueError("Length of objectives and categories do not match.") - for o in objective: - if o > 1: - raise ValueError("Objective values has to be smaller equal than 1.") - if o < 0: - raise ValueError("Objective values has to be larger equal than zero") + for w in weights: + if w > 1: + raise ValueError("Objective weight has to be smaller equal than 1.") + if w < 0: + raise ValueError("Objective weight has to be larger equal than zero") + # Save the categories to the objective if they do not exist + objective.categories = ( + list(values["categories"]) + if objective.categories is None + else objective.categories + ) return objective def validate_experimental(self, values: pd.Series) -> pd.Series: @@ -399,22 +406,33 @@ def validate_experimental(self, values: pd.Series) -> pd.Series: ) return values + def __str__(self) -> str: + return "CategoricalOutputFeature" + def to_dict(self) -> Dict: """Returns the catergories and corresponding objective values as dictionary""" - return dict(zip(self.categories, self.objective)) - + return dict(zip(self.categories, self.objective.weights)) + def to_dict_label(self) -> Dict: """Returns the catergories and label location of categories""" - return dict(zip(self.categories, [i for i in range(len(self.categories))])) - + return {c: i for i, c in enumerate(self.categories)} + def from_dict_label(self) -> Dict: """Returns the label location and the categories""" d = self.to_dict_label() return dict(zip(d.values(), d.keys())) - - def map_to_categories(self, values: pd.Series) -> pd.Series: - """Maps the input array to the categories""" - return values.round().astype(int).map(self.from_dict_label()) + + def map_to_categories(self, values: pd.DataFrame) -> pd.Series: + """Maps the input matrix of probabilities to the categories via argmax""" + return values.idxmax(1).str.replace(f"{self.key}_pred_", "").values + + def compute_objective(self, values: pd.DataFrame) -> pd.Series: + """Computes the objective value as: (p.o).sum() where p is the vector of probabilities and o is the vector of objective values""" + values.columns = values.columns.str.replace(f"{self.key}_pred_", "") + scale_series = pd.Series(self.to_dict()) + return pd.Series( + data=(values * scale_series).sum(1).values, name=f"{self.key}_pred" + ) def __call__(self, values: pd.Series) -> pd.Series: - return values.map(self.to_dict()) + return self.objective(values) diff --git a/bofire/data_models/objectives/api.py b/bofire/data_models/objectives/api.py index d90f76fcc..f5c64c76e 100644 --- a/bofire/data_models/objectives/api.py +++ b/bofire/data_models/objectives/api.py @@ -1,5 +1,6 @@ from typing import Union +from bofire.data_models.objectives.categorical import CategoricalObjective from bofire.data_models.objectives.identity import ( IdentityObjective, MaximizeObjective, @@ -25,6 +26,7 @@ ] AnyConstraintObjective = Union[ + CategoricalObjective, MaximizeSigmoidObjective, MinimizeSigmoidObjective, TargetObjective, diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py new file mode 100644 index 000000000..c2c5a577d --- /dev/null +++ b/bofire/data_models/objectives/categorical.py @@ -0,0 +1,40 @@ +from typing import List, Literal, Union + +import numpy as np +import pandas as pd + +from bofire.data_models.objectives.objective import ( + ConstrainedObjective, + Objective, + TWeight, +) + + +class CategoricalObjective(Objective, ConstrainedObjective): + """Compute the categorical objective value as: + + Po where P is an [n, c] matrix where each row is a probability vector + (P[i, :].sum()=1 for all i) and o is a column vector of objective values + + Attributes: + w (float): float between zero and one for weighting the objective. + weights (list): list of values of size c (c is number of categories) such that the i-th entry is in (0, 1) + """ + + w: TWeight = 1.0 + weights: List[float] + eta: float = 1.0 + categories: Union[List[str], None] = None + + type: Literal["CategoricalObjective"] = "CategoricalObjective" + + def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: + """The call function returning a probabilistic reward for x. + + Args: + x (np.ndarray): A matrix of x values + + Returns: + np.ndarray: A reward calculated as inner product of probabilities and feasible objectives. + """ + return x.map(dict(zip(self.categories, self.weights))) diff --git a/bofire/data_models/strategies/predictives/qparego.py b/bofire/data_models/strategies/predictives/qparego.py index 9776bc001..9d228d4d4 100644 --- a/bofire/data_models/strategies/predictives/qparego.py +++ b/bofire/data_models/strategies/predictives/qparego.py @@ -3,7 +3,7 @@ from pydantic import Field from bofire.data_models.acquisition_functions.api import qEI, qLogEI, qLogNEI, qNEI -from bofire.data_models.features.api import CategoricalOutput, Feature +from bofire.data_models.features.api import Feature from bofire.data_models.objectives.api import ( CloseToTargetObjective, MaximizeObjective, diff --git a/bofire/data_models/strategies/predictives/sobo.py b/bofire/data_models/strategies/predictives/sobo.py index 00e3f84a9..4b777d6d9 100644 --- a/bofire/data_models/strategies/predictives/sobo.py +++ b/bofire/data_models/strategies/predictives/sobo.py @@ -3,7 +3,7 @@ from pydantic import Field, validator from bofire.data_models.acquisition_functions.api import AnyAcquisitionFunction, qNEI -from bofire.data_models.features.api import CategoricalOutput, Feature +from bofire.data_models.features.api import Feature from bofire.data_models.objectives.api import ConstrainedObjective, Objective from bofire.data_models.strategies.predictives.botorch import BotorchStrategy diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index 5fa18ac15..dcc50d5dc 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -12,8 +12,8 @@ from bofire.data_models.surrogates.mixed_single_task_gp import ( MixedSingleTaskGPSurrogate, ) -from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.data_models.surrogates.mlp import MLPEnsemble +from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate diff --git a/bofire/strategies/doe/utils_categorical_discrete.py b/bofire/strategies/doe/utils_categorical_discrete.py index abfe1e340..b70a9f8a0 100644 --- a/bofire/strategies/doe/utils_categorical_discrete.py +++ b/bofire/strategies/doe/utils_categorical_discrete.py @@ -157,13 +157,15 @@ def NChooseKGroup_with_quantity( and group restrictions. """ if quantity_if_picked is not None: - if type(quantity_if_picked) is list and len(keys) != len(quantity_if_picked): + if isinstance(quantity_if_picked, list) and len(keys) != len( + quantity_if_picked + ): raise ValueError( f"number of keys must be the same as corresponding quantities. Received {len(keys)} keys " f"and {len(quantity_if_picked)} quantities" ) - if type(quantity_if_picked) is list and True in [ + if isinstance(quantity_if_picked, list) and True in [ 0 in q for q in quantity_if_picked ]: raise ValueError( @@ -194,7 +196,7 @@ def NChooseKGroup_with_quantity( if True in ["_" in k for k in keys]: raise ValueError('"_" is not allowed as an character in the keys') - if quantity_if_picked is not None and type(quantity_if_picked) != list: + if quantity_if_picked is not None and not isinstance(quantity_if_picked, list): quantity_if_picked = [quantity_if_picked for k in keys] # type: ignore quantity_var, all_new_constraints = [], [] diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index a54de4e85..9e35391e7 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -1,3 +1,4 @@ +import itertools from abc import abstractmethod from typing import List, Optional, Tuple @@ -102,20 +103,48 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: experiments=experiments, specs=self.input_preprocessing_specs ) preds, stds = self._predict(transformed) + column_names = list( + itertools.chain( + *[ + [f"{feat.key}_pred"] + if not isinstance(feat, CategoricalOutput) + else [f"{feat.key}_pred_{cat}" for cat in feat.categories] + for feat in self.domain.outputs.get() + ] + ) + ) if stds is not None: predictions = pd.DataFrame( data=np.hstack((preds, stds)), - columns=["%s_pred" % feat.key for feat in self.domain.outputs.get()] - + ["%s_sd" % feat.key for feat in self.domain.outputs.get()], + columns=column_names + + list( + itertools.chain( + *[ + [f"{feat.key}_sd"] + if not isinstance(feat, CategoricalOutput) + else [f"{feat.key}_sd_{cat}" for cat in feat.categories] + for feat in self.domain.outputs.get() + ] + ) + ), ) else: predictions = pd.DataFrame( data=preds, - columns=["%s_pred" % feat.key for feat in self.domain.outputs.get()], + columns=column_names, + ) + categorical_preds = { + f"{feat.key}_pred": ( + ind, + feat.map_to_categories(predictions.filter(regex=f"{feat.key}_pred_")), + ) + for ind, feat in enumerate(self.domain.outputs.get()) + if isinstance(feat, CategoricalOutput) + } + for key in categorical_preds.keys(): + predictions.insert( + categorical_preds[key][0], key, categorical_preds[key][1] ) - categorical_df = pd.DataFrame.from_dict({f"{feat.key}_pred": feat.map_to_categories(predictions[f"{feat.key}_pred"]) for feat in self.domain.outputs.get() if isinstance(feat, CategoricalOutput)}) - if not categorical_df.empty: - predictions.update(categorical_df) desis = self.domain.outputs(predictions, predictions=True) predictions = pd.concat((predictions, desis), axis=1) predictions.index = experiments.index diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index b64b3f5aa..aa317da08 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -2,8 +2,8 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.mapper import map from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -from bofire.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.surrogates.mlp import MLPEnsemble +from bofire.surrogates.mlp_classifier import MLPClassifierEnsemble from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py index 3c44a3050..99d1b9391 100644 --- a/bofire/surrogates/mlp_classifier.py +++ b/bofire/surrogates/mlp_classifier.py @@ -95,10 +95,10 @@ def forward(self, X: Tensor): X: A `batch_shape x n x d`-dim input tensor `X`. Returns: - A `batch_shape x s x n x m`-dimensional output tensor where - `s` is the size of the ensemble. + A `batch_shape x s x n x m x C`-dimensional output tensor where + `s` is the size of the ensemble and `C` is the number of classes. """ - return torch.stack([torch.argmax(mlp(X), dim=-1).unsqueeze(-1).float() for mlp in self.mlps], dim=-3) + return torch.stack([torch.softmax(mlp(X), dim=-1) for mlp in self.mlps], dim=-3) @property def num_outputs(self) -> int: @@ -128,7 +128,7 @@ def fit_mlp( """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) - loss_function = nn.NLLLoss(reduction='mean') + loss_function = nn.NLLLoss(reduction="mean") optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) for _ in range(n_epoches): current_loss = 0.0 @@ -180,7 +180,9 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): label_mapping = self.outputs[0].to_dict_label() # Convert Y to classification tensor - Y = pd.DataFrame.from_dict({col: Y[col].map(label_mapping) for col in Y.columns}) + Y = pd.DataFrame.from_dict( + {col: Y[col].map(label_mapping) for col in Y.columns} + ) mlps = [] subsample_size = round(self.subsample_fraction * X.shape[0]) @@ -196,7 +198,9 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): ) mlp = MLPClassifier( input_size=transformed_X.shape[1], - output_size=len(label_mapping), # Set outputs based on number of categories + output_size=len( + label_mapping + ), # Set outputs based on number of categories hidden_layer_sizes=self.hidden_layer_sizes, activation=self.activation, # type: ignore dropout=self.dropout, diff --git a/bofire/utils/torch_tools.py b/bofire/utils/torch_tools.py index 6518e5197..e0c75f9ca 100644 --- a/bofire/utils/torch_tools.py +++ b/bofire/utils/torch_tools.py @@ -12,6 +12,7 @@ ) from bofire.data_models.features.api import ContinuousInput, Input from bofire.data_models.objectives.api import ( + CategoricalObjective, CloseToTargetObjective, ConstrainedObjective, MaximizeObjective, @@ -138,7 +139,7 @@ def min_constraint(indices: Tensor, num_features: int, min_count: int): def constrained_objective2botorch( idx: int, objective: ConstrainedObjective, -) -> Tuple[List[Callable[[Tensor], Tensor]], List[float]]: +) -> Tuple[List[Callable[[Tensor], Tensor]], List[float], int]: """Create a callable that can be used by `botorch.utils.objective.apply_constraints` to setup ouput constrained optimizations. @@ -147,24 +148,50 @@ def constrained_objective2botorch( objective (BotorchConstrainedObjective): The objective that should be transformed. Returns: - Tuple[List[Callable[[Tensor], Tensor]], List[float]]: List of callables that can be used by botorch for setting up the constrained objective, and - list of the corresponding botorch eta values. + Tuple[List[Callable[[Tensor], Tensor]], List[float], int]: List of callables that can be used by botorch for setting up the constrained objective, + list of the corresponding botorch eta values, final index used by the method (to track for categorical variables) """ assert isinstance( objective, ConstrainedObjective ), "Objective is not a `ConstrainedObjective`." if isinstance(objective, MaximizeSigmoidObjective): - return [lambda Z: (Z[..., idx] - objective.tp) * -1.0], [ - 1.0 / objective.steepness - ] + return ( + [lambda Z: (Z[..., idx] - objective.tp) * -1.0], + [1.0 / objective.steepness], + idx + 1, + ) elif isinstance(objective, MinimizeSigmoidObjective): - return [lambda Z: (Z[..., idx] - objective.tp)], [1.0 / objective.steepness] + return ( + [lambda Z: (Z[..., idx] - objective.tp)], + [1.0 / objective.steepness], + idx + 1, + ) elif isinstance(objective, TargetObjective): - return [ - lambda Z: (Z[..., idx] - (objective.target_value - objective.tolerance)) - * -1.0, - lambda Z: (Z[..., idx] - (objective.target_value + objective.tolerance)), - ], [1.0 / objective.steepness, 1.0 / objective.steepness] + return ( + [ + lambda Z: (Z[..., idx] - (objective.target_value - objective.tolerance)) + * -1.0, + lambda Z: ( + Z[..., idx] - (objective.target_value + objective.tolerance) + ), + ], + [1.0 / objective.steepness, 1.0 / objective.steepness], + idx + 1, + ) + elif isinstance(objective, CategoricalObjective): + # The output of a categorical objective has final dim `c` where `c` is number of classes + return ( + [ + lambda Z: -1.0 + * objective.w + * ( + Z[..., idx : idx + len(objective.weights)] + * torch.tensor(objective.weights).to(**tkwargs) + ).sum(-1) + ], + [objective.eta], + idx + len(objective.weights), + ) else: raise ValueError(f"Objective {objective.__class__.__name__} not known.") @@ -185,13 +212,16 @@ def get_output_constraints( """ constraints = [] etas = [] - for idx, feat in enumerate(outputs.get()): + idx = 0 + for feat in outputs.get(): if isinstance(feat.objective, ConstrainedObjective): # type: ignore - iconstraints, ietas = constrained_objective2botorch( + iconstraints, ietas, idx = constrained_objective2botorch( idx, objective=feat.objective # type: ignore ) constraints += iconstraints etas += ietas + else: + idx += 1 return constraints, etas diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index 9fdbd1457..264a29d19 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -3,7 +3,7 @@ import uuid import bofire.data_models.features.api as features -from bofire.data_models.objectives.api import MaximizeObjective +from bofire.data_models.objectives.api import CategoricalObjective, MaximizeObjective from tests.bofire.data_models.specs.objectives import specs as objectives from tests.bofire.data_models.specs.specs import Specs @@ -85,7 +85,7 @@ lambda: { "key": str(uuid.uuid4()), "categories": ("a", "b", "c"), - "objective": [0.0, 1.0, 0.0], + "objective": CategoricalObjective(weights=[0.0, 1.0, 0.0]), }, ) specs.add_valid( diff --git a/tests/bofire/data_models/test_features.py b/tests/bofire/data_models/test_features.py index c41227d6a..87f6935db 100644 --- a/tests/bofire/data_models/test_features.py +++ b/tests/bofire/data_models/test_features.py @@ -29,7 +29,11 @@ Fragments, MordredDescriptors, ) -from bofire.data_models.objectives.api import MinimizeObjective, Objective +from bofire.data_models.objectives.api import ( + CategoricalObjective, + MinimizeObjective, + Objective, +) from bofire.data_models.surrogates.api import ScalerEnum objective = MinimizeObjective(w=1) @@ -2215,7 +2219,9 @@ def test_inputs_get_bounds_fit(): of2, of3, CategoricalOutput( - key="of4", categories=["a", "b"], objective=[1.0, 0.0] + key="of4", + categories=["a", "b"], + objective=CategoricalObjective(weights=[1.0, 0.0]), ), ] ), @@ -2225,21 +2231,17 @@ def test_inputs_get_bounds_fit(): ) def test_outputs_call(features, samples): o = features(samples) - assert o.shape == ( - len(samples), - len(features.get_keys_by_objective(Objective)) - + len(features.get_keys(CategoricalOutput)), - ) + assert o.shape == (len(samples), len(features.get_keys_by_objective(Objective))) assert list(o.columns) == [ - f"{key}_des" - for key in features.get_keys_by_objective(Objective) - + features.get_keys(CategoricalOutput) + f"{key}_des" for key in features.get_keys_by_objective(Objective) ] def test_categorical_output(): feature = CategoricalOutput( - key="a", categories=["alpha", "beta", "gamma"], objective=[1.0, 0.0, 0.1] + key="a", + categories=["alpha", "beta", "gamma"], + objective=CategoricalObjective(weights=[1.0, 0.0, 0.1]), ) assert feature.to_dict() == {"alpha": 1.0, "beta": 0.0, "gamma": 0.1} diff --git a/tests/bofire/surrogates/test_diagnostics.py b/tests/bofire/surrogates/test_diagnostics.py index fc6cb28f5..3d8d6db91 100644 --- a/tests/bofire/surrogates/test_diagnostics.py +++ b/tests/bofire/surrogates/test_diagnostics.py @@ -212,7 +212,7 @@ def test_cvresult_get_UQ_metric_valid(): assert cv.n_samples == 10 for metric in UQ_metrics.keys(): m = cv.get_metric(metric=metric) - assert type(m) == float + assert isinstance(m, float) def test_cvresult_get_UQ_metric_invalid(): diff --git a/tests/bofire/utils/test_torch_tools.py b/tests/bofire/utils/test_torch_tools.py index b700014bd..4b434e146 100644 --- a/tests/bofire/utils/test_torch_tools.py +++ b/tests/bofire/utils/test_torch_tools.py @@ -678,7 +678,7 @@ def test_get_initial_conditions_generator(sequential: bool): ], ) def test_constrained_objective2botorch(objective): - cs, etas = constrained_objective2botorch(idx=0, objective=objective) + cs, etas, _ = constrained_objective2botorch(idx=0, objective=objective) x = torch.from_numpy(np.linspace(0, 30, 500)).unsqueeze(-1) y = torch.ones([500]) diff --git a/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb deleted file mode 100644 index 13518e12c..000000000 --- a/tutorials/basic_examples/Unknown_Binary_Constraint_Classification.ipynb +++ /dev/null @@ -1,379 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Classification Surrogate Tests\n", - "\n", - "We are interested in testing whether or not a surrogate model can correctly identify unknown constraints based on binary feasibility/infeasibility. This involves new models which produce `CategoricalOutput`s rather than continuous outputs. Mathematically, instead of multiplying the objective by $\\sigma(x)\\in(0,1)$, we multiply by $I(x)$ which is 1 if $x\\in X$ otherwise it is 0. Since currently BoTorch does not offer support for discrete feasibility constraints (see: [here](https://github.com/pytorch/botorch/blob/main/botorch/utils/objective.py#L122)), we will instead always multiply our objective directly by the feasibility value\n", - "\n", - "In our toy example, the feasible points satisfy $x_1+x_2<= 1.0$." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "# Import packages\n", - "import bofire.strategies.api as strategies\n", - "from bofire.data_models.api import Domain, Outputs, Inputs\n", - "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", - "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective\n", - "import numpy as np" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Manual setup of the optimization domain\n", - "\n", - "The following cell shows how to manually setup the optimization problem in BoFire for didactic purposes. We design a feasible set and output constraints for example." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
x_0x_1x_2x_3x_4x_5f_0f_1
00.5020150.5563550.1844070.1359660.1912900.00.000763infeasible
10.4000850.3518170.8489170.9245940.6579470.5-0.856798feasible
20.6945590.4818010.2399830.5284140.6426160.0-0.850312infeasible
30.6782070.8400330.2989880.9258510.7408470.5-0.665724infeasible
40.0067850.7879130.7788130.8616380.2902250.5-0.996492feasible
\n", - "
" - ], - "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0 f_1\n", - "0 0.502015 0.556355 0.184407 0.135966 0.191290 0.0 0.000763 infeasible\n", - "1 0.400085 0.351817 0.848917 0.924594 0.657947 0.5 -0.856798 feasible\n", - "2 0.694559 0.481801 0.239983 0.528414 0.642616 0.0 -0.850312 infeasible\n", - "3 0.678207 0.840033 0.298988 0.925851 0.740847 0.5 -0.665724 infeasible\n", - "4 0.006785 0.787913 0.778813 0.861638 0.290225 0.5 -0.996492 feasible" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Set-up the inputs and outputs, use categorical domain just as an example\n", - "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(0, 1)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=(0.5, 0.0))])\n", - "\n", - "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", - "output_features = Outputs(features=[\n", - " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", - " # ContinuousOutput(key=f\"f_{2}\", objective=MinimizeSigmoidObjective(w=1., tp=0.0, steepness=0.5)),\n", - " CategoricalOutput(key=f\"f_{1}\", categories=[\"infeasible\", \"feasible\"], objective=[0, 1]) # This function will be associated with learning the feasibility/infeasibility\n", - " ]\n", - ")\n", - "\n", - "# Create domain\n", - "domain1 = Domain(inputs=input_features, outputs=output_features)\n", - "\n", - "# Sample random points\n", - "sample_df = domain1.inputs.sample(20).astype(float) # Sample x's\n", - "\n", - "# Write a function which outputs one continuous variable and another discrete based on some logic\n", - "# Here, feasible points are points whose first two components sum to less then 1.0 - in real experiments, these would not be known\n", - "sample_df[\"f_0\"] = np.cos(sample_df.values.sum(1))\n", - "sample_df[\"f_1\"] = \"infeasible\"\n", - "sample_df.loc[sample_df[\"x_0\"]+sample_df[\"x_1\"] <= 1.0, \"f_1\"] = \"feasible\"\n", - "# sample_df[\"f_2\"] = np.random.uniform(size=(len(sample_df),))\n", - "\n", - "sample_df.head(5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup of the Strategy and ask for Candidates\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from bofire.data_models.acquisition_functions.api import qNEI, qUCB, qSR, qEI\n", - "from bofire.data_models.strategies.api import QparegoStrategy, MultiplicativeSoboStrategy, SoboStrategy\n", - "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate\n", - "from bofire.data_models.domain.api import Outputs\n", - "\n", - "strategy_data = SoboStrategy(domain=domain1, \n", - " acquisition_function=qEI(), \n", - " surrogate_specs=BotorchSurrogates(surrogates=\n", - " [\n", - " MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[-1]]), lr=1.0, n_epochs=100, hidden_layer_sizes=(20,)),\n", - " # MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs[1]]))\n", - " ]\n", - " )\n", - " )\n", - "\n", - "strategy = strategies.map(strategy_data)\n", - "\n", - "strategy.tell(sample_df)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
x_0x_1x_2x_3x_4x_5f_0_predf_1_predf_0_sdf_1_sdf_0_desf_1_des
00.7233411.00.7036350.4261670.3280140.5-1.267539feasible0.0421170.5477231.2675391.0
10.0000001.00.3122030.4500190.3889260.5-1.250402feasible0.0429630.5477231.2504021.0
\n", - "
" - ], - "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_1_pred \\\n", - "0 0.723341 1.0 0.703635 0.426167 0.328014 0.5 -1.267539 feasible \n", - "1 0.000000 1.0 0.312203 0.450019 0.388926 0.5 -1.250402 feasible \n", - "\n", - " f_0_sd f_1_sd f_0_des f_1_des \n", - "0 0.042117 0.547723 1.267539 1.0 \n", - "1 0.042963 0.547723 1.250402 1.0 " - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "candidates = strategy.ask(2)\n", - "\n", - "candidates" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Add Classification Models for Surrogates\n", - "\n", - "Updating the surrogates to allow for classification of output values (i.e. 'feasible' or 'infeasible').\n", - "\n", - "### Housekeeping changes\n", - "\n", - "1. Update the categorical input/outputs ('bofire/data_models/features/categorical.py') to always return a tuple instead of a list for `categories` and attribute (to prevent mutation)\n", - " - Associated test are changed in 'tests/bofire/data_models/specs/features.py'\n", - "2. \n", - "\n", - "### Classification Models\n", - "\n", - "Initially, we are only interested in checking whether or not certain points are feasible or infeasible, hence this is a binary classification problem. \n", - "\n", - "\n", - "### Questions\n", - "\n", - "1. Should we force `allowed` to be a tuple for the categorical input/outputs? If so, we need to refactor indexing for Pandas DFs..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bofire", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.0" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "6f21737eef49beedf03d74399b47fe38d73eff760737ca33d38b9fe616638e91" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb new file mode 100644 index 000000000..fb78e2ef2 --- /dev/null +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -0,0 +1,632 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Classification Surrogate Tests\n", + "\n", + "We are interested in testing whether or not a surrogate model can correctly identify unknown constraints based on categorical criteria with classification surrogates. Essentially, we want to account for scenarios where specialists can look at a set of experiments and label outcomes as 'acceptable', 'unacceptable', 'ideal', etc. \n", + "\n", + "This involves new models that produce `CategoricalOutput`'s rather than continuous outputs. Mathematically, if $g_{\\theta}:\\mathbb{R}^d\\to[0,1]^c$ represents the function governed by learnable parameters $\\theta$ which outputs a probability vector over $c$ potential classes (i.e. for input $x\\in\\mathbb{R}^d$, $g_{\\theta}(x)^\\top\\mathbf{1}=1$ where $\\mathbf{1}$ is the vector of all 1's) and we have acceptibility criteria for the corresponding classes given by $a\\in[0,1]^c$, we can compute a scalar output as $g_{\\theta}(x)^\\top a\\in[0,1]$ as an objective value to be passed in as a constrained function.\n", + "\n", + "In this script, we look at a modified and constrained version of the optimization problem associated with the [Levy function](https://www.sfu.ca/~ssurjano/levy.html), which has a global minima at $x^*=\\mathbf{1}$. We classify constraints for three classes: 'acceptable', 'unacceptable', and 'ideal' based on how close we are to the optimal decision variable; obviously, this value is unknown in a real-world setting, but this serves as a reasonable example." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "# Import packages\n", + "import bofire.strategies.api as strategies\n", + "from bofire.data_models.api import Domain, Outputs, Inputs\n", + "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", + "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective, CategoricalObjective\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Manual setup of the optimization domain\n", + "\n", + "The following cells show how to manually setup the optimization problem in BoFire for didactic purposes." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Write a function which scales the inputs according to the Levy function - i.e. computes $w_i$\n", + "def scale_inputs(x: pd.Series) -> pd.Series:\n", + " return 1 + (x - 1) / 4" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5f_0f_1f_2
0-0.427145-1.239930-1.2896070.8360750.4597560.04.938796unacceptable-0.417839
1-1.439973-1.083066-0.7676011.8650771.2093881.05.837671unacceptable-1.434549
20.643776-0.0106310.6904850.5759220.3992340.00.468453ideal0.648536
30.955959-1.0966470.8496081.4617310.0093880.01.607040ideal0.960725
4-1.2697511.7248731.1083341.6412211.6773331.03.843082ideal-1.263064
5-1.347163-1.2631421.813048-1.7602041.5563391.010.255809unacceptable-1.344647
6-0.739443-0.5867441.6140160.0012261.9816230.02.718785ideal-0.734792
7-1.687649-1.804157-0.383259-1.128108-0.7169711.011.486883unacceptable-1.685165
8-1.3489171.6684771.684095-1.183607-1.4525101.08.185157unacceptable-1.347965
9-1.2980110.203581-1.267659-1.742074-0.5309090.09.691982unacceptable-1.290670
101.6302070.5654040.900429-0.752504-1.2485701.02.843968ideal1.636059
110.150783-1.017744-0.241076-0.747660-0.5342851.02.254525unacceptable0.156698
12-0.757027-0.7574520.2766290.054870-1.7877211.06.733639unacceptable-0.753403
131.909901-1.215048-1.3574051.2343051.8223690.05.587213ideal1.914966
14-0.712812-1.343809-0.549294-1.1133020.2704540.05.321709unacceptable-0.710558
150.414610-0.393348-0.287759-1.958587-0.2798580.06.394600unacceptable0.416899
16-1.627874-0.5609770.0629281.7141030.3948540.05.349225unacceptable-1.621146
17-0.9246770.326249-0.642047-0.8021131.5029431.02.878352unacceptable-0.916453
18-1.0068620.9767320.4192610.285177-1.6985250.06.397444unacceptable-1.000865
19-1.3817260.2564161.467449-0.750482-0.5580651.04.250857unacceptable-1.376709
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", + "0 -0.427145 -1.239930 -1.289607 0.836075 0.459756 0.0 4.938796 \n", + "1 -1.439973 -1.083066 -0.767601 1.865077 1.209388 1.0 5.837671 \n", + "2 0.643776 -0.010631 0.690485 0.575922 0.399234 0.0 0.468453 \n", + "3 0.955959 -1.096647 0.849608 1.461731 0.009388 0.0 1.607040 \n", + "4 -1.269751 1.724873 1.108334 1.641221 1.677333 1.0 3.843082 \n", + "5 -1.347163 -1.263142 1.813048 -1.760204 1.556339 1.0 10.255809 \n", + "6 -0.739443 -0.586744 1.614016 0.001226 1.981623 0.0 2.718785 \n", + "7 -1.687649 -1.804157 -0.383259 -1.128108 -0.716971 1.0 11.486883 \n", + "8 -1.348917 1.668477 1.684095 -1.183607 -1.452510 1.0 8.185157 \n", + "9 -1.298011 0.203581 -1.267659 -1.742074 -0.530909 0.0 9.691982 \n", + "10 1.630207 0.565404 0.900429 -0.752504 -1.248570 1.0 2.843968 \n", + "11 0.150783 -1.017744 -0.241076 -0.747660 -0.534285 1.0 2.254525 \n", + "12 -0.757027 -0.757452 0.276629 0.054870 -1.787721 1.0 6.733639 \n", + "13 1.909901 -1.215048 -1.357405 1.234305 1.822369 0.0 5.587213 \n", + "14 -0.712812 -1.343809 -0.549294 -1.113302 0.270454 0.0 5.321709 \n", + "15 0.414610 -0.393348 -0.287759 -1.958587 -0.279858 0.0 6.394600 \n", + "16 -1.627874 -0.560977 0.062928 1.714103 0.394854 0.0 5.349225 \n", + "17 -0.924677 0.326249 -0.642047 -0.802113 1.502943 1.0 2.878352 \n", + "18 -1.006862 0.976732 0.419261 0.285177 -1.698525 0.0 6.397444 \n", + "19 -1.381726 0.256416 1.467449 -0.750482 -0.558065 1.0 4.250857 \n", + "\n", + " f_1 f_2 \n", + "0 unacceptable -0.417839 \n", + "1 unacceptable -1.434549 \n", + "2 ideal 0.648536 \n", + "3 ideal 0.960725 \n", + "4 ideal -1.263064 \n", + "5 unacceptable -1.344647 \n", + "6 ideal -0.734792 \n", + "7 unacceptable -1.685165 \n", + "8 unacceptable -1.347965 \n", + "9 unacceptable -1.290670 \n", + "10 ideal 1.636059 \n", + "11 unacceptable 0.156698 \n", + "12 unacceptable -0.753403 \n", + "13 ideal 1.914966 \n", + "14 unacceptable -0.710558 \n", + "15 unacceptable 0.416899 \n", + "16 unacceptable -1.621146 \n", + "17 unacceptable -0.916453 \n", + "18 unacceptable -1.000865 \n", + "19 unacceptable -1.376709 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Set-up the inputs and outputs, use categorical domain just as an example\n", + "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(-2, 2)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=(0.0, 1.0))])\n", + "\n", + "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", + "output_features = Outputs(features=[\n", + " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", + " CategoricalOutput(key=f\"f_{1}\", categories=[\"unacceptable\", \"acceptable\", \"ideal\"], objective=CategoricalObjective(weights=(0, 0.5, 1))), # This function will be associated with learning the categories\n", + " ContinuousOutput(key=f\"f_{2}\", objective=MinimizeSigmoidObjective(w=1., tp=0.0, steepness=0.5)),\n", + " ]\n", + ")\n", + "\n", + "# Create domain\n", + "domain1 = Domain(inputs=input_features, outputs=output_features)\n", + "\n", + "# Sample random points\n", + "sample_df = domain1.inputs.sample(50).astype(float) # Sample x's\n", + "\n", + "# Write a function which outputs one continuous variable and another discrete based on some logic\n", + "sample_df[\"f_0\"] = np.sin(np.pi * scale_inputs(sample_df[\"x_0\"])) ** 2 + sum([(scale_inputs(sample_df[col]) - 1) ** 2 * (1 + 10 * np.sin(np.pi * scale_inputs(sample_df[col]) + 1) ** 2 if ind < len(sample_df.columns) else 1 + np.sin(2 * np.pi * scale_inputs(sample_df[col])) ** 2) for ind, col in enumerate(sample_df.columns)])\n", + "sample_df[\"f_1\"] = \"unacceptable\"\n", + "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 1.0, \"f_1\"] = \"acceptable\"\n", + "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 2.0, \"f_1\"] = \"ideal\"\n", + "sample_df[\"f_2\"] = sample_df[\"x_0\"] + 1e-2 * np.random.uniform(size=(len(sample_df),))\n", + "\n", + "sample_df.head(20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup of the Strategy and ask for Candidates\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\optim\\fit.py:130: OptimizationWarning: `scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH\n", + " warn(\n" + ] + } + ], + "source": [ + "from bofire.data_models.acquisition_functions.api import qEI\n", + "from bofire.data_models.strategies.api import SoboStrategy\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate\n", + "from bofire.data_models.domain.api import Outputs\n", + "\n", + "strategy_data = SoboStrategy(domain=domain1, \n", + " acquisition_function=qEI(), \n", + " surrogate_specs=BotorchSurrogates(surrogates=\n", + " [\n", + " MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=1.0, n_epochs=50, hidden_layer_sizes=(20,)),\n", + " MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", + " ]\n", + " )\n", + " )\n", + "\n", + "strategy = strategies.map(strategy_data)\n", + "\n", + "strategy.tell(sample_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5f_0_predf_2_predf_1_predf_1_pred_unacceptablef_1_pred_acceptablef_1_pred_idealf_0_sdf_2_sdf_1_sd_unacceptablef_1_sd_acceptablef_1_sd_idealf_0_desf_2_desf_1_des
00.2985420.6113790.5633672.0000000.3578451.0-1.1847540.303593unacceptable0.5125050.0644480.4230480.6328540.0030970.2947120.0587370.3249071.1847540.4621240.455272
10.1801370.5650720.6000421.0936290.3614720.0-1.0919470.185482unacceptable0.5125050.0644500.4230450.5813220.0029970.2947120.0587340.3249011.0919470.4768310.455270
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_2_pred \\\n", + "0 0.298542 0.611379 0.563367 2.000000 0.357845 1.0 -1.184754 0.303593 \n", + "1 0.180137 0.565072 0.600042 1.093629 0.361472 0.0 -1.091947 0.185482 \n", + "\n", + " f_1_pred f_1_pred_unacceptable f_1_pred_acceptable f_1_pred_ideal \\\n", + "0 unacceptable 0.512505 0.064448 0.423048 \n", + "1 unacceptable 0.512505 0.064450 0.423045 \n", + "\n", + " f_0_sd f_2_sd f_1_sd_unacceptable f_1_sd_acceptable f_1_sd_ideal \\\n", + "0 0.632854 0.003097 0.294712 0.058737 0.324907 \n", + "1 0.581322 0.002997 0.294712 0.058734 0.324901 \n", + "\n", + " f_0_des f_2_des f_1_des \n", + "0 1.184754 0.462124 0.455272 \n", + "1 1.091947 0.476831 0.455270 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "candidates = strategy.ask(2)\n", + "candidates" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bofire", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "6f21737eef49beedf03d74399b47fe38d73eff760737ca33d38b9fe616638e91" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 53e61dac6b4ec9e071649a367a0b614c87d83d7e Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 5 Oct 2023 10:32:32 -0400 Subject: [PATCH 07/31] Update validators, fix bugs, link categorical objectives and weights --- bofire/data_models/features/categorical.py | 30 +- bofire/data_models/objectives/categorical.py | 25 +- .../data_models/surrogates/fully_bayesian.py | 16 + bofire/data_models/surrogates/linear.py | 21 +- .../surrogates/mixed_single_task_gp.py | 16 + bofire/data_models/surrogates/mlp.py | 18 +- .../data_models/surrogates/mlp_classifier.py | 18 +- .../data_models/surrogates/random_forest.py | 18 +- .../data_models/surrogates/single_task_gp.py | 23 +- bofire/data_models/surrogates/tanimoto_gp.py | 18 +- bofire/data_models/surrogates/xgb.py | 16 + bofire/surrogates/mlp_classifier.py | 4 +- bofire/utils/torch_tools.py | 7 +- tests/bofire/data_models/specs/features.py | 2 +- tests/bofire/data_models/test_features.py | 4 +- .../Unknown_Constraint_Classification.ipynb | 491 +++++++++--------- 16 files changed, 434 insertions(+), 293 deletions(-) diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 4f24d7e9b..3805ea5e9 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -382,20 +382,18 @@ def validate_categories_unique(cls, categories): @validator("objective") def validate_objective(cls, objective, values): - weights = objective.weights - if len(weights) != len(values["categories"]): - raise ValueError("Length of objectives and categories do not match.") - for w in weights: - if w > 1: - raise ValueError("Objective weight has to be smaller equal than 1.") - if w < 0: - raise ValueError("Objective weight has to be larger equal than zero") - # Save the categories to the objective if they do not exist - objective.categories = ( - list(values["categories"]) - if objective.categories is None - else objective.categories - ) + """validates that objective desirabilities are the same length as categories + + Raises: + ValueError: when len(objective.desirability) != len(categories) + + Returns: + CategoricalObjective + """ + if len(objective.desirability) != len(values["categories"]): + raise ValueError( + f"{len(objective.desirability)} desirabilities and {len(values['categories'])} categories" + ) return objective def validate_experimental(self, values: pd.Series) -> pd.Series: @@ -411,7 +409,7 @@ def __str__(self) -> str: def to_dict(self) -> Dict: """Returns the catergories and corresponding objective values as dictionary""" - return dict(zip(self.categories, self.objective.weights)) + return dict(zip(self.categories, self.objective.desirability)) def to_dict_label(self) -> Dict: """Returns the catergories and label location of categories""" @@ -435,4 +433,4 @@ def compute_objective(self, values: pd.DataFrame) -> pd.Series: ) def __call__(self, values: pd.Series) -> pd.Series: - return self.objective(values) + return values.map(self.to_dict()) diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index c2c5a577d..4e3d6ce6a 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -1,7 +1,8 @@ -from typing import List, Literal, Union +from typing import Literal, Tuple, Union import numpy as np import pandas as pd +from pydantic import validator from bofire.data_models.objectives.objective import ( ConstrainedObjective, @@ -14,20 +15,27 @@ class CategoricalObjective(Objective, ConstrainedObjective): """Compute the categorical objective value as: Po where P is an [n, c] matrix where each row is a probability vector - (P[i, :].sum()=1 for all i) and o is a column vector of objective values + (P[i, :].sum()=1 for all i) and o is a vector of size [c] of objective values Attributes: w (float): float between zero and one for weighting the objective. - weights (list): list of values of size c (c is number of categories) such that the i-th entry is in (0, 1) + desirability (tuple): tuple of values of size c (c is number of categories) such that the i-th entry is in (0, 1) """ w: TWeight = 1.0 - weights: List[float] + desirability: Tuple[float, ...] eta: float = 1.0 - categories: Union[List[str], None] = None - type: Literal["CategoricalObjective"] = "CategoricalObjective" + @validator("desirability") + def validate_desirability(cls, desirability): + for w in desirability: + if w > 1: + raise ValueError("Objective weight has to be smaller equal than 1.") + if w < 0: + raise ValueError("Objective weight has to be larger equal than zero") + return desirability + def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: """The call function returning a probabilistic reward for x. @@ -37,4 +45,7 @@ def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarr Returns: np.ndarray: A reward calculated as inner product of probabilities and feasible objectives. """ - return x.map(dict(zip(self.categories, self.weights))) + print( + "Categorical objective currently does not have a function. Returning the original input." + ) + return x diff --git a/bofire/data_models/surrogates/fully_bayesian.py b/bofire/data_models/surrogates/fully_bayesian.py index e25366287..2add26b17 100644 --- a/bofire/data_models/surrogates/fully_bayesian.py +++ b/bofire/data_models/surrogates/fully_bayesian.py @@ -2,6 +2,7 @@ from pydantic import conint, validator +from bofire.data_models.features.api import ContinuousOutput from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable import TrainableSurrogate @@ -19,3 +20,18 @@ def validate_thinning(cls, value, values): if values["num_samples"] / value < 1: raise ValueError("`num_samples` has to be larger than `thinning`.") return value + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/linear.py b/bofire/data_models/surrogates/linear.py index a03d5cae4..fe03c99e9 100644 --- a/bofire/data_models/surrogates/linear.py +++ b/bofire/data_models/surrogates/linear.py @@ -1,15 +1,15 @@ from typing import Literal -from pydantic import Field +from pydantic import Field, validator +# from bofire.data_models.strategies.api import FactorialStrategy +from bofire.data_models.features.api import ContinuousOutput from bofire.data_models.kernels.api import LinearKernel from bofire.data_models.priors.api import ( BOTORCH_NOISE_PRIOR, BOTORCH_SCALE_PRIOR, AnyPrior, ) - -# from bofire.data_models.strategies.api import FactorialStrategy from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable import TrainableSurrogate @@ -23,3 +23,18 @@ class LinearSurrogate(BotorchSurrogate, TrainableSurrogate): ) noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR()) scaler: ScalerEnum = ScalerEnum.NORMALIZE + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/mixed_single_task_gp.py b/bofire/data_models/surrogates/mixed_single_task_gp.py index f59b8c904..bbb3b541e 100644 --- a/bofire/data_models/surrogates/mixed_single_task_gp.py +++ b/bofire/data_models/surrogates/mixed_single_task_gp.py @@ -3,6 +3,7 @@ from pydantic import Field, validator from bofire.data_models.enum import CategoricalEncodingEnum +from bofire.data_models.features.api import ContinuousOutput from bofire.data_models.kernels.api import ( AnyCategoricalKernal, AnyContinuousKernel, @@ -32,3 +33,18 @@ def validate_categoricals(cls, v, values): "MixedSingleTaskGPSurrogate can only be used if at least one one-hot encoded categorical feature is present." ) return v + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index 57bd61bc1..6173f521f 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -1,7 +1,8 @@ from typing import Annotated, Literal, Sequence -from pydantic import Field +from pydantic import Field, validator +from bofire.data_models.features.api import ContinuousOutput from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable import TrainableSurrogate @@ -20,3 +21,18 @@ class MLPEnsemble(BotorchSurrogate, TrainableSurrogate): subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/mlp_classifier.py b/bofire/data_models/surrogates/mlp_classifier.py index 6e6c59400..ea9fc576d 100644 --- a/bofire/data_models/surrogates/mlp_classifier.py +++ b/bofire/data_models/surrogates/mlp_classifier.py @@ -1,7 +1,8 @@ from typing import Annotated, Literal, Sequence -from pydantic import Field +from pydantic import Field, validator +from bofire.data_models.features.api import CategoricalOutput from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable import TrainableSurrogate @@ -20,3 +21,18 @@ class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not CategoricalOutput + + Returns: + List[CategoricalOutput] + """ + for o in outputs: + if not isinstance(o, CategoricalOutput): + raise ValueError("all outputs need to be categorical") + return outputs diff --git a/bofire/data_models/surrogates/random_forest.py b/bofire/data_models/surrogates/random_forest.py index 621bb30f7..b9f2855eb 100644 --- a/bofire/data_models/surrogates/random_forest.py +++ b/bofire/data_models/surrogates/random_forest.py @@ -1,8 +1,9 @@ from typing import Literal, Optional, Union -from pydantic import Field +from pydantic import Field, validator from typing_extensions import Annotated +from bofire.data_models.features.api import ContinuousOutput from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.trainable import TrainableSurrogate @@ -30,3 +31,18 @@ class RandomForestSurrogate(BotorchSurrogate, TrainableSurrogate): random_state: Optional[int] = None ccp_alpha: Annotated[float, Field(ge=0)] = 0.0 max_samples: Optional[Union[int, float]] = None + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/single_task_gp.py b/bofire/data_models/surrogates/single_task_gp.py index 2ccd1c617..c6ef592f7 100644 --- a/bofire/data_models/surrogates/single_task_gp.py +++ b/bofire/data_models/surrogates/single_task_gp.py @@ -1,11 +1,13 @@ from typing import Literal, Optional import pandas as pd -from pydantic import Field +from pydantic import Field, validator from bofire.data_models.domain.api import Inputs from bofire.data_models.enum import RegressionMetricsEnum -from bofire.data_models.features.api import CategoricalInput + +# from bofire.data_models.strategies.api import FactorialStrategy +from bofire.data_models.features.api import CategoricalInput, ContinuousOutput from bofire.data_models.kernels.api import ( AnyKernel, MaternKernel, @@ -21,8 +23,6 @@ MBO_OUTPUTSCALE_PRIOR, AnyPrior, ) - -# from bofire.data_models.strategies.api import FactorialStrategy from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable import Hyperconfig, TrainableSurrogate @@ -110,3 +110,18 @@ class SingleTaskGPSurrogate(BotorchSurrogate, TrainableSurrogate): hyperconfig: Optional[SingleTaskGPHyperconfig] = Field( default_factory=lambda: SingleTaskGPHyperconfig() ) + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/tanimoto_gp.py b/bofire/data_models/surrogates/tanimoto_gp.py index 1a1df1d89..adeea1293 100644 --- a/bofire/data_models/surrogates/tanimoto_gp.py +++ b/bofire/data_models/surrogates/tanimoto_gp.py @@ -1,7 +1,8 @@ from typing import Literal -from pydantic import Field +from pydantic import Field, validator +from bofire.data_models.features.api import ContinuousOutput from bofire.data_models.kernels.api import AnyKernel, ScaleKernel from bofire.data_models.kernels.molecular import TanimotoKernel from bofire.data_models.priors.api import ( @@ -27,3 +28,18 @@ class TanimotoGPSurrogate(BotorchSurrogate, TrainableSurrogate): ) noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR()) scaler: ScalerEnum = ScalerEnum.IDENTITY + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/data_models/surrogates/xgb.py b/bofire/data_models/surrogates/xgb.py index 467bdd2d4..7e38c8bd6 100644 --- a/bofire/data_models/surrogates/xgb.py +++ b/bofire/data_models/surrogates/xgb.py @@ -7,6 +7,7 @@ from bofire.data_models.features.api import ( CategoricalDescriptorInput, CategoricalInput, + ContinuousOutput, NumericalInput, ) from bofire.data_models.surrogates.surrogate import Surrogate @@ -75,3 +76,18 @@ def validate_input_preprocessing_specs(cls, v, values): if v.get(key) is not None: raise ValueError("Currently no numeric transforms are supported.") return v + + @validator("outputs") + def validate_outputs(cls, outputs): + """validates outputs + + Raises: + ValueError: if output type is not ContinuousOutput + + Returns: + List[ContinuousOutput] + """ + for o in outputs: + if not isinstance(o, ContinuousOutput): + raise ValueError("all outputs need to be continuous") + return outputs diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py index 99d1b9391..a7789c04f 100644 --- a/bofire/surrogates/mlp_classifier.py +++ b/bofire/surrogates/mlp_classifier.py @@ -95,10 +95,10 @@ def forward(self, X: Tensor): X: A `batch_shape x n x d`-dim input tensor `X`. Returns: - A `batch_shape x s x n x m x C`-dimensional output tensor where + A `batch_shape x s x n x C`-dimensional output tensor where `s` is the size of the ensemble and `C` is the number of classes. """ - return torch.stack([torch.softmax(mlp(X), dim=-1) for mlp in self.mlps], dim=-3) + return torch.stack([mlp(X).exp() for mlp in self.mlps], dim=-3) @property def num_outputs(self) -> int: diff --git a/bofire/utils/torch_tools.py b/bofire/utils/torch_tools.py index e0c75f9ca..c80e38de0 100644 --- a/bofire/utils/torch_tools.py +++ b/bofire/utils/torch_tools.py @@ -183,14 +183,13 @@ def constrained_objective2botorch( return ( [ lambda Z: -1.0 - * objective.w * ( - Z[..., idx : idx + len(objective.weights)] - * torch.tensor(objective.weights).to(**tkwargs) + Z[..., idx : idx + len(objective.desirability)] + * torch.tensor(objective.desirability).to(**tkwargs) ).sum(-1) ], [objective.eta], - idx + len(objective.weights), + idx + len(objective.desirability), ) else: raise ValueError(f"Objective {objective.__class__.__name__} not known.") diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index 264a29d19..2847c3455 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -85,7 +85,7 @@ lambda: { "key": str(uuid.uuid4()), "categories": ("a", "b", "c"), - "objective": CategoricalObjective(weights=[0.0, 1.0, 0.0]), + "objective": CategoricalObjective(desirability=(0.0, 1.0, 0.0)), }, ) specs.add_valid( diff --git a/tests/bofire/data_models/test_features.py b/tests/bofire/data_models/test_features.py index 87f6935db..9c100bb4f 100644 --- a/tests/bofire/data_models/test_features.py +++ b/tests/bofire/data_models/test_features.py @@ -2221,7 +2221,7 @@ def test_inputs_get_bounds_fit(): CategoricalOutput( key="of4", categories=["a", "b"], - objective=CategoricalObjective(weights=[1.0, 0.0]), + objective=CategoricalObjective(desirability=[1.0, 0.0]), ), ] ), @@ -2241,7 +2241,7 @@ def test_categorical_output(): feature = CategoricalOutput( key="a", categories=["alpha", "beta", "gamma"], - objective=CategoricalObjective(weights=[1.0, 0.0, 0.1]), + objective=CategoricalObjective(desirability=[1.0, 0.0, 0.1]), ) assert feature.to_dict() == {"alpha": 1.0, "beta": 0.0, "gamma": 0.1} diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index fb78e2ef2..2a39131cb 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -98,243 +98,243 @@ " \n", " \n", " 0\n", - " -0.427145\n", - " -1.239930\n", - " -1.289607\n", - " 0.836075\n", - " 0.459756\n", + " -1.607257\n", + " 0.066965\n", + " 1.956921\n", + " -1.206965\n", + " -1.726151\n", " 0.0\n", - " 4.938796\n", + " 11.168473\n", " unacceptable\n", - " -0.417839\n", + " -1.599521\n", " \n", " \n", " 1\n", - " -1.439973\n", - " -1.083066\n", - " -0.767601\n", - " 1.865077\n", - " 1.209388\n", - " 1.0\n", - " 5.837671\n", + " 0.576852\n", + " -0.826301\n", + " 0.188427\n", + " 0.228371\n", + " -0.548498\n", + " 0.0\n", + " 1.232865\n", " unacceptable\n", - " -1.434549\n", + " 0.583328\n", " \n", " \n", " 2\n", - " 0.643776\n", - " -0.010631\n", - " 0.690485\n", - " 0.575922\n", - " 0.399234\n", - " 0.0\n", - " 0.468453\n", - " ideal\n", - " 0.648536\n", + " 0.426233\n", + " -0.353617\n", + " -1.501839\n", + " -1.336984\n", + " 1.611887\n", + " 1.0\n", + " 5.894178\n", + " unacceptable\n", + " 0.432236\n", " \n", " \n", " 3\n", - " 0.955959\n", - " -1.096647\n", - " 0.849608\n", - " 1.461731\n", - " 0.009388\n", - " 0.0\n", - " 1.607040\n", + " 0.791177\n", + " 0.517637\n", + " 1.054840\n", + " 1.702119\n", + " -0.133859\n", + " 1.0\n", + " 0.538519\n", " ideal\n", - " 0.960725\n", + " 0.794129\n", " \n", " \n", " 4\n", - " -1.269751\n", - " 1.724873\n", - " 1.108334\n", - " 1.641221\n", - " 1.677333\n", + " -0.212858\n", + " 1.738505\n", + " 0.047741\n", + " 0.785163\n", + " -0.386461\n", " 1.0\n", - " 3.843082\n", + " 1.373196\n", " ideal\n", - " -1.263064\n", + " -0.204777\n", " \n", " \n", " 5\n", - " -1.347163\n", - " -1.263142\n", - " 1.813048\n", - " -1.760204\n", - " 1.556339\n", - " 1.0\n", - " 10.255809\n", + " -1.343810\n", + " 0.174728\n", + " -0.742157\n", + " 0.828275\n", + " -0.383426\n", + " 0.0\n", + " 3.939393\n", " unacceptable\n", - " -1.344647\n", + " -1.341086\n", " \n", " \n", " 6\n", - " -0.739443\n", - " -0.586744\n", - " 1.614016\n", - " 0.001226\n", - " 1.981623\n", + " 0.427134\n", + " -1.622157\n", + " 1.390718\n", + " -0.325319\n", + " -0.767477\n", " 0.0\n", - " 2.718785\n", - " ideal\n", - " -0.734792\n", + " 4.739526\n", + " unacceptable\n", + " 0.431781\n", " \n", " \n", " 7\n", - " -1.687649\n", - " -1.804157\n", - " -0.383259\n", - " -1.128108\n", - " -0.716971\n", - " 1.0\n", - " 11.486883\n", - " unacceptable\n", - " -1.685165\n", + " 0.314917\n", + " 1.367960\n", + " 0.926581\n", + " 1.710552\n", + " -0.843306\n", + " 0.0\n", + " 1.487580\n", + " ideal\n", + " 0.316096\n", " \n", " \n", " 8\n", - " -1.348917\n", - " 1.668477\n", - " 1.684095\n", - " -1.183607\n", - " -1.452510\n", + " -1.891133\n", + " -1.125407\n", + " 1.468634\n", + " -0.906353\n", + " 0.456227\n", " 1.0\n", - " 8.185157\n", + " 8.206412\n", " unacceptable\n", - " -1.347965\n", + " -1.888861\n", " \n", " \n", " 9\n", - " -1.298011\n", - " 0.203581\n", - " -1.267659\n", - " -1.742074\n", - " -0.530909\n", + " -0.281061\n", + " -0.167345\n", + " 0.589614\n", + " 1.681104\n", + " 1.800321\n", " 0.0\n", - " 9.691982\n", - " unacceptable\n", - " -1.290670\n", + " 1.807627\n", + " ideal\n", + " -0.272137\n", " \n", " \n", " 10\n", - " 1.630207\n", - " 0.565404\n", - " 0.900429\n", - " -0.752504\n", - " -1.248570\n", + " 0.936582\n", + " 1.536808\n", + " 1.250784\n", + " 1.286891\n", + " 0.246477\n", " 1.0\n", - " 2.843968\n", + " 0.378613\n", " ideal\n", - " 1.636059\n", + " 0.942305\n", " \n", " \n", " 11\n", - " 0.150783\n", - " -1.017744\n", - " -0.241076\n", - " -0.747660\n", - " -0.534285\n", + " 1.013835\n", + " -1.056954\n", + " -0.894305\n", + " -1.599589\n", + " -0.405500\n", " 1.0\n", - " 2.254525\n", + " 5.569871\n", " unacceptable\n", - " 0.156698\n", + " 1.013881\n", " \n", " \n", " 12\n", - " -0.757027\n", - " -0.757452\n", - " 0.276629\n", - " 0.054870\n", - " -1.787721\n", + " -1.144869\n", + " 0.845439\n", + " 0.511907\n", + " 1.450526\n", + " -0.416333\n", " 1.0\n", - " 6.733639\n", - " unacceptable\n", - " -0.753403\n", + " 2.774191\n", + " ideal\n", + " -1.140227\n", " \n", " \n", " 13\n", - " 1.909901\n", - " -1.215048\n", - " -1.357405\n", - " 1.234305\n", - " 1.822369\n", - " 0.0\n", - " 5.587213\n", + " -1.213794\n", + " 1.993482\n", + " -1.097520\n", + " 1.123132\n", + " 0.433351\n", + " 1.0\n", + " 4.678376\n", " ideal\n", - " 1.914966\n", + " -1.208828\n", " \n", " \n", " 14\n", - " -0.712812\n", - " -1.343809\n", - " -0.549294\n", - " -1.113302\n", - " 0.270454\n", - " 0.0\n", - " 5.321709\n", + " 1.523778\n", + " 0.713813\n", + " -0.675652\n", + " -1.730426\n", + " 0.057436\n", + " 1.0\n", + " 5.140214\n", " unacceptable\n", - " -0.710558\n", + " 1.530586\n", " \n", " \n", " 15\n", - " 0.414610\n", - " -0.393348\n", - " -0.287759\n", - " -1.958587\n", - " -0.279858\n", + " 0.504682\n", + " 0.182944\n", + " 1.127047\n", + " 0.739912\n", + " -1.141001\n", " 0.0\n", - " 6.394600\n", - " unacceptable\n", - " 0.416899\n", + " 1.851892\n", + " acceptable\n", + " 0.505863\n", " \n", " \n", " 16\n", - " -1.627874\n", - " -0.560977\n", - " 0.062928\n", - " 1.714103\n", - " 0.394854\n", + " 1.962748\n", + " -0.965457\n", + " 0.262051\n", + " 0.326342\n", + " -1.285648\n", " 0.0\n", - " 5.349225\n", + " 4.234999\n", " unacceptable\n", - " -1.621146\n", + " 1.968161\n", " \n", " \n", " 17\n", - " -0.924677\n", - " 0.326249\n", - " -0.642047\n", - " -0.802113\n", - " 1.502943\n", + " 0.054107\n", + " 0.017339\n", + " -1.421447\n", + " -0.475980\n", + " 0.825959\n", " 1.0\n", - " 2.878352\n", + " 3.445593\n", " unacceptable\n", - " -0.916453\n", + " 0.063285\n", " \n", " \n", " 18\n", - " -1.006862\n", - " 0.976732\n", - " 0.419261\n", - " 0.285177\n", - " -1.698525\n", + " -0.538143\n", + " -0.853412\n", + " 1.263803\n", + " -1.066115\n", + " -0.697828\n", " 0.0\n", - " 6.397444\n", + " 3.396296\n", " unacceptable\n", - " -1.000865\n", + " -0.532379\n", " \n", " \n", " 19\n", - " -1.381726\n", - " 0.256416\n", - " 1.467449\n", - " -0.750482\n", - " -0.558065\n", + " 0.492618\n", + " -1.346028\n", + " -0.405918\n", + " -1.280858\n", + " -0.913901\n", " 1.0\n", - " 4.250857\n", + " 5.346855\n", " unacceptable\n", - " -1.376709\n", + " 0.494498\n", " \n", " \n", "\n", @@ -342,48 +342,48 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 -0.427145 -1.239930 -1.289607 0.836075 0.459756 0.0 4.938796 \n", - "1 -1.439973 -1.083066 -0.767601 1.865077 1.209388 1.0 5.837671 \n", - "2 0.643776 -0.010631 0.690485 0.575922 0.399234 0.0 0.468453 \n", - "3 0.955959 -1.096647 0.849608 1.461731 0.009388 0.0 1.607040 \n", - "4 -1.269751 1.724873 1.108334 1.641221 1.677333 1.0 3.843082 \n", - "5 -1.347163 -1.263142 1.813048 -1.760204 1.556339 1.0 10.255809 \n", - "6 -0.739443 -0.586744 1.614016 0.001226 1.981623 0.0 2.718785 \n", - "7 -1.687649 -1.804157 -0.383259 -1.128108 -0.716971 1.0 11.486883 \n", - "8 -1.348917 1.668477 1.684095 -1.183607 -1.452510 1.0 8.185157 \n", - "9 -1.298011 0.203581 -1.267659 -1.742074 -0.530909 0.0 9.691982 \n", - "10 1.630207 0.565404 0.900429 -0.752504 -1.248570 1.0 2.843968 \n", - "11 0.150783 -1.017744 -0.241076 -0.747660 -0.534285 1.0 2.254525 \n", - "12 -0.757027 -0.757452 0.276629 0.054870 -1.787721 1.0 6.733639 \n", - "13 1.909901 -1.215048 -1.357405 1.234305 1.822369 0.0 5.587213 \n", - "14 -0.712812 -1.343809 -0.549294 -1.113302 0.270454 0.0 5.321709 \n", - "15 0.414610 -0.393348 -0.287759 -1.958587 -0.279858 0.0 6.394600 \n", - "16 -1.627874 -0.560977 0.062928 1.714103 0.394854 0.0 5.349225 \n", - "17 -0.924677 0.326249 -0.642047 -0.802113 1.502943 1.0 2.878352 \n", - "18 -1.006862 0.976732 0.419261 0.285177 -1.698525 0.0 6.397444 \n", - "19 -1.381726 0.256416 1.467449 -0.750482 -0.558065 1.0 4.250857 \n", + "0 -1.607257 0.066965 1.956921 -1.206965 -1.726151 0.0 11.168473 \n", + "1 0.576852 -0.826301 0.188427 0.228371 -0.548498 0.0 1.232865 \n", + "2 0.426233 -0.353617 -1.501839 -1.336984 1.611887 1.0 5.894178 \n", + "3 0.791177 0.517637 1.054840 1.702119 -0.133859 1.0 0.538519 \n", + "4 -0.212858 1.738505 0.047741 0.785163 -0.386461 1.0 1.373196 \n", + "5 -1.343810 0.174728 -0.742157 0.828275 -0.383426 0.0 3.939393 \n", + "6 0.427134 -1.622157 1.390718 -0.325319 -0.767477 0.0 4.739526 \n", + "7 0.314917 1.367960 0.926581 1.710552 -0.843306 0.0 1.487580 \n", + "8 -1.891133 -1.125407 1.468634 -0.906353 0.456227 1.0 8.206412 \n", + "9 -0.281061 -0.167345 0.589614 1.681104 1.800321 0.0 1.807627 \n", + "10 0.936582 1.536808 1.250784 1.286891 0.246477 1.0 0.378613 \n", + "11 1.013835 -1.056954 -0.894305 -1.599589 -0.405500 1.0 5.569871 \n", + "12 -1.144869 0.845439 0.511907 1.450526 -0.416333 1.0 2.774191 \n", + "13 -1.213794 1.993482 -1.097520 1.123132 0.433351 1.0 4.678376 \n", + "14 1.523778 0.713813 -0.675652 -1.730426 0.057436 1.0 5.140214 \n", + "15 0.504682 0.182944 1.127047 0.739912 -1.141001 0.0 1.851892 \n", + "16 1.962748 -0.965457 0.262051 0.326342 -1.285648 0.0 4.234999 \n", + "17 0.054107 0.017339 -1.421447 -0.475980 0.825959 1.0 3.445593 \n", + "18 -0.538143 -0.853412 1.263803 -1.066115 -0.697828 0.0 3.396296 \n", + "19 0.492618 -1.346028 -0.405918 -1.280858 -0.913901 1.0 5.346855 \n", "\n", " f_1 f_2 \n", - "0 unacceptable -0.417839 \n", - "1 unacceptable -1.434549 \n", - "2 ideal 0.648536 \n", - "3 ideal 0.960725 \n", - "4 ideal -1.263064 \n", - "5 unacceptable -1.344647 \n", - "6 ideal -0.734792 \n", - "7 unacceptable -1.685165 \n", - "8 unacceptable -1.347965 \n", - "9 unacceptable -1.290670 \n", - "10 ideal 1.636059 \n", - "11 unacceptable 0.156698 \n", - "12 unacceptable -0.753403 \n", - "13 ideal 1.914966 \n", - "14 unacceptable -0.710558 \n", - "15 unacceptable 0.416899 \n", - "16 unacceptable -1.621146 \n", - "17 unacceptable -0.916453 \n", - "18 unacceptable -1.000865 \n", - "19 unacceptable -1.376709 " + "0 unacceptable -1.599521 \n", + "1 unacceptable 0.583328 \n", + "2 unacceptable 0.432236 \n", + "3 ideal 0.794129 \n", + "4 ideal -0.204777 \n", + "5 unacceptable -1.341086 \n", + "6 unacceptable 0.431781 \n", + "7 ideal 0.316096 \n", + "8 unacceptable -1.888861 \n", + "9 ideal -0.272137 \n", + "10 ideal 0.942305 \n", + "11 unacceptable 1.013881 \n", + "12 ideal -1.140227 \n", + "13 ideal -1.208828 \n", + "14 unacceptable 1.530586 \n", + "15 acceptable 0.505863 \n", + "16 unacceptable 1.968161 \n", + "17 unacceptable 0.063285 \n", + "18 unacceptable -0.532379 \n", + "19 unacceptable 0.494498 " ] }, "execution_count": 3, @@ -398,7 +398,7 @@ "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", "output_features = Outputs(features=[\n", " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", - " CategoricalOutput(key=f\"f_{1}\", categories=[\"unacceptable\", \"acceptable\", \"ideal\"], objective=CategoricalObjective(weights=(0, 0.5, 1))), # This function will be associated with learning the categories\n", + " CategoricalOutput(key=f\"f_{1}\", categories=(\"unacceptable\", \"acceptable\", \"ideal\"), objective=CategoricalObjective(desirability=(0, 0.5, 1))), # This function will be associated with learning the categories\n", " ContinuousOutput(key=f\"f_{2}\", objective=MinimizeSigmoidObjective(w=1., tp=0.0, steepness=0.5)),\n", " ]\n", ")\n", @@ -431,16 +431,7 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\optim\\fit.py:130: OptimizationWarning: `scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH\n", - " warn(\n" - ] - } - ], + "outputs": [], "source": [ "from bofire.data_models.acquisition_functions.api import qEI\n", "from bofire.data_models.strategies.api import SoboStrategy\n", @@ -525,70 +516,70 @@ " \n", " \n", " 0\n", - " 0.298542\n", - " 0.611379\n", - " 0.563367\n", - " 2.000000\n", - " 0.357845\n", + " 0.287768\n", + " -0.159834\n", + " 1.117793\n", + " -0.432811\n", + " 0.891773\n", " 1.0\n", - " -1.184754\n", - " 0.303593\n", + " -1.219115\n", + " 0.292788\n", " unacceptable\n", - " 0.512505\n", - " 0.064448\n", - " 0.423048\n", - " 0.632854\n", - " 0.003097\n", - " 0.294712\n", - " 0.058737\n", - " 0.324907\n", - " 1.184754\n", - " 0.462124\n", - " 0.455272\n", + " 0.458888\n", + " 0.113660\n", + " 0.427452\n", + " 0.473508\n", + " 0.003124\n", + " 0.101463\n", + " 0.037833\n", + " 0.098314\n", + " 1.219115\n", + " 0.463467\n", + " 0.484282\n", " \n", " \n", " 1\n", - " 0.180137\n", - " 0.565072\n", - " 0.600042\n", - " 1.093629\n", - " 0.361472\n", + " -1.768118\n", + " 0.077809\n", + " 2.000000\n", + " -1.189829\n", + " -2.000000\n", " 0.0\n", - " -1.091947\n", - " 0.185482\n", + " 12.989883\n", + " -1.764363\n", " unacceptable\n", - " 0.512505\n", - " 0.064450\n", - " 0.423045\n", - " 0.581322\n", - " 0.002997\n", - " 0.294712\n", - " 0.058734\n", - " 0.324901\n", - " 1.091947\n", - " 0.476831\n", - " 0.455270\n", + " 0.536629\n", + " 0.119498\n", + " 0.343873\n", + " 0.177809\n", + " 0.003669\n", + " 0.214828\n", + " 0.036302\n", + " 0.215846\n", + " -12.989883\n", + " 0.707274\n", + " 0.403622\n", " \n", " \n", "\n", "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_2_pred \\\n", - "0 0.298542 0.611379 0.563367 2.000000 0.357845 1.0 -1.184754 0.303593 \n", - "1 0.180137 0.565072 0.600042 1.093629 0.361472 0.0 -1.091947 0.185482 \n", + " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_2_pred \\\n", + "0 0.287768 -0.159834 1.117793 -0.432811 0.891773 1.0 -1.219115 0.292788 \n", + "1 -1.768118 0.077809 2.000000 -1.189829 -2.000000 0.0 12.989883 -1.764363 \n", "\n", " f_1_pred f_1_pred_unacceptable f_1_pred_acceptable f_1_pred_ideal \\\n", - "0 unacceptable 0.512505 0.064448 0.423048 \n", - "1 unacceptable 0.512505 0.064450 0.423045 \n", + "0 unacceptable 0.458888 0.113660 0.427452 \n", + "1 unacceptable 0.536629 0.119498 0.343873 \n", "\n", " f_0_sd f_2_sd f_1_sd_unacceptable f_1_sd_acceptable f_1_sd_ideal \\\n", - "0 0.632854 0.003097 0.294712 0.058737 0.324907 \n", - "1 0.581322 0.002997 0.294712 0.058734 0.324901 \n", + "0 0.473508 0.003124 0.101463 0.037833 0.098314 \n", + "1 0.177809 0.003669 0.214828 0.036302 0.215846 \n", "\n", - " f_0_des f_2_des f_1_des \n", - "0 1.184754 0.462124 0.455272 \n", - "1 1.091947 0.476831 0.455270 " + " f_0_des f_2_des f_1_des \n", + "0 1.219115 0.463467 0.484282 \n", + "1 -12.989883 0.707274 0.403622 " ] }, "execution_count": 5, From 89df3a5739bb28aaa135cef69030ce64be7763d3 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 21 Dec 2023 08:55:52 -0500 Subject: [PATCH 08/31] Pre-merge commit --- bofire/data_models/objectives/categorical.py | 13 +- bofire/data_models/surrogates/surrogate.py | 20 +- bofire/data_models/surrogates/xgb.py | 1 + bofire/utils/torch_tools.py | 10 +- .../Unknown_Constraint_Classification.ipynb | 566 +++++++----------- 5 files changed, 258 insertions(+), 352 deletions(-) diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index 4e3d6ce6a..9818f08a1 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -19,23 +19,14 @@ class CategoricalObjective(Objective, ConstrainedObjective): Attributes: w (float): float between zero and one for weighting the objective. - desirability (tuple): tuple of values of size c (c is number of categories) such that the i-th entry is in (0, 1) + desirability (tuple): tuple of values of size c (c is number of categories) such that the i-th entry is in {True, False} """ w: TWeight = 1.0 - desirability: Tuple[float, ...] + desirability: Tuple[bool, ...] eta: float = 1.0 type: Literal["CategoricalObjective"] = "CategoricalObjective" - @validator("desirability") - def validate_desirability(cls, desirability): - for w in desirability: - if w > 1: - raise ValueError("Objective weight has to be smaller equal than 1.") - if w < 0: - raise ValueError("Objective weight has to be larger equal than zero") - return desirability - def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: """The call function returning a probabilistic reward for x. diff --git a/bofire/data_models/surrogates/surrogate.py b/bofire/data_models/surrogates/surrogate.py index 170e6b4c7..2a5f73e7e 100644 --- a/bofire/data_models/surrogates/surrogate.py +++ b/bofire/data_models/surrogates/surrogate.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Optional, Union from pydantic import Field, validator from bofire.data_models.base import BaseModel from bofire.data_models.domain.api import Inputs, Outputs -from bofire.data_models.features.api import TInputTransformSpecs +from bofire.data_models.features.api import ContinuousOutput, CategoricalOutput, TInputTransformSpecs class Surrogate(BaseModel): @@ -28,3 +28,19 @@ def validate_outputs(cls, v, values): if len(v) == 0: raise ValueError("At least one output feature has to be provided.") return v + + @classmethod # TODO: Remove this, change it, ??? + def is_output_implemented(cls, outputs, my_type: Union[ContinuousOutput, CategoricalOutput]) -> bool: + """Abstract method to check output type for surrogate models + + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output + + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + for o in outputs: + if not isinstance(o, my_type): + return False + return True diff --git a/bofire/data_models/surrogates/xgb.py b/bofire/data_models/surrogates/xgb.py index 7e38c8bd6..c77ee8faa 100644 --- a/bofire/data_models/surrogates/xgb.py +++ b/bofire/data_models/surrogates/xgb.py @@ -77,6 +77,7 @@ def validate_input_preprocessing_specs(cls, v, values): raise ValueError("Currently no numeric transforms are supported.") return v + @validator("outputs") def validate_outputs(cls, outputs): """validates outputs diff --git a/bofire/utils/torch_tools.py b/bofire/utils/torch_tools.py index c80e38de0..5c5bd3a15 100644 --- a/bofire/utils/torch_tools.py +++ b/bofire/utils/torch_tools.py @@ -139,6 +139,7 @@ def min_constraint(indices: Tensor, num_features: int, min_count: int): def constrained_objective2botorch( idx: int, objective: ConstrainedObjective, + eps: float = 1e-6 ) -> Tuple[List[Callable[[Tensor], Tensor]], List[float], int]: """Create a callable that can be used by `botorch.utils.objective.apply_constraints` to setup ouput constrained optimizations. @@ -180,13 +181,12 @@ def constrained_objective2botorch( ) elif isinstance(objective, CategoricalObjective): # The output of a categorical objective has final dim `c` where `c` is number of classes + # Pass in the expected acceptance probability and perform an inverse sigmoid to atain the original probabilities return ( [ - lambda Z: -1.0 - * ( - Z[..., idx : idx + len(objective.desirability)] - * torch.tensor(objective.desirability).to(**tkwargs) - ).sum(-1) + lambda Z: torch.log( + torch.clamp(1 / (Z[..., idx : idx + len(objective.desirability)] * torch.tensor(objective.desirability).to(**tkwargs)).sum(-1) - 1, min=eps, max=1-eps) + ) ], [objective.eta], idx + len(objective.desirability), diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 2a39131cb..4b0c12227 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -98,243 +98,243 @@ " \n", " \n", " 0\n", - " -1.607257\n", - " 0.066965\n", - " 1.956921\n", - " -1.206965\n", - " -1.726151\n", + " 0.435454\n", + " -1.788437\n", + " -1.988452\n", + " -0.421175\n", + " -1.708184\n", " 0.0\n", - " 11.168473\n", + " 15.230583\n", " unacceptable\n", - " -1.599521\n", + " 0.437269\n", " \n", " \n", " 1\n", - " 0.576852\n", - " -0.826301\n", - " 0.188427\n", - " 0.228371\n", - " -0.548498\n", - " 0.0\n", - " 1.232865\n", + " -0.839497\n", + " -1.317481\n", + " -0.621791\n", + " 1.029460\n", + " -0.029063\n", + " 1.0\n", + " 4.092890\n", " unacceptable\n", - " 0.583328\n", + " -0.829668\n", " \n", " \n", " 2\n", - " 0.426233\n", - " -0.353617\n", - " -1.501839\n", - " -1.336984\n", - " 1.611887\n", - " 1.0\n", - " 5.894178\n", + " 0.457487\n", + " -1.632370\n", + " -0.219897\n", + " -0.589791\n", + " -0.543485\n", + " 0.0\n", + " 4.653823\n", " unacceptable\n", - " 0.432236\n", + " 0.466307\n", " \n", " \n", " 3\n", - " 0.791177\n", - " 0.517637\n", - " 1.054840\n", - " 1.702119\n", - " -0.133859\n", - " 1.0\n", - " 0.538519\n", - " ideal\n", - " 0.794129\n", + " -0.545863\n", + " 1.068268\n", + " -0.840059\n", + " -0.156614\n", + " 0.532427\n", + " 0.0\n", + " 1.944284\n", + " unacceptable\n", + " -0.544115\n", " \n", " \n", " 4\n", - " -0.212858\n", - " 1.738505\n", - " 0.047741\n", - " 0.785163\n", - " -0.386461\n", + " 1.294465\n", + " -1.086366\n", + " -1.444690\n", + " -1.005573\n", + " -0.743576\n", " 1.0\n", - " 1.373196\n", - " ideal\n", - " -0.204777\n", + " 5.516022\n", + " unacceptable\n", + " 1.303366\n", " \n", " \n", " 5\n", - " -1.343810\n", - " 0.174728\n", - " -0.742157\n", - " 0.828275\n", - " -0.383426\n", - " 0.0\n", - " 3.939393\n", + " -0.826168\n", + " -1.615002\n", + " 1.624282\n", + " 0.041478\n", + " 0.413251\n", + " 1.0\n", + " 5.652505\n", " unacceptable\n", - " -1.341086\n", + " -0.819800\n", " \n", " \n", " 6\n", - " 0.427134\n", - " -1.622157\n", - " 1.390718\n", - " -0.325319\n", - " -0.767477\n", - " 0.0\n", - " 4.739526\n", - " unacceptable\n", - " 0.431781\n", + " 0.213634\n", + " -0.469763\n", + " 0.950611\n", + " 1.945084\n", + " -0.783820\n", + " 1.0\n", + " 1.695620\n", + " ideal\n", + " 0.214467\n", " \n", " \n", " 7\n", - " 0.314917\n", - " 1.367960\n", - " 0.926581\n", - " 1.710552\n", - " -0.843306\n", - " 0.0\n", - " 1.487580\n", - " ideal\n", - " 0.316096\n", + " -0.570017\n", + " 0.019903\n", + " -1.137774\n", + " -0.417314\n", + " 1.457074\n", + " 1.0\n", + " 2.909370\n", + " unacceptable\n", + " -0.567218\n", " \n", " \n", " 8\n", - " -1.891133\n", - " -1.125407\n", - " 1.468634\n", - " -0.906353\n", - " 0.456227\n", - " 1.0\n", - " 8.206412\n", + " -0.756360\n", + " -0.905112\n", + " 1.536840\n", + " -0.081425\n", + " 0.350125\n", + " 0.0\n", + " 2.621606\n", " unacceptable\n", - " -1.888861\n", + " -0.747130\n", " \n", " \n", " 9\n", - " -0.281061\n", - " -0.167345\n", - " 0.589614\n", - " 1.681104\n", - " 1.800321\n", + " 1.615900\n", + " -0.951538\n", + " 1.450884\n", + " 0.623330\n", + " 0.232692\n", " 0.0\n", - " 1.807627\n", + " 1.690015\n", " ideal\n", - " -0.272137\n", + " 1.622131\n", " \n", " \n", " 10\n", - " 0.936582\n", - " 1.536808\n", - " 1.250784\n", - " 1.286891\n", - " 0.246477\n", + " 1.679599\n", + " 0.482301\n", + " -0.331234\n", + " 0.343804\n", + " 1.968027\n", " 1.0\n", - " 0.378613\n", + " 1.466946\n", " ideal\n", - " 0.942305\n", + " 1.679712\n", " \n", " \n", " 11\n", - " 1.013835\n", - " -1.056954\n", - " -0.894305\n", - " -1.599589\n", - " -0.405500\n", - " 1.0\n", - " 5.569871\n", - " unacceptable\n", - " 1.013881\n", + " 1.967228\n", + " 1.924398\n", + " 0.038540\n", + " 1.206153\n", + " 1.048228\n", + " 0.0\n", + " 1.880371\n", + " ideal\n", + " 1.969058\n", " \n", " \n", " 12\n", - " -1.144869\n", - " 0.845439\n", - " 0.511907\n", - " 1.450526\n", - " -0.416333\n", + " 0.784865\n", + " -1.456000\n", + " -1.918741\n", + " -0.222590\n", + " -1.711922\n", " 1.0\n", - " 2.774191\n", - " ideal\n", - " -1.140227\n", + " 12.611483\n", + " unacceptable\n", + " 0.786212\n", " \n", " \n", " 13\n", - " -1.213794\n", - " 1.993482\n", - " -1.097520\n", - " 1.123132\n", - " 0.433351\n", - " 1.0\n", - " 4.678376\n", - " ideal\n", - " -1.208828\n", + " -0.690906\n", + " -0.902616\n", + " 1.696768\n", + " 0.247553\n", + " -0.051436\n", + " 0.0\n", + " 2.647546\n", + " unacceptable\n", + " -0.687487\n", " \n", " \n", " 14\n", - " 1.523778\n", - " 0.713813\n", - " -0.675652\n", - " -1.730426\n", - " 0.057436\n", + " -0.415655\n", + " 0.478458\n", + " 0.975439\n", + " 1.812020\n", + " 1.215072\n", " 1.0\n", - " 5.140214\n", - " unacceptable\n", - " 1.530586\n", + " 1.493446\n", + " ideal\n", + " -0.414414\n", " \n", " \n", " 15\n", - " 0.504682\n", - " 0.182944\n", - " 1.127047\n", - " 0.739912\n", - " -1.141001\n", - " 0.0\n", - " 1.851892\n", + " 0.355296\n", + " -1.806993\n", + " 0.823093\n", + " -0.596255\n", + " 1.884369\n", + " 1.0\n", + " 5.908147\n", " acceptable\n", - " 0.505863\n", + " 0.363767\n", " \n", " \n", " 16\n", - " 1.962748\n", - " -0.965457\n", - " 0.262051\n", - " 0.326342\n", - " -1.285648\n", + " -0.227270\n", + " 1.515492\n", + " -0.378421\n", + " -0.282083\n", + " -1.022800\n", " 0.0\n", - " 4.234999\n", + " 2.312977\n", " unacceptable\n", - " 1.968161\n", + " -0.226730\n", " \n", " \n", " 17\n", - " 0.054107\n", - " 0.017339\n", - " -1.421447\n", - " -0.475980\n", - " 0.825959\n", - " 1.0\n", - " 3.445593\n", + " -1.938272\n", + " -0.544588\n", + " 0.677599\n", + " 1.611933\n", + " -1.657255\n", + " 0.0\n", + " 10.617281\n", " unacceptable\n", - " 0.063285\n", + " -1.938064\n", " \n", " \n", " 18\n", - " -0.538143\n", - " -0.853412\n", - " 1.263803\n", - " -1.066115\n", - " -0.697828\n", + " -1.233844\n", + " 1.107200\n", + " -0.249493\n", + " 1.132248\n", + " -1.648399\n", " 0.0\n", - " 3.396296\n", + " 6.794837\n", " unacceptable\n", - " -0.532379\n", + " -1.226482\n", " \n", " \n", " 19\n", - " 0.492618\n", - " -1.346028\n", - " -0.405918\n", - " -1.280858\n", - " -0.913901\n", - " 1.0\n", - " 5.346855\n", + " 0.312220\n", + " -0.639545\n", + " 1.626105\n", + " -0.081847\n", + " -0.252003\n", + " 0.0\n", + " 1.202175\n", " unacceptable\n", - " 0.494498\n", + " 0.313409\n", " \n", " \n", "\n", @@ -342,48 +342,48 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 -1.607257 0.066965 1.956921 -1.206965 -1.726151 0.0 11.168473 \n", - "1 0.576852 -0.826301 0.188427 0.228371 -0.548498 0.0 1.232865 \n", - "2 0.426233 -0.353617 -1.501839 -1.336984 1.611887 1.0 5.894178 \n", - "3 0.791177 0.517637 1.054840 1.702119 -0.133859 1.0 0.538519 \n", - "4 -0.212858 1.738505 0.047741 0.785163 -0.386461 1.0 1.373196 \n", - "5 -1.343810 0.174728 -0.742157 0.828275 -0.383426 0.0 3.939393 \n", - "6 0.427134 -1.622157 1.390718 -0.325319 -0.767477 0.0 4.739526 \n", - "7 0.314917 1.367960 0.926581 1.710552 -0.843306 0.0 1.487580 \n", - "8 -1.891133 -1.125407 1.468634 -0.906353 0.456227 1.0 8.206412 \n", - "9 -0.281061 -0.167345 0.589614 1.681104 1.800321 0.0 1.807627 \n", - "10 0.936582 1.536808 1.250784 1.286891 0.246477 1.0 0.378613 \n", - "11 1.013835 -1.056954 -0.894305 -1.599589 -0.405500 1.0 5.569871 \n", - "12 -1.144869 0.845439 0.511907 1.450526 -0.416333 1.0 2.774191 \n", - "13 -1.213794 1.993482 -1.097520 1.123132 0.433351 1.0 4.678376 \n", - "14 1.523778 0.713813 -0.675652 -1.730426 0.057436 1.0 5.140214 \n", - "15 0.504682 0.182944 1.127047 0.739912 -1.141001 0.0 1.851892 \n", - "16 1.962748 -0.965457 0.262051 0.326342 -1.285648 0.0 4.234999 \n", - "17 0.054107 0.017339 -1.421447 -0.475980 0.825959 1.0 3.445593 \n", - "18 -0.538143 -0.853412 1.263803 -1.066115 -0.697828 0.0 3.396296 \n", - "19 0.492618 -1.346028 -0.405918 -1.280858 -0.913901 1.0 5.346855 \n", + "0 0.435454 -1.788437 -1.988452 -0.421175 -1.708184 0.0 15.230583 \n", + "1 -0.839497 -1.317481 -0.621791 1.029460 -0.029063 1.0 4.092890 \n", + "2 0.457487 -1.632370 -0.219897 -0.589791 -0.543485 0.0 4.653823 \n", + "3 -0.545863 1.068268 -0.840059 -0.156614 0.532427 0.0 1.944284 \n", + "4 1.294465 -1.086366 -1.444690 -1.005573 -0.743576 1.0 5.516022 \n", + "5 -0.826168 -1.615002 1.624282 0.041478 0.413251 1.0 5.652505 \n", + "6 0.213634 -0.469763 0.950611 1.945084 -0.783820 1.0 1.695620 \n", + "7 -0.570017 0.019903 -1.137774 -0.417314 1.457074 1.0 2.909370 \n", + "8 -0.756360 -0.905112 1.536840 -0.081425 0.350125 0.0 2.621606 \n", + "9 1.615900 -0.951538 1.450884 0.623330 0.232692 0.0 1.690015 \n", + "10 1.679599 0.482301 -0.331234 0.343804 1.968027 1.0 1.466946 \n", + "11 1.967228 1.924398 0.038540 1.206153 1.048228 0.0 1.880371 \n", + "12 0.784865 -1.456000 -1.918741 -0.222590 -1.711922 1.0 12.611483 \n", + "13 -0.690906 -0.902616 1.696768 0.247553 -0.051436 0.0 2.647546 \n", + "14 -0.415655 0.478458 0.975439 1.812020 1.215072 1.0 1.493446 \n", + "15 0.355296 -1.806993 0.823093 -0.596255 1.884369 1.0 5.908147 \n", + "16 -0.227270 1.515492 -0.378421 -0.282083 -1.022800 0.0 2.312977 \n", + "17 -1.938272 -0.544588 0.677599 1.611933 -1.657255 0.0 10.617281 \n", + "18 -1.233844 1.107200 -0.249493 1.132248 -1.648399 0.0 6.794837 \n", + "19 0.312220 -0.639545 1.626105 -0.081847 -0.252003 0.0 1.202175 \n", "\n", " f_1 f_2 \n", - "0 unacceptable -1.599521 \n", - "1 unacceptable 0.583328 \n", - "2 unacceptable 0.432236 \n", - "3 ideal 0.794129 \n", - "4 ideal -0.204777 \n", - "5 unacceptable -1.341086 \n", - "6 unacceptable 0.431781 \n", - "7 ideal 0.316096 \n", - "8 unacceptable -1.888861 \n", - "9 ideal -0.272137 \n", - "10 ideal 0.942305 \n", - "11 unacceptable 1.013881 \n", - "12 ideal -1.140227 \n", - "13 ideal -1.208828 \n", - "14 unacceptable 1.530586 \n", - "15 acceptable 0.505863 \n", - "16 unacceptable 1.968161 \n", - "17 unacceptable 0.063285 \n", - "18 unacceptable -0.532379 \n", - "19 unacceptable 0.494498 " + "0 unacceptable 0.437269 \n", + "1 unacceptable -0.829668 \n", + "2 unacceptable 0.466307 \n", + "3 unacceptable -0.544115 \n", + "4 unacceptable 1.303366 \n", + "5 unacceptable -0.819800 \n", + "6 ideal 0.214467 \n", + "7 unacceptable -0.567218 \n", + "8 unacceptable -0.747130 \n", + "9 ideal 1.622131 \n", + "10 ideal 1.679712 \n", + "11 ideal 1.969058 \n", + "12 unacceptable 0.786212 \n", + "13 unacceptable -0.687487 \n", + "14 ideal -0.414414 \n", + "15 acceptable 0.363767 \n", + "16 unacceptable -0.226730 \n", + "17 unacceptable -1.938064 \n", + "18 unacceptable -1.226482 \n", + "19 unacceptable 0.313409 " ] }, "execution_count": 3, @@ -431,18 +431,45 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "TypeError", + "evalue": "__init__() missing 1 required positional argument: 'task_feature'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mc:\\Users\\G15361\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\tutorials\\basic_examples\\Unknown_Constraint_Classification.ipynb Cell 7\u001b[0m line \u001b[0;36m2\n\u001b[0;32m 15\u001b[0m strategy_data \u001b[39m=\u001b[39m SoboStrategy(domain\u001b[39m=\u001b[39mdomain1, \n\u001b[0;32m 16\u001b[0m acquisition_function\u001b[39m=\u001b[39mqEI(), \n\u001b[0;32m 17\u001b[0m surrogate_specs\u001b[39m=\u001b[39mBotorchSurrogates(surrogates\u001b[39m=\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 22\u001b[0m )\n\u001b[0;32m 23\u001b[0m )\n\u001b[0;32m 25\u001b[0m strategy \u001b[39m=\u001b[39m strategies\u001b[39m.\u001b[39mmap(strategy_data)\n\u001b[1;32m---> 27\u001b[0m strategy\u001b[39m.\u001b[39;49mtell(sample_df)\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\predictive.py:84\u001b[0m, in \u001b[0;36mPredictiveStrategy.tell\u001b[1;34m(self, experiments, replace, retrain)\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39madd_experiments(experiments)\n\u001b[0;32m 83\u001b[0m \u001b[39mif\u001b[39;00m retrain \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhas_sufficient_experiments():\n\u001b[1;32m---> 84\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfit()\n\u001b[0;32m 85\u001b[0m \u001b[39m# we have a seperate _tell here for things that are relevant when setting up the strategy but unrelated\u001b[39;00m\n\u001b[0;32m 86\u001b[0m \u001b[39m# to fitting the models like initializing the ACQF.\u001b[39;00m\n\u001b[0;32m 87\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_tell()\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\predictive.py:165\u001b[0m, in \u001b[0;36mPredictiveStrategy.fit\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 163\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdomain\u001b[39m.\u001b[39mvalidate_experiments(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexperiments, strict\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m 164\u001b[0m \u001b[39m# transformed = self.transformer.fit_transform(self.experiments)\u001b[39;00m\n\u001b[1;32m--> 165\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mexperiments)\n\u001b[0;32m 166\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_fitted \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\botorch.py:138\u001b[0m, in \u001b[0;36mBotorchStrategy._fit\u001b[1;34m(self, experiments)\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[39m# map the surrogate spec, we keep it here as attribute to be able to save/dump\u001b[39;00m\n\u001b[0;32m 135\u001b[0m \u001b[39m# the surrogate\u001b[39;00m\n\u001b[0;32m 136\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogates \u001b[39m=\u001b[39m BotorchSurrogates(data_model\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogate_specs) \u001b[39m# type: ignore\u001b[39;00m\n\u001b[1;32m--> 138\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msurrogates\u001b[39m.\u001b[39;49mfit(experiments) \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 139\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogates\u001b[39m.\u001b[39mcompatibilize( \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 140\u001b[0m inputs\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdomain\u001b[39m.\u001b[39minputs, \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 141\u001b[0m outputs\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdomain\u001b[39m.\u001b[39moutputs, \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 142\u001b[0m )\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\surrogates\\botorch_surrogates.py:40\u001b[0m, in \u001b[0;36mBotorchSurrogates.fit\u001b[1;34m(self, experiments)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[39mfor\u001b[39;00m model \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogates:\n\u001b[0;32m 39\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(model, TrainableSurrogate):\n\u001b[1;32m---> 40\u001b[0m model\u001b[39m.\u001b[39;49mfit(experiments)\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\surrogates\\trainable.py:35\u001b[0m, in \u001b[0;36mTrainableSurrogate.fit\u001b[1;34m(self, experiments, options)\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[39m# fit\u001b[39;00m\n\u001b[0;32m 34\u001b[0m options \u001b[39m=\u001b[39m options \u001b[39mor\u001b[39;00m {}\n\u001b[1;32m---> 35\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fit(X\u001b[39m=\u001b[39mX, Y\u001b[39m=\u001b[39mY, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39moptions)\n", + "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\surrogates\\gp_classifier.py:85\u001b[0m, in \u001b[0;36mGPClassifier._fit\u001b[1;34m(self, X, Y)\u001b[0m\n\u001b[0;32m 82\u001b[0m tf \u001b[39m=\u001b[39m ChainedInputTransform(tf1\u001b[39m=\u001b[39mscaler, tf2\u001b[39m=\u001b[39mo2n) \u001b[39mif\u001b[39;00m scaler \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m o2n\n\u001b[0;32m 84\u001b[0m \u001b[39m# fit the model\u001b[39;00m\n\u001b[1;32m---> 85\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel \u001b[39m=\u001b[39m botorch\u001b[39m.\u001b[39;49mmodels\u001b[39m.\u001b[39;49mMultiTaskGP(\n\u001b[0;32m 86\u001b[0m train_X\u001b[39m=\u001b[39;49mo2n\u001b[39m.\u001b[39;49mtransform(tX),\n\u001b[0;32m 87\u001b[0m train_Y\u001b[39m=\u001b[39;49mtY,\n\u001b[0;32m 88\u001b[0m likelihood\u001b[39m=\u001b[39;49mSoftmaxLikelihood,\n\u001b[0;32m 89\u001b[0m covar_module\u001b[39m=\u001b[39;49mpartial(kernels\u001b[39m.\u001b[39;49mmap, data_model\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcontinuous_kernel),\n\u001b[0;32m 90\u001b[0m outcome_transform\u001b[39m=\u001b[39;49mStandardize(m\u001b[39m=\u001b[39;49mtY\u001b[39m.\u001b[39;49mshape[\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m]),\n\u001b[0;32m 91\u001b[0m input_transform\u001b[39m=\u001b[39;49mtf,\n\u001b[0;32m 92\u001b[0m )\n\u001b[0;32m 93\u001b[0m mll \u001b[39m=\u001b[39m ExactMarginalLogLikelihood(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mlikelihood, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel)\n\u001b[0;32m 94\u001b[0m fit_gpytorch_mll(mll, options\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining_specs)\n", + "\u001b[1;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'task_feature'" + ] + } + ], "source": [ "from bofire.data_models.acquisition_functions.api import qEI\n", "from bofire.data_models.strategies.api import SoboStrategy\n", - "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate, GPClassifier\n", "from bofire.data_models.domain.api import Outputs\n", "\n", + "# strategy_data = SoboStrategy(domain=domain1, \n", + "# acquisition_function=qEI(), \n", + "# surrogate_specs=BotorchSurrogates(surrogates=\n", + "# [\n", + "# MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=1.0, n_epochs=50, hidden_layer_sizes=(20,)),\n", + "# MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", + "# ]\n", + "# )\n", + "# )\n", "strategy_data = SoboStrategy(domain=domain1, \n", " acquisition_function=qEI(), \n", " surrogate_specs=BotorchSurrogates(surrogates=\n", " [\n", - " MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=1.0, n_epochs=50, hidden_layer_sizes=(20,)),\n", + " GPClassifier(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")])),\n", " MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", " ]\n", " )\n", @@ -455,138 +482,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:214: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
x_0x_1x_2x_3x_4x_5f_0_predf_2_predf_1_predf_1_pred_unacceptablef_1_pred_acceptablef_1_pred_idealf_0_sdf_2_sdf_1_sd_unacceptablef_1_sd_acceptablef_1_sd_idealf_0_desf_2_desf_1_des
00.287768-0.1598341.117793-0.4328110.8917731.0-1.2191150.292788unacceptable0.4588880.1136600.4274520.4735080.0031240.1014630.0378330.0983141.2191150.4634670.484282
1-1.7681180.0778092.000000-1.189829-2.0000000.012.989883-1.764363unacceptable0.5366290.1194980.3438730.1778090.0036690.2148280.0363020.215846-12.9898830.7072740.403622
\n", - "
" - ], - "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_2_pred \\\n", - "0 0.287768 -0.159834 1.117793 -0.432811 0.891773 1.0 -1.219115 0.292788 \n", - "1 -1.768118 0.077809 2.000000 -1.189829 -2.000000 0.0 12.989883 -1.764363 \n", - "\n", - " f_1_pred f_1_pred_unacceptable f_1_pred_acceptable f_1_pred_ideal \\\n", - "0 unacceptable 0.458888 0.113660 0.427452 \n", - "1 unacceptable 0.536629 0.119498 0.343873 \n", - "\n", - " f_0_sd f_2_sd f_1_sd_unacceptable f_1_sd_acceptable f_1_sd_ideal \\\n", - "0 0.473508 0.003124 0.101463 0.037833 0.098314 \n", - "1 0.177809 0.003669 0.214828 0.036302 0.215846 \n", - "\n", - " f_0_des f_2_des f_1_des \n", - "0 1.219115 0.463467 0.484282 \n", - "1 -12.989883 0.707274 0.403622 " - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "candidates = strategy.ask(2)\n", "candidates" From b1f51473d9e164e20278e7c995701f571bfff333 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 21 Dec 2023 16:40:47 -0500 Subject: [PATCH 09/31] Refactor classification surrogates --- bofire/data_models/domain/features.py | 35 +- bofire/data_models/features/categorical.py | 79 +- bofire/data_models/features/feature.py | 2 +- bofire/data_models/objectives/api.py | 6 +- bofire/data_models/objectives/categorical.py | 80 +- .../strategies/predictives/sobo.py | 2 +- .../surrogates/botorch_surrogates.py | 2 +- bofire/data_models/surrogates/empirical.py | 16 +- .../data_models/surrogates/fully_bayesian.py | 26 +- bofire/data_models/surrogates/linear.py | 28 +- .../surrogates/mixed_single_task_gp.py | 22 +- bofire/data_models/surrogates/mlp.py | 25 +- .../data_models/surrogates/mlp_classifier.py | 24 +- bofire/data_models/surrogates/polynomial.py | 16 +- .../data_models/surrogates/random_forest.py | 27 +- .../data_models/surrogates/single_task_gp.py | 33 +- bofire/data_models/surrogates/surrogate.py | 17 +- bofire/data_models/surrogates/tanimoto_gp.py | 24 +- bofire/data_models/surrogates/xgb.py | 22 +- bofire/strategies/doe/design.py | 1 - bofire/strategies/predictives/predictive.py | 5 +- bofire/surrogates/mlp.py | 21 +- bofire/surrogates/mlp_classifier.py | 315 +++--- bofire/utils/torch_tools.py | 19 +- .../bofire/data_models/domain/test_outputs.py | 17 +- .../data_models/features/test_categorical.py | 11 - .../data_models/features/test_descriptor.py | 2 +- .../serialization/test_serialization.py | 1 - tests/bofire/data_models/specs/features.py | 9 +- tests/bofire/surrogates/test_mlp.py | 8 +- tests/bofire/surrogates/test_surrogates.py | 17 +- .../Unknown_Constraint_Classification.ipynb | 916 +++++++++++++----- 32 files changed, 1233 insertions(+), 595 deletions(-) diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index 477a32af0..67a1e8fdc 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -655,7 +655,7 @@ def __call__( and not isinstance(feat, CategoricalOutput) ] + [ - feat.compute_objective(experiments.filter(regex=f"{feat.key}_pred_")) # type: ignore + pd.Series(data=feat(experiments.filter(regex=f"{feat.key}_pred_")), name=f"{feat.key}_pred") # type: ignore if predictions else experiments[feat.key] for feat in self.features @@ -715,6 +715,12 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: includes=Objective, excludes=CategoricalObjective ) ] + + [ + [f"{key}_pred", f"{key}_sd"] + for key in self.get_keys_by_objective( + excludes=Objective, includes=None # type: ignore + ) + ] ) ) # check that pred, sd, and des cols are specified and numerical @@ -729,18 +735,23 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: raise ValueError(f"Not all values of column `{col}` are numerical.") if candidates[col].isnull().to_numpy().any(): raise ValueError(f"Nan values are present in {col}.") - # Check for categorical output - categorical_cols = [ - (f"{obj.key}_pred", obj.categories) - for obj in self.get_by_objective(includes=CategoricalObjective) - ] - if len(categorical_cols) == 0: - return candidates - for col in categorical_cols: - if col[0] not in candidates: + # # Check for categorical output + # categorical_cols = [ + # (f"{obj.key}_pred", obj.categories) + # for obj in self.get_by_objective(includes=CategoricalObjective) + # ] + # if len(categorical_cols) == 0: + # return candidates + # for col in categorical_cols: + # if col[0] not in candidates: + # raise ValueError(f"missing column {col}") + # if len(candidates[col[0]]) - candidates[col[0]].isin(col[1]).sum() > 0: + # raise ValueError(f"values present are not in {col[1]}") + for feat in self.get(CategoricalOutput): + col = f"{feat.key}_pred" + if col not in candidates: raise ValueError(f"missing column {col}") - if len(candidates[col[0]]) - candidates[col[0]].isin(col[1]).sum() > 0: - raise ValueError(f"values present are not in {col[1]}") + feat.validate_experimental(candidates[col]) return candidates def preprocess_experiments_one_valid_output( diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index ec355e2ef..a8a1cc78e 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -1,8 +1,8 @@ -from typing import ClassVar, Dict, List, Literal, Optional, Tuple, Union +from typing import ClassVar, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd -from pydantic import Field, root_validator, validator +from pydantic import root_validator, validator from bofire.data_models.enum import CategoricalEncodingEnum from bofire.data_models.features.feature import ( @@ -13,7 +13,10 @@ TCategoryVals, TTransform, ) -from bofire.data_models.objectives.categorical import CategoricalObjective +from bofire.data_models.objectives.categorical import ( + CategoricalObjective, + ConstrainedCategoricalObjective, +) class CategoricalInput(Input): @@ -362,11 +365,11 @@ class CategoricalOutput(Output): order_id: ClassVar[int] = 8 categories: TCategoryVals - objective: CategoricalObjective = Field( - default_factory=lambda: CategoricalObjective(w=1.0) - ) + objective: Optional[ + Union[CategoricalObjective, ConstrainedCategoricalObjective] + ] = None - @validator("categories") + @validator("categories", allow_reuse=True) def validate_categories_unique(cls, categories): """validates that categories have unique names @@ -384,21 +387,30 @@ def validate_categories_unique(cls, categories): return tuple(categories) @validator("objective") - def validate_objective(cls, objective, values): - """validates that objective desirabilities are the same length as categories + def validate_objectives_unique(cls, objective, values): + """validates that categories have unique names + + Args: + categories (Union[List[str], Tuple[str]]): List or tuple of category names Raises: - ValueError: when len(objective.desirability) != len(categories) + ValueError: when categories do not match objective categories Returns: - CategoricalObjective + Tuple[str]: Tuple of the categories """ - if len(objective.desirability) != len(values["categories"]): - raise ValueError( - f"{len(objective.desirability)} desirabilities and {len(values['categories'])} categories" - ) + if objective.categories != tuple(values["categories"]): + raise ValueError("categories must match to objective categories") return objective + @classmethod + def from_objective( + cls, + key: str, + objective: Union[CategoricalObjective, ConstrainedCategoricalObjective], + ): + return cls(key=key, objective=objective, categories=objective.categories) + def validate_experimental(self, values: pd.Series) -> pd.Series: values = values.map(str) if sum(values.isin(self.categories)) != len(values): @@ -407,33 +419,14 @@ def validate_experimental(self, values: pd.Series) -> pd.Series: ) return values + def __call__(self, values: pd.Series) -> pd.Series: + if self.objective is None: + return pd.Series( + data=[np.nan for _ in range(len(values))], + index=values.index, + name=values.name, + ) + return self.objective(values) # type: ignore + def __str__(self) -> str: return "CategoricalOutputFeature" - - def to_dict(self) -> Dict: - """Returns the catergories and corresponding objective values as dictionary""" - return dict(zip(self.categories, self.objective.desirability)) - - def to_dict_label(self) -> Dict: - """Returns the catergories and label location of categories""" - return {c: i for i, c in enumerate(self.categories)} - - def from_dict_label(self) -> Dict: - """Returns the label location and the categories""" - d = self.to_dict_label() - return dict(zip(d.values(), d.keys())) - - def map_to_categories(self, values: pd.DataFrame) -> pd.Series: - """Maps the input matrix of probabilities to the categories via argmax""" - return values.idxmax(1).str.replace(f"{self.key}_pred_", "").values - - def compute_objective(self, values: pd.DataFrame) -> pd.Series: - """Computes the objective value as: (p.o).sum() where p is the vector of probabilities and o is the vector of objective values""" - values.columns = values.columns.str.replace(f"{self.key}_pred_", "") - scale_series = pd.Series(self.to_dict()) - return pd.Series( - data=(values * scale_series).sum(1).values, name=f"{self.key}_pred" - ) - - def __call__(self, values: pd.Series) -> pd.Series: - return values.map(self.to_dict()) diff --git a/bofire/data_models/features/feature.py b/bofire/data_models/features/feature.py index 186c7c48c..215cf1948 100644 --- a/bofire/data_models/features/feature.py +++ b/bofire/data_models/features/feature.py @@ -165,7 +165,7 @@ def is_categorical(s: pd.Series, categories: List[str]): TDescriptors = Annotated[List[str], Field(min_items=1)] -TCategoryVals = Annotated[List[str], Field(min_items=2)] +TCategoryVals = Tuple[str, ...] TAllowedVals = Optional[Annotated[List[bool], Field(min_items=2)]] diff --git a/bofire/data_models/objectives/api.py b/bofire/data_models/objectives/api.py index f5c64c76e..0fe3f2e7f 100644 --- a/bofire/data_models/objectives/api.py +++ b/bofire/data_models/objectives/api.py @@ -1,6 +1,9 @@ from typing import Union -from bofire.data_models.objectives.categorical import CategoricalObjective +from bofire.data_models.objectives.categorical import ( + CategoricalObjective, + ConstrainedCategoricalObjective, +) from bofire.data_models.objectives.identity import ( IdentityObjective, MaximizeObjective, @@ -27,6 +30,7 @@ AnyConstraintObjective = Union[ CategoricalObjective, + ConstrainedCategoricalObjective, MaximizeSigmoidObjective, MinimizeSigmoidObjective, TargetObjective, diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index 9818f08a1..2540376de 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -1,9 +1,11 @@ -from typing import Literal, Tuple, Union +from typing import Dict, Literal, Tuple, Union +from warnings import warn import numpy as np import pandas as pd from pydantic import validator +from bofire.data_models.features.feature import TCategoryVals from bofire.data_models.objectives.objective import ( ConstrainedObjective, Objective, @@ -11,7 +13,37 @@ ) -class CategoricalObjective(Objective, ConstrainedObjective): +class CategoricalObjective(Objective): + """Categorical objective class; stores categories""" + + type: Literal["CategoricalObjective"] = "CategoricalObjective" + categories: TCategoryVals + + @validator("categories") + def validate_categories_unique(cls, categories): + """validates that categories have unique names + + Args: + categories (Union[List[str], Tuple[str]]): List or tuple of category names + + Raises: + ValueError: when categories have non-unique names + + Returns: + Tuple[str]: Tuple of the categories + """ + if len(categories) != len(set(categories)): + raise ValueError("categories must be unique") + return tuple(categories) + + def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: + warn( + "Categorical objective currently does not have a function. Returning the original input." + ) + return x + + +class ConstrainedCategoricalObjective(ConstrainedObjective, CategoricalObjective): """Compute the categorical objective value as: Po where P is an [n, c] matrix where each row is a probability vector @@ -23,11 +55,46 @@ class CategoricalObjective(Objective, ConstrainedObjective): """ w: TWeight = 1.0 + categories: TCategoryVals desirability: Tuple[bool, ...] eta: float = 1.0 - type: Literal["CategoricalObjective"] = "CategoricalObjective" + type: Literal["ConstrainedCategoricalObjective"] = "ConstrainedCategoricalObjective" - def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: + @validator("desirability") + def validate_categories_unique(cls, desirability, values): + """validates that desirabilities match the categories + + Args: + categories (Union[List[str], Tuple[str]]): List or tuple of category names + + Raises: + ValueError: when desirability count is not equal to category count + + Returns: + Tuple[bool]: Tuple of the desirability + """ + if len(desirability) != len(values["categories"]): + raise ValueError( + "number of categories differs from number of desirabilities" + ) + return tuple(desirability) + + def to_dict(self) -> Dict: + """Returns the categories and corresponding objective values as dictionary""" + return dict(zip(self.categories, self.desirability)) + + def to_dict_label(self) -> Dict: + """Returns the catergories and label location of categories""" + return {c: i for i, c in enumerate(self.categories)} + + def from_dict_label(self) -> Dict: + """Returns the label location and the categories""" + d = self.to_dict_label() + return dict(zip(d.values(), d.keys())) + + def __call__( + self, x: Union[pd.Series, np.ndarray] + ) -> Union[pd.Series, np.ndarray, float]: """The call function returning a probabilistic reward for x. Args: @@ -36,7 +103,4 @@ def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarr Returns: np.ndarray: A reward calculated as inner product of probabilities and feasible objectives. """ - print( - "Categorical objective currently does not have a function. Returning the original input." - ) - return x + return np.dot(x, np.array(self.desirability)) diff --git a/bofire/data_models/strategies/predictives/sobo.py b/bofire/data_models/strategies/predictives/sobo.py index c160661c8..6b24f92fc 100644 --- a/bofire/data_models/strategies/predictives/sobo.py +++ b/bofire/data_models/strategies/predictives/sobo.py @@ -6,7 +6,7 @@ AnySingleObjectiveAcquisitionFunction, qLogNEI, ) -from bofire.data_models.features.api import CategoricalOutput, Feature +from bofire.data_models.features.api import Feature from bofire.data_models.objectives.api import ConstrainedObjective, Objective from bofire.data_models.strategies.predictives.botorch import BotorchStrategy diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index 85dde7ece..48a8f9861 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -13,8 +13,8 @@ MixedSingleTaskGPSurrogate, ) from bofire.data_models.surrogates.mlp import MLPEnsemble -from bofire.data_models.surrogates.polynomial import PolynomialSurrogate from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble +from bofire.data_models.surrogates.polynomial import PolynomialSurrogate from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate diff --git a/bofire/data_models/surrogates/empirical.py b/bofire/data_models/surrogates/empirical.py index 6a887c697..018f53208 100644 --- a/bofire/data_models/surrogates/empirical.py +++ b/bofire/data_models/surrogates/empirical.py @@ -1,7 +1,21 @@ -from typing import Literal +from typing import Literal, Type +from bofire.data_models.features.api import AnyOutput from bofire.data_models.surrogates.botorch import BotorchSurrogate class EmpiricalSurrogate(BotorchSurrogate): type: Literal["EmpiricalSurrogate"] = "EmpiricalSurrogate" + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output + + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return True diff --git a/bofire/data_models/surrogates/fully_bayesian.py b/bofire/data_models/surrogates/fully_bayesian.py index 18288379e..56e132b46 100644 --- a/bofire/data_models/surrogates/fully_bayesian.py +++ b/bofire/data_models/surrogates/fully_bayesian.py @@ -1,11 +1,9 @@ -from typing import Literal +from typing import Literal, Type from pydantic import conint, validator -from bofire.data_models.features.api import ContinuousOutput -from bofire.data_models.surrogates.botorch import BotorchSurrogate -from bofire.data_models.surrogates.scaler import ScalerEnum -from bofire.data_models.surrogates.trainable import TrainableSurrogate +from bofire.data_models.features.api import AnyOutput, ContinuousOutput +from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate class SaasSingleTaskGPSurrogate(TrainableBotorchSurrogate): @@ -20,17 +18,15 @@ def validate_thinning(cls, value, values): raise ValueError("`num_samples` has to be larger than `thinning`.") return value - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/linear.py b/bofire/data_models/surrogates/linear.py index 9ff42c985..eba916356 100644 --- a/bofire/data_models/surrogates/linear.py +++ b/bofire/data_models/surrogates/linear.py @@ -1,18 +1,16 @@ -from typing import Literal +from typing import Literal, Type -from pydantic import Field, validator +from pydantic import Field # from bofire.data_models.strategies.api import FactorialStrategy -from bofire.data_models.features.api import ContinuousOutput +from bofire.data_models.features.api import AnyOutput, ContinuousOutput from bofire.data_models.kernels.api import LinearKernel from bofire.data_models.priors.api import ( BOTORCH_NOISE_PRIOR, - BOTORCH_SCALE_PRIOR, AnyPrior, ) -from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum -from bofire.data_models.surrogates.trainable import TrainableSurrogate +from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate class LinearSurrogate(TrainableBotorchSurrogate): @@ -22,17 +20,15 @@ class LinearSurrogate(TrainableBotorchSurrogate): noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR()) scaler: ScalerEnum = ScalerEnum.NORMALIZE - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/mixed_single_task_gp.py b/bofire/data_models/surrogates/mixed_single_task_gp.py index 23c5ba5d4..df0b36478 100644 --- a/bofire/data_models/surrogates/mixed_single_task_gp.py +++ b/bofire/data_models/surrogates/mixed_single_task_gp.py @@ -1,9 +1,9 @@ -from typing import Literal +from typing import Literal, Type from pydantic import Field, validator from bofire.data_models.enum import CategoricalEncodingEnum -from bofire.data_models.features.api import ContinuousOutput +from bofire.data_models.features.api import AnyOutput, ContinuousOutput from bofire.data_models.kernels.api import ( AnyCategoricalKernal, AnyContinuousKernel, @@ -31,17 +31,15 @@ def validate_categoricals(cls, v, values): ) return v - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index e023da801..9c6e6d3a2 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -1,9 +1,8 @@ -from typing import Annotated, Literal, Sequence +from typing import Annotated, Literal, Sequence, Type -from pydantic import Field, validator +from pydantic import Field -from bofire.data_models.features.api import ContinuousOutput -from bofire.data_models.surrogates.botorch import BotorchSurrogate +from bofire.data_models.features.api import AnyOutput, ContinuousOutput from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate @@ -22,17 +21,15 @@ class MLPEnsemble(TrainableBotorchSurrogate): shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/mlp_classifier.py b/bofire/data_models/surrogates/mlp_classifier.py index ea9fc576d..9e99ed6d5 100644 --- a/bofire/data_models/surrogates/mlp_classifier.py +++ b/bofire/data_models/surrogates/mlp_classifier.py @@ -1,8 +1,8 @@ -from typing import Annotated, Literal, Sequence +from typing import Annotated, Literal, Sequence, Type -from pydantic import Field, validator +from pydantic import Field -from bofire.data_models.features.api import CategoricalOutput +from bofire.data_models.features.api import AnyOutput, CategoricalOutput from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable import TrainableSurrogate @@ -22,17 +22,15 @@ class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not CategoricalOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[CategoricalOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, CategoricalOutput): - raise ValueError("all outputs need to be categorical") - return outputs + return True if isinstance(my_type, CategoricalOutput) else False diff --git a/bofire/data_models/surrogates/polynomial.py b/bofire/data_models/surrogates/polynomial.py index 8f64ea106..2c001d442 100644 --- a/bofire/data_models/surrogates/polynomial.py +++ b/bofire/data_models/surrogates/polynomial.py @@ -1,8 +1,9 @@ -from typing import Literal +from typing import Literal, Type from pydantic import Field from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.features.api import AnyOutput, ContinuousOutput from bofire.data_models.kernels.api import ( PolynomialKernel, ) @@ -23,3 +24,16 @@ def from_power(power: int, inputs: Inputs, outputs: Outputs): return PolynomialSurrogate( kernel=PolynomialKernel(power=power), inputs=inputs, outputs=outputs ) + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output + + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/random_forest.py b/bofire/data_models/surrogates/random_forest.py index 3c9361e2a..550d947ed 100644 --- a/bofire/data_models/surrogates/random_forest.py +++ b/bofire/data_models/surrogates/random_forest.py @@ -1,11 +1,10 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional, Type, Union -from pydantic import Field, validator +from pydantic import Field from typing_extensions import Annotated -from bofire.data_models.features.api import ContinuousOutput -from bofire.data_models.surrogates.botorch import BotorchSurrogate -from bofire.data_models.surrogates.trainable import TrainableSurrogate +from bofire.data_models.features.api import AnyOutput, ContinuousOutput +from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate class RandomForestSurrogate(TrainableBotorchSurrogate): @@ -32,17 +31,15 @@ class RandomForestSurrogate(TrainableBotorchSurrogate): ccp_alpha: Annotated[float, Field(ge=0)] = 0.0 max_samples: Optional[Union[int, float]] = None - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/single_task_gp.py b/bofire/data_models/surrogates/single_task_gp.py index 13dc7cb58..55c47512a 100644 --- a/bofire/data_models/surrogates/single_task_gp.py +++ b/bofire/data_models/surrogates/single_task_gp.py @@ -1,13 +1,17 @@ -from typing import Literal, Optional +from typing import Literal, Optional, Type import pandas as pd -from pydantic import Field, validator +from pydantic import Field from bofire.data_models.domain.api import Inputs from bofire.data_models.enum import RegressionMetricsEnum # from bofire.data_models.strategies.api import FactorialStrategy -from bofire.data_models.features.api import CategoricalInput, ContinuousOutput +from bofire.data_models.features.api import ( + AnyOutput, + CategoricalInput, + ContinuousOutput, +) from bofire.data_models.kernels.api import ( AnyKernel, MaternKernel, @@ -23,9 +27,8 @@ MBO_OUTPUTSCALE_PRIOR, AnyPrior, ) -from bofire.data_models.surrogates.botorch import BotorchSurrogate -from bofire.data_models.surrogates.scaler import ScalerEnum -from bofire.data_models.surrogates.trainable import Hyperconfig, TrainableSurrogate +from bofire.data_models.surrogates.trainable import Hyperconfig +from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate class SingleTaskGPHyperconfig(Hyperconfig): @@ -110,17 +113,15 @@ class SingleTaskGPSurrogate(TrainableBotorchSurrogate): default_factory=lambda: SingleTaskGPHyperconfig() ) - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/surrogate.py b/bofire/data_models/surrogates/surrogate.py index 2a5f73e7e..553a77cf7 100644 --- a/bofire/data_models/surrogates/surrogate.py +++ b/bofire/data_models/surrogates/surrogate.py @@ -1,10 +1,11 @@ -from typing import Optional, Union +from abc import abstractmethod +from typing import Optional, Type from pydantic import Field, validator from bofire.data_models.base import BaseModel from bofire.data_models.domain.api import Inputs, Outputs -from bofire.data_models.features.api import ContinuousOutput, CategoricalOutput, TInputTransformSpecs +from bofire.data_models.features.api import AnyOutput, TInputTransformSpecs class Surrogate(BaseModel): @@ -28,9 +29,10 @@ def validate_outputs(cls, v, values): if len(v) == 0: raise ValueError("At least one output feature has to be provided.") return v - - @classmethod # TODO: Remove this, change it, ??? - def is_output_implemented(cls, outputs, my_type: Union[ContinuousOutput, CategoricalOutput]) -> bool: + + @classmethod + @abstractmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models Args: @@ -40,7 +42,4 @@ def is_output_implemented(cls, outputs, my_type: Union[ContinuousOutput, Categor Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, my_type): - return False - return True + pass diff --git a/bofire/data_models/surrogates/tanimoto_gp.py b/bofire/data_models/surrogates/tanimoto_gp.py index 9f2cf01fe..92efe8c05 100644 --- a/bofire/data_models/surrogates/tanimoto_gp.py +++ b/bofire/data_models/surrogates/tanimoto_gp.py @@ -1,8 +1,8 @@ -from typing import Literal +from typing import Literal, Type -from pydantic import Field, validator +from pydantic import Field -from bofire.data_models.features.api import ContinuousOutput +from bofire.data_models.features.api import AnyOutput, ContinuousOutput from bofire.data_models.kernels.api import AnyKernel, ScaleKernel from bofire.data_models.kernels.molecular import TanimotoKernel from bofire.data_models.priors.api import ( @@ -28,17 +28,15 @@ class TanimotoGPSurrogate(TrainableBotorchSurrogate): noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR()) scaler: ScalerEnum = ScalerEnum.IDENTITY - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/data_models/surrogates/xgb.py b/bofire/data_models/surrogates/xgb.py index c77ee8faa..02b2f68a1 100644 --- a/bofire/data_models/surrogates/xgb.py +++ b/bofire/data_models/surrogates/xgb.py @@ -1,10 +1,11 @@ -from typing import Literal, Optional +from typing import Literal, Optional, Type from pydantic import Field, validator from typing_extensions import Annotated from bofire.data_models.enum import CategoricalEncodingEnum from bofire.data_models.features.api import ( + AnyOutput, CategoricalDescriptorInput, CategoricalInput, ContinuousOutput, @@ -77,18 +78,15 @@ def validate_input_preprocessing_specs(cls, v, values): raise ValueError("Currently no numeric transforms are supported.") return v - - @validator("outputs") - def validate_outputs(cls, outputs): - """validates outputs + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models - Raises: - ValueError: if output type is not ContinuousOutput + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output Returns: - List[ContinuousOutput] + bool: True if the output type is valid for the surrogate chosen, False otherwise """ - for o in outputs: - if not isinstance(o, ContinuousOutput): - raise ValueError("all outputs need to be continuous") - return outputs + return True if isinstance(my_type, ContinuousOutput) else False diff --git a/bofire/strategies/doe/design.py b/bofire/strategies/doe/design.py index 42ab9510d..e84b7f5b4 100644 --- a/bofire/strategies/doe/design.py +++ b/bofire/strategies/doe/design.py @@ -632,7 +632,6 @@ def check_partially_fixed_experiments( n_experiments: int, partially_fixed_experiments: pd.DataFrame, ) -> None: - n_partially_fixed_experiments = len(partially_fixed_experiments.index) # for partially fixed experiments only check if all inputs are part of the domain diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index 9e35391e7..154a1d35d 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -136,7 +136,10 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: categorical_preds = { f"{feat.key}_pred": ( ind, - feat.map_to_categories(predictions.filter(regex=f"{feat.key}_pred_")), + predictions.filter(regex=f"{feat.key}_pred_") + .idxmax(1) + .str.replace(f"{feat.key}_pred_", "") + .values, ) for ind, feat in enumerate(self.domain.outputs.get()) if isinstance(feat, CategoricalOutput) diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 28b3eb87f..2d4614aa5 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -18,7 +18,7 @@ from bofire.utils.torch_tools import tkwargs -class RegressionDataSet(Dataset): +class MLPDataset(Dataset): """ Prepare the dataset for regression """ @@ -114,18 +114,19 @@ def num_outputs(self) -> int: def fit_mlp( mlp: MLP, - dataset: RegressionDataSet, + dataset: MLPDataset, batch_size: int = 10, n_epoches: int = 200, lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, + loss_function: nn.Module = nn.L1Loss, ): - """Fit a MLP for regression to a dataset. + """Fit a MLP to a dataset. Args: mlp (MLP): The MLP that should be fitted. - dataset (RegressionDataSet): The data that should be fitted + dataset (MLPDataset): The data that should be fitted batch_size (int, optional): Batch size. Defaults to 10. n_epoches (int, optional): Number of training epoches. Defaults to 200. lr (float, optional): Initial learning rate. Defaults to 1e-4. @@ -134,15 +135,21 @@ def fit_mlp( """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) - loss_function = nn.L1Loss() optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) + loss_function = loss_function() for _ in range(n_epoches): current_loss = 0.0 for data in train_loader: # Get and prepare inputs inputs, targets = data - if len(targets.shape) == 1: + if isinstance(loss_function, nn.CrossEntropyLoss): + targets = targets.flatten().long() + elif len(targets.shape) == 1 and not isinstance( + loss_function, nn.CrossEntropyLoss + ): targets = targets.reshape((targets.shape[0], 1)) + else: + pass # Zero the gradients optimizer.zero_grad() @@ -199,7 +206,7 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) - dataset = RegressionDataSet( + dataset = MLPDataset( X=scaler.transform(tX) if scaler is not None else tX, y=output_scaler(ty)[0] if output_scaler is not None else ty, ) diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py index a7789c04f..dcbb12230 100644 --- a/bofire/surrogates/mlp_classifier.py +++ b/bofire/surrogates/mlp_classifier.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Sequence +from typing import Optional, Sequence import numpy as np import pandas as pd @@ -6,75 +6,18 @@ import torch.nn as nn from botorch.models.ensemble import EnsembleModel from torch import Tensor -from torch.utils.data import DataLoader, Dataset from bofire.data_models.enum import OutputFilteringEnum from bofire.data_models.surrogates.api import MLPEnsemble as DataModel from bofire.surrogates.botorch import BotorchSurrogate +from bofire.surrogates.mlp import MLP, MLPDataset, fit_mlp from bofire.surrogates.single_task_gp import get_scaler from bofire.surrogates.trainable import TrainableSurrogate from bofire.utils.torch_tools import tkwargs -class ClassificationDataSet(Dataset): - """ - Prepare the dataset for classification - """ - - def __init__(self, X: Tensor, y: Tensor): - self.X = X.to(**tkwargs) - self.y = y.to(**tkwargs) - - def __len__(self): - return len(self.X) - - def __getitem__(self, i: int): - return self.X[i], self.y[i] - - -class MLPClassifier(nn.Module): - def __init__( - self, - input_size: int, - output_size: int = 1, - hidden_layer_sizes: Sequence = (100,), - dropout: float = 0.0, - activation: Literal["relu", "logistic", "tanh"] = "relu", - ): - super().__init__() - if activation == "relu": - f_activation = nn.ReLU - elif activation == "logistic": - f_activation = nn.Sigmoid - elif activation == "tanh": - f_activation = nn.Tanh - else: - raise ValueError(f"Activation {activation} not known.") - layers = [ - nn.Linear(input_size, hidden_layer_sizes[0]).to(**tkwargs), - f_activation(), - ] - if dropout > 0.0: - layers.append(nn.Dropout(dropout)) - if len(hidden_layer_sizes) > 1: - for i in range(len(hidden_layer_sizes) - 1): - layers += [ - nn.Linear(hidden_layer_sizes[i], hidden_layer_sizes[i + 1]).to( - **tkwargs - ), - f_activation(), - ] - if dropout > 0.0: - layers.append(nn.Dropout(dropout)) - layers.append(nn.Linear(hidden_layer_sizes[-1], output_size).to(**tkwargs)) - self.layers = nn.Sequential(*layers) - - def forward(self, x): - return nn.functional.log_softmax(self.layers(x), dim=1) - - class _MLPClassifierEnsemble(EnsembleModel): - def __init__(self, mlps: Sequence[MLPClassifier]): + def __init__(self, mlps: Sequence[MLP]): super().__init__() if len(mlps) == 0: raise ValueError("List of mlps is empty.") @@ -89,7 +32,7 @@ def __init__(self, mlps: Sequence[MLPClassifier]): mlp.eval() def forward(self, X: Tensor): - r"""Compute the model output at X. + r"""Assumes that the OUTPUT of the MLPs are the logits and hence we take the softmax over the last dimension here Args: X: A `batch_shape x n x d`-dim input tensor `X`. @@ -98,7 +41,9 @@ def forward(self, X: Tensor): A `batch_shape x s x n x C`-dimensional output tensor where `s` is the size of the ensemble and `C` is the number of classes. """ - return torch.stack([mlp(X).exp() for mlp in self.mlps], dim=-3) + return torch.stack( + [nn.functional.softmax(mlp(X), dim=-1) for mlp in self.mlps], dim=-3 + ) @property def num_outputs(self) -> int: @@ -106,55 +51,6 @@ def num_outputs(self) -> int: return self.mlps[0].layers[-1].out_features # type: ignore -def fit_mlp( - mlp: MLPClassifier, - dataset: ClassificationDataSet, - batch_size: int = 10, - n_epoches: int = 200, - lr: float = 1e-3, - shuffle: bool = True, - weight_decay: float = 0.0, -): - """Fit a MLP for classification to a dataset. - - Args: - mlp (MLP): The MLP that should be fitted. - dataset (ClassificationDataSet): The data that should be fitted - batch_size (int, optional): Batch size. Defaults to 10. - n_epoches (int, optional): Number of training epoches. Defaults to 200. - lr (float, optional): Initial learning rate. Defaults to 1e-4. - shuffle (bool, optional): Whereas the batches should be shuffled. Defaults to True. - weight_decay (float, optional): Weight decay (L2 regularization). Defaults to 0.0 (no regularization). - """ - mlp.train() - train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) - loss_function = nn.NLLLoss(reduction="mean") - optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) - for _ in range(n_epoches): - current_loss = 0.0 - for data in train_loader: - # Get and prepare inputs - inputs, targets = data - - # Zero the gradients - optimizer.zero_grad() - - # Perform forward pass - outputs = mlp(inputs) - - # Compute loss - loss = loss_function(outputs, targets.flatten().long()) - - # Perform backward pass - loss.backward() - - # Perform optimization - optimizer.step() - - # Print statistics - current_loss += loss.item() - - class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): def __init__(self, data_model: DataModel, **kwargs): self.n_estimators = data_model.n_estimators @@ -176,8 +72,8 @@ def __init__(self, data_model: DataModel, **kwargs): def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) - # Map dictionary to objective values - gives what is feasible and to labels - to perform opt - label_mapping = self.outputs[0].to_dict_label() + # Map dictionary of objective values to labels + label_mapping = self.outputs[0].objective.to_dict_label() # Convert Y to classification tensor Y = pd.DataFrame.from_dict( @@ -192,11 +88,11 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) - dataset = ClassificationDataSet( + dataset = MLPDataset( X=scaler.transform(tX) if scaler is not None else tX, y=ty, ) - mlp = MLPClassifier( + mlp = MLP( input_size=transformed_X.shape[1], output_size=len( label_mapping @@ -213,8 +109,197 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): lr=self.lr, shuffle=self.shuffle, weight_decay=self.weight_decay, + loss_function=nn.CrossEntropyLoss, # utilizes logits as input ) mlps.append(mlp) self.model = _MLPClassifierEnsemble(mlps=mlps) if scaler is not None: self.model.input_transform = scaler + + +# class MLPClassifier(nn.Module): +# def __init__( +# self, +# input_size: int, +# output_size: int = 1, +# hidden_layer_sizes: Sequence = (100,), +# dropout: float = 0.0, +# activation: Literal["relu", "logistic", "tanh"] = "relu", +# ): +# super().__init__() +# if activation == "relu": +# f_activation = nn.ReLU +# elif activation == "logistic": +# f_activation = nn.Sigmoid +# elif activation == "tanh": +# f_activation = nn.Tanh +# else: +# raise ValueError(f"Activation {activation} not known.") +# layers = [ +# nn.Linear(input_size, hidden_layer_sizes[0]).to(**tkwargs), +# f_activation(), +# ] +# if dropout > 0.0: +# layers.append(nn.Dropout(dropout)) +# if len(hidden_layer_sizes) > 1: +# for i in range(len(hidden_layer_sizes) - 1): +# layers += [ +# nn.Linear(hidden_layer_sizes[i], hidden_layer_sizes[i + 1]).to( +# **tkwargs +# ), +# f_activation(), +# ] +# if dropout > 0.0: +# layers.append(nn.Dropout(dropout)) +# layers.append(nn.Linear(hidden_layer_sizes[-1], output_size).to(**tkwargs)) +# self.layers = nn.Sequential(*layers) + +# def forward(self, x): +# return nn.functional.softmax(self.layers(x), dim=1) + + +# class _MLPClassifierEnsemble(EnsembleModel): +# def __init__(self, mlps: Sequence[MLPClassifier]): +# super().__init__() +# if len(mlps) == 0: +# raise ValueError("List of mlps is empty.") +# num_in_features = mlps[0].layers[0].in_features +# num_out_features = mlps[0].layers[-1].out_features +# for mlp in mlps: +# assert mlp.layers[0].in_features == num_in_features +# assert mlp.layers[-1].out_features == num_out_features +# self.mlps = mlps +# # put all models in eval mode +# for mlp in self.mlps: +# mlp.eval() + +# def forward(self, X: Tensor): +# r"""Compute the model output at X. + +# Args: +# X: A `batch_shape x n x d`-dim input tensor `X`. + +# Returns: +# A `batch_shape x s x n x C`-dimensional output tensor where +# `s` is the size of the ensemble and `C` is the number of classes. +# """ +# return torch.stack([mlp(X).exp() for mlp in self.mlps], dim=-3) + +# @property +# def num_outputs(self) -> int: +# r"""The number of outputs of the model.""" +# return self.mlps[0].layers[-1].out_features # type: ignore + + +# def fit_mlp( +# mlp: MLPClassifier, +# dataset: ClassificationDataSet, +# batch_size: int = 10, +# n_epoches: int = 200, +# lr: float = 1e-3, +# shuffle: bool = True, +# weight_decay: float = 0.0, +# ): +# """Fit a MLP for classification to a dataset. + +# Args: +# mlp (MLP): The MLP that should be fitted. +# dataset (ClassificationDataSet): The data that should be fitted +# batch_size (int, optional): Batch size. Defaults to 10. +# n_epoches (int, optional): Number of training epoches. Defaults to 200. +# lr (float, optional): Initial learning rate. Defaults to 1e-4. +# shuffle (bool, optional): Whereas the batches should be shuffled. Defaults to True. +# weight_decay (float, optional): Weight decay (L2 regularization). Defaults to 0.0 (no regularization). +# """ +# mlp.train() +# train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) +# loss_function = nn.NLLLoss(reduction="mean") +# optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) +# for _ in range(n_epoches): +# current_loss = 0.0 +# for data in train_loader: +# # Get and prepare inputs +# inputs, targets = data + +# # Zero the gradients +# optimizer.zero_grad() + +# # Perform forward pass +# outputs = mlp(inputs) + +# # Compute loss +# loss = loss_function(outputs, targets.flatten().long()) + +# # Perform backward pass +# loss.backward() + +# # Perform optimization +# optimizer.step() + +# # Print statistics +# current_loss += loss.item() + + +# class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): +# def __init__(self, data_model: DataModel, **kwargs): +# self.n_estimators = data_model.n_estimators +# self.hidden_layer_sizes = data_model.hidden_layer_sizes +# self.activation = data_model.activation +# self.dropout = data_model.dropout +# self.batch_size = data_model.batch_size +# self.n_epochs = data_model.n_epochs +# self.lr = data_model.lr +# self.weight_decay = data_model.weight_decay +# self.subsample_fraction = data_model.subsample_fraction +# self.shuffle = data_model.shuffle +# self.scaler = data_model.scaler +# super().__init__(data_model, **kwargs) + +# _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL +# model: Optional[_MLPClassifierEnsemble] = None + +# def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): +# scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) +# transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) +# # Map dictionary to objective values - gives what is feasible and to labels - to perform opt +# label_mapping = self.outputs[0].to_dict_label() + +# # Convert Y to classification tensor +# Y = pd.DataFrame.from_dict( +# {col: Y[col].map(label_mapping) for col in Y.columns} +# ) + +# mlps = [] +# subsample_size = round(self.subsample_fraction * X.shape[0]) +# for _ in range(self.n_estimators): +# # resample X and Y +# sample_idx = np.random.choice(X.shape[0], replace=True, size=subsample_size) +# tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) +# ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) + +# dataset = ClassificationDataSet( +# X=scaler.transform(tX) if scaler is not None else tX, +# y=ty, +# ) +# mlp = MLPClassifier( +# input_size=transformed_X.shape[1], +# output_size=len( +# label_mapping +# ), # Set outputs based on number of categories +# hidden_layer_sizes=self.hidden_layer_sizes, +# activation=self.activation, # type: ignore +# dropout=self.dropout, +# ) +# fit_mlp( +# mlp=mlp, +# dataset=dataset, +# batch_size=self.batch_size, +# n_epoches=self.n_epochs, +# lr=self.lr, +# shuffle=self.shuffle, +# weight_decay=self.weight_decay, +# ) +# mlps.append(mlp) +# self.model = _MLPClassifierEnsemble(mlps=mlps) +# if scaler is not None: +# self.model.input_transform = scaler diff --git a/bofire/utils/torch_tools.py b/bofire/utils/torch_tools.py index c00dbc975..682bf7995 100644 --- a/bofire/utils/torch_tools.py +++ b/bofire/utils/torch_tools.py @@ -14,8 +14,8 @@ ) from bofire.data_models.features.api import ContinuousInput, Input from bofire.data_models.objectives.api import ( - CategoricalObjective, CloseToTargetObjective, + ConstrainedCategoricalObjective, ConstrainedObjective, MaximizeObjective, MaximizeSigmoidObjective, @@ -178,9 +178,7 @@ def min_constraint(indices: Tensor, num_features: int, min_count: int): def constrained_objective2botorch( - idx: int, - objective: ConstrainedObjective, - eps: float = 1e-6 + idx: int, objective: ConstrainedObjective, eps: float = 1e-6 ) -> Tuple[List[Callable[[Tensor], Tensor]], List[float], int]: """Create a callable that can be used by `botorch.utils.objective.apply_constraints` to setup ouput constrained optimizations. @@ -220,13 +218,22 @@ def constrained_objective2botorch( [1.0 / objective.steepness, 1.0 / objective.steepness], idx + 1, ) - elif isinstance(objective, CategoricalObjective): + elif isinstance(objective, ConstrainedCategoricalObjective): # The output of a categorical objective has final dim `c` where `c` is number of classes # Pass in the expected acceptance probability and perform an inverse sigmoid to atain the original probabilities return ( [ lambda Z: torch.log( - torch.clamp(1 / (Z[..., idx : idx + len(objective.desirability)] * torch.tensor(objective.desirability).to(**tkwargs)).sum(-1) - 1, min=eps, max=1-eps) + torch.clamp( + 1 + / ( + Z[..., idx : idx + len(objective.desirability)] + * torch.tensor(objective.desirability).to(**tkwargs) + ).sum(-1) + - 1, + min=eps, + max=1 - eps, + ) ) ], [objective.eta], diff --git a/tests/bofire/data_models/domain/test_outputs.py b/tests/bofire/data_models/domain/test_outputs.py index 9d5240891..6f1b2456b 100644 --- a/tests/bofire/data_models/domain/test_outputs.py +++ b/tests/bofire/data_models/domain/test_outputs.py @@ -8,6 +8,7 @@ from bofire.data_models.domain.api import Outputs from bofire.data_models.features.api import CategoricalOutput, ContinuousOutput from bofire.data_models.objectives.api import ( + ConstrainedCategoricalObjective, ConstrainedObjective, MaximizeObjective, MaximizeSigmoidObjective, @@ -254,7 +255,11 @@ def test_get_outputs_by_objective_none(): of2, of3, CategoricalOutput( - key="of4", categories=["a", "b"], objective=[1.0, 0.0] + key="of4", + categories=["a", "b"], + objective=ConstrainedCategoricalObjective( + categories=("a", "b"), desirability=(True, False) + ), ), ] ), @@ -266,11 +271,17 @@ def test_outputs_call(features, samples): o = features(samples) assert o.shape == ( len(samples), - len(features.get_keys_by_objective(Objective)) + len( + features.get_keys_by_objective( + Objective, excludes=ConstrainedCategoricalObjective + ) + ) + len(features.get_keys(CategoricalOutput)), ) assert list(o.columns) == [ f"{key}_des" - for key in features.get_keys_by_objective(Objective) + for key in features.get_keys_by_objective( + Objective, excludes=ConstrainedCategoricalObjective + ) + features.get_keys(CategoricalOutput) ] diff --git a/tests/bofire/data_models/features/test_categorical.py b/tests/bofire/data_models/features/test_categorical.py index c486fc2a4..c3e91de2e 100644 --- a/tests/bofire/data_models/features/test_categorical.py +++ b/tests/bofire/data_models/features/test_categorical.py @@ -10,7 +10,6 @@ from bofire.data_models.features.api import ( CategoricalDescriptorInput, CategoricalInput, - CategoricalOutput, ) @@ -461,13 +460,3 @@ def test_categorical_input_feature_allowed_categories(input_feature, expected): ) def test_categorical_input_feature_forbidden_categories(input_feature, expected): assert input_feature.get_forbidden_categories() == expected - - -def test_categorical_output(): - feature = CategoricalOutput( - key="a", categories=["alpha", "beta", "gamma"], objective=[1.0, 0.0, 0.1] - ) - - assert feature.to_dict() == {"alpha": 1.0, "beta": 0.0, "gamma": 0.1} - data = pd.Series(data=["alpha", "beta", "beta", "gamma"], name="a") - assert_series_equal(feature(data), pd.Series(data=[1.0, 0.0, 0.0, 0.1], name="a")) diff --git a/tests/bofire/data_models/features/test_descriptor.py b/tests/bofire/data_models/features/test_descriptor.py index f9007116d..90b791f05 100644 --- a/tests/bofire/data_models/features/test_descriptor.py +++ b/tests/bofire/data_models/features/test_descriptor.py @@ -378,6 +378,6 @@ def test_categorical_descriptor_input_feature_from_dataframe( columns=descriptors, ) f = CategoricalDescriptorInput.from_df("k", df) - assert f.categories == categories + assert f.categories == tuple(categories) assert f.descriptors == descriptors assert f.values == values diff --git a/tests/bofire/data_models/serialization/test_serialization.py b/tests/bofire/data_models/serialization/test_serialization.py index 2dcc0ebaa..9f375f66c 100644 --- a/tests/bofire/data_models/serialization/test_serialization.py +++ b/tests/bofire/data_models/serialization/test_serialization.py @@ -28,7 +28,6 @@ def test_objective_should_be_serializable(objective_spec: Spec): def test_feature_should_be_serializable(feature_spec: Spec): spec = feature_spec.typed_spec() obj = feature_spec.cls(**spec) - print(spec) assert obj.dict() == spec diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index 5fe6c6067..163f61f50 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -3,7 +3,10 @@ import uuid import bofire.data_models.features.api as features -from bofire.data_models.objectives.api import CategoricalObjective, MaximizeObjective +from bofire.data_models.objectives.api import ( + ConstrainedCategoricalObjective, + MaximizeObjective, +) from tests.bofire.data_models.specs.objectives import specs as objectives from tests.bofire.data_models.specs.specs import Specs @@ -88,7 +91,9 @@ lambda: { "key": str(uuid.uuid4()), "categories": ("a", "b", "c"), - "objective": CategoricalObjective(desirability=(0.0, 1.0, 0.0)), + "objective": ConstrainedCategoricalObjective( + categories=("a", "b", "c"), desirability=(True, True, False) + ), }, ) specs.add_valid( diff --git a/tests/bofire/surrogates/test_mlp.py b/tests/bofire/surrogates/test_mlp.py index 8acde7c63..b447d6ada 100644 --- a/tests/bofire/surrogates/test_mlp.py +++ b/tests/bofire/surrogates/test_mlp.py @@ -14,7 +14,7 @@ ContinuousOutput, ) from bofire.data_models.surrogates.api import MLPEnsemble, ScalerEnum -from bofire.surrogates.mlp import MLP, RegressionDataSet, _MLPEnsemble, fit_mlp +from bofire.surrogates.mlp import MLP, MLPDataset, _MLPEnsemble, fit_mlp from bofire.utils.torch_tools import tkwargs @@ -78,10 +78,10 @@ def test_mlp_dropout(): assert mlp.layers[9].out_features == 1 -def test_RegressionDataSet(): +def test_DataSet(): X = torch.randn(10, 3) y = torch.randn(10, 1) - dset = RegressionDataSet(X=X, y=y) + dset = MLPDataset(X=X, y=y) assert len(dset) == 10 xi, yi = dset[3] assert torch.allclose(xi, X[3].to(**tkwargs)) @@ -114,7 +114,7 @@ def test_fit_mlp(mlp, weight_decay, n_epoches, lr, shuffle): X, y = torch.from_numpy(experiments[["x_1", "x_2"]].values), torch.from_numpy( experiments[["y"]].values ) - dset = RegressionDataSet(X=X, y=y) + dset = MLPDataset(X=X, y=y) fit_mlp( mlp=mlp, dataset=dset, diff --git a/tests/bofire/surrogates/test_surrogates.py b/tests/bofire/surrogates/test_surrogates.py index 30e360fa3..5e8554559 100644 --- a/tests/bofire/surrogates/test_surrogates.py +++ b/tests/bofire/surrogates/test_surrogates.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Type import numpy as np import pandas as pd @@ -6,7 +6,7 @@ from pydantic.error_wrappers import ValidationError from bofire.data_models.domain.api import Inputs, Outputs -from bofire.data_models.features.api import ContinuousInput, ContinuousOutput +from bofire.data_models.features.api import AnyOutput, ContinuousInput, ContinuousOutput from bofire.data_models.surrogates.api import Surrogate as SurrogateDataModel from bofire.surrogates.api import PredictedValue, Surrogate @@ -14,6 +14,19 @@ class DummyDataModel(SurrogateDataModel): type: Literal["Dummy"] = "Dummy" + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + + Args: + outputs: objective functions for the surrogate + my_type: continuous or categorical output + + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return True + class Dummy(Surrogate): def __init__( diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 4b0c12227..947549c37 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -8,7 +8,7 @@ "\n", "We are interested in testing whether or not a surrogate model can correctly identify unknown constraints based on categorical criteria with classification surrogates. Essentially, we want to account for scenarios where specialists can look at a set of experiments and label outcomes as 'acceptable', 'unacceptable', 'ideal', etc. \n", "\n", - "This involves new models that produce `CategoricalOutput`'s rather than continuous outputs. Mathematically, if $g_{\\theta}:\\mathbb{R}^d\\to[0,1]^c$ represents the function governed by learnable parameters $\\theta$ which outputs a probability vector over $c$ potential classes (i.e. for input $x\\in\\mathbb{R}^d$, $g_{\\theta}(x)^\\top\\mathbf{1}=1$ where $\\mathbf{1}$ is the vector of all 1's) and we have acceptibility criteria for the corresponding classes given by $a\\in[0,1]^c$, we can compute a scalar output as $g_{\\theta}(x)^\\top a\\in[0,1]$ as an objective value to be passed in as a constrained function.\n", + "This involves new models that produce `CategoricalOutput`'s rather than continuous outputs. Mathematically, if $g_{\\theta}:\\mathbb{R}^d\\to[0,1]^c$ represents the function governed by learnable parameters $\\theta$ which outputs a probability vector over $c$ potential classes (i.e. for input $x\\in\\mathbb{R}^d$, $g_{\\theta}(x)^\\top\\mathbf{1}=1$ where $\\mathbf{1}$ is the vector of all 1's) and we have acceptibility criteria for the corresponding classes given by $a\\in\\{0,1\\}^c$, we can compute the scalar output $g_{\\theta}(x)^\\top a\\in[0,1]$ which represents the expected value of acceptance as an objective value to be passed in as a constrained function.\n", "\n", "In this script, we look at a modified and constrained version of the optimization problem associated with the [Levy function](https://www.sfu.ca/~ssurjano/levy.html), which has a global minima at $x^*=\\mathbf{1}$. We classify constraints for three classes: 'acceptable', 'unacceptable', and 'ideal' based on how close we are to the optimal decision variable; obviously, this value is unknown in a real-world setting, but this serves as a reasonable example." ] @@ -32,7 +32,7 @@ "import bofire.strategies.api as strategies\n", "from bofire.data_models.api import Domain, Outputs, Inputs\n", "from bofire.data_models.features.api import ContinuousInput, ContinuousOutput, CategoricalOutput, CategoricalInput\n", - "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective, CategoricalObjective\n", + "from bofire.data_models.objectives.api import MinimizeObjective, MinimizeSigmoidObjective, ConstrainedCategoricalObjective\n", "import numpy as np\n", "import pandas as pd" ] @@ -98,243 +98,243 @@ " \n", " \n", " 0\n", - " 0.435454\n", - " -1.788437\n", - " -1.988452\n", - " -0.421175\n", - " -1.708184\n", + " 0.043387\n", + " 0.471935\n", + " -1.903405\n", + " -1.505071\n", + " -1.631395\n", " 0.0\n", - " 15.230583\n", + " 12.885008\n", " unacceptable\n", - " 0.437269\n", + " 0.049104\n", " \n", " \n", " 1\n", - " -0.839497\n", - " -1.317481\n", - " -0.621791\n", - " 1.029460\n", - " -0.029063\n", + " -0.516480\n", + " -0.519954\n", + " 0.172009\n", + " -0.671419\n", + " 0.576339\n", " 1.0\n", - " 4.092890\n", + " 1.743262\n", " unacceptable\n", - " -0.829668\n", + " -0.511997\n", " \n", " \n", " 2\n", - " 0.457487\n", - " -1.632370\n", - " -0.219897\n", - " -0.589791\n", - " -0.543485\n", + " 1.420828\n", + " -0.997470\n", + " -0.869831\n", + " -0.603185\n", + " 0.988815\n", " 0.0\n", - " 4.653823\n", + " 2.214841\n", " unacceptable\n", - " 0.466307\n", + " 1.426101\n", " \n", " \n", " 3\n", - " -0.545863\n", - " 1.068268\n", - " -0.840059\n", - " -0.156614\n", - " 0.532427\n", - " 0.0\n", - " 1.944284\n", + " -1.729386\n", + " -0.517152\n", + " 0.792940\n", + " 1.841196\n", + " -1.018597\n", + " 1.0\n", + " 6.757267\n", " unacceptable\n", - " -0.544115\n", + " -1.723294\n", " \n", " \n", " 4\n", - " 1.294465\n", - " -1.086366\n", - " -1.444690\n", - " -1.005573\n", - " -0.743576\n", - " 1.0\n", - " 5.516022\n", - " unacceptable\n", - " 1.303366\n", + " 0.137756\n", + " -1.910416\n", + " 0.266167\n", + " 1.376514\n", + " 1.853052\n", + " 0.0\n", + " 6.659112\n", + " acceptable\n", + " 0.146894\n", " \n", " \n", " 5\n", - " -0.826168\n", - " -1.615002\n", - " 1.624282\n", - " 0.041478\n", - " 0.413251\n", - " 1.0\n", - " 5.652505\n", + " 0.615494\n", + " 0.638574\n", + " -0.846471\n", + " 0.486294\n", + " -1.844437\n", + " 0.0\n", + " 5.965467\n", " unacceptable\n", - " -0.819800\n", + " 0.623373\n", " \n", " \n", " 6\n", - " 0.213634\n", - " -0.469763\n", - " 0.950611\n", - " 1.945084\n", - " -0.783820\n", + " 0.577691\n", + " -1.316748\n", + " -0.598852\n", + " 1.499971\n", + " 0.435526\n", " 1.0\n", - " 1.695620\n", - " ideal\n", - " 0.214467\n", + " 2.791735\n", + " acceptable\n", + " 0.586726\n", " \n", " \n", " 7\n", - " -0.570017\n", - " 0.019903\n", - " -1.137774\n", - " -0.417314\n", - " 1.457074\n", + " -1.382985\n", + " -1.629080\n", + " 0.957606\n", + " 1.263108\n", + " -1.248810\n", " 1.0\n", - " 2.909370\n", + " 8.962357\n", " unacceptable\n", - " -0.567218\n", + " -1.376337\n", " \n", " \n", " 8\n", - " -0.756360\n", - " -0.905112\n", - " 1.536840\n", - " -0.081425\n", - " 0.350125\n", - " 0.0\n", - " 2.621606\n", + " -0.732977\n", + " -1.635727\n", + " -1.165820\n", + " -1.912441\n", + " 1.095427\n", + " 1.0\n", + " 12.089185\n", " unacceptable\n", - " -0.747130\n", + " -0.723748\n", " \n", " \n", " 9\n", - " 1.615900\n", - " -0.951538\n", - " 1.450884\n", - " 0.623330\n", - " 0.232692\n", - " 0.0\n", - " 1.690015\n", - " ideal\n", - " 1.622131\n", + " -1.324801\n", + " -1.119786\n", + " 1.021565\n", + " -0.304530\n", + " 0.425360\n", + " 1.0\n", + " 4.633603\n", + " unacceptable\n", + " -1.322161\n", " \n", " \n", " 10\n", - " 1.679599\n", - " 0.482301\n", - " -0.331234\n", - " 0.343804\n", - " 1.968027\n", - " 1.0\n", - " 1.466946\n", - " ideal\n", - " 1.679712\n", + " 0.957669\n", + " -1.272092\n", + " -0.461742\n", + " 1.557717\n", + " 1.954150\n", + " 0.0\n", + " 3.006800\n", + " acceptable\n", + " 0.965879\n", " \n", " \n", " 11\n", - " 1.967228\n", - " 1.924398\n", - " 0.038540\n", - " 1.206153\n", - " 1.048228\n", + " -1.609909\n", + " -1.972229\n", + " -0.950002\n", + " -1.645930\n", + " -0.068231\n", " 0.0\n", - " 1.880371\n", - " ideal\n", - " 1.969058\n", + " 15.049384\n", + " unacceptable\n", + " -1.602951\n", " \n", " \n", " 12\n", - " 0.784865\n", - " -1.456000\n", - " -1.918741\n", - " -0.222590\n", - " -1.711922\n", + " -0.872965\n", + " -0.263758\n", + " -1.418009\n", + " 1.434761\n", + " -0.156978\n", " 1.0\n", - " 12.611483\n", + " 4.579632\n", " unacceptable\n", - " 0.786212\n", + " -0.865004\n", " \n", " \n", " 13\n", - " -0.690906\n", - " -0.902616\n", - " 1.696768\n", - " 0.247553\n", - " -0.051436\n", - " 0.0\n", - " 2.647546\n", + " -1.650395\n", + " -1.291974\n", + " -1.693253\n", + " 1.404619\n", + " -0.617787\n", + " 1.0\n", + " 11.135251\n", " unacceptable\n", - " -0.687487\n", + " -1.641277\n", " \n", " \n", " 14\n", - " -0.415655\n", - " 0.478458\n", - " 0.975439\n", - " 1.812020\n", - " 1.215072\n", + " -0.122850\n", + " 0.062456\n", + " -1.093654\n", + " 0.762955\n", + " -0.608164\n", " 1.0\n", - " 1.493446\n", - " ideal\n", - " -0.414414\n", + " 2.333336\n", + " unacceptable\n", + " -0.116326\n", " \n", " \n", " 15\n", - " 0.355296\n", - " -1.806993\n", - " 0.823093\n", - " -0.596255\n", - " 1.884369\n", - " 1.0\n", - " 5.908147\n", - " acceptable\n", - " 0.363767\n", + " -0.343801\n", + " 1.969565\n", + " -0.400887\n", + " 1.546484\n", + " 0.955387\n", + " 0.0\n", + " 1.926944\n", + " ideal\n", + " -0.335117\n", " \n", " \n", " 16\n", - " -0.227270\n", - " 1.515492\n", - " -0.378421\n", - " -0.282083\n", - " -1.022800\n", + " -1.023179\n", + " -0.506223\n", + " -1.442436\n", + " 1.881762\n", + " 0.702930\n", " 0.0\n", - " 2.312977\n", + " 5.610748\n", " unacceptable\n", - " -0.226730\n", + " -1.015035\n", " \n", " \n", " 17\n", - " -1.938272\n", - " -0.544588\n", - " 0.677599\n", - " 1.611933\n", - " -1.657255\n", + " 0.806867\n", + " 0.223890\n", + " 1.955296\n", + " 1.470794\n", + " 1.646253\n", " 0.0\n", - " 10.617281\n", - " unacceptable\n", - " -1.938064\n", + " 1.263481\n", + " ideal\n", + " 0.814321\n", " \n", " \n", " 18\n", - " -1.233844\n", - " 1.107200\n", - " -0.249493\n", - " 1.132248\n", - " -1.648399\n", + " -0.063949\n", + " -1.605616\n", + " -1.694294\n", + " 1.014650\n", + " -0.542840\n", " 0.0\n", - " 6.794837\n", + " 8.664851\n", " unacceptable\n", - " -1.226482\n", + " -0.063911\n", " \n", " \n", " 19\n", - " 0.312220\n", - " -0.639545\n", - " 1.626105\n", - " -0.081847\n", - " -0.252003\n", - " 0.0\n", - " 1.202175\n", - " unacceptable\n", - " 0.313409\n", + " -0.791231\n", + " 0.269571\n", + " -0.625560\n", + " 1.688765\n", + " 0.347797\n", + " 1.0\n", + " 2.277518\n", + " acceptable\n", + " -0.788239\n", " \n", " \n", "\n", @@ -342,48 +342,48 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 0.435454 -1.788437 -1.988452 -0.421175 -1.708184 0.0 15.230583 \n", - "1 -0.839497 -1.317481 -0.621791 1.029460 -0.029063 1.0 4.092890 \n", - "2 0.457487 -1.632370 -0.219897 -0.589791 -0.543485 0.0 4.653823 \n", - "3 -0.545863 1.068268 -0.840059 -0.156614 0.532427 0.0 1.944284 \n", - "4 1.294465 -1.086366 -1.444690 -1.005573 -0.743576 1.0 5.516022 \n", - "5 -0.826168 -1.615002 1.624282 0.041478 0.413251 1.0 5.652505 \n", - "6 0.213634 -0.469763 0.950611 1.945084 -0.783820 1.0 1.695620 \n", - "7 -0.570017 0.019903 -1.137774 -0.417314 1.457074 1.0 2.909370 \n", - "8 -0.756360 -0.905112 1.536840 -0.081425 0.350125 0.0 2.621606 \n", - "9 1.615900 -0.951538 1.450884 0.623330 0.232692 0.0 1.690015 \n", - "10 1.679599 0.482301 -0.331234 0.343804 1.968027 1.0 1.466946 \n", - "11 1.967228 1.924398 0.038540 1.206153 1.048228 0.0 1.880371 \n", - "12 0.784865 -1.456000 -1.918741 -0.222590 -1.711922 1.0 12.611483 \n", - "13 -0.690906 -0.902616 1.696768 0.247553 -0.051436 0.0 2.647546 \n", - "14 -0.415655 0.478458 0.975439 1.812020 1.215072 1.0 1.493446 \n", - "15 0.355296 -1.806993 0.823093 -0.596255 1.884369 1.0 5.908147 \n", - "16 -0.227270 1.515492 -0.378421 -0.282083 -1.022800 0.0 2.312977 \n", - "17 -1.938272 -0.544588 0.677599 1.611933 -1.657255 0.0 10.617281 \n", - "18 -1.233844 1.107200 -0.249493 1.132248 -1.648399 0.0 6.794837 \n", - "19 0.312220 -0.639545 1.626105 -0.081847 -0.252003 0.0 1.202175 \n", + "0 0.043387 0.471935 -1.903405 -1.505071 -1.631395 0.0 12.885008 \n", + "1 -0.516480 -0.519954 0.172009 -0.671419 0.576339 1.0 1.743262 \n", + "2 1.420828 -0.997470 -0.869831 -0.603185 0.988815 0.0 2.214841 \n", + "3 -1.729386 -0.517152 0.792940 1.841196 -1.018597 1.0 6.757267 \n", + "4 0.137756 -1.910416 0.266167 1.376514 1.853052 0.0 6.659112 \n", + "5 0.615494 0.638574 -0.846471 0.486294 -1.844437 0.0 5.965467 \n", + "6 0.577691 -1.316748 -0.598852 1.499971 0.435526 1.0 2.791735 \n", + "7 -1.382985 -1.629080 0.957606 1.263108 -1.248810 1.0 8.962357 \n", + "8 -0.732977 -1.635727 -1.165820 -1.912441 1.095427 1.0 12.089185 \n", + "9 -1.324801 -1.119786 1.021565 -0.304530 0.425360 1.0 4.633603 \n", + "10 0.957669 -1.272092 -0.461742 1.557717 1.954150 0.0 3.006800 \n", + "11 -1.609909 -1.972229 -0.950002 -1.645930 -0.068231 0.0 15.049384 \n", + "12 -0.872965 -0.263758 -1.418009 1.434761 -0.156978 1.0 4.579632 \n", + "13 -1.650395 -1.291974 -1.693253 1.404619 -0.617787 1.0 11.135251 \n", + "14 -0.122850 0.062456 -1.093654 0.762955 -0.608164 1.0 2.333336 \n", + "15 -0.343801 1.969565 -0.400887 1.546484 0.955387 0.0 1.926944 \n", + "16 -1.023179 -0.506223 -1.442436 1.881762 0.702930 0.0 5.610748 \n", + "17 0.806867 0.223890 1.955296 1.470794 1.646253 0.0 1.263481 \n", + "18 -0.063949 -1.605616 -1.694294 1.014650 -0.542840 0.0 8.664851 \n", + "19 -0.791231 0.269571 -0.625560 1.688765 0.347797 1.0 2.277518 \n", "\n", " f_1 f_2 \n", - "0 unacceptable 0.437269 \n", - "1 unacceptable -0.829668 \n", - "2 unacceptable 0.466307 \n", - "3 unacceptable -0.544115 \n", - "4 unacceptable 1.303366 \n", - "5 unacceptable -0.819800 \n", - "6 ideal 0.214467 \n", - "7 unacceptable -0.567218 \n", - "8 unacceptable -0.747130 \n", - "9 ideal 1.622131 \n", - "10 ideal 1.679712 \n", - "11 ideal 1.969058 \n", - "12 unacceptable 0.786212 \n", - "13 unacceptable -0.687487 \n", - "14 ideal -0.414414 \n", - "15 acceptable 0.363767 \n", - "16 unacceptable -0.226730 \n", - "17 unacceptable -1.938064 \n", - "18 unacceptable -1.226482 \n", - "19 unacceptable 0.313409 " + "0 unacceptable 0.049104 \n", + "1 unacceptable -0.511997 \n", + "2 unacceptable 1.426101 \n", + "3 unacceptable -1.723294 \n", + "4 acceptable 0.146894 \n", + "5 unacceptable 0.623373 \n", + "6 acceptable 0.586726 \n", + "7 unacceptable -1.376337 \n", + "8 unacceptable -0.723748 \n", + "9 unacceptable -1.322161 \n", + "10 acceptable 0.965879 \n", + "11 unacceptable -1.602951 \n", + "12 unacceptable -0.865004 \n", + "13 unacceptable -1.641277 \n", + "14 unacceptable -0.116326 \n", + "15 ideal -0.335117 \n", + "16 unacceptable -1.015035 \n", + "17 ideal 0.814321 \n", + "18 unacceptable -0.063911 \n", + "19 acceptable -0.788239 " ] }, "execution_count": 3, @@ -398,7 +398,7 @@ "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", "output_features = Outputs(features=[\n", " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", - " CategoricalOutput(key=f\"f_{1}\", categories=(\"unacceptable\", \"acceptable\", \"ideal\"), objective=CategoricalObjective(desirability=(0, 0.5, 1))), # This function will be associated with learning the categories\n", + " CategoricalOutput(key=f\"f_{1}\", categories=(\"unacceptable\", \"acceptable\", \"ideal\"), objective=ConstrainedCategoricalObjective(categories=(\"unacceptable\", \"acceptable\", \"ideal\"), desirability=(False, True, True))), # This function will be associated with learning the categories\n", " ContinuousOutput(key=f\"f_{2}\", objective=MinimizeSigmoidObjective(w=1., tp=0.0, steepness=0.5)),\n", " ]\n", ")\n", @@ -412,8 +412,8 @@ "# Write a function which outputs one continuous variable and another discrete based on some logic\n", "sample_df[\"f_0\"] = np.sin(np.pi * scale_inputs(sample_df[\"x_0\"])) ** 2 + sum([(scale_inputs(sample_df[col]) - 1) ** 2 * (1 + 10 * np.sin(np.pi * scale_inputs(sample_df[col]) + 1) ** 2 if ind < len(sample_df.columns) else 1 + np.sin(2 * np.pi * scale_inputs(sample_df[col])) ** 2) for ind, col in enumerate(sample_df.columns)])\n", "sample_df[\"f_1\"] = \"unacceptable\"\n", - "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 1.0, \"f_1\"] = \"acceptable\"\n", - "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 2.0, \"f_1\"] = \"ideal\"\n", + "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 1.5, \"f_1\"] = \"acceptable\"\n", + "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 3.5, \"f_1\"] = \"ideal\"\n", "sample_df[\"f_2\"] = sample_df[\"x_0\"] + 1e-2 * np.random.uniform(size=(len(sample_df),))\n", "\n", "sample_df.head(20)" @@ -423,7 +423,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Setup of the Strategy and ask for Candidates\n", + "## Setup strategy and ask for candidates\n", "\n" ] }, @@ -431,45 +431,18 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "__init__() missing 1 required positional argument: 'task_feature'", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mc:\\Users\\G15361\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\tutorials\\basic_examples\\Unknown_Constraint_Classification.ipynb Cell 7\u001b[0m line \u001b[0;36m2\n\u001b[0;32m 15\u001b[0m strategy_data \u001b[39m=\u001b[39m SoboStrategy(domain\u001b[39m=\u001b[39mdomain1, \n\u001b[0;32m 16\u001b[0m acquisition_function\u001b[39m=\u001b[39mqEI(), \n\u001b[0;32m 17\u001b[0m surrogate_specs\u001b[39m=\u001b[39mBotorchSurrogates(surrogates\u001b[39m=\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 22\u001b[0m )\n\u001b[0;32m 23\u001b[0m )\n\u001b[0;32m 25\u001b[0m strategy \u001b[39m=\u001b[39m strategies\u001b[39m.\u001b[39mmap(strategy_data)\n\u001b[1;32m---> 27\u001b[0m strategy\u001b[39m.\u001b[39;49mtell(sample_df)\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\predictive.py:84\u001b[0m, in \u001b[0;36mPredictiveStrategy.tell\u001b[1;34m(self, experiments, replace, retrain)\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39madd_experiments(experiments)\n\u001b[0;32m 83\u001b[0m \u001b[39mif\u001b[39;00m retrain \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhas_sufficient_experiments():\n\u001b[1;32m---> 84\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfit()\n\u001b[0;32m 85\u001b[0m \u001b[39m# we have a seperate _tell here for things that are relevant when setting up the strategy but unrelated\u001b[39;00m\n\u001b[0;32m 86\u001b[0m \u001b[39m# to fitting the models like initializing the ACQF.\u001b[39;00m\n\u001b[0;32m 87\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_tell()\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\predictive.py:165\u001b[0m, in \u001b[0;36mPredictiveStrategy.fit\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 163\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdomain\u001b[39m.\u001b[39mvalidate_experiments(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexperiments, strict\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m 164\u001b[0m \u001b[39m# transformed = self.transformer.fit_transform(self.experiments)\u001b[39;00m\n\u001b[1;32m--> 165\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mexperiments)\n\u001b[0;32m 166\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_fitted \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\strategies\\predictives\\botorch.py:138\u001b[0m, in \u001b[0;36mBotorchStrategy._fit\u001b[1;34m(self, experiments)\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[39m# map the surrogate spec, we keep it here as attribute to be able to save/dump\u001b[39;00m\n\u001b[0;32m 135\u001b[0m \u001b[39m# the surrogate\u001b[39;00m\n\u001b[0;32m 136\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogates \u001b[39m=\u001b[39m BotorchSurrogates(data_model\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogate_specs) \u001b[39m# type: ignore\u001b[39;00m\n\u001b[1;32m--> 138\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msurrogates\u001b[39m.\u001b[39;49mfit(experiments) \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 139\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogates\u001b[39m.\u001b[39mcompatibilize( \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 140\u001b[0m inputs\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdomain\u001b[39m.\u001b[39minputs, \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 141\u001b[0m outputs\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdomain\u001b[39m.\u001b[39moutputs, \u001b[39m# type: ignore\u001b[39;00m\n\u001b[0;32m 142\u001b[0m )\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\surrogates\\botorch_surrogates.py:40\u001b[0m, in \u001b[0;36mBotorchSurrogates.fit\u001b[1;34m(self, experiments)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[39mfor\u001b[39;00m model \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msurrogates:\n\u001b[0;32m 39\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(model, TrainableSurrogate):\n\u001b[1;32m---> 40\u001b[0m model\u001b[39m.\u001b[39;49mfit(experiments)\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\surrogates\\trainable.py:35\u001b[0m, in \u001b[0;36mTrainableSurrogate.fit\u001b[1;34m(self, experiments, options)\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[39m# fit\u001b[39;00m\n\u001b[0;32m 34\u001b[0m options \u001b[39m=\u001b[39m options \u001b[39mor\u001b[39;00m {}\n\u001b[1;32m---> 35\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fit(X\u001b[39m=\u001b[39mX, Y\u001b[39m=\u001b[39mY, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39moptions)\n", - "File \u001b[1;32m~\\OneDrive - Evonik Industries AG\\Documents\\Projects\\bofire\\bofire\\surrogates\\gp_classifier.py:85\u001b[0m, in \u001b[0;36mGPClassifier._fit\u001b[1;34m(self, X, Y)\u001b[0m\n\u001b[0;32m 82\u001b[0m tf \u001b[39m=\u001b[39m ChainedInputTransform(tf1\u001b[39m=\u001b[39mscaler, tf2\u001b[39m=\u001b[39mo2n) \u001b[39mif\u001b[39;00m scaler \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m o2n\n\u001b[0;32m 84\u001b[0m \u001b[39m# fit the model\u001b[39;00m\n\u001b[1;32m---> 85\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel \u001b[39m=\u001b[39m botorch\u001b[39m.\u001b[39;49mmodels\u001b[39m.\u001b[39;49mMultiTaskGP(\n\u001b[0;32m 86\u001b[0m train_X\u001b[39m=\u001b[39;49mo2n\u001b[39m.\u001b[39;49mtransform(tX),\n\u001b[0;32m 87\u001b[0m train_Y\u001b[39m=\u001b[39;49mtY,\n\u001b[0;32m 88\u001b[0m likelihood\u001b[39m=\u001b[39;49mSoftmaxLikelihood,\n\u001b[0;32m 89\u001b[0m covar_module\u001b[39m=\u001b[39;49mpartial(kernels\u001b[39m.\u001b[39;49mmap, data_model\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcontinuous_kernel),\n\u001b[0;32m 90\u001b[0m outcome_transform\u001b[39m=\u001b[39;49mStandardize(m\u001b[39m=\u001b[39;49mtY\u001b[39m.\u001b[39;49mshape[\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m]),\n\u001b[0;32m 91\u001b[0m input_transform\u001b[39m=\u001b[39;49mtf,\n\u001b[0;32m 92\u001b[0m )\n\u001b[0;32m 93\u001b[0m mll \u001b[39m=\u001b[39m ExactMarginalLogLikelihood(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mlikelihood, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel)\n\u001b[0;32m 94\u001b[0m fit_gpytorch_mll(mll, options\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining_specs)\n", - "\u001b[1;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'task_feature'" - ] - } - ], + "outputs": [], "source": [ "from bofire.data_models.acquisition_functions.api import qEI\n", "from bofire.data_models.strategies.api import SoboStrategy\n", - "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate, GPClassifier\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate\n", "from bofire.data_models.domain.api import Outputs\n", "\n", - "# strategy_data = SoboStrategy(domain=domain1, \n", - "# acquisition_function=qEI(), \n", - "# surrogate_specs=BotorchSurrogates(surrogates=\n", - "# [\n", - "# MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=1.0, n_epochs=50, hidden_layer_sizes=(20,)),\n", - "# MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", - "# ]\n", - "# )\n", - "# )\n", "strategy_data = SoboStrategy(domain=domain1, \n", " acquisition_function=qEI(), \n", " surrogate_specs=BotorchSurrogates(surrogates=\n", " [\n", - " GPClassifier(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")])),\n", + " MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.1, n_epochs=100, hidden_layer_sizes=(20,10,)),\n", " MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", " ]\n", " )\n", @@ -482,13 +455,482 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5f_0_predf_2_predf_1_predf_1_pred_unacceptablef_1_pred_acceptablef_1_pred_idealf_0_sdf_2_sdf_1_sd_unacceptablef_1_sd_acceptablef_1_sd_idealf_0_desf_2_desf_1_des
00.011502-0.1619201.9722170.4385550.4570980.0-0.5081090.018091ideal0.0051090.2577330.7371570.4247430.0028010.0045730.3467970.3449080.5081090.4977390.994891
10.165405-0.124200-0.2838040.3569750.6522891.0-0.5497020.171525acceptable0.0019800.9965530.0014670.3654220.0027200.0017160.0038790.0024280.5497020.4785720.998020
2-0.008095-0.1885121.6624390.3466880.4308951.0-0.371163-0.001649acceptable0.0004940.5782230.4212830.4183340.0027770.0006780.4842750.4843950.3711630.5002060.999506
3-0.0341152.000000-0.095091-0.3787932.0000000.00.887337-0.029129ideal0.0013760.1367120.8619120.8992510.0028710.0030500.3032320.302495-0.8873370.5036410.998624
4-0.0116000.1171771.6154370.2018910.7444780.0-0.337936-0.005347ideal0.0077440.2360330.7562230.4265270.0027730.0079030.3797990.3775170.3379360.5006680.992256
5-0.073124-0.2437462.0000000.4021200.3528891.0-0.395214-0.066696acceptable0.0003550.5077540.4918910.4089500.0028030.0005580.4720510.4722810.3952140.5083360.999645
60.0968870.753092-0.296234-0.2686971.7570061.00.2564620.102279acceptable0.0004520.6145700.3849780.5090060.0027770.0008300.4646930.464979-0.2564620.4872180.999548
70.1738880.237557-0.316069-0.0110890.8945071.0-0.3760640.179808acceptable0.0015980.9960050.0023960.4057100.0027270.0019530.0059980.0043700.3760640.4775390.998402
8-0.249681-0.0602912.0000000.3593770.6700420.0-0.281907-0.243816ideal0.0055390.2217720.7726900.4542170.0028040.0055960.3107030.3080600.2819070.5304390.994461
90.126138-0.276652-0.1651110.6719390.4427781.0-0.4360710.132389acceptable0.0021370.9964030.0014600.3461240.0027220.0015720.0036720.0024230.4360710.4834570.997863
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_2_pred \\\n", + "0 0.011502 -0.161920 1.972217 0.438555 0.457098 0.0 -0.508109 0.018091 \n", + "1 0.165405 -0.124200 -0.283804 0.356975 0.652289 1.0 -0.549702 0.171525 \n", + "2 -0.008095 -0.188512 1.662439 0.346688 0.430895 1.0 -0.371163 -0.001649 \n", + "3 -0.034115 2.000000 -0.095091 -0.378793 2.000000 0.0 0.887337 -0.029129 \n", + "4 -0.011600 0.117177 1.615437 0.201891 0.744478 0.0 -0.337936 -0.005347 \n", + "5 -0.073124 -0.243746 2.000000 0.402120 0.352889 1.0 -0.395214 -0.066696 \n", + "6 0.096887 0.753092 -0.296234 -0.268697 1.757006 1.0 0.256462 0.102279 \n", + "7 0.173888 0.237557 -0.316069 -0.011089 0.894507 1.0 -0.376064 0.179808 \n", + "8 -0.249681 -0.060291 2.000000 0.359377 0.670042 0.0 -0.281907 -0.243816 \n", + "9 0.126138 -0.276652 -0.165111 0.671939 0.442778 1.0 -0.436071 0.132389 \n", + "\n", + " f_1_pred f_1_pred_unacceptable f_1_pred_acceptable f_1_pred_ideal \\\n", + "0 ideal 0.005109 0.257733 0.737157 \n", + "1 acceptable 0.001980 0.996553 0.001467 \n", + "2 acceptable 0.000494 0.578223 0.421283 \n", + "3 ideal 0.001376 0.136712 0.861912 \n", + "4 ideal 0.007744 0.236033 0.756223 \n", + "5 acceptable 0.000355 0.507754 0.491891 \n", + "6 acceptable 0.000452 0.614570 0.384978 \n", + "7 acceptable 0.001598 0.996005 0.002396 \n", + "8 ideal 0.005539 0.221772 0.772690 \n", + "9 acceptable 0.002137 0.996403 0.001460 \n", + "\n", + " f_0_sd f_2_sd f_1_sd_unacceptable f_1_sd_acceptable f_1_sd_ideal \\\n", + "0 0.424743 0.002801 0.004573 0.346797 0.344908 \n", + "1 0.365422 0.002720 0.001716 0.003879 0.002428 \n", + "2 0.418334 0.002777 0.000678 0.484275 0.484395 \n", + "3 0.899251 0.002871 0.003050 0.303232 0.302495 \n", + "4 0.426527 0.002773 0.007903 0.379799 0.377517 \n", + "5 0.408950 0.002803 0.000558 0.472051 0.472281 \n", + "6 0.509006 0.002777 0.000830 0.464693 0.464979 \n", + "7 0.405710 0.002727 0.001953 0.005998 0.004370 \n", + "8 0.454217 0.002804 0.005596 0.310703 0.308060 \n", + "9 0.346124 0.002722 0.001572 0.003672 0.002423 \n", + "\n", + " f_0_des f_2_des f_1_des \n", + "0 0.508109 0.497739 0.994891 \n", + "1 0.549702 0.478572 0.998020 \n", + "2 0.371163 0.500206 0.999506 \n", + "3 -0.887337 0.503641 0.998624 \n", + "4 0.337936 0.500668 0.992256 \n", + "5 0.395214 0.508336 0.999645 \n", + "6 -0.256462 0.487218 0.999548 \n", + "7 0.376064 0.477539 0.998402 \n", + "8 0.281907 0.530439 0.994461 \n", + "9 0.436071 0.483457 0.997863 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "candidates = strategy.ask(2)\n", + "candidates = strategy.ask(10)\n", "candidates" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See performance of the classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "We defined 'unacceptable' as values in (-infinity, 1.5), 'acceptable' as values in [1.5, 3.5), and 'ideal' as values in [3.5, infinity)\n", + "\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0f_1_pred
02.717452ideal
11.766665acceptable
23.243415acceptable
33.492002ideal
42.667383ideal
53.438138acceptable
63.042054acceptable
71.978794acceptable
82.719447ideal
91.799092acceptable
\n", + "
" + ], + "text/plain": [ + " 0 f_1_pred\n", + "0 2.717452 ideal\n", + "1 1.766665 acceptable\n", + "2 3.243415 acceptable\n", + "3 3.492002 ideal\n", + "4 2.667383 ideal\n", + "5 3.438138 acceptable\n", + "6 3.042054 acceptable\n", + "7 1.978794 acceptable\n", + "8 2.719447 ideal\n", + "9 1.799092 acceptable" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(f\"We defined 'unacceptable' as values in (-infinity, 1.5), 'acceptable' as values in [1.5, 3.5), and 'ideal' as values in [3.5, infinity)\\n\")\n", + "pd.concat((candidates[[feat.key for feat in input_features]].astype(float).sum(1), candidates[\"f_1_pred\"]), axis=1)" + ] } ], "metadata": { From 9caecc713624b46683b89ee9f3093911a52ed0db Mon Sep 17 00:00:00 2001 From: gmancino Date: Wed, 24 Jan 2024 11:24:23 -0500 Subject: [PATCH 10/31] Address previous PR issues --- bofire/data_models/domain/features.py | 37 +- bofire/data_models/enum.py | 7 + bofire/data_models/features/categorical.py | 58 +- bofire/data_models/objectives/api.py | 5 +- bofire/data_models/objectives/categorical.py | 55 +- bofire/data_models/surrogates/api.py | 15 +- .../surrogates/botorch_surrogates.py | 10 +- bofire/data_models/surrogates/empirical.py | 3 - .../data_models/surrogates/fully_bayesian.py | 7 +- bofire/data_models/surrogates/linear.py | 5 +- .../surrogates/mixed_single_task_gp.py | 5 +- bofire/data_models/surrogates/mlp.py | 29 +- .../data_models/surrogates/mlp_classifier.py | 36 - bofire/data_models/surrogates/polynomial.py | 5 +- .../data_models/surrogates/random_forest.py | 5 +- .../data_models/surrogates/single_task_gp.py | 5 +- bofire/data_models/surrogates/surrogate.py | 12 +- bofire/data_models/surrogates/tanimoto_gp.py | 5 +- bofire/data_models/surrogates/xgb.py | 5 +- bofire/strategies/predictives/predictive.py | 39 +- .../samplers/universal_constraint.py | 4 +- bofire/surrogates/api.py | 7 +- bofire/surrogates/diagnostics.py | 65 +- bofire/surrogates/mapper.py | 7 +- bofire/surrogates/mlp.py | 91 +- bofire/surrogates/mlp_classifier.py | 305 ---- bofire/surrogates/surrogate.py | 74 +- bofire/surrogates/trainable.py | 15 + bofire/surrogates/values.py | 4 +- .../bofire/data_models/domain/test_outputs.py | 36 +- tests/bofire/data_models/specs/features.py | 12 +- tests/bofire/data_models/specs/objectives.py | 21 + tests/bofire/data_models/specs/surrogates.py | 101 +- .../Unknown_Constraint_Classification.ipynb | 1224 ++++++++++------- 34 files changed, 1257 insertions(+), 1057 deletions(-) delete mode 100644 bofire/data_models/surrogates/mlp_classifier.py delete mode 100644 bofire/surrogates/mlp_classifier.py diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index 6880ce8c7..a0b7bd16f 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -619,13 +619,11 @@ def get_keys_by_objective( self, includes: Union[ List[Type[AbstractObjective]], - Tuple[Type[AbstractObjective]], Type[AbstractObjective], Type[Objective], ] = Objective, excludes: Union[ List[Type[AbstractObjective]], - Tuple[Type[AbstractObjective]], Type[AbstractObjective], None, ] = None, @@ -665,11 +663,11 @@ def __call__( and not isinstance(feat, CategoricalOutput) ] + [ - pd.Series(data=feat(experiments.filter(regex=f"{feat.key}_pred_")), name=f"{feat.key}_pred") # type: ignore + pd.Series(data=feat(experiments.filter(regex=f"{feat.key}(.*)_prob")), name=f"{feat.key}_pred") # type: ignore if predictions else experiments[feat.key] for feat in self.features - if isinstance(feat, CategoricalOutput) + if feat.objective is not None and isinstance(feat, CategoricalOutput) ], axis=1, ) @@ -720,8 +718,8 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: continuous_cols = list( itertools.chain.from_iterable( [ - [f"{obj.key}_pred", f"{obj.key}_sd", f"{obj.key}_des"] - for obj in self.get_by_objective( + [f"{feat.key}_pred", f"{feat.key}_sd", f"{feat.key}_des"] + for feat in self.get_by_objective( includes=Objective, excludes=CategoricalObjective ) ] @@ -745,23 +743,18 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: raise ValueError(f"Not all values of column `{col}` are numerical.") if candidates[col].isnull().to_numpy().any(): raise ValueError(f"Nan values are present in {col}.") - # # Check for categorical output - # categorical_cols = [ - # (f"{obj.key}_pred", obj.categories) - # for obj in self.get_by_objective(includes=CategoricalObjective) - # ] - # if len(categorical_cols) == 0: - # return candidates - # for col in categorical_cols: - # if col[0] not in candidates: - # raise ValueError(f"missing column {col}") - # if len(candidates[col[0]]) - candidates[col[0]].isin(col[1]).sum() > 0: - # raise ValueError(f"values present are not in {col[1]}") + # Looping over features allows to check categories objective wise for feat in self.get(CategoricalOutput): - col = f"{feat.key}_pred" - if col not in candidates: - raise ValueError(f"missing column {col}") - feat.validate_experimental(candidates[col]) + cols = [f"{feat.key}_pred", f"{feat.key}_des"] + for col in cols: + if col not in candidates: + raise ValueError(f"missing column {col}") + if col == f"{feat.key}_pred": + feat.validate_experimental(candidates[col]) + else: + # Check sd and desirability + if candidates[col].isnull().to_numpy().any(): + raise ValueError(f"Nan values are present in {col}.") return candidates def preprocess_experiments_one_valid_output( diff --git a/bofire/data_models/enum.py b/bofire/data_models/enum.py index 4331aa5fd..d055b5e08 100644 --- a/bofire/data_models/enum.py +++ b/bofire/data_models/enum.py @@ -28,6 +28,13 @@ class CategoricalEncodingEnum(Enum): DESCRIPTOR = "DESCRIPTOR" # only possible for categorical with descriptors +class ClassificationMetricsEnum(Enum): + """Enumeration class for classification metrics.""" + + ACCURACY = "ACCURACY" + F1 = "F1" + + class OutputFilteringEnum(Enum): ALL = "ALL" ANY = "ANY" diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 368f6fc83..14ebfb3e3 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd -from pydantic import root_validator, validator +from pydantic import Field, field_validator, model_validator from bofire.data_models.enum import CategoricalEncodingEnum from bofire.data_models.features.feature import ( @@ -13,8 +13,8 @@ TCategoryVals, TTransform, ) +from bofire.data_models.objectives.api import AnyCategoricalObjective from bofire.data_models.objectives.categorical import ( - CategoricalObjective, ConstrainedCategoricalObjective, ) @@ -23,8 +23,8 @@ class CategoricalInput(Input): """Base class for all categorical input features. Attributes: - categories (Tuple[str]): Names of the categories. - allowed (Tuple[bool]): List of bools indicating if a category is allowed within the optimization. + categories (List[str]): Names of the categories. + allowed (List[bool]): List of bools indicating if a category is allowed within the optimization. """ type: Literal["CategoricalInput"] = "CategoricalInput" @@ -39,7 +39,7 @@ def validate_categories_unique(cls, categories): """validates that categories have unique names Args: - categories (Union[List[str], Tuple[str]]): List or tuple of category names + categories (List[str]): List of category names Raises: ValueError: when categories have non-unique names @@ -49,7 +49,7 @@ def validate_categories_unique(cls, categories): """ if len(categories) != len(set(categories)): raise ValueError("categories must be unique") - return tuple(categories) + return categories @field_validator("allowed") @classmethod @@ -358,16 +358,19 @@ class CategoricalOutput(Output): order_id: ClassVar[int] = 8 categories: TCategoryVals - objective: Optional[ - Union[CategoricalObjective, ConstrainedCategoricalObjective] - ] = None + objective: Optional[AnyCategoricalObjective] = Field( + default_factory=lambda: ConstrainedCategoricalObjective( + w=1.0, categories=["a", "b"], desirability=[True, False] + ) + ) - @validator("categories", allow_reuse=True) - def validate_categories_unique(cls, categories): + @field_validator("categories") + @classmethod + def validate_categories_unique(cls, categories: List[str]) -> List["str"]: """validates that categories have unique names Args: - categories (Union[List[str], Tuple[str]]): List or tuple of category names + categories (List[str]): List or tuple of category names Raises: ValueError: when categories have non-unique names @@ -377,14 +380,17 @@ def validate_categories_unique(cls, categories): """ if len(categories) != len(set(categories)): raise ValueError("categories must be unique") - return tuple(categories) + return categories - @validator("objective") - def validate_objectives_unique(cls, objective, values): + @field_validator("objective") + @classmethod + def validate_objectives_unique( + cls, objective: AnyCategoricalObjective, info + ) -> AnyCategoricalObjective: """validates that categories have unique names Args: - categories (Union[List[str], Tuple[str]]): List or tuple of category names + categories (List[str]): List or tuple of category names Raises: ValueError: when categories do not match objective categories @@ -392,7 +398,7 @@ def validate_objectives_unique(cls, objective, values): Returns: Tuple[str]: Tuple of the categories """ - if objective.categories != tuple(values["categories"]): + if objective.categories != info.data["categories"]: raise ValueError("categories must match to objective categories") return objective @@ -400,18 +406,10 @@ def validate_objectives_unique(cls, objective, values): def from_objective( cls, key: str, - objective: Union[CategoricalObjective, ConstrainedCategoricalObjective], + objective: ConstrainedCategoricalObjective, ): return cls(key=key, objective=objective, categories=objective.categories) - def validate_experimental(self, values: pd.Series) -> pd.Series: - values = values.map(str) - if sum(values.isin(self.categories)) != len(values): - raise ValueError( - f"invalid values for `{self.key}`, allowed are: `{self.categories}`" - ) - return values - def __call__(self, values: pd.Series) -> pd.Series: if self.objective is None: return pd.Series( @@ -421,5 +419,13 @@ def __call__(self, values: pd.Series) -> pd.Series: ) return self.objective(values) # type: ignore + def validate_experimental(self, values: pd.Series) -> pd.Series: + values = values.map(str) + if sum(values.isin(self.categories)) != len(values): + raise ValueError( + f"invalid values for `{self.key}`, allowed are: `{self.categories}`" + ) + return values + def __str__(self) -> str: return "CategoricalOutputFeature" diff --git a/bofire/data_models/objectives/api.py b/bofire/data_models/objectives/api.py index 0fe3f2e7f..133440ad9 100644 --- a/bofire/data_models/objectives/api.py +++ b/bofire/data_models/objectives/api.py @@ -26,11 +26,12 @@ IdentityObjective, SigmoidObjective, ConstrainedObjective, + CategoricalObjective, ] +AnyCategoricalObjective = ConstrainedCategoricalObjective + AnyConstraintObjective = Union[ - CategoricalObjective, - ConstrainedCategoricalObjective, MaximizeSigmoidObjective, MinimizeSigmoidObjective, TargetObjective, diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index 2540376de..960a481d7 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -1,9 +1,8 @@ -from typing import Dict, Literal, Tuple, Union -from warnings import warn +from typing import Dict, List, Literal, Union import numpy as np import pandas as pd -from pydantic import validator +from pydantic import field_validator from bofire.data_models.features.feature import TCategoryVals from bofire.data_models.objectives.objective import ( @@ -13,37 +12,13 @@ ) -class CategoricalObjective(Objective): - """Categorical objective class; stores categories""" - - type: Literal["CategoricalObjective"] = "CategoricalObjective" - categories: TCategoryVals - - @validator("categories") - def validate_categories_unique(cls, categories): - """validates that categories have unique names - - Args: - categories (Union[List[str], Tuple[str]]): List or tuple of category names - - Raises: - ValueError: when categories have non-unique names - - Returns: - Tuple[str]: Tuple of the categories - """ - if len(categories) != len(set(categories)): - raise ValueError("categories must be unique") - return tuple(categories) - - def __call__(self, x: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: - warn( - "Categorical objective currently does not have a function. Returning the original input." - ) - return x +class CategoricalObjective: + """Abstract categorical objective class""" -class ConstrainedCategoricalObjective(ConstrainedObjective, CategoricalObjective): +class ConstrainedCategoricalObjective( + ConstrainedObjective, CategoricalObjective, Objective +): """Compute the categorical objective value as: Po where P is an [n, c] matrix where each row is a probability vector @@ -51,21 +26,23 @@ class ConstrainedCategoricalObjective(ConstrainedObjective, CategoricalObjective Attributes: w (float): float between zero and one for weighting the objective. - desirability (tuple): tuple of values of size c (c is number of categories) such that the i-th entry is in {True, False} + desirability (list): list of values of size c (c is number of categories) such that the i-th entry is in {True, False} """ w: TWeight = 1.0 categories: TCategoryVals - desirability: Tuple[bool, ...] + desirability: List[bool] eta: float = 1.0 type: Literal["ConstrainedCategoricalObjective"] = "ConstrainedCategoricalObjective" - @validator("desirability") - def validate_categories_unique(cls, desirability, values): + @field_validator( + "desirability", + ) + def validate_categories_unique(cls, desirability: List[bool], info) -> List[bool]: """validates that desirabilities match the categories Args: - categories (Union[List[str], Tuple[str]]): List or tuple of category names + categories (List[str]): List or tuple of category names Raises: ValueError: when desirability count is not equal to category count @@ -73,11 +50,11 @@ def validate_categories_unique(cls, desirability, values): Returns: Tuple[bool]: Tuple of the desirability """ - if len(desirability) != len(values["categories"]): + if len(desirability) != len(info.data["categories"]): raise ValueError( "number of categories differs from number of desirabilities" ) - return tuple(desirability) + return desirability def to_dict(self) -> Dict: """Returns the categories and corresponding objective values as dictionary""" diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index 5f89ec894..4cc6b77d5 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -14,8 +14,11 @@ from bofire.data_models.surrogates.mixed_single_task_gp import ( MixedSingleTaskGPSurrogate, ) - from bofire.data_models.surrogates.mlp import MLPEnsemble - from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble + from bofire.data_models.surrogates.mlp import ( + ClassificationMLPEnsemble, + MLPEnsemble, + RegressionMLPEnsemble, + ) from bofire.data_models.surrogates.polynomial import PolynomialSurrogate from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import ( @@ -34,8 +37,8 @@ RandomForestSurrogate, SingleTaskGPSurrogate, MixedSingleTaskGPSurrogate, - MLPClassifierEnsemble, - MLPEnsemble, + ClassificationMLPEnsemble, + RegressionMLPEnsemble, SaasSingleTaskGPSurrogate, XGBoostSurrogate, LinearSurrogate, @@ -47,8 +50,8 @@ RandomForestSurrogate, SingleTaskGPSurrogate, MixedSingleTaskGPSurrogate, - MLPClassifierEnsemble, - MLPEnsemble, + ClassificationMLPEnsemble, + RegressionMLPEnsemble, SaasSingleTaskGPSurrogate, XGBoostSurrogate, LinearSurrogate, diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index 10331966a..9f6800c20 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -12,8 +12,10 @@ from bofire.data_models.surrogates.mixed_single_task_gp import ( MixedSingleTaskGPSurrogate, ) -from bofire.data_models.surrogates.mlp import MLPEnsemble -from bofire.data_models.surrogates.mlp_classifier import MLPClassifierEnsemble +from bofire.data_models.surrogates.mlp import ( + ClassificationMLPEnsemble, + RegressionMLPEnsemble, +) from bofire.data_models.surrogates.polynomial import PolynomialSurrogate from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate @@ -24,8 +26,8 @@ RandomForestSurrogate, SingleTaskGPSurrogate, MixedSingleTaskGPSurrogate, - MLPClassifierEnsemble, - MLPEnsemble, + RegressionMLPEnsemble, + ClassificationMLPEnsemble, SaasSingleTaskGPSurrogate, TanimotoGPSurrogate, LinearSurrogate, diff --git a/bofire/data_models/surrogates/empirical.py b/bofire/data_models/surrogates/empirical.py index 018f53208..d055a37cf 100644 --- a/bofire/data_models/surrogates/empirical.py +++ b/bofire/data_models/surrogates/empirical.py @@ -10,11 +10,8 @@ class EmpiricalSurrogate(BotorchSurrogate): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ diff --git a/bofire/data_models/surrogates/fully_bayesian.py b/bofire/data_models/surrogates/fully_bayesian.py index 5a67e8ae6..96e15215e 100644 --- a/bofire/data_models/surrogates/fully_bayesian.py +++ b/bofire/data_models/surrogates/fully_bayesian.py @@ -18,17 +18,14 @@ class SaasSingleTaskGPSurrogate(TrainableBotorchSurrogate): def validate_thinning(cls, thinning, info): if info.data["num_samples"] / thinning < 1: raise ValueError("`num_samples` has to be larger than `thinning`.") - return value + return thinning @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/linear.py b/bofire/data_models/surrogates/linear.py index eba916356..49b6f3980 100644 --- a/bofire/data_models/surrogates/linear.py +++ b/bofire/data_models/surrogates/linear.py @@ -23,12 +23,9 @@ class LinearSurrogate(TrainableBotorchSurrogate): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/mixed_single_task_gp.py b/bofire/data_models/surrogates/mixed_single_task_gp.py index 553cabe30..3a767d1e7 100644 --- a/bofire/data_models/surrogates/mixed_single_task_gp.py +++ b/bofire/data_models/surrogates/mixed_single_task_gp.py @@ -40,12 +40,9 @@ def validate_categoricals(cls, v, values): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index 9c6e6d3a2..8e887cc9f 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -2,7 +2,11 @@ from pydantic import Field -from bofire.data_models.features.api import AnyOutput, ContinuousOutput +from bofire.data_models.features.api import ( + AnyOutput, + CategoricalOutput, + ContinuousOutput, +) from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate @@ -21,15 +25,32 @@ class MLPEnsemble(TrainableBotorchSurrogate): shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE + +class RegressionMLPEnsemble(MLPEnsemble): + type: Literal["RegressionMLPEnsemble"] = "RegressionMLPEnsemble" + final_activation: Literal["identity"] = "identity" + @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return isinstance(my_type, ContinuousOutput) + +class ClassificationMLPEnsemble(MLPEnsemble): + type: Literal["ClassificationMLPEnsemble"] = "ClassificationMLPEnsemble" + final_activation: Literal["softmax"] = "softmax" + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + Args: + my_type: continuous or categorical output Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, CategoricalOutput) diff --git a/bofire/data_models/surrogates/mlp_classifier.py b/bofire/data_models/surrogates/mlp_classifier.py deleted file mode 100644 index 9e99ed6d5..000000000 --- a/bofire/data_models/surrogates/mlp_classifier.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Annotated, Literal, Sequence, Type - -from pydantic import Field - -from bofire.data_models.features.api import AnyOutput, CategoricalOutput -from bofire.data_models.surrogates.botorch import BotorchSurrogate -from bofire.data_models.surrogates.scaler import ScalerEnum -from bofire.data_models.surrogates.trainable import TrainableSurrogate - - -class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): - type: Literal["MLPClassifierEnsemble"] = "MLPClassifierEnsemble" - n_estimators: Annotated[int, Field(ge=1)] = 5 - hidden_layer_sizes: Sequence = (100,) - activation: Literal["relu", "logistic", "tanh"] = "relu" - dropout: Annotated[float, Field(ge=0.0)] = 0.0 - batch_size: Annotated[int, Field(ge=1)] = 10 - n_epochs: Annotated[int, Field(ge=1)] = 200 - lr: Annotated[float, Field(gt=0.0)] = 1e-4 - weight_decay: Annotated[float, Field(ge=0.0)] = 0.0 - subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 - shuffle: bool = True - scaler: ScalerEnum = ScalerEnum.NORMALIZE - - @classmethod - def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: - """Abstract method to check output type for surrogate models - - Args: - outputs: objective functions for the surrogate - my_type: continuous or categorical output - - Returns: - bool: True if the output type is valid for the surrogate chosen, False otherwise - """ - return True if isinstance(my_type, CategoricalOutput) else False diff --git a/bofire/data_models/surrogates/polynomial.py b/bofire/data_models/surrogates/polynomial.py index 2c001d442..63e4539ad 100644 --- a/bofire/data_models/surrogates/polynomial.py +++ b/bofire/data_models/surrogates/polynomial.py @@ -28,12 +28,9 @@ def from_power(power: int, inputs: Inputs, outputs: Outputs): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/random_forest.py b/bofire/data_models/surrogates/random_forest.py index 550d947ed..650f5069f 100644 --- a/bofire/data_models/surrogates/random_forest.py +++ b/bofire/data_models/surrogates/random_forest.py @@ -34,12 +34,9 @@ class RandomForestSurrogate(TrainableBotorchSurrogate): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/single_task_gp.py b/bofire/data_models/surrogates/single_task_gp.py index 0994f50e7..594a9b8cf 100644 --- a/bofire/data_models/surrogates/single_task_gp.py +++ b/bofire/data_models/surrogates/single_task_gp.py @@ -116,12 +116,9 @@ class SingleTaskGPSurrogate(TrainableBotorchSurrogate): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/surrogate.py b/bofire/data_models/surrogates/surrogate.py index b2d60d977..2270da9c6 100644 --- a/bofire/data_models/surrogates/surrogate.py +++ b/bofire/data_models/surrogates/surrogate.py @@ -10,7 +10,6 @@ class Surrogate(BaseModel): type: str - inputs: Inputs outputs: Outputs input_preprocessing_specs: TInputTransformSpecs = Field( @@ -29,20 +28,21 @@ def validate_input_preprocessing_specs(cls, v, info): @field_validator("outputs") @classmethod - def validate_outputs(cls, v, values): - if len(v) == 0: + def validate_outputs(cls, outputs, info): + if len(outputs) == 0: raise ValueError("At least one output feature has to be provided.") - return v + for o in outputs: + if not cls.is_output_implemented(o): + raise ValueError("Invalid output type passed.") + return outputs @classmethod @abstractmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ diff --git a/bofire/data_models/surrogates/tanimoto_gp.py b/bofire/data_models/surrogates/tanimoto_gp.py index 92efe8c05..00aed8641 100644 --- a/bofire/data_models/surrogates/tanimoto_gp.py +++ b/bofire/data_models/surrogates/tanimoto_gp.py @@ -31,12 +31,9 @@ class TanimotoGPSurrogate(TrainableBotorchSurrogate): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/data_models/surrogates/xgb.py b/bofire/data_models/surrogates/xgb.py index fabd0e883..262ca1649 100644 --- a/bofire/data_models/surrogates/xgb.py +++ b/bofire/data_models/surrogates/xgb.py @@ -82,12 +82,9 @@ def validate_input_preprocessing_specs(cls, v, info): @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: """Abstract method to check output type for surrogate models - Args: - outputs: objective functions for the surrogate my_type: continuous or categorical output - Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return True if isinstance(my_type, ContinuousOutput) else False + return isinstance(my_type, ContinuousOutput) diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index 09cad9416..da0417593 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -108,7 +108,7 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: *[ [f"{feat.key}_pred"] if not isinstance(feat, CategoricalOutput) - else [f"{feat.key}_pred_{cat}" for cat in feat.categories] + else [f"{feat.key}_{cat}_prob" for cat in feat.categories] for feat in self.domain.outputs.get() ] ) @@ -122,7 +122,7 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: *[ [f"{feat.key}_sd"] if not isinstance(feat, CategoricalOutput) - else [f"{feat.key}_sd_{cat}" for cat in feat.categories] + else [f"{feat.key}_{cat}_sd" for cat in feat.categories] for feat in self.domain.outputs.get() ] ) @@ -133,21 +133,26 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: data=preds, columns=column_names, ) - categorical_preds = { - f"{feat.key}_pred": ( - ind, - predictions.filter(regex=f"{feat.key}_pred_") - .idxmax(1) - .str.replace(f"{feat.key}_pred_", "") - .values, - ) - for ind, feat in enumerate(self.domain.outputs.get()) - if isinstance(feat, CategoricalOutput) - } - for key in categorical_preds.keys(): - predictions.insert( - categorical_preds[key][0], key, categorical_preds[key][1] - ) + for feat in self.domain.outputs.get(): + if isinstance(feat, CategoricalOutput): + predictions.insert( + loc=0, + column=f"{feat.key}_pred", + value=predictions.filter(regex=f"{feat.key}(.*)_prob") + .idxmax(1) + .str.replace(f"{feat.key}_", "") + .str.replace("_prob", "") + .values, + ) + predictions.insert( + loc=1, + column=f"{feat.key}_sd", + value=predictions.filter(regex=f"{feat.key}(.*)_sd") + .pow(2.0) + .sum(1) + .pow(0.5) + .values, + ) desis = self.domain.outputs(predictions, predictions=True) predictions = pd.concat((predictions, desis), axis=1) predictions.index = experiments.index diff --git a/bofire/strategies/samplers/universal_constraint.py b/bofire/strategies/samplers/universal_constraint.py index f4990688a..fff8d524a 100644 --- a/bofire/strategies/samplers/universal_constraint.py +++ b/bofire/strategies/samplers/universal_constraint.py @@ -37,9 +37,7 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: fixed_experiments=self.candidates, ) - samples = samples.iloc[ - self.num_candidates :, - ] + samples = samples.iloc[self.num_candidates :,] samples = samples.sample( n=candidate_count, replace=False, diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index aa317da08..3a47cd649 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -2,8 +2,11 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.mapper import map from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -from bofire.surrogates.mlp import MLPEnsemble -from bofire.surrogates.mlp_classifier import MLPClassifierEnsemble +from bofire.surrogates.mlp import ( + ClassificationMLPEnsemble, + MLPEnsemble, + RegressionMLPEnsemble, +) from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate diff --git a/bofire/surrogates/diagnostics.py b/bofire/surrogates/diagnostics.py index eb55a6cb8..6e0935196 100644 --- a/bofire/surrogates/diagnostics.py +++ b/bofire/surrogates/diagnostics.py @@ -7,6 +7,8 @@ from scipy.integrate import simps from scipy.stats import fisher_exact, kendalltau, norm, pearsonr, spearmanr from sklearn.metrics import ( + accuracy_score, + f1_score, mean_absolute_error, mean_absolute_percentage_error, mean_squared_error, @@ -15,7 +17,49 @@ from bofire.data_models.base import BaseModel from bofire.data_models.domain.domain import is_numeric -from bofire.data_models.enum import RegressionMetricsEnum, UQRegressionMetricsEnum +from bofire.data_models.enum import ( + ClassificationMetricsEnum, + RegressionMetricsEnum, + UQRegressionMetricsEnum, +) + + +def _accuracy_score( + observed: np.ndarray, + predicted: np.ndarray, + standard_deviation: Optional[np.ndarray] = None, +) -> float: + """Calculates the standard accuracy score. + + Args: + observed (np.ndarray): Observed data. + predicted (np.ndarray): Predicted data. + standard_deviation (Optional[np.ndarray], optional): Predicted standard deviation. + Ignored in the calculation. Defaults to None. + + Returns: + float: Accuracy score. + """ + return float(accuracy_score(observed, predicted)) + + +def _f1_score( + observed: np.ndarray, + predicted: np.ndarray, + standard_deviation: Optional[np.ndarray] = None, +) -> float: + """Calculates the f1 accuracy score. + + Args: + observed (np.ndarray): Observed data. + predicted (np.ndarray): Predicted data. + standard_deviation (Optional[np.ndarray], optional): Predicted standard deviation. + Ignored in the calculation. Defaults to None. + + Returns: + float: Accuracy score. + """ + return float(f1_score(observed, predicted, average="micro")) def _mean_absolute_error( @@ -429,6 +473,8 @@ def _AbsoluteMiscalibrationArea( RegressionMetricsEnum.PEARSON: _pearson, RegressionMetricsEnum.SPEARMAN: _spearman, RegressionMetricsEnum.FISHER: _fisher_exact_test_p, + ClassificationMetricsEnum.ACCURACY: _accuracy_score, + ClassificationMetricsEnum.F1: _f1_score, } UQ_metrics = { @@ -511,7 +557,10 @@ def n_samples(self) -> int: return len(self.observed) def get_metric( - self, metric: Union[RegressionMetricsEnum, UQRegressionMetricsEnum] + self, + metric: Union[ + ClassificationMetricsEnum, RegressionMetricsEnum, UQRegressionMetricsEnum + ], ) -> float: """Calculates a metric for the fold. @@ -630,7 +679,9 @@ def _combine_folds(self) -> CvResult: def get_metric( self, - metric: Union[RegressionMetricsEnum, UQRegressionMetricsEnum], + metric: Union[ + ClassificationMetricsEnum, RegressionMetricsEnum, UQRegressionMetricsEnum + ], combine_folds: bool = True, ) -> pd.Series: """Calculates a metric for every fold and returns them as pd.Series. @@ -654,7 +705,13 @@ def get_metric( def get_metrics( self, - metrics: Sequence[Union[RegressionMetricsEnum, UQRegressionMetricsEnum]] = [ + metrics: Sequence[ + Union[ + ClassificationMetricsEnum, + RegressionMetricsEnum, + UQRegressionMetricsEnum, + ] + ] = [ RegressionMetricsEnum.MAE, RegressionMetricsEnum.MSD, RegressionMetricsEnum.R2, diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 39dd8cedc..bacd95e2d 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -4,8 +4,7 @@ from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate -from bofire.surrogates.mlp import MLPEnsemble -from bofire.surrogates.mlp_classifier import MLPClassifierEnsemble +from bofire.surrogates.mlp import ClassificationMLPEnsemble, RegressionMLPEnsemble from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate from bofire.surrogates.surrogate import Surrogate @@ -16,8 +15,8 @@ data_models.RandomForestSurrogate: RandomForestSurrogate, data_models.SingleTaskGPSurrogate: SingleTaskGPSurrogate, data_models.MixedSingleTaskGPSurrogate: MixedSingleTaskGPSurrogate, - data_models.MLPClassifierEnsemble: MLPClassifierEnsemble, - data_models.MLPEnsemble: MLPEnsemble, + data_models.RegressionMLPEnsemble: RegressionMLPEnsemble, + data_models.ClassificationMLPEnsemble: ClassificationMLPEnsemble, data_models.SaasSingleTaskGPSurrogate: SaasSingleTaskGPSurrogate, data_models.XGBoostSurrogate: XGBoostSurrogate, data_models.LinearSurrogate: SingleTaskGPSurrogate, diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 2d4614aa5..9ca802206 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import Literal, Optional, Sequence import numpy as np @@ -10,7 +11,13 @@ from torch.utils.data import DataLoader, Dataset from bofire.data_models.enum import OutputFilteringEnum +from bofire.data_models.surrogates.api import ( + ClassificationMLPEnsemble as DataModelClassification, +) from bofire.data_models.surrogates.api import MLPEnsemble as DataModel +from bofire.data_models.surrogates.api import ( + RegressionMLPEnsemble as DataModelRegression, +) from bofire.data_models.surrogates.scaler import ScalerEnum from bofire.surrogates.botorch import BotorchSurrogate from bofire.surrogates.single_task_gp import get_scaler @@ -20,7 +27,7 @@ class MLPDataset(Dataset): """ - Prepare the dataset for regression + Prepare the dataset for MLP training """ def __init__(self, X: Tensor, y: Tensor): @@ -42,6 +49,7 @@ def __init__( hidden_layer_sizes: Sequence = (100,), dropout: float = 0.0, activation: Literal["relu", "logistic", "tanh"] = "relu", + final_activation: Literal["softmax", "identity"] = "identity", ): super().__init__() if activation == "relu": @@ -69,6 +77,14 @@ def __init__( if dropout > 0.0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(hidden_layer_sizes[-1], output_size).to(**tkwargs)) + if final_activation == "softmax": + layers.append(nn.Softmax(dim=-1)) + elif final_activation == "identity": + layers.append(nn.Identity()) + else: + raise ValueError( + f"Currently only serving classification and regression problems; {final_activation} is not known." + ) self.layers = nn.Sequential(*layers) def forward(self, x): @@ -83,10 +99,10 @@ def __init__( if len(mlps) == 0: raise ValueError("List of mlps is empty.") num_in_features = mlps[0].layers[0].in_features - num_out_features = mlps[0].layers[-1].out_features + num_out_features = mlps[0].layers[-2].out_features for mlp in mlps: assert mlp.layers[0].in_features == num_in_features - assert mlp.layers[-1].out_features == num_out_features + assert mlp.layers[-2].out_features == num_out_features self.mlps = mlps if output_scaler is not None: self.outcome_transform = output_scaler @@ -109,7 +125,7 @@ def forward(self, X: Tensor): @property def num_outputs(self) -> int: r"""The number of outputs of the model.""" - return self.mlps[0].layers[-1].out_features # type: ignore + return self.mlps[0].layers[-2].out_features # type: ignore def fit_mlp( @@ -132,6 +148,7 @@ def fit_mlp( lr (float, optional): Initial learning rate. Defaults to 1e-4. shuffle (bool, optional): Whereas the batches should be shuffled. Defaults to True. weight_decay (float, optional): Weight decay (L2 regularization). Defaults to 0.0 (no regularization). + loss_function (Module, optional): Loss function specified by the problem type. Defaults to L1 loss for regression problems. """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) @@ -189,6 +206,16 @@ def __init__(self, data_model: DataModel, **kwargs): _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL model: Optional[_MLPEnsemble] = None + @abstractmethod + def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): + pass + + +class RegressionMLPEnsemble(MLPEnsemble): + def __init__(self, data_model: DataModelRegression, **kwargs): + self.final_activation = "identity" + super().__init__(data_model, **kwargs) + def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) @@ -216,6 +243,7 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): hidden_layer_sizes=self.hidden_layer_sizes, activation=self.activation, # type: ignore dropout=self.dropout, + final_activation="identity", ) fit_mlp( mlp=mlp, @@ -225,8 +253,63 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): lr=self.lr, shuffle=self.shuffle, weight_decay=self.weight_decay, + loss_function=nn.L1Loss, ) mlps.append(mlp) self.model = _MLPEnsemble(mlps, output_scaler=output_scaler) if scaler is not None: self.model.input_transform = scaler + + +class ClassificationMLPEnsemble(MLPEnsemble): + def __init__(self, data_model: DataModelClassification, **kwargs): + self.final_activation = "softmax" + super().__init__(data_model, **kwargs) + + def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): + scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) + transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) + # Map dictionary of objective values to labels + label_mapping = self.outputs[0].objective.to_dict_label() + + # Convert Y to classification tensor + Y = pd.DataFrame.from_dict( + {col: Y[col].map(label_mapping) for col in Y.columns} + ) + + mlps = [] + subsample_size = round(self.subsample_fraction * X.shape[0]) + for _ in range(self.n_estimators): + # resample X and Y + sample_idx = np.random.choice(X.shape[0], replace=True, size=subsample_size) + tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) + ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) + + dataset = MLPDataset( + X=scaler.transform(tX) if scaler is not None else tX, + y=ty, + ) + mlp = MLP( + input_size=transformed_X.shape[1], + output_size=len( + label_mapping + ), # Set outputs based on number of categories + hidden_layer_sizes=self.hidden_layer_sizes, + activation=self.activation, # type: ignore + dropout=self.dropout, + final_activation="softmax", + ) + fit_mlp( + mlp=mlp, + dataset=dataset, + batch_size=self.batch_size, + n_epoches=self.n_epochs, + lr=self.lr, + shuffle=self.shuffle, + weight_decay=self.weight_decay, + loss_function=nn.CrossEntropyLoss, # utilizes logits as input + ) + mlps.append(mlp) + self.model = _MLPEnsemble(mlps=mlps) + if scaler is not None: + self.model.input_transform = scaler diff --git a/bofire/surrogates/mlp_classifier.py b/bofire/surrogates/mlp_classifier.py deleted file mode 100644 index dcbb12230..000000000 --- a/bofire/surrogates/mlp_classifier.py +++ /dev/null @@ -1,305 +0,0 @@ -from typing import Optional, Sequence - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from botorch.models.ensemble import EnsembleModel -from torch import Tensor - -from bofire.data_models.enum import OutputFilteringEnum -from bofire.data_models.surrogates.api import MLPEnsemble as DataModel -from bofire.surrogates.botorch import BotorchSurrogate -from bofire.surrogates.mlp import MLP, MLPDataset, fit_mlp -from bofire.surrogates.single_task_gp import get_scaler -from bofire.surrogates.trainable import TrainableSurrogate -from bofire.utils.torch_tools import tkwargs - - -class _MLPClassifierEnsemble(EnsembleModel): - def __init__(self, mlps: Sequence[MLP]): - super().__init__() - if len(mlps) == 0: - raise ValueError("List of mlps is empty.") - num_in_features = mlps[0].layers[0].in_features - num_out_features = mlps[0].layers[-1].out_features - for mlp in mlps: - assert mlp.layers[0].in_features == num_in_features - assert mlp.layers[-1].out_features == num_out_features - self.mlps = mlps - # put all models in eval mode - for mlp in self.mlps: - mlp.eval() - - def forward(self, X: Tensor): - r"""Assumes that the OUTPUT of the MLPs are the logits and hence we take the softmax over the last dimension here - - Args: - X: A `batch_shape x n x d`-dim input tensor `X`. - - Returns: - A `batch_shape x s x n x C`-dimensional output tensor where - `s` is the size of the ensemble and `C` is the number of classes. - """ - return torch.stack( - [nn.functional.softmax(mlp(X), dim=-1) for mlp in self.mlps], dim=-3 - ) - - @property - def num_outputs(self) -> int: - r"""The number of outputs of the model.""" - return self.mlps[0].layers[-1].out_features # type: ignore - - -class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): - def __init__(self, data_model: DataModel, **kwargs): - self.n_estimators = data_model.n_estimators - self.hidden_layer_sizes = data_model.hidden_layer_sizes - self.activation = data_model.activation - self.dropout = data_model.dropout - self.batch_size = data_model.batch_size - self.n_epochs = data_model.n_epochs - self.lr = data_model.lr - self.weight_decay = data_model.weight_decay - self.subsample_fraction = data_model.subsample_fraction - self.shuffle = data_model.shuffle - self.scaler = data_model.scaler - super().__init__(data_model, **kwargs) - - _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL - model: Optional[_MLPClassifierEnsemble] = None - - def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): - scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) - transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) - # Map dictionary of objective values to labels - label_mapping = self.outputs[0].objective.to_dict_label() - - # Convert Y to classification tensor - Y = pd.DataFrame.from_dict( - {col: Y[col].map(label_mapping) for col in Y.columns} - ) - - mlps = [] - subsample_size = round(self.subsample_fraction * X.shape[0]) - for _ in range(self.n_estimators): - # resample X and Y - sample_idx = np.random.choice(X.shape[0], replace=True, size=subsample_size) - tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) - ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) - - dataset = MLPDataset( - X=scaler.transform(tX) if scaler is not None else tX, - y=ty, - ) - mlp = MLP( - input_size=transformed_X.shape[1], - output_size=len( - label_mapping - ), # Set outputs based on number of categories - hidden_layer_sizes=self.hidden_layer_sizes, - activation=self.activation, # type: ignore - dropout=self.dropout, - ) - fit_mlp( - mlp=mlp, - dataset=dataset, - batch_size=self.batch_size, - n_epoches=self.n_epochs, - lr=self.lr, - shuffle=self.shuffle, - weight_decay=self.weight_decay, - loss_function=nn.CrossEntropyLoss, # utilizes logits as input - ) - mlps.append(mlp) - self.model = _MLPClassifierEnsemble(mlps=mlps) - if scaler is not None: - self.model.input_transform = scaler - - -# class MLPClassifier(nn.Module): -# def __init__( -# self, -# input_size: int, -# output_size: int = 1, -# hidden_layer_sizes: Sequence = (100,), -# dropout: float = 0.0, -# activation: Literal["relu", "logistic", "tanh"] = "relu", -# ): -# super().__init__() -# if activation == "relu": -# f_activation = nn.ReLU -# elif activation == "logistic": -# f_activation = nn.Sigmoid -# elif activation == "tanh": -# f_activation = nn.Tanh -# else: -# raise ValueError(f"Activation {activation} not known.") -# layers = [ -# nn.Linear(input_size, hidden_layer_sizes[0]).to(**tkwargs), -# f_activation(), -# ] -# if dropout > 0.0: -# layers.append(nn.Dropout(dropout)) -# if len(hidden_layer_sizes) > 1: -# for i in range(len(hidden_layer_sizes) - 1): -# layers += [ -# nn.Linear(hidden_layer_sizes[i], hidden_layer_sizes[i + 1]).to( -# **tkwargs -# ), -# f_activation(), -# ] -# if dropout > 0.0: -# layers.append(nn.Dropout(dropout)) -# layers.append(nn.Linear(hidden_layer_sizes[-1], output_size).to(**tkwargs)) -# self.layers = nn.Sequential(*layers) - -# def forward(self, x): -# return nn.functional.softmax(self.layers(x), dim=1) - - -# class _MLPClassifierEnsemble(EnsembleModel): -# def __init__(self, mlps: Sequence[MLPClassifier]): -# super().__init__() -# if len(mlps) == 0: -# raise ValueError("List of mlps is empty.") -# num_in_features = mlps[0].layers[0].in_features -# num_out_features = mlps[0].layers[-1].out_features -# for mlp in mlps: -# assert mlp.layers[0].in_features == num_in_features -# assert mlp.layers[-1].out_features == num_out_features -# self.mlps = mlps -# # put all models in eval mode -# for mlp in self.mlps: -# mlp.eval() - -# def forward(self, X: Tensor): -# r"""Compute the model output at X. - -# Args: -# X: A `batch_shape x n x d`-dim input tensor `X`. - -# Returns: -# A `batch_shape x s x n x C`-dimensional output tensor where -# `s` is the size of the ensemble and `C` is the number of classes. -# """ -# return torch.stack([mlp(X).exp() for mlp in self.mlps], dim=-3) - -# @property -# def num_outputs(self) -> int: -# r"""The number of outputs of the model.""" -# return self.mlps[0].layers[-1].out_features # type: ignore - - -# def fit_mlp( -# mlp: MLPClassifier, -# dataset: ClassificationDataSet, -# batch_size: int = 10, -# n_epoches: int = 200, -# lr: float = 1e-3, -# shuffle: bool = True, -# weight_decay: float = 0.0, -# ): -# """Fit a MLP for classification to a dataset. - -# Args: -# mlp (MLP): The MLP that should be fitted. -# dataset (ClassificationDataSet): The data that should be fitted -# batch_size (int, optional): Batch size. Defaults to 10. -# n_epoches (int, optional): Number of training epoches. Defaults to 200. -# lr (float, optional): Initial learning rate. Defaults to 1e-4. -# shuffle (bool, optional): Whereas the batches should be shuffled. Defaults to True. -# weight_decay (float, optional): Weight decay (L2 regularization). Defaults to 0.0 (no regularization). -# """ -# mlp.train() -# train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) -# loss_function = nn.NLLLoss(reduction="mean") -# optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) -# for _ in range(n_epoches): -# current_loss = 0.0 -# for data in train_loader: -# # Get and prepare inputs -# inputs, targets = data - -# # Zero the gradients -# optimizer.zero_grad() - -# # Perform forward pass -# outputs = mlp(inputs) - -# # Compute loss -# loss = loss_function(outputs, targets.flatten().long()) - -# # Perform backward pass -# loss.backward() - -# # Perform optimization -# optimizer.step() - -# # Print statistics -# current_loss += loss.item() - - -# class MLPClassifierEnsemble(BotorchSurrogate, TrainableSurrogate): -# def __init__(self, data_model: DataModel, **kwargs): -# self.n_estimators = data_model.n_estimators -# self.hidden_layer_sizes = data_model.hidden_layer_sizes -# self.activation = data_model.activation -# self.dropout = data_model.dropout -# self.batch_size = data_model.batch_size -# self.n_epochs = data_model.n_epochs -# self.lr = data_model.lr -# self.weight_decay = data_model.weight_decay -# self.subsample_fraction = data_model.subsample_fraction -# self.shuffle = data_model.shuffle -# self.scaler = data_model.scaler -# super().__init__(data_model, **kwargs) - -# _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL -# model: Optional[_MLPClassifierEnsemble] = None - -# def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): -# scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X) -# transformed_X = self.inputs.transform(X, self.input_preprocessing_specs) -# # Map dictionary to objective values - gives what is feasible and to labels - to perform opt -# label_mapping = self.outputs[0].to_dict_label() - -# # Convert Y to classification tensor -# Y = pd.DataFrame.from_dict( -# {col: Y[col].map(label_mapping) for col in Y.columns} -# ) - -# mlps = [] -# subsample_size = round(self.subsample_fraction * X.shape[0]) -# for _ in range(self.n_estimators): -# # resample X and Y -# sample_idx = np.random.choice(X.shape[0], replace=True, size=subsample_size) -# tX = torch.from_numpy(transformed_X.values[sample_idx]).to(**tkwargs) -# ty = torch.from_numpy(Y.values[sample_idx]).to(**tkwargs) - -# dataset = ClassificationDataSet( -# X=scaler.transform(tX) if scaler is not None else tX, -# y=ty, -# ) -# mlp = MLPClassifier( -# input_size=transformed_X.shape[1], -# output_size=len( -# label_mapping -# ), # Set outputs based on number of categories -# hidden_layer_sizes=self.hidden_layer_sizes, -# activation=self.activation, # type: ignore -# dropout=self.dropout, -# ) -# fit_mlp( -# mlp=mlp, -# dataset=dataset, -# batch_size=self.batch_size, -# n_epoches=self.n_epochs, -# lr=self.lr, -# shuffle=self.shuffle, -# weight_decay=self.weight_decay, -# ) -# mlps.append(mlp) -# self.model = _MLPClassifierEnsemble(mlps=mlps) -# if scaler is not None: -# self.model.input_transform = scaler diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index 5a61acbc3..5bfce96c9 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -5,6 +5,7 @@ import pandas as pd from bofire.data_models.domain.domain import is_numeric +from bofire.data_models.features.api import CategoricalOutput from bofire.data_models.surrogates.api import Surrogate as DataModel from bofire.surrogates.values import PredictedValue @@ -44,27 +45,88 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: Xt[c] = pd.to_numeric(Xt[c], errors="raise") # predict preds, stds = self._predict(Xt) + # set up column names + columns = [] + for featkey in self.outputs.get_keys(): + if isinstance(self.outputs.get_by_key(featkey), CategoricalOutput): + columns = ( + columns + + [ + f"{featkey}_{cat}_prob" + for cat in self.outputs.get_by_key(featkey).categories + ] + + [ + f"{featkey}_{cat}_sd" + for cat in self.outputs.get_by_key(featkey).categories + ] + ) + else: + columns = ( + columns + + [f"{featkey}_pred" for featkey in self.outputs.get_keys()] + + [f"{featkey}_sd" for featkey in self.outputs.get_keys()] + ) # postprocess predictions = pd.DataFrame( data=np.hstack((preds, stds)), - columns=["%s_pred" % featkey for featkey in self.outputs.get_keys()] - + ["%s_sd" % featkey for featkey in self.outputs.get_keys()], + columns=columns, ) + # append predictions for categorical cases + for feat in self.outputs.get(): + if isinstance(feat, CategoricalOutput): + predictions.insert( + loc=0, + column=f"{feat.key}_pred", + value=predictions.filter(regex=f"{feat.key}(.*)_prob") + .idxmax(1) + .str.replace(f"{feat.key}_", "") + .str.replace("_prob", "") + .values, + ) + predictions.insert( + loc=1, + column=f"{feat.key}_sd", + value=predictions.filter(regex=f"{feat.key}(.*)_sd") + .pow(2.0) + .sum(1) + .pow(0.5) + .values, + ) # validate self.validate_predictions(predictions=predictions) # return return predictions def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: - expected_cols = [ - f"{key}_{t}" for key in self.outputs.get_keys() for t in ["pred", "sd"] - ] + expected_cols = [] + for featkey in self.outputs.get_keys(): + if isinstance(self.outputs.get_by_key(featkey), CategoricalOutput): + expected_cols = ( + expected_cols + + [f"{featkey}_{t}" for t in ["pred", "sd"]] + + [ + f"{featkey}_{cat}_prob" + for cat in self.outputs.get_by_key(featkey).categories + ] + + [ + f"{featkey}_{cat}_sd" + for cat in self.outputs.get_by_key(featkey).categories + ] + ) + check_columns = [ + col for col in expected_cols if col != f"{featkey}_pred" + ] + else: + expected_cols = expected_cols + [ + f"{featkey}_{t}" for t in ["pred", "sd"] + ] + check_columns = expected_cols if sorted(predictions.columns) != sorted(expected_cols): raise ValueError( f"Predictions are ill-formatted. Expected: {expected_cols}, got: {list(predictions.columns)}." ) # check that values are numeric - if not is_numeric(predictions): + if not is_numeric(predictions[check_columns]): raise ValueError("Not all values in predictions are numeric.") return predictions diff --git a/bofire/surrogates/trainable.py b/bofire/surrogates/trainable.py index 62caeed0c..794745c8a 100644 --- a/bofire/surrogates/trainable.py +++ b/bofire/surrogates/trainable.py @@ -12,6 +12,7 @@ ContinuousOutput, DiscreteInput, ) +from bofire.data_models.objectives.api import CategoricalObjective from bofire.surrogates.diagnostics import CvResult, CvResults from bofire.surrogates.surrogate import Surrogate @@ -181,6 +182,20 @@ def cross_validate( # now do the scoring y_test_pred = self.predict(X_test) # type: ignore y_train_pred = self.predict(X_train) # type: ignore + + # Convert to categorical if applicable + if isinstance(self.outputs[0].objective, CategoricalObjective): + y_test_pred[f"{key}_pred"] = y_test_pred[f"{key}_pred"].map( + self.outputs[0].objective.to_dict_label() + ) + y_train_pred[f"{key}_pred"] = y_train_pred[f"{key}_pred"].map( + self.outputs[0].objective.to_dict_label() + ) + y_test[key] = y_test[key].map(self.outputs[0].objective.to_dict_label()) + y_train[key] = y_train[key].map( + self.outputs[0].objective.to_dict_label() + ) + # now store the results train_results.append( CvResult( # type: ignore diff --git a/bofire/surrogates/values.py b/bofire/surrogates/values.py index 70845fd75..989f3e042 100644 --- a/bofire/surrogates/values.py +++ b/bofire/surrogates/values.py @@ -1,3 +1,5 @@ +from typing import Union + from pydantic import Field from typing_extensions import Annotated @@ -15,5 +17,5 @@ class PredictedValue(BaseModel): Has to be greater/equal than zero. """ - predictedValue: float + predictedValue: Union[float, str] standardDeviation: Annotated[float, Field(ge=0)] diff --git a/tests/bofire/data_models/domain/test_outputs.py b/tests/bofire/data_models/domain/test_outputs.py index 6f1b2456b..2c639e0bb 100644 --- a/tests/bofire/data_models/domain/test_outputs.py +++ b/tests/bofire/data_models/domain/test_outputs.py @@ -258,7 +258,7 @@ def test_get_outputs_by_objective_none(): key="of4", categories=["a", "b"], objective=ConstrainedCategoricalObjective( - categories=("a", "b"), desirability=(True, False) + categories=["a", "b"], desirability=[True, False] ), ), ] @@ -285,3 +285,37 @@ def test_outputs_call(features, samples): ) + features.get_keys(CategoricalOutput) ] + + +def test_categorical_objective_methods(): + obj = ConstrainedCategoricalObjective( + categories=["a", "b"], desirability=[True, False] + ) + assert {"a": True, "b": False} == obj.to_dict() + assert {"a": 0, "b": 1} == obj.to_dict_label() + assert {0: "a", 1: "b"} == obj.from_dict_label() + + +def test_categorical_output_methods(): + outputs = Outputs( + features=[ + of1, + of2, + of3, + CategoricalOutput( + key="of4", + categories=["a", "b"], + objective=ConstrainedCategoricalObjective( + categories=["a", "b"], desirability=[True, False] + ), + ), + ] + ) + + # Test the `get_keys_by_objective` + assert ["of1", "of2"] == outputs.get_keys_by_objective( + includes=Objective, excludes=ConstrainedObjective + ) + assert ["of4"] == outputs.get_keys_by_objective( + includes=ConstrainedObjective, excludes=None + ) diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index 8c6615d4a..fb13f5936 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -125,9 +125,9 @@ features.CategoricalOutput, lambda: { "key": str(uuid.uuid4()), - "categories": ("a", "b", "c"), + "categories": ["a", "b", "c"], "objective": ConstrainedCategoricalObjective( - categories=("a", "b", "c"), desirability=(True, True, False) + categories=["a", "b", "c"], desirability=[True, True, False] ), }, ) @@ -151,11 +151,3 @@ "allowed": [True, True, True, True], }, ) -specs.add_valid( - features.CategoricalOutput, - lambda: { - "key": str(uuid.uuid4()), - "categories": ["a", "b", "c"], - "objective": [0.0, 1.0, 0.0], - }, -) diff --git a/tests/bofire/data_models/specs/objectives.py b/tests/bofire/data_models/specs/objectives.py index 6812eaf66..85f306bec 100644 --- a/tests/bofire/data_models/specs/objectives.py +++ b/tests/bofire/data_models/specs/objectives.py @@ -46,3 +46,24 @@ "steepness": 0.3, }, ) + +specs.add_valid( + objectives.ConstrainedCategoricalObjective, + lambda: { + "w": 1.0, + "categories": ["green", "red", "blue"], + "desirability": [True, False, True], + "eta": 1.0, + }, +) +specs.add_invalid( + objectives.ConstrainedCategoricalObjective, + lambda: { + "w": 1.0, + "categories": ["green", "red", "blue"], + "desirability": [True, False, True, False], + "eta": 1.0, + }, + error=ValueError, + message="number of categories differs from number of desirabilities", +) diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index 047e7fb17..71000122b 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -5,6 +5,7 @@ from bofire.data_models.enum import CategoricalEncodingEnum from bofire.data_models.features.api import ( CategoricalInput, + CategoricalOutput, ContinuousInput, ContinuousOutput, MolecularInput, @@ -158,7 +159,7 @@ }, ) specs.add_valid( - models.MLPEnsemble, + models.RegressionMLPEnsemble, lambda: { "inputs": Inputs( features=[ @@ -188,6 +189,104 @@ "hyperconfig": None, }, ) +specs.add_invalid( + models.RegressionMLPEnsemble, + lambda: { + "inputs": Inputs( + features=[ + features.valid(ContinuousInput).obj(), + ] + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(CategoricalOutput).obj(), + ] + ).model_dump(), + "aggregations": None, + "n_estimators": 2, + "hidden_layer_sizes": (100,), + "activation": "relu", + "dropout": 0.0, + "batch_size": 10, + "n_epochs": 200, + "lr": 1e-4, + "weight_decay": 0.0, + "subsample_fraction": 1.0, + "shuffle": True, + "scaler": ScalerEnum.NORMALIZE, + "output_scaler": ScalerEnum.STANDARDIZE, + "input_preprocessing_specs": {}, + "dump": None, + "hyperconfig": None, + }, + error=ValueError, + message="Invalid output type passed.", +) + +specs.add_valid( + models.ClassificationMLPEnsemble, + lambda: { + "inputs": Inputs( + features=[ + features.valid(ContinuousInput).obj(), + ] + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(CategoricalOutput).obj(), + ] + ).model_dump(), + "aggregations": None, + "n_estimators": 2, + "hidden_layer_sizes": (100,), + "activation": "relu", + "dropout": 0.0, + "batch_size": 10, + "n_epochs": 200, + "lr": 1e-4, + "weight_decay": 0.0, + "subsample_fraction": 1.0, + "shuffle": True, + "scaler": ScalerEnum.NORMALIZE, + "output_scaler": ScalerEnum.STANDARDIZE, + "input_preprocessing_specs": {}, + "dump": None, + "hyperconfig": None, + }, +) +specs.add_invalid( + models.ClassificationMLPEnsemble, + lambda: { + "inputs": Inputs( + features=[ + features.valid(ContinuousInput).obj(), + ] + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ] + ).model_dump(), + "aggregations": None, + "n_estimators": 2, + "hidden_layer_sizes": (100,), + "activation": "relu", + "dropout": 0.0, + "batch_size": 10, + "n_epochs": 200, + "lr": 1e-4, + "weight_decay": 0.0, + "subsample_fraction": 1.0, + "shuffle": True, + "scaler": ScalerEnum.NORMALIZE, + "output_scaler": ScalerEnum.STANDARDIZE, + "input_preprocessing_specs": {}, + "dump": None, + "hyperconfig": None, + }, + error=ValueError, + message="Invalid output type passed.", +) specs.add_valid( models.XGBoostSurrogate, diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 947549c37..8f8e203df 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -10,7 +10,9 @@ "\n", "This involves new models that produce `CategoricalOutput`'s rather than continuous outputs. Mathematically, if $g_{\\theta}:\\mathbb{R}^d\\to[0,1]^c$ represents the function governed by learnable parameters $\\theta$ which outputs a probability vector over $c$ potential classes (i.e. for input $x\\in\\mathbb{R}^d$, $g_{\\theta}(x)^\\top\\mathbf{1}=1$ where $\\mathbf{1}$ is the vector of all 1's) and we have acceptibility criteria for the corresponding classes given by $a\\in\\{0,1\\}^c$, we can compute the scalar output $g_{\\theta}(x)^\\top a\\in[0,1]$ which represents the expected value of acceptance as an objective value to be passed in as a constrained function.\n", "\n", - "In this script, we look at a modified and constrained version of the optimization problem associated with the [Levy function](https://www.sfu.ca/~ssurjano/levy.html), which has a global minima at $x^*=\\mathbf{1}$. We classify constraints for three classes: 'acceptable', 'unacceptable', and 'ideal' based on how close we are to the optimal decision variable; obviously, this value is unknown in a real-world setting, but this serves as a reasonable example." + "In this script, we look at a modified and constrained version of the optimization problem associated with the [Levy function](https://www.sfu.ca/~ssurjano/levy.html), which has a global minima at $x^*=\\mathbf{1}$. We classify constraints for three classes: 'acceptable', 'unacceptable', and 'ideal' based on how close we are to the optimal decision variable; obviously, this value is unknown in a real-world setting, but this serves as a reasonable example.\n", + "\n", + "Initially, this script contains an example of JUST training the classification surrogate on the generated data." ] }, { @@ -23,7 +25,9 @@ "output_type": "stream", "text": [ "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + " from .autonotebook import tqdm as notebook_tqdm\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\pydantic\\_migration.py:283: UserWarning: `pydantic.error_wrappers:ValidationError` has been moved to `pydantic:ValidationError`.\n", + " warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')\n" ] } ], @@ -98,292 +102,292 @@ " \n", " \n", " 0\n", - " 0.043387\n", - " 0.471935\n", - " -1.903405\n", - " -1.505071\n", - " -1.631395\n", - " 0.0\n", - " 12.885008\n", - " unacceptable\n", - " 0.049104\n", + " -0.310698\n", + " 0.392037\n", + " -1.806570\n", + " -1.981755\n", + " 1.384988\n", + " 0\n", + " 11.628122\n", + " acceptable\n", + " -0.310070\n", " \n", " \n", " 1\n", - " -0.516480\n", - " -0.519954\n", - " 0.172009\n", - " -0.671419\n", - " 0.576339\n", - " 1.0\n", - " 1.743262\n", - " unacceptable\n", - " -0.511997\n", + " -1.489428\n", + " -1.288171\n", + " -0.505947\n", + " -0.384544\n", + " 0.425808\n", + " 1\n", + " 6.224438\n", + " acceptable\n", + " -1.485448\n", " \n", " \n", " 2\n", - " 1.420828\n", - " -0.997470\n", - " -0.869831\n", - " -0.603185\n", - " 0.988815\n", - " 0.0\n", - " 2.214841\n", - " unacceptable\n", - " 1.426101\n", + " 1.461134\n", + " 0.337194\n", + " -0.082360\n", + " -1.730603\n", + " -1.588481\n", + " 0\n", + " 8.278724\n", + " ideal\n", + " 1.467232\n", " \n", " \n", " 3\n", - " -1.729386\n", - " -0.517152\n", - " 0.792940\n", - " 1.841196\n", - " -1.018597\n", - " 1.0\n", - " 6.757267\n", + " -1.781506\n", + " -1.436844\n", + " 1.930972\n", + " 1.783539\n", + " -0.900358\n", + " 1\n", + " 9.733454\n", " unacceptable\n", - " -1.723294\n", + " -1.779890\n", " \n", " \n", " 4\n", - " 0.137756\n", - " -1.910416\n", - " 0.266167\n", - " 1.376514\n", - " 1.853052\n", - " 0.0\n", - " 6.659112\n", + " 0.665722\n", + " -1.744514\n", + " 0.446853\n", + " 0.700885\n", + " 0.651237\n", + " 1\n", + " 4.665782\n", " acceptable\n", - " 0.146894\n", + " 0.668620\n", " \n", " \n", " 5\n", - " 0.615494\n", - " 0.638574\n", - " -0.846471\n", - " 0.486294\n", - " -1.844437\n", - " 0.0\n", - " 5.965467\n", - " unacceptable\n", - " 0.623373\n", + " -1.499912\n", + " -0.747193\n", + " 0.664655\n", + " -1.011433\n", + " -1.354635\n", + " 1\n", + " 7.671702\n", + " ideal\n", + " -1.491068\n", " \n", " \n", " 6\n", - " 0.577691\n", - " -1.316748\n", - " -0.598852\n", - " 1.499971\n", - " 0.435526\n", - " 1.0\n", - " 2.791735\n", - " acceptable\n", - " 0.586726\n", + " -1.543055\n", + " 1.962503\n", + " 1.798311\n", + " 0.168173\n", + " -1.495112\n", + " 1\n", + " 8.231485\n", + " unacceptable\n", + " -1.542496\n", " \n", " \n", " 7\n", - " -1.382985\n", - " -1.629080\n", - " 0.957606\n", - " 1.263108\n", - " -1.248810\n", - " 1.0\n", - " 8.962357\n", - " unacceptable\n", - " -1.376337\n", + " -0.708194\n", + " 1.516163\n", + " -1.662763\n", + " 0.717943\n", + " -0.500210\n", + " 0\n", + " 5.660023\n", + " ideal\n", + " -0.703393\n", " \n", " \n", " 8\n", - " -0.732977\n", - " -1.635727\n", - " -1.165820\n", - " -1.912441\n", - " 1.095427\n", - " 1.0\n", - " 12.089185\n", - " unacceptable\n", - " -0.723748\n", + " -0.146424\n", + " -0.733847\n", + " -1.398632\n", + " -0.390983\n", + " -1.966037\n", + " 1\n", + " 9.502397\n", + " ideal\n", + " -0.142366\n", " \n", " \n", " 9\n", - " -1.324801\n", - " -1.119786\n", - " 1.021565\n", - " -0.304530\n", - " 0.425360\n", - " 1.0\n", - " 4.633603\n", - " unacceptable\n", - " -1.322161\n", + " -1.435747\n", + " -0.197231\n", + " -0.085655\n", + " 1.959668\n", + " -1.492488\n", + " 0\n", + " 7.361934\n", + " ideal\n", + " -1.429517\n", " \n", " \n", " 10\n", - " 0.957669\n", - " -1.272092\n", - " -0.461742\n", - " 1.557717\n", - " 1.954150\n", - " 0.0\n", - " 3.006800\n", + " -0.672001\n", + " 0.989061\n", + " 0.625807\n", + " 1.945203\n", + " -1.505085\n", + " 0\n", + " 4.971067\n", " acceptable\n", - " 0.965879\n", + " -0.671859\n", " \n", " \n", " 11\n", - " -1.609909\n", - " -1.972229\n", - " -0.950002\n", - " -1.645930\n", - " -0.068231\n", - " 0.0\n", - " 15.049384\n", - " unacceptable\n", - " -1.602951\n", + " -0.930565\n", + " -1.942693\n", + " 0.550933\n", + " -0.877697\n", + " -1.120742\n", + " 0\n", + " 9.486661\n", + " ideal\n", + " -0.930336\n", " \n", " \n", " 12\n", - " -0.872965\n", - " -0.263758\n", - " -1.418009\n", - " 1.434761\n", - " -0.156978\n", - " 1.0\n", - " 4.579632\n", + " -1.762717\n", + " 0.738706\n", + " -1.863998\n", + " -1.565166\n", + " 0.792053\n", + " 1\n", + " 13.751706\n", " unacceptable\n", - " -0.865004\n", + " -1.757165\n", " \n", " \n", " 13\n", - " -1.650395\n", - " -1.291974\n", - " -1.693253\n", - " 1.404619\n", - " -0.617787\n", - " 1.0\n", - " 11.135251\n", - " unacceptable\n", - " -1.641277\n", + " -1.863466\n", + " 0.404012\n", + " 0.844637\n", + " -1.966831\n", + " -0.169374\n", + " 0\n", + " 11.650621\n", + " ideal\n", + " -1.857675\n", " \n", " \n", " 14\n", - " -0.122850\n", - " 0.062456\n", - " -1.093654\n", - " 0.762955\n", - " -0.608164\n", - " 1.0\n", - " 2.333336\n", - " unacceptable\n", - " -0.116326\n", + " -0.981088\n", + " 1.761401\n", + " -0.346133\n", + " -0.842548\n", + " -0.307780\n", + " 0\n", + " 3.160072\n", + " acceptable\n", + " -0.972805\n", " \n", " \n", " 15\n", - " -0.343801\n", - " 1.969565\n", - " -0.400887\n", - " 1.546484\n", - " 0.955387\n", - " 0.0\n", - " 1.926944\n", + " 0.703509\n", + " -1.285697\n", + " 0.011635\n", + " 1.999545\n", + " 0.608979\n", + " 0\n", + " 2.874619\n", " ideal\n", - " -0.335117\n", + " 0.710066\n", " \n", " \n", " 16\n", - " -1.023179\n", - " -0.506223\n", - " -1.442436\n", - " 1.881762\n", - " 0.702930\n", - " 0.0\n", - " 5.610748\n", - " unacceptable\n", - " -1.015035\n", + " -0.262847\n", + " -0.670558\n", + " -1.676711\n", + " 1.346576\n", + " 0.734364\n", + " 1\n", + " 5.254699\n", + " ideal\n", + " -0.255986\n", " \n", " \n", " 17\n", - " 0.806867\n", - " 0.223890\n", - " 1.955296\n", - " 1.470794\n", - " 1.646253\n", - " 0.0\n", - " 1.263481\n", - " ideal\n", - " 0.814321\n", + " 1.949815\n", + " -0.132073\n", + " -1.921915\n", + " 1.955173\n", + " -0.510576\n", + " 0\n", + " 7.427848\n", + " acceptable\n", + " 1.951420\n", " \n", " \n", " 18\n", - " -0.063949\n", - " -1.605616\n", - " -1.694294\n", - " 1.014650\n", - " -0.542840\n", - " 0.0\n", - " 8.664851\n", - " unacceptable\n", - " -0.063911\n", + " 1.352886\n", + " -0.307143\n", + " 0.472704\n", + " -1.201528\n", + " -1.226323\n", + " 0\n", + " 3.723930\n", + " ideal\n", + " 1.356302\n", " \n", " \n", " 19\n", - " -0.791231\n", - " 0.269571\n", - " -0.625560\n", - " 1.688765\n", - " 0.347797\n", - " 1.0\n", - " 2.277518\n", + " -1.093545\n", + " 1.693383\n", + " 0.291218\n", + " -0.979725\n", + " -1.454517\n", + " 0\n", + " 6.389192\n", " acceptable\n", - " -0.788239\n", + " -1.092528\n", " \n", " \n", "\n", "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 0.043387 0.471935 -1.903405 -1.505071 -1.631395 0.0 12.885008 \n", - "1 -0.516480 -0.519954 0.172009 -0.671419 0.576339 1.0 1.743262 \n", - "2 1.420828 -0.997470 -0.869831 -0.603185 0.988815 0.0 2.214841 \n", - "3 -1.729386 -0.517152 0.792940 1.841196 -1.018597 1.0 6.757267 \n", - "4 0.137756 -1.910416 0.266167 1.376514 1.853052 0.0 6.659112 \n", - "5 0.615494 0.638574 -0.846471 0.486294 -1.844437 0.0 5.965467 \n", - "6 0.577691 -1.316748 -0.598852 1.499971 0.435526 1.0 2.791735 \n", - "7 -1.382985 -1.629080 0.957606 1.263108 -1.248810 1.0 8.962357 \n", - "8 -0.732977 -1.635727 -1.165820 -1.912441 1.095427 1.0 12.089185 \n", - "9 -1.324801 -1.119786 1.021565 -0.304530 0.425360 1.0 4.633603 \n", - "10 0.957669 -1.272092 -0.461742 1.557717 1.954150 0.0 3.006800 \n", - "11 -1.609909 -1.972229 -0.950002 -1.645930 -0.068231 0.0 15.049384 \n", - "12 -0.872965 -0.263758 -1.418009 1.434761 -0.156978 1.0 4.579632 \n", - "13 -1.650395 -1.291974 -1.693253 1.404619 -0.617787 1.0 11.135251 \n", - "14 -0.122850 0.062456 -1.093654 0.762955 -0.608164 1.0 2.333336 \n", - "15 -0.343801 1.969565 -0.400887 1.546484 0.955387 0.0 1.926944 \n", - "16 -1.023179 -0.506223 -1.442436 1.881762 0.702930 0.0 5.610748 \n", - "17 0.806867 0.223890 1.955296 1.470794 1.646253 0.0 1.263481 \n", - "18 -0.063949 -1.605616 -1.694294 1.014650 -0.542840 0.0 8.664851 \n", - "19 -0.791231 0.269571 -0.625560 1.688765 0.347797 1.0 2.277518 \n", + " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", + "0 -0.310698 0.392037 -1.806570 -1.981755 1.384988 0 11.628122 \n", + "1 -1.489428 -1.288171 -0.505947 -0.384544 0.425808 1 6.224438 \n", + "2 1.461134 0.337194 -0.082360 -1.730603 -1.588481 0 8.278724 \n", + "3 -1.781506 -1.436844 1.930972 1.783539 -0.900358 1 9.733454 \n", + "4 0.665722 -1.744514 0.446853 0.700885 0.651237 1 4.665782 \n", + "5 -1.499912 -0.747193 0.664655 -1.011433 -1.354635 1 7.671702 \n", + "6 -1.543055 1.962503 1.798311 0.168173 -1.495112 1 8.231485 \n", + "7 -0.708194 1.516163 -1.662763 0.717943 -0.500210 0 5.660023 \n", + "8 -0.146424 -0.733847 -1.398632 -0.390983 -1.966037 1 9.502397 \n", + "9 -1.435747 -0.197231 -0.085655 1.959668 -1.492488 0 7.361934 \n", + "10 -0.672001 0.989061 0.625807 1.945203 -1.505085 0 4.971067 \n", + "11 -0.930565 -1.942693 0.550933 -0.877697 -1.120742 0 9.486661 \n", + "12 -1.762717 0.738706 -1.863998 -1.565166 0.792053 1 13.751706 \n", + "13 -1.863466 0.404012 0.844637 -1.966831 -0.169374 0 11.650621 \n", + "14 -0.981088 1.761401 -0.346133 -0.842548 -0.307780 0 3.160072 \n", + "15 0.703509 -1.285697 0.011635 1.999545 0.608979 0 2.874619 \n", + "16 -0.262847 -0.670558 -1.676711 1.346576 0.734364 1 5.254699 \n", + "17 1.949815 -0.132073 -1.921915 1.955173 -0.510576 0 7.427848 \n", + "18 1.352886 -0.307143 0.472704 -1.201528 -1.226323 0 3.723930 \n", + "19 -1.093545 1.693383 0.291218 -0.979725 -1.454517 0 6.389192 \n", "\n", " f_1 f_2 \n", - "0 unacceptable 0.049104 \n", - "1 unacceptable -0.511997 \n", - "2 unacceptable 1.426101 \n", - "3 unacceptable -1.723294 \n", - "4 acceptable 0.146894 \n", - "5 unacceptable 0.623373 \n", - "6 acceptable 0.586726 \n", - "7 unacceptable -1.376337 \n", - "8 unacceptable -0.723748 \n", - "9 unacceptable -1.322161 \n", - "10 acceptable 0.965879 \n", - "11 unacceptable -1.602951 \n", - "12 unacceptable -0.865004 \n", - "13 unacceptable -1.641277 \n", - "14 unacceptable -0.116326 \n", - "15 ideal -0.335117 \n", - "16 unacceptable -1.015035 \n", - "17 ideal 0.814321 \n", - "18 unacceptable -0.063911 \n", - "19 acceptable -0.788239 " + "0 acceptable -0.310070 \n", + "1 acceptable -1.485448 \n", + "2 ideal 1.467232 \n", + "3 unacceptable -1.779890 \n", + "4 acceptable 0.668620 \n", + "5 ideal -1.491068 \n", + "6 unacceptable -1.542496 \n", + "7 ideal -0.703393 \n", + "8 ideal -0.142366 \n", + "9 ideal -1.429517 \n", + "10 acceptable -0.671859 \n", + "11 ideal -0.930336 \n", + "12 unacceptable -1.757165 \n", + "13 ideal -1.857675 \n", + "14 acceptable -0.972805 \n", + "15 ideal 0.710066 \n", + "16 ideal -0.255986 \n", + "17 acceptable 1.951420 \n", + "18 ideal 1.356302 \n", + "19 acceptable -1.092528 " ] }, "execution_count": 3, @@ -393,12 +397,12 @@ ], "source": [ "# Set-up the inputs and outputs, use categorical domain just as an example\n", - "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(-2, 2)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=(0.0, 1.0))])\n", + "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(-2, 2)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=[\"0\", \"1\"], allowed=[True, True])])\n", "\n", "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", "output_features = Outputs(features=[\n", " ContinuousOutput(key=f\"f_{0}\", objective=MinimizeObjective(w=1.)),\n", - " CategoricalOutput(key=f\"f_{1}\", categories=(\"unacceptable\", \"acceptable\", \"ideal\"), objective=ConstrainedCategoricalObjective(categories=(\"unacceptable\", \"acceptable\", \"ideal\"), desirability=(False, True, True))), # This function will be associated with learning the categories\n", + " CategoricalOutput(key=f\"f_{1}\", categories=[\"unacceptable\", \"acceptable\", \"ideal\"], objective=ConstrainedCategoricalObjective(categories=[\"unacceptable\", \"acceptable\", \"ideal\"], desirability=[False, True, True])), # This function will be associated with learning the categories\n", " ContinuousOutput(key=f\"f_{2}\", objective=MinimizeSigmoidObjective(w=1., tp=0.0, steepness=0.5)),\n", " ]\n", ")\n", @@ -407,18 +411,182 @@ "domain1 = Domain(inputs=input_features, outputs=output_features)\n", "\n", "# Sample random points\n", - "sample_df = domain1.inputs.sample(50).astype(float) # Sample x's\n", + "sample_df = domain1.inputs.sample(50)\n", "\n", "# Write a function which outputs one continuous variable and another discrete based on some logic\n", - "sample_df[\"f_0\"] = np.sin(np.pi * scale_inputs(sample_df[\"x_0\"])) ** 2 + sum([(scale_inputs(sample_df[col]) - 1) ** 2 * (1 + 10 * np.sin(np.pi * scale_inputs(sample_df[col]) + 1) ** 2 if ind < len(sample_df.columns) else 1 + np.sin(2 * np.pi * scale_inputs(sample_df[col])) ** 2) for ind, col in enumerate(sample_df.columns)])\n", + "sample_df[\"f_0\"] = np.sin(np.pi * scale_inputs(sample_df[\"x_0\"])) ** 2 + sum([(scale_inputs(sample_df[col]) - 1) ** 2 * (1 + 10 * np.sin(np.pi * scale_inputs(sample_df[col]) + 1) ** 2 if ind < len(sample_df.columns) else 1 + np.sin(2 * np.pi * scale_inputs(sample_df[col])) ** 2) for ind, col in enumerate(sample_df.columns) if not sample_df[col].dtype == \"O\"])\n", "sample_df[\"f_1\"] = \"unacceptable\"\n", - "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 1.5, \"f_1\"] = \"acceptable\"\n", - "sample_df.loc[sample_df[input_features.get_keys()].sum(1) >= 3.5, \"f_1\"] = \"ideal\"\n", + "sample_df.loc[(sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 6.5) * (sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 3.5), \"f_1\"] = \"acceptable\"\n", + "sample_df.loc[(sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 5.5) * (sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 4.5), \"f_1\"] = \"ideal\"\n", "sample_df[\"f_2\"] = sample_df[\"x_0\"] + 1e-2 * np.random.uniform(size=(len(sample_df),))\n", "\n", "sample_df.head(20)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate the classification model performance (outside of the optimization procedure)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# Import packages\n", + "import bofire.surrogates.api as surrogates\n", + "from bofire.data_models.surrogates.api import ClassificationMLPEnsemble\n", + "from bofire.surrogates.diagnostics import ClassificationMetricsEnum\n", + "\n", + "# Instantiate the surrogate model \n", + "model = ClassificationMLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.01, n_epochs=100, hidden_layer_sizes=(20,10,))\n", + "surrogate = surrogates.map(model)\n", + "\n", + "# Fit the model to the classification data\n", + "cv_df = sample_df.drop([\"f_0\", \"f_2\"], axis=1)\n", + "cv_df[\"valid_f_1\"] = 1\n", + "cv = surrogate.cross_validate(cv_df, folds=5)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ACCURACYF1
00.760.76
\n", + "
" + ], + "text/plain": [ + " ACCURACY F1\n", + "0 0.76 0.76" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Print results\n", + "cv[0].get_metrics(metrics=ClassificationMetricsEnum, combine_folds=True) # print training set performance" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ACCURACYF1
00.180.18
\n", + "
" + ], + "text/plain": [ + " ACCURACY F1\n", + "0 0.18 0.18" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cv[1].get_metrics(metrics=ClassificationMetricsEnum, combine_folds=True) # print test set performance" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -429,20 +597,20 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from bofire.data_models.acquisition_functions.api import qEI\n", "from bofire.data_models.strategies.api import SoboStrategy\n", - "from bofire.data_models.surrogates.api import BotorchSurrogates, MLPClassifierEnsemble, MixedSingleTaskGPSurrogate\n", + "from bofire.data_models.surrogates.api import BotorchSurrogates, ClassificationMLPEnsemble, MixedSingleTaskGPSurrogate\n", "from bofire.data_models.domain.api import Outputs\n", "\n", "strategy_data = SoboStrategy(domain=domain1, \n", " acquisition_function=qEI(), \n", " surrogate_specs=BotorchSurrogates(surrogates=\n", " [\n", - " MLPClassifierEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.1, n_epochs=100, hidden_layer_sizes=(20,10,)),\n", + " ClassificationMLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.01, n_epochs=100, hidden_layer_sizes=(20,10,)),\n", " MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", " ]\n", " )\n", @@ -455,18 +623,18 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPClassifierEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", " warnings.warn(\n" ] }, @@ -497,17 +665,18 @@ " x_3\n", " x_4\n", " x_5\n", + " f_1_pred\n", + " f_1_sd\n", " f_0_pred\n", " f_2_pred\n", - " f_1_pred\n", - " f_1_pred_unacceptable\n", - " f_1_pred_acceptable\n", - " f_1_pred_ideal\n", + " ...\n", + " f_1_acceptable_prob\n", + " f_1_ideal_prob\n", " f_0_sd\n", " f_2_sd\n", - " f_1_sd_unacceptable\n", - " f_1_sd_acceptable\n", - " f_1_sd_ideal\n", + " f_1_unacceptable_sd\n", + " f_1_acceptable_sd\n", + " f_1_ideal_sd\n", " f_0_des\n", " f_2_des\n", " f_1_des\n", @@ -516,289 +685,302 @@ " \n", " \n", " 0\n", - " 0.011502\n", - " -0.161920\n", - " 1.972217\n", - " 0.438555\n", - " 0.457098\n", - " 0.0\n", - " -0.508109\n", - " 0.018091\n", + " 0.012650\n", + " -0.082007\n", + " 0.196167\n", + " -0.068576\n", + " 2.000000\n", + " 0\n", " ideal\n", - " 0.005109\n", - " 0.257733\n", - " 0.737157\n", - " 0.424743\n", - " 0.002801\n", - " 0.004573\n", - " 0.346797\n", - " 0.344908\n", - " 0.508109\n", - " 0.497739\n", - " 0.994891\n", + " 0.585006\n", + " -0.178305\n", + " 0.017555\n", + " ...\n", + " 0.302352\n", + " 0.696724\n", + " 0.486615\n", + " 0.003456\n", + " 0.001530\n", + " 0.413975\n", + " 0.413346\n", + " 0.178305\n", + " 0.497806\n", + " 0.999076\n", " \n", " \n", " 1\n", - " 0.165405\n", - " -0.124200\n", - " -0.283804\n", - " 0.356975\n", - " 0.652289\n", - " 1.0\n", - " -0.549702\n", - " 0.171525\n", - " acceptable\n", - " 0.001980\n", - " 0.996553\n", - " 0.001467\n", - " 0.365422\n", - " 0.002720\n", - " 0.001716\n", - " 0.003879\n", - " 0.002428\n", - " 0.549702\n", - " 0.478572\n", - " 0.998020\n", + " 0.006090\n", + " -0.110946\n", + " 0.036496\n", + " -0.075202\n", + " 1.376258\n", + " 0\n", + " ideal\n", + " 0.581150\n", + " -0.033201\n", + " 0.011008\n", + " ...\n", + " 0.285243\n", + " 0.714701\n", + " 0.317448\n", + " 0.003389\n", + " 0.000087\n", + " 0.410947\n", + " 0.410923\n", + " 0.033201\n", + " 0.498624\n", + " 0.999944\n", " \n", " \n", " 2\n", - " -0.008095\n", - " -0.188512\n", - " 1.662439\n", - " 0.346688\n", - " 0.430895\n", - " 1.0\n", - " -0.371163\n", - " -0.001649\n", + " 0.161220\n", + " -0.056338\n", + " -0.295114\n", + " -0.061665\n", + " 0.096058\n", + " 0\n", " acceptable\n", - " 0.000494\n", - " 0.578223\n", - " 0.421283\n", - " 0.418334\n", - " 0.002777\n", - " 0.000678\n", - " 0.484275\n", - " 0.484395\n", - " 0.371163\n", - " 0.500206\n", - " 0.999506\n", + " 0.745874\n", + " 0.011562\n", + " 0.166371\n", + " ...\n", + " 0.591001\n", + " 0.408996\n", + " 0.271178\n", + " 0.003351\n", + " 0.000005\n", + " 0.527414\n", + " 0.527411\n", + " -0.011562\n", + " 0.479216\n", + " 0.999998\n", " \n", " \n", " 3\n", - " -0.034115\n", - " 2.000000\n", - " -0.095091\n", - " -0.378793\n", - " 2.000000\n", - " 0.0\n", - " 0.887337\n", - " -0.029129\n", + " -0.027612\n", + " 0.391135\n", + " 0.121035\n", + " 1.256868\n", + " 1.784862\n", + " 0\n", " ideal\n", - " 0.001376\n", - " 0.136712\n", - " 0.861912\n", - " 0.899251\n", - " 0.002871\n", - " 0.003050\n", - " 0.303232\n", - " 0.302495\n", - " -0.887337\n", - " 0.503641\n", - " 0.998624\n", + " 0.599128\n", + " 0.209979\n", + " -0.022735\n", + " ...\n", + " 0.297190\n", + " 0.702338\n", + " 0.469250\n", + " 0.003467\n", + " 0.000979\n", + " 0.423850\n", + " 0.423444\n", + " -0.209979\n", + " 0.502842\n", + " 0.999528\n", " \n", " \n", " 4\n", - " -0.011600\n", - " 0.117177\n", - " 1.615437\n", - " 0.201891\n", - " 0.744478\n", - " 0.0\n", - " -0.337936\n", - " -0.005347\n", - " ideal\n", - " 0.007744\n", - " 0.236033\n", - " 0.756223\n", - " 0.426527\n", - " 0.002773\n", - " 0.007903\n", - " 0.379799\n", - " 0.377517\n", - " 0.337936\n", - " 0.500668\n", - " 0.992256\n", + " -0.031088\n", + " 1.118860\n", + " 0.087444\n", + " -0.012355\n", + " 2.000000\n", + " 0\n", + " acceptable\n", + " 0.683969\n", + " 0.085227\n", + " -0.026003\n", + " ...\n", + " 0.502813\n", + " 0.468748\n", + " 0.500849\n", + " 0.003525\n", + " 0.053980\n", + " 0.498330\n", + " 0.465368\n", + " -0.085227\n", + " 0.503250\n", + " 0.971561\n", " \n", " \n", " 5\n", - " -0.073124\n", - " -0.243746\n", - " 2.000000\n", - " 0.402120\n", - " 0.352889\n", - " 1.0\n", - " -0.395214\n", - " -0.066696\n", - " acceptable\n", - " 0.000355\n", - " 0.507754\n", - " 0.491891\n", - " 0.408950\n", - " 0.002803\n", - " 0.000558\n", - " 0.472051\n", - " 0.472281\n", - " 0.395214\n", - " 0.508336\n", - " 0.999645\n", + " -0.015485\n", + " -0.106424\n", + " 0.251205\n", + " -0.105387\n", + " 1.627490\n", + " 0\n", + " ideal\n", + " 0.596716\n", + " -0.075283\n", + " -0.010639\n", + " ...\n", + " 0.242466\n", + " 0.757293\n", + " 0.352754\n", + " 0.003411\n", + " 0.000412\n", + " 0.422019\n", + " 0.421865\n", + " 0.075283\n", + " 0.501330\n", + " 0.999759\n", " \n", " \n", " 6\n", - " 0.096887\n", - " 0.753092\n", - " -0.296234\n", - " -0.268697\n", - " 1.757006\n", - " 1.0\n", - " 0.256462\n", - " 0.102279\n", - " acceptable\n", - " 0.000452\n", - " 0.614570\n", - " 0.384978\n", - " 0.509006\n", - " 0.002777\n", - " 0.000830\n", - " 0.464693\n", - " 0.464979\n", - " -0.256462\n", - " 0.487218\n", - " 0.999548\n", + " -0.177911\n", + " -0.006491\n", + " 0.004730\n", + " -0.100077\n", + " 2.000000\n", + " 0\n", + " ideal\n", + " 0.629016\n", + " 0.015120\n", + " -0.173160\n", + " ...\n", + " 0.391062\n", + " 0.608024\n", + " 0.520686\n", + " 0.003441\n", + " 0.001167\n", + " 0.444837\n", + " 0.444724\n", + " -0.015120\n", + " 0.521632\n", + " 0.999086\n", " \n", " \n", " 7\n", - " 0.173888\n", - " 0.237557\n", - " -0.316069\n", - " -0.011089\n", - " 0.894507\n", - " 1.0\n", - " -0.376064\n", - " 0.179808\n", - " acceptable\n", - " 0.001598\n", - " 0.996005\n", - " 0.002396\n", - " 0.405710\n", - " 0.002727\n", - " 0.001953\n", - " 0.005998\n", - " 0.004370\n", - " 0.376064\n", - " 0.477539\n", - " 0.998402\n", + " 0.019420\n", + " 1.192449\n", + " 0.083663\n", + " -0.014620\n", + " 1.388673\n", + " 0\n", + " ideal\n", + " 0.730330\n", + " 0.186320\n", + " 0.024592\n", + " ...\n", + " 0.430550\n", + " 0.567086\n", + " 0.322597\n", + " 0.003467\n", + " 0.004309\n", + " 0.517441\n", + " 0.515383\n", + " -0.186320\n", + " 0.496926\n", + " 0.997636\n", " \n", " \n", " 8\n", - " -0.249681\n", - " -0.060291\n", + " 0.735010\n", + " -0.009883\n", + " 2.000000\n", + " -0.052121\n", " 2.000000\n", - " 0.359377\n", - " 0.670042\n", - " 0.0\n", - " -0.281907\n", - " -0.243816\n", + " 0\n", " ideal\n", - " 0.005539\n", - " 0.221772\n", - " 0.772690\n", - " 0.454217\n", - " 0.002804\n", - " 0.005596\n", - " 0.310703\n", - " 0.308060\n", - " 0.281907\n", - " 0.530439\n", - " 0.994461\n", + " 0.702506\n", + " 1.088985\n", + " 0.740256\n", + " ...\n", + " 0.145336\n", + " 0.659978\n", + " 1.068470\n", + " 0.003609\n", + " 0.435210\n", + " 0.304475\n", + " 0.459785\n", + " -1.088985\n", + " 0.408510\n", + " 0.805314\n", " \n", " \n", " 9\n", - " 0.126138\n", - " -0.276652\n", - " -0.165111\n", - " 0.671939\n", - " 0.442778\n", - " 1.0\n", - " -0.436071\n", - " 0.132389\n", - " acceptable\n", - " 0.002137\n", - " 0.996403\n", - " 0.001460\n", - " 0.346124\n", - " 0.002722\n", - " 0.001572\n", - " 0.003672\n", - " 0.002423\n", - " 0.436071\n", - " 0.483457\n", - " 0.997863\n", + " 0.028036\n", + " 2.000000\n", + " -0.318588\n", + " 2.000000\n", + " 0.310111\n", + " 0\n", + " ideal\n", + " 0.603002\n", + " 1.676667\n", + " 0.033345\n", + " ...\n", + " 0.475370\n", + " 0.524512\n", + " 1.170120\n", + " 0.003564\n", + " 0.000234\n", + " 0.426380\n", + " 0.426393\n", + " -1.676667\n", + " 0.495832\n", + " 0.999883\n", " \n", " \n", "\n", + "

10 rows × 21 columns

\n", "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0_pred f_2_pred \\\n", - "0 0.011502 -0.161920 1.972217 0.438555 0.457098 0.0 -0.508109 0.018091 \n", - "1 0.165405 -0.124200 -0.283804 0.356975 0.652289 1.0 -0.549702 0.171525 \n", - "2 -0.008095 -0.188512 1.662439 0.346688 0.430895 1.0 -0.371163 -0.001649 \n", - "3 -0.034115 2.000000 -0.095091 -0.378793 2.000000 0.0 0.887337 -0.029129 \n", - "4 -0.011600 0.117177 1.615437 0.201891 0.744478 0.0 -0.337936 -0.005347 \n", - "5 -0.073124 -0.243746 2.000000 0.402120 0.352889 1.0 -0.395214 -0.066696 \n", - "6 0.096887 0.753092 -0.296234 -0.268697 1.757006 1.0 0.256462 0.102279 \n", - "7 0.173888 0.237557 -0.316069 -0.011089 0.894507 1.0 -0.376064 0.179808 \n", - "8 -0.249681 -0.060291 2.000000 0.359377 0.670042 0.0 -0.281907 -0.243816 \n", - "9 0.126138 -0.276652 -0.165111 0.671939 0.442778 1.0 -0.436071 0.132389 \n", + " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", + "0 0.012650 -0.082007 0.196167 -0.068576 2.000000 0 ideal 0.585006 \n", + "1 0.006090 -0.110946 0.036496 -0.075202 1.376258 0 ideal 0.581150 \n", + "2 0.161220 -0.056338 -0.295114 -0.061665 0.096058 0 acceptable 0.745874 \n", + "3 -0.027612 0.391135 0.121035 1.256868 1.784862 0 ideal 0.599128 \n", + "4 -0.031088 1.118860 0.087444 -0.012355 2.000000 0 acceptable 0.683969 \n", + "5 -0.015485 -0.106424 0.251205 -0.105387 1.627490 0 ideal 0.596716 \n", + "6 -0.177911 -0.006491 0.004730 -0.100077 2.000000 0 ideal 0.629016 \n", + "7 0.019420 1.192449 0.083663 -0.014620 1.388673 0 ideal 0.730330 \n", + "8 0.735010 -0.009883 2.000000 -0.052121 2.000000 0 ideal 0.702506 \n", + "9 0.028036 2.000000 -0.318588 2.000000 0.310111 0 ideal 0.603002 \n", + "\n", + " f_0_pred f_2_pred ... f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", + "0 -0.178305 0.017555 ... 0.302352 0.696724 0.486615 \n", + "1 -0.033201 0.011008 ... 0.285243 0.714701 0.317448 \n", + "2 0.011562 0.166371 ... 0.591001 0.408996 0.271178 \n", + "3 0.209979 -0.022735 ... 0.297190 0.702338 0.469250 \n", + "4 0.085227 -0.026003 ... 0.502813 0.468748 0.500849 \n", + "5 -0.075283 -0.010639 ... 0.242466 0.757293 0.352754 \n", + "6 0.015120 -0.173160 ... 0.391062 0.608024 0.520686 \n", + "7 0.186320 0.024592 ... 0.430550 0.567086 0.322597 \n", + "8 1.088985 0.740256 ... 0.145336 0.659978 1.068470 \n", + "9 1.676667 0.033345 ... 0.475370 0.524512 1.170120 \n", "\n", - " f_1_pred f_1_pred_unacceptable f_1_pred_acceptable f_1_pred_ideal \\\n", - "0 ideal 0.005109 0.257733 0.737157 \n", - "1 acceptable 0.001980 0.996553 0.001467 \n", - "2 acceptable 0.000494 0.578223 0.421283 \n", - "3 ideal 0.001376 0.136712 0.861912 \n", - "4 ideal 0.007744 0.236033 0.756223 \n", - "5 acceptable 0.000355 0.507754 0.491891 \n", - "6 acceptable 0.000452 0.614570 0.384978 \n", - "7 acceptable 0.001598 0.996005 0.002396 \n", - "8 ideal 0.005539 0.221772 0.772690 \n", - "9 acceptable 0.002137 0.996403 0.001460 \n", + " f_2_sd f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_des \\\n", + "0 0.003456 0.001530 0.413975 0.413346 0.178305 \n", + "1 0.003389 0.000087 0.410947 0.410923 0.033201 \n", + "2 0.003351 0.000005 0.527414 0.527411 -0.011562 \n", + "3 0.003467 0.000979 0.423850 0.423444 -0.209979 \n", + "4 0.003525 0.053980 0.498330 0.465368 -0.085227 \n", + "5 0.003411 0.000412 0.422019 0.421865 0.075283 \n", + "6 0.003441 0.001167 0.444837 0.444724 -0.015120 \n", + "7 0.003467 0.004309 0.517441 0.515383 -0.186320 \n", + "8 0.003609 0.435210 0.304475 0.459785 -1.088985 \n", + "9 0.003564 0.000234 0.426380 0.426393 -1.676667 \n", "\n", - " f_0_sd f_2_sd f_1_sd_unacceptable f_1_sd_acceptable f_1_sd_ideal \\\n", - "0 0.424743 0.002801 0.004573 0.346797 0.344908 \n", - "1 0.365422 0.002720 0.001716 0.003879 0.002428 \n", - "2 0.418334 0.002777 0.000678 0.484275 0.484395 \n", - "3 0.899251 0.002871 0.003050 0.303232 0.302495 \n", - "4 0.426527 0.002773 0.007903 0.379799 0.377517 \n", - "5 0.408950 0.002803 0.000558 0.472051 0.472281 \n", - "6 0.509006 0.002777 0.000830 0.464693 0.464979 \n", - "7 0.405710 0.002727 0.001953 0.005998 0.004370 \n", - "8 0.454217 0.002804 0.005596 0.310703 0.308060 \n", - "9 0.346124 0.002722 0.001572 0.003672 0.002423 \n", + " f_2_des f_1_des \n", + "0 0.497806 0.999076 \n", + "1 0.498624 0.999944 \n", + "2 0.479216 0.999998 \n", + "3 0.502842 0.999528 \n", + "4 0.503250 0.971561 \n", + "5 0.501330 0.999759 \n", + "6 0.521632 0.999086 \n", + "7 0.496926 0.997636 \n", + "8 0.408510 0.805314 \n", + "9 0.495832 0.999883 \n", "\n", - " f_0_des f_2_des f_1_des \n", - "0 0.508109 0.497739 0.994891 \n", - "1 0.549702 0.478572 0.998020 \n", - "2 0.371163 0.500206 0.999506 \n", - "3 -0.887337 0.503641 0.998624 \n", - "4 0.337936 0.500668 0.992256 \n", - "5 0.395214 0.508336 0.999645 \n", - "6 -0.256462 0.487218 0.999548 \n", - "7 0.376064 0.477539 0.998402 \n", - "8 0.281907 0.530439 0.994461 \n", - "9 0.436071 0.483457 0.997863 " + "[10 rows x 21 columns]" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -812,22 +994,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## See performance of the classifier" + "## Check classification of proposed candidates\n", + "\n", + "Use the logic from above to verify the classification values" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Append to the candidates\n", + "candidates[\"f_1_true\"] = \"unacceptable\"\n", + "candidates.loc[(candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 6.5) * (candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 3.5), \"f_1_true\"] = \"acceptable\"\n", + "candidates.loc[(candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 5.5) * (candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 4.5), \"f_1_true\"] = \"ideal\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "We defined 'unacceptable' as values in (-infinity, 1.5), 'acceptable' as values in [1.5, 3.5), and 'ideal' as values in [3.5, infinity)\n", - "\n" - ] - }, { "data": { "text/html": [ @@ -849,87 +1037,87 @@ " \n", " \n", " \n", - " 0\n", " f_1_pred\n", + " f_1_true\n", " \n", " \n", " \n", " \n", " 0\n", - " 2.717452\n", " ideal\n", + " unacceptable\n", " \n", " \n", " 1\n", - " 1.766665\n", - " acceptable\n", + " ideal\n", + " unacceptable\n", " \n", " \n", " 2\n", - " 3.243415\n", " acceptable\n", + " unacceptable\n", " \n", " \n", " 3\n", - " 3.492002\n", " ideal\n", + " acceptable\n", " \n", " \n", " 4\n", - " 2.667383\n", - " ideal\n", + " acceptable\n", + " unacceptable\n", " \n", " \n", " 5\n", - " 3.438138\n", - " acceptable\n", + " ideal\n", + " unacceptable\n", " \n", " \n", " 6\n", - " 3.042054\n", - " acceptable\n", + " ideal\n", + " unacceptable\n", " \n", " \n", " 7\n", - " 1.978794\n", - " acceptable\n", + " ideal\n", + " unacceptable\n", " \n", " \n", " 8\n", - " 2.719447\n", + " ideal\n", " ideal\n", " \n", " \n", " 9\n", - " 1.799092\n", - " acceptable\n", + " ideal\n", + " ideal\n", " \n", " \n", "\n", "" ], "text/plain": [ - " 0 f_1_pred\n", - "0 2.717452 ideal\n", - "1 1.766665 acceptable\n", - "2 3.243415 acceptable\n", - "3 3.492002 ideal\n", - "4 2.667383 ideal\n", - "5 3.438138 acceptable\n", - "6 3.042054 acceptable\n", - "7 1.978794 acceptable\n", - "8 2.719447 ideal\n", - "9 1.799092 acceptable" + " f_1_pred f_1_true\n", + "0 ideal unacceptable\n", + "1 ideal unacceptable\n", + "2 acceptable unacceptable\n", + "3 ideal acceptable\n", + "4 acceptable unacceptable\n", + "5 ideal unacceptable\n", + "6 ideal unacceptable\n", + "7 ideal unacceptable\n", + "8 ideal ideal\n", + "9 ideal ideal" ] }, - "execution_count": 6, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "print(f\"We defined 'unacceptable' as values in (-infinity, 1.5), 'acceptable' as values in [1.5, 3.5), and 'ideal' as values in [3.5, infinity)\\n\")\n", - "pd.concat((candidates[[feat.key for feat in input_features]].astype(float).sum(1), candidates[\"f_1_pred\"]), axis=1)" + "# Print results\n", + "candidates[[\"f_1_pred\", \"f_1_true\"]]" ] } ], From 090dcfe92a890f169f026f5ff9cc6207b47fa5db Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 26 Jan 2024 13:16:50 -0500 Subject: [PATCH 11/31] Initial test fixes --- bofire/data_models/objectives/api.py | 1 + bofire/data_models/surrogates/mlp.py | 10 + bofire/strategies/predictives/predictive.py | 6 +- .../samplers/universal_constraint.py | 6 +- bofire/surrogates/diagnostics.py | 5 +- bofire/surrogates/mlp.py | 6 +- bofire/surrogates/surrogate.py | 37 +- bofire/surrogates/trainable.py | 4 +- .../data_models/features/test_descriptor.py | 2 +- tests/bofire/data_models/specs/features.py | 8 +- tests/bofire/data_models/specs/surrogates.py | 4 + tests/bofire/strategies/doe/test_objective.py | 28 +- tests/bofire/strategies/doe/test_utils.py | 10 +- tests/bofire/surrogates/test_mlp.py | 12 +- .../Unknown_Constraint_Classification.ipynb | 946 +++++++++--------- tutorials/models_serial.ipynb | 4 +- 16 files changed, 548 insertions(+), 541 deletions(-) diff --git a/bofire/data_models/objectives/api.py b/bofire/data_models/objectives/api.py index 133440ad9..de1a5381f 100644 --- a/bofire/data_models/objectives/api.py +++ b/bofire/data_models/objectives/api.py @@ -46,4 +46,5 @@ MinimizeSigmoidObjective, TargetObjective, CloseToTargetObjective, + ConstrainedCategoricalObjective, ] diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index 8e887cc9f..9df107234 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -25,6 +25,16 @@ class MLPEnsemble(TrainableBotorchSurrogate): shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + Args: + my_type: continuous or categorical output + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return isinstance(my_type, (CategoricalOutput, ContinuousOutput)) + class RegressionMLPEnsemble(MLPEnsemble): type: Literal["RegressionMLPEnsemble"] = "RegressionMLPEnsemble" diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index da0417593..cb7e39d75 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -147,11 +147,7 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: predictions.insert( loc=1, column=f"{feat.key}_sd", - value=predictions.filter(regex=f"{feat.key}(.*)_sd") - .pow(2.0) - .sum(1) - .pow(0.5) - .values, + value=0.0, ) desis = self.domain.outputs(predictions, predictions=True) predictions = pd.concat((predictions, desis), axis=1) diff --git a/bofire/strategies/samplers/universal_constraint.py b/bofire/strategies/samplers/universal_constraint.py index fff8d524a..8251827f6 100644 --- a/bofire/strategies/samplers/universal_constraint.py +++ b/bofire/strategies/samplers/universal_constraint.py @@ -37,7 +37,9 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: fixed_experiments=self.candidates, ) - samples = samples.iloc[self.num_candidates :,] + samples = samples.iloc[ + self.num_candidates :, + ] samples = samples.sample( n=candidate_count, replace=False, @@ -50,4 +52,4 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: return samples def has_sufficient_experiments(self) -> bool: - return True + return True \ No newline at end of file diff --git a/bofire/surrogates/diagnostics.py b/bofire/surrogates/diagnostics.py index 6e0935196..a4b260545 100644 --- a/bofire/surrogates/diagnostics.py +++ b/bofire/surrogates/diagnostics.py @@ -473,6 +473,9 @@ def _AbsoluteMiscalibrationArea( RegressionMetricsEnum.PEARSON: _pearson, RegressionMetricsEnum.SPEARMAN: _spearman, RegressionMetricsEnum.FISHER: _fisher_exact_test_p, +} + +classification_metrics = { ClassificationMetricsEnum.ACCURACY: _accuracy_score, ClassificationMetricsEnum.F1: _f1_score, } @@ -486,7 +489,7 @@ def _AbsoluteMiscalibrationArea( UQRegressionMetricsEnum.ABSOLUTEMISCALIBRATIONAREA: _AbsoluteMiscalibrationArea, } -all_metrics = {**metrics, **UQ_metrics} +all_metrics = {**metrics, **UQ_metrics, **classification_metrics} class CvResult(BaseModel): diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 9ca802206..9e42b14ae 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Literal, Optional, Sequence +from typing import Literal, Optional, Sequence, Union import numpy as np import pandas as pd @@ -136,7 +136,7 @@ def fit_mlp( lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, - loss_function: nn.Module = nn.L1Loss, + loss_function: Union[nn.L1Loss, nn.CrossEntropyLoss] = nn.L1Loss, ): """Fit a MLP to a dataset. @@ -148,7 +148,7 @@ def fit_mlp( lr (float, optional): Initial learning rate. Defaults to 1e-4. shuffle (bool, optional): Whereas the batches should be shuffled. Defaults to True. weight_decay (float, optional): Weight decay (L2 regularization). Defaults to 0.0 (no regularization). - loss_function (Module, optional): Loss function specified by the problem type. Defaults to L1 loss for regression problems. + loss_function (Loss function, NOT Optional): Loss function specified by the problem type. Defaults to L1 loss for regression problems. """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index 5bfce96c9..750bf07c7 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -46,30 +46,25 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: # predict preds, stds = self._predict(Xt) # set up column names - columns = [] + pred_cols = [] + sd_cols = [] for featkey in self.outputs.get_keys(): if isinstance(self.outputs.get_by_key(featkey), CategoricalOutput): - columns = ( - columns - + [ - f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories - ] - + [ - f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories - ] - ) + pred_cols = pred_cols + [ + f"{featkey}_{cat}_prob" + for cat in self.outputs.get_by_key(featkey).categories + ] + sd_cols = sd_cols + [ + f"{featkey}_{cat}_sd" + for cat in self.outputs.get_by_key(featkey).categories + ] else: - columns = ( - columns - + [f"{featkey}_pred" for featkey in self.outputs.get_keys()] - + [f"{featkey}_sd" for featkey in self.outputs.get_keys()] - ) + pred_cols = pred_cols + [f"{featkey}_pred"] + sd_cols = sd_cols + [f"{featkey}_sd"] # postprocess predictions = pd.DataFrame( data=np.hstack((preds, stds)), - columns=columns, + columns=pred_cols + sd_cols, ) # append predictions for categorical cases for feat in self.outputs.get(): @@ -86,11 +81,7 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: predictions.insert( loc=1, column=f"{feat.key}_sd", - value=predictions.filter(regex=f"{feat.key}(.*)_sd") - .pow(2.0) - .sum(1) - .pow(0.5) - .values, + value=0.0, ) # validate self.validate_predictions(predictions=predictions) diff --git a/bofire/surrogates/trainable.py b/bofire/surrogates/trainable.py index 794745c8a..a845b196e 100644 --- a/bofire/surrogates/trainable.py +++ b/bofire/surrogates/trainable.py @@ -12,7 +12,7 @@ ContinuousOutput, DiscreteInput, ) -from bofire.data_models.objectives.api import CategoricalObjective +from bofire.data_models.objectives.api import ConstrainedCategoricalObjective from bofire.surrogates.diagnostics import CvResult, CvResults from bofire.surrogates.surrogate import Surrogate @@ -184,7 +184,7 @@ def cross_validate( y_train_pred = self.predict(X_train) # type: ignore # Convert to categorical if applicable - if isinstance(self.outputs[0].objective, CategoricalObjective): + if isinstance(self.outputs[0].objective, ConstrainedCategoricalObjective): y_test_pred[f"{key}_pred"] = y_test_pred[f"{key}_pred"].map( self.outputs[0].objective.to_dict_label() ) diff --git a/tests/bofire/data_models/features/test_descriptor.py b/tests/bofire/data_models/features/test_descriptor.py index 90b791f05..f9007116d 100644 --- a/tests/bofire/data_models/features/test_descriptor.py +++ b/tests/bofire/data_models/features/test_descriptor.py @@ -378,6 +378,6 @@ def test_categorical_descriptor_input_feature_from_dataframe( columns=descriptors, ) f = CategoricalDescriptorInput.from_df("k", df) - assert f.categories == tuple(categories) + assert f.categories == categories assert f.descriptors == descriptors assert f.values == values diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index fb13f5936..ce920ac7a 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -59,7 +59,7 @@ features.CategoricalInput, lambda: { "key": str(uuid.uuid4()), - "categories": ("c1", "c2", "c3"), + "categories": ["c1", "c2", "c3"], "allowed": [True, True, False], }, ) @@ -102,7 +102,7 @@ features.CategoricalDescriptorInput, lambda: { "key": str(uuid.uuid4()), - "categories": ("c1", "c2", "c3"), + "categories": ["c1", "c2", "c3"], "allowed": [True, True, False], "descriptors": ["d1", "d2"], "values": [ @@ -142,12 +142,12 @@ features.CategoricalMolecularInput, lambda: { "key": str(uuid.uuid4()), - "categories": ( + "categories": [ "CC(=O)Oc1ccccc1C(=O)O", "c1ccccc1", "[CH3][CH2][OH]", "N[C@](C)(F)C(=O)O", - ), + ], "allowed": [True, True, True, True], }, ) diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index 71000122b..f294d7ae8 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -175,6 +175,7 @@ "n_estimators": 2, "hidden_layer_sizes": (100,), "activation": "relu", + "final_activation": "identity", "dropout": 0.0, "batch_size": 10, "n_epochs": 200, @@ -206,6 +207,7 @@ "n_estimators": 2, "hidden_layer_sizes": (100,), "activation": "relu", + "final_activation": "softmax", "dropout": 0.0, "batch_size": 10, "n_epochs": 200, @@ -240,6 +242,7 @@ "n_estimators": 2, "hidden_layer_sizes": (100,), "activation": "relu", + "final_activation": "softmax", "dropout": 0.0, "batch_size": 10, "n_epochs": 200, @@ -271,6 +274,7 @@ "n_estimators": 2, "hidden_layer_sizes": (100,), "activation": "relu", + "final_activation": "identity", "dropout": 0.0, "batch_size": 10, "n_epochs": 200, diff --git a/tests/bofire/strategies/doe/test_objective.py b/tests/bofire/strategies/doe/test_objective.py index 043a87782..4a1761509 100644 --- a/tests/bofire/strategies/doe/test_objective.py +++ b/tests/bofire/strategies/doe/test_objective.py @@ -68,9 +68,9 @@ def test_Objective_model_jacobian_t(): "x1", "x2", "x3", - "x1 ** 2", - "x2 ** 2", - "x3 ** 2", + "x1**2", + "x2**2", + "x3**2", "x1:x2", "x1:x3", "x2:x3", @@ -323,20 +323,20 @@ def test_Objective_model_jacobian_t(): columns=[ "1", "x1", - "x1 ** 2", - "x1 ** 3", + "x1**2", + "x1**3", "x2", - "x2 ** 2", - "x2 ** 3", + "x2**2", + "x2**3", "x3", - "x3 ** 2", - "x3 ** 3", + "x3**2", + "x3**3", "x4", - "x4 ** 2", - "x4 ** 3", + "x4**2", + "x4**3", "x5", - "x5 ** 2", - "x5 ** 3", + "x5**2", + "x5**3", "x2:x1", "x3:x1", "x4:x1", @@ -419,7 +419,7 @@ def test_DOptimality_instantiation(): assert isinstance(d_optimality.model, Formula) assert all( np.array(d_optimality.model, dtype=str) - == np.array(["1", "x1", "x2", "x3", "x3 ** 2", "x1:x2"]) + == np.array(["1", "x1", "x2", "x3", "x3**2", "x1:x2"]) ) x = np.array([[1, 2, 3], [1, 2, 3]]) diff --git a/tests/bofire/strategies/doe/test_utils.py b/tests/bofire/strategies/doe/test_utils.py index c3aa45d37..5e9284355 100644 --- a/tests/bofire/strategies/doe/test_utils.py +++ b/tests/bofire/strategies/doe/test_utils.py @@ -74,7 +74,7 @@ def test_get_formula_from_string(): assert all(term in np.array(model_formula, dtype=str) for term in terms) # linear and quadratic - terms = ["1", "x0", "x1", "x2", "x0 ** 2", "x1 ** 2", "x2 ** 2"] + terms = ["1", "x0", "x1", "x2", "x0**2", "x1**2", "x2**2"] model_formula = get_formula_from_string( domain=domain, model_type="linear-and-quadratic" ) @@ -90,9 +90,9 @@ def test_get_formula_from_string(): "x0:x1", "x0:x2", "x1:x2", - "x0 ** 2", - "x1 ** 2", - "x2 ** 2", + "x0**2", + "x1**2", + "x2**2", ] model_formula = get_formula_from_string(domain=domain, model_type="fully-quadratic") assert all(term in terms for term in model_formula) @@ -100,7 +100,7 @@ def test_get_formula_from_string(): # custom model terms_lhs = ["y"] - terms_rhs = ["1", "x0", "x0 ** 2", "x0:x1"] + terms_rhs = ["1", "x0", "x0**2", "x0:x1"] model_formula = get_formula_from_string( domain=domain, model_type="y ~ 1 + x0 + x0:x1 + {x0**2}", diff --git a/tests/bofire/surrogates/test_mlp.py b/tests/bofire/surrogates/test_mlp.py index b447d6ada..67708aff1 100644 --- a/tests/bofire/surrogates/test_mlp.py +++ b/tests/bofire/surrogates/test_mlp.py @@ -13,7 +13,7 @@ ContinuousInput, ContinuousOutput, ) -from bofire.data_models.surrogates.api import MLPEnsemble, ScalerEnum +from bofire.data_models.surrogates.api import RegressionMLPEnsemble, ScalerEnum from bofire.surrogates.mlp import MLP, MLPDataset, _MLPEnsemble, fit_mlp from bofire.utils.torch_tools import tkwargs @@ -39,12 +39,12 @@ def test_mlp_activation_invalid(): @pytest.mark.parametrize("output_size", [1, 2]) def test_mlp_input_size(output_size): mlp = MLP(input_size=2, output_size=output_size) - assert mlp.layers[-1].out_features == output_size + assert mlp.layers[-2].out_features == output_size def test_mlp_hidden_layer_sizes(): mlp = MLP(input_size=2, output_size=1, hidden_layer_sizes=(8, 4, 2)) - assert len(mlp.layers) == 7 + assert len(mlp.layers) == 8 # added final acitvation function as a layer assert mlp.layers[0].in_features == 2 assert mlp.layers[0].out_features == 8 assert mlp.layers[2].in_features == 8 @@ -60,7 +60,7 @@ def test_mlp_hidden_layer_sizes(): def test_mlp_dropout(): mlp = MLP(input_size=2, output_size=1, hidden_layer_sizes=(8, 4, 2), dropout=0.2) - assert len(mlp.layers) == 10 + assert len(mlp.layers) == 11 assert mlp.layers[0].in_features == 2 assert mlp.layers[0].out_features == 8 assert isinstance(mlp.layers[1], nn.modules.activation.ReLU) @@ -175,7 +175,7 @@ def test_mlp_ensemble_fit(scaler, output_scaler): bench = Himmelblau() samples = bench.domain.inputs.sample(10) experiments = bench.f(samples, return_complete=True) - ens = MLPEnsemble( + ens = RegressionMLPEnsemble( inputs=bench.domain.inputs, outputs=bench.domain.outputs, n_estimators=2, @@ -229,7 +229,7 @@ def test_mlp_ensemble_fit_categorical(scaler): experiments.loc[experiments.x_cat == "papa", "y"] /= 2.0 experiments["valid_y"] = 1 - ens = MLPEnsemble( + ens = RegressionMLPEnsemble( inputs=inputs, outputs=outputs, n_estimators=2, diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 8f8e203df..48d1f6ae6 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -102,243 +102,243 @@ " \n", " \n", " 0\n", - " -0.310698\n", - " 0.392037\n", - " -1.806570\n", - " -1.981755\n", - " 1.384988\n", + " 0.296884\n", + " 1.339136\n", + " -0.699016\n", + " -1.483660\n", + " -1.619869\n", " 0\n", - " 11.628122\n", - " acceptable\n", - " -0.310070\n", + " 7.435109\n", + " ideal\n", + " 0.302020\n", " \n", " \n", " 1\n", - " -1.489428\n", - " -1.288171\n", - " -0.505947\n", - " -0.384544\n", - " 0.425808\n", - " 1\n", - " 6.224438\n", - " acceptable\n", - " -1.485448\n", + " -1.198177\n", + " -1.871994\n", + " -0.096251\n", + " 1.753299\n", + " -1.822007\n", + " 0\n", + " 13.140989\n", + " unacceptable\n", + " -1.188586\n", " \n", " \n", " 2\n", - " 1.461134\n", - " 0.337194\n", - " -0.082360\n", - " -1.730603\n", - " -1.588481\n", + " -0.197863\n", + " 0.134761\n", + " -1.559977\n", + " -1.984249\n", + " 0.769223\n", " 0\n", - " 8.278724\n", + " 10.049967\n", " ideal\n", - " 1.467232\n", + " -0.189195\n", " \n", " \n", " 3\n", - " -1.781506\n", - " -1.436844\n", - " 1.930972\n", - " 1.783539\n", - " -0.900358\n", - " 1\n", - " 9.733454\n", + " 0.423355\n", + " -0.333732\n", + " -0.992391\n", + " 0.021709\n", + " 1.187608\n", + " 0\n", + " 1.452937\n", " unacceptable\n", - " -1.779890\n", + " 0.424980\n", " \n", " \n", " 4\n", - " 0.665722\n", - " -1.744514\n", - " 0.446853\n", - " 0.700885\n", - " 0.651237\n", - " 1\n", - " 4.665782\n", - " acceptable\n", - " 0.668620\n", + " -1.827201\n", + " -1.653218\n", + " -1.266475\n", + " 1.432659\n", + " -1.158900\n", + " 0\n", + " 12.936642\n", + " unacceptable\n", + " -1.826852\n", " \n", " \n", " 5\n", - " -1.499912\n", - " -0.747193\n", - " 0.664655\n", - " -1.011433\n", - " -1.354635\n", - " 1\n", - " 7.671702\n", - " ideal\n", - " -1.491068\n", + " 0.073122\n", + " -1.789563\n", + " 1.670582\n", + " 1.867550\n", + " 1.759474\n", + " 0\n", + " 6.432953\n", + " unacceptable\n", + " 0.080055\n", " \n", " \n", " 6\n", - " -1.543055\n", - " 1.962503\n", - " 1.798311\n", - " 0.168173\n", - " -1.495112\n", + " 0.847212\n", + " -1.236811\n", + " -0.533595\n", + " -1.842110\n", + " 1.425404\n", " 1\n", - " 8.231485\n", - " unacceptable\n", - " -1.542496\n", + " 7.133379\n", + " acceptable\n", + " 0.854993\n", " \n", " \n", " 7\n", - " -0.708194\n", - " 1.516163\n", - " -1.662763\n", - " 0.717943\n", - " -0.500210\n", - " 0\n", - " 5.660023\n", - " ideal\n", - " -0.703393\n", + " -0.728056\n", + " 1.114750\n", + " 1.837209\n", + " 0.768175\n", + " 1.179493\n", + " 1\n", + " 1.895683\n", + " acceptable\n", + " -0.724770\n", " \n", " \n", " 8\n", - " -0.146424\n", - " -0.733847\n", - " -1.398632\n", - " -0.390983\n", - " -1.966037\n", + " -0.787473\n", + " -1.280905\n", + " 1.964343\n", + " 1.051243\n", + " -1.969924\n", " 1\n", - " 9.502397\n", - " ideal\n", - " -0.142366\n", + " 9.828818\n", + " unacceptable\n", + " -0.778318\n", " \n", " \n", " 9\n", - " -1.435747\n", - " -0.197231\n", - " -0.085655\n", - " 1.959668\n", - " -1.492488\n", - " 0\n", - " 7.361934\n", - " ideal\n", - " -1.429517\n", + " -1.805895\n", + " 1.544210\n", + " -0.926871\n", + " 1.780826\n", + " 1.477554\n", + " 1\n", + " 6.990552\n", + " unacceptable\n", + " -1.797201\n", " \n", " \n", " 10\n", - " -0.672001\n", - " 0.989061\n", - " 0.625807\n", - " 1.945203\n", - " -1.505085\n", - " 0\n", - " 4.971067\n", + " -0.353654\n", + " 1.743341\n", + " -1.336186\n", + " -1.700009\n", + " 0.440015\n", + " 1\n", + " 7.701678\n", " acceptable\n", - " -0.671859\n", + " -0.345705\n", " \n", " \n", " 11\n", - " -0.930565\n", - " -1.942693\n", - " 0.550933\n", - " -0.877697\n", - " -1.120742\n", + " 1.181500\n", + " 1.070489\n", + " -0.555974\n", + " -0.604853\n", + " -0.795902\n", " 0\n", - " 9.486661\n", - " ideal\n", - " -0.930336\n", + " 1.056973\n", + " acceptable\n", + " 1.189100\n", " \n", " \n", " 12\n", - " -1.762717\n", - " 0.738706\n", - " -1.863998\n", - " -1.565166\n", - " 0.792053\n", + " 0.920993\n", + " -0.776714\n", + " -0.928656\n", + " 0.050378\n", + " -0.887893\n", " 1\n", - " 13.751706\n", - " unacceptable\n", - " -1.757165\n", + " 2.087592\n", + " acceptable\n", + " 0.929757\n", " \n", " \n", " 13\n", - " -1.863466\n", - " 0.404012\n", - " 0.844637\n", - " -1.966831\n", - " -0.169374\n", + " -0.599356\n", + " 1.893306\n", + " -1.413963\n", + " 0.700251\n", + " 1.361780\n", " 0\n", - " 11.650621\n", - " ideal\n", - " -1.857675\n", + " 4.407259\n", + " acceptable\n", + " -0.593129\n", " \n", " \n", " 14\n", - " -0.981088\n", - " 1.761401\n", - " -0.346133\n", - " -0.842548\n", - " -0.307780\n", + " -1.178618\n", + " -1.807959\n", + " 1.643023\n", + " -1.367653\n", + " -1.300478\n", " 0\n", - " 3.160072\n", - " acceptable\n", - " -0.972805\n", + " 12.031568\n", + " unacceptable\n", + " -1.170357\n", " \n", " \n", " 15\n", - " 0.703509\n", - " -1.285697\n", - " 0.011635\n", - " 1.999545\n", - " 0.608979\n", + " 0.012359\n", + " 0.674245\n", + " -1.080136\n", + " -0.569172\n", + " -1.401835\n", " 0\n", - " 2.874619\n", - " ideal\n", - " 0.710066\n", + " 4.597730\n", + " acceptable\n", + " 0.022354\n", " \n", " \n", " 16\n", - " -0.262847\n", - " -0.670558\n", - " -1.676711\n", - " 1.346576\n", - " 0.734364\n", - " 1\n", - " 5.254699\n", - " ideal\n", - " -0.255986\n", + " -1.161135\n", + " -1.765019\n", + " 1.158920\n", + " 1.827956\n", + " -1.107484\n", + " 0\n", + " 8.805522\n", + " unacceptable\n", + " -1.157363\n", " \n", " \n", " 17\n", - " 1.949815\n", - " -0.132073\n", - " -1.921915\n", - " 1.955173\n", - " -0.510576\n", + " -1.750851\n", + " 0.165063\n", + " -1.370027\n", + " 0.828754\n", + " -0.325600\n", " 0\n", - " 7.427848\n", + " 7.730277\n", " acceptable\n", - " 1.951420\n", + " -1.742488\n", " \n", " \n", " 18\n", - " 1.352886\n", - " -0.307143\n", - " 0.472704\n", - " -1.201528\n", - " -1.226323\n", - " 0\n", - " 3.723930\n", - " ideal\n", - " 1.356302\n", + " -1.695762\n", + " 1.902611\n", + " -1.071329\n", + " -0.927217\n", + " -0.700890\n", + " 1\n", + " 7.765145\n", + " acceptable\n", + " -1.693954\n", " \n", " \n", " 19\n", - " -1.093545\n", - " 1.693383\n", - " 0.291218\n", - " -0.979725\n", - " -1.454517\n", + " -0.528461\n", + " -1.958166\n", + " 0.104057\n", + " 0.300671\n", + " -1.417810\n", " 0\n", - " 6.389192\n", + " 9.545906\n", " acceptable\n", - " -1.092528\n", + " -0.525143\n", " \n", " \n", "\n", @@ -346,48 +346,48 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 -0.310698 0.392037 -1.806570 -1.981755 1.384988 0 11.628122 \n", - "1 -1.489428 -1.288171 -0.505947 -0.384544 0.425808 1 6.224438 \n", - "2 1.461134 0.337194 -0.082360 -1.730603 -1.588481 0 8.278724 \n", - "3 -1.781506 -1.436844 1.930972 1.783539 -0.900358 1 9.733454 \n", - "4 0.665722 -1.744514 0.446853 0.700885 0.651237 1 4.665782 \n", - "5 -1.499912 -0.747193 0.664655 -1.011433 -1.354635 1 7.671702 \n", - "6 -1.543055 1.962503 1.798311 0.168173 -1.495112 1 8.231485 \n", - "7 -0.708194 1.516163 -1.662763 0.717943 -0.500210 0 5.660023 \n", - "8 -0.146424 -0.733847 -1.398632 -0.390983 -1.966037 1 9.502397 \n", - "9 -1.435747 -0.197231 -0.085655 1.959668 -1.492488 0 7.361934 \n", - "10 -0.672001 0.989061 0.625807 1.945203 -1.505085 0 4.971067 \n", - "11 -0.930565 -1.942693 0.550933 -0.877697 -1.120742 0 9.486661 \n", - "12 -1.762717 0.738706 -1.863998 -1.565166 0.792053 1 13.751706 \n", - "13 -1.863466 0.404012 0.844637 -1.966831 -0.169374 0 11.650621 \n", - "14 -0.981088 1.761401 -0.346133 -0.842548 -0.307780 0 3.160072 \n", - "15 0.703509 -1.285697 0.011635 1.999545 0.608979 0 2.874619 \n", - "16 -0.262847 -0.670558 -1.676711 1.346576 0.734364 1 5.254699 \n", - "17 1.949815 -0.132073 -1.921915 1.955173 -0.510576 0 7.427848 \n", - "18 1.352886 -0.307143 0.472704 -1.201528 -1.226323 0 3.723930 \n", - "19 -1.093545 1.693383 0.291218 -0.979725 -1.454517 0 6.389192 \n", + "0 0.296884 1.339136 -0.699016 -1.483660 -1.619869 0 7.435109 \n", + "1 -1.198177 -1.871994 -0.096251 1.753299 -1.822007 0 13.140989 \n", + "2 -0.197863 0.134761 -1.559977 -1.984249 0.769223 0 10.049967 \n", + "3 0.423355 -0.333732 -0.992391 0.021709 1.187608 0 1.452937 \n", + "4 -1.827201 -1.653218 -1.266475 1.432659 -1.158900 0 12.936642 \n", + "5 0.073122 -1.789563 1.670582 1.867550 1.759474 0 6.432953 \n", + "6 0.847212 -1.236811 -0.533595 -1.842110 1.425404 1 7.133379 \n", + "7 -0.728056 1.114750 1.837209 0.768175 1.179493 1 1.895683 \n", + "8 -0.787473 -1.280905 1.964343 1.051243 -1.969924 1 9.828818 \n", + "9 -1.805895 1.544210 -0.926871 1.780826 1.477554 1 6.990552 \n", + "10 -0.353654 1.743341 -1.336186 -1.700009 0.440015 1 7.701678 \n", + "11 1.181500 1.070489 -0.555974 -0.604853 -0.795902 0 1.056973 \n", + "12 0.920993 -0.776714 -0.928656 0.050378 -0.887893 1 2.087592 \n", + "13 -0.599356 1.893306 -1.413963 0.700251 1.361780 0 4.407259 \n", + "14 -1.178618 -1.807959 1.643023 -1.367653 -1.300478 0 12.031568 \n", + "15 0.012359 0.674245 -1.080136 -0.569172 -1.401835 0 4.597730 \n", + "16 -1.161135 -1.765019 1.158920 1.827956 -1.107484 0 8.805522 \n", + "17 -1.750851 0.165063 -1.370027 0.828754 -0.325600 0 7.730277 \n", + "18 -1.695762 1.902611 -1.071329 -0.927217 -0.700890 1 7.765145 \n", + "19 -0.528461 -1.958166 0.104057 0.300671 -1.417810 0 9.545906 \n", "\n", " f_1 f_2 \n", - "0 acceptable -0.310070 \n", - "1 acceptable -1.485448 \n", - "2 ideal 1.467232 \n", - "3 unacceptable -1.779890 \n", - "4 acceptable 0.668620 \n", - "5 ideal -1.491068 \n", - "6 unacceptable -1.542496 \n", - "7 ideal -0.703393 \n", - "8 ideal -0.142366 \n", - "9 ideal -1.429517 \n", - "10 acceptable -0.671859 \n", - "11 ideal -0.930336 \n", - "12 unacceptable -1.757165 \n", - "13 ideal -1.857675 \n", - "14 acceptable -0.972805 \n", - "15 ideal 0.710066 \n", - "16 ideal -0.255986 \n", - "17 acceptable 1.951420 \n", - "18 ideal 1.356302 \n", - "19 acceptable -1.092528 " + "0 ideal 0.302020 \n", + "1 unacceptable -1.188586 \n", + "2 ideal -0.189195 \n", + "3 unacceptable 0.424980 \n", + "4 unacceptable -1.826852 \n", + "5 unacceptable 0.080055 \n", + "6 acceptable 0.854993 \n", + "7 acceptable -0.724770 \n", + "8 unacceptable -0.778318 \n", + "9 unacceptable -1.797201 \n", + "10 acceptable -0.345705 \n", + "11 acceptable 1.189100 \n", + "12 acceptable 0.929757 \n", + "13 acceptable -0.593129 \n", + "14 unacceptable -1.170357 \n", + "15 acceptable 0.022354 \n", + "16 unacceptable -1.157363 \n", + "17 acceptable -1.742488 \n", + "18 acceptable -1.693954 \n", + "19 acceptable -0.525143 " ] }, "execution_count": 3, @@ -511,16 +511,16 @@ " \n", " \n", " 0\n", - " 0.76\n", - " 0.76\n", + " 0.795\n", + " 0.795\n", " \n", " \n", "\n", "" ], "text/plain": [ - " ACCURACY F1\n", - "0 0.76 0.76" + " ACCURACY F1\n", + "0 0.795 0.795" ] }, "execution_count": 5, @@ -566,8 +566,8 @@ " \n", " \n", " 0\n", - " 0.18\n", - " 0.18\n", + " 0.52\n", + " 0.52\n", " \n", " \n", "\n", @@ -575,7 +575,7 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.18 0.18" + "0 0.52 0.52" ] }, "execution_count": 6, @@ -685,243 +685,243 @@ " \n", " \n", " 0\n", - " 0.012650\n", - " -0.082007\n", - " 0.196167\n", - " -0.068576\n", - " 2.000000\n", - " 0\n", - " ideal\n", - " 0.585006\n", - " -0.178305\n", - " 0.017555\n", + " 0.481400\n", + " 0.586632\n", + " 0.063636\n", + " 0.087314\n", + " 0.424632\n", + " 1\n", + " acceptable\n", + " 0.0\n", + " -2.807002\n", + " 0.487184\n", " ...\n", - " 0.302352\n", - " 0.696724\n", - " 0.486615\n", - " 0.003456\n", - " 0.001530\n", - " 0.413975\n", - " 0.413346\n", - " 0.178305\n", - " 0.497806\n", - " 0.999076\n", + " 0.974527\n", + " 0.007553\n", + " 0.667668\n", + " 0.003238\n", + " 0.017963\n", + " 0.029889\n", + " 0.014268\n", + " 2.807002\n", + " 0.439401\n", + " 0.982080\n", " \n", " \n", " 1\n", - " 0.006090\n", - " -0.110946\n", - " 0.036496\n", - " -0.075202\n", - " 1.376258\n", + " 0.205385\n", + " 0.435466\n", + " 0.386084\n", + " 0.820992\n", + " 0.562380\n", " 0\n", - " ideal\n", - " 0.581150\n", - " -0.033201\n", - " 0.011008\n", + " unacceptable\n", + " 0.0\n", + " -2.478325\n", + " 0.210586\n", " ...\n", - " 0.285243\n", - " 0.714701\n", - " 0.317448\n", - " 0.003389\n", - " 0.000087\n", - " 0.410947\n", - " 0.410923\n", - " 0.033201\n", - " 0.498624\n", - " 0.999944\n", + " 0.408227\n", + " 0.001115\n", + " 0.736828\n", + " 0.003263\n", + " 0.539504\n", + " 0.540456\n", + " 0.002281\n", + " 2.478325\n", + " 0.473701\n", + " 0.409342\n", " \n", " \n", " 2\n", - " 0.161220\n", - " -0.056338\n", - " -0.295114\n", - " -0.061665\n", - " 0.096058\n", - " 0\n", + " 0.578421\n", + " 0.790867\n", + " -0.054982\n", + " 0.602806\n", + " 0.624962\n", + " 1\n", " acceptable\n", - " 0.745874\n", - " 0.011562\n", - " 0.166371\n", + " 0.0\n", + " -2.495759\n", + " 0.584149\n", " ...\n", - " 0.591001\n", - " 0.408996\n", - " 0.271178\n", - " 0.003351\n", - " 0.000005\n", - " 0.527414\n", - " 0.527411\n", - " -0.011562\n", - " 0.479216\n", - " 0.999998\n", + " 0.965720\n", + " 0.007595\n", + " 0.851270\n", + " 0.003274\n", + " 0.024524\n", + " 0.033772\n", + " 0.014316\n", + " 2.495759\n", + " 0.427496\n", + " 0.973315\n", " \n", " \n", " 3\n", - " -0.027612\n", - " 0.391135\n", - " 0.121035\n", - " 1.256868\n", - " 1.784862\n", - " 0\n", - " ideal\n", - " 0.599128\n", - " 0.209979\n", - " -0.022735\n", + " 0.225033\n", + " 0.648267\n", + " 0.115931\n", + " 0.153590\n", + " 0.527780\n", + " 1\n", + " acceptable\n", + " 0.0\n", + " -2.546250\n", + " 0.230460\n", " ...\n", - " 0.297190\n", - " 0.702338\n", - " 0.469250\n", - " 0.003467\n", - " 0.000979\n", - " 0.423850\n", - " 0.423444\n", - " -0.209979\n", - " 0.502842\n", - " 0.999528\n", + " 0.860037\n", + " 0.012140\n", + " 0.717848\n", + " 0.003234\n", + " 0.170332\n", + " 0.177282\n", + " 0.021604\n", + " 2.546250\n", + " 0.471224\n", + " 0.872177\n", " \n", " \n", " 4\n", - " -0.031088\n", - " 1.118860\n", - " 0.087444\n", - " -0.012355\n", - " 2.000000\n", - " 0\n", + " 0.342310\n", + " 0.498711\n", + " 0.047750\n", + " 0.008932\n", + " 0.668452\n", + " 1\n", " acceptable\n", - " 0.683969\n", - " 0.085227\n", - " -0.026003\n", + " 0.0\n", + " -2.549582\n", + " 0.347790\n", " ...\n", - " 0.502813\n", - " 0.468748\n", - " 0.500849\n", - " 0.003525\n", - " 0.053980\n", - " 0.498330\n", - " 0.465368\n", - " -0.085227\n", - " 0.503250\n", - " 0.971561\n", + " 0.932932\n", + " 0.008269\n", + " 0.655287\n", + " 0.003230\n", + " 0.099837\n", + " 0.101100\n", + " 0.014011\n", + " 2.549582\n", + " 0.456635\n", + " 0.941201\n", " \n", " \n", " 5\n", - " -0.015485\n", - " -0.106424\n", - " 0.251205\n", - " -0.105387\n", - " 1.627490\n", - " 0\n", - " ideal\n", - " 0.596716\n", - " -0.075283\n", - " -0.010639\n", + " 0.380575\n", + " 1.040915\n", + " 0.214337\n", + " 0.069153\n", + " 0.753324\n", + " 1\n", + " acceptable\n", + " 0.0\n", + " -2.368879\n", + " 0.386290\n", " ...\n", - " 0.242466\n", - " 0.757293\n", - " 0.352754\n", - " 0.003411\n", - " 0.000412\n", - " 0.422019\n", - " 0.421865\n", - " 0.075283\n", - " 0.501330\n", - " 0.999759\n", + " 0.983611\n", + " 0.005720\n", + " 0.795297\n", + " 0.003277\n", + " 0.013258\n", + " 0.023035\n", + " 0.012126\n", + " 2.368879\n", + " 0.451863\n", + " 0.989332\n", " \n", " \n", " 6\n", - " -0.177911\n", - " -0.006491\n", - " 0.004730\n", - " -0.100077\n", - " 2.000000\n", - " 0\n", - " ideal\n", - " 0.629016\n", - " 0.015120\n", - " -0.173160\n", + " 0.080792\n", + " 0.260320\n", + " 0.331997\n", + " 0.669731\n", + " 0.425205\n", + " 1\n", + " unacceptable\n", + " 0.0\n", + " -2.326501\n", + " 0.085955\n", " ...\n", - " 0.391062\n", - " 0.608024\n", - " 0.520686\n", - " 0.003441\n", - " 0.001167\n", - " 0.444837\n", - " 0.444724\n", - " -0.015120\n", - " 0.521632\n", - " 0.999086\n", + " 0.210872\n", + " 0.001047\n", + " 0.763624\n", + " 0.003236\n", + " 0.440892\n", + " 0.441507\n", + " 0.001058\n", + " 2.326501\n", + " 0.489257\n", + " 0.211920\n", " \n", " \n", " 7\n", - " 0.019420\n", - " 1.192449\n", - " 0.083663\n", - " -0.014620\n", - " 1.388673\n", - " 0\n", - " ideal\n", - " 0.730330\n", - " 0.186320\n", - " 0.024592\n", + " 0.243598\n", + " 0.698221\n", + " -0.014641\n", + " 0.092352\n", + " 0.708394\n", + " 1\n", + " acceptable\n", + " 0.0\n", + " -2.445781\n", + " 0.248967\n", " ...\n", - " 0.430550\n", - " 0.567086\n", - " 0.322597\n", - " 0.003467\n", - " 0.004309\n", - " 0.517441\n", - " 0.515383\n", - " -0.186320\n", - " 0.496926\n", - " 0.997636\n", + " 0.943763\n", + " 0.006818\n", + " 0.699401\n", + " 0.003234\n", + " 0.090640\n", + " 0.091162\n", + " 0.012202\n", + " 2.445781\n", + " 0.468919\n", + " 0.950581\n", " \n", " \n", " 8\n", - " 0.735010\n", - " -0.009883\n", - " 2.000000\n", - " -0.052121\n", - " 2.000000\n", - " 0\n", - " ideal\n", - " 0.702506\n", - " 1.088985\n", - " 0.740256\n", + " 0.182756\n", + " -0.921039\n", + " -0.549821\n", + " 0.689784\n", + " -0.655921\n", + " 1\n", + " unacceptable\n", + " 0.0\n", + " 1.978269\n", + " 0.187628\n", " ...\n", - " 0.145336\n", - " 0.659978\n", - " 1.068470\n", - " 0.003609\n", - " 0.435210\n", - " 0.304475\n", - " 0.459785\n", - " -1.088985\n", - " 0.408510\n", - " 0.805314\n", + " 0.223373\n", + " 0.004275\n", + " 0.536589\n", + " 0.003255\n", + " 0.389221\n", + " 0.391444\n", + " 0.007985\n", + " -1.978269\n", + " 0.476564\n", + " 0.227648\n", " \n", " \n", " 9\n", - " 0.028036\n", - " 2.000000\n", - " -0.318588\n", - " 2.000000\n", - " 0.310111\n", - " 0\n", - " ideal\n", - " 0.603002\n", - " 1.676667\n", - " 0.033345\n", + " -0.297385\n", + " -0.422154\n", + " 1.407541\n", + " 0.605559\n", + " 1.597342\n", + " 1\n", + " unacceptable\n", + " 0.0\n", + " 1.351191\n", + " -0.292334\n", " ...\n", - " 0.475370\n", - " 0.524512\n", - " 1.170120\n", - " 0.003564\n", - " 0.000234\n", - " 0.426380\n", - " 0.426393\n", - " -1.676667\n", - " 0.495832\n", - " 0.999883\n", + " 0.198005\n", + " 0.012080\n", + " 0.576409\n", + " 0.003309\n", + " 0.436782\n", + " 0.442695\n", + " 0.026803\n", + " -1.351191\n", + " 0.536477\n", + " 0.210085\n", " \n", " \n", "\n", @@ -929,53 +929,53 @@ "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", - "0 0.012650 -0.082007 0.196167 -0.068576 2.000000 0 ideal 0.585006 \n", - "1 0.006090 -0.110946 0.036496 -0.075202 1.376258 0 ideal 0.581150 \n", - "2 0.161220 -0.056338 -0.295114 -0.061665 0.096058 0 acceptable 0.745874 \n", - "3 -0.027612 0.391135 0.121035 1.256868 1.784862 0 ideal 0.599128 \n", - "4 -0.031088 1.118860 0.087444 -0.012355 2.000000 0 acceptable 0.683969 \n", - "5 -0.015485 -0.106424 0.251205 -0.105387 1.627490 0 ideal 0.596716 \n", - "6 -0.177911 -0.006491 0.004730 -0.100077 2.000000 0 ideal 0.629016 \n", - "7 0.019420 1.192449 0.083663 -0.014620 1.388673 0 ideal 0.730330 \n", - "8 0.735010 -0.009883 2.000000 -0.052121 2.000000 0 ideal 0.702506 \n", - "9 0.028036 2.000000 -0.318588 2.000000 0.310111 0 ideal 0.603002 \n", + " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", + "0 0.481400 0.586632 0.063636 0.087314 0.424632 1 acceptable 0.0 \n", + "1 0.205385 0.435466 0.386084 0.820992 0.562380 0 unacceptable 0.0 \n", + "2 0.578421 0.790867 -0.054982 0.602806 0.624962 1 acceptable 0.0 \n", + "3 0.225033 0.648267 0.115931 0.153590 0.527780 1 acceptable 0.0 \n", + "4 0.342310 0.498711 0.047750 0.008932 0.668452 1 acceptable 0.0 \n", + "5 0.380575 1.040915 0.214337 0.069153 0.753324 1 acceptable 0.0 \n", + "6 0.080792 0.260320 0.331997 0.669731 0.425205 1 unacceptable 0.0 \n", + "7 0.243598 0.698221 -0.014641 0.092352 0.708394 1 acceptable 0.0 \n", + "8 0.182756 -0.921039 -0.549821 0.689784 -0.655921 1 unacceptable 0.0 \n", + "9 -0.297385 -0.422154 1.407541 0.605559 1.597342 1 unacceptable 0.0 \n", "\n", " f_0_pred f_2_pred ... f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", - "0 -0.178305 0.017555 ... 0.302352 0.696724 0.486615 \n", - "1 -0.033201 0.011008 ... 0.285243 0.714701 0.317448 \n", - "2 0.011562 0.166371 ... 0.591001 0.408996 0.271178 \n", - "3 0.209979 -0.022735 ... 0.297190 0.702338 0.469250 \n", - "4 0.085227 -0.026003 ... 0.502813 0.468748 0.500849 \n", - "5 -0.075283 -0.010639 ... 0.242466 0.757293 0.352754 \n", - "6 0.015120 -0.173160 ... 0.391062 0.608024 0.520686 \n", - "7 0.186320 0.024592 ... 0.430550 0.567086 0.322597 \n", - "8 1.088985 0.740256 ... 0.145336 0.659978 1.068470 \n", - "9 1.676667 0.033345 ... 0.475370 0.524512 1.170120 \n", + "0 -2.807002 0.487184 ... 0.974527 0.007553 0.667668 \n", + "1 -2.478325 0.210586 ... 0.408227 0.001115 0.736828 \n", + "2 -2.495759 0.584149 ... 0.965720 0.007595 0.851270 \n", + "3 -2.546250 0.230460 ... 0.860037 0.012140 0.717848 \n", + "4 -2.549582 0.347790 ... 0.932932 0.008269 0.655287 \n", + "5 -2.368879 0.386290 ... 0.983611 0.005720 0.795297 \n", + "6 -2.326501 0.085955 ... 0.210872 0.001047 0.763624 \n", + "7 -2.445781 0.248967 ... 0.943763 0.006818 0.699401 \n", + "8 1.978269 0.187628 ... 0.223373 0.004275 0.536589 \n", + "9 1.351191 -0.292334 ... 0.198005 0.012080 0.576409 \n", "\n", " f_2_sd f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_des \\\n", - "0 0.003456 0.001530 0.413975 0.413346 0.178305 \n", - "1 0.003389 0.000087 0.410947 0.410923 0.033201 \n", - "2 0.003351 0.000005 0.527414 0.527411 -0.011562 \n", - "3 0.003467 0.000979 0.423850 0.423444 -0.209979 \n", - "4 0.003525 0.053980 0.498330 0.465368 -0.085227 \n", - "5 0.003411 0.000412 0.422019 0.421865 0.075283 \n", - "6 0.003441 0.001167 0.444837 0.444724 -0.015120 \n", - "7 0.003467 0.004309 0.517441 0.515383 -0.186320 \n", - "8 0.003609 0.435210 0.304475 0.459785 -1.088985 \n", - "9 0.003564 0.000234 0.426380 0.426393 -1.676667 \n", + "0 0.003238 0.017963 0.029889 0.014268 2.807002 \n", + "1 0.003263 0.539504 0.540456 0.002281 2.478325 \n", + "2 0.003274 0.024524 0.033772 0.014316 2.495759 \n", + "3 0.003234 0.170332 0.177282 0.021604 2.546250 \n", + "4 0.003230 0.099837 0.101100 0.014011 2.549582 \n", + "5 0.003277 0.013258 0.023035 0.012126 2.368879 \n", + "6 0.003236 0.440892 0.441507 0.001058 2.326501 \n", + "7 0.003234 0.090640 0.091162 0.012202 2.445781 \n", + "8 0.003255 0.389221 0.391444 0.007985 -1.978269 \n", + "9 0.003309 0.436782 0.442695 0.026803 -1.351191 \n", "\n", " f_2_des f_1_des \n", - "0 0.497806 0.999076 \n", - "1 0.498624 0.999944 \n", - "2 0.479216 0.999998 \n", - "3 0.502842 0.999528 \n", - "4 0.503250 0.971561 \n", - "5 0.501330 0.999759 \n", - "6 0.521632 0.999086 \n", - "7 0.496926 0.997636 \n", - "8 0.408510 0.805314 \n", - "9 0.495832 0.999883 \n", + "0 0.439401 0.982080 \n", + "1 0.473701 0.409342 \n", + "2 0.427496 0.973315 \n", + "3 0.471224 0.872177 \n", + "4 0.456635 0.941201 \n", + "5 0.451863 0.989332 \n", + "6 0.489257 0.211920 \n", + "7 0.468919 0.950581 \n", + "8 0.476564 0.227648 \n", + "9 0.536477 0.210085 \n", "\n", "[10 rows x 21 columns]" ] @@ -1001,7 +1001,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -1013,7 +1013,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1044,12 +1044,12 @@ " \n", " \n", " 0\n", - " ideal\n", + " acceptable\n", " unacceptable\n", " \n", " \n", " 1\n", - " ideal\n", + " unacceptable\n", " unacceptable\n", " \n", " \n", @@ -1059,8 +1059,8 @@ " \n", " \n", " 3\n", - " ideal\n", " acceptable\n", + " unacceptable\n", " \n", " \n", " 4\n", @@ -1069,48 +1069,48 @@ " \n", " \n", " 5\n", - " ideal\n", + " acceptable\n", " unacceptable\n", " \n", " \n", " 6\n", - " ideal\n", + " unacceptable\n", " unacceptable\n", " \n", " \n", " 7\n", - " ideal\n", + " acceptable\n", " unacceptable\n", " \n", " \n", " 8\n", - " ideal\n", - " ideal\n", + " unacceptable\n", + " unacceptable\n", " \n", " \n", " 9\n", - " ideal\n", - " ideal\n", + " unacceptable\n", + " acceptable\n", " \n", " \n", "\n", "" ], "text/plain": [ - " f_1_pred f_1_true\n", - "0 ideal unacceptable\n", - "1 ideal unacceptable\n", - "2 acceptable unacceptable\n", - "3 ideal acceptable\n", - "4 acceptable unacceptable\n", - "5 ideal unacceptable\n", - "6 ideal unacceptable\n", - "7 ideal unacceptable\n", - "8 ideal ideal\n", - "9 ideal ideal" + " f_1_pred f_1_true\n", + "0 acceptable unacceptable\n", + "1 unacceptable unacceptable\n", + "2 acceptable unacceptable\n", + "3 acceptable unacceptable\n", + "4 acceptable unacceptable\n", + "5 acceptable unacceptable\n", + "6 unacceptable unacceptable\n", + "7 acceptable unacceptable\n", + "8 unacceptable unacceptable\n", + "9 unacceptable acceptable" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } diff --git a/tutorials/models_serial.ipynb b/tutorials/models_serial.ipynb index d3f73fca3..ecd206d64 100644 --- a/tutorials/models_serial.ipynb +++ b/tutorials/models_serial.ipynb @@ -25,7 +25,7 @@ "outputs": [], "source": [ "from bofire.data_models.domain.api import Outputs\n", - "from bofire.data_models.surrogates.api import SingleTaskGPSurrogate, RandomForestSurrogate, MixedSingleTaskGPSurrogate, AnySurrogate, RandomForestSurrogate, EmpiricalSurrogate, MLPEnsemble\n", + "from bofire.data_models.surrogates.api import SingleTaskGPSurrogate, RandomForestSurrogate, MixedSingleTaskGPSurrogate, AnySurrogate, RandomForestSurrogate, EmpiricalSurrogate, RegressionMLPEnsemble\n", "from bofire.benchmarks.single import Himmelblau\n", "from bofire.benchmarks.multi import CrossCoupling\n", "import bofire.surrogates.api as surrogates\n", @@ -521,7 +521,7 @@ ], "source": [ "# we setup the data model, here a Single Task GP\n", - "surrogate_data = MLPEnsemble(\n", + "surrogate_data = RegressionMLPEnsemble(\n", " inputs=input_features,\n", " outputs=output_features,\n", " n_estimators=2\n", From 4f756949011ed6fb1ff978a59ebe4613f15d10be Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 26 Jan 2024 13:40:30 -0500 Subject: [PATCH 12/31] Fix type changes and tutorials --- bofire/strategies/samplers/universal_constraint.py | 2 +- bofire/surrogates/mlp.py | 2 +- bofire/surrogates/surrogate.py | 9 +++++---- .../benchmarks/007-Benchmark_outlier_detection.ipynb | 10 ++++++++++ 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/bofire/strategies/samplers/universal_constraint.py b/bofire/strategies/samplers/universal_constraint.py index 8251827f6..f4990688a 100644 --- a/bofire/strategies/samplers/universal_constraint.py +++ b/bofire/strategies/samplers/universal_constraint.py @@ -52,4 +52,4 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: return samples def has_sufficient_experiments(self) -> bool: - return True \ No newline at end of file + return True diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 9e42b14ae..4839f48b6 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -136,7 +136,7 @@ def fit_mlp( lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, - loss_function: Union[nn.L1Loss, nn.CrossEntropyLoss] = nn.L1Loss, + loss_function=nn.L1Loss, ): """Fit a MLP to a dataset. diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index 750bf07c7..468cb1f90 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -49,7 +49,7 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: pred_cols = [] sd_cols = [] for featkey in self.outputs.get_keys(): - if isinstance(self.outputs.get_by_key(featkey), CategoricalOutput): + if hasattr(self.outputs.get_by_key(featkey), "categories"): pred_cols = pred_cols + [ f"{featkey}_{cat}_prob" for cat in self.outputs.get_by_key(featkey).categories @@ -90,8 +90,9 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: expected_cols = [] + check_columns = [] for featkey in self.outputs.get_keys(): - if isinstance(self.outputs.get_by_key(featkey), CategoricalOutput): + if hasattr(self.outputs.get_by_key(featkey), "categories"): expected_cols = ( expected_cols + [f"{featkey}_{t}" for t in ["pred", "sd"]] @@ -104,14 +105,14 @@ def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: for cat in self.outputs.get_by_key(featkey).categories ] ) - check_columns = [ + check_columns = check_columns + [ col for col in expected_cols if col != f"{featkey}_pred" ] else: expected_cols = expected_cols + [ f"{featkey}_{t}" for t in ["pred", "sd"] ] - check_columns = expected_cols + check_columns = check_columns + expected_cols if sorted(predictions.columns) != sorted(expected_cols): raise ValueError( f"Predictions are ill-formatted. Expected: {expected_cols}, got: {list(predictions.columns)}." diff --git a/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb b/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb index fa3bff9d1..7b65d776c 100644 --- a/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb +++ b/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb @@ -123,6 +123,16 @@ " noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR())\n", " scaler: ScalerEnum = ScalerEnum.NORMALIZE\n", "\n", + " @classmethod\n", + " def is_output_implemented(cls, my_type) -> bool:\n", + " \"\"\"Abstract method to check output type for surrogate models\n", + " Args:\n", + " my_type: continuous or categorical output\n", + " Returns:\n", + " bool: True if the output type is valid for the surrogate chosen, False otherwise\n", + " \"\"\"\n", + " return isinstance(my_type, ContinuousOutput)\n", + "\n", "\n", "class SingleTaskVariationalGPSurrogate(BotorchSurrogate1, TrainableSurrogate):\n", " def __init__(\n", From dbeb6b8bb78f0f80a21e5fe2ba810ce8479f8251 Mon Sep 17 00:00:00 2001 From: gmancino Date: Tue, 30 Jan 2024 10:44:31 -0500 Subject: [PATCH 13/31] Fix tests --- bofire/data_models/features/categorical.py | 11 +- .../samplers/universal_constraint.py | 4 +- bofire/surrogates/mlp.py | 4 +- tests/bofire/data_models/specs/features.py | 2 +- tests/bofire/strategies/doe/test_objective.py | 28 +- .../Unknown_Constraint_Classification.ipynb | 908 +++++++++--------- 6 files changed, 477 insertions(+), 480 deletions(-) diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 14ebfb3e3..eacf476a1 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -382,11 +382,8 @@ def validate_categories_unique(cls, categories: List[str]) -> List["str"]: raise ValueError("categories must be unique") return categories - @field_validator("objective") - @classmethod - def validate_objectives_unique( - cls, objective: AnyCategoricalObjective, info - ) -> AnyCategoricalObjective: + @model_validator(mode="after") + def validate_objectives_unique(self): """validates that categories have unique names Args: @@ -398,9 +395,9 @@ def validate_objectives_unique( Returns: Tuple[str]: Tuple of the categories """ - if objective.categories != info.data["categories"]: + if self.objective.categories != self.categories: raise ValueError("categories must match to objective categories") - return objective + return self @classmethod def from_objective( diff --git a/bofire/strategies/samplers/universal_constraint.py b/bofire/strategies/samplers/universal_constraint.py index f4990688a..fff8d524a 100644 --- a/bofire/strategies/samplers/universal_constraint.py +++ b/bofire/strategies/samplers/universal_constraint.py @@ -37,9 +37,7 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: fixed_experiments=self.candidates, ) - samples = samples.iloc[ - self.num_candidates :, - ] + samples = samples.iloc[self.num_candidates :,] samples = samples.sample( n=candidate_count, replace=False, diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 4839f48b6..0d8e54c47 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -136,7 +136,9 @@ def fit_mlp( lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, - loss_function=nn.L1Loss, + loss_function: Union[ + nn.modules.loss.L1Loss, nn.modules.loss.CrossEntropyLoss + ] = nn.L1Loss, ): """Fit a MLP to a dataset. diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index ce920ac7a..478780ff5 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -128,7 +128,7 @@ "categories": ["a", "b", "c"], "objective": ConstrainedCategoricalObjective( categories=["a", "b", "c"], desirability=[True, True, False] - ), + ).model_dump(), }, ) specs.add_valid( diff --git a/tests/bofire/strategies/doe/test_objective.py b/tests/bofire/strategies/doe/test_objective.py index 4a1761509..043a87782 100644 --- a/tests/bofire/strategies/doe/test_objective.py +++ b/tests/bofire/strategies/doe/test_objective.py @@ -68,9 +68,9 @@ def test_Objective_model_jacobian_t(): "x1", "x2", "x3", - "x1**2", - "x2**2", - "x3**2", + "x1 ** 2", + "x2 ** 2", + "x3 ** 2", "x1:x2", "x1:x3", "x2:x3", @@ -323,20 +323,20 @@ def test_Objective_model_jacobian_t(): columns=[ "1", "x1", - "x1**2", - "x1**3", + "x1 ** 2", + "x1 ** 3", "x2", - "x2**2", - "x2**3", + "x2 ** 2", + "x2 ** 3", "x3", - "x3**2", - "x3**3", + "x3 ** 2", + "x3 ** 3", "x4", - "x4**2", - "x4**3", + "x4 ** 2", + "x4 ** 3", "x5", - "x5**2", - "x5**3", + "x5 ** 2", + "x5 ** 3", "x2:x1", "x3:x1", "x4:x1", @@ -419,7 +419,7 @@ def test_DOptimality_instantiation(): assert isinstance(d_optimality.model, Formula) assert all( np.array(d_optimality.model, dtype=str) - == np.array(["1", "x1", "x2", "x3", "x3**2", "x1:x2"]) + == np.array(["1", "x1", "x2", "x3", "x3 ** 2", "x1:x2"]) ) x = np.array([[1, 2, 3], [1, 2, 3]]) diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 48d1f6ae6..61c0c22e2 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -102,243 +102,243 @@ " \n", " \n", " 0\n", - " 0.296884\n", - " 1.339136\n", - " -0.699016\n", - " -1.483660\n", - " -1.619869\n", + " -1.147520\n", + " 0.352595\n", + " 1.233487\n", + " 0.764528\n", + " 0.728527\n", " 0\n", - " 7.435109\n", - " ideal\n", - " 0.302020\n", + " 2.600000\n", + " acceptable\n", + " -1.140822\n", " \n", " \n", " 1\n", - " -1.198177\n", - " -1.871994\n", - " -0.096251\n", - " 1.753299\n", - " -1.822007\n", + " -0.541992\n", + " 1.030834\n", + " 0.544503\n", + " -0.451754\n", + " -1.379001\n", " 0\n", - " 13.140989\n", - " unacceptable\n", - " -1.188586\n", + " 3.722210\n", + " acceptable\n", + " -0.533687\n", " \n", " \n", " 2\n", - " -0.197863\n", - " 0.134761\n", - " -1.559977\n", - " -1.984249\n", - " 0.769223\n", - " 0\n", - " 10.049967\n", + " 1.620988\n", + " 1.569770\n", + " 0.510210\n", + " 0.478663\n", + " 0.577311\n", + " 1\n", + " 0.891793\n", " ideal\n", - " -0.189195\n", + " 1.626114\n", " \n", " \n", " 3\n", - " 0.423355\n", - " -0.333732\n", - " -0.992391\n", - " 0.021709\n", - " 1.187608\n", + " 0.603489\n", + " 1.560799\n", + " 1.147270\n", + " 0.860127\n", + " -0.984688\n", " 0\n", - " 1.452937\n", - " unacceptable\n", - " 0.424980\n", + " 1.315285\n", + " ideal\n", + " 0.606293\n", " \n", " \n", " 4\n", - " -1.827201\n", - " -1.653218\n", - " -1.266475\n", - " 1.432659\n", - " -1.158900\n", + " 1.115889\n", + " 0.675806\n", + " 1.156867\n", + " 0.238840\n", + " -0.270817\n", " 0\n", - " 12.936642\n", + " 0.259161\n", " unacceptable\n", - " -1.826852\n", + " 1.121470\n", " \n", " \n", " 5\n", - " 0.073122\n", - " -1.789563\n", - " 1.670582\n", - " 1.867550\n", - " 1.759474\n", + " 0.947684\n", + " -1.348468\n", + " 1.595431\n", + " -1.910428\n", + " 0.906305\n", " 0\n", - " 6.432953\n", + " 7.925056\n", " unacceptable\n", - " 0.080055\n", + " 0.956881\n", " \n", " \n", " 6\n", - " 0.847212\n", - " -1.236811\n", - " -0.533595\n", - " -1.842110\n", - " 1.425404\n", + " -1.560800\n", + " 1.897495\n", + " -0.659937\n", + " -0.922824\n", + " 1.988474\n", " 1\n", - " 7.133379\n", - " acceptable\n", - " 0.854993\n", + " 6.471322\n", + " unacceptable\n", + " -1.554441\n", " \n", " \n", " 7\n", - " -0.728056\n", - " 1.114750\n", - " 1.837209\n", - " 0.768175\n", - " 1.179493\n", - " 1\n", - " 1.895683\n", + " 1.765355\n", + " 0.208224\n", + " 0.475951\n", + " -0.412166\n", + " 1.013625\n", + " 0\n", + " 1.024378\n", " acceptable\n", - " -0.724770\n", + " 1.774402\n", " \n", " \n", " 8\n", - " -0.787473\n", - " -1.280905\n", - " 1.964343\n", - " 1.051243\n", - " -1.969924\n", - " 1\n", - " 9.828818\n", + " 0.817979\n", + " -1.398751\n", + " -1.613844\n", + " -1.626274\n", + " 1.718574\n", + " 0\n", + " 10.271230\n", " unacceptable\n", - " -0.778318\n", + " 0.821963\n", " \n", " \n", " 9\n", - " -1.805895\n", - " 1.544210\n", - " -0.926871\n", - " 1.780826\n", - " 1.477554\n", + " 1.624639\n", + " 0.591260\n", + " 1.056660\n", + " 1.964813\n", + " -0.609677\n", " 1\n", - " 6.990552\n", - " unacceptable\n", - " -1.797201\n", + " 1.434253\n", + " acceptable\n", + " 1.633391\n", " \n", " \n", " 10\n", - " -0.353654\n", - " 1.743341\n", - " -1.336186\n", - " -1.700009\n", - " 0.440015\n", - " 1\n", - " 7.701678\n", - " acceptable\n", - " -0.345705\n", + " 0.659178\n", + " -1.488256\n", + " -0.769463\n", + " -0.498631\n", + " 1.641154\n", + " 0\n", + " 4.016106\n", + " ideal\n", + " 0.666231\n", " \n", " \n", " 11\n", - " 1.181500\n", - " 1.070489\n", - " -0.555974\n", - " -0.604853\n", - " -0.795902\n", - " 0\n", - " 1.056973\n", + " 0.843863\n", + " -0.778241\n", + " -0.742882\n", + " 1.041363\n", + " -0.462427\n", + " 1\n", + " 1.118582\n", " acceptable\n", - " 1.189100\n", + " 0.845148\n", " \n", " \n", " 12\n", - " 0.920993\n", - " -0.776714\n", - " -0.928656\n", - " 0.050378\n", - " -0.887893\n", - " 1\n", - " 2.087592\n", + " 1.713291\n", + " -0.327819\n", + " 1.934123\n", + " -1.535219\n", + " -0.263203\n", + " 0\n", + " 4.643351\n", " acceptable\n", - " 0.929757\n", + " 1.715253\n", " \n", " \n", " 13\n", - " -0.599356\n", - " 1.893306\n", - " -1.413963\n", - " 0.700251\n", - " 1.361780\n", - " 0\n", - " 4.407259\n", + " 0.425012\n", + " -0.428886\n", + " 0.678414\n", + " 0.867640\n", + " 1.947507\n", + " 1\n", + " 1.058468\n", " acceptable\n", - " -0.593129\n", + " 0.426804\n", " \n", " \n", " 14\n", - " -1.178618\n", - " -1.807959\n", - " 1.643023\n", - " -1.367653\n", - " -1.300478\n", + " -1.783885\n", + " 0.597387\n", + " 0.018537\n", + " -0.867099\n", + " 1.208851\n", " 0\n", - " 12.031568\n", - " unacceptable\n", - " -1.170357\n", + " 6.139702\n", + " acceptable\n", + " -1.775463\n", " \n", " \n", " 15\n", - " 0.012359\n", - " 0.674245\n", - " -1.080136\n", - " -0.569172\n", - " -1.401835\n", + " 0.840074\n", + " -0.300067\n", + " -1.617072\n", + " -1.733238\n", + " -1.571214\n", " 0\n", - " 4.597730\n", + " 11.558383\n", " acceptable\n", - " 0.022354\n", + " 0.841770\n", " \n", " \n", " 16\n", - " -1.161135\n", - " -1.765019\n", - " 1.158920\n", - " 1.827956\n", - " -1.107484\n", - " 0\n", - " 8.805522\n", - " unacceptable\n", - " -1.157363\n", + " 1.666942\n", + " -0.941869\n", + " 0.904134\n", + " 0.685048\n", + " 1.717750\n", + " 1\n", + " 1.776970\n", + " acceptable\n", + " 1.668979\n", " \n", " \n", " 17\n", - " -1.750851\n", - " 0.165063\n", - " -1.370027\n", - " 0.828754\n", - " -0.325600\n", - " 0\n", - " 7.730277\n", - " acceptable\n", - " -1.742488\n", + " 1.396363\n", + " 1.807155\n", + " 0.594131\n", + " -1.607430\n", + " 1.190373\n", + " 1\n", + " 4.328181\n", + " unacceptable\n", + " 1.402815\n", " \n", " \n", " 18\n", - " -1.695762\n", - " 1.902611\n", - " -1.071329\n", - " -0.927217\n", - " -0.700890\n", - " 1\n", - " 7.765145\n", - " acceptable\n", - " -1.693954\n", + " -0.583546\n", + " 1.973070\n", + " -1.911667\n", + " -1.225064\n", + " -1.323822\n", + " 0\n", + " 11.085039\n", + " unacceptable\n", + " -0.574170\n", " \n", " \n", " 19\n", - " -0.528461\n", - " -1.958166\n", - " 0.104057\n", - " 0.300671\n", - " -1.417810\n", - " 0\n", - " 9.545906\n", - " acceptable\n", - " -0.525143\n", + " -0.133564\n", + " -0.037052\n", + " 1.349117\n", + " 0.108222\n", + " -0.739183\n", + " 1\n", + " 1.385558\n", + " unacceptable\n", + " -0.129039\n", " \n", " \n", "\n", @@ -346,48 +346,48 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 0.296884 1.339136 -0.699016 -1.483660 -1.619869 0 7.435109 \n", - "1 -1.198177 -1.871994 -0.096251 1.753299 -1.822007 0 13.140989 \n", - "2 -0.197863 0.134761 -1.559977 -1.984249 0.769223 0 10.049967 \n", - "3 0.423355 -0.333732 -0.992391 0.021709 1.187608 0 1.452937 \n", - "4 -1.827201 -1.653218 -1.266475 1.432659 -1.158900 0 12.936642 \n", - "5 0.073122 -1.789563 1.670582 1.867550 1.759474 0 6.432953 \n", - "6 0.847212 -1.236811 -0.533595 -1.842110 1.425404 1 7.133379 \n", - "7 -0.728056 1.114750 1.837209 0.768175 1.179493 1 1.895683 \n", - "8 -0.787473 -1.280905 1.964343 1.051243 -1.969924 1 9.828818 \n", - "9 -1.805895 1.544210 -0.926871 1.780826 1.477554 1 6.990552 \n", - "10 -0.353654 1.743341 -1.336186 -1.700009 0.440015 1 7.701678 \n", - "11 1.181500 1.070489 -0.555974 -0.604853 -0.795902 0 1.056973 \n", - "12 0.920993 -0.776714 -0.928656 0.050378 -0.887893 1 2.087592 \n", - "13 -0.599356 1.893306 -1.413963 0.700251 1.361780 0 4.407259 \n", - "14 -1.178618 -1.807959 1.643023 -1.367653 -1.300478 0 12.031568 \n", - "15 0.012359 0.674245 -1.080136 -0.569172 -1.401835 0 4.597730 \n", - "16 -1.161135 -1.765019 1.158920 1.827956 -1.107484 0 8.805522 \n", - "17 -1.750851 0.165063 -1.370027 0.828754 -0.325600 0 7.730277 \n", - "18 -1.695762 1.902611 -1.071329 -0.927217 -0.700890 1 7.765145 \n", - "19 -0.528461 -1.958166 0.104057 0.300671 -1.417810 0 9.545906 \n", + "0 -1.147520 0.352595 1.233487 0.764528 0.728527 0 2.600000 \n", + "1 -0.541992 1.030834 0.544503 -0.451754 -1.379001 0 3.722210 \n", + "2 1.620988 1.569770 0.510210 0.478663 0.577311 1 0.891793 \n", + "3 0.603489 1.560799 1.147270 0.860127 -0.984688 0 1.315285 \n", + "4 1.115889 0.675806 1.156867 0.238840 -0.270817 0 0.259161 \n", + "5 0.947684 -1.348468 1.595431 -1.910428 0.906305 0 7.925056 \n", + "6 -1.560800 1.897495 -0.659937 -0.922824 1.988474 1 6.471322 \n", + "7 1.765355 0.208224 0.475951 -0.412166 1.013625 0 1.024378 \n", + "8 0.817979 -1.398751 -1.613844 -1.626274 1.718574 0 10.271230 \n", + "9 1.624639 0.591260 1.056660 1.964813 -0.609677 1 1.434253 \n", + "10 0.659178 -1.488256 -0.769463 -0.498631 1.641154 0 4.016106 \n", + "11 0.843863 -0.778241 -0.742882 1.041363 -0.462427 1 1.118582 \n", + "12 1.713291 -0.327819 1.934123 -1.535219 -0.263203 0 4.643351 \n", + "13 0.425012 -0.428886 0.678414 0.867640 1.947507 1 1.058468 \n", + "14 -1.783885 0.597387 0.018537 -0.867099 1.208851 0 6.139702 \n", + "15 0.840074 -0.300067 -1.617072 -1.733238 -1.571214 0 11.558383 \n", + "16 1.666942 -0.941869 0.904134 0.685048 1.717750 1 1.776970 \n", + "17 1.396363 1.807155 0.594131 -1.607430 1.190373 1 4.328181 \n", + "18 -0.583546 1.973070 -1.911667 -1.225064 -1.323822 0 11.085039 \n", + "19 -0.133564 -0.037052 1.349117 0.108222 -0.739183 1 1.385558 \n", "\n", " f_1 f_2 \n", - "0 ideal 0.302020 \n", - "1 unacceptable -1.188586 \n", - "2 ideal -0.189195 \n", - "3 unacceptable 0.424980 \n", - "4 unacceptable -1.826852 \n", - "5 unacceptable 0.080055 \n", - "6 acceptable 0.854993 \n", - "7 acceptable -0.724770 \n", - "8 unacceptable -0.778318 \n", - "9 unacceptable -1.797201 \n", - "10 acceptable -0.345705 \n", - "11 acceptable 1.189100 \n", - "12 acceptable 0.929757 \n", - "13 acceptable -0.593129 \n", - "14 unacceptable -1.170357 \n", - "15 acceptable 0.022354 \n", - "16 unacceptable -1.157363 \n", - "17 acceptable -1.742488 \n", - "18 acceptable -1.693954 \n", - "19 acceptable -0.525143 " + "0 acceptable -1.140822 \n", + "1 acceptable -0.533687 \n", + "2 ideal 1.626114 \n", + "3 ideal 0.606293 \n", + "4 unacceptable 1.121470 \n", + "5 unacceptable 0.956881 \n", + "6 unacceptable -1.554441 \n", + "7 acceptable 1.774402 \n", + "8 unacceptable 0.821963 \n", + "9 acceptable 1.633391 \n", + "10 ideal 0.666231 \n", + "11 acceptable 0.845148 \n", + "12 acceptable 1.715253 \n", + "13 acceptable 0.426804 \n", + "14 acceptable -1.775463 \n", + "15 acceptable 0.841770 \n", + "16 acceptable 1.668979 \n", + "17 unacceptable 1.402815 \n", + "18 unacceptable -0.574170 \n", + "19 unacceptable -0.129039 " ] }, "execution_count": 3, @@ -511,16 +511,16 @@ " \n", " \n", " 0\n", - " 0.795\n", - " 0.795\n", + " 0.77\n", + " 0.77\n", " \n", " \n", "\n", "" ], "text/plain": [ - " ACCURACY F1\n", - "0 0.795 0.795" + " ACCURACY F1\n", + "0 0.77 0.77" ] }, "execution_count": 5, @@ -566,8 +566,8 @@ " \n", " \n", " 0\n", - " 0.52\n", - " 0.52\n", + " 0.46\n", + " 0.46\n", " \n", " \n", "\n", @@ -575,7 +575,7 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.52 0.52" + "0 0.46 0.46" ] }, "execution_count": 6, @@ -685,243 +685,243 @@ " \n", " \n", " 0\n", - " 0.481400\n", - " 0.586632\n", - " 0.063636\n", - " 0.087314\n", - " 0.424632\n", - " 1\n", + " 0.098607\n", + " 1.274078\n", + " 2.000000\n", + " 0.749040\n", + " -0.156863\n", + " 0\n", " acceptable\n", " 0.0\n", - " -2.807002\n", - " 0.487184\n", + " -0.598172\n", + " 0.102933\n", " ...\n", - " 0.974527\n", - " 0.007553\n", - " 0.667668\n", - " 0.003238\n", - " 0.017963\n", - " 0.029889\n", - " 0.014268\n", - " 2.807002\n", - " 0.439401\n", - " 0.982080\n", + " 0.999879\n", + " 2.749825e-07\n", + " 0.830174\n", + " 0.003029\n", + " 0.000270\n", + " 0.000271\n", + " 6.091118e-07\n", + " 0.598172\n", + " 0.487136\n", + " 0.999879\n", " \n", " \n", " 1\n", - " 0.205385\n", - " 0.435466\n", - " 0.386084\n", - " 0.820992\n", - " 0.562380\n", + " 0.646990\n", + " 0.604150\n", + " -0.341737\n", + " 1.307168\n", + " -0.151535\n", " 0\n", - " unacceptable\n", + " acceptable\n", " 0.0\n", - " -2.478325\n", - " 0.210586\n", + " -0.465931\n", + " 0.650499\n", " ...\n", - " 0.408227\n", - " 0.001115\n", - " 0.736828\n", - " 0.003263\n", - " 0.539504\n", - " 0.540456\n", - " 0.002281\n", - " 2.478325\n", - " 0.473701\n", - " 0.409342\n", + " 0.979470\n", + " 1.731224e-02\n", + " 0.241746\n", + " 0.002957\n", + " 0.007131\n", + " 0.045842\n", + " 3.871093e-02\n", + " 0.465931\n", + " 0.419397\n", + " 0.996782\n", " \n", " \n", " 2\n", - " 0.578421\n", - " 0.790867\n", - " -0.054982\n", - " 0.602806\n", - " 0.624962\n", - " 1\n", + " -0.268094\n", + " 2.000000\n", + " 2.000000\n", + " 0.417189\n", + " -0.160303\n", + " 0\n", " acceptable\n", " 0.0\n", - " -2.495759\n", - " 0.584149\n", + " -0.095701\n", + " -0.263037\n", " ...\n", - " 0.965720\n", - " 0.007595\n", - " 0.851270\n", - " 0.003274\n", - " 0.024524\n", - " 0.033772\n", - " 0.014316\n", - " 2.495759\n", - " 0.427496\n", - " 0.973315\n", + " 0.999936\n", + " 9.786710e-08\n", + " 1.129313\n", + " 0.003170\n", + " 0.000139\n", + " 0.000139\n", + " 2.182845e-07\n", + " 0.095701\n", + " 0.532832\n", + " 0.999937\n", " \n", " \n", " 3\n", - " 0.225033\n", - " 0.648267\n", - " 0.115931\n", - " 0.153590\n", - " 0.527780\n", + " -0.407627\n", + " 2.000000\n", + " 2.000000\n", + " 2.000000\n", + " -0.147082\n", " 1\n", " acceptable\n", " 0.0\n", - " -2.546250\n", - " 0.230460\n", + " 1.523601\n", + " -0.403380\n", " ...\n", - " 0.860037\n", - " 0.012140\n", - " 0.717848\n", - " 0.003234\n", - " 0.170332\n", - " 0.177282\n", - " 0.021604\n", - " 2.546250\n", - " 0.471224\n", - " 0.872177\n", + " 0.808412\n", + " 1.612794e-07\n", + " 1.637777\n", + " 0.003359\n", + " 0.416360\n", + " 0.416360\n", + " 2.241354e-07\n", + " -1.523601\n", + " 0.550252\n", + " 0.808412\n", " \n", " \n", " 4\n", - " 0.342310\n", - " 0.498711\n", - " 0.047750\n", - " 0.008932\n", - " 0.668452\n", - " 1\n", + " 0.928282\n", + " 0.986890\n", + " -0.608701\n", + " 1.760348\n", + " -0.144849\n", + " 0\n", " acceptable\n", " 0.0\n", - " -2.549582\n", - " 0.347790\n", + " -0.305604\n", + " 0.931805\n", " ...\n", - " 0.932932\n", - " 0.008269\n", - " 0.655287\n", - " 0.003230\n", - " 0.099837\n", - " 0.101100\n", - " 0.014011\n", - " 2.549582\n", - " 0.456635\n", - " 0.941201\n", + " 0.967507\n", + " 2.830622e-02\n", + " 0.478486\n", + " 0.003040\n", + " 0.009351\n", + " 0.072646\n", + " 6.329445e-02\n", + " 0.305604\n", + " 0.385587\n", + " 0.995813\n", " \n", " \n", " 5\n", - " 0.380575\n", - " 1.040915\n", - " 0.214337\n", - " 0.069153\n", - " 0.753324\n", - " 1\n", + " -0.171500\n", + " 1.702729\n", + " 2.000000\n", + " 1.147333\n", + " -0.205059\n", + " 0\n", " acceptable\n", " 0.0\n", - " -2.368879\n", - " 0.386290\n", + " -0.176775\n", + " -0.167058\n", " ...\n", - " 0.983611\n", - " 0.005720\n", - " 0.795297\n", - " 0.003277\n", - " 0.013258\n", - " 0.023035\n", - " 0.012126\n", - " 2.368879\n", - " 0.451863\n", - " 0.989332\n", + " 0.999999\n", + " 8.229916e-09\n", + " 1.045623\n", + " 0.003128\n", + " 0.000003\n", + " 0.000003\n", + " 1.748655e-08\n", + " 0.176775\n", + " 0.520870\n", + " 0.999999\n", " \n", " \n", " 6\n", - " 0.080792\n", - " 0.260320\n", - " 0.331997\n", - " 0.669731\n", - " 0.425205\n", - " 1\n", - " unacceptable\n", + " 0.348453\n", + " 0.749030\n", + " 2.000000\n", + " 0.805095\n", + " -0.119394\n", + " 0\n", + " acceptable\n", " 0.0\n", - " -2.326501\n", - " 0.085955\n", + " -0.598377\n", + " 0.352705\n", " ...\n", - " 0.210872\n", - " 0.001047\n", - " 0.763624\n", - " 0.003236\n", - " 0.440892\n", - " 0.441507\n", - " 0.001058\n", - " 2.326501\n", - " 0.489257\n", - " 0.211920\n", + " 0.999376\n", + " 1.184046e-06\n", + " 0.710733\n", + " 0.002993\n", + " 0.001392\n", + " 0.001395\n", + " 2.614105e-06\n", + " 0.598377\n", + " 0.456026\n", + " 0.999377\n", " \n", " \n", " 7\n", - " 0.243598\n", - " 0.698221\n", - " -0.014641\n", - " 0.092352\n", - " 0.708394\n", - " 1\n", + " -2.000000\n", + " 0.457609\n", + " 1.675461\n", + " 1.207754\n", + " 0.243209\n", + " 0\n", " acceptable\n", " 0.0\n", - " -2.445781\n", - " 0.248967\n", + " 4.752548\n", + " -1.994210\n", " ...\n", - " 0.943763\n", - " 0.006818\n", - " 0.699401\n", - " 0.003234\n", - " 0.090640\n", - " 0.091162\n", - " 0.012202\n", - " 2.445781\n", - " 0.468919\n", - " 0.950581\n", + " 0.949335\n", + " 1.139145e-04\n", + " 0.807708\n", + " 0.005353\n", + " 0.113034\n", + " 0.113289\n", + " 2.547206e-04\n", + " -4.752548\n", + " 0.730489\n", + " 0.949449\n", " \n", " \n", " 8\n", - " 0.182756\n", - " -0.921039\n", - " -0.549821\n", - " 0.689784\n", - " -0.655921\n", - " 1\n", - " unacceptable\n", + " 0.126459\n", + " 1.752512\n", + " 2.000000\n", + " 0.269535\n", + " -0.025944\n", + " 0\n", + " acceptable\n", " 0.0\n", - " 1.978269\n", - " 0.187628\n", + " -0.373769\n", + " 0.130896\n", " ...\n", - " 0.223373\n", - " 0.004275\n", - " 0.536589\n", - " 0.003255\n", - " 0.389221\n", - " 0.391444\n", - " 0.007985\n", - " -1.978269\n", - " 0.476564\n", - " 0.227648\n", + " 0.998390\n", + " 1.199743e-06\n", + " 1.007492\n", + " 0.003081\n", + " 0.003593\n", + " 0.003596\n", + " 2.680733e-06\n", + " 0.373769\n", + " 0.483644\n", + " 0.998391\n", " \n", " \n", " 9\n", - " -0.297385\n", - " -0.422154\n", - " 1.407541\n", - " 0.605559\n", - " 1.597342\n", - " 1\n", - " unacceptable\n", + " 0.106551\n", + " 1.362876\n", + " 2.000000\n", + " 1.558184\n", + " -0.132843\n", + " 0\n", + " acceptable\n", " 0.0\n", - " 1.351191\n", - " -0.292334\n", + " -0.100465\n", + " 0.110338\n", " ...\n", - " 0.198005\n", - " 0.012080\n", - " 0.576409\n", - " 0.003309\n", - " 0.436782\n", - " 0.442695\n", - " 0.026803\n", - " -1.351191\n", - " 0.536477\n", - " 0.210085\n", + " 0.999999\n", + " 6.511140e-09\n", + " 1.039770\n", + " 0.003075\n", + " 0.000001\n", + " 0.000001\n", + " 1.175186e-08\n", + " 0.100465\n", + " 0.486211\n", + " 0.999999\n", " \n", " \n", "\n", @@ -929,53 +929,53 @@ "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", - "0 0.481400 0.586632 0.063636 0.087314 0.424632 1 acceptable 0.0 \n", - "1 0.205385 0.435466 0.386084 0.820992 0.562380 0 unacceptable 0.0 \n", - "2 0.578421 0.790867 -0.054982 0.602806 0.624962 1 acceptable 0.0 \n", - "3 0.225033 0.648267 0.115931 0.153590 0.527780 1 acceptable 0.0 \n", - "4 0.342310 0.498711 0.047750 0.008932 0.668452 1 acceptable 0.0 \n", - "5 0.380575 1.040915 0.214337 0.069153 0.753324 1 acceptable 0.0 \n", - "6 0.080792 0.260320 0.331997 0.669731 0.425205 1 unacceptable 0.0 \n", - "7 0.243598 0.698221 -0.014641 0.092352 0.708394 1 acceptable 0.0 \n", - "8 0.182756 -0.921039 -0.549821 0.689784 -0.655921 1 unacceptable 0.0 \n", - "9 -0.297385 -0.422154 1.407541 0.605559 1.597342 1 unacceptable 0.0 \n", + " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", + "0 0.098607 1.274078 2.000000 0.749040 -0.156863 0 acceptable 0.0 \n", + "1 0.646990 0.604150 -0.341737 1.307168 -0.151535 0 acceptable 0.0 \n", + "2 -0.268094 2.000000 2.000000 0.417189 -0.160303 0 acceptable 0.0 \n", + "3 -0.407627 2.000000 2.000000 2.000000 -0.147082 1 acceptable 0.0 \n", + "4 0.928282 0.986890 -0.608701 1.760348 -0.144849 0 acceptable 0.0 \n", + "5 -0.171500 1.702729 2.000000 1.147333 -0.205059 0 acceptable 0.0 \n", + "6 0.348453 0.749030 2.000000 0.805095 -0.119394 0 acceptable 0.0 \n", + "7 -2.000000 0.457609 1.675461 1.207754 0.243209 0 acceptable 0.0 \n", + "8 0.126459 1.752512 2.000000 0.269535 -0.025944 0 acceptable 0.0 \n", + "9 0.106551 1.362876 2.000000 1.558184 -0.132843 0 acceptable 0.0 \n", "\n", " f_0_pred f_2_pred ... f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", - "0 -2.807002 0.487184 ... 0.974527 0.007553 0.667668 \n", - "1 -2.478325 0.210586 ... 0.408227 0.001115 0.736828 \n", - "2 -2.495759 0.584149 ... 0.965720 0.007595 0.851270 \n", - "3 -2.546250 0.230460 ... 0.860037 0.012140 0.717848 \n", - "4 -2.549582 0.347790 ... 0.932932 0.008269 0.655287 \n", - "5 -2.368879 0.386290 ... 0.983611 0.005720 0.795297 \n", - "6 -2.326501 0.085955 ... 0.210872 0.001047 0.763624 \n", - "7 -2.445781 0.248967 ... 0.943763 0.006818 0.699401 \n", - "8 1.978269 0.187628 ... 0.223373 0.004275 0.536589 \n", - "9 1.351191 -0.292334 ... 0.198005 0.012080 0.576409 \n", + "0 -0.598172 0.102933 ... 0.999879 2.749825e-07 0.830174 \n", + "1 -0.465931 0.650499 ... 0.979470 1.731224e-02 0.241746 \n", + "2 -0.095701 -0.263037 ... 0.999936 9.786710e-08 1.129313 \n", + "3 1.523601 -0.403380 ... 0.808412 1.612794e-07 1.637777 \n", + "4 -0.305604 0.931805 ... 0.967507 2.830622e-02 0.478486 \n", + "5 -0.176775 -0.167058 ... 0.999999 8.229916e-09 1.045623 \n", + "6 -0.598377 0.352705 ... 0.999376 1.184046e-06 0.710733 \n", + "7 4.752548 -1.994210 ... 0.949335 1.139145e-04 0.807708 \n", + "8 -0.373769 0.130896 ... 0.998390 1.199743e-06 1.007492 \n", + "9 -0.100465 0.110338 ... 0.999999 6.511140e-09 1.039770 \n", "\n", " f_2_sd f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_des \\\n", - "0 0.003238 0.017963 0.029889 0.014268 2.807002 \n", - "1 0.003263 0.539504 0.540456 0.002281 2.478325 \n", - "2 0.003274 0.024524 0.033772 0.014316 2.495759 \n", - "3 0.003234 0.170332 0.177282 0.021604 2.546250 \n", - "4 0.003230 0.099837 0.101100 0.014011 2.549582 \n", - "5 0.003277 0.013258 0.023035 0.012126 2.368879 \n", - "6 0.003236 0.440892 0.441507 0.001058 2.326501 \n", - "7 0.003234 0.090640 0.091162 0.012202 2.445781 \n", - "8 0.003255 0.389221 0.391444 0.007985 -1.978269 \n", - "9 0.003309 0.436782 0.442695 0.026803 -1.351191 \n", + "0 0.003029 0.000270 0.000271 6.091118e-07 0.598172 \n", + "1 0.002957 0.007131 0.045842 3.871093e-02 0.465931 \n", + "2 0.003170 0.000139 0.000139 2.182845e-07 0.095701 \n", + "3 0.003359 0.416360 0.416360 2.241354e-07 -1.523601 \n", + "4 0.003040 0.009351 0.072646 6.329445e-02 0.305604 \n", + "5 0.003128 0.000003 0.000003 1.748655e-08 0.176775 \n", + "6 0.002993 0.001392 0.001395 2.614105e-06 0.598377 \n", + "7 0.005353 0.113034 0.113289 2.547206e-04 -4.752548 \n", + "8 0.003081 0.003593 0.003596 2.680733e-06 0.373769 \n", + "9 0.003075 0.000001 0.000001 1.175186e-08 0.100465 \n", "\n", " f_2_des f_1_des \n", - "0 0.439401 0.982080 \n", - "1 0.473701 0.409342 \n", - "2 0.427496 0.973315 \n", - "3 0.471224 0.872177 \n", - "4 0.456635 0.941201 \n", - "5 0.451863 0.989332 \n", - "6 0.489257 0.211920 \n", - "7 0.468919 0.950581 \n", - "8 0.476564 0.227648 \n", - "9 0.536477 0.210085 \n", + "0 0.487136 0.999879 \n", + "1 0.419397 0.996782 \n", + "2 0.532832 0.999937 \n", + "3 0.550252 0.808412 \n", + "4 0.385587 0.995813 \n", + "5 0.520870 0.999999 \n", + "6 0.456026 0.999377 \n", + "7 0.730489 0.949449 \n", + "8 0.483644 0.998391 \n", + "9 0.486211 0.999999 \n", "\n", "[10 rows x 21 columns]" ] @@ -1045,17 +1045,17 @@ " \n", " 0\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 1\n", - " unacceptable\n", + " acceptable\n", " unacceptable\n", " \n", " \n", " 2\n", " acceptable\n", - " unacceptable\n", + " ideal\n", " \n", " \n", " 3\n", @@ -1065,49 +1065,49 @@ " \n", " 4\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 5\n", " acceptable\n", - " unacceptable\n", + " ideal\n", " \n", " \n", " 6\n", - " unacceptable\n", - " unacceptable\n", + " acceptable\n", + " acceptable\n", " \n", " \n", " 7\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 8\n", - " unacceptable\n", - " unacceptable\n", + " acceptable\n", + " acceptable\n", " \n", " \n", " 9\n", - " unacceptable\n", " acceptable\n", + " ideal\n", " \n", " \n", "\n", "" ], "text/plain": [ - " f_1_pred f_1_true\n", - "0 acceptable unacceptable\n", - "1 unacceptable unacceptable\n", - "2 acceptable unacceptable\n", - "3 acceptable unacceptable\n", - "4 acceptable unacceptable\n", - "5 acceptable unacceptable\n", - "6 unacceptable unacceptable\n", - "7 acceptable unacceptable\n", - "8 unacceptable unacceptable\n", - "9 unacceptable acceptable" + " f_1_pred f_1_true\n", + "0 acceptable acceptable\n", + "1 acceptable unacceptable\n", + "2 acceptable ideal\n", + "3 acceptable unacceptable\n", + "4 acceptable acceptable\n", + "5 acceptable ideal\n", + "6 acceptable acceptable\n", + "7 acceptable acceptable\n", + "8 acceptable acceptable\n", + "9 acceptable ideal" ] }, "execution_count": 10, From 2021d73fca51e7c1ad0e65e1492c6bb7ceeae3ba Mon Sep 17 00:00:00 2001 From: gmancino Date: Tue, 30 Jan 2024 11:18:35 -0500 Subject: [PATCH 14/31] Update Tanimoto GP --- bofire/data_models/surrogates/mixed_tanimoto_gp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bofire/data_models/surrogates/mixed_tanimoto_gp.py b/bofire/data_models/surrogates/mixed_tanimoto_gp.py index 57a3fbe23..445cafddc 100644 --- a/bofire/data_models/surrogates/mixed_tanimoto_gp.py +++ b/bofire/data_models/surrogates/mixed_tanimoto_gp.py @@ -1,8 +1,9 @@ -from typing import Literal +from typing import Literal, Type from pydantic import Field, validator # from bofire.data_models.enum import MolecularEncodingEnum +from bofire.data_models.features.api import AnyOutput, ContinuousOutput from bofire.data_models.kernels.api import ( AnyCategoricalKernal, AnyContinuousKernel, @@ -45,6 +46,16 @@ class MixedTanimotoGPSurrogate(TrainableBotorchSurrogate): scaler: ScalerEnum = ScalerEnum.NORMALIZE noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR()) + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + Args: + my_type: continuous or categorical output + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return isinstance(my_type, ContinuousOutput) + @validator("input_preprocessing_specs") def validate_moleculars(cls, v, values): """Checks that at least one of fingerprints, fragments, or fingerprintsfragments features are present.""" From 20ec136ffdf8bf662c9f332aac2f25bdfbf85116 Mon Sep 17 00:00:00 2001 From: gmancino Date: Tue, 30 Jan 2024 11:25:13 -0500 Subject: [PATCH 15/31] Fix black version --- bofire/benchmarks/single.py | 50 +++++++++++-------- bofire/data_models/domain/features.py | 16 +++--- bofire/data_models/features/molecular.py | 8 +-- .../base_fingerprint_kernel.py | 1 + .../fingerprint_kernels/tanimoto_kernel.py | 2 + bofire/kernels/mapper.py | 40 +++++++++------ bofire/runners/hyperoptimize.py | 8 +-- bofire/strategies/predictives/botorch.py | 16 +++--- bofire/strategies/predictives/mobo.py | 8 +-- bofire/strategies/predictives/predictive.py | 24 +++++---- bofire/strategies/predictives/sobo.py | 24 +++++---- bofire/strategies/samplers/polytope.py | 6 +-- .../samplers/universal_constraint.py | 4 +- bofire/surrogates/diagnostics.py | 8 +-- bofire/surrogates/fully_bayesian.py | 8 +-- bofire/surrogates/mixed_single_task_gp.py | 8 +-- bofire/surrogates/single_task_gp.py | 8 +-- tests/bofire/benchmarks/test_single.py | 2 +- .../domain/test_domain_validators.py | 8 +-- tests/bofire/strategies/test_qehvi.py | 8 +-- tests/bofire/utils/test_torch_tools.py | 22 ++++---- 21 files changed, 170 insertions(+), 109 deletions(-) diff --git a/bofire/benchmarks/single.py b/bofire/benchmarks/single.py index def5ea3f7..c5d32fb4e 100644 --- a/bofire/benchmarks/single.py +++ b/bofire/benchmarks/single.py @@ -175,18 +175,20 @@ def __init__(self, dim: int = 6, allowed_k: Optional[int] = None, **kwargs) -> N outputs=Outputs( features=[ContinuousOutput(key="y", objective=MinimizeObjective())] ), - constraints=Constraints( - constraints=[ - NChooseKConstraint( - features=[f"x_{i}" for i in range(dim)], - min_count=0, - max_count=allowed_k, - none_also_valid=True, - ) - ] - ) - if allowed_k - else Constraints(), + constraints=( + Constraints( + constraints=[ + NChooseKConstraint( + features=[f"x_{i}" for i in range(dim)], + min_count=0, + max_count=allowed_k, + none_also_valid=True, + ) + ] + ) + if allowed_k + else Constraints() + ), ) self._hartmann = botorch_hartmann(dim=dim) @@ -227,21 +229,25 @@ def __init__(self, locality_factor: Optional[float] = None, **kwargs) -> None: key="x_1", bounds=(-5.0, 10), local_relative_bounds=( - 0.5 * locality_factor, - 0.5 * locality_factor, - ) - if locality_factor is not None - else (math.inf, math.inf), + ( + 0.5 * locality_factor, + 0.5 * locality_factor, + ) + if locality_factor is not None + else (math.inf, math.inf) + ), ), ContinuousInput( key="x_2", bounds=(0.0, 15.0), local_relative_bounds=( - 1.5 * locality_factor, - 1.5 * locality_factor, - ) - if locality_factor is not None - else (math.inf, math.inf), + ( + 1.5 * locality_factor, + 1.5 * locality_factor, + ) + if locality_factor is not None + else (math.inf, math.inf) + ), ), ] ), diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index a0af995af..e48e65596 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -570,9 +570,11 @@ def get_bounds( lo, up = feat.get_bounds( transform_type=specs.get(feat.key), # type: ignore values=experiments[feat.key] if experiments is not None else None, - reference_value=reference_experiment[feat.key] # type: ignore - if reference_experiment is not None - else None, + reference_value=( + reference_experiment[feat.key] # type: ignore + if reference_experiment is not None + else None + ), ) lower += lo upper += up @@ -697,9 +699,11 @@ def __call__( and not isinstance(feat, CategoricalOutput) ] + [ - pd.Series(data=feat(experiments.filter(regex=f"{feat.key}(.*)_prob")), name=f"{feat.key}_pred") # type: ignore - if predictions - else experiments[feat.key] + ( + pd.Series(data=feat(experiments.filter(regex=f"{feat.key}(.*)_prob")), name=f"{feat.key}_pred") # type: ignore + if predictions + else experiments[feat.key] + ) for feat in self.features if feat.objective is not None and isinstance(feat, CategoricalOutput) ], diff --git a/bofire/data_models/features/molecular.py b/bofire/data_models/features/molecular.py index 3146c2b3c..b47a7cab8 100644 --- a/bofire/data_models/features/molecular.py +++ b/bofire/data_models/features/molecular.py @@ -163,9 +163,11 @@ def get_bounds( # else we return the complete bounds data = self.to_descriptor_encoding( transform_type=transform_type, - values=pd.Series(self.get_allowed_categories()) - if values is None - else pd.Series(self.categories), + values=( + pd.Series(self.get_allowed_categories()) + if values is None + else pd.Series(self.categories) + ), ) lower = data.min(axis=0).values.tolist() upper = data.max(axis=0).values.tolist() diff --git a/bofire/kernels/fingerprint_kernels/base_fingerprint_kernel.py b/bofire/kernels/fingerprint_kernels/base_fingerprint_kernel.py index 5cfee3f9e..69d1ce32e 100644 --- a/bofire/kernels/fingerprint_kernels/base_fingerprint_kernel.py +++ b/bofire/kernels/fingerprint_kernels/base_fingerprint_kernel.py @@ -2,6 +2,7 @@ Module for test_kernels that operate on fingerprint representations (bit vectors or count vectors). Author: Ryan-Rhys Griffiths and Austin Tripp 2022 """ + # This code was copied from GAUCHE: https://github.com/leojklarner/gauche/blob/main/gauche/kernels/fingerprint_kernels/base_fingerprint_kernel.py import torch diff --git a/bofire/kernels/fingerprint_kernels/tanimoto_kernel.py b/bofire/kernels/fingerprint_kernels/tanimoto_kernel.py index bca339bd3..f5fceab3e 100644 --- a/bofire/kernels/fingerprint_kernels/tanimoto_kernel.py +++ b/bofire/kernels/fingerprint_kernels/tanimoto_kernel.py @@ -2,6 +2,7 @@ Tanimoto Kernel. Operates on representations including bit vectors e.g. Morgan/ECFP6 fingerprints count vectors e.g. RDKit fragment features. """ + # This code was copied from GAUCHE: https://github.com/leojklarner/gauche/blob/main/gauche/kernels/fingerprint_kernels/tanimoto_kernel.py import torch @@ -36,6 +37,7 @@ class TanimotoKernel(BitKernel): >>> covar_module = gpytorch.kernels.ScaleKernel(TanimotoKernel()) >>> covar = covar_module(batch_x) # Output: LazyTensor of size (2 x 10 x 10) """ + is_stationary = False has_lengthscale = False diff --git a/bofire/kernels/mapper.py b/bofire/kernels/mapper.py index 87e6b7999..de9e225db 100644 --- a/bofire/kernels/mapper.py +++ b/bofire/kernels/mapper.py @@ -20,9 +20,11 @@ def map_RBFKernel( batch_shape=batch_shape, ard_num_dims=len(active_dims) if data_model.ard else None, active_dims=active_dims, # type: ignore - lengthscale_prior=priors.map(data_model.lengthscale_prior) - if data_model.lengthscale_prior is not None - else None, + lengthscale_prior=( + priors.map(data_model.lengthscale_prior) + if data_model.lengthscale_prior is not None + else None + ), ) @@ -37,9 +39,11 @@ def map_MaternKernel( ard_num_dims=len(active_dims) if data_model.ard else None, active_dims=active_dims, nu=data_model.nu, - lengthscale_prior=priors.map(data_model.lengthscale_prior) - if data_model.lengthscale_prior is not None - else None, + lengthscale_prior=( + priors.map(data_model.lengthscale_prior) + if data_model.lengthscale_prior is not None + else None + ), ) @@ -52,9 +56,11 @@ def map_LinearKernel( return gpytorch.kernels.LinearKernel( batch_shape=batch_shape, active_dims=active_dims, - variance_prior=priors.map(data_model.variance_prior) - if data_model.variance_prior is not None - else None, + variance_prior=( + priors.map(data_model.variance_prior) + if data_model.variance_prior is not None + else None + ), ) @@ -68,9 +74,11 @@ def map_PolynomialKernel( batch_shape=batch_shape, active_dims=active_dims, power=data_model.power, - offset_prior=priors.map(data_model.offset_prior) - if data_model.offset_prior is not None - else None, + offset_prior=( + priors.map(data_model.offset_prior) + if data_model.offset_prior is not None + else None + ), ) @@ -125,9 +133,11 @@ def map_ScaleKernel( ard_num_dims=ard_num_dims, active_dims=active_dims, ), - outputscale_prior=priors.map(data_model.outputscale_prior) - if data_model.outputscale_prior is not None - else None, + outputscale_prior=( + priors.map(data_model.outputscale_prior) + if data_model.outputscale_prior is not None + else None + ), ) diff --git a/bofire/runners/hyperoptimize.py b/bofire/runners/hyperoptimize.py index f17b610b2..1e66324a9 100644 --- a/bofire/runners/hyperoptimize.py +++ b/bofire/runners/hyperoptimize.py @@ -78,9 +78,11 @@ def sample(domain): # analyze the results and get the best experiments = experiments.sort_values( by=benchmark.target_metric.name, - ascending=True - if isinstance(benchmark.domain.outputs[0].objective, MinimizeObjective) - else False, + ascending=( + True + if isinstance(benchmark.domain.outputs[0].objective, MinimizeObjective) + else False + ), ) surrogate_data.update_hyperparameters(experiments.iloc[0]) diff --git a/bofire/strategies/predictives/botorch.py b/bofire/strategies/predictives/botorch.py index 8bcfd0b26..706296da5 100644 --- a/bofire/strategies/predictives/botorch.py +++ b/bofire/strategies/predictives/botorch.py @@ -136,13 +136,15 @@ def _fit(self, experiments: pd.DataFrame): from bofire.runners.hyperoptimize import hyperoptimize self.surrogate_specs.surrogates = [ # type: ignore - hyperoptimize( - surrogate_data=surrogate_data, # type: ignore - training_data=experiments, - folds=self.folds, - )[0] - if isinstance(surrogate_data, get_args(AnyTrainableSurrogate)) - else surrogate_data + ( + hyperoptimize( + surrogate_data=surrogate_data, # type: ignore + training_data=experiments, + folds=self.folds, + )[0] + if isinstance(surrogate_data, get_args(AnyTrainableSurrogate)) + else surrogate_data + ) for surrogate_data in self.surrogate_specs.surrogates # type: ignore ] diff --git a/bofire/strategies/predictives/mobo.py b/bofire/strategies/predictives/mobo.py index a7d1a547b..839ddb234 100644 --- a/bofire/strategies/predictives/mobo.py +++ b/bofire/strategies/predictives/mobo.py @@ -77,9 +77,11 @@ def _get_acqfs(self, n) -> List[AcquisitionFunction]: mc_samples=self.num_sobol_samples, cache_root=True if isinstance(self.model, GPyTorchModel) else False, alpha=self.acquisition_function.alpha, - prune_baseline=self.acquisition_function.prune_baseline - if isinstance(self.acquisition_function, (qLogNEHVI, qNEHVI)) - else True, + prune_baseline=( + self.acquisition_function.prune_baseline + if isinstance(self.acquisition_function, (qLogNEHVI, qNEHVI)) + else True + ), Y=Y, ) return [acqf] diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index cb7e39d75..3fb74c190 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -106,9 +106,11 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: column_names = list( itertools.chain( *[ - [f"{feat.key}_pred"] - if not isinstance(feat, CategoricalOutput) - else [f"{feat.key}_{cat}_prob" for cat in feat.categories] + ( + [f"{feat.key}_pred"] + if not isinstance(feat, CategoricalOutput) + else [f"{feat.key}_{cat}_prob" for cat in feat.categories] + ) for feat in self.domain.outputs.get() ] ) @@ -120,9 +122,11 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: + list( itertools.chain( *[ - [f"{feat.key}_sd"] - if not isinstance(feat, CategoricalOutput) - else [f"{feat.key}_{cat}_sd" for cat in feat.categories] + ( + [f"{feat.key}_sd"] + if not isinstance(feat, CategoricalOutput) + else [f"{feat.key}_{cat}_sd" for cat in feat.categories] + ) for feat in self.domain.outputs.get() ] ) @@ -193,9 +197,11 @@ def to_candidates(self, candidates: pd.DataFrame) -> List[Candidate]: feat.key: OutputValue( predictedValue=str(row[f"{feat.key}_pred"]), standardDeviation=row[f"{feat.key}_sd"], - objective=row[f"{feat.key}_des"] - if feat.objective is not None # type: ignore - else 1.0, + objective=( + row[f"{feat.key}_des"] + if feat.objective is not None # type: ignore + else 1.0 + ), ) for feat in self.domain.outputs.get() }, diff --git a/bofire/strategies/predictives/sobo.py b/bofire/strategies/predictives/sobo.py index 5a05b33c3..2ec1c2f3b 100644 --- a/bofire/strategies/predictives/sobo.py +++ b/bofire/strategies/predictives/sobo.py @@ -64,17 +64,23 @@ def _get_acqfs(self, n) -> List[AcquisitionFunction]: X_pending=X_pending, constraints=constraint_callables, mc_samples=self.num_sobol_samples, - beta=self.acquisition_function.beta - if isinstance(self.acquisition_function, qUCB) - else 0.2, - tau=self.acquisition_function.tau - if isinstance(self.acquisition_function, qPI) - else 1e-3, + beta=( + self.acquisition_function.beta + if isinstance(self.acquisition_function, qUCB) + else 0.2 + ), + tau=( + self.acquisition_function.tau + if isinstance(self.acquisition_function, qPI) + else 1e-3 + ), eta=torch.tensor(etas).to(**tkwargs), cache_root=True if isinstance(self.model, GPyTorchModel) else False, - prune_baseline=self.acquisition_function.prune_baseline - if isinstance(self.acquisition_function, (qNEI, qLogNEI)) - else True, + prune_baseline=( + self.acquisition_function.prune_baseline + if isinstance(self.acquisition_function, (qNEI, qLogNEI)) + else True + ), ) return [acqf] diff --git a/bofire/strategies/samplers/polytope.py b/bofire/strategies/samplers/polytope.py index 18767b362..cc7b8a84a 100644 --- a/bofire/strategies/samplers/polytope.py +++ b/bofire/strategies/samplers/polytope.py @@ -138,9 +138,9 @@ def _ask(self, n: int) -> pd.DataFrame: n=1, q=n, bounds=bounds.to(**tkwargs), - inequality_constraints=unfixed_ineqs - if len(unfixed_ineqs) > 0 # type: ignore - else None, + inequality_constraints=( + unfixed_ineqs if len(unfixed_ineqs) > 0 else None # type: ignore + ), equality_constraints=combined_eqs if len(combined_eqs) > 0 else None, n_burnin=self.n_burnin, thinning=self.n_thinning, diff --git a/bofire/strategies/samplers/universal_constraint.py b/bofire/strategies/samplers/universal_constraint.py index fff8d524a..f4990688a 100644 --- a/bofire/strategies/samplers/universal_constraint.py +++ b/bofire/strategies/samplers/universal_constraint.py @@ -37,7 +37,9 @@ def _ask(self, candidate_count: int) -> pd.DataFrame: fixed_experiments=self.candidates, ) - samples = samples.iloc[self.num_candidates :,] + samples = samples.iloc[ + self.num_candidates :, + ] samples = samples.sample( n=candidate_count, replace=False, diff --git a/bofire/surrogates/diagnostics.py b/bofire/surrogates/diagnostics.py index a4b260545..e5b60875f 100644 --- a/bofire/surrogates/diagnostics.py +++ b/bofire/surrogates/diagnostics.py @@ -763,9 +763,11 @@ def CvResults2CrossValidationValues( CrossValidationValues( observed=fold.observed.tolist(), predicted=fold.predicted.tolist(), - standardDeviation=fold.standard_deviation.tolist() - if fold.standard_deviation is not None - else None, + standardDeviation=( + fold.standard_deviation.tolist() + if fold.standard_deviation is not None + else None + ), metrics=metrics.loc[i].to_dict() if fold.n_samples > 1 else None, ) ) diff --git a/bofire/surrogates/fully_bayesian.py b/bofire/surrogates/fully_bayesian.py index 8f87f60eb..00f288b21 100644 --- a/bofire/surrogates/fully_bayesian.py +++ b/bofire/surrogates/fully_bayesian.py @@ -44,9 +44,11 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame, disable_progbar: bool = True): self.model = SaasFullyBayesianSingleTaskGP( train_X=tX, train_Y=tY, - outcome_transform=Standardize(m=1) - if self.output_scaler == ScalerEnum.STANDARDIZE - else None, + outcome_transform=( + Standardize(m=1) + if self.output_scaler == ScalerEnum.STANDARDIZE + else None + ), input_transform=scaler, ) fit_fully_bayesian_model_nuts( diff --git a/bofire/surrogates/mixed_single_task_gp.py b/bofire/surrogates/mixed_single_task_gp.py index f3bc8dcc8..4a5b992d0 100644 --- a/bofire/surrogates/mixed_single_task_gp.py +++ b/bofire/surrogates/mixed_single_task_gp.py @@ -90,9 +90,11 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): cat_dims=cat_dims, # cont_kernel_factory=self.continuous_kernel.to_gpytorch, cont_kernel_factory=partial(kernels.map, data_model=self.continuous_kernel), - outcome_transform=Standardize(m=tY.shape[-1]) - if self.output_scaler == ScalerEnum.STANDARDIZE - else None, + outcome_transform=( + Standardize(m=tY.shape[-1]) + if self.output_scaler == ScalerEnum.STANDARDIZE + else None + ), input_transform=tf, ) self.model.likelihood.noise_covar.noise_prior = priors.map(self.noise_prior) # type: ignore diff --git a/bofire/surrogates/single_task_gp.py b/bofire/surrogates/single_task_gp.py index 31a542b89..06e6cab49 100644 --- a/bofire/surrogates/single_task_gp.py +++ b/bofire/surrogates/single_task_gp.py @@ -53,9 +53,11 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): active_dims=list(range(tX.shape[1])), ard_num_dims=1, # this keyword is ingored ), - outcome_transform=Standardize(m=tY.shape[-1]) - if self.output_scaler == ScalerEnum.STANDARDIZE - else None, + outcome_transform=( + Standardize(m=tY.shape[-1]) + if self.output_scaler == ScalerEnum.STANDARDIZE + else None + ), input_transform=scaler, ) diff --git a/tests/bofire/benchmarks/test_single.py b/tests/bofire/benchmarks/test_single.py index aa4e9c333..8ce3ad88f 100644 --- a/tests/bofire/benchmarks/test_single.py +++ b/tests/bofire/benchmarks/test_single.py @@ -40,7 +40,7 @@ def test_hartmann(): (Branin, True, {}), (Branin, False, {}), (Branin30, True, {}), - (Branin30, False, {}) + (Branin30, False, {}), # TO DO: Implement feature that tests Ackley for categorical and descriptive inputs. # (Ackley, {"categorical": True}), # (Ackley, {"descriptor": True}), diff --git a/tests/bofire/data_models/domain/test_domain_validators.py b/tests/bofire/data_models/domain/test_domain_validators.py index b3b151cbc..845ef9918 100644 --- a/tests/bofire/data_models/domain/test_domain_validators.py +++ b/tests/bofire/data_models/domain/test_domain_validators.py @@ -72,9 +72,11 @@ def generate_experiments( ] }, **{ - f.key: random.choice(f.categories) # type: ignore - if not only_allowed_categories - else random.choice(f.get_allowed_categories()) # type: ignore + f.key: ( + random.choice(f.categories) # type: ignore + if not only_allowed_categories + else random.choice(f.get_allowed_categories()) + ) # type: ignore for f in domain.get_features(CategoricalInput) }, } diff --git a/tests/bofire/strategies/test_qehvi.py b/tests/bofire/strategies/test_qehvi.py index a07818197..dfa3cea7c 100644 --- a/tests/bofire/strategies/test_qehvi.py +++ b/tests/bofire/strategies/test_qehvi.py @@ -160,9 +160,11 @@ def test_qehvi(strategy, use_ref_point, num_test_candidates): assert isinstance(acqf.objective, GenericMCMultiOutputObjective) assert isinstance( acqf, - qExpectedHypervolumeImprovement - if strategy == data_models.QehviStrategy - else qNoisyExpectedHypervolumeImprovement, + ( + qExpectedHypervolumeImprovement + if strategy == data_models.QehviStrategy + else qNoisyExpectedHypervolumeImprovement + ), ) # test acqf calc # acqf_vals = my_strategy._choose_from_pool(experiments_test, num_test_candidates) diff --git a/tests/bofire/utils/test_torch_tools.py b/tests/bofire/utils/test_torch_tools.py index 67e6f413a..19cc2a8dc 100644 --- a/tests/bofire/utils/test_torch_tools.py +++ b/tests/bofire/utils/test_torch_tools.py @@ -162,12 +162,14 @@ def test_get_custom_botorch_objective(f, exclude_constraints): reward3 = obj3(a_samples[:, 2]) # do the comparison assert np.allclose( - (reward1**obj1.w + reward3**obj3.w) - * (reward1**obj1.w * reward3**obj3.w) - if exclude_constraints - else (reward1**obj1.w + reward2**obj2.w) - * (reward1**obj1.w * reward2**obj2.w) - * (reward1**obj1.w * reward3**obj3.w), + ( + (reward1**obj1.w + reward3**obj3.w) + * (reward1**obj1.w * reward3**obj3.w) + if exclude_constraints + else (reward1**obj1.w + reward2**obj2.w) + * (reward1**obj1.w * reward2**obj2.w) + * (reward1**obj1.w * reward3**obj3.w) + ), objective_forward.detach().numpy(), rtol=1e-06, ) @@ -266,9 +268,11 @@ def test_get_additive_botorch_objective(exclude_constraints): # do the comparison assert np.allclose( # objective.reward(samples, desFunc)[0].detach().numpy(), - reward1 * obj1.w + reward3 * obj3.w - if exclude_constraints - else reward1 * obj1.w + reward3 * obj3.w + reward2 * obj2.w, + ( + reward1 * obj1.w + reward3 * obj3.w + if exclude_constraints + else reward1 * obj1.w + reward3 * obj3.w + reward2 * obj2.w + ), objective_forward.detach().numpy(), rtol=1e-06, ) From b8a99fcce7950b7b8b988e157e2f0a9554c09ed9 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 08:36:19 -0500 Subject: [PATCH 16/31] Update DOE --- bofire/strategies/doe/utils.py | 4 ++-- tests/bofire/strategies/doe/test_utils.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bofire/strategies/doe/utils.py b/bofire/strategies/doe/utils.py index ff673b341..4e9a1b5d2 100644 --- a/bofire/strategies/doe/utils.py +++ b/bofire/strategies/doe/utils.py @@ -21,7 +21,7 @@ from bofire.data_models.strategies.api import ( PolytopeSampler as PolytopeSamplerDataModel, ) -from bofire.strategies.samplers.polytope import PolytopeSampler +from bofire.strategies.polytope import PolytopeSampler def get_formula_from_string( @@ -540,4 +540,4 @@ def nchoosek_constraints_as_bounds( # convert bounds to list of tuples bounds = [(b[0], b[1]) for b in bounds] - return bounds + return bounds \ No newline at end of file diff --git a/tests/bofire/strategies/doe/test_utils.py b/tests/bofire/strategies/doe/test_utils.py index 5e9284355..06d2c84d1 100644 --- a/tests/bofire/strategies/doe/test_utils.py +++ b/tests/bofire/strategies/doe/test_utils.py @@ -74,7 +74,7 @@ def test_get_formula_from_string(): assert all(term in np.array(model_formula, dtype=str) for term in terms) # linear and quadratic - terms = ["1", "x0", "x1", "x2", "x0**2", "x1**2", "x2**2"] + terms = ["1", "x0", "x1", "x2", "x0 ** 2", "x1 ** 2", "x2 ** 2"] model_formula = get_formula_from_string( domain=domain, model_type="linear-and-quadratic" ) @@ -90,9 +90,9 @@ def test_get_formula_from_string(): "x0:x1", "x0:x2", "x1:x2", - "x0**2", - "x1**2", - "x2**2", + "x0 ** 2", + "x1 ** 2", + "x2 ** 2", ] model_formula = get_formula_from_string(domain=domain, model_type="fully-quadratic") assert all(term in terms for term in model_formula) @@ -100,7 +100,7 @@ def test_get_formula_from_string(): # custom model terms_lhs = ["y"] - terms_rhs = ["1", "x0", "x0**2", "x0:x1"] + terms_rhs = ["1", "x0", "x0 ** 2", "x0:x1"] model_formula = get_formula_from_string( domain=domain, model_type="y ~ 1 + x0 + x0:x1 + {x0**2}", @@ -702,4 +702,4 @@ def test_nchoosek_constraints_as_bounds(): # ] # assert len(bounds) == 20 # for i in range(20): - # assert _bounds[i] == bounds[i] + # assert _bounds[i] == bounds[i] \ No newline at end of file From a2287d82996632f70f245c307244f65932b95a83 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 08:57:47 -0500 Subject: [PATCH 17/31] Fix type checks --- bofire/strategies/doe/utils.py | 2 +- bofire/surrogates/mlp.py | 2 +- bofire/surrogates/surrogate.py | 95 +- bofire/surrogates/trainable.py | 12 +- tests/bofire/strategies/doe/test_utils.py | 2 +- .../Unknown_Constraint_Classification.ipynb | 864 +++++++++--------- 6 files changed, 486 insertions(+), 491 deletions(-) diff --git a/bofire/strategies/doe/utils.py b/bofire/strategies/doe/utils.py index 4e9a1b5d2..992e9adae 100644 --- a/bofire/strategies/doe/utils.py +++ b/bofire/strategies/doe/utils.py @@ -540,4 +540,4 @@ def nchoosek_constraints_as_bounds( # convert bounds to list of tuples bounds = [(b[0], b[1]) for b in bounds] - return bounds \ No newline at end of file + return bounds diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index d409c19da..e843cd518 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -138,7 +138,7 @@ def fit_mlp( weight_decay: float = 0.0, loss_function: Union[ nn.modules.loss.L1Loss, nn.modules.loss.CrossEntropyLoss - ] = nn.L1Loss, + ] = nn.L1Loss, # type: ignore ): """Fit a MLP to a dataset. diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index 468cb1f90..06ecfe670 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -5,7 +5,7 @@ import pandas as pd from bofire.data_models.domain.domain import is_numeric -from bofire.data_models.features.api import CategoricalOutput +from bofire.data_models.features.api import CategoricalOutput, ContinuousOutput from bofire.data_models.surrogates.api import Surrogate as DataModel from bofire.surrogates.values import PredictedValue @@ -48,41 +48,39 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: # set up column names pred_cols = [] sd_cols = [] - for featkey in self.outputs.get_keys(): - if hasattr(self.outputs.get_by_key(featkey), "categories"): - pred_cols = pred_cols + [ - f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories - ] - sd_cols = sd_cols + [ - f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories - ] - else: - pred_cols = pred_cols + [f"{featkey}_pred"] - sd_cols = sd_cols + [f"{featkey}_sd"] + for featkey in self.outputs.get_keys(CategoricalOutput): + pred_cols = pred_cols + [ + f"{featkey}_{cat}_prob" + for cat in self.outputs.get_by_key(featkey).categories + ] + sd_cols = sd_cols + [ + f"{featkey}_{cat}_sd" + for cat in self.outputs.get_by_key(featkey).categories + ] + for featkey in self.outputs.get_keys(ContinuousOutput): + pred_cols = pred_cols + [f"{featkey}_pred"] + sd_cols = sd_cols + [f"{featkey}_sd"] # postprocess predictions = pd.DataFrame( data=np.hstack((preds, stds)), columns=pred_cols + sd_cols, ) # append predictions for categorical cases - for feat in self.outputs.get(): - if isinstance(feat, CategoricalOutput): - predictions.insert( - loc=0, - column=f"{feat.key}_pred", - value=predictions.filter(regex=f"{feat.key}(.*)_prob") - .idxmax(1) - .str.replace(f"{feat.key}_", "") - .str.replace("_prob", "") - .values, - ) - predictions.insert( - loc=1, - column=f"{feat.key}_sd", - value=0.0, - ) + for feat in self.outputs.get(CategoricalOutput): + predictions.insert( + loc=0, + column=f"{feat.key}_pred", + value=predictions.filter(regex=f"{feat.key}(.*)_prob") + .idxmax(1) + .str.replace(f"{feat.key}_", "") + .str.replace("_prob", "") + .values, + ) + predictions.insert( + loc=1, + column=f"{feat.key}_sd", + value=0.0, + ) # validate self.validate_predictions(predictions=predictions) # return @@ -91,28 +89,25 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: expected_cols = [] check_columns = [] - for featkey in self.outputs.get_keys(): - if hasattr(self.outputs.get_by_key(featkey), "categories"): - expected_cols = ( - expected_cols - + [f"{featkey}_{t}" for t in ["pred", "sd"]] - + [ - f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories - ] - + [ - f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories - ] - ) - check_columns = check_columns + [ - col for col in expected_cols if col != f"{featkey}_pred" + for featkey in self.outputs.get_keys(CategoricalOutput): + expected_cols = ( + expected_cols + + [f"{featkey}_{t}" for t in ["pred", "sd"]] + + [ + f"{featkey}_{cat}_prob" + for cat in self.outputs.get_by_key(featkey).categories ] - else: - expected_cols = expected_cols + [ - f"{featkey}_{t}" for t in ["pred", "sd"] + + [ + f"{featkey}_{cat}_sd" + for cat in self.outputs.get_by_key(featkey).categories ] - check_columns = check_columns + expected_cols + ) + check_columns = check_columns + [ + col for col in expected_cols if col != f"{featkey}_pred" + ] + for featkey in self.outputs.get_keys(ContinuousOutput): + expected_cols = expected_cols + [f"{featkey}_{t}" for t in ["pred", "sd"]] + check_columns = check_columns + expected_cols if sorted(predictions.columns) != sorted(expected_cols): raise ValueError( f"Predictions are ill-formatted. Expected: {expected_cols}, got: {list(predictions.columns)}." diff --git a/bofire/surrogates/trainable.py b/bofire/surrogates/trainable.py index a845b196e..2c66a26b1 100644 --- a/bofire/surrogates/trainable.py +++ b/bofire/surrogates/trainable.py @@ -184,16 +184,18 @@ def cross_validate( y_train_pred = self.predict(X_train) # type: ignore # Convert to categorical if applicable - if isinstance(self.outputs[0].objective, ConstrainedCategoricalObjective): + if isinstance(self.outputs.get_by_key(key).objective, ConstrainedCategoricalObjective): # type: ignore y_test_pred[f"{key}_pred"] = y_test_pred[f"{key}_pred"].map( - self.outputs[0].objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() ) y_train_pred[f"{key}_pred"] = y_train_pred[f"{key}_pred"].map( - self.outputs[0].objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() + ) + y_test[key] = y_test[key].map( + self.outputs.get_by_key(key).objective.to_dict_label() ) - y_test[key] = y_test[key].map(self.outputs[0].objective.to_dict_label()) y_train[key] = y_train[key].map( - self.outputs[0].objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() ) # now store the results diff --git a/tests/bofire/strategies/doe/test_utils.py b/tests/bofire/strategies/doe/test_utils.py index 06d2c84d1..c3aa45d37 100644 --- a/tests/bofire/strategies/doe/test_utils.py +++ b/tests/bofire/strategies/doe/test_utils.py @@ -702,4 +702,4 @@ def test_nchoosek_constraints_as_bounds(): # ] # assert len(bounds) == 20 # for i in range(20): - # assert _bounds[i] == bounds[i] \ No newline at end of file + # assert _bounds[i] == bounds[i] diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 61c0c22e2..fb39ddd4a 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -25,9 +25,7 @@ "output_type": "stream", "text": [ "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\pydantic\\_migration.py:283: UserWarning: `pydantic.error_wrappers:ValidationError` has been moved to `pydantic:ValidationError`.\n", - " warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')\n" + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -102,243 +100,243 @@ " \n", " \n", " 0\n", - " -1.147520\n", - " 0.352595\n", - " 1.233487\n", - " 0.764528\n", - " 0.728527\n", + " -0.267947\n", + " 1.021106\n", + " 1.623060\n", + " 0.627026\n", + " 1.206148\n", " 0\n", - " 2.600000\n", - " acceptable\n", - " -1.140822\n", + " 1.140700\n", + " ideal\n", + " -0.265670\n", " \n", " \n", " 1\n", - " -0.541992\n", - " 1.030834\n", - " 0.544503\n", - " -0.451754\n", - " -1.379001\n", + " -0.988891\n", + " -0.747924\n", + " -0.077380\n", + " -1.366909\n", + " 0.122713\n", " 0\n", - " 3.722210\n", - " acceptable\n", - " -0.533687\n", + " 4.933869\n", + " unacceptable\n", + " -0.981222\n", " \n", " \n", " 2\n", - " 1.620988\n", - " 1.569770\n", - " 0.510210\n", - " 0.478663\n", - " 0.577311\n", - " 1\n", - " 0.891793\n", - " ideal\n", - " 1.626114\n", + " 0.171629\n", + " -0.592879\n", + " 0.188483\n", + " -0.706858\n", + " -0.710738\n", + " 0\n", + " 1.585117\n", + " unacceptable\n", + " 0.173409\n", " \n", " \n", " 3\n", - " 0.603489\n", - " 1.560799\n", - " 1.147270\n", - " 0.860127\n", - " -0.984688\n", - " 0\n", - " 1.315285\n", - " ideal\n", - " 0.606293\n", + " -1.745543\n", + " 1.718025\n", + " -0.866392\n", + " -0.236673\n", + " 1.936513\n", + " 1\n", + " 6.809027\n", + " unacceptable\n", + " -1.740944\n", " \n", " \n", " 4\n", - " 1.115889\n", - " 0.675806\n", - " 1.156867\n", - " 0.238840\n", - " -0.270817\n", - " 0\n", - " 0.259161\n", - " unacceptable\n", - " 1.121470\n", + " 0.461210\n", + " -0.087926\n", + " 0.883033\n", + " -0.319482\n", + " -1.813779\n", + " 1\n", + " 5.272989\n", + " acceptable\n", + " 0.461377\n", " \n", " \n", " 5\n", - " 0.947684\n", - " -1.348468\n", - " 1.595431\n", - " -1.910428\n", - " 0.906305\n", + " 1.386029\n", + " -0.098083\n", + " 0.174276\n", + " 0.804043\n", + " 0.321967\n", " 0\n", - " 7.925056\n", + " 0.470671\n", " unacceptable\n", - " 0.956881\n", + " 1.391563\n", " \n", " \n", " 6\n", - " -1.560800\n", - " 1.897495\n", - " -0.659937\n", - " -0.922824\n", - " 1.988474\n", - " 1\n", - " 6.471322\n", + " -1.946305\n", + " -0.055122\n", + " 1.468864\n", + " -1.514258\n", + " -1.641203\n", + " 0\n", + " 13.303294\n", " unacceptable\n", - " -1.554441\n", + " -1.938876\n", " \n", " \n", " 7\n", - " 1.765355\n", - " 0.208224\n", - " 0.475951\n", - " -0.412166\n", - " 1.013625\n", + " 0.381642\n", + " 1.842479\n", + " -1.500184\n", + " 0.353710\n", + " 0.508349\n", " 0\n", - " 1.024378\n", - " acceptable\n", - " 1.774402\n", + " 3.959427\n", + " ideal\n", + " 0.390396\n", " \n", " \n", " 8\n", - " 0.817979\n", - " -1.398751\n", - " -1.613844\n", - " -1.626274\n", - " 1.718574\n", + " -0.428826\n", + " -1.958059\n", + " -0.221027\n", + " -1.372467\n", + " 1.275357\n", " 0\n", - " 10.271230\n", - " unacceptable\n", - " 0.821963\n", + " 9.170777\n", + " ideal\n", + " -0.419537\n", " \n", " \n", " 9\n", - " 1.624639\n", - " 0.591260\n", - " 1.056660\n", - " 1.964813\n", - " -0.609677\n", + " -0.473310\n", + " -0.402426\n", + " -0.608333\n", + " -0.403410\n", + " 1.495413\n", " 1\n", - " 1.434253\n", - " acceptable\n", - " 1.633391\n", + " 1.713585\n", + " unacceptable\n", + " -0.465823\n", " \n", " \n", " 10\n", - " 0.659178\n", - " -1.488256\n", - " -0.769463\n", - " -0.498631\n", - " 1.641154\n", + " -1.961280\n", + " -1.586803\n", + " -0.561849\n", + " 0.064580\n", + " -1.416727\n", " 0\n", - " 4.016106\n", - " ideal\n", - " 0.666231\n", + " 12.654743\n", + " acceptable\n", + " -1.958333\n", " \n", " \n", " 11\n", - " 0.843863\n", - " -0.778241\n", - " -0.742882\n", - " 1.041363\n", - " -0.462427\n", + " -0.235274\n", + " -0.638830\n", + " -0.107924\n", + " 1.769868\n", + " 1.108539\n", " 1\n", - " 1.118582\n", + " 1.582444\n", " acceptable\n", - " 0.845148\n", + " -0.227187\n", " \n", " \n", " 12\n", - " 1.713291\n", - " -0.327819\n", - " 1.934123\n", - " -1.535219\n", - " -0.263203\n", + " 0.617453\n", + " 0.186646\n", + " -0.328716\n", + " 1.212407\n", + " -0.965701\n", " 0\n", - " 4.643351\n", - " acceptable\n", - " 1.715253\n", + " 1.254832\n", + " unacceptable\n", + " 0.625534\n", " \n", " \n", " 13\n", - " 0.425012\n", - " -0.428886\n", - " 0.678414\n", - " 0.867640\n", - " 1.947507\n", + " -0.004329\n", + " -0.521047\n", + " -1.408930\n", + " -0.211512\n", + " -0.027510\n", " 1\n", - " 1.058468\n", - " acceptable\n", - " 0.426804\n", + " 3.536682\n", + " unacceptable\n", + " -0.004252\n", " \n", " \n", " 14\n", - " -1.783885\n", - " 0.597387\n", - " 0.018537\n", - " -0.867099\n", - " 1.208851\n", + " 0.473978\n", + " -1.578219\n", + " 0.083191\n", + " -0.357965\n", + " 1.037494\n", " 0\n", - " 6.139702\n", + " 3.895441\n", " acceptable\n", - " -1.775463\n", + " 0.478809\n", " \n", " \n", " 15\n", - " 0.840074\n", - " -0.300067\n", - " -1.617072\n", - " -1.733238\n", - " -1.571214\n", + " -0.653601\n", + " -0.407954\n", + " -1.939297\n", + " 1.055724\n", + " 0.830087\n", " 0\n", - " 11.558383\n", - " acceptable\n", - " 0.841770\n", + " 6.974997\n", + " ideal\n", + " -0.651507\n", " \n", " \n", " 16\n", - " 1.666942\n", - " -0.941869\n", - " 0.904134\n", - " 0.685048\n", - " 1.717750\n", - " 1\n", - " 1.776970\n", + " 0.285208\n", + " -1.327075\n", + " 0.705723\n", + " 1.284219\n", + " -0.324223\n", + " 0\n", + " 2.739220\n", " acceptable\n", - " 1.668979\n", + " 0.286143\n", " \n", " \n", " 17\n", - " 1.396363\n", - " 1.807155\n", - " 0.594131\n", - " -1.607430\n", - " 1.190373\n", + " 0.462940\n", + " -1.584694\n", + " 1.114151\n", + " 0.626237\n", + " -1.071201\n", " 1\n", - " 4.328181\n", - " unacceptable\n", - " 1.402815\n", + " 4.968933\n", + " ideal\n", + " 0.466059\n", " \n", " \n", " 18\n", - " -0.583546\n", - " 1.973070\n", - " -1.911667\n", - " -1.225064\n", - " -1.323822\n", + " 1.746916\n", + " 1.093672\n", + " 1.004508\n", + " -1.620785\n", + " -1.880987\n", " 0\n", - " 11.085039\n", + " 9.614141\n", " unacceptable\n", - " -0.574170\n", + " 1.748456\n", " \n", " \n", " 19\n", - " -0.133564\n", - " -0.037052\n", - " 1.349117\n", - " 0.108222\n", - " -0.739183\n", + " 1.276677\n", + " -0.126877\n", + " -1.214876\n", + " -0.465293\n", + " -0.207115\n", " 1\n", - " 1.385558\n", + " 2.140485\n", " unacceptable\n", - " -0.129039\n", + " 1.283395\n", " \n", " \n", "\n", @@ -346,48 +344,48 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 -1.147520 0.352595 1.233487 0.764528 0.728527 0 2.600000 \n", - "1 -0.541992 1.030834 0.544503 -0.451754 -1.379001 0 3.722210 \n", - "2 1.620988 1.569770 0.510210 0.478663 0.577311 1 0.891793 \n", - "3 0.603489 1.560799 1.147270 0.860127 -0.984688 0 1.315285 \n", - "4 1.115889 0.675806 1.156867 0.238840 -0.270817 0 0.259161 \n", - "5 0.947684 -1.348468 1.595431 -1.910428 0.906305 0 7.925056 \n", - "6 -1.560800 1.897495 -0.659937 -0.922824 1.988474 1 6.471322 \n", - "7 1.765355 0.208224 0.475951 -0.412166 1.013625 0 1.024378 \n", - "8 0.817979 -1.398751 -1.613844 -1.626274 1.718574 0 10.271230 \n", - "9 1.624639 0.591260 1.056660 1.964813 -0.609677 1 1.434253 \n", - "10 0.659178 -1.488256 -0.769463 -0.498631 1.641154 0 4.016106 \n", - "11 0.843863 -0.778241 -0.742882 1.041363 -0.462427 1 1.118582 \n", - "12 1.713291 -0.327819 1.934123 -1.535219 -0.263203 0 4.643351 \n", - "13 0.425012 -0.428886 0.678414 0.867640 1.947507 1 1.058468 \n", - "14 -1.783885 0.597387 0.018537 -0.867099 1.208851 0 6.139702 \n", - "15 0.840074 -0.300067 -1.617072 -1.733238 -1.571214 0 11.558383 \n", - "16 1.666942 -0.941869 0.904134 0.685048 1.717750 1 1.776970 \n", - "17 1.396363 1.807155 0.594131 -1.607430 1.190373 1 4.328181 \n", - "18 -0.583546 1.973070 -1.911667 -1.225064 -1.323822 0 11.085039 \n", - "19 -0.133564 -0.037052 1.349117 0.108222 -0.739183 1 1.385558 \n", + "0 -0.267947 1.021106 1.623060 0.627026 1.206148 0 1.140700 \n", + "1 -0.988891 -0.747924 -0.077380 -1.366909 0.122713 0 4.933869 \n", + "2 0.171629 -0.592879 0.188483 -0.706858 -0.710738 0 1.585117 \n", + "3 -1.745543 1.718025 -0.866392 -0.236673 1.936513 1 6.809027 \n", + "4 0.461210 -0.087926 0.883033 -0.319482 -1.813779 1 5.272989 \n", + "5 1.386029 -0.098083 0.174276 0.804043 0.321967 0 0.470671 \n", + "6 -1.946305 -0.055122 1.468864 -1.514258 -1.641203 0 13.303294 \n", + "7 0.381642 1.842479 -1.500184 0.353710 0.508349 0 3.959427 \n", + "8 -0.428826 -1.958059 -0.221027 -1.372467 1.275357 0 9.170777 \n", + "9 -0.473310 -0.402426 -0.608333 -0.403410 1.495413 1 1.713585 \n", + "10 -1.961280 -1.586803 -0.561849 0.064580 -1.416727 0 12.654743 \n", + "11 -0.235274 -0.638830 -0.107924 1.769868 1.108539 1 1.582444 \n", + "12 0.617453 0.186646 -0.328716 1.212407 -0.965701 0 1.254832 \n", + "13 -0.004329 -0.521047 -1.408930 -0.211512 -0.027510 1 3.536682 \n", + "14 0.473978 -1.578219 0.083191 -0.357965 1.037494 0 3.895441 \n", + "15 -0.653601 -0.407954 -1.939297 1.055724 0.830087 0 6.974997 \n", + "16 0.285208 -1.327075 0.705723 1.284219 -0.324223 0 2.739220 \n", + "17 0.462940 -1.584694 1.114151 0.626237 -1.071201 1 4.968933 \n", + "18 1.746916 1.093672 1.004508 -1.620785 -1.880987 0 9.614141 \n", + "19 1.276677 -0.126877 -1.214876 -0.465293 -0.207115 1 2.140485 \n", "\n", " f_1 f_2 \n", - "0 acceptable -1.140822 \n", - "1 acceptable -0.533687 \n", - "2 ideal 1.626114 \n", - "3 ideal 0.606293 \n", - "4 unacceptable 1.121470 \n", - "5 unacceptable 0.956881 \n", - "6 unacceptable -1.554441 \n", - "7 acceptable 1.774402 \n", - "8 unacceptable 0.821963 \n", - "9 acceptable 1.633391 \n", - "10 ideal 0.666231 \n", - "11 acceptable 0.845148 \n", - "12 acceptable 1.715253 \n", - "13 acceptable 0.426804 \n", - "14 acceptable -1.775463 \n", - "15 acceptable 0.841770 \n", - "16 acceptable 1.668979 \n", - "17 unacceptable 1.402815 \n", - "18 unacceptable -0.574170 \n", - "19 unacceptable -0.129039 " + "0 ideal -0.265670 \n", + "1 unacceptable -0.981222 \n", + "2 unacceptable 0.173409 \n", + "3 unacceptable -1.740944 \n", + "4 acceptable 0.461377 \n", + "5 unacceptable 1.391563 \n", + "6 unacceptable -1.938876 \n", + "7 ideal 0.390396 \n", + "8 ideal -0.419537 \n", + "9 unacceptable -0.465823 \n", + "10 acceptable -1.958333 \n", + "11 acceptable -0.227187 \n", + "12 unacceptable 0.625534 \n", + "13 unacceptable -0.004252 \n", + "14 acceptable 0.478809 \n", + "15 ideal -0.651507 \n", + "16 acceptable 0.286143 \n", + "17 ideal 0.466059 \n", + "18 unacceptable 1.748456 \n", + "19 unacceptable 1.283395 " ] }, "execution_count": 3, @@ -511,16 +509,16 @@ " \n", " \n", " 0\n", - " 0.77\n", - " 0.77\n", + " 0.745\n", + " 0.745\n", " \n", " \n", "\n", "" ], "text/plain": [ - " ACCURACY F1\n", - "0 0.77 0.77" + " ACCURACY F1\n", + "0 0.745 0.745" ] }, "execution_count": 5, @@ -566,8 +564,8 @@ " \n", " \n", " 0\n", - " 0.46\n", - " 0.46\n", + " 0.32\n", + " 0.32\n", " \n", " \n", "\n", @@ -575,7 +573,7 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.46 0.46" + "0 0.32 0.32" ] }, "execution_count": 6, @@ -685,243 +683,243 @@ " \n", " \n", " 0\n", - " 0.098607\n", - " 1.274078\n", + " 0.744661\n", + " -0.014262\n", + " -0.451975\n", " 2.000000\n", - " 0.749040\n", - " -0.156863\n", - " 0\n", + " 0.264316\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.598172\n", - " 0.102933\n", + " -0.794483\n", + " 0.751355\n", " ...\n", - " 0.999879\n", - " 2.749825e-07\n", - " 0.830174\n", - " 0.003029\n", - " 0.000270\n", - " 0.000271\n", - " 6.091118e-07\n", - " 0.598172\n", - " 0.487136\n", - " 0.999879\n", + " 0.781416\n", + " 0.197786\n", + " 0.749730\n", + " 0.003478\n", + " 0.011921\n", + " 0.436942\n", + " 0.438615\n", + " 0.794483\n", + " 0.407170\n", + " 0.979202\n", " \n", " \n", " 1\n", - " 0.646990\n", - " 0.604150\n", - " -0.341737\n", - " 1.307168\n", - " -0.151535\n", - " 0\n", + " -0.025971\n", + " 0.538385\n", + " 0.445088\n", + " -0.029734\n", + " 0.207444\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.465931\n", - " 0.650499\n", + " -0.061202\n", + " -0.021137\n", " ...\n", - " 0.979470\n", - " 1.731224e-02\n", - " 0.241746\n", - " 0.002957\n", - " 0.007131\n", - " 0.045842\n", - " 3.871093e-02\n", - " 0.465931\n", - " 0.419397\n", - " 0.996782\n", + " 0.878746\n", + " 0.119988\n", + " 0.471483\n", + " 0.003391\n", + " 0.001982\n", + " 0.270151\n", + " 0.268299\n", + " 0.061202\n", + " 0.502642\n", + " 0.998734\n", " \n", " \n", " 2\n", - " -0.268094\n", + " 0.643000\n", + " 0.170813\n", " 2.000000\n", " 2.000000\n", - " 0.417189\n", - " -0.160303\n", - " 0\n", + " 0.250742\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.095701\n", - " -0.263037\n", + " -0.536501\n", + " 0.649062\n", " ...\n", - " 0.999936\n", - " 9.786710e-08\n", - " 1.129313\n", - " 0.003170\n", - " 0.000139\n", - " 0.000139\n", - " 2.182845e-07\n", - " 0.095701\n", - " 0.532832\n", - " 0.999937\n", + " 0.947750\n", + " 0.050431\n", + " 1.069920\n", + " 0.003621\n", + " 0.004066\n", + " 0.116790\n", + " 0.112724\n", + " 0.536501\n", + " 0.419572\n", + " 0.998181\n", " \n", " \n", " 3\n", - " -0.407627\n", + " 1.046492\n", + " 1.009127\n", + " 1.693894\n", " 2.000000\n", " 2.000000\n", - " 2.000000\n", - " -0.147082\n", " 1\n", " acceptable\n", " 0.0\n", - " 1.523601\n", - " -0.403380\n", + " 1.483338\n", + " 1.054340\n", " ...\n", - " 0.808412\n", - " 1.612794e-07\n", - " 1.637777\n", - " 0.003359\n", - " 0.416360\n", - " 0.416360\n", - " 2.241354e-07\n", - " -1.523601\n", - " 0.550252\n", - " 0.808412\n", + " 0.799970\n", + " 0.199863\n", + " 1.958267\n", + " 0.003926\n", + " 0.000285\n", + " 0.447182\n", + " 0.446908\n", + " -1.483338\n", + " 0.371177\n", + " 0.999833\n", " \n", " \n", " 4\n", - " 0.928282\n", - " 0.986890\n", - " -0.608701\n", - " 1.760348\n", - " -0.144849\n", - " 0\n", + " 0.249543\n", + " 0.448522\n", + " 2.000000\n", + " -0.021518\n", + " 0.143814\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.305604\n", - " 0.931805\n", + " -0.015632\n", + " 0.253437\n", " ...\n", - " 0.967507\n", - " 2.830622e-02\n", - " 0.478486\n", - " 0.003040\n", - " 0.009351\n", - " 0.072646\n", - " 6.329445e-02\n", - " 0.305604\n", - " 0.385587\n", - " 0.995813\n", + " 0.997194\n", + " 0.002804\n", + " 0.752782\n", + " 0.003477\n", + " 0.000004\n", + " 0.006273\n", + " 0.006269\n", + " 0.015632\n", + " 0.468363\n", + " 0.999998\n", " \n", " \n", " 5\n", - " -0.171500\n", - " 1.702729\n", - " 2.000000\n", - " 1.147333\n", - " -0.205059\n", - " 0\n", + " 0.854194\n", + " 0.464335\n", + " 1.579664\n", + " 0.782255\n", + " 1.068275\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.176775\n", - " -0.167058\n", + " 0.071247\n", + " 0.860109\n", " ...\n", - " 0.999999\n", - " 8.229916e-09\n", - " 1.045623\n", - " 0.003128\n", - " 0.000003\n", - " 0.000003\n", - " 1.748655e-08\n", - " 0.176775\n", - " 0.520870\n", - " 0.999999\n", + " 0.801522\n", + " 0.197997\n", + " 0.378974\n", + " 0.003551\n", + " 0.001074\n", + " 0.443808\n", + " 0.442734\n", + " -0.071247\n", + " 0.394113\n", + " 0.999519\n", " \n", " \n", " 6\n", - " 0.348453\n", - " 0.749030\n", + " 0.884042\n", + " 0.406672\n", + " 1.494174\n", " 2.000000\n", - " 0.805095\n", - " -0.119394\n", - " 0\n", + " 0.157301\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.598377\n", - " 0.352705\n", + " -0.616233\n", + " 0.891074\n", " ...\n", - " 0.999376\n", - " 1.184046e-06\n", - " 0.710733\n", - " 0.002993\n", - " 0.001392\n", - " 0.001395\n", - " 2.614105e-06\n", - " 0.598377\n", - " 0.456026\n", - " 0.999377\n", + " 0.860195\n", + " 0.136536\n", + " 0.916442\n", + " 0.003627\n", + " 0.007275\n", + " 0.312579\n", + " 0.305304\n", + " 0.616233\n", + " 0.390422\n", + " 0.996731\n", " \n", " \n", " 7\n", - " -2.000000\n", - " 0.457609\n", - " 1.675461\n", - " 1.207754\n", - " 0.243209\n", - " 0\n", + " 0.143312\n", + " 0.679986\n", + " -0.116402\n", + " 0.081383\n", + " 0.247956\n", + " 1\n", " acceptable\n", " 0.0\n", - " 4.752548\n", - " -1.994210\n", + " -0.067516\n", + " 0.147784\n", " ...\n", - " 0.949335\n", - " 1.139145e-04\n", - " 0.807708\n", - " 0.005353\n", - " 0.113034\n", - " 0.113289\n", - " 2.547206e-04\n", - " -4.752548\n", - " 0.730489\n", - " 0.949449\n", + " 0.795419\n", + " 0.189967\n", + " 0.470200\n", + " 0.003365\n", + " 0.009657\n", + " 0.426349\n", + " 0.424642\n", + " 0.067516\n", + " 0.481535\n", + " 0.985387\n", " \n", " \n", " 8\n", - " 0.126459\n", - " 1.752512\n", - " 2.000000\n", - " 0.269535\n", - " -0.025944\n", - " 0\n", + " -0.293373\n", + " 0.876231\n", + " 0.524355\n", + " 0.052708\n", + " 0.143248\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.373769\n", - " 0.130896\n", + " 0.195204\n", + " -0.286897\n", " ...\n", - " 0.998390\n", - " 1.199743e-06\n", - " 1.007492\n", - " 0.003081\n", - " 0.003593\n", - " 0.003596\n", - " 2.680733e-06\n", - " 0.373769\n", - " 0.483644\n", - " 0.998391\n", + " 0.914762\n", + " 0.084724\n", + " 0.575528\n", + " 0.003441\n", + " 0.000992\n", + " 0.189307\n", + " 0.189448\n", + " -0.195204\n", + " 0.535801\n", + " 0.999486\n", " \n", " \n", " 9\n", - " 0.106551\n", - " 1.362876\n", - " 2.000000\n", - " 1.558184\n", - " -0.132843\n", - " 0\n", + " 0.395178\n", + " 0.358173\n", + " 1.405300\n", + " -0.040535\n", + " 0.223360\n", + " 1\n", " acceptable\n", " 0.0\n", - " -0.100465\n", - " 0.110338\n", + " -0.125187\n", + " 0.399090\n", " ...\n", - " 0.999999\n", - " 6.511140e-09\n", - " 1.039770\n", - " 0.003075\n", - " 0.000001\n", - " 0.000001\n", - " 1.175186e-08\n", - " 0.100465\n", - " 0.486211\n", - " 0.999999\n", + " 0.973413\n", + " 0.026521\n", + " 0.356968\n", + " 0.003403\n", + " 0.000143\n", + " 0.059446\n", + " 0.059303\n", + " 0.125187\n", + " 0.450279\n", + " 0.999934\n", " \n", " \n", "\n", @@ -930,52 +928,52 @@ ], "text/plain": [ " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", - "0 0.098607 1.274078 2.000000 0.749040 -0.156863 0 acceptable 0.0 \n", - "1 0.646990 0.604150 -0.341737 1.307168 -0.151535 0 acceptable 0.0 \n", - "2 -0.268094 2.000000 2.000000 0.417189 -0.160303 0 acceptable 0.0 \n", - "3 -0.407627 2.000000 2.000000 2.000000 -0.147082 1 acceptable 0.0 \n", - "4 0.928282 0.986890 -0.608701 1.760348 -0.144849 0 acceptable 0.0 \n", - "5 -0.171500 1.702729 2.000000 1.147333 -0.205059 0 acceptable 0.0 \n", - "6 0.348453 0.749030 2.000000 0.805095 -0.119394 0 acceptable 0.0 \n", - "7 -2.000000 0.457609 1.675461 1.207754 0.243209 0 acceptable 0.0 \n", - "8 0.126459 1.752512 2.000000 0.269535 -0.025944 0 acceptable 0.0 \n", - "9 0.106551 1.362876 2.000000 1.558184 -0.132843 0 acceptable 0.0 \n", + "0 0.744661 -0.014262 -0.451975 2.000000 0.264316 1 acceptable 0.0 \n", + "1 -0.025971 0.538385 0.445088 -0.029734 0.207444 1 acceptable 0.0 \n", + "2 0.643000 0.170813 2.000000 2.000000 0.250742 1 acceptable 0.0 \n", + "3 1.046492 1.009127 1.693894 2.000000 2.000000 1 acceptable 0.0 \n", + "4 0.249543 0.448522 2.000000 -0.021518 0.143814 1 acceptable 0.0 \n", + "5 0.854194 0.464335 1.579664 0.782255 1.068275 1 acceptable 0.0 \n", + "6 0.884042 0.406672 1.494174 2.000000 0.157301 1 acceptable 0.0 \n", + "7 0.143312 0.679986 -0.116402 0.081383 0.247956 1 acceptable 0.0 \n", + "8 -0.293373 0.876231 0.524355 0.052708 0.143248 1 acceptable 0.0 \n", + "9 0.395178 0.358173 1.405300 -0.040535 0.223360 1 acceptable 0.0 \n", "\n", " f_0_pred f_2_pred ... f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", - "0 -0.598172 0.102933 ... 0.999879 2.749825e-07 0.830174 \n", - "1 -0.465931 0.650499 ... 0.979470 1.731224e-02 0.241746 \n", - "2 -0.095701 -0.263037 ... 0.999936 9.786710e-08 1.129313 \n", - "3 1.523601 -0.403380 ... 0.808412 1.612794e-07 1.637777 \n", - "4 -0.305604 0.931805 ... 0.967507 2.830622e-02 0.478486 \n", - "5 -0.176775 -0.167058 ... 0.999999 8.229916e-09 1.045623 \n", - "6 -0.598377 0.352705 ... 0.999376 1.184046e-06 0.710733 \n", - "7 4.752548 -1.994210 ... 0.949335 1.139145e-04 0.807708 \n", - "8 -0.373769 0.130896 ... 0.998390 1.199743e-06 1.007492 \n", - "9 -0.100465 0.110338 ... 0.999999 6.511140e-09 1.039770 \n", + "0 -0.794483 0.751355 ... 0.781416 0.197786 0.749730 \n", + "1 -0.061202 -0.021137 ... 0.878746 0.119988 0.471483 \n", + "2 -0.536501 0.649062 ... 0.947750 0.050431 1.069920 \n", + "3 1.483338 1.054340 ... 0.799970 0.199863 1.958267 \n", + "4 -0.015632 0.253437 ... 0.997194 0.002804 0.752782 \n", + "5 0.071247 0.860109 ... 0.801522 0.197997 0.378974 \n", + "6 -0.616233 0.891074 ... 0.860195 0.136536 0.916442 \n", + "7 -0.067516 0.147784 ... 0.795419 0.189967 0.470200 \n", + "8 0.195204 -0.286897 ... 0.914762 0.084724 0.575528 \n", + "9 -0.125187 0.399090 ... 0.973413 0.026521 0.356968 \n", "\n", " f_2_sd f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_des \\\n", - "0 0.003029 0.000270 0.000271 6.091118e-07 0.598172 \n", - "1 0.002957 0.007131 0.045842 3.871093e-02 0.465931 \n", - "2 0.003170 0.000139 0.000139 2.182845e-07 0.095701 \n", - "3 0.003359 0.416360 0.416360 2.241354e-07 -1.523601 \n", - "4 0.003040 0.009351 0.072646 6.329445e-02 0.305604 \n", - "5 0.003128 0.000003 0.000003 1.748655e-08 0.176775 \n", - "6 0.002993 0.001392 0.001395 2.614105e-06 0.598377 \n", - "7 0.005353 0.113034 0.113289 2.547206e-04 -4.752548 \n", - "8 0.003081 0.003593 0.003596 2.680733e-06 0.373769 \n", - "9 0.003075 0.000001 0.000001 1.175186e-08 0.100465 \n", + "0 0.003478 0.011921 0.436942 0.438615 0.794483 \n", + "1 0.003391 0.001982 0.270151 0.268299 0.061202 \n", + "2 0.003621 0.004066 0.116790 0.112724 0.536501 \n", + "3 0.003926 0.000285 0.447182 0.446908 -1.483338 \n", + "4 0.003477 0.000004 0.006273 0.006269 0.015632 \n", + "5 0.003551 0.001074 0.443808 0.442734 -0.071247 \n", + "6 0.003627 0.007275 0.312579 0.305304 0.616233 \n", + "7 0.003365 0.009657 0.426349 0.424642 0.067516 \n", + "8 0.003441 0.000992 0.189307 0.189448 -0.195204 \n", + "9 0.003403 0.000143 0.059446 0.059303 0.125187 \n", "\n", " f_2_des f_1_des \n", - "0 0.487136 0.999879 \n", - "1 0.419397 0.996782 \n", - "2 0.532832 0.999937 \n", - "3 0.550252 0.808412 \n", - "4 0.385587 0.995813 \n", - "5 0.520870 0.999999 \n", - "6 0.456026 0.999377 \n", - "7 0.730489 0.949449 \n", - "8 0.483644 0.998391 \n", - "9 0.486211 0.999999 \n", + "0 0.407170 0.979202 \n", + "1 0.502642 0.998734 \n", + "2 0.419572 0.998181 \n", + "3 0.371177 0.999833 \n", + "4 0.468363 0.999998 \n", + "5 0.394113 0.999519 \n", + "6 0.390422 0.996731 \n", + "7 0.481535 0.985387 \n", + "8 0.535801 0.999486 \n", + "9 0.450279 0.999934 \n", "\n", "[10 rows x 21 columns]" ] @@ -1045,7 +1043,7 @@ " \n", " 0\n", " acceptable\n", - " acceptable\n", + " unacceptable\n", " \n", " \n", " 1\n", @@ -1065,7 +1063,7 @@ " \n", " 4\n", " acceptable\n", - " acceptable\n", + " unacceptable\n", " \n", " \n", " 5\n", @@ -1075,22 +1073,22 @@ " \n", " 6\n", " acceptable\n", - " acceptable\n", + " ideal\n", " \n", " \n", " 7\n", " acceptable\n", - " acceptable\n", + " unacceptable\n", " \n", " \n", " 8\n", " acceptable\n", - " acceptable\n", + " unacceptable\n", " \n", " \n", " 9\n", " acceptable\n", - " ideal\n", + " unacceptable\n", " \n", " \n", "\n", @@ -1098,16 +1096,16 @@ ], "text/plain": [ " f_1_pred f_1_true\n", - "0 acceptable acceptable\n", + "0 acceptable unacceptable\n", "1 acceptable unacceptable\n", "2 acceptable ideal\n", "3 acceptable unacceptable\n", - "4 acceptable acceptable\n", + "4 acceptable unacceptable\n", "5 acceptable ideal\n", - "6 acceptable acceptable\n", - "7 acceptable acceptable\n", - "8 acceptable acceptable\n", - "9 acceptable ideal" + "6 acceptable ideal\n", + "7 acceptable unacceptable\n", + "8 acceptable unacceptable\n", + "9 acceptable unacceptable" ] }, "execution_count": 10, From 6cb4676c9423508d384a3f2434ed33c089ec5b4d Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 10:12:56 -0500 Subject: [PATCH 18/31] More type fixes --- bofire/surrogates/mlp.py | 6 ++---- bofire/surrogates/surrogate.py | 8 ++++---- bofire/surrogates/trainable.py | 8 ++++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index e843cd518..943da2402 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Literal, Optional, Sequence, Union +from typing import Literal, Optional, Sequence import numpy as np import pandas as pd @@ -136,9 +136,7 @@ def fit_mlp( lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, - loss_function: Union[ - nn.modules.loss.L1Loss, nn.modules.loss.CrossEntropyLoss - ] = nn.L1Loss, # type: ignore + loss_function=nn.L1Loss, # type: ignore ): """Fit a MLP to a dataset. diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index 06ecfe670..98d090963 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -51,11 +51,11 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: for featkey in self.outputs.get_keys(CategoricalOutput): pred_cols = pred_cols + [ f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories + for cat in self.outputs.get_by_key(featkey).categories # type: ignore ] sd_cols = sd_cols + [ f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories + for cat in self.outputs.get_by_key(featkey).categories # type: ignore ] for featkey in self.outputs.get_keys(ContinuousOutput): pred_cols = pred_cols + [f"{featkey}_pred"] @@ -95,11 +95,11 @@ def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: + [f"{featkey}_{t}" for t in ["pred", "sd"]] + [ f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories + for cat in self.outputs.get_by_key(featkey).categories # type: ignore ] + [ f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories + for cat in self.outputs.get_by_key(featkey).categories # type: ignore ] ) check_columns = check_columns + [ diff --git a/bofire/surrogates/trainable.py b/bofire/surrogates/trainable.py index 2c66a26b1..4f180c441 100644 --- a/bofire/surrogates/trainable.py +++ b/bofire/surrogates/trainable.py @@ -186,16 +186,16 @@ def cross_validate( # Convert to categorical if applicable if isinstance(self.outputs.get_by_key(key).objective, ConstrainedCategoricalObjective): # type: ignore y_test_pred[f"{key}_pred"] = y_test_pred[f"{key}_pred"].map( - self.outputs.get_by_key(key).objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() # type: ignore ) y_train_pred[f"{key}_pred"] = y_train_pred[f"{key}_pred"].map( - self.outputs.get_by_key(key).objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() # type: ignore ) y_test[key] = y_test[key].map( - self.outputs.get_by_key(key).objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() # type: ignore ) y_train[key] = y_train[key].map( - self.outputs.get_by_key(key).objective.to_dict_label() + self.outputs.get_by_key(key).objective.to_dict_label() # type: ignore ) # now store the results From 93d567576a7d567960b991ffdc8b33841c7de74a Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 10:18:22 -0500 Subject: [PATCH 19/31] Fix MLP loss function issue --- bofire/surrogates/mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 943da2402..02c256962 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Literal, Optional, Sequence +from typing import Literal, Optional, Sequence, Union import numpy as np import pandas as pd @@ -136,7 +136,7 @@ def fit_mlp( lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, - loss_function=nn.L1Loss, # type: ignore + loss_function: nn.modules.loss = nn.L1Loss, # type: ignore ): """Fit a MLP to a dataset. @@ -307,7 +307,7 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame): lr=self.lr, shuffle=self.shuffle, weight_decay=self.weight_decay, - loss_function=nn.CrossEntropyLoss, # utilizes logits as input + loss_function=nn.CrossEntropyLoss, # type: ignore ) mlps.append(mlp) self.model = _MLPEnsemble(mlps=mlps) From 3ca1d3481b8580633d226dd9bb526403b0108acb Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 10:19:46 -0500 Subject: [PATCH 20/31] Formatting --- bofire/surrogates/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 02c256962..854b4e018 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Literal, Optional, Sequence, Union +from typing import Literal, Optional, Sequence import numpy as np import pandas as pd From 5ed04e95ead0599ca052ddde722b30592e5adc22 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 10:27:53 -0500 Subject: [PATCH 21/31] Format MLP loss function --- bofire/surrogates/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 854b4e018..2c2926534 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -136,7 +136,7 @@ def fit_mlp( lr: float = 1e-4, shuffle: bool = True, weight_decay: float = 0.0, - loss_function: nn.modules.loss = nn.L1Loss, # type: ignore + loss_function=nn.L1Loss, # type: ignore ): """Fit a MLP to a dataset. From 417fb0f9c4f1f7439bc1849ea53140038580bee7 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 1 Feb 2024 13:27:49 -0500 Subject: [PATCH 22/31] Type checking fix --- bofire/data_models/features/categorical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 226882157..8107a8c59 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -398,7 +398,7 @@ def validate_objectives_unique(self): Returns: Tuple[str]: Tuple of the categories """ - if self.objective.categories != self.categories: + if self.objective.categories != self.categories: # type: ignore raise ValueError("categories must match to objective categories") return self From dc7dac0c91ae9be068781b697ed2075a3f092a91 Mon Sep 17 00:00:00 2001 From: gmancino Date: Thu, 8 Feb 2024 11:30:44 -0500 Subject: [PATCH 23/31] Start fixes --- bofire/data_models/features/categorical.py | 2 +- bofire/data_models/objectives/categorical.py | 27 +- bofire/data_models/surrogates/api.py | 16 + bofire/surrogates/mlp.py | 2 - tests/bofire/data_models/specs/objectives.py | 13 + tests/bofire/data_models/specs/outputs.py | 15 +- .../Unknown_Constraint_Classification.ipynb | 2110 ++++++++++++----- 7 files changed, 1584 insertions(+), 601 deletions(-) diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 8107a8c59..5c686e3d8 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -369,7 +369,7 @@ class CategoricalOutput(Output): @field_validator("categories") @classmethod - def validate_categories_unique(cls, categories: List[str]) -> List["str"]: + def validate_categories_unique(cls, categories: List[str]) -> List[str]: """validates that categories have unique names Args: diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index 960a481d7..870e76e68 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -36,19 +36,40 @@ class ConstrainedCategoricalObjective( type: Literal["ConstrainedCategoricalObjective"] = "ConstrainedCategoricalObjective" @field_validator( - "desirability", + "categories", ) - def validate_categories_unique(cls, desirability: List[bool], info) -> List[bool]: + def validate_categories_unique(cls, categories: List[str]) -> List[bool]: """validates that desirabilities match the categories Args: categories (List[str]): List or tuple of category names + Raises: + ValueError: when categories are not unique + + Returns: + List[str]: List of categories + """ + if len(categories) != len(set(categories)): + raise ValueError( + "Categories are not unique" + ) + return categories + + @field_validator( + "desirability", + ) + def validate_desirability(cls, desirability: List[bool], info) -> List[bool]: + """validates that desirabilities match the categories + + Args: + desireability (List[str]): List or tuple of desirabilities + Raises: ValueError: when desirability count is not equal to category count Returns: - Tuple[bool]: Tuple of the desirability + List[bool]: List of the desirability """ if len(desirability) != len(info.data["categories"]): raise ValueError( diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index e2f46583d..06de3dbde 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -63,6 +63,22 @@ PolynomialSurrogate, TanimotoGPSurrogate, ] + + AnyRegressionSurrogate = Union[ + EmpiricalSurrogate, + RandomForestSurrogate, + SingleTaskGPSurrogate, + MixedSingleTaskGPSurrogate, + MixedTanimotoGPSurrogate, + RegressionMLPEnsemble, + SaasSingleTaskGPSurrogate, + XGBoostSurrogate, + LinearSurrogate, + PolynomialSurrogate, + TanimotoGPSurrogate, + ] + + AnyClassificationSurrogate = ClassificationMLPEnsemble except ImportError: # with the minimal installationwe don't have botorch pass diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 2c2926534..83e9d0529 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -165,8 +165,6 @@ def fit_mlp( loss_function, nn.CrossEntropyLoss ): targets = targets.reshape((targets.shape[0], 1)) - else: - pass # Zero the gradients optimizer.zero_grad() diff --git a/tests/bofire/data_models/specs/objectives.py b/tests/bofire/data_models/specs/objectives.py index 85f306bec..bed2e9902 100644 --- a/tests/bofire/data_models/specs/objectives.py +++ b/tests/bofire/data_models/specs/objectives.py @@ -56,6 +56,7 @@ "eta": 1.0, }, ) + specs.add_invalid( objectives.ConstrainedCategoricalObjective, lambda: { @@ -67,3 +68,15 @@ error=ValueError, message="number of categories differs from number of desirabilities", ) + +specs.add_invalid( + objectives.ConstrainedCategoricalObjective, + lambda: { + "w": 1.0, + "categories": ["green", "red", "blue", "blue"], + "desirability": [True, False, True, False], + "eta": 1.0, + }, + error=ValueError, + message="categories must be unique", +) \ No newline at end of file diff --git a/tests/bofire/data_models/specs/outputs.py b/tests/bofire/data_models/specs/outputs.py index c6445f506..b4ffc60fc 100644 --- a/tests/bofire/data_models/specs/outputs.py +++ b/tests/bofire/data_models/specs/outputs.py @@ -1,5 +1,6 @@ from bofire.data_models.domain.api import Outputs -from bofire.data_models.features.api import CategoricalInput, ContinuousOutput +from bofire.data_models.features.api import CategoricalInput, ContinuousOutput, CategoricalOutput +from bofire.data_models.objectives.api import ConstrainedCategoricalObjective from tests.bofire.data_models.specs.specs import Specs specs = Specs([]) @@ -14,7 +15,6 @@ }, ) - specs.add_invalid( Outputs, lambda: { @@ -37,3 +37,14 @@ error=ValueError, message="Feature keys are not unique.", ) + +specs.add_invalid( + Outputs, + lambda: { + "features": [ + CategoricalOutput(key="b", categories=["a", "b"], objective=ConstrainedCategoricalObjective(categories=["c", "d"], desirability=[True, True])), + ], + }, + error=ValueError, + message="categories must match to objective categories", +) diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index fb39ddd4a..7ac093ea9 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -10,9 +10,7 @@ "\n", "This involves new models that produce `CategoricalOutput`'s rather than continuous outputs. Mathematically, if $g_{\\theta}:\\mathbb{R}^d\\to[0,1]^c$ represents the function governed by learnable parameters $\\theta$ which outputs a probability vector over $c$ potential classes (i.e. for input $x\\in\\mathbb{R}^d$, $g_{\\theta}(x)^\\top\\mathbf{1}=1$ where $\\mathbf{1}$ is the vector of all 1's) and we have acceptibility criteria for the corresponding classes given by $a\\in\\{0,1\\}^c$, we can compute the scalar output $g_{\\theta}(x)^\\top a\\in[0,1]$ which represents the expected value of acceptance as an objective value to be passed in as a constrained function.\n", "\n", - "In this script, we look at a modified and constrained version of the optimization problem associated with the [Levy function](https://www.sfu.ca/~ssurjano/levy.html), which has a global minima at $x^*=\\mathbf{1}$. We classify constraints for three classes: 'acceptable', 'unacceptable', and 'ideal' based on how close we are to the optimal decision variable; obviously, this value is unknown in a real-world setting, but this serves as a reasonable example.\n", - "\n", - "Initially, this script contains an example of JUST training the classification surrogate on the generated data." + "In this script, we look at the [Rosenbrock function constrained to a disk](https://en.wikipedia.org/wiki/Test_functions_for_optimization#cite_note-12) which attains a global minima at $(x_0^*,x_1^*)=(1.0, 1.0)$. To facilitate testing the functionality offered by BoFire, we label all points inside of the circle $x_0^2+x_1^2\\le2$ as 'acceptable' and futher label anything inside of the interesction of this circle and the circle $(x_0-1)^2+(x_1-1)^2\\le1.0$ as 'ideal'; points lying outside of these two locations are labeled as \"unacceptable.\"" ] }, { @@ -51,18 +49,33 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "# Write a function which scales the inputs according to the Levy function - i.e. computes $w_i$\n", - "def scale_inputs(x: pd.Series) -> pd.Series:\n", - " return 1 + (x - 1) / 4" + "# Write helper functions which give the objective and the constraints\n", + "def rosenbrock(x: pd.Series) -> pd.Series:\n", + " assert \"x_0\" in x.columns\n", + " assert \"x_1\" in x.columns\n", + " return (1 - x[\"x_0\"]) ** 2 + 100 * (x[\"x_1\"] - x[\"x_0\"] ** 2) ** 2\n", + "\n", + "def constraints(x: pd.Series) -> pd.Series:\n", + " assert \"x_0\" in x.columns\n", + " assert \"x_1\" in x.columns\n", + " feasiblity_vector = []\n", + " for _, row in x.iterrows():\n", + " if (row[\"x_0\"] ** 2 + row[\"x_1\"] ** 2 <= 2.0) and ((row[\"x_0\"] - 1.0) ** 2 + (row[\"x_1\"] - 1.0) ** 2 <= 1.0):\n", + " feasiblity_vector.append(\"ideal\")\n", + " elif row[\"x_0\"] ** 2 + row[\"x_1\"] ** 2 <= 2.0:\n", + " feasiblity_vector.append(\"acceptable\")\n", + " else:\n", + " feasiblity_vector.append(\"unacceptable\")\n", + " return feasiblity_vector" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 158, "metadata": {}, "outputs": [ { @@ -88,10 +101,7 @@ " \n", " x_0\n", " x_1\n", - " x_2\n", " x_3\n", - " x_4\n", - " x_5\n", " f_0\n", " f_1\n", " f_2\n", @@ -100,302 +110,70 @@ " \n", " \n", " 0\n", - " -0.267947\n", - " 1.021106\n", - " 1.623060\n", - " 0.627026\n", - " 1.206148\n", - " 0\n", - " 1.140700\n", - " ideal\n", - " -0.265670\n", - " \n", - " \n", - " 1\n", - " -0.988891\n", - " -0.747924\n", - " -0.077380\n", - " -1.366909\n", - " 0.122713\n", - " 0\n", - " 4.933869\n", - " unacceptable\n", - " -0.981222\n", - " \n", - " \n", - " 2\n", - " 0.171629\n", - " -0.592879\n", - " 0.188483\n", - " -0.706858\n", - " -0.710738\n", - " 0\n", - " 1.585117\n", - " unacceptable\n", - " 0.173409\n", - " \n", - " \n", - " 3\n", - " -1.745543\n", - " 1.718025\n", - " -0.866392\n", - " -0.236673\n", - " 1.936513\n", + " -0.149953\n", + " -0.704143\n", " 1\n", - " 6.809027\n", - " unacceptable\n", - " -1.740944\n", - " \n", - " \n", - " 4\n", - " 0.461210\n", - " -0.087926\n", - " 0.883033\n", - " -0.319482\n", - " -1.813779\n", - " 1\n", - " 5.272989\n", + " 54.121306\n", " acceptable\n", - " 0.461377\n", - " \n", - " \n", - " 5\n", - " 1.386029\n", - " -0.098083\n", - " 0.174276\n", - " 0.804043\n", - " 0.321967\n", - " 0\n", - " 0.470671\n", - " unacceptable\n", - " 1.391563\n", - " \n", - " \n", - " 6\n", - " -1.946305\n", - " -0.055122\n", - " 1.468864\n", - " -1.514258\n", - " -1.641203\n", - " 0\n", - " 13.303294\n", - " unacceptable\n", - " -1.938876\n", - " \n", - " \n", - " 7\n", - " 0.381642\n", - " 1.842479\n", - " -1.500184\n", - " 0.353710\n", - " 0.508349\n", - " 0\n", - " 3.959427\n", - " ideal\n", - " 0.390396\n", - " \n", - " \n", - " 8\n", - " -0.428826\n", - " -1.958059\n", - " -0.221027\n", - " -1.372467\n", - " 1.275357\n", - " 0\n", - " 9.170777\n", - " ideal\n", - " -0.419537\n", + " 1.009579\n", " \n", " \n", - " 9\n", - " -0.473310\n", - " -0.402426\n", - " -0.608333\n", - " -0.403410\n", - " 1.495413\n", - " 1\n", - " 1.713585\n", - " unacceptable\n", - " -0.465823\n", - " \n", - " \n", - " 10\n", - " -1.961280\n", - " -1.586803\n", - " -0.561849\n", - " 0.064580\n", - " -1.416727\n", + " 1\n", + " -0.625311\n", + " -0.463190\n", " 0\n", - " 12.654743\n", - " acceptable\n", - " -1.958333\n", - " \n", - " \n", - " 11\n", - " -0.235274\n", - " -0.638830\n", - " -0.107924\n", - " 1.769868\n", - " 1.108539\n", - " 1\n", - " 1.582444\n", + " 75.608036\n", " acceptable\n", - " -0.227187\n", + " 0.001819\n", " \n", " \n", - " 12\n", - " 0.617453\n", - " 0.186646\n", - " -0.328716\n", - " 1.212407\n", - " -0.965701\n", - " 0\n", - " 1.254832\n", - " unacceptable\n", - " 0.625534\n", - " \n", - " \n", - " 13\n", - " -0.004329\n", - " -0.521047\n", - " -1.408930\n", - " -0.211512\n", - " -0.027510\n", + " 2\n", + " -0.765853\n", + " 0.927654\n", " 1\n", - " 3.536682\n", - " unacceptable\n", - " -0.004252\n", - " \n", - " \n", - " 14\n", - " 0.473978\n", - " -1.578219\n", - " 0.083191\n", - " -0.357965\n", - " 1.037494\n", - " 0\n", - " 3.895441\n", - " acceptable\n", - " 0.478809\n", - " \n", - " \n", - " 15\n", - " -0.653601\n", - " -0.407954\n", - " -1.939297\n", - " 1.055724\n", - " 0.830087\n", - " 0\n", - " 6.974997\n", - " ideal\n", - " -0.651507\n", - " \n", - " \n", - " 16\n", - " 0.285208\n", - " -1.327075\n", - " 0.705723\n", - " 1.284219\n", - " -0.324223\n", - " 0\n", - " 2.739220\n", + " 14.754710\n", " acceptable\n", - " 0.286143\n", + " 1.006574\n", " \n", " \n", - " 17\n", - " 0.462940\n", - " -1.584694\n", - " 1.114151\n", - " 0.626237\n", - " -1.071201\n", + " 3\n", + " -1.447047\n", + " -0.059688\n", " 1\n", - " 4.968933\n", - " ideal\n", - " 0.466059\n", - " \n", - " \n", - " 18\n", - " 1.746916\n", - " 1.093672\n", - " 1.004508\n", - " -1.620785\n", - " -1.880987\n", - " 0\n", - " 9.614141\n", + " 469.801324\n", " unacceptable\n", - " 1.748456\n", + " 1.002428\n", " \n", " \n", - " 19\n", - " 1.276677\n", - " -0.126877\n", - " -1.214876\n", - " -0.465293\n", - " -0.207115\n", + " 4\n", + " -0.540554\n", + " 1.090780\n", " 1\n", - " 2.140485\n", - " unacceptable\n", - " 1.283395\n", + " 66.146436\n", + " acceptable\n", + " 1.004633\n", " \n", " \n", "\n", "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_0 \\\n", - "0 -0.267947 1.021106 1.623060 0.627026 1.206148 0 1.140700 \n", - "1 -0.988891 -0.747924 -0.077380 -1.366909 0.122713 0 4.933869 \n", - "2 0.171629 -0.592879 0.188483 -0.706858 -0.710738 0 1.585117 \n", - "3 -1.745543 1.718025 -0.866392 -0.236673 1.936513 1 6.809027 \n", - "4 0.461210 -0.087926 0.883033 -0.319482 -1.813779 1 5.272989 \n", - "5 1.386029 -0.098083 0.174276 0.804043 0.321967 0 0.470671 \n", - "6 -1.946305 -0.055122 1.468864 -1.514258 -1.641203 0 13.303294 \n", - "7 0.381642 1.842479 -1.500184 0.353710 0.508349 0 3.959427 \n", - "8 -0.428826 -1.958059 -0.221027 -1.372467 1.275357 0 9.170777 \n", - "9 -0.473310 -0.402426 -0.608333 -0.403410 1.495413 1 1.713585 \n", - "10 -1.961280 -1.586803 -0.561849 0.064580 -1.416727 0 12.654743 \n", - "11 -0.235274 -0.638830 -0.107924 1.769868 1.108539 1 1.582444 \n", - "12 0.617453 0.186646 -0.328716 1.212407 -0.965701 0 1.254832 \n", - "13 -0.004329 -0.521047 -1.408930 -0.211512 -0.027510 1 3.536682 \n", - "14 0.473978 -1.578219 0.083191 -0.357965 1.037494 0 3.895441 \n", - "15 -0.653601 -0.407954 -1.939297 1.055724 0.830087 0 6.974997 \n", - "16 0.285208 -1.327075 0.705723 1.284219 -0.324223 0 2.739220 \n", - "17 0.462940 -1.584694 1.114151 0.626237 -1.071201 1 4.968933 \n", - "18 1.746916 1.093672 1.004508 -1.620785 -1.880987 0 9.614141 \n", - "19 1.276677 -0.126877 -1.214876 -0.465293 -0.207115 1 2.140485 \n", - "\n", - " f_1 f_2 \n", - "0 ideal -0.265670 \n", - "1 unacceptable -0.981222 \n", - "2 unacceptable 0.173409 \n", - "3 unacceptable -1.740944 \n", - "4 acceptable 0.461377 \n", - "5 unacceptable 1.391563 \n", - "6 unacceptable -1.938876 \n", - "7 ideal 0.390396 \n", - "8 ideal -0.419537 \n", - "9 unacceptable -0.465823 \n", - "10 acceptable -1.958333 \n", - "11 acceptable -0.227187 \n", - "12 unacceptable 0.625534 \n", - "13 unacceptable -0.004252 \n", - "14 acceptable 0.478809 \n", - "15 ideal -0.651507 \n", - "16 acceptable 0.286143 \n", - "17 ideal 0.466059 \n", - "18 unacceptable 1.748456 \n", - "19 unacceptable 1.283395 " + " x_0 x_1 x_3 f_0 f_1 f_2\n", + "0 -0.149953 -0.704143 1 54.121306 acceptable 1.009579\n", + "1 -0.625311 -0.463190 0 75.608036 acceptable 0.001819\n", + "2 -0.765853 0.927654 1 14.754710 acceptable 1.006574\n", + "3 -1.447047 -0.059688 1 469.801324 unacceptable 1.002428\n", + "4 -0.540554 1.090780 1 66.146436 acceptable 1.004633" ] }, - "execution_count": 3, + "execution_count": 158, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Set-up the inputs and outputs, use categorical domain just as an example\n", - "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(-2, 2)) for i in range(5)] + [CategoricalInput(key=f\"x_5\", categories=[\"0\", \"1\"], allowed=[True, True])])\n", + "input_features = Inputs(features=[ContinuousInput(key=f\"x_{i}\", bounds=(-1.75, 1.75)) for i in range(2)] + [CategoricalInput(key=f\"x_3\", categories=[\"0\", \"1\"], allowed=[True, True])])\n", "\n", "# here the minimize objective is used, if you want to maximize you have to use the maximize objective.\n", "output_features = Outputs(features=[\n", @@ -409,16 +187,1190 @@ "domain1 = Domain(inputs=input_features, outputs=output_features)\n", "\n", "# Sample random points\n", - "sample_df = domain1.inputs.sample(50)\n", + "sample_df = domain1.inputs.sample(100)\n", "\n", "# Write a function which outputs one continuous variable and another discrete based on some logic\n", - "sample_df[\"f_0\"] = np.sin(np.pi * scale_inputs(sample_df[\"x_0\"])) ** 2 + sum([(scale_inputs(sample_df[col]) - 1) ** 2 * (1 + 10 * np.sin(np.pi * scale_inputs(sample_df[col]) + 1) ** 2 if ind < len(sample_df.columns) else 1 + np.sin(2 * np.pi * scale_inputs(sample_df[col])) ** 2) for ind, col in enumerate(sample_df.columns) if not sample_df[col].dtype == \"O\"])\n", - "sample_df[\"f_1\"] = \"unacceptable\"\n", - "sample_df.loc[(sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 6.5) * (sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 3.5), \"f_1\"] = \"acceptable\"\n", - "sample_df.loc[(sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 5.5) * (sample_df[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 4.5), \"f_1\"] = \"ideal\"\n", - "sample_df[\"f_2\"] = sample_df[\"x_0\"] + 1e-2 * np.random.uniform(size=(len(sample_df),))\n", - "\n", - "sample_df.head(20)" + "sample_df[\"f_0\"] = rosenbrock(x=sample_df)\n", + "sample_df[\"f_1\"] = constraints(x=sample_df)\n", + "sample_df[\"f_2\"] = sample_df[\"x_3\"].astype(float) + 1e-2 * np.random.uniform(size=(len(sample_df),))\n", + "sample_df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "f_1=acceptable
x_0=%{x}
x_1=%{y}", + "legendgroup": "acceptable", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "acceptable", + "orientation": "v", + "showlegend": true, + "type": "scatter", + "x": [ + -0.14995258503883457, + -0.6253111311588586, + -0.765853028108006, + -0.5405541311965942, + 0.9894994728281232, + -0.865978521372166, + -1.0217962106629315, + -0.4014156739458248, + -0.12059424146260334, + -1.0084908064927265, + -0.9349526862095585, + 0.9458316530354285, + -0.5320149984993212, + -0.4689036742689612, + -0.07833150213004725, + -0.9793104434929384, + 0.5788385198054735, + -0.9481124505558867, + -1.1878657259442678, + 0.7035953005643529, + 0.2579234370307777, + -1.2867733344241161, + -0.3218188759641296, + -0.29215910785338806, + -0.8088505843334085, + -0.16053526753142844, + 0.6219324521230791, + 0.03371532152549417, + -0.18044427816549646, + 0.6672898005757197, + 0.36287033501092125, + -0.5410845636765149, + 0.6184934671940714, + 0.79958184162304, + 0.6854423654922401, + -0.32974185478696016, + -0.5158328238096581, + -0.8267714092240124, + -0.9304400399308963, + 0.29751211472578376, + -1.0674027719621937, + -0.7640154931359711 + ], + "xaxis": "x", + "y": [ + -0.7041428419544369, + -0.4631897120601236, + 0.9276536133420858, + 1.0907795664450228, + -0.20189922968971485, + 0.8873233761755905, + 0.7368573002639325, + 0.7110972468528676, + -0.43418666628890823, + 0.08216578094828186, + -0.5657310098005697, + -1.0505195679349209, + 0.4106824576035617, + -1.1004617179921463, + 0.9644868021314426, + 0.8813719192174521, + 0.006476845776403506, + 0.2047160309884497, + -0.48721762585284756, + -0.8694642189151136, + -0.8858761844683214, + -0.26450926092801574, + 0.7700394750617825, + 0.7931467073831526, + -1.099688304768781, + 1.2163564957070103, + -0.021834260403785777, + 0.5629070883753866, + -0.8868851959550068, + -1.2323026418478094, + -0.9954921557404812, + -0.7890090665021279, + -1.1629674424244054, + -0.4567118757180961, + -0.634972965145935, + 0.6601685944241988, + -0.6391884645670298, + -0.9719722439473206, + -0.744901749184502, + -0.9803760720872019, + -0.5942335372991383, + 0.23621357099645968 + ], + "yaxis": "y" + }, + { + "hovertemplate": "f_1=unacceptable
x_0=%{x}
x_1=%{y}", + "legendgroup": "unacceptable", + "marker": { + "color": "#EF553B", + "symbol": "circle" + }, + "mode": "markers", + "name": "unacceptable", + "orientation": "v", + "showlegend": true, + "type": "scatter", + "x": [ + -1.4470467658262591, + -1.6009504802388597, + 1.4204895481210702, + 1.3489227092112688, + -0.639284790399812, + 1.575648155774772, + 1.7270515387294698, + -1.1378584773675158, + -0.4870902364848939, + 1.3433379011252464, + 1.6525833048685383, + 0.052856077835125026, + -1.718194617060083, + 1.2519082611394308, + -1.3754185446920677, + -0.18322045819850907, + 1.327292051993207, + 1.0733856862440985, + 0.9008441174905695, + 0.47167063168201295, + -1.28379343444775, + -1.225041501040866, + -0.9879804926500095, + -1.7123070859149068, + -0.8222878848033226, + 1.698340355837336, + -0.40510921927474874, + -1.139666414607729, + 1.669256729311793, + -0.5749235839074829, + 1.3532335670733184, + 1.410971842195503, + 1.2552385529823566, + -1.0402317158880352, + -1.400659402027841, + 1.6382339166896722, + -1.3547647166845345, + 1.1865185162807816, + -1.4333442355901473, + 1.646587754792467, + 0.582077673409676, + 1.380227103874569, + 1.3622158395068897, + 1.1921850110333656, + -1.2624798548187615, + 1.70645818280738 + ], + "xaxis": "x", + "y": [ + -0.05968813763014724, + -0.08203263567781471, + -1.7256511732407582, + -1.3062738654746688, + -1.380758573136951, + 1.3698035388942658, + -0.14646898218101168, + 1.1472779541835205, + -1.546055723303492, + 1.2777644377843878, + -1.0947398765543759, + -1.691760325948563, + 0.2847825778795654, + -1.7156770939939636, + -1.3464909491716763, + 1.689598233686933, + -0.7777149658959187, + -1.457321100538615, + 1.7237444956486576, + -1.7222531301544601, + 1.4804621415248884, + -1.0312548094743161, + 1.464230692511908, + 1.727604973119779, + -1.4420228070981478, + 1.026409514853559, + 1.4523235910972625, + 1.6510720743055476, + 0.5237495796359455, + 1.6473828157610697, + -0.6903014501907374, + 1.3866811498163663, + -1.1832663015708027, + 1.3995479124379653, + 0.4939162310211662, + 1.6967495852807661, + 0.4613230281774947, + -1.1777089751033198, + 0.09067972703617566, + 0.5611805367913849, + -1.5872277904858516, + 0.39999239270670284, + -0.6292457424969355, + 1.2082895149675292, + -1.6847143997552474, + -1.454374296036075 + ], + "yaxis": "y" + }, + { + "hovertemplate": "f_1=ideal
x_0=%{x}
x_1=%{y}", + "legendgroup": "ideal", + "marker": { + "color": "#00cc96", + "symbol": "circle" + }, + "mode": "markers", + "name": "ideal", + "orientation": "v", + "showlegend": true, + "type": "scatter", + "x": [ + 0.9061821719090513, + 0.5162826776143459, + 0.5094335938756926, + 0.6859698250145936, + 0.5823600872410029, + 1.0858917598939684, + 1.3205055890138766, + 0.14878500290037744, + 0.7581455463744304, + 1.1678712726745357, + 0.6653414272221312, + 0.2300548458979994 + ], + "xaxis": "x", + "y": [ + 0.886718714624303, + 0.39576812610714196, + 0.5122034490809124, + 0.7226512509233372, + 0.7116114473585995, + 0.7731429665448246, + 0.40197976993092377, + 0.5202938254854113, + 0.24505092339428747, + 0.5559493377544111, + 0.28496872065832335, + 0.8280874889772782 + ], + "yaxis": "y" + } + ], + "layout": { + "height": 525, + "legend": { + "title": { + "text": "f_1" + }, + "tracegroupgap": 0 + }, + "shapes": [ + { + "fillcolor": "red", + "line": { + "color": "red" + }, + "opacity": 0.1, + "type": "circle", + "x0": -1.4142135623730951, + "x1": 1.4142135623730951, + "xref": "x", + "y0": -1.4142135623730951, + "y1": 1.4142135623730951, + "yref": "y" + }, + { + "fillcolor": "LightSeaGreen", + "line": { + "color": "LightSeaGreen" + }, + "opacity": 0.2, + "type": "circle", + "x0": 0, + "x1": 2, + "xref": "x", + "y0": 0, + "y1": 2, + "yref": "y" + } + ], + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Samples with labels" + }, + "width": 550, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x_0" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x_1" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot the sample df\n", + "import math\n", + "import plotly.express as px \n", + "fig = px.scatter(sample_df, x=\"x_0\", y=\"x_1\", color=\"f_1\", width=550, height=525, title=\"Samples with labels\")\n", + "fig.add_shape(type=\"circle\",\n", + " xref=\"x\", yref=\"y\",\n", + " opacity=0.1,\n", + " fillcolor=\"red\",\n", + " x0=-math.sqrt(2), y0=-math.sqrt(2), x1=math.sqrt(2), y1=math.sqrt(2),\n", + " line_color=\"red\",\n", + ")\n", + "fig.add_shape(type=\"circle\",\n", + " xref=\"x\", yref=\"y\",\n", + " opacity=0.2,\n", + " fillcolor=\"LightSeaGreen\",\n", + " x0=0, y0=0, x1=2, y1=2,\n", + " line_color=\"LightSeaGreen\",\n", + ")\n", + "fig.show()" ] }, { @@ -430,33 +1382,37 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 163, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n" ] } ], @@ -467,18 +1423,18 @@ "from bofire.surrogates.diagnostics import ClassificationMetricsEnum\n", "\n", "# Instantiate the surrogate model \n", - "model = ClassificationMLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.01, n_epochs=100, hidden_layer_sizes=(20,10,))\n", + "model = ClassificationMLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.03, n_epochs=100, hidden_layer_sizes=(4,2,), weight_decay=0.0, batch_size=10, activation=\"tanh\")\n", "surrogate = surrogates.map(model)\n", "\n", "# Fit the model to the classification data\n", "cv_df = sample_df.drop([\"f_0\", \"f_2\"], axis=1)\n", "cv_df[\"valid_f_1\"] = 1\n", - "cv = surrogate.cross_validate(cv_df, folds=5)\n" + "cv = surrogate.cross_validate(cv_df, folds=3)\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 164, "metadata": {}, "outputs": [ { @@ -509,19 +1465,19 @@ " \n", " \n", " 0\n", - " 0.745\n", - " 0.745\n", + " 0.68\n", + " 0.68\n", " \n", " \n", "\n", "" ], "text/plain": [ - " ACCURACY F1\n", - "0 0.745 0.745" + " ACCURACY F1\n", + "0 0.68 0.68" ] }, - "execution_count": 5, + "execution_count": 164, "metadata": {}, "output_type": "execute_result" } @@ -533,7 +1489,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 165, "metadata": {}, "outputs": [ { @@ -564,8 +1520,8 @@ " \n", " \n", " 0\n", - " 0.32\n", - " 0.32\n", + " 0.54\n", + " 0.54\n", " \n", " \n", "\n", @@ -573,10 +1529,10 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.32 0.32" + "0 0.54 0.54" ] }, - "execution_count": 6, + "execution_count": 165, "metadata": {}, "output_type": "execute_result" } @@ -595,7 +1551,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 166, "metadata": {}, "outputs": [], "source": [ @@ -608,7 +1564,7 @@ " acquisition_function=qEI(), \n", " surrogate_specs=BotorchSurrogates(surrogates=\n", " [\n", - " ClassificationMLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.01, n_epochs=100, hidden_layer_sizes=(20,10,)),\n", + " ClassificationMLPEnsemble(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_1\")]), lr=0.03, n_epochs=100, hidden_layer_sizes=(4,2,), weight_decay=0.0, batch_size=10, activation=\"tanh\"),\n", " MixedSingleTaskGPSurrogate(inputs=domain1.inputs, outputs=Outputs(features=[domain1.outputs.get_by_key(\"f_2\")]))\n", " ]\n", " )\n", @@ -621,19 +1577,25 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 167, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n", - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning: Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", - " warnings.warn(\n" + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n", + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\models\\model.py:230: RuntimeWarning:\n", + "\n", + "Could not update `train_inputs` with transformed inputs since _MLPEnsemble does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.\n", + "\n" ] }, { @@ -659,15 +1621,12 @@ " \n", " x_0\n", " x_1\n", - " x_2\n", " x_3\n", - " x_4\n", - " x_5\n", " f_1_pred\n", " f_1_sd\n", " f_0_pred\n", " f_2_pred\n", - " ...\n", + " f_1_unacceptable_prob\n", " f_1_acceptable_prob\n", " f_1_ideal_prob\n", " f_0_sd\n", @@ -683,302 +1642,269 @@ " \n", " \n", " 0\n", - " 0.744661\n", - " -0.014262\n", - " -0.451975\n", - " 2.000000\n", - " 0.264316\n", - " 1\n", + " 0.403369\n", + " 0.161374\n", + " 0\n", " acceptable\n", " 0.0\n", - " -0.794483\n", - " 0.751355\n", - " ...\n", - " 0.781416\n", - " 0.197786\n", - " 0.749730\n", - " 0.003478\n", - " 0.011921\n", - " 0.436942\n", - " 0.438615\n", - " 0.794483\n", - " 0.407170\n", - " 0.979202\n", + " 0.092455\n", + " 0.004524\n", + " 0.399619\n", + " 0.599358\n", + " 0.001023\n", + " 2.192848\n", + " 0.002932\n", + " 0.542024\n", + " 0.541092\n", + " 0.001150\n", + " -0.092455\n", + " 0.499435\n", + " 0.600381\n", " \n", " \n", " 1\n", - " -0.025971\n", - " 0.538385\n", - " 0.445088\n", - " -0.029734\n", - " 0.207444\n", + " 0.291096\n", + " 0.091716\n", " 1\n", " acceptable\n", " 0.0\n", - " -0.061202\n", - " -0.021137\n", - " ...\n", - " 0.878746\n", - " 0.119988\n", - " 0.471483\n", - " 0.003391\n", - " 0.001982\n", - " 0.270151\n", - " 0.268299\n", - " 0.061202\n", - " 0.502642\n", - " 0.998734\n", + " 0.338534\n", + " 1.005305\n", + " 0.112541\n", + " 0.886353\n", + " 0.001106\n", + " 2.772511\n", + " 0.002915\n", + " 0.197303\n", + " 0.196632\n", + " 0.001800\n", + " -0.338534\n", + " 0.376918\n", + " 0.887459\n", " \n", " \n", " 2\n", - " 0.643000\n", - " 0.170813\n", - " 2.000000\n", - " 2.000000\n", - " 0.250742\n", + " 1.319030\n", + " 1.750000\n", " 1\n", - " acceptable\n", + " unacceptable\n", " 0.0\n", - " -0.536501\n", - " 0.649062\n", - " ...\n", - " 0.947750\n", - " 0.050431\n", - " 1.069920\n", - " 0.003621\n", - " 0.004066\n", - " 0.116790\n", - " 0.112724\n", - " 0.536501\n", - " 0.419572\n", - " 0.998181\n", + " -2.338611\n", + " 1.003526\n", + " 0.688225\n", + " 0.110699\n", + " 0.201075\n", + " 4.714258\n", + " 0.003410\n", + " 0.449328\n", + " 0.240964\n", + " 0.445632\n", + " 2.338611\n", + " 0.377126\n", + " 0.311775\n", " \n", " \n", " 3\n", - " 1.046492\n", - " 1.009127\n", - " 1.693894\n", - " 2.000000\n", - " 2.000000\n", - " 1\n", + " 0.086516\n", + " -0.000797\n", + " 0\n", " acceptable\n", " 0.0\n", - " 1.483338\n", - " 1.054340\n", - " ...\n", - " 0.799970\n", - " 0.199863\n", - " 1.958267\n", - " 0.003926\n", - " 0.000285\n", - " 0.447182\n", - " 0.446908\n", - " -1.483338\n", - " 0.371177\n", - " 0.999833\n", + " 0.900832\n", + " 0.004706\n", + " 0.398625\n", + " 0.600412\n", + " 0.000963\n", + " 2.342997\n", + " 0.002935\n", + " 0.542560\n", + " 0.541685\n", + " 0.001131\n", + " -0.900832\n", + " 0.499412\n", + " 0.601375\n", " \n", " \n", " 4\n", - " 0.249543\n", - " 0.448522\n", - " 2.000000\n", - " -0.021518\n", - " 0.143814\n", - " 1\n", + " -0.228315\n", + " 0.046514\n", + " 0\n", " acceptable\n", " 0.0\n", - " -0.015632\n", - " 0.253437\n", - " ...\n", - " 0.997194\n", - " 0.002804\n", - " 0.752782\n", - " 0.003477\n", - " 0.000004\n", - " 0.006273\n", - " 0.006269\n", - " 0.015632\n", - " 0.468363\n", - " 0.999998\n", + " 1.500286\n", + " 0.004834\n", + " 0.397428\n", + " 0.601626\n", + " 0.000946\n", + " 2.390450\n", + " 0.002935\n", + " 0.541194\n", + " 0.540336\n", + " 0.001115\n", + " -1.500286\n", + " 0.499396\n", + " 0.602572\n", " \n", " \n", " 5\n", - " 0.854194\n", - " 0.464335\n", - " 1.579664\n", - " 0.782255\n", - " 1.068275\n", + " 0.076667\n", + " -0.004871\n", " 1\n", " acceptable\n", " 0.0\n", - " 0.071247\n", - " 0.860109\n", - " ...\n", - " 0.801522\n", - " 0.197997\n", - " 0.378974\n", - " 0.003551\n", - " 0.001074\n", - " 0.443808\n", - " 0.442734\n", - " -0.071247\n", - " 0.394113\n", - " 0.999519\n", + " 0.979306\n", + " 1.005360\n", + " 0.198561\n", + " 0.800833\n", + " 0.000606\n", + " 2.372284\n", + " 0.002916\n", + " 0.237685\n", + " 0.237178\n", + " 0.000844\n", + " -0.979306\n", + " 0.376911\n", + " 0.801439\n", " \n", " \n", " 6\n", - " 0.884042\n", - " 0.406672\n", - " 1.494174\n", - " 2.000000\n", - " 0.157301\n", - " 1\n", + " 0.308219\n", + " 0.087674\n", + " 0\n", " acceptable\n", " 0.0\n", - " -0.616233\n", - " 0.891074\n", - " ...\n", - " 0.860195\n", - " 0.136536\n", - " 0.916442\n", - " 0.003627\n", - " 0.007275\n", - " 0.312579\n", - " 0.305304\n", - " 0.616233\n", - " 0.390422\n", - " 0.996731\n", + " 0.222725\n", + " 0.004599\n", + " 0.399124\n", + " 0.599887\n", + " 0.000989\n", + " 2.834773\n", + " 0.002933\n", + " 0.542427\n", + " 0.541527\n", + " 0.001137\n", + " -0.222725\n", + " 0.499425\n", + " 0.600876\n", " \n", " \n", " 7\n", - " 0.143312\n", - " 0.679986\n", - " -0.116402\n", - " 0.081383\n", - " 0.247956\n", - " 1\n", - " acceptable\n", + " 0.829716\n", + " 0.681121\n", + " 0\n", + " unacceptable\n", " 0.0\n", - " -0.067516\n", - " 0.147784\n", - " ...\n", - " 0.795419\n", - " 0.189967\n", - " 0.470200\n", - " 0.003365\n", - " 0.009657\n", - " 0.426349\n", - " 0.424642\n", - " 0.067516\n", - " 0.481535\n", - " 0.985387\n", + " -0.348330\n", + " 0.003866\n", + " 0.510435\n", + " 0.418242\n", + " 0.071323\n", + " 1.065240\n", + " 0.002941\n", + " 0.496067\n", + " 0.425085\n", + " 0.157438\n", + " 0.348330\n", + " 0.499517\n", + " 0.489565\n", " \n", " \n", " 8\n", - " -0.293373\n", - " 0.876231\n", - " 0.524355\n", - " 0.052708\n", - " 0.143248\n", + " -0.219844\n", + " 0.044418\n", " 1\n", " acceptable\n", " 0.0\n", - " 0.195204\n", - " -0.286897\n", - " ...\n", - " 0.914762\n", - " 0.084724\n", - " 0.575528\n", - " 0.003441\n", - " 0.000992\n", - " 0.189307\n", - " 0.189448\n", - " -0.195204\n", - " 0.535801\n", - " 0.999486\n", + " 1.494160\n", + " 1.005372\n", + " 0.239904\n", + " 0.759543\n", + " 0.000553\n", + " 2.332408\n", + " 0.002915\n", + " 0.304143\n", + " 0.303688\n", + " 0.000761\n", + " -1.494160\n", + " 0.376910\n", + " 0.760096\n", " \n", " \n", " 9\n", - " 0.395178\n", - " 0.358173\n", - " 1.405300\n", - " -0.040535\n", - " 0.223360\n", - " 1\n", + " -0.063892\n", + " 0.013738\n", + " 0\n", " acceptable\n", " 0.0\n", - " -0.125187\n", - " 0.399090\n", - " ...\n", - " 0.973413\n", - " 0.026521\n", - " 0.356968\n", - " 0.003403\n", - " 0.000143\n", - " 0.059446\n", - " 0.059303\n", - " 0.125187\n", - " 0.450279\n", - " 0.999934\n", + " 1.097156\n", + " 0.004757\n", + " 0.398289\n", + " 0.600757\n", + " 0.000954\n", + " 1.653161\n", + " 0.002935\n", + " 0.542274\n", + " 0.541407\n", + " 0.001125\n", + " -1.097156\n", + " 0.499405\n", + " 0.601711\n", " \n", " \n", "\n", - "

10 rows × 21 columns

\n", "" ], "text/plain": [ - " x_0 x_1 x_2 x_3 x_4 x_5 f_1_pred f_1_sd \\\n", - "0 0.744661 -0.014262 -0.451975 2.000000 0.264316 1 acceptable 0.0 \n", - "1 -0.025971 0.538385 0.445088 -0.029734 0.207444 1 acceptable 0.0 \n", - "2 0.643000 0.170813 2.000000 2.000000 0.250742 1 acceptable 0.0 \n", - "3 1.046492 1.009127 1.693894 2.000000 2.000000 1 acceptable 0.0 \n", - "4 0.249543 0.448522 2.000000 -0.021518 0.143814 1 acceptable 0.0 \n", - "5 0.854194 0.464335 1.579664 0.782255 1.068275 1 acceptable 0.0 \n", - "6 0.884042 0.406672 1.494174 2.000000 0.157301 1 acceptable 0.0 \n", - "7 0.143312 0.679986 -0.116402 0.081383 0.247956 1 acceptable 0.0 \n", - "8 -0.293373 0.876231 0.524355 0.052708 0.143248 1 acceptable 0.0 \n", - "9 0.395178 0.358173 1.405300 -0.040535 0.223360 1 acceptable 0.0 \n", + " x_0 x_1 x_3 f_1_pred f_1_sd f_0_pred f_2_pred \\\n", + "0 0.403369 0.161374 0 acceptable 0.0 0.092455 0.004524 \n", + "1 0.291096 0.091716 1 acceptable 0.0 0.338534 1.005305 \n", + "2 1.319030 1.750000 1 unacceptable 0.0 -2.338611 1.003526 \n", + "3 0.086516 -0.000797 0 acceptable 0.0 0.900832 0.004706 \n", + "4 -0.228315 0.046514 0 acceptable 0.0 1.500286 0.004834 \n", + "5 0.076667 -0.004871 1 acceptable 0.0 0.979306 1.005360 \n", + "6 0.308219 0.087674 0 acceptable 0.0 0.222725 0.004599 \n", + "7 0.829716 0.681121 0 unacceptable 0.0 -0.348330 0.003866 \n", + "8 -0.219844 0.044418 1 acceptable 0.0 1.494160 1.005372 \n", + "9 -0.063892 0.013738 0 acceptable 0.0 1.097156 0.004757 \n", "\n", - " f_0_pred f_2_pred ... f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", - "0 -0.794483 0.751355 ... 0.781416 0.197786 0.749730 \n", - "1 -0.061202 -0.021137 ... 0.878746 0.119988 0.471483 \n", - "2 -0.536501 0.649062 ... 0.947750 0.050431 1.069920 \n", - "3 1.483338 1.054340 ... 0.799970 0.199863 1.958267 \n", - "4 -0.015632 0.253437 ... 0.997194 0.002804 0.752782 \n", - "5 0.071247 0.860109 ... 0.801522 0.197997 0.378974 \n", - "6 -0.616233 0.891074 ... 0.860195 0.136536 0.916442 \n", - "7 -0.067516 0.147784 ... 0.795419 0.189967 0.470200 \n", - "8 0.195204 -0.286897 ... 0.914762 0.084724 0.575528 \n", - "9 -0.125187 0.399090 ... 0.973413 0.026521 0.356968 \n", + " f_1_unacceptable_prob f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", + "0 0.399619 0.599358 0.001023 2.192848 \n", + "1 0.112541 0.886353 0.001106 2.772511 \n", + "2 0.688225 0.110699 0.201075 4.714258 \n", + "3 0.398625 0.600412 0.000963 2.342997 \n", + "4 0.397428 0.601626 0.000946 2.390450 \n", + "5 0.198561 0.800833 0.000606 2.372284 \n", + "6 0.399124 0.599887 0.000989 2.834773 \n", + "7 0.510435 0.418242 0.071323 1.065240 \n", + "8 0.239904 0.759543 0.000553 2.332408 \n", + "9 0.398289 0.600757 0.000954 1.653161 \n", "\n", " f_2_sd f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_des \\\n", - "0 0.003478 0.011921 0.436942 0.438615 0.794483 \n", - "1 0.003391 0.001982 0.270151 0.268299 0.061202 \n", - "2 0.003621 0.004066 0.116790 0.112724 0.536501 \n", - "3 0.003926 0.000285 0.447182 0.446908 -1.483338 \n", - "4 0.003477 0.000004 0.006273 0.006269 0.015632 \n", - "5 0.003551 0.001074 0.443808 0.442734 -0.071247 \n", - "6 0.003627 0.007275 0.312579 0.305304 0.616233 \n", - "7 0.003365 0.009657 0.426349 0.424642 0.067516 \n", - "8 0.003441 0.000992 0.189307 0.189448 -0.195204 \n", - "9 0.003403 0.000143 0.059446 0.059303 0.125187 \n", + "0 0.002932 0.542024 0.541092 0.001150 -0.092455 \n", + "1 0.002915 0.197303 0.196632 0.001800 -0.338534 \n", + "2 0.003410 0.449328 0.240964 0.445632 2.338611 \n", + "3 0.002935 0.542560 0.541685 0.001131 -0.900832 \n", + "4 0.002935 0.541194 0.540336 0.001115 -1.500286 \n", + "5 0.002916 0.237685 0.237178 0.000844 -0.979306 \n", + "6 0.002933 0.542427 0.541527 0.001137 -0.222725 \n", + "7 0.002941 0.496067 0.425085 0.157438 0.348330 \n", + "8 0.002915 0.304143 0.303688 0.000761 -1.494160 \n", + "9 0.002935 0.542274 0.541407 0.001125 -1.097156 \n", "\n", " f_2_des f_1_des \n", - "0 0.407170 0.979202 \n", - "1 0.502642 0.998734 \n", - "2 0.419572 0.998181 \n", - "3 0.371177 0.999833 \n", - "4 0.468363 0.999998 \n", - "5 0.394113 0.999519 \n", - "6 0.390422 0.996731 \n", - "7 0.481535 0.985387 \n", - "8 0.535801 0.999486 \n", - "9 0.450279 0.999934 \n", - "\n", - "[10 rows x 21 columns]" + "0 0.499435 0.600381 \n", + "1 0.376918 0.887459 \n", + "2 0.377126 0.311775 \n", + "3 0.499412 0.601375 \n", + "4 0.499396 0.602572 \n", + "5 0.376911 0.801439 \n", + "6 0.499425 0.600876 \n", + "7 0.499517 0.489565 \n", + "8 0.376910 0.760096 \n", + "9 0.499405 0.601711 " ] }, - "execution_count": 8, + "execution_count": 167, "metadata": {}, "output_type": "execute_result" } @@ -999,19 +1925,17 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 168, "metadata": {}, "outputs": [], "source": [ "# Append to the candidates\n", - "candidates[\"f_1_true\"] = \"unacceptable\"\n", - "candidates.loc[(candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 6.5) * (candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 3.5), \"f_1_true\"] = \"acceptable\"\n", - "candidates.loc[(candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) <= 5.5) * (candidates[input_features.get_keys(includes=ContinuousInput, excludes=CategoricalInput)].abs().sum(1) >= 4.5), \"f_1_true\"] = \"ideal\"" + "candidates[\"f_1_true\"] = constraints(x=candidates)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 169, "metadata": {}, "outputs": [ { @@ -1043,72 +1967,72 @@ " \n", " 0\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 1\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 2\n", - " acceptable\n", - " ideal\n", + " unacceptable\n", + " unacceptable\n", " \n", " \n", " 3\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 4\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 5\n", " acceptable\n", - " ideal\n", + " acceptable\n", " \n", " \n", " 6\n", " acceptable\n", - " ideal\n", + " acceptable\n", " \n", " \n", " 7\n", - " acceptable\n", " unacceptable\n", + " ideal\n", " \n", " \n", " 8\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", " 9\n", " acceptable\n", - " unacceptable\n", + " acceptable\n", " \n", " \n", "\n", "" ], "text/plain": [ - " f_1_pred f_1_true\n", - "0 acceptable unacceptable\n", - "1 acceptable unacceptable\n", - "2 acceptable ideal\n", - "3 acceptable unacceptable\n", - "4 acceptable unacceptable\n", - "5 acceptable ideal\n", - "6 acceptable ideal\n", - "7 acceptable unacceptable\n", - "8 acceptable unacceptable\n", - "9 acceptable unacceptable" + " f_1_pred f_1_true\n", + "0 acceptable acceptable\n", + "1 acceptable acceptable\n", + "2 unacceptable unacceptable\n", + "3 acceptable acceptable\n", + "4 acceptable acceptable\n", + "5 acceptable acceptable\n", + "6 acceptable acceptable\n", + "7 unacceptable ideal\n", + "8 acceptable acceptable\n", + "9 acceptable acceptable" ] }, - "execution_count": 10, + "execution_count": 169, "metadata": {}, "output_type": "execute_result" } From 3d2465e11a8cadd28deaa7d8f6c0f82675903620 Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 9 Feb 2024 17:13:16 -0500 Subject: [PATCH 24/31] Update PR to include type changes and tests --- bofire/data_models/features/categorical.py | 48 +- bofire/data_models/objectives/categorical.py | 42 +- .../data_models/surrogates/fully_bayesian.py | 2 +- bofire/data_models/surrogates/linear.py | 2 +- .../surrogates/mixed_single_task_gp.py | 2 +- .../surrogates/mixed_tanimoto_gp.py | 2 +- bofire/data_models/surrogates/mlp.py | 22 +- bofire/data_models/surrogates/polynomial.py | 2 +- .../data_models/surrogates/random_forest.py | 2 +- .../data_models/surrogates/single_task_gp.py | 2 +- bofire/data_models/surrogates/surrogate.py | 5 +- bofire/data_models/surrogates/tanimoto_gp.py | 2 +- bofire/data_models/surrogates/xgb.py | 2 +- bofire/strategies/predictives/predictive.py | 34 +- bofire/surrogates/surrogate.py | 41 +- bofire/utils/naming_conventions.py | 23 + bofire/utils/torch_tools.py | 12 +- .../data_models/features/test_categorical.py | 15 + tests/bofire/data_models/specs/objectives.py | 7 +- tests/bofire/data_models/specs/outputs.py | 14 +- tests/bofire/data_models/specs/surrogates.py | 2 - tests/bofire/surrogates/test_diagnostics.py | 36 + tests/bofire/utils/test_torch_tools.py | 21 + .../Unknown_Constraint_Classification.ipynb | 990 +++++++++--------- .../007-Benchmark_outlier_detection.ipynb | 2 +- 25 files changed, 672 insertions(+), 660 deletions(-) create mode 100644 bofire/utils/naming_conventions.py diff --git a/bofire/data_models/features/categorical.py b/bofire/data_models/features/categorical.py index 295349933..6e93e1352 100644 --- a/bofire/data_models/features/categorical.py +++ b/bofire/data_models/features/categorical.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import Annotated, ClassVar, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd @@ -6,11 +6,8 @@ from bofire.data_models.enum import CategoricalEncodingEnum from bofire.data_models.features.feature import _CAT_SEP, Input, Output, TTransform -from bofire.data_models.types import TCategoryVals from bofire.data_models.objectives.api import AnyCategoricalObjective -from bofire.data_models.objectives.categorical import ( - ConstrainedCategoricalObjective, -) +from bofire.data_models.types import TCategoryVals class CategoricalInput(Input): @@ -339,55 +336,22 @@ class CategoricalOutput(Output): order_id: ClassVar[int] = 9 categories: TCategoryVals - objective: Optional[AnyCategoricalObjective] = Field( - default_factory=lambda: ConstrainedCategoricalObjective( - w=1.0, categories=["a", "b"], desirability=[True, False] - ) - ) - - @field_validator("categories") - @classmethod - def validate_categories_unique(cls, categories: List[str]) -> List[str]: - """validates that categories have unique names - - Args: - categories (List[str]): List or tuple of category names - - Raises: - ValueError: when categories have non-unique names - - Returns: - Tuple[str]: Tuple of the categories - """ - if len(categories) != len(set(categories)): - raise ValueError("categories must be unique") - return categories + objective: AnyCategoricalObjective @model_validator(mode="after") - def validate_objectives_unique(self): - """validates that categories have unique names - - Args: - categories (List[str]): List or tuple of category names + def validate_objective_categories(self): + """validates that objective categories match the output categories Raises: ValueError: when categories do not match objective categories Returns: - Tuple[str]: Tuple of the categories + self """ if self.objective.categories != self.categories: # type: ignore raise ValueError("categories must match to objective categories") return self - @classmethod - def from_objective( - cls, - key: str, - objective: ConstrainedCategoricalObjective, - ): - return cls(key=key, objective=objective, categories=objective.categories) - def __call__(self, values: pd.Series) -> pd.Series: if self.objective is None: return pd.Series( diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index 870e76e68..92e24e5f0 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -2,14 +2,14 @@ import numpy as np import pandas as pd -from pydantic import field_validator +from pydantic import model_validator -from bofire.data_models.features.feature import TCategoryVals from bofire.data_models.objectives.objective import ( ConstrainedObjective, Objective, TWeight, ) +from bofire.data_models.types import TCategoryVals class CategoricalObjective: @@ -32,50 +32,26 @@ class ConstrainedCategoricalObjective( w: TWeight = 1.0 categories: TCategoryVals desirability: List[bool] - eta: float = 1.0 type: Literal["ConstrainedCategoricalObjective"] = "ConstrainedCategoricalObjective" - @field_validator( - "categories", - ) - def validate_categories_unique(cls, categories: List[str]) -> List[bool]: - """validates that desirabilities match the categories + @model_validator(mode="after") + def validate_desireability(self): + """validates that categories have unique names Args: categories (List[str]): List or tuple of category names Raises: - ValueError: when categories are not unique + ValueError: when categories do not match objective categories Returns: - List[str]: List of categories + Tuple[str]: Tuple of the categories """ - if len(categories) != len(set(categories)): - raise ValueError( - "Categories are not unique" - ) - return categories - - @field_validator( - "desirability", - ) - def validate_desirability(cls, desirability: List[bool], info) -> List[bool]: - """validates that desirabilities match the categories - - Args: - desireability (List[str]): List or tuple of desirabilities - - Raises: - ValueError: when desirability count is not equal to category count - - Returns: - List[bool]: List of the desirability - """ - if len(desirability) != len(info.data["categories"]): + if len(self.desirability) != len(self.categories): raise ValueError( "number of categories differs from number of desirabilities" ) - return desirability + return self def to_dict(self) -> Dict: """Returns the categories and corresponding objective values as dictionary""" diff --git a/bofire/data_models/surrogates/fully_bayesian.py b/bofire/data_models/surrogates/fully_bayesian.py index 96e15215e..e114bdb02 100644 --- a/bofire/data_models/surrogates/fully_bayesian.py +++ b/bofire/data_models/surrogates/fully_bayesian.py @@ -28,4 +28,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/data_models/surrogates/linear.py b/bofire/data_models/surrogates/linear.py index 49b6f3980..211b6283f 100644 --- a/bofire/data_models/surrogates/linear.py +++ b/bofire/data_models/surrogates/linear.py @@ -28,4 +28,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/data_models/surrogates/mixed_single_task_gp.py b/bofire/data_models/surrogates/mixed_single_task_gp.py index 3a767d1e7..8d36e7b7e 100644 --- a/bofire/data_models/surrogates/mixed_single_task_gp.py +++ b/bofire/data_models/surrogates/mixed_single_task_gp.py @@ -45,4 +45,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/data_models/surrogates/mixed_tanimoto_gp.py b/bofire/data_models/surrogates/mixed_tanimoto_gp.py index 445cafddc..d2fb98dd4 100644 --- a/bofire/data_models/surrogates/mixed_tanimoto_gp.py +++ b/bofire/data_models/surrogates/mixed_tanimoto_gp.py @@ -54,7 +54,7 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) @validator("input_preprocessing_specs") def validate_moleculars(cls, v, values): diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index 9df107234..451aa4f7a 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -25,15 +25,15 @@ class MLPEnsemble(TrainableBotorchSurrogate): shuffle: bool = True scaler: ScalerEnum = ScalerEnum.NORMALIZE - @classmethod - def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: - """Abstract method to check output type for surrogate models - Args: - my_type: continuous or categorical output - Returns: - bool: True if the output type is valid for the surrogate chosen, False otherwise - """ - return isinstance(my_type, (CategoricalOutput, ContinuousOutput)) + # @classmethod + # def is_output_implemented(cls, my_type: str) -> bool: + # """Abstract method to check output type for surrogate models + # Args: + # my_type: continuous or categorical output + # Returns: + # bool: True if the output type is valid for the surrogate chosen, False otherwise + # """ + # return isinstance(my_type, (CategoricalOutput, ContinuousOutput)) class RegressionMLPEnsemble(MLPEnsemble): @@ -48,7 +48,7 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) class ClassificationMLPEnsemble(MLPEnsemble): @@ -63,4 +63,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, CategoricalOutput) + return isinstance(my_type, type(CategoricalOutput)) diff --git a/bofire/data_models/surrogates/polynomial.py b/bofire/data_models/surrogates/polynomial.py index 63e4539ad..42533dc49 100644 --- a/bofire/data_models/surrogates/polynomial.py +++ b/bofire/data_models/surrogates/polynomial.py @@ -33,4 +33,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/data_models/surrogates/random_forest.py b/bofire/data_models/surrogates/random_forest.py index 650f5069f..93711ae96 100644 --- a/bofire/data_models/surrogates/random_forest.py +++ b/bofire/data_models/surrogates/random_forest.py @@ -39,4 +39,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/data_models/surrogates/single_task_gp.py b/bofire/data_models/surrogates/single_task_gp.py index 594a9b8cf..b943db173 100644 --- a/bofire/data_models/surrogates/single_task_gp.py +++ b/bofire/data_models/surrogates/single_task_gp.py @@ -121,4 +121,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/data_models/surrogates/surrogate.py b/bofire/data_models/surrogates/surrogate.py index 2270da9c6..da0063dd7 100644 --- a/bofire/data_models/surrogates/surrogate.py +++ b/bofire/data_models/surrogates/surrogate.py @@ -5,7 +5,8 @@ from bofire.data_models.base import BaseModel from bofire.data_models.domain.api import Inputs, Outputs -from bofire.data_models.features.api import AnyOutput, TInputTransformSpecs +from bofire.data_models.features.api import AnyOutput +from bofire.data_models.types import TInputTransformSpecs class Surrogate(BaseModel): @@ -32,7 +33,7 @@ def validate_outputs(cls, outputs, info): if len(outputs) == 0: raise ValueError("At least one output feature has to be provided.") for o in outputs: - if not cls.is_output_implemented(o): + if not cls.is_output_implemented(type(o)): raise ValueError("Invalid output type passed.") return outputs diff --git a/bofire/data_models/surrogates/tanimoto_gp.py b/bofire/data_models/surrogates/tanimoto_gp.py index 3ab1c5dda..51f6eb1ff 100644 --- a/bofire/data_models/surrogates/tanimoto_gp.py +++ b/bofire/data_models/surrogates/tanimoto_gp.py @@ -41,7 +41,7 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) # TanimotoGP will be used when at least one of fingerprints, fragments, or fingerprintsfragments are present @validator("input_preprocessing_specs") diff --git a/bofire/data_models/surrogates/xgb.py b/bofire/data_models/surrogates/xgb.py index 262ca1649..99821182b 100644 --- a/bofire/data_models/surrogates/xgb.py +++ b/bofire/data_models/surrogates/xgb.py @@ -87,4 +87,4 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: Returns: bool: True if the output type is valid for the surrogate chosen, False otherwise """ - return isinstance(my_type, ContinuousOutput) + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index ec9875fed..5365edee2 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -1,4 +1,3 @@ -import itertools from abc import abstractmethod from typing import List, Optional, Tuple @@ -6,12 +5,13 @@ import pandas as pd from pydantic import PositiveInt -from bofire.data_models.features.api import CategoricalOutput, TInputTransformSpecs +from bofire.data_models.features.api import CategoricalOutput from bofire.data_models.strategies.api import Strategy as DataModel from bofire.data_models.types import TInputTransformSpecs from bofire.strategies.data_models.candidate import Candidate from bofire.strategies.data_models.values import InputValue, OutputValue from bofire.strategies.strategy import Strategy +from bofire.utils.naming_conventions import get_column_names class PredictiveStrategy(Strategy): @@ -104,39 +104,15 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: experiments=experiments, specs=self.input_preprocessing_specs ) preds, stds = self._predict(transformed) - column_names = list( - itertools.chain( - *[ - ( - [f"{feat.key}_pred"] - if not isinstance(feat, CategoricalOutput) - else [f"{feat.key}_{cat}_prob" for cat in feat.categories] - ) - for feat in self.domain.outputs.get() - ] - ) - ) + pred_cols, sd_cols = get_column_names(self.domain.outputs) if stds is not None: predictions = pd.DataFrame( - data=np.hstack((preds, stds)), - columns=column_names - + list( - itertools.chain( - *[ - ( - [f"{feat.key}_sd"] - if not isinstance(feat, CategoricalOutput) - else [f"{feat.key}_{cat}_sd" for cat in feat.categories] - ) - for feat in self.domain.outputs.get() - ] - ) - ), + data=np.hstack((preds, stds)), columns=pred_cols + sd_cols ) else: predictions = pd.DataFrame( data=preds, - columns=column_names, + columns=pred_cols, ) for feat in self.domain.outputs.get(): if isinstance(feat, CategoricalOutput): diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index 98d090963..fddaad2d5 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -5,9 +5,10 @@ import pandas as pd from bofire.data_models.domain.domain import is_numeric -from bofire.data_models.features.api import CategoricalOutput, ContinuousOutput +from bofire.data_models.features.api import CategoricalOutput from bofire.data_models.surrogates.api import Surrogate as DataModel from bofire.surrogates.values import PredictedValue +from bofire.utils.naming_conventions import get_column_names class Surrogate(ABC): @@ -46,20 +47,7 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: # predict preds, stds = self._predict(Xt) # set up column names - pred_cols = [] - sd_cols = [] - for featkey in self.outputs.get_keys(CategoricalOutput): - pred_cols = pred_cols + [ - f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories # type: ignore - ] - sd_cols = sd_cols + [ - f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories # type: ignore - ] - for featkey in self.outputs.get_keys(ContinuousOutput): - pred_cols = pred_cols + [f"{featkey}_pred"] - sd_cols = sd_cols + [f"{featkey}_sd"] + pred_cols, sd_cols = get_column_names(self.outputs) # postprocess predictions = pd.DataFrame( data=np.hstack((preds, stds)), @@ -87,27 +75,12 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: return predictions def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: - expected_cols = [] - check_columns = [] + # Get the column names + pred_cols, sd_cols = get_column_names(self.outputs) + expected_cols = pred_cols + sd_cols + check_columns = list(expected_cols) for featkey in self.outputs.get_keys(CategoricalOutput): - expected_cols = ( - expected_cols - + [f"{featkey}_{t}" for t in ["pred", "sd"]] - + [ - f"{featkey}_{cat}_prob" - for cat in self.outputs.get_by_key(featkey).categories # type: ignore - ] - + [ - f"{featkey}_{cat}_sd" - for cat in self.outputs.get_by_key(featkey).categories # type: ignore - ] - ) - check_columns = check_columns + [ - col for col in expected_cols if col != f"{featkey}_pred" - ] - for featkey in self.outputs.get_keys(ContinuousOutput): expected_cols = expected_cols + [f"{featkey}_{t}" for t in ["pred", "sd"]] - check_columns = check_columns + expected_cols if sorted(predictions.columns) != sorted(expected_cols): raise ValueError( f"Predictions are ill-formatted. Expected: {expected_cols}, got: {list(predictions.columns)}." diff --git a/bofire/utils/naming_conventions.py b/bofire/utils/naming_conventions.py new file mode 100644 index 000000000..5d1430afb --- /dev/null +++ b/bofire/utils/naming_conventions.py @@ -0,0 +1,23 @@ +from bofire.data_models.features.api import ( + AnyOutput, + CategoricalOutput, + ContinuousOutput, +) + + +def get_column_names(outputs: AnyOutput): + pred_cols, sd_cols = [], [] + for featkey in outputs.get_keys(CategoricalOutput): + pred_cols = pred_cols + [ + f"{featkey}_{cat}_prob" + for cat in outputs.get_by_key(featkey).categories # type: ignore + ] + sd_cols = sd_cols + [ + f"{featkey}_{cat}_sd" + for cat in outputs.get_by_key(featkey).categories # type: ignore + ] + for featkey in outputs.get_keys(ContinuousOutput): + pred_cols = pred_cols + [f"{featkey}_pred"] + sd_cols = sd_cols + [f"{featkey}_sd"] + + return pred_cols, sd_cols diff --git a/bofire/utils/torch_tools.py b/bofire/utils/torch_tools.py index 1e6ec8f03..e2dc50ac2 100644 --- a/bofire/utils/torch_tools.py +++ b/bofire/utils/torch_tools.py @@ -269,19 +269,19 @@ def constrained_objective2botorch( return ( [ lambda Z: torch.log( - torch.clamp( - 1 - / ( + 1 + / torch.clamp( + ( Z[..., idx : idx + len(objective.desirability)] * torch.tensor(objective.desirability).to(**tkwargs) - ).sum(-1) - - 1, + ).sum(-1), min=eps, max=1 - eps, ) + - 1, ) ], - [objective.eta], + [1.0], idx + len(objective.desirability), ) else: diff --git a/tests/bofire/data_models/features/test_categorical.py b/tests/bofire/data_models/features/test_categorical.py index c3e91de2e..d7699d9d0 100644 --- a/tests/bofire/data_models/features/test_categorical.py +++ b/tests/bofire/data_models/features/test_categorical.py @@ -10,7 +10,9 @@ from bofire.data_models.features.api import ( CategoricalDescriptorInput, CategoricalInput, + CategoricalOutput, ) +from bofire.data_models.objectives.api import ConstrainedCategoricalObjective @pytest.mark.parametrize( @@ -460,3 +462,16 @@ def test_categorical_input_feature_allowed_categories(input_feature, expected): ) def test_categorical_input_feature_forbidden_categories(input_feature, expected): assert input_feature.get_forbidden_categories() == expected + + +def test_categorical_output_call(): + test_df = pd.DataFrame(data=[[0.7, 0.3], [0.2, 0.8]], columns=["c1", "c2"]) + categorical_output = CategoricalOutput( + key="a", + categories=["c1", "c2"], + objective=ConstrainedCategoricalObjective( + categories=["c1", "c2"], desirability=[True, False] + ), + ) + output = categorical_output(test_df) + assert output.tolist() == test_df["c1"].tolist() diff --git a/tests/bofire/data_models/specs/objectives.py b/tests/bofire/data_models/specs/objectives.py index bed2e9902..4544811bc 100644 --- a/tests/bofire/data_models/specs/objectives.py +++ b/tests/bofire/data_models/specs/objectives.py @@ -53,7 +53,6 @@ "w": 1.0, "categories": ["green", "red", "blue"], "desirability": [True, False, True], - "eta": 1.0, }, ) @@ -63,7 +62,6 @@ "w": 1.0, "categories": ["green", "red", "blue"], "desirability": [True, False, True, False], - "eta": 1.0, }, error=ValueError, message="number of categories differs from number of desirabilities", @@ -75,8 +73,7 @@ "w": 1.0, "categories": ["green", "red", "blue", "blue"], "desirability": [True, False, True, False], - "eta": 1.0, }, error=ValueError, - message="categories must be unique", -) \ No newline at end of file + message="Categories must be unique", +) diff --git a/tests/bofire/data_models/specs/outputs.py b/tests/bofire/data_models/specs/outputs.py index b4ffc60fc..0387dac74 100644 --- a/tests/bofire/data_models/specs/outputs.py +++ b/tests/bofire/data_models/specs/outputs.py @@ -1,5 +1,9 @@ from bofire.data_models.domain.api import Outputs -from bofire.data_models.features.api import CategoricalInput, ContinuousOutput, CategoricalOutput +from bofire.data_models.features.api import ( + CategoricalInput, + CategoricalOutput, + ContinuousOutput, +) from bofire.data_models.objectives.api import ConstrainedCategoricalObjective from tests.bofire.data_models.specs.specs import Specs @@ -42,7 +46,13 @@ Outputs, lambda: { "features": [ - CategoricalOutput(key="b", categories=["a", "b"], objective=ConstrainedCategoricalObjective(categories=["c", "d"], desirability=[True, True])), + CategoricalOutput( + key="b", + categories=["a", "b"], + objective=ConstrainedCategoricalObjective( + categories=["c", "d"], desirability=[True, True] + ), + ), ], }, error=ValueError, diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index 7906e80b2..be8c9ebe3 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -222,7 +222,6 @@ "hyperconfig": None, }, error=ValueError, - message="Invalid output type passed.", ) specs.add_valid( @@ -289,7 +288,6 @@ "hyperconfig": None, }, error=ValueError, - message="Invalid output type passed.", ) specs.add_valid( diff --git a/tests/bofire/surrogates/test_diagnostics.py b/tests/bofire/surrogates/test_diagnostics.py index 9f4e8a58a..2dba72630 100644 --- a/tests/bofire/surrogates/test_diagnostics.py +++ b/tests/bofire/surrogates/test_diagnostics.py @@ -3,6 +3,8 @@ import pytest from scipy.stats import pearsonr, spearmanr from sklearn.metrics import ( + accuracy_score, + f1_score, mean_absolute_error, mean_absolute_percentage_error, mean_squared_error, @@ -16,7 +18,9 @@ CvResults, CvResults2CrossValidationValues, UQ_metrics, + _accuracy_score, _CVPPDiagram, + _f1_score, _mean_absolute_error, _mean_absolute_percentage_error, _mean_squared_error, @@ -86,6 +90,38 @@ def test_sklearn_metrics(bofire, sklearn): assert bofire(observed, predicted) == sklearn(observed, predicted) +@pytest.mark.parametrize( + "bofire, sklearn", + [ + (_accuracy_score, accuracy_score), + ], +) +def test_sklearn_metrics_accuracy(bofire, sklearn): + n_samples = 20 + observed = np.random.choice([0, 1, 2, 3], size=(n_samples,)) + predicted = np.random.choice([0, 1, 2, 3], size=(n_samples,)) + sd = None + assert bofire(observed, predicted, sd) == sklearn(observed, predicted) + assert bofire(observed, predicted) == sklearn(observed, predicted) + + +@pytest.mark.parametrize( + "bofire, sklearn", + [ + (_f1_score, f1_score), + ], +) +def test_sklearn_metrics_f1(bofire, sklearn): + n_samples = 20 + observed = np.random.choice([0, 1, 2, 3], size=(n_samples,)) + predicted = np.random.choice([0, 1, 2, 3], size=(n_samples,)) + sd = None + assert bofire(observed, predicted, sd) == sklearn( + observed, predicted, average="micro" + ) + assert bofire(observed, predicted) == sklearn(observed, predicted, average="micro") + + @pytest.mark.parametrize( "bofire, scipy", [ diff --git a/tests/bofire/utils/test_torch_tools.py b/tests/bofire/utils/test_torch_tools.py index 1f06b4bea..6f337c60c 100644 --- a/tests/bofire/utils/test_torch_tools.py +++ b/tests/bofire/utils/test_torch_tools.py @@ -24,6 +24,7 @@ ) from bofire.data_models.objectives.api import ( CloseToTargetObjective, + ConstrainedCategoricalObjective, MaximizeObjective, MaximizeSigmoidObjective, MinimizeObjective, @@ -831,3 +832,23 @@ def test_constrained_objective2botorch(objective): y *= soft_eval_constraint(xtt, eta) assert np.allclose(objective.__call__(np.linspace(0, 30, 500)), y.numpy().ravel()) + + +def test_constrained_objective(): + desirability = [True, False, False] + obj1 = ConstrainedCategoricalObjective( + categories=["c1", "c2", "c3"], desirability=desirability + ) + cs, etas, _ = constrained_objective2botorch(idx=0, objective=obj1) + + x = torch.rand((50, 3)) + x /= x.sum(1).unsqueeze(1) # Convert to probabilities + true_y = (x * torch.tensor(desirability)).sum(-1) + transformed_y = torch.log(1 / torch.clamp(true_y, 1e-6, 1 - 1e-6) - 1) + + assert len(cs) == 1 + assert etas[0] == 1.0 + + y_hat = cs[0](x) + assert np.allclose(y_hat.numpy(), transformed_y.numpy()) + assert np.allclose(np.exp(-np.log(np.exp(y_hat.numpy()) + 1)), true_y) diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 7ac093ea9..1ac43bf71 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 158, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -110,48 +110,48 @@ " \n", " \n", " 0\n", - " -0.149953\n", - " -0.704143\n", + " -0.350619\n", + " 0.504052\n", " 1\n", - " 54.121306\n", + " 16.349300\n", " acceptable\n", - " 1.009579\n", + " 1.009714\n", " \n", " \n", " 1\n", - " -0.625311\n", - " -0.463190\n", + " 0.115844\n", + " 0.947509\n", " 0\n", - " 75.608036\n", - " acceptable\n", - " 0.001819\n", + " 88.034069\n", + " ideal\n", + " 0.005132\n", " \n", " \n", " 2\n", - " -0.765853\n", - " 0.927654\n", + " -0.556177\n", + " -0.459197\n", " 1\n", - " 14.754710\n", + " 61.485379\n", " acceptable\n", - " 1.006574\n", + " 1.000490\n", " \n", " \n", " 3\n", - " -1.447047\n", - " -0.059688\n", + " -1.635584\n", + " 0.905708\n", " 1\n", - " 469.801324\n", + " 320.033865\n", " unacceptable\n", - " 1.002428\n", + " 1.006897\n", " \n", " \n", " 4\n", - " -0.540554\n", - " 1.090780\n", - " 1\n", - " 66.146436\n", - " acceptable\n", - " 1.004633\n", + " -1.474244\n", + " 0.855846\n", + " 0\n", + " 179.715811\n", + " unacceptable\n", + " 0.005518\n", " \n", " \n", "\n", @@ -159,14 +159,14 @@ ], "text/plain": [ " x_0 x_1 x_3 f_0 f_1 f_2\n", - "0 -0.149953 -0.704143 1 54.121306 acceptable 1.009579\n", - "1 -0.625311 -0.463190 0 75.608036 acceptable 0.001819\n", - "2 -0.765853 0.927654 1 14.754710 acceptable 1.006574\n", - "3 -1.447047 -0.059688 1 469.801324 unacceptable 1.002428\n", - "4 -0.540554 1.090780 1 66.146436 acceptable 1.004633" + "0 -0.350619 0.504052 1 16.349300 acceptable 1.009714\n", + "1 0.115844 0.947509 0 88.034069 ideal 0.005132\n", + "2 -0.556177 -0.459197 1 61.485379 acceptable 1.000490\n", + "3 -1.635584 0.905708 1 320.033865 unacceptable 1.006897\n", + "4 -1.474244 0.855846 0 179.715811 unacceptable 0.005518" ] }, - "execution_count": 158, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -198,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 159, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -221,247 +221,247 @@ "showlegend": true, "type": "scatter", "x": [ - -0.14995258503883457, - -0.6253111311588586, - -0.765853028108006, - -0.5405541311965942, - 0.9894994728281232, - -0.865978521372166, - -1.0217962106629315, - -0.4014156739458248, - -0.12059424146260334, - -1.0084908064927265, - -0.9349526862095585, - 0.9458316530354285, - -0.5320149984993212, - -0.4689036742689612, - -0.07833150213004725, - -0.9793104434929384, - 0.5788385198054735, - -0.9481124505558867, - -1.1878657259442678, - 0.7035953005643529, - 0.2579234370307777, - -1.2867733344241161, - -0.3218188759641296, - -0.29215910785338806, - -0.8088505843334085, - -0.16053526753142844, - 0.6219324521230791, - 0.03371532152549417, - -0.18044427816549646, - 0.6672898005757197, - 0.36287033501092125, - -0.5410845636765149, - 0.6184934671940714, - 0.79958184162304, - 0.6854423654922401, - -0.32974185478696016, - -0.5158328238096581, - -0.8267714092240124, - -0.9304400399308963, - 0.29751211472578376, - -1.0674027719621937, - -0.7640154931359711 + -0.35061926427297285, + -0.5561765809009358, + 0.42159282278448273, + -0.14506555699220236, + 1.0353008249274431, + 0.3454444836674355, + -0.7680788452262832, + -0.9933224816259388, + -1.0750262968274997, + 0.7461911958641041, + -0.7476273482445326, + 0.4462467901093472, + -0.07013883971177926, + -0.7613373765433603, + -0.20793189372038778, + -0.7126880256050545, + -0.4298026069294787, + -0.2926997442891148, + -0.48304637342466483, + 0.08633019936772834, + -0.275665340089877, + 0.13163143755446582, + -0.032200237982807245, + 0.630156543011732, + -0.16589247459195278, + -0.11312470103711392, + 1.0905345858791655, + 0.20275804836531774, + 0.3338996521298583, + 0.424736503060791, + 0.09844039335599075 ], "xaxis": "x", "y": [ - -0.7041428419544369, - -0.4631897120601236, - 0.9276536133420858, - 1.0907795664450228, - -0.20189922968971485, - 0.8873233761755905, - 0.7368573002639325, - 0.7110972468528676, - -0.43418666628890823, - 0.08216578094828186, - -0.5657310098005697, - -1.0505195679349209, - 0.4106824576035617, - -1.1004617179921463, - 0.9644868021314426, - 0.8813719192174521, - 0.006476845776403506, - 0.2047160309884497, - -0.48721762585284756, - -0.8694642189151136, - -0.8858761844683214, - -0.26450926092801574, - 0.7700394750617825, - 0.7931467073831526, - -1.099688304768781, - 1.2163564957070103, - -0.021834260403785777, - 0.5629070883753866, - -0.8868851959550068, - -1.2323026418478094, - -0.9954921557404812, - -0.7890090665021279, - -1.1629674424244054, - -0.4567118757180961, - -0.634972965145935, - 0.6601685944241988, - -0.6391884645670298, - -0.9719722439473206, - -0.744901749184502, - -0.9803760720872019, - -0.5942335372991383, - 0.23621357099645968 + 0.5040523211718617, + -0.45919668429740934, + -0.43223105338024803, + -0.5989964538201493, + -0.5649500247669095, + -0.5355390267695426, + 0.9550083378879219, + 0.39901716758574235, + -0.6999593792729817, + -0.6234513037416309, + 0.6181801617811384, + -0.8593732637207796, + -0.7322677185785165, + -1.1687864901526548, + 0.20233627906121776, + -1.1278268495317838, + 0.5197761747103247, + 1.0172863322035508, + -0.23399463937541132, + 0.31703732293322284, + -1.2846423115547891, + 0.4365633647242215, + 0.8598355161979865, + -1.2156493776606054, + -0.6155534138636449, + -0.15102340108948908, + -0.007804401336033662, + -1.1909885140068746, + -1.0674680087818904, + -0.9005255478079787, + -1.3402079733707999 ], "yaxis": "y" }, { - "hovertemplate": "f_1=unacceptable
x_0=%{x}
x_1=%{y}", - "legendgroup": "unacceptable", + "hovertemplate": "f_1=ideal
x_0=%{x}
x_1=%{y}", + "legendgroup": "ideal", "marker": { "color": "#EF553B", "symbol": "circle" }, "mode": "markers", - "name": "unacceptable", + "name": "ideal", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ - -1.4470467658262591, - -1.6009504802388597, - 1.4204895481210702, - 1.3489227092112688, - -0.639284790399812, - 1.575648155774772, - 1.7270515387294698, - -1.1378584773675158, - -0.4870902364848939, - 1.3433379011252464, - 1.6525833048685383, - 0.052856077835125026, - -1.718194617060083, - 1.2519082611394308, - -1.3754185446920677, - -0.18322045819850907, - 1.327292051993207, - 1.0733856862440985, - 0.9008441174905695, - 0.47167063168201295, - -1.28379343444775, - -1.225041501040866, - -0.9879804926500095, - -1.7123070859149068, - -0.8222878848033226, - 1.698340355837336, - -0.40510921927474874, - -1.139666414607729, - 1.669256729311793, - -0.5749235839074829, - 1.3532335670733184, - 1.410971842195503, - 1.2552385529823566, - -1.0402317158880352, - -1.400659402027841, - 1.6382339166896722, - -1.3547647166845345, - 1.1865185162807816, - -1.4333442355901473, - 1.646587754792467, - 0.582077673409676, - 1.380227103874569, - 1.3622158395068897, - 1.1921850110333656, - -1.2624798548187615, - 1.70645818280738 + 0.11584353456460073, + 0.1724670780831814, + 0.539818540333588, + 0.15873339655513652, + 1.2444600831177084, + 0.999781463467607, + 1.1787431301713105, + 0.20622111629978224, + 1.1595275447276538 ], "xaxis": "x", "y": [ - -0.05968813763014724, - -0.08203263567781471, - -1.7256511732407582, - -1.3062738654746688, - -1.380758573136951, - 1.3698035388942658, - -0.14646898218101168, - 1.1472779541835205, - -1.546055723303492, - 1.2777644377843878, - -1.0947398765543759, - -1.691760325948563, - 0.2847825778795654, - -1.7156770939939636, - -1.3464909491716763, - 1.689598233686933, - -0.7777149658959187, - -1.457321100538615, - 1.7237444956486576, - -1.7222531301544601, - 1.4804621415248884, - -1.0312548094743161, - 1.464230692511908, - 1.727604973119779, - -1.4420228070981478, - 1.026409514853559, - 1.4523235910972625, - 1.6510720743055476, - 0.5237495796359455, - 1.6473828157610697, - -0.6903014501907374, - 1.3866811498163663, - -1.1832663015708027, - 1.3995479124379653, - 0.4939162310211662, - 1.6967495852807661, - 0.4613230281774947, - -1.1777089751033198, - 0.09067972703617566, - 0.5611805367913849, - -1.5872277904858516, - 0.39999239270670284, - -0.6292457424969355, - 1.2082895149675292, - -1.6847143997552474, - -1.454374296036075 + 0.9475093155136509, + 1.1526019287770612, + 0.3540271307917444, + 1.0991475451214328, + 0.618585769140207, + 0.135534043926683, + 0.676481578469553, + 1.2930076238613046, + 0.09014125654752259 ], "yaxis": "y" }, { - "hovertemplate": "f_1=ideal
x_0=%{x}
x_1=%{y}", - "legendgroup": "ideal", + "hovertemplate": "f_1=unacceptable
x_0=%{x}
x_1=%{y}", + "legendgroup": "unacceptable", "marker": { "color": "#00cc96", "symbol": "circle" }, "mode": "markers", - "name": "ideal", + "name": "unacceptable", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ - 0.9061821719090513, - 0.5162826776143459, - 0.5094335938756926, - 0.6859698250145936, - 0.5823600872410029, - 1.0858917598939684, - 1.3205055890138766, - 0.14878500290037744, - 0.7581455463744304, - 1.1678712726745357, - 0.6653414272221312, - 0.2300548458979994 + -1.635584388534313, + -1.4742444296046657, + -0.15729314694102903, + 1.5070853560854864, + 1.2101965364760718, + -1.5831812018666418, + -1.3957357203071692, + 1.613078214961829, + 1.4254239036801257, + -1.2787167270583444, + -1.4023268999370018, + 1.7378096362285174, + 0.6296633849190485, + 1.6560265107086973, + -1.4604637134929277, + 1.3954194906359776, + -0.3273878127686993, + 1.5426483753054665, + 0.7725626061807955, + 1.5974661747883538, + 1.688121615826073, + -1.5816687013750201, + -1.3669080002392295, + -0.933603434080167, + 0.4822506956934074, + -1.4195045992551758, + -1.5862157185428094, + -0.6113354136054985, + -0.04870290118977216, + -1.3475465180331827, + 1.4165738521389044, + -0.7433524285306379, + -0.9790834992740071, + 0.4177249330842745, + -1.5065150315297096, + 0.8379155379744967, + -1.6748335190020076, + 1.732364494283424, + 1.6746380993561396, + -1.6486196321522595, + -1.6923451290787777, + 1.0735234059299241, + 1.6026120196380926, + 1.6368331859404996, + 1.344210315200074, + 0.5756493888406227, + 1.0955596911970198, + 1.1989776196979713, + 1.6443171102633003, + -1.2980309445061669, + 1.0537367958501087, + -1.2500086281626677, + 0.38230323278127143, + -1.6985134148647894, + 1.7156174130125836, + 1.6853737833769693, + -0.294642574709699, + 1.1650211575724394, + -0.6805240325421535, + -0.8809593972969152 ], "xaxis": "x", "y": [ - 0.886718714624303, - 0.39576812610714196, - 0.5122034490809124, - 0.7226512509233372, - 0.7116114473585995, - 0.7731429665448246, - 0.40197976993092377, - 0.5202938254854113, - 0.24505092339428747, - 0.5559493377544111, - 0.28496872065832335, - 0.8280874889772782 + 0.905708248434292, + 0.8558461621953355, + -1.682916252803905, + -0.3787532353125016, + 0.8582593276903339, + -0.5950896738838634, + -0.5940137871920919, + -1.1106762625595035, + 1.749844877144934, + -0.7745816878526347, + -1.0080154738799716, + 0.26464782707414836, + 1.469231740752055, + -1.6208934974410696, + -0.07648992450843961, + -1.0433090911846108, + -1.4790774655942789, + -1.0982882152772908, + -1.241192967822959, + -0.5220349756706821, + -0.3558204214616887, + 0.2092233051309027, + -1.3578044601398707, + 1.7264947346866033, + 1.5894671117819081, + -1.744256531724441, + -0.7461734234273045, + 1.7424015284023127, + -1.5287936583164727, + -0.6814323944749845, + 1.266759433319939, + 1.5399709474328813, + -1.5645549489253556, + -1.5196196260358157, + 0.44194322574419287, + 1.3477204637842641, + 1.7295610322135961, + -1.1974130031498986, + 0.49951459600915626, + 0.11526237063394973, + -1.7248874050370138, + 1.5908887950885897, + 0.27689415663406214, + -0.31530244091173354, + -1.6932092535195582, + 1.3482667568433095, + 1.5726556777198932, + -0.9488385694367101, + 0.2919559196715964, + 1.7065717592720269, + 1.5139353043217532, + -1.6001644269256965, + -1.7115693621588608, + -0.38867144542644105, + 0.6923493546148185, + 0.1427173148876928, + 1.4870813273250616, + -0.9737357252558907, + 1.2417204272219675, + 1.740049552575495 ], "yaxis": "y" } @@ -1382,7 +1382,7 @@ }, { "cell_type": "code", - "execution_count": 163, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -1434,7 +1434,7 @@ }, { "cell_type": "code", - "execution_count": 164, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -1465,8 +1465,8 @@ " \n", " \n", " 0\n", - " 0.68\n", - " 0.68\n", + " 0.72\n", + " 0.72\n", " \n", " \n", "\n", @@ -1474,10 +1474,10 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.68 0.68" + "0 0.72 0.72" ] }, - "execution_count": 164, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -1489,7 +1489,7 @@ }, { "cell_type": "code", - "execution_count": 165, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -1520,8 +1520,8 @@ " \n", " \n", " 0\n", - " 0.54\n", - " 0.54\n", + " 0.65\n", + " 0.65\n", " \n", " \n", "\n", @@ -1529,10 +1529,10 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.54 0.54" + "0 0.65 0.65" ] }, - "execution_count": 165, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -1551,7 +1551,7 @@ }, { "cell_type": "code", - "execution_count": 166, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -1577,7 +1577,7 @@ }, { "cell_type": "code", - "execution_count": 167, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -1624,16 +1624,16 @@ " x_3\n", " f_1_pred\n", " f_1_sd\n", - " f_0_pred\n", - " f_2_pred\n", " f_1_unacceptable_prob\n", " f_1_acceptable_prob\n", " f_1_ideal_prob\n", - " f_0_sd\n", - " f_2_sd\n", + " f_0_pred\n", + " f_2_pred\n", " f_1_unacceptable_sd\n", " f_1_acceptable_sd\n", " f_1_ideal_sd\n", + " f_0_sd\n", + " f_2_sd\n", " f_0_des\n", " f_2_des\n", " f_1_des\n", @@ -1642,269 +1642,269 @@ " \n", " \n", " 0\n", - " 0.403369\n", - " 0.161374\n", - " 0\n", + " 0.154314\n", + " -0.001118\n", + " 1\n", " acceptable\n", " 0.0\n", - " 0.092455\n", - " 0.004524\n", - " 0.399619\n", - " 0.599358\n", - " 0.001023\n", - " 2.192848\n", - " 0.002932\n", - " 0.542024\n", - " 0.541092\n", - " 0.001150\n", - " -0.092455\n", - " 0.499435\n", - " 0.600381\n", + " 0.425532\n", + " 1.004986\n", + " 0.639645\n", + " 0.357941\n", + " 0.002414\n", + " 6.588467\n", + " 0.003142\n", + " 0.476779\n", + " 0.473279\n", + " 0.004336\n", + " -0.357941\n", + " 0.499698\n", + " 1.644631\n", " \n", " \n", " 1\n", - " 0.291096\n", - " 0.091716\n", - " 1\n", - " acceptable\n", + " 0.149196\n", + " 0.013251\n", + " 0\n", + " ideal\n", " 0.0\n", - " 0.338534\n", - " 1.005305\n", - " 0.112541\n", - " 0.886353\n", - " 0.001106\n", - " 2.772511\n", - " 0.002915\n", - " 0.197303\n", - " 0.196632\n", - " 0.001800\n", - " -0.338534\n", - " 0.376918\n", - " 0.887459\n", + " 0.465530\n", + " 0.005402\n", + " 0.795662\n", + " 0.204113\n", + " 0.000226\n", + " 3.777311\n", + " 0.003156\n", + " 0.430231\n", + " 0.430007\n", + " 0.000242\n", + " -0.204113\n", + " 0.499972\n", + " 0.801064\n", " \n", " \n", " 2\n", - " 1.319030\n", + " -1.333767\n", " 1.750000\n", - " 1\n", - " unacceptable\n", + " 0\n", + " ideal\n", " 0.0\n", - " -2.338611\n", - " 1.003526\n", - " 0.688225\n", - " 0.110699\n", - " 0.201075\n", - " 4.714258\n", - " 0.003410\n", - " 0.449328\n", - " 0.240964\n", - " 0.445632\n", - " 2.338611\n", - " 0.377126\n", - " 0.311775\n", + " 0.623769\n", + " 0.007608\n", + " 0.795045\n", + " 0.204738\n", + " 0.000217\n", + " 11.252150\n", + " 0.003407\n", + " 0.431594\n", + " 0.431392\n", + " 0.000222\n", + " -0.204738\n", + " 0.499973\n", + " 0.802653\n", " \n", " \n", " 3\n", - " 0.086516\n", - " -0.000797\n", - " 0\n", - " acceptable\n", + " -0.630759\n", + " 0.390813\n", + " 1\n", + " unacceptable\n", " 0.0\n", - " 0.900832\n", - " 0.004706\n", - " 0.398625\n", - " 0.600412\n", - " 0.000963\n", - " 2.342997\n", - " 0.002935\n", - " 0.542560\n", - " 0.541685\n", - " 0.001131\n", - " -0.900832\n", - " 0.499412\n", - " 0.601375\n", + " 2.985411\n", + " 1.005220\n", + " 0.638927\n", + " 0.359041\n", + " 0.002033\n", + " 6.575972\n", + " 0.003180\n", + " 0.478100\n", + " 0.475209\n", + " 0.003490\n", + " -0.359041\n", + " 0.499746\n", + " 1.644147\n", " \n", " \n", " 4\n", - " -0.228315\n", - " 0.046514\n", + " -1.114157\n", + " 1.209882\n", " 0\n", - " acceptable\n", + " unacceptable\n", " 0.0\n", - " 1.500286\n", - " 0.004834\n", - " 0.397428\n", - " 0.601626\n", - " 0.000946\n", - " 2.390450\n", - " 0.002935\n", - " 0.541194\n", - " 0.540336\n", - " 0.001115\n", - " -1.500286\n", - " 0.499396\n", - " 0.602572\n", + " 8.030368\n", + " 0.006634\n", + " 0.795048\n", + " 0.204736\n", + " 0.000216\n", + " 14.069044\n", + " 0.003171\n", + " 0.431596\n", + " 0.431394\n", + " 0.000222\n", + " -0.204736\n", + " 0.499973\n", + " 0.801682\n", " \n", " \n", " 5\n", - " 0.076667\n", - " -0.004871\n", + " -0.358540\n", + " 0.102376\n", " 1\n", - " acceptable\n", + " unacceptable\n", " 0.0\n", - " 0.979306\n", - " 1.005360\n", - " 0.198561\n", - " 0.800833\n", - " 0.000606\n", - " 2.372284\n", - " 0.002916\n", - " 0.237685\n", - " 0.237178\n", - " 0.000844\n", - " -0.979306\n", - " 0.376911\n", - " 0.801439\n", + " 1.546216\n", + " 1.004791\n", + " 0.638009\n", + " 0.360123\n", + " 0.001868\n", + " 4.052824\n", + " 0.003159\n", + " 0.479443\n", + " 0.476816\n", + " 0.003129\n", + " -0.360123\n", + " 0.499766\n", + " 1.642799\n", " \n", " \n", " 6\n", - " 0.308219\n", - " 0.087674\n", - " 0\n", + " 0.736046\n", + " 0.516076\n", + " 1\n", " acceptable\n", " 0.0\n", - " 0.222725\n", - " 0.004599\n", - " 0.399124\n", - " 0.599887\n", - " 0.000989\n", - " 2.834773\n", - " 0.002933\n", - " 0.542427\n", - " 0.541527\n", - " 0.001137\n", - " -0.222725\n", - " 0.499425\n", - " 0.600876\n", + " -0.240881\n", + " 1.004885\n", + " 0.819717\n", + " 0.173096\n", + " 0.007188\n", + " 8.248212\n", + " 0.003162\n", + " 0.331533\n", + " 0.333793\n", + " 0.014980\n", + " -0.173096\n", + " 0.499102\n", + " 1.824602\n", " \n", " \n", " 7\n", - " 0.829716\n", - " 0.681121\n", + " -0.558970\n", + " 0.322902\n", " 0\n", " unacceptable\n", " 0.0\n", - " -0.348330\n", - " 0.003866\n", - " 0.510435\n", - " 0.418242\n", - " 0.071323\n", - " 1.065240\n", - " 0.002941\n", - " 0.496067\n", - " 0.425085\n", - " 0.157438\n", - " 0.348330\n", - " 0.499517\n", - " 0.489565\n", + " 2.070167\n", + " 0.005272\n", + " 0.795045\n", + " 0.204738\n", + " 0.000217\n", + " 4.657032\n", + " 0.003146\n", + " 0.431604\n", + " 0.431400\n", + " 0.000224\n", + " -0.204738\n", + " 0.499973\n", + " 0.800317\n", " \n", " \n", " 8\n", - " -0.219844\n", - " 0.044418\n", + " 0.303819\n", + " 0.044806\n", " 1\n", " acceptable\n", " 0.0\n", - " 1.494160\n", - " 1.005372\n", - " 0.239904\n", - " 0.759543\n", - " 0.000553\n", - " 2.332408\n", - " 0.002915\n", - " 0.304143\n", - " 0.303688\n", - " 0.000761\n", - " -1.494160\n", - " 0.376910\n", - " 0.760096\n", + " 0.656479\n", + " 1.005106\n", + " 0.644744\n", + " 0.351593\n", + " 0.003664\n", + " 6.172526\n", + " 0.003141\n", + " 0.468581\n", + " 0.463146\n", + " 0.007114\n", + " -0.351593\n", + " 0.499542\n", + " 1.649849\n", " \n", " \n", " 9\n", - " -0.063892\n", - " 0.013738\n", + " -0.992659\n", + " 0.943164\n", " 0\n", - " acceptable\n", + " unacceptable\n", " 0.0\n", - " 1.097156\n", - " 0.004757\n", - " 0.398289\n", - " 0.600757\n", - " 0.000954\n", - " 1.653161\n", - " 0.002935\n", - " 0.542274\n", - " 0.541407\n", - " 0.001125\n", - " -1.097156\n", - " 0.499405\n", - " 0.601711\n", + " 6.008622\n", + " 0.006171\n", + " 0.795048\n", + " 0.204736\n", + " 0.000216\n", + " 10.497928\n", + " 0.003149\n", + " 0.431597\n", + " 0.431395\n", + " 0.000222\n", + " -0.204736\n", + " 0.499973\n", + " 0.801219\n", " \n", " \n", "\n", "" ], "text/plain": [ - " x_0 x_1 x_3 f_1_pred f_1_sd f_0_pred f_2_pred \\\n", - "0 0.403369 0.161374 0 acceptable 0.0 0.092455 0.004524 \n", - "1 0.291096 0.091716 1 acceptable 0.0 0.338534 1.005305 \n", - "2 1.319030 1.750000 1 unacceptable 0.0 -2.338611 1.003526 \n", - "3 0.086516 -0.000797 0 acceptable 0.0 0.900832 0.004706 \n", - "4 -0.228315 0.046514 0 acceptable 0.0 1.500286 0.004834 \n", - "5 0.076667 -0.004871 1 acceptable 0.0 0.979306 1.005360 \n", - "6 0.308219 0.087674 0 acceptable 0.0 0.222725 0.004599 \n", - "7 0.829716 0.681121 0 unacceptable 0.0 -0.348330 0.003866 \n", - "8 -0.219844 0.044418 1 acceptable 0.0 1.494160 1.005372 \n", - "9 -0.063892 0.013738 0 acceptable 0.0 1.097156 0.004757 \n", + " x_0 x_1 x_3 f_1_pred f_1_sd f_1_unacceptable_prob \\\n", + "0 0.154314 -0.001118 1 acceptable 0.0 0.425532 \n", + "1 0.149196 0.013251 0 ideal 0.0 0.465530 \n", + "2 -1.333767 1.750000 0 ideal 0.0 0.623769 \n", + "3 -0.630759 0.390813 1 unacceptable 0.0 2.985411 \n", + "4 -1.114157 1.209882 0 unacceptable 0.0 8.030368 \n", + "5 -0.358540 0.102376 1 unacceptable 0.0 1.546216 \n", + "6 0.736046 0.516076 1 acceptable 0.0 -0.240881 \n", + "7 -0.558970 0.322902 0 unacceptable 0.0 2.070167 \n", + "8 0.303819 0.044806 1 acceptable 0.0 0.656479 \n", + "9 -0.992659 0.943164 0 unacceptable 0.0 6.008622 \n", "\n", - " f_1_unacceptable_prob f_1_acceptable_prob f_1_ideal_prob f_0_sd \\\n", - "0 0.399619 0.599358 0.001023 2.192848 \n", - "1 0.112541 0.886353 0.001106 2.772511 \n", - "2 0.688225 0.110699 0.201075 4.714258 \n", - "3 0.398625 0.600412 0.000963 2.342997 \n", - "4 0.397428 0.601626 0.000946 2.390450 \n", - "5 0.198561 0.800833 0.000606 2.372284 \n", - "6 0.399124 0.599887 0.000989 2.834773 \n", - "7 0.510435 0.418242 0.071323 1.065240 \n", - "8 0.239904 0.759543 0.000553 2.332408 \n", - "9 0.398289 0.600757 0.000954 1.653161 \n", + " f_1_acceptable_prob f_1_ideal_prob f_0_pred f_2_pred \\\n", + "0 1.004986 0.639645 0.357941 0.002414 \n", + "1 0.005402 0.795662 0.204113 0.000226 \n", + "2 0.007608 0.795045 0.204738 0.000217 \n", + "3 1.005220 0.638927 0.359041 0.002033 \n", + "4 0.006634 0.795048 0.204736 0.000216 \n", + "5 1.004791 0.638009 0.360123 0.001868 \n", + "6 1.004885 0.819717 0.173096 0.007188 \n", + "7 0.005272 0.795045 0.204738 0.000217 \n", + "8 1.005106 0.644744 0.351593 0.003664 \n", + "9 0.006171 0.795048 0.204736 0.000216 \n", "\n", - " f_2_sd f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_des \\\n", - "0 0.002932 0.542024 0.541092 0.001150 -0.092455 \n", - "1 0.002915 0.197303 0.196632 0.001800 -0.338534 \n", - "2 0.003410 0.449328 0.240964 0.445632 2.338611 \n", - "3 0.002935 0.542560 0.541685 0.001131 -0.900832 \n", - "4 0.002935 0.541194 0.540336 0.001115 -1.500286 \n", - "5 0.002916 0.237685 0.237178 0.000844 -0.979306 \n", - "6 0.002933 0.542427 0.541527 0.001137 -0.222725 \n", - "7 0.002941 0.496067 0.425085 0.157438 0.348330 \n", - "8 0.002915 0.304143 0.303688 0.000761 -1.494160 \n", - "9 0.002935 0.542274 0.541407 0.001125 -1.097156 \n", + " f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_sd f_2_sd \\\n", + "0 6.588467 0.003142 0.476779 0.473279 0.004336 \n", + "1 3.777311 0.003156 0.430231 0.430007 0.000242 \n", + "2 11.252150 0.003407 0.431594 0.431392 0.000222 \n", + "3 6.575972 0.003180 0.478100 0.475209 0.003490 \n", + "4 14.069044 0.003171 0.431596 0.431394 0.000222 \n", + "5 4.052824 0.003159 0.479443 0.476816 0.003129 \n", + "6 8.248212 0.003162 0.331533 0.333793 0.014980 \n", + "7 4.657032 0.003146 0.431604 0.431400 0.000224 \n", + "8 6.172526 0.003141 0.468581 0.463146 0.007114 \n", + "9 10.497928 0.003149 0.431597 0.431395 0.000222 \n", "\n", - " f_2_des f_1_des \n", - "0 0.499435 0.600381 \n", - "1 0.376918 0.887459 \n", - "2 0.377126 0.311775 \n", - "3 0.499412 0.601375 \n", - "4 0.499396 0.602572 \n", - "5 0.376911 0.801439 \n", - "6 0.499425 0.600876 \n", - "7 0.499517 0.489565 \n", - "8 0.376910 0.760096 \n", - "9 0.499405 0.601711 " + " f_0_des f_2_des f_1_des \n", + "0 -0.357941 0.499698 1.644631 \n", + "1 -0.204113 0.499972 0.801064 \n", + "2 -0.204738 0.499973 0.802653 \n", + "3 -0.359041 0.499746 1.644147 \n", + "4 -0.204736 0.499973 0.801682 \n", + "5 -0.360123 0.499766 1.642799 \n", + "6 -0.173096 0.499102 1.824602 \n", + "7 -0.204738 0.499973 0.800317 \n", + "8 -0.351593 0.499542 1.649849 \n", + "9 -0.204736 0.499973 0.801219 " ] }, - "execution_count": 167, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -1925,7 +1925,7 @@ }, { "cell_type": "code", - "execution_count": 168, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -1935,7 +1935,7 @@ }, { "cell_type": "code", - "execution_count": 169, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -1959,6 +1959,8 @@ " \n", " \n", " \n", + " x_0\n", + " x_1\n", " f_1_pred\n", " f_1_true\n", " \n", @@ -1966,52 +1968,72 @@ " \n", " \n", " 0\n", + " 0.154314\n", + " -0.001118\n", " acceptable\n", " acceptable\n", " \n", " \n", " 1\n", - " acceptable\n", + " 0.149196\n", + " 0.013251\n", + " ideal\n", " acceptable\n", " \n", " \n", " 2\n", - " unacceptable\n", + " -1.333767\n", + " 1.750000\n", + " ideal\n", " unacceptable\n", " \n", " \n", " 3\n", - " acceptable\n", + " -0.630759\n", + " 0.390813\n", + " unacceptable\n", " acceptable\n", " \n", " \n", " 4\n", - " acceptable\n", - " acceptable\n", + " -1.114157\n", + " 1.209882\n", + " unacceptable\n", + " unacceptable\n", " \n", " \n", " 5\n", - " acceptable\n", + " -0.358540\n", + " 0.102376\n", + " unacceptable\n", " acceptable\n", " \n", " \n", " 6\n", + " 0.736046\n", + " 0.516076\n", " acceptable\n", - " acceptable\n", + " ideal\n", " \n", " \n", " 7\n", + " -0.558970\n", + " 0.322902\n", " unacceptable\n", - " ideal\n", + " acceptable\n", " \n", " \n", " 8\n", + " 0.303819\n", + " 0.044806\n", " acceptable\n", " acceptable\n", " \n", " \n", " 9\n", - " acceptable\n", + " -0.992659\n", + " 0.943164\n", + " unacceptable\n", " acceptable\n", " \n", " \n", @@ -2019,27 +2041,27 @@ "" ], "text/plain": [ - " f_1_pred f_1_true\n", - "0 acceptable acceptable\n", - "1 acceptable acceptable\n", - "2 unacceptable unacceptable\n", - "3 acceptable acceptable\n", - "4 acceptable acceptable\n", - "5 acceptable acceptable\n", - "6 acceptable acceptable\n", - "7 unacceptable ideal\n", - "8 acceptable acceptable\n", - "9 acceptable acceptable" + " x_0 x_1 f_1_pred f_1_true\n", + "0 0.154314 -0.001118 acceptable acceptable\n", + "1 0.149196 0.013251 ideal acceptable\n", + "2 -1.333767 1.750000 ideal unacceptable\n", + "3 -0.630759 0.390813 unacceptable acceptable\n", + "4 -1.114157 1.209882 unacceptable unacceptable\n", + "5 -0.358540 0.102376 unacceptable acceptable\n", + "6 0.736046 0.516076 acceptable ideal\n", + "7 -0.558970 0.322902 unacceptable acceptable\n", + "8 0.303819 0.044806 acceptable acceptable\n", + "9 -0.992659 0.943164 unacceptable acceptable" ] }, - "execution_count": 169, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Print results\n", - "candidates[[\"f_1_pred\", \"f_1_true\"]]" + "candidates[[\"x_0\", \"x_1\", \"f_1_pred\", \"f_1_true\"]]" ] } ], diff --git a/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb b/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb index 7b65d776c..cf211a0f5 100644 --- a/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb +++ b/tutorials/benchmarks/007-Benchmark_outlier_detection.ipynb @@ -131,7 +131,7 @@ " Returns:\n", " bool: True if the output type is valid for the surrogate chosen, False otherwise\n", " \"\"\"\n", - " return isinstance(my_type, ContinuousOutput)\n", + " return isinstance(my_type, type(ContinuousOutput))\n", "\n", "\n", "class SingleTaskVariationalGPSurrogate(BotorchSurrogate1, TrainableSurrogate):\n", From 18d0f138442056b8db3c7a5a98ce53dd4b2b2dc5 Mon Sep 17 00:00:00 2001 From: gmancino Date: Mon, 12 Feb 2024 09:00:31 -0500 Subject: [PATCH 25/31] Type checking --- bofire/strategies/predictives/predictive.py | 2 +- bofire/surrogates/surrogate.py | 4 ++-- bofire/utils/naming_conventions.py | 9 +++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index 5365edee2..0edadfe0a 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -104,7 +104,7 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: experiments=experiments, specs=self.input_preprocessing_specs ) preds, stds = self._predict(transformed) - pred_cols, sd_cols = get_column_names(self.domain.outputs) + pred_cols, sd_cols = get_column_names(self.domain.outputs) # type: ignore if stds is not None: predictions = pd.DataFrame( data=np.hstack((preds, stds)), columns=pred_cols + sd_cols diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index fddaad2d5..be61eb459 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -47,7 +47,7 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: # predict preds, stds = self._predict(Xt) # set up column names - pred_cols, sd_cols = get_column_names(self.outputs) + pred_cols, sd_cols = get_column_names(self.outputs) # type: ignore # postprocess predictions = pd.DataFrame( data=np.hstack((preds, stds)), @@ -76,7 +76,7 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: def validate_predictions(self, predictions: pd.DataFrame) -> pd.DataFrame: # Get the column names - pred_cols, sd_cols = get_column_names(self.outputs) + pred_cols, sd_cols = get_column_names(self.outputs) # type: ignore expected_cols = pred_cols + sd_cols check_columns = list(expected_cols) for featkey in self.outputs.get_keys(CategoricalOutput): diff --git a/bofire/utils/naming_conventions.py b/bofire/utils/naming_conventions.py index 5d1430afb..35603dad5 100644 --- a/bofire/utils/naming_conventions.py +++ b/bofire/utils/naming_conventions.py @@ -1,13 +1,14 @@ +from typing import List, Tuple + from bofire.data_models.features.api import ( - AnyOutput, CategoricalOutput, ContinuousOutput, ) -def get_column_names(outputs: AnyOutput): +def get_column_names(outputs) -> Tuple[List[str], List[str]]: pred_cols, sd_cols = [], [] - for featkey in outputs.get_keys(CategoricalOutput): + for featkey in outputs.get_keys(CategoricalOutput): # type: ignore pred_cols = pred_cols + [ f"{featkey}_{cat}_prob" for cat in outputs.get_by_key(featkey).categories # type: ignore @@ -16,7 +17,7 @@ def get_column_names(outputs: AnyOutput): f"{featkey}_{cat}_sd" for cat in outputs.get_by_key(featkey).categories # type: ignore ] - for featkey in outputs.get_keys(ContinuousOutput): + for featkey in outputs.get_keys(ContinuousOutput): # type: ignore pred_cols = pred_cols + [f"{featkey}_pred"] sd_cols = sd_cols + [f"{featkey}_sd"] From c19411cc7afcee005eb565de0490975d13abd353 Mon Sep 17 00:00:00 2001 From: gmancino Date: Mon, 12 Feb 2024 11:08:35 -0500 Subject: [PATCH 26/31] Fix constraint tests --- bofire/utils/torch_tools.py | 2 +- tests/bofire/utils/test_torch_tools.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bofire/utils/torch_tools.py b/bofire/utils/torch_tools.py index e2dc50ac2..df973988d 100644 --- a/bofire/utils/torch_tools.py +++ b/bofire/utils/torch_tools.py @@ -223,7 +223,7 @@ def get_nonlinear_constraints(domain: Domain) -> List[Callable[[Tensor], float]] def constrained_objective2botorch( - idx: int, objective: ConstrainedObjective, eps: float = 1e-6 + idx: int, objective: ConstrainedObjective, eps: float = 1e-8 ) -> Tuple[List[Callable[[Tensor], Tensor]], List[float], int]: """Create a callable that can be used by `botorch.utils.objective.apply_constraints` to setup ouput constrained optimizations. diff --git a/tests/bofire/utils/test_torch_tools.py b/tests/bofire/utils/test_torch_tools.py index 6f337c60c..9cac4991d 100644 --- a/tests/bofire/utils/test_torch_tools.py +++ b/tests/bofire/utils/test_torch_tools.py @@ -841,14 +841,19 @@ def test_constrained_objective(): ) cs, etas, _ = constrained_objective2botorch(idx=0, objective=obj1) - x = torch.rand((50, 3)) - x /= x.sum(1).unsqueeze(1) # Convert to probabilities + x = torch.zeros((50, 3)) + x[:, 0] = torch.arange(50) / 50 true_y = (x * torch.tensor(desirability)).sum(-1) - transformed_y = torch.log(1 / torch.clamp(true_y, 1e-6, 1 - 1e-6) - 1) + transformed_y = torch.log(1 / torch.clamp(true_y, 1e-8, 1 - 1e-8) - 1) assert len(cs) == 1 assert etas[0] == 1.0 y_hat = cs[0](x) assert np.allclose(y_hat.numpy(), transformed_y.numpy()) - assert np.allclose(np.exp(-np.log(np.exp(y_hat.numpy()) + 1)), true_y) + assert ( + np.linalg.norm( + np.exp(-np.log(np.exp(y_hat.numpy()) + 1)) - true_y.numpy(), ord=np.inf + ) + <= 1e-8 + ) From 8266240ed2fd3bd15668b17f27a723c3fa78e0f4 Mon Sep 17 00:00:00 2001 From: gmancino Date: Fri, 16 Feb 2024 13:13:05 -0500 Subject: [PATCH 27/31] Fix tests and update naming convention script --- bofire/strategies/predictives/predictive.py | 23 +- bofire/surrogates/surrogate.py | 21 +- bofire/utils/naming_conventions.py | 44 +- tests/bofire/surrogates/test_mlp.py | 52 +- tests/bofire/utils/test_naming_conventions.py | 139 +++ .../Unknown_Constraint_Classification.ipynb | 1030 +++++++++-------- 6 files changed, 758 insertions(+), 551 deletions(-) create mode 100644 tests/bofire/utils/test_naming_conventions.py diff --git a/bofire/strategies/predictives/predictive.py b/bofire/strategies/predictives/predictive.py index 0edadfe0a..bd8cd6f64 100644 --- a/bofire/strategies/predictives/predictive.py +++ b/bofire/strategies/predictives/predictive.py @@ -5,13 +5,15 @@ import pandas as pd from pydantic import PositiveInt -from bofire.data_models.features.api import CategoricalOutput from bofire.data_models.strategies.api import Strategy as DataModel from bofire.data_models.types import TInputTransformSpecs from bofire.strategies.data_models.candidate import Candidate from bofire.strategies.data_models.values import InputValue, OutputValue from bofire.strategies.strategy import Strategy -from bofire.utils.naming_conventions import get_column_names +from bofire.utils.naming_conventions import ( + get_column_names, + postprocess_categorical_predictions, +) class PredictiveStrategy(Strategy): @@ -114,22 +116,7 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame: data=preds, columns=pred_cols, ) - for feat in self.domain.outputs.get(): - if isinstance(feat, CategoricalOutput): - predictions.insert( - loc=0, - column=f"{feat.key}_pred", - value=predictions.filter(regex=f"{feat.key}(.*)_prob") - .idxmax(1) - .str.replace(f"{feat.key}_", "") - .str.replace("_prob", "") - .values, - ) - predictions.insert( - loc=1, - column=f"{feat.key}_sd", - value=0.0, - ) + predictions = postprocess_categorical_predictions(predictions=predictions, outputs=self.domain.outputs) # type: ignore desis = self.domain.outputs(predictions, predictions=True) predictions = pd.concat((predictions, desis), axis=1) predictions.index = experiments.index diff --git a/bofire/surrogates/surrogate.py b/bofire/surrogates/surrogate.py index be61eb459..10849e66c 100644 --- a/bofire/surrogates/surrogate.py +++ b/bofire/surrogates/surrogate.py @@ -8,7 +8,10 @@ from bofire.data_models.features.api import CategoricalOutput from bofire.data_models.surrogates.api import Surrogate as DataModel from bofire.surrogates.values import PredictedValue -from bofire.utils.naming_conventions import get_column_names +from bofire.utils.naming_conventions import ( + get_column_names, + postprocess_categorical_predictions, +) class Surrogate(ABC): @@ -54,21 +57,7 @@ def predict(self, X: pd.DataFrame) -> pd.DataFrame: columns=pred_cols + sd_cols, ) # append predictions for categorical cases - for feat in self.outputs.get(CategoricalOutput): - predictions.insert( - loc=0, - column=f"{feat.key}_pred", - value=predictions.filter(regex=f"{feat.key}(.*)_prob") - .idxmax(1) - .str.replace(f"{feat.key}_", "") - .str.replace("_prob", "") - .values, - ) - predictions.insert( - loc=1, - column=f"{feat.key}_sd", - value=0.0, - ) + predictions = postprocess_categorical_predictions(predictions=predictions, outputs=self.outputs) # type: ignore # validate self.validate_predictions(predictions=predictions) # return diff --git a/bofire/utils/naming_conventions.py b/bofire/utils/naming_conventions.py index 35603dad5..fc9eaf653 100644 --- a/bofire/utils/naming_conventions.py +++ b/bofire/utils/naming_conventions.py @@ -1,12 +1,24 @@ from typing import List, Tuple +import pandas as pd + +from bofire.data_models.domain.api import Outputs from bofire.data_models.features.api import ( CategoricalOutput, ContinuousOutput, ) -def get_column_names(outputs) -> Tuple[List[str], List[str]]: +def get_column_names(outputs: Outputs) -> Tuple[List[str], List[str]]: + """ + Specifies column names for given Outputs type. + + Args: + outputs (Outputs): The Outputs object containing the individual outputs. + + Returns: + Tuple[List[str], List[str]]: A tuple containing the prediction column names and the standard deviation column names + """ pred_cols, sd_cols = [], [] for featkey in outputs.get_keys(CategoricalOutput): # type: ignore pred_cols = pred_cols + [ @@ -22,3 +34,33 @@ def get_column_names(outputs) -> Tuple[List[str], List[str]]: sd_cols = sd_cols + [f"{featkey}_sd"] return pred_cols, sd_cols + + +def postprocess_categorical_predictions(predictions: pd.DataFrame, outputs: Outputs) -> pd.DataFrame: # type: ignore + """ + Postprocess categorical predictions by finding the maximum probability location + + Args: + predictions (pd.DataFrame): The dataframe containing the predictions. + outputs (Outputs): The Outputs object containing the individual outputs. + + Returns: + predictions (pd.DataFrame): The (potentially modified) original dataframe with categorical predictions added + """ + for feat in outputs.get(): + if isinstance(feat, CategoricalOutput): # type: ignore + predictions.insert( + loc=0, + column=f"{feat.key}_pred", + value=predictions.filter(regex=f"{feat.key}(.*)_prob") + .idxmax(1) + .str.replace(f"{feat.key}_", "") + .str.replace("_prob", "") + .values, + ) + predictions.insert( + loc=1, + column=f"{feat.key}_sd", + value=0.0, + ) + return predictions diff --git a/tests/bofire/surrogates/test_mlp.py b/tests/bofire/surrogates/test_mlp.py index 67708aff1..ffe71408f 100644 --- a/tests/bofire/surrogates/test_mlp.py +++ b/tests/bofire/surrogates/test_mlp.py @@ -7,13 +7,19 @@ import bofire.surrogates.api as surrogates from bofire.benchmarks.single import Himmelblau -from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.domain.api import Domain, Inputs, Outputs from bofire.data_models.features.api import ( CategoricalInput, + CategoricalOutput, ContinuousInput, ContinuousOutput, ) -from bofire.data_models.surrogates.api import RegressionMLPEnsemble, ScalerEnum +from bofire.data_models.objectives.api import ConstrainedCategoricalObjective +from bofire.data_models.surrogates.api import ( + ClassificationMLPEnsemble, + RegressionMLPEnsemble, + ScalerEnum, +) from bofire.surrogates.mlp import MLP, MLPDataset, _MLPEnsemble, fit_mlp from bofire.utils.torch_tools import tkwargs @@ -259,3 +265,45 @@ def test_mlp_ensemble_fit_categorical(scaler): surrogate2.loads(dump) preds2 = surrogate2.predict(experiments) assert_frame_equal(preds, preds2) + + +def test_mlp_classification_ensemble_fit(): + # Define toy problem + inputs = Inputs( + features=[ + ContinuousInput(key="x_1", bounds=(-1, 1)), + ContinuousInput(key="x_2", bounds=(-1, 1)), + ] + ) + outputs = Outputs( + features=[ + CategoricalOutput( + key="f_1", + categories=["unacceptable", "acceptable"], + objective=ConstrainedCategoricalObjective( + categories=["unacceptable", "acceptable"], + desirability=[False, True], + ), + ) + ] + ) + domain = Domain(inputs=inputs, outputs=outputs) + samples = domain.inputs.sample(10) + samples["f_1"] = "unacceptable" + samples.loc[samples["x_1"] < 0.0, "f_1"] = "acceptable" + samples["valid_f_1"] = 1 + ens = ClassificationMLPEnsemble( + inputs=domain.inputs, + outputs=domain.outputs, + n_estimators=2, + n_epochs=5, + ) + surrogate = surrogates.map(ens) + surrogate.fit(experiments=samples) + + preds = surrogate.predict(samples) + dump = surrogate.dumps() + surrogate2 = surrogates.map(ens) + surrogate2.loads(dump) + preds2 = surrogate2.predict(samples) + assert_frame_equal(preds, preds2) diff --git a/tests/bofire/utils/test_naming_conventions.py b/tests/bofire/utils/test_naming_conventions.py new file mode 100644 index 000000000..a86f5637b --- /dev/null +++ b/tests/bofire/utils/test_naming_conventions.py @@ -0,0 +1,139 @@ +import pandas as pd +import pytest + +from bofire.data_models.domain.api import Outputs +from bofire.data_models.features.api import CategoricalOutput, ContinuousOutput +from bofire.data_models.objectives.api import ( + ConstrainedCategoricalObjective, + MinimizeObjective, +) +from bofire.utils.naming_conventions import ( + get_column_names, + postprocess_categorical_predictions, +) + +continuous_output = ContinuousOutput(key="cont", objective=MinimizeObjective(w=1)) +categorical_output = CategoricalOutput( + key="cat", + categories=["alpha", "beta"], + objective=ConstrainedCategoricalObjective( + categories=["alpha", "beta"], desirability=[True, False] + ), +) +predictions = pd.DataFrame( + data=[[0.8, 0.2, 1.5, 1e-3, 1e-2, 1e-1]], + columns=[ + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], +) + + +@pytest.mark.parametrize( + "output_features, expected_names", + [ + ([continuous_output], ["cont_pred", "cont_sd"]), + ( + [categorical_output], + ["cat_alpha_prob", "cat_beta_prob", "cat_alpha_sd", "cat_beta_sd"], + ), + ( + [continuous_output, categorical_output], + [ + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], + ), + ( + [categorical_output, continuous_output], + [ + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], + ), + ], +) +def test_get_column_names(output_features, expected_names): + test_outputs = Outputs(features=output_features) + pred_cols, sd_cols = get_column_names(test_outputs) + assert pred_cols + sd_cols == expected_names + + +@pytest.mark.parametrize( + "output_features, input_names, final_names", + [ + ([continuous_output], ["cont_pred", "cont_sd"], ["cont_pred", "cont_sd"]), + ( + [categorical_output], + ["cat_alpha_prob", "cat_beta_prob", "cat_alpha_sd", "cat_beta_sd"], + [ + "cat_pred", + "cat_sd", + "cat_alpha_prob", + "cat_beta_prob", + "cat_alpha_sd", + "cat_beta_sd", + ], + ), + ( + [continuous_output, categorical_output], + [ + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], + [ + "cat_pred", + "cat_sd", + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], + ), + ( + [categorical_output, continuous_output], + [ + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], + [ + "cat_pred", + "cat_sd", + "cat_alpha_prob", + "cat_beta_prob", + "cont_pred", + "cat_alpha_sd", + "cat_beta_sd", + "cont_sd", + ], + ), + ], +) +def test_postprocess_categorical_predictions(output_features, input_names, final_names): + test_outputs = Outputs(features=output_features) + updated_preds = postprocess_categorical_predictions( + predictions=predictions[input_names], outputs=test_outputs + ) + assert updated_preds.columns.tolist() == final_names diff --git a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb index 1ac43bf71..446a4cb1e 100644 --- a/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb +++ b/tutorials/basic_examples/Unknown_Constraint_Classification.ipynb @@ -15,18 +15,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "# Import packages\n", "import bofire.strategies.api as strategies\n", @@ -49,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -75,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -110,63 +101,63 @@ " \n", " \n", " 0\n", - " -0.350619\n", - " 0.504052\n", - " 1\n", - " 16.349300\n", + " -0.073213\n", + " -0.173816\n", + " 0\n", + " 4.362194\n", " acceptable\n", - " 1.009714\n", + " 0.007990\n", " \n", " \n", " 1\n", - " 0.115844\n", - " 0.947509\n", + " -0.558842\n", + " 0.940771\n", " 0\n", - " 88.034069\n", - " ideal\n", - " 0.005132\n", + " 41.927069\n", + " acceptable\n", + " 0.005296\n", " \n", " \n", " 2\n", - " -0.556177\n", - " -0.459197\n", - " 1\n", - " 61.485379\n", - " acceptable\n", - " 1.000490\n", + " -1.663016\n", + " -0.422555\n", + " 0\n", + " 1023.539326\n", + " unacceptable\n", + " 0.005456\n", " \n", " \n", " 3\n", - " -1.635584\n", - " 0.905708\n", - " 1\n", - " 320.033865\n", + " -1.286946\n", + " 1.150658\n", + " 0\n", + " 30.790510\n", " unacceptable\n", - " 1.006897\n", + " 0.007460\n", " \n", " \n", " 4\n", - " -1.474244\n", - " 0.855846\n", + " 0.756959\n", + " 0.342372\n", " 0\n", - " 179.715811\n", - " unacceptable\n", - " 0.005518\n", + " 5.377383\n", + " ideal\n", + " 0.003363\n", " \n", " \n", "\n", "" ], "text/plain": [ - " x_0 x_1 x_3 f_0 f_1 f_2\n", - "0 -0.350619 0.504052 1 16.349300 acceptable 1.009714\n", - "1 0.115844 0.947509 0 88.034069 ideal 0.005132\n", - "2 -0.556177 -0.459197 1 61.485379 acceptable 1.000490\n", - "3 -1.635584 0.905708 1 320.033865 unacceptable 1.006897\n", - "4 -1.474244 0.855846 0 179.715811 unacceptable 0.005518" + " x_0 x_1 x_3 f_0 f_1 f_2\n", + "0 -0.073213 -0.173816 0 4.362194 acceptable 0.007990\n", + "1 -0.558842 0.940771 0 41.927069 acceptable 0.005296\n", + "2 -1.663016 -0.422555 0 1023.539326 unacceptable 0.005456\n", + "3 -1.286946 1.150658 0 30.790510 unacceptable 0.007460\n", + "4 0.756959 0.342372 0 5.377383 ideal 0.003363" ] }, - "execution_count": 3, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -198,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -221,247 +212,247 @@ "showlegend": true, "type": "scatter", "x": [ - -0.35061926427297285, - -0.5561765809009358, - 0.42159282278448273, - -0.14506555699220236, - 1.0353008249274431, - 0.3454444836674355, - -0.7680788452262832, - -0.9933224816259388, - -1.0750262968274997, - 0.7461911958641041, - -0.7476273482445326, - 0.4462467901093472, - -0.07013883971177926, - -0.7613373765433603, - -0.20793189372038778, - -0.7126880256050545, - -0.4298026069294787, - -0.2926997442891148, - -0.48304637342466483, - 0.08633019936772834, - -0.275665340089877, - 0.13163143755446582, - -0.032200237982807245, - 0.630156543011732, - -0.16589247459195278, - -0.11312470103711392, - 1.0905345858791655, - 0.20275804836531774, - 0.3338996521298583, - 0.424736503060791, - 0.09844039335599075 + -0.07321318049019965, + -0.5588419893731316, + -0.8985230467616341, + -1.0292058348995758, + 0.35244862279577305, + 0.9528048806523377, + -1.1646259264294618, + -0.042087482980167845, + 0.4039809383037536, + -1.0307240210015078, + -0.1560715060200324, + -1.1049138766739544, + -0.06810459837939531, + -0.34754085919972844, + 0.9283633902161212, + -1.0436371217335334, + -1.0546493075448813, + 0.02494438460292603, + -0.1710270775439422, + -0.0013368667688371527, + -0.392453180338423, + -0.8647353336778001, + -1.0847911352150188, + -0.42142878734689226, + 1.2496701479105563, + -1.2948654582931687, + -0.3051286802687181, + 0.3662905154646685, + 0.3074010270419505, + -0.7803326707031792, + 0.03260906153286447, + 0.6109573086447044, + -0.43864798001160876, + 0.5438019840565453, + -0.8265360983870151, + 0.37264481028879093, + -0.3281250548835608, + -0.4167203675788893, + -0.0657892829207547 ], "xaxis": "x", "y": [ - 0.5040523211718617, - -0.45919668429740934, - -0.43223105338024803, - -0.5989964538201493, - -0.5649500247669095, - -0.5355390267695426, - 0.9550083378879219, - 0.39901716758574235, - -0.6999593792729817, - -0.6234513037416309, - 0.6181801617811384, - -0.8593732637207796, - -0.7322677185785165, - -1.1687864901526548, - 0.20233627906121776, - -1.1278268495317838, - 0.5197761747103247, - 1.0172863322035508, - -0.23399463937541132, - 0.31703732293322284, - -1.2846423115547891, - 0.4365633647242215, - 0.8598355161979865, - -1.2156493776606054, - -0.6155534138636449, - -0.15102340108948908, - -0.007804401336033662, - -1.1909885140068746, - -1.0674680087818904, - -0.9005255478079787, - -1.3402079733707999 + -0.17381593448062027, + 0.9407713973498608, + 0.020726174078847137, + 0.913310979582409, + -0.3661492308705969, + -0.09940322024898673, + 0.5508334055831456, + -0.05418903548597043, + -0.8227839556914979, + 0.1457624938904325, + 1.147833506871621, + -0.6869957196475109, + -1.4056893033808622, + 0.0718887572832887, + -0.6790577296916382, + -0.15584601306358925, + 0.4876389137502479, + 0.7475805215463804, + -0.6273661151308012, + 0.03222974761832331, + -0.5555262316869118, + 0.8877236548358565, + -0.2629103743044927, + 0.19134278742538058, + 0.013239929092230263, + 0.5397028293175845, + 0.2035004923467727, + -1.2377668011422878, + -0.6462284326270806, + 0.05582182979174766, + 0.236334433477269, + -1.103983091579659, + 1.0587162308079088, + -1.0355728036529377, + -0.7568147674837169, + -0.45150202629320724, + -0.9243292819936946, + -1.1460364580489035, + 1.0692357670790393 ], "yaxis": "y" }, { - "hovertemplate": "f_1=ideal
x_0=%{x}
x_1=%{y}", - "legendgroup": "ideal", + "hovertemplate": "f_1=unacceptable
x_0=%{x}
x_1=%{y}", + "legendgroup": "unacceptable", "marker": { "color": "#EF553B", "symbol": "circle" }, "mode": "markers", - "name": "ideal", + "name": "unacceptable", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ - 0.11584353456460073, - 0.1724670780831814, - 0.539818540333588, - 0.15873339655513652, - 1.2444600831177084, - 0.999781463467607, - 1.1787431301713105, - 0.20622111629978224, - 1.1595275447276538 + -1.6630162760045795, + -1.286946137400366, + -1.1931846597891524, + 1.2146613347296635, + 1.7000891922760921, + 1.4872678080585287, + -1.4804637110018204, + 1.1699124133484027, + 1.516602066498184, + -0.9909446226396376, + -1.5378141242498566, + 1.5954528965403396, + -1.5294381480146355, + 1.4594978431768837, + -0.3577423022945294, + 1.620229456555938, + -1.664431954673963, + 1.4160944727551756, + -1.5482048479807955, + 1.1924918084084961, + -1.731761227261967, + 1.404936688277174, + 1.0947766973891762, + -0.519276259387774, + 0.6730885610273702, + 0.915988953919344, + 1.1752917242186087, + 1.3637226302449381, + 1.6785214915139632, + 1.58105290419071, + 1.4691098470484172, + -1.6563050260430705, + -1.270415941962483, + 0.2099876083549137, + -1.1662504899545398, + 1.518607456866964, + 0.8441218222783702, + -1.559056742958962, + -1.5939434136412012, + 1.3689384384883678, + -1.6053436462630795, + 0.008220829197049806 ], "xaxis": "x", "y": [ - 0.9475093155136509, - 1.1526019287770612, - 0.3540271307917444, - 1.0991475451214328, - 0.618585769140207, - 0.135534043926683, - 0.676481578469553, - 1.2930076238613046, - 0.09014125654752259 + -0.422554511553753, + 1.150657542953963, + -1.1410548182715137, + -1.3472181218753596, + 0.11396236358195266, + 1.0691736792669388, + -1.5004165302144574, + -1.4892754744616727, + 1.3480308717512366, + 1.3080579840997544, + -0.685289399058566, + 1.2498147025055397, + 0.9462210918878791, + -1.1355527818651603, + 1.7213059768850618, + -0.7557851364023072, + -0.6890348617836524, + 1.6603265875448114, + 1.6760989802647415, + -1.2484894932822872, + 1.1512686965592982, + -0.20575815864446456, + -1.703770113900696, + -1.3505322003913427, + -1.248301135898829, + 1.0888210192742083, + 1.6211428734674747, + 0.7952815061305452, + -0.04727366650452258, + 0.9285271910202479, + -0.16942296469152174, + 0.6418882271459929, + -0.7121367703504429, + -1.4394070766626794, + -1.0653570065486102, + -1.7155959760444761, + 1.6390645504126615, + 0.7405404555502035, + 0.7745426501661656, + -0.892754487596001, + -0.20459051213543078, + -1.7000574034009923 ], "yaxis": "y" }, { - "hovertemplate": "f_1=unacceptable
x_0=%{x}
x_1=%{y}", - "legendgroup": "unacceptable", + "hovertemplate": "f_1=ideal
x_0=%{x}
x_1=%{y}", + "legendgroup": "ideal", "marker": { "color": "#00cc96", "symbol": "circle" }, "mode": "markers", - "name": "unacceptable", + "name": "ideal", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ - -1.635584388534313, - -1.4742444296046657, - -0.15729314694102903, - 1.5070853560854864, - 1.2101965364760718, - -1.5831812018666418, - -1.3957357203071692, - 1.613078214961829, - 1.4254239036801257, - -1.2787167270583444, - -1.4023268999370018, - 1.7378096362285174, - 0.6296633849190485, - 1.6560265107086973, - -1.4604637134929277, - 1.3954194906359776, - -0.3273878127686993, - 1.5426483753054665, - 0.7725626061807955, - 1.5974661747883538, - 1.688121615826073, - -1.5816687013750201, - -1.3669080002392295, - -0.933603434080167, - 0.4822506956934074, - -1.4195045992551758, - -1.5862157185428094, - -0.6113354136054985, - -0.04870290118977216, - -1.3475465180331827, - 1.4165738521389044, - -0.7433524285306379, - -0.9790834992740071, - 0.4177249330842745, - -1.5065150315297096, - 0.8379155379744967, - -1.6748335190020076, - 1.732364494283424, - 1.6746380993561396, - -1.6486196321522595, - -1.6923451290787777, - 1.0735234059299241, - 1.6026120196380926, - 1.6368331859404996, - 1.344210315200074, - 0.5756493888406227, - 1.0955596911970198, - 1.1989776196979713, - 1.6443171102633003, - -1.2980309445061669, - 1.0537367958501087, - -1.2500086281626677, - 0.38230323278127143, - -1.6985134148647894, - 1.7156174130125836, - 1.6853737833769693, - -0.294642574709699, - 1.1650211575724394, - -0.6805240325421535, - -0.8809593972969152 + 0.7569591284188575, + 0.93010764779936, + 0.5603523538456492, + 0.9792953410787515, + 0.10392338145894309, + 0.8683091101696028, + 0.26744273616884895, + 0.6855353177932879, + 0.10899662399819876, + 0.6268617084110515, + 1.153891059773529, + 0.13595914708052836, + 0.6246997131763035, + 0.2709624813181275, + 1.1659237387592172, + 0.2416269711683232, + 0.25385854534216046, + 0.2922839428499806, + 0.07904362548010702 ], "xaxis": "x", "y": [ - 0.905708248434292, - 0.8558461621953355, - -1.682916252803905, - -0.3787532353125016, - 0.8582593276903339, - -0.5950896738838634, - -0.5940137871920919, - -1.1106762625595035, - 1.749844877144934, - -0.7745816878526347, - -1.0080154738799716, - 0.26464782707414836, - 1.469231740752055, - -1.6208934974410696, - -0.07648992450843961, - -1.0433090911846108, - -1.4790774655942789, - -1.0982882152772908, - -1.241192967822959, - -0.5220349756706821, - -0.3558204214616887, - 0.2092233051309027, - -1.3578044601398707, - 1.7264947346866033, - 1.5894671117819081, - -1.744256531724441, - -0.7461734234273045, - 1.7424015284023127, - -1.5287936583164727, - -0.6814323944749845, - 1.266759433319939, - 1.5399709474328813, - -1.5645549489253556, - -1.5196196260358157, - 0.44194322574419287, - 1.3477204637842641, - 1.7295610322135961, - -1.1974130031498986, - 0.49951459600915626, - 0.11526237063394973, - -1.7248874050370138, - 1.5908887950885897, - 0.27689415663406214, - -0.31530244091173354, - -1.6932092535195582, - 1.3482667568433095, - 1.5726556777198932, - -0.9488385694367101, - 0.2919559196715964, - 1.7065717592720269, - 1.5139353043217532, - -1.6001644269256965, - -1.7115693621588608, - -0.38867144542644105, - 0.6923493546148185, - 0.1427173148876928, - 1.4870813273250616, - -0.9737357252558907, - 1.2417204272219675, - 1.740049552575495 + 0.3423724207639953, + 0.6464112742365464, + 1.1292454853281328, + 0.3232307548157345, + 0.6755389021646168, + 0.5375652067425829, + 1.2131514639239276, + 0.9910921631319534, + 0.6557939575342395, + 1.088247143412076, + 0.2244959943276288, + 0.8790756393547463, + 0.3947077652057054, + 0.5872178792073877, + 0.5167616872242884, + 1.2043027308846912, + 0.9056486121924587, + 1.1191673592578, + 1.0616664303067846 ], "yaxis": "y" } @@ -1382,7 +1373,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -1434,7 +1425,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -1465,19 +1456,19 @@ " \n", " \n", " 0\n", - " 0.72\n", - " 0.72\n", + " 0.725\n", + " 0.725\n", " \n", " \n", "\n", "" ], "text/plain": [ - " ACCURACY F1\n", - "0 0.72 0.72" + " ACCURACY F1\n", + "0 0.725 0.725" ] }, - "execution_count": 6, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -1489,7 +1480,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1520,8 +1511,8 @@ " \n", " \n", " 0\n", - " 0.65\n", - " 0.65\n", + " 0.56\n", + " 0.56\n", " \n", " \n", "\n", @@ -1529,10 +1520,10 @@ ], "text/plain": [ " ACCURACY F1\n", - "0 0.65 0.65" + "0 0.56 0.56" ] }, - "execution_count": 7, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1551,9 +1542,20 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\ProgramData\\Anaconda3\\envs\\bofire-env\\lib\\site-packages\\botorch\\optim\\fit.py:102: OptimizationWarning:\n", + "\n", + "`scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH\n", + "\n" + ] + } + ], "source": [ "from bofire.data_models.acquisition_functions.api import qEI\n", "from bofire.data_models.strategies.api import SoboStrategy\n", @@ -1577,7 +1579,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1642,213 +1644,213 @@ " \n", " \n", " 0\n", - " 0.154314\n", - " -0.001118\n", + " 0.445121\n", + " 0.203488\n", " 1\n", " acceptable\n", " 0.0\n", - " 0.425532\n", - " 1.004986\n", - " 0.639645\n", - " 0.357941\n", - " 0.002414\n", - " 6.588467\n", - " 0.003142\n", - " 0.476779\n", - " 0.473279\n", - " 0.004336\n", - " -0.357941\n", - " 0.499698\n", - " 1.644631\n", + " -0.193104\n", + " 1.006308\n", + " 0.104767\n", + " 0.144256\n", + " 0.750977\n", + " 4.968438\n", + " 0.003073\n", + " 0.203727\n", + " 0.219485\n", + " 0.421818\n", + " -0.144256\n", + " 0.407216\n", + " 1.111075\n", " \n", " \n", " 1\n", - " 0.149196\n", - " 0.013251\n", + " -1.309524\n", + " 1.750000\n", " 0\n", - " ideal\n", + " unacceptable\n", " 0.0\n", - " 0.465530\n", - " 0.005402\n", - " 0.795662\n", - " 0.204113\n", - " 0.000226\n", - " 3.777311\n", - " 0.003156\n", - " 0.430231\n", - " 0.430007\n", - " 0.000242\n", - " -0.204113\n", - " 0.499972\n", - " 0.801064\n", + " 0.730642\n", + " 0.006611\n", + " 0.368515\n", + " 0.630446\n", + " 0.001039\n", + " 16.560211\n", + " 0.003667\n", + " 0.348812\n", + " 0.347461\n", + " 0.001424\n", + " -0.630446\n", + " 0.499870\n", + " 0.375126\n", " \n", " \n", " 2\n", - " -1.333767\n", + " -1.258058\n", " 1.750000\n", - " 0\n", - " ideal\n", + " 1\n", + " acceptable\n", " 0.0\n", - " 0.623769\n", - " 0.007608\n", - " 0.795045\n", - " 0.204738\n", - " 0.000217\n", - " 11.252150\n", - " 0.003407\n", - " 0.431594\n", - " 0.431392\n", - " 0.000222\n", - " -0.204738\n", - " 0.499973\n", - " 0.802653\n", + " -10.671111\n", + " 1.006274\n", + " 0.339735\n", + " 0.471522\n", + " 0.188743\n", + " 19.401399\n", + " 0.003533\n", + " 0.410371\n", + " 0.431353\n", + " 0.417536\n", + " -0.471522\n", + " 0.476425\n", + " 1.346009\n", " \n", " \n", " 3\n", - " -0.630759\n", - " 0.390813\n", + " 1.292107\n", + " 1.750000\n", " 1\n", - " unacceptable\n", + " acceptable\n", " 0.0\n", - " 2.985411\n", - " 1.005220\n", - " 0.638927\n", - " 0.359041\n", - " 0.002033\n", - " 6.575972\n", - " 0.003180\n", - " 0.478100\n", - " 0.475209\n", - " 0.003490\n", - " -0.359041\n", - " 0.499746\n", - " 1.644147\n", + " -5.226207\n", + " 1.005614\n", + " 0.263137\n", + " 0.002680\n", + " 0.734182\n", + " 6.172339\n", + " 0.003671\n", + " 0.427584\n", + " 0.003406\n", + " 0.425910\n", + " -0.002680\n", + " 0.409244\n", + " 1.268751\n", " \n", " \n", " 4\n", - " -1.114157\n", - " 1.209882\n", + " 0.498312\n", + " 0.254918\n", " 0\n", - " unacceptable\n", + " ideal\n", " 0.0\n", - " 8.030368\n", - " 0.006634\n", - " 0.795048\n", - " 0.204736\n", - " 0.000216\n", - " 14.069044\n", - " 0.003171\n", - " 0.431596\n", - " 0.431394\n", - " 0.000222\n", - " -0.204736\n", - " 0.499973\n", - " 0.801682\n", + " 0.202863\n", + " 0.005522\n", + " 0.363525\n", + " 0.237623\n", + " 0.398852\n", + " 5.107777\n", + " 0.003061\n", + " 0.343516\n", + " 0.236302\n", + " 0.546131\n", + " -0.237623\n", + " 0.450308\n", + " 0.369047\n", " \n", " \n", " 5\n", - " -0.358540\n", - " 0.102376\n", - " 1\n", + " -1.430995\n", + " 1.750000\n", + " 0\n", " unacceptable\n", " 0.0\n", - " 1.546216\n", - " 1.004791\n", - " 0.638009\n", - " 0.360123\n", - " 0.001868\n", - " 4.052824\n", - " 0.003159\n", - " 0.479443\n", - " 0.476816\n", - " 0.003129\n", - " -0.360123\n", - " 0.499766\n", - " 1.642799\n", + " 15.038112\n", + " 0.007034\n", + " 0.367135\n", + " 0.631857\n", + " 0.001008\n", + " 22.077586\n", + " 0.003766\n", + " 0.347731\n", + " 0.346423\n", + " 0.001392\n", + " -0.631857\n", + " 0.499874\n", + " 0.374169\n", " \n", " \n", " 6\n", - " 0.736046\n", - " 0.516076\n", - " 1\n", - " acceptable\n", + " -1.501348\n", + " 1.627440\n", + " 0\n", + " unacceptable\n", " 0.0\n", - " -0.240881\n", - " 1.004885\n", - " 0.819717\n", - " 0.173096\n", - " 0.007188\n", - " 8.248212\n", - " 0.003162\n", - " 0.331533\n", - " 0.333793\n", - " 0.014980\n", - " -0.173096\n", - " 0.499102\n", - " 1.824602\n", + " 49.375552\n", + " 0.007131\n", + " 0.364155\n", + " 0.634857\n", + " 0.000988\n", + " 21.504947\n", + " 0.003674\n", + " 0.345456\n", + " 0.344176\n", + " 0.001372\n", + " -0.634857\n", + " 0.499877\n", + " 0.371287\n", " \n", " \n", " 7\n", - " -0.558970\n", - " 0.322902\n", + " -1.468474\n", + " 1.695817\n", " 0\n", " unacceptable\n", " 0.0\n", - " 2.070167\n", - " 0.005272\n", - " 0.795045\n", - " 0.204738\n", - " 0.000217\n", - " 4.657032\n", - " 0.003146\n", - " 0.431604\n", - " 0.431400\n", - " 0.000224\n", - " -0.204738\n", - " 0.499973\n", - " 0.800317\n", + " 29.651239\n", + " 0.007100\n", + " 0.365440\n", + " 0.633564\n", + " 0.000996\n", + " 22.143859\n", + " 0.003728\n", + " 0.346425\n", + " 0.345133\n", + " 0.001380\n", + " -0.633564\n", + " 0.499875\n", + " 0.372539\n", " \n", " \n", " 8\n", - " 0.303819\n", - " 0.044806\n", - " 1\n", - " acceptable\n", + " -1.518190\n", + " 1.517966\n", + " 0\n", + " unacceptable\n", " 0.0\n", - " 0.656479\n", - " 1.005106\n", - " 0.644744\n", - " 0.351593\n", - " 0.003664\n", - " 6.172526\n", - " 0.003141\n", - " 0.468581\n", - " 0.463146\n", - " 0.007114\n", - " -0.351593\n", - " 0.499542\n", - " 1.649849\n", + " 72.020318\n", + " 0.007057\n", + " 0.363156\n", + " 0.635863\n", + " 0.000981\n", + " 18.759220\n", + " 0.003571\n", + " 0.344707\n", + " 0.343437\n", + " 0.001366\n", + " -0.635863\n", + " 0.499877\n", + " 0.370213\n", " \n", " \n", " 9\n", - " -0.992659\n", - " 0.943164\n", + " -1.459283\n", + " 1.442226\n", " 0\n", " unacceptable\n", " 0.0\n", - " 6.008622\n", - " 0.006171\n", - " 0.795048\n", - " 0.204736\n", - " 0.000216\n", - " 10.497928\n", - " 0.003149\n", - " 0.431597\n", - " 0.431395\n", - " 0.000222\n", - " -0.204736\n", - " 0.499973\n", - " 0.801219\n", + " 54.177939\n", + " 0.006770\n", + " 0.362903\n", + " 0.636116\n", + " 0.000981\n", + " 13.822711\n", + " 0.003457\n", + " 0.344511\n", + " 0.343241\n", + " 0.001366\n", + " -0.636116\n", + " 0.499877\n", + " 0.369673\n", " \n", " \n", "\n", @@ -1856,55 +1858,55 @@ ], "text/plain": [ " x_0 x_1 x_3 f_1_pred f_1_sd f_1_unacceptable_prob \\\n", - "0 0.154314 -0.001118 1 acceptable 0.0 0.425532 \n", - "1 0.149196 0.013251 0 ideal 0.0 0.465530 \n", - "2 -1.333767 1.750000 0 ideal 0.0 0.623769 \n", - "3 -0.630759 0.390813 1 unacceptable 0.0 2.985411 \n", - "4 -1.114157 1.209882 0 unacceptable 0.0 8.030368 \n", - "5 -0.358540 0.102376 1 unacceptable 0.0 1.546216 \n", - "6 0.736046 0.516076 1 acceptable 0.0 -0.240881 \n", - "7 -0.558970 0.322902 0 unacceptable 0.0 2.070167 \n", - "8 0.303819 0.044806 1 acceptable 0.0 0.656479 \n", - "9 -0.992659 0.943164 0 unacceptable 0.0 6.008622 \n", + "0 0.445121 0.203488 1 acceptable 0.0 -0.193104 \n", + "1 -1.309524 1.750000 0 unacceptable 0.0 0.730642 \n", + "2 -1.258058 1.750000 1 acceptable 0.0 -10.671111 \n", + "3 1.292107 1.750000 1 acceptable 0.0 -5.226207 \n", + "4 0.498312 0.254918 0 ideal 0.0 0.202863 \n", + "5 -1.430995 1.750000 0 unacceptable 0.0 15.038112 \n", + "6 -1.501348 1.627440 0 unacceptable 0.0 49.375552 \n", + "7 -1.468474 1.695817 0 unacceptable 0.0 29.651239 \n", + "8 -1.518190 1.517966 0 unacceptable 0.0 72.020318 \n", + "9 -1.459283 1.442226 0 unacceptable 0.0 54.177939 \n", "\n", " f_1_acceptable_prob f_1_ideal_prob f_0_pred f_2_pred \\\n", - "0 1.004986 0.639645 0.357941 0.002414 \n", - "1 0.005402 0.795662 0.204113 0.000226 \n", - "2 0.007608 0.795045 0.204738 0.000217 \n", - "3 1.005220 0.638927 0.359041 0.002033 \n", - "4 0.006634 0.795048 0.204736 0.000216 \n", - "5 1.004791 0.638009 0.360123 0.001868 \n", - "6 1.004885 0.819717 0.173096 0.007188 \n", - "7 0.005272 0.795045 0.204738 0.000217 \n", - "8 1.005106 0.644744 0.351593 0.003664 \n", - "9 0.006171 0.795048 0.204736 0.000216 \n", + "0 1.006308 0.104767 0.144256 0.750977 \n", + "1 0.006611 0.368515 0.630446 0.001039 \n", + "2 1.006274 0.339735 0.471522 0.188743 \n", + "3 1.005614 0.263137 0.002680 0.734182 \n", + "4 0.005522 0.363525 0.237623 0.398852 \n", + "5 0.007034 0.367135 0.631857 0.001008 \n", + "6 0.007131 0.364155 0.634857 0.000988 \n", + "7 0.007100 0.365440 0.633564 0.000996 \n", + "8 0.007057 0.363156 0.635863 0.000981 \n", + "9 0.006770 0.362903 0.636116 0.000981 \n", "\n", " f_1_unacceptable_sd f_1_acceptable_sd f_1_ideal_sd f_0_sd f_2_sd \\\n", - "0 6.588467 0.003142 0.476779 0.473279 0.004336 \n", - "1 3.777311 0.003156 0.430231 0.430007 0.000242 \n", - "2 11.252150 0.003407 0.431594 0.431392 0.000222 \n", - "3 6.575972 0.003180 0.478100 0.475209 0.003490 \n", - "4 14.069044 0.003171 0.431596 0.431394 0.000222 \n", - "5 4.052824 0.003159 0.479443 0.476816 0.003129 \n", - "6 8.248212 0.003162 0.331533 0.333793 0.014980 \n", - "7 4.657032 0.003146 0.431604 0.431400 0.000224 \n", - "8 6.172526 0.003141 0.468581 0.463146 0.007114 \n", - "9 10.497928 0.003149 0.431597 0.431395 0.000222 \n", + "0 4.968438 0.003073 0.203727 0.219485 0.421818 \n", + "1 16.560211 0.003667 0.348812 0.347461 0.001424 \n", + "2 19.401399 0.003533 0.410371 0.431353 0.417536 \n", + "3 6.172339 0.003671 0.427584 0.003406 0.425910 \n", + "4 5.107777 0.003061 0.343516 0.236302 0.546131 \n", + "5 22.077586 0.003766 0.347731 0.346423 0.001392 \n", + "6 21.504947 0.003674 0.345456 0.344176 0.001372 \n", + "7 22.143859 0.003728 0.346425 0.345133 0.001380 \n", + "8 18.759220 0.003571 0.344707 0.343437 0.001366 \n", + "9 13.822711 0.003457 0.344511 0.343241 0.001366 \n", "\n", " f_0_des f_2_des f_1_des \n", - "0 -0.357941 0.499698 1.644631 \n", - "1 -0.204113 0.499972 0.801064 \n", - "2 -0.204738 0.499973 0.802653 \n", - "3 -0.359041 0.499746 1.644147 \n", - "4 -0.204736 0.499973 0.801682 \n", - "5 -0.360123 0.499766 1.642799 \n", - "6 -0.173096 0.499102 1.824602 \n", - "7 -0.204738 0.499973 0.800317 \n", - "8 -0.351593 0.499542 1.649849 \n", - "9 -0.204736 0.499973 0.801219 " + "0 -0.144256 0.407216 1.111075 \n", + "1 -0.630446 0.499870 0.375126 \n", + "2 -0.471522 0.476425 1.346009 \n", + "3 -0.002680 0.409244 1.268751 \n", + "4 -0.237623 0.450308 0.369047 \n", + "5 -0.631857 0.499874 0.374169 \n", + "6 -0.634857 0.499877 0.371287 \n", + "7 -0.633564 0.499875 0.372539 \n", + "8 -0.635863 0.499877 0.370213 \n", + "9 -0.636116 0.499877 0.369673 " ] }, - "execution_count": 9, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -1925,7 +1927,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -1935,7 +1937,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1968,73 +1970,73 @@ " \n", " \n", " 0\n", - " 0.154314\n", - " -0.001118\n", - " acceptable\n", + " 0.445121\n", + " 0.203488\n", " acceptable\n", + " ideal\n", " \n", " \n", " 1\n", - " 0.149196\n", - " 0.013251\n", - " ideal\n", - " acceptable\n", + " -1.309524\n", + " 1.750000\n", + " unacceptable\n", + " unacceptable\n", " \n", " \n", " 2\n", - " -1.333767\n", + " -1.258058\n", " 1.750000\n", - " ideal\n", + " acceptable\n", " unacceptable\n", " \n", " \n", " 3\n", - " -0.630759\n", - " 0.390813\n", - " unacceptable\n", + " 1.292107\n", + " 1.750000\n", " acceptable\n", + " unacceptable\n", " \n", " \n", " 4\n", - " -1.114157\n", - " 1.209882\n", - " unacceptable\n", - " unacceptable\n", + " 0.498312\n", + " 0.254918\n", + " ideal\n", + " ideal\n", " \n", " \n", " 5\n", - " -0.358540\n", - " 0.102376\n", + " -1.430995\n", + " 1.750000\n", + " unacceptable\n", " unacceptable\n", - " acceptable\n", " \n", " \n", " 6\n", - " 0.736046\n", - " 0.516076\n", - " acceptable\n", - " ideal\n", + " -1.501348\n", + " 1.627440\n", + " unacceptable\n", + " unacceptable\n", " \n", " \n", " 7\n", - " -0.558970\n", - " 0.322902\n", + " -1.468474\n", + " 1.695817\n", + " unacceptable\n", " unacceptable\n", - " acceptable\n", " \n", " \n", " 8\n", - " 0.303819\n", - " 0.044806\n", - " acceptable\n", - " acceptable\n", + " -1.518190\n", + " 1.517966\n", + " unacceptable\n", + " unacceptable\n", " \n", " \n", " 9\n", - " -0.992659\n", - " 0.943164\n", + " -1.459283\n", + " 1.442226\n", + " unacceptable\n", " unacceptable\n", - " acceptable\n", " \n", " \n", "\n", @@ -2042,19 +2044,19 @@ ], "text/plain": [ " x_0 x_1 f_1_pred f_1_true\n", - "0 0.154314 -0.001118 acceptable acceptable\n", - "1 0.149196 0.013251 ideal acceptable\n", - "2 -1.333767 1.750000 ideal unacceptable\n", - "3 -0.630759 0.390813 unacceptable acceptable\n", - "4 -1.114157 1.209882 unacceptable unacceptable\n", - "5 -0.358540 0.102376 unacceptable acceptable\n", - "6 0.736046 0.516076 acceptable ideal\n", - "7 -0.558970 0.322902 unacceptable acceptable\n", - "8 0.303819 0.044806 acceptable acceptable\n", - "9 -0.992659 0.943164 unacceptable acceptable" + "0 0.445121 0.203488 acceptable ideal\n", + "1 -1.309524 1.750000 unacceptable unacceptable\n", + "2 -1.258058 1.750000 acceptable unacceptable\n", + "3 1.292107 1.750000 acceptable unacceptable\n", + "4 0.498312 0.254918 ideal ideal\n", + "5 -1.430995 1.750000 unacceptable unacceptable\n", + "6 -1.501348 1.627440 unacceptable unacceptable\n", + "7 -1.468474 1.695817 unacceptable unacceptable\n", + "8 -1.518190 1.517966 unacceptable unacceptable\n", + "9 -1.459283 1.442226 unacceptable unacceptable" ] }, - "execution_count": 12, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } From e0095091693b94f833887a8c3e367612adb1e1c5 Mon Sep 17 00:00:00 2001 From: gmancino Date: Mon, 26 Feb 2024 09:33:32 -0500 Subject: [PATCH 28/31] Remove comments from MLP file --- bofire/data_models/surrogates/mlp.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index 451aa4f7a..a73327376 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -23,17 +23,7 @@ class MLPEnsemble(TrainableBotorchSurrogate): weight_decay: Annotated[float, Field(ge=0.0)] = 0.0 subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 shuffle: bool = True - scaler: ScalerEnum = ScalerEnum.NORMALIZE - - # @classmethod - # def is_output_implemented(cls, my_type: str) -> bool: - # """Abstract method to check output type for surrogate models - # Args: - # my_type: continuous or categorical output - # Returns: - # bool: True if the output type is valid for the surrogate chosen, False otherwise - # """ - # return isinstance(my_type, (CategoricalOutput, ContinuousOutput)) + scalar: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY class RegressionMLPEnsemble(MLPEnsemble): From f025988c6da789270e190e78ff7f32a06efad1a5 Mon Sep 17 00:00:00 2001 From: gmancino Date: Mon, 26 Feb 2024 09:39:36 -0500 Subject: [PATCH 29/31] Fix tests --- bofire/data_models/surrogates/mlp.py | 3 ++- tests/bofire/data_models/specs/surrogates.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index a73327376..2d5759ed5 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -23,7 +23,8 @@ class MLPEnsemble(TrainableBotorchSurrogate): weight_decay: Annotated[float, Field(ge=0.0)] = 0.0 subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 shuffle: bool = True - scalar: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY + scaler: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY + output_scaler: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY class RegressionMLPEnsemble(MLPEnsemble): diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index be8c9ebe3..6c8b8a9fa 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -183,8 +183,8 @@ "weight_decay": 0.0, "subsample_fraction": 1.0, "shuffle": True, - "scaler": ScalerEnum.NORMALIZE, - "output_scaler": ScalerEnum.STANDARDIZE, + "scaler": ScalerEnum.IDENTITY, + "output_scaler": ScalerEnum.IDENTITY, "input_preprocessing_specs": {}, "dump": None, "hyperconfig": None, @@ -215,8 +215,8 @@ "weight_decay": 0.0, "subsample_fraction": 1.0, "shuffle": True, - "scaler": ScalerEnum.NORMALIZE, - "output_scaler": ScalerEnum.STANDARDIZE, + "scaler": ScalerEnum.IDENTITY, + "output_scaler": ScalerEnum.IDENTITY, "input_preprocessing_specs": {}, "dump": None, "hyperconfig": None, @@ -249,8 +249,8 @@ "weight_decay": 0.0, "subsample_fraction": 1.0, "shuffle": True, - "scaler": ScalerEnum.NORMALIZE, - "output_scaler": ScalerEnum.STANDARDIZE, + "scaler": ScalerEnum.IDENTITY, + "output_scaler": ScalerEnum.IDENTITY, "input_preprocessing_specs": {}, "dump": None, "hyperconfig": None, @@ -281,8 +281,8 @@ "weight_decay": 0.0, "subsample_fraction": 1.0, "shuffle": True, - "scaler": ScalerEnum.NORMALIZE, - "output_scaler": ScalerEnum.STANDARDIZE, + "scaler": ScalerEnum.IDENTITY, + "output_scaler": ScalerEnum.IDENTITY, "input_preprocessing_specs": {}, "dump": None, "hyperconfig": None, From 8dc5ee074561019e79d9ea7ef2862d8ae71177c1 Mon Sep 17 00:00:00 2001 From: gmancino Date: Mon, 26 Feb 2024 09:53:13 -0500 Subject: [PATCH 30/31] Fix MLP scalers --- bofire/data_models/surrogates/mlp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bofire/data_models/surrogates/mlp.py b/bofire/data_models/surrogates/mlp.py index 2d5759ed5..73059dc98 100644 --- a/bofire/data_models/surrogates/mlp.py +++ b/bofire/data_models/surrogates/mlp.py @@ -23,13 +23,13 @@ class MLPEnsemble(TrainableBotorchSurrogate): weight_decay: Annotated[float, Field(ge=0.0)] = 0.0 subsample_fraction: Annotated[float, Field(gt=0.0)] = 1.0 shuffle: bool = True - scaler: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY - output_scaler: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY class RegressionMLPEnsemble(MLPEnsemble): type: Literal["RegressionMLPEnsemble"] = "RegressionMLPEnsemble" final_activation: Literal["identity"] = "identity" + scaler: ScalerEnum = ScalerEnum.IDENTITY + output_scaler: ScalerEnum = ScalerEnum.IDENTITY @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: @@ -45,6 +45,8 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: class ClassificationMLPEnsemble(MLPEnsemble): type: Literal["ClassificationMLPEnsemble"] = "ClassificationMLPEnsemble" final_activation: Literal["softmax"] = "softmax" + scaler: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY + output_scaler: Literal[ScalerEnum.IDENTITY] = ScalerEnum.IDENTITY @classmethod def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: From dd341d632479f0b165de6bdf4c7cb0e70c3f3b08 Mon Sep 17 00:00:00 2001 From: gmancino Date: Mon, 26 Feb 2024 15:51:11 -0500 Subject: [PATCH 31/31] Remove CategoricalObjective --- bofire/data_models/domain/features.py | 4 ++-- bofire/data_models/objectives/api.py | 2 -- bofire/data_models/objectives/categorical.py | 8 +------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index 794cbf6d6..9ff8a80c9 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -32,7 +32,7 @@ from bofire.data_models.molfeatures.api import MolFeatures from bofire.data_models.objectives.api import ( AbstractObjective, - CategoricalObjective, + ConstrainedCategoricalObjective, Objective, ) from bofire.data_models.types import TInputTransformSpecs @@ -760,7 +760,7 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame: [ [f"{feat.key}_pred", f"{feat.key}_sd", f"{feat.key}_des"] for feat in self.get_by_objective( - includes=Objective, excludes=CategoricalObjective + includes=Objective, excludes=ConstrainedCategoricalObjective ) ] + [ diff --git a/bofire/data_models/objectives/api.py b/bofire/data_models/objectives/api.py index de1a5381f..1d50e03e2 100644 --- a/bofire/data_models/objectives/api.py +++ b/bofire/data_models/objectives/api.py @@ -1,7 +1,6 @@ from typing import Union from bofire.data_models.objectives.categorical import ( - CategoricalObjective, ConstrainedCategoricalObjective, ) from bofire.data_models.objectives.identity import ( @@ -26,7 +25,6 @@ IdentityObjective, SigmoidObjective, ConstrainedObjective, - CategoricalObjective, ] AnyCategoricalObjective = ConstrainedCategoricalObjective diff --git a/bofire/data_models/objectives/categorical.py b/bofire/data_models/objectives/categorical.py index 92e24e5f0..595cab0cd 100644 --- a/bofire/data_models/objectives/categorical.py +++ b/bofire/data_models/objectives/categorical.py @@ -12,13 +12,7 @@ from bofire.data_models.types import TCategoryVals -class CategoricalObjective: - """Abstract categorical objective class""" - - -class ConstrainedCategoricalObjective( - ConstrainedObjective, CategoricalObjective, Objective -): +class ConstrainedCategoricalObjective(ConstrainedObjective, Objective): """Compute the categorical objective value as: Po where P is an [n, c] matrix where each row is a probability vector