Skip to content

Commit

Permalink
Make deprecated SurrogateSpec arguments InitVars (#3128)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3128

We do not want to keep the deprecated arguments as fields of the `SurrogateSpec`. We only convert them to `ModelConfig` in `__post_init__` and ignore them after. `InitVar` is the perfect tool for doing this.

Keeping these as fields leads to issues in other diffs where the values of the deprecated attribute goes out of sync with the specified `ModelConfig`.

Reviewed By: Balandat

Differential Revision: D66553624

fbshipit-source-id: b598fe3fedc914a7e68fcf183386b8761cfd11fb
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 28, 2024
1 parent 326700d commit a44cb2f
Showing 1 changed file with 67 additions and 56 deletions.
123 changes: 67 additions & 56 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from dataclasses import dataclass, field, InitVar
from logging import Logger
from typing import Any

import numpy as np

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
Expand Down Expand Up @@ -318,7 +317,7 @@ def _raise_deprecation_warning(
for k, v in kwargs.items():
should_raise = False
if k in default_is_dict:
if v != {}:
if v not in [{}, None]:
should_raise = True
elif (v is not None and k != "mll_class") or (
k == "mll_class" and v is not ExactMarginalLogLikelihood
Expand Down Expand Up @@ -452,72 +451,84 @@ class string names and the values are dictionaries of input transform
cross-validation.
"""

botorch_model_class: type[Model] | None = None
botorch_model_kwargs: dict[str, Any] = field(default_factory=dict)

mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_kwargs: dict[str, Any] = field(default_factory=dict)

covar_module_class: type[Kernel] | None = None
covar_module_kwargs: dict[str, Any] | None = None

likelihood_class: type[Likelihood] | None = None
likelihood_kwargs: dict[str, Any] | None = None

input_transform_classes: list[type[InputTransform]] | None = None
input_transform_options: dict[str, dict[str, Any]] | None = None

outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] | None = None

allow_batched_models: bool = True
# pyre-ignore [16]: Pyre doesn't understand InitVars.
botorch_model_class: InitVar[type[Model] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
botorch_model_kwargs: InitVar[dict[str, Any] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
mll_class: InitVar[type[MarginalLogLikelihood]] = ExactMarginalLogLikelihood
# pyre-ignore [16]: Pyre doesn't understand InitVars.
mll_kwargs: InitVar[dict[str, Any] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
covar_module_class: InitVar[type[Kernel] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
covar_module_kwargs: InitVar[dict[str, Any] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
likelihood_class: InitVar[type[Likelihood] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
likelihood_kwargs: InitVar[dict[str, Any] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
input_transform_classes: InitVar[list[type[InputTransform]] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
input_transform_options: InitVar[dict[str, dict[str, Any]] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
outcome_transform_classes: InitVar[list[type[OutcomeTransform]] | None] = None
# pyre-ignore [16]: Pyre doesn't understand InitVars.
outcome_transform_options: InitVar[dict[str, dict[str, Any]] | None] = None

model_configs: list[ModelConfig] = field(default_factory=list)
metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
eval_criterion: str = RANK_CORRELATION
outcomes: list[str] = field(default_factory=list)
allow_batched_models: bool = True
use_posterior_predictive: bool = False

def __post_init__(self) -> None:
def __post_init__(
self,
botorch_model_class: type[Model] | None,
botorch_model_kwargs: dict[str, Any] | None,
mll_class: type[MarginalLogLikelihood],
mll_kwargs: dict[str, Any] | None,
covar_module_class: type[Kernel] | None,
covar_module_kwargs: dict[str, Any] | None,
likelihood_class: type[Likelihood] | None,
likelihood_kwargs: dict[str, Any] | None,
input_transform_classes: list[type[InputTransform]] | None,
input_transform_options: dict[str, dict[str, Any]] | None,
outcome_transform_classes: list[type[OutcomeTransform]] | None,
outcome_transform_options: dict[str, dict[str, Any]] | None,
) -> None:
warnings_raised = _raise_deprecation_warning(
is_surrogate=False,
botorch_model_class=self.botorch_model_class,
botorch_model_kwargs=self.botorch_model_kwargs,
mll_class=self.mll_class,
mll_kwargs=self.mll_kwargs,
outcome_transform_classes=self.outcome_transform_classes,
outcome_transform_options=self.outcome_transform_options,
input_transform_classes=self.input_transform_classes,
input_transform_options=self.input_transform_options,
covar_module_class=self.covar_module_class,
covar_module_options=self.covar_module_kwargs,
likelihood_class=self.likelihood_class,
likelihood_options=self.likelihood_kwargs,
botorch_model_class=botorch_model_class,
botorch_model_kwargs=botorch_model_kwargs,
mll_class=mll_class,
mll_kwargs=mll_kwargs,
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=outcome_transform_options,
input_transform_classes=input_transform_classes,
input_transform_options=input_transform_options,
covar_module_class=covar_module_class,
covar_module_options=covar_module_kwargs,
likelihood_class=likelihood_class,
likelihood_options=likelihood_kwargs,
)
if len(self.model_configs) == 0:
model_config = get_model_config_from_deprecated_args(
botorch_model_class=self.botorch_model_class,
model_options=self.botorch_model_kwargs,
mll_class=self.mll_class,
mll_options=self.mll_kwargs,
outcome_transform_classes=self.outcome_transform_classes,
outcome_transform_options=self.outcome_transform_options,
input_transform_classes=self.input_transform_classes,
input_transform_options=self.input_transform_options,
covar_module_class=self.covar_module_class,
covar_module_options=self.covar_module_kwargs,
likelihood_class=self.likelihood_class,
likelihood_options=self.likelihood_kwargs,
)
# re-initialize with the non-deprecated arguments
self.__init__(
allow_batched_models=self.allow_batched_models,
model_configs=[model_config],
metric_to_model_configs=self.metric_to_model_configs,
eval_criterion=self.eval_criterion,
outcomes=self.outcomes,
use_posterior_predictive=self.use_posterior_predictive,
botorch_model_class=botorch_model_class,
model_options=botorch_model_kwargs,
mll_class=mll_class,
mll_options=mll_kwargs,
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=outcome_transform_options,
input_transform_classes=input_transform_classes,
input_transform_options=input_transform_options,
covar_module_class=covar_module_class,
covar_module_options=covar_module_kwargs,
likelihood_class=likelihood_class,
likelihood_options=likelihood_kwargs,
)
object.__setattr__(self, "model_configs", [model_config])
elif warnings_raised:
raise UserInputError(
"model_configs and deprecated arguments were both specified. "
Expand Down

0 comments on commit a44cb2f

Please sign in to comment.