Skip to content

Commit

Permalink
removed deprecation from strategies serialization (#375)
Browse files Browse the repository at this point in the history
* removed deprecation from strategies serialization

* removed deprection from model notebook
  • Loading branch information
bertiqwerty authored Mar 15, 2024
1 parent 86f1220 commit 39fb5fa
Show file tree
Hide file tree
Showing 2 changed files with 1,405 additions and 221 deletions.
189 changes: 35 additions & 154 deletions tutorials/models_serial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from bofire.data_models.domain.api import Outputs\n",
"from bofire.data_models.surrogates.api import SingleTaskGPSurrogate, RandomForestSurrogate, MixedSingleTaskGPSurrogate, AnySurrogate, RandomForestSurrogate, EmpiricalSurrogate, RegressionMLPEnsemble\n",
"from bofire.benchmarks.single import Himmelblau\n",
"from bofire.benchmarks.multi import CrossCoupling\n",
"import bofire.surrogates.api as surrogates\n",
"import json\n",
"from bofire.data_models.enum import CategoricalEncodingEnum\n",
"\n",
"from pydantic import parse_obj_as"
"from pydantic import TypeAdapter\n",
"\n",
"import bofire.surrogates.api as surrogates\n",
"from bofire.benchmarks.multi import CrossCoupling\n",
"from bofire.benchmarks.single import Himmelblau\n",
"from bofire.data_models.domain.api import Outputs\n",
"from bofire.data_models.enum import CategoricalEncodingEnum\n",
"from bofire.data_models.surrogates.api import (\n",
" AnySurrogate,\n",
" EmpiricalSurrogate,\n",
" MixedSingleTaskGPSurrogate,\n",
" RandomForestSurrogate,\n",
" RegressionMLPEnsemble,\n",
" SingleTaskGPSurrogate,\n",
")"
]
},
{
Expand All @@ -47,130 +55,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x_1</th>\n",
" <th>x_2</th>\n",
" <th>y</th>\n",
" <th>valid_y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.010799</td>\n",
" <td>-0.612165</td>\n",
" <td>184.746878</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-1.779981</td>\n",
" <td>-0.137665</td>\n",
" <td>140.265892</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-5.063193</td>\n",
" <td>-3.811183</td>\n",
" <td>123.236129</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.114825</td>\n",
" <td>5.419270</td>\n",
" <td>1178.897832</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>-2.921467</td>\n",
" <td>-2.808005</td>\n",
" <td>31.952544</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.906090</td>\n",
" <td>4.477183</td>\n",
" <td>227.148355</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>3.319714</td>\n",
" <td>-2.211923</td>\n",
" <td>6.272053</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>3.629923</td>\n",
" <td>-0.748149</td>\n",
" <td>9.937792</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-1.612215</td>\n",
" <td>4.451890</td>\n",
" <td>141.192994</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>0.242512</td>\n",
" <td>4.767127</td>\n",
" <td>293.096581</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" x_1 x_2 y valid_y\n",
"0 5.010799 -0.612165 184.746878 1\n",
"1 -1.779981 -0.137665 140.265892 1\n",
"2 -5.063193 -3.811183 123.236129 1\n",
"3 5.114825 5.419270 1178.897832 1\n",
"4 -2.921467 -2.808005 31.952544 1\n",
"5 0.906090 4.477183 227.148355 1\n",
"6 3.319714 -2.211923 6.272053 1\n",
"7 3.629923 -0.748149 9.937792 1\n",
"8 -1.612215 4.451890 141.192994 1\n",
"9 0.242512 4.767127 293.096581 1"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"benchmark = Himmelblau()\n",
"samples = benchmark.domain.inputs.sample(n=50)\n",
Expand Down Expand Up @@ -214,7 +101,7 @@
}
],
"source": [
"input_features.json()"
"input_features.model_dump_json()"
]
},
{
Expand All @@ -234,7 +121,7 @@
}
],
"source": [
"output_features.json()"
"output_features.model_dump_json()"
]
},
{
Expand Down Expand Up @@ -290,7 +177,7 @@
"metadata": {},
"outputs": [],
"source": [
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))"
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)"
]
},
{
Expand Down Expand Up @@ -390,7 +277,7 @@
}
],
"source": [
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"surrogate = surrogates.map(surrogate_data)\n",
"surrogate.loads(dump)\n",
"\n",
Expand Down Expand Up @@ -450,8 +337,8 @@
"outputs": [],
"source": [
"# Load it from the spec\n",
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"# Map it \n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"# Map it\n",
"surrogate = surrogates.map(surrogate_data)\n",
"# Fit it\n",
"surrogate.fit(experiments=experiments)\n",
Expand Down Expand Up @@ -480,7 +367,7 @@
}
],
"source": [
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data =TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"surrogate = surrogates.map(surrogate_data)\n",
"surrogate.loads(dump)\n",
"\n",
Expand Down Expand Up @@ -540,7 +427,7 @@
"outputs": [],
"source": [
"# Load it from the spec\n",
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"# Map it \n",
"surrogate = surrogates.map(surrogate_data)\n",
"# Fit it\n",
Expand Down Expand Up @@ -570,7 +457,7 @@
}
],
"source": [
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"surrogate = surrogates.map(surrogate_data)\n",
"surrogate.loads(dump)\n",
"\n",
Expand Down Expand Up @@ -602,6 +489,7 @@
"from botorch.models.deterministic import DeterministicModel\n",
"from torch import Tensor\n",
"\n",
"\n",
"class HimmelblauModel(DeterministicModel):\n",
" def __init__(self):\n",
" super().__init__()\n",
Expand Down Expand Up @@ -638,7 +526,7 @@
")\n",
"\n",
"# we generate the json spec\n",
"jspec = surrogate_data.json()\n",
"jspec = surrogate_data.model_dump_json()\n",
"\n",
"jspec"
]
Expand All @@ -650,8 +538,8 @@
"outputs": [],
"source": [
"# Load it from the spec\n",
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"# Map it \n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"# Map it\n",
"surrogate = surrogates.map(surrogate_data)\n",
"# attach the actual model to it\n",
"surrogate.model = HimmelblauModel()\n",
Expand Down Expand Up @@ -680,7 +568,7 @@
}
],
"source": [
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"surrogate = surrogates.map(surrogate_data)\n",
"surrogate.loads(dump)\n",
"\n",
Expand Down Expand Up @@ -941,7 +829,7 @@
"outputs": [],
"source": [
"# Load it from the spec\n",
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"# Map it \n",
"surrogate = surrogates.map(surrogate_data)\n",
"# Fit it\n",
Expand Down Expand Up @@ -971,7 +859,7 @@
}
],
"source": [
"surrogate_data = parse_obj_as(AnySurrogate, json.loads(jspec))\n",
"surrogate_data = TypeAdapter(AnySurrogate).validate_json(jspec)\n",
"surrogate = surrogates.map(surrogate_data)\n",
"surrogate.loads(dump)\n",
"\n",
Expand All @@ -983,13 +871,6 @@
"# check for equality\n",
"predictions==predictions2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -1008,7 +889,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.13"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
Loading

0 comments on commit 39fb5fa

Please sign in to comment.