diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index d1f27ef17ad..65d060c0679 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -5,12 +5,13 @@ # pyre-strict -from typing import Any +from typing import Any, Sequence from ax.benchmark.benchmark_method import BenchmarkMethod from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Models +from ax.modelbridge.transforms.base import Transform from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import LogExpectedImprovement @@ -39,6 +40,7 @@ def get_sobol_mbm_generation_strategy( model_cls: type[Model], acquisition_cls: type[AcquisitionFunction], + transforms: Sequence[type[Transform]] | None, name: str | None = None, num_sobol_trials: int = 5, model_gen_kwargs: dict[str, Any] | None = None, @@ -75,6 +77,8 @@ def get_sobol_mbm_generation_strategy( "botorch_acqf_class": acquisition_cls, "surrogate_spec": SurrogateSpec(botorch_model_class=model_cls), } + if transforms is not None: + model_kwargs["transforms"] = transforms model_name = model_names_abbrevations.get(model_cls.__name__, model_cls.__name__) acqf_name = acqf_name_abbreviations.get( @@ -109,6 +113,7 @@ def get_sobol_botorch_modular_acquisition( model_cls: type[Model], acquisition_cls: type[AcquisitionFunction], distribute_replications: bool, + transforms: Sequence[type[Transform]] | None, name: str | None = None, num_sobol_trials: int = 5, model_gen_kwargs: dict[str, Any] | None = None, @@ -162,6 +167,7 @@ def get_sobol_botorch_modular_acquisition( generation_strategy = get_sobol_mbm_generation_strategy( model_cls=model_cls, acquisition_cls=acquisition_cls, + transforms=transforms, name=name, num_sobol_trials=num_sobol_trials, model_gen_kwargs=model_gen_kwargs, diff --git a/ax/core/map_data.py b/ax/core/map_data.py index 5b2685b54d1..3ec9e8f1044 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -7,12 +7,14 @@ from __future__ import annotations +from bisect import bisect_right from collections.abc import Iterable, Sequence from copy import deepcopy from logging import Logger from typing import Any, Generic, TypeVar import numpy as np +import numpy.typing as npt import pandas as pd from ax.core.data import Data from ax.core.types import TMapTrialEvaluation @@ -275,15 +277,15 @@ def from_multiple_data( def df(self) -> pd.DataFrame: """Returns a Data shaped DataFrame""" - # If map_keys is empty just return the df if self._memo_df is not None: return self._memo_df + # If map_keys is empty just return the df if len(self.map_keys) == 0: return self.map_df - self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates( - MapData.DEDUPLICATE_BY_COLUMNS, keep="last" + self._memo_df = _tail( + map_df=self.map_df, map_keys=self.map_keys, n=1, sort=True ) return self._memo_df @@ -337,6 +339,32 @@ def clone(self) -> MapData: description=self.description, ) + def latest( + self, + map_keys: list[str] | None = None, + rows_per_group: int = 1, + ) -> MapData: + """Return a new MapData with the most recently observed `rows_per_group` + rows for each (arm, metric) group, determined by the `map_key` values, + where higher implies more recent. + + This function considers only the relative ordering of the `map_key` values, + making it most suitable when these values are equally spaced. + + If `rows_per_group` is greater than the number of rows in a given + (arm, metric) group, then all rows are returned. + """ + if map_keys is None: + map_keys = self.map_keys + + return MapData( + df=_tail( + map_df=self.map_df, map_keys=map_keys, n=rows_per_group, sort=True + ), + map_key_infos=self.map_key_infos, + description=self.description, + ) + def subsample( self, map_key: str | None = None, @@ -345,11 +373,13 @@ def subsample( limit_rows_per_metric: int | None = None, include_first_last: bool = True, ) -> MapData: - """Subsample the `map_key` column in an equally-spaced manner (if there is - a `self.map_keys` is length one, then `map_key` can be set to None). The - values of the `map_key` column are not taken into account, so this function - is most reasonable when those values are equally-spaced. There are three - ways that this can be done: + """Return a new MapData that subsamples the `map_key` column in an + equally-spaced manner. If `self.map_keys` has a length of one, `map_key` + can be set to None. This function considers only the relative ordering + of the `map_key` values, making it most suitable when these values are + equally spaced. + + There are three ways that this can be done: 1. If `keep_every = k` is set, then every kth row of the DataFrame in the `map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`. In other words, every kth step of each (arm, metric) will be kept. @@ -411,6 +441,60 @@ def subsample( ) +def _ceil_divide( + a: int | np.int_ | npt.NDArray[np.int_], b: int | np.int_ | npt.NDArray[np.int_] +) -> np.int_ | npt.NDArray[np.int_]: + return -np.floor_divide(-a, b) + + +def _subsample_rate( + map_df: pd.DataFrame, + keep_every: int | None = None, + limit_rows_per_group: int | None = None, + limit_rows_per_metric: int | None = None, +) -> int: + if keep_every is not None: + return keep_every + + grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS) + group_sizes = grouped_map_df.size() + max_rows = group_sizes.max() + + if limit_rows_per_group is not None: + return _ceil_divide(max_rows, limit_rows_per_group).item() + + if limit_rows_per_metric is not None: + # search for the `keep_every` such that when you apply it to each group, + # the total number of rows is smaller than `limit_rows_per_metric`. + ks = np.arange(max_rows, 0, -1) + # total sizes in ascending order + total_sizes = np.sum( + _ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1 + ) + # binary search + i = bisect_right(total_sizes, limit_rows_per_metric) + # if no such `k` is found, then `derived_keep_every` stays as 1. + if i > 0: + return ks[i - 1].item() + + raise ValueError( + "at least one of `keep_every`, `limit_rows_per_group`, " + "or `limit_rows_per_metric` must be specified." + ) + + +def _tail( + map_df: pd.DataFrame, + map_keys: list[str], + n: int = 1, + sort: bool = True, +) -> pd.DataFrame: + df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n) + if sort: + df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True) + return df + + def _subsample_one_metric( map_df: pd.DataFrame, map_key: str | None = None, @@ -420,30 +504,21 @@ def _subsample_one_metric( include_first_last: bool = True, ) -> pd.DataFrame: """Helper function to subsample a dataframe that holds a single metric.""" - derived_keep_every = 1 - if keep_every is not None: - derived_keep_every = keep_every - elif limit_rows_per_group is not None: - max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max() - derived_keep_every = np.ceil(max_rows / limit_rows_per_group) - elif limit_rows_per_metric is not None: - group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy() - # search for the `keep_every` such that when you apply it to each group, - # the total number of rows is smaller than `limit_rows_per_metric`. - for k in range(1, group_sizes.max() + 1): - if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric: - derived_keep_every = k - break - # if no such `k` is found, then `derived_keep_every` stays as 1. + + grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS) + + derived_keep_every = _subsample_rate( + map_df, keep_every, limit_rows_per_group, limit_rows_per_metric + ) if derived_keep_every <= 1: filtered_map_df = map_df else: filtered_dfs = [] - for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS): + for _, df_g in grouped_map_df: df_g = df_g.sort_values(map_key) if include_first_last: - rows_per_group = int(np.ceil(len(df_g) / derived_keep_every)) + rows_per_group = _ceil_divide(len(df_g), derived_keep_every) linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group) idcs = np.round(linspace_idcs).astype(int) filtered_df = df_g.iloc[idcs] diff --git a/ax/core/observation.py b/ax/core/observation.py index 50bd1550572..94ee013d7a5 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -426,7 +426,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: feature_cols = OBS_COLS.intersection(data.df.columns) # note we use this check, rather than isinstance, since # only some Modelbridges (e.g. MapTorchModelBridge) - # use observations_from_map_data, which is required + # use observations_from_data, which is required # to properly handle MapData features (e.g. fidelity). if is_map_data: data = checked_cast(MapData, data) @@ -448,174 +448,113 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: def observations_from_data( experiment: experiment.Experiment, - data: Data, - statuses_to_include: set[TrialStatus] | None = None, - statuses_to_include_map_metric: set[TrialStatus] | None = None, -) -> list[Observation]: - """Convert Data to observations. - - Converts a Data object to a list of Observation objects. Pulls arm parameters from - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. - - Uses a diagonal covariance matrix across metric_names. - - Args: - experiment: Experiment with arm parameters. - data: Data of observations. - statuses_to_include: data from non-MapMetrics will only be included for trials - with statuses in this set. Defaults to all statuses except abandoned. - statuses_to_include_map_metric: data from MapMetrics will only be included for - trials with statuses in this set. Defaults to completed status only. - - Returns: - List of Observation objects. - """ - if statuses_to_include is None: - statuses_to_include = NON_ABANDONED_STATUSES - if statuses_to_include_map_metric is None: - statuses_to_include_map_metric = {TrialStatus.COMPLETED} - feature_cols = get_feature_cols(data) - observations = [] - arm_name_only = len(feature_cols) == 1 # there will always be an arm name - # One DataFrame where all rows have all features. - isnull = data.df[feature_cols].isnull() - isnull_any = isnull.any(axis=1) - incomplete_df_cols = isnull[isnull_any].any() - - # Get the incomplete_df columns that are complete, and usable as groupby keys. - complete_feature_cols = list( - OBS_COLS.intersection(incomplete_df_cols.index[~incomplete_df_cols]) - ) - - if set(feature_cols) == set(complete_feature_cols): - complete_df = data.df - incomplete_df = None - else: - # The groupby and filter is expensive, so do it only if we have to. - grouped = data.df.groupby(by=complete_feature_cols) - complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) - incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) - - # Get Observations from complete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=complete_df, - cols=feature_cols, - arm_name_only=arm_name_only, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - ) - if incomplete_df is not None: - # Get Observations from incomplete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=incomplete_df, - cols=complete_feature_cols, - arm_name_only=arm_name_only, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - ) - return observations - - -def observations_from_map_data( - experiment: experiment.Experiment, - map_data: MapData, + data: Data | MapData, statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, + latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, ) -> list[Observation]: - """Convert MapData to observations. + """Convert Data (or MapData) to observations. - Converts a MapData object to a list of Observation objects. Pulls arm parameters - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. + Converts a Data (or MapData) object to a list of Observation objects. + Pulls arm parameters from from experiment. Overrides fidelity parameters + in the arm with those found in the Data object. Uses a diagonal covariance matrix across metric_names. Args: experiment: Experiment with arm parameters. - map_data: MapData of observations. + data: Data (or MapData) of observations. statuses_to_include: data from non-MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. statuses_to_include_map_metric: data from MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. map_keys_as_parameters: Whether map_keys should be returned as part of the parameters of the Observation objects. - limit_rows_per_metric: If specified, uses MapData.subsample() with - `limit_rows_per_metric` equal to the specified value on the first - map_key (map_data.map_keys[0]) to subsample the MapData. This is - useful in, e.g., cases where learning curves are frequently - updated, leading to an intractable number of Observation objects - created. - limit_rows_per_group: If specified, uses MapData.subsample() with - `limit_rows_per_group` equal to the specified value on the first - map_key (map_data.map_keys[0]) to subsample the MapData. + latest_rows_per_group: If specified and data is an instance of MapData, + uses MapData.latest() with `rows_per_group=latest_rows_per_group` to + retrieve the most recent rows for each group. Useful in cases where + learning curves are frequently updated, preventing an excessive + number of Observation objects. Overrides `limit_rows_per_metric` + and `limit_rows_per_group`. + limit_rows_per_metric: If specified and data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_metric` on the first + map_key (map_data.map_keys[0]) to subsample the MapData. Useful for + managing the number of Observation objects when learning curves are + frequently updated. Ignored if `latest_rows_per_group` is specified. + limit_rows_per_group: If specified and data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_group` on the first + map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if + `latest_rows_per_group` is specified. Returns: List of Observation objects. """ + is_map_data = isinstance(data, MapData) + if statuses_to_include is None: statuses_to_include = NON_ABANDONED_STATUSES if statuses_to_include_map_metric is None: statuses_to_include_map_metric = NON_ABANDONED_STATUSES - if limit_rows_per_metric is not None or limit_rows_per_group is not None: - map_data = map_data.subsample( - map_key=map_data.map_keys[0], - limit_rows_per_metric=limit_rows_per_metric, - limit_rows_per_group=limit_rows_per_group, - include_first_last=True, - ) - feature_cols = get_feature_cols(map_data, is_map_data=True) - observations = [] + + map_keys = [] + obs_cols = OBS_COLS + if is_map_data: + data = checked_cast(MapData, data) + + if latest_rows_per_group is not None: + data = data.latest( + map_keys=data.map_keys, rows_per_group=latest_rows_per_group + ) + else: + if limit_rows_per_metric is not None or limit_rows_per_group is not None: + data = data.subsample( + map_key=data.map_keys[0], + limit_rows_per_metric=limit_rows_per_metric, + limit_rows_per_group=limit_rows_per_group, + include_first_last=True, + ) + + map_keys.extend(data.map_keys) + obs_cols = obs_cols.union(data.map_keys) + df = data.map_df + else: + df = data.df + + feature_cols = get_feature_cols(data, is_map_data=is_map_data) + arm_name_only = len(feature_cols) == 1 # there will always be an arm name # One DataFrame where all rows have all features. - isnull = map_data.map_df[feature_cols].isnull() + isnull = df[feature_cols].isnull() isnull_any = isnull.any(axis=1) incomplete_df_cols = isnull[isnull_any].any() # Get the incomplete_df columns that are complete, and usable as groupby keys. - obs_cols_and_map = OBS_COLS.union(map_data.map_keys) complete_feature_cols = list( - obs_cols_and_map.intersection(incomplete_df_cols.index[~incomplete_df_cols]) + obs_cols.intersection(incomplete_df_cols.index[~incomplete_df_cols]) ) if set(feature_cols) == set(complete_feature_cols): - complete_df = map_data.map_df + complete_df = df incomplete_df = None else: # The groupby and filter is expensive, so do it only if we have to. - grouped = map_data.map_df.groupby( - by=( - complete_feature_cols - if len(complete_feature_cols) > 1 - else complete_feature_cols[0] - ) - ) + grouped = df.groupby(by=complete_feature_cols) complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) # Get Observations from complete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=complete_df, - cols=feature_cols, - arm_name_only=arm_name_only, - map_keys=map_data.map_keys, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys_as_parameters=map_keys_as_parameters, - ) + observations = _observations_from_dataframe( + experiment=experiment, + df=complete_df, + cols=feature_cols, + arm_name_only=arm_name_only, + map_keys=map_keys, + statuses_to_include=statuses_to_include, + statuses_to_include_map_metric=statuses_to_include_map_metric, + map_keys_as_parameters=map_keys_as_parameters, ) if incomplete_df is not None: # Get Observations from incomplete_df @@ -625,7 +564,7 @@ def observations_from_map_data( df=incomplete_df, cols=complete_feature_cols, arm_name_only=arm_name_only, - map_keys=map_data.map_keys, + map_keys=map_keys, statuses_to_include=statuses_to_include, statuses_to_include_map_metric=statuses_to_include_map_metric, map_keys_as_parameters=map_keys_as_parameters, diff --git a/ax/core/tests/test_map_data.py b/ax/core/tests/test_map_data.py index ce0576e295c..0b4f1f5fd22 100644 --- a/ax/core/tests/test_map_data.py +++ b/ax/core/tests/test_map_data.py @@ -6,6 +6,7 @@ # pyre-strict +import numpy as np import pandas as pd from ax.core.data import Data from ax.core.map_data import MapData, MapKeyInfo @@ -236,7 +237,17 @@ def test_upcast(self) -> None: self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call - def test_subsample(self) -> None: + self.assertTrue( + fresh.df.equals( + fresh.map_df.sort_values(fresh.map_keys).drop_duplicates( + MapData.DEDUPLICATE_BY_COLUMNS, keep="last" + ) + ) + ) + + def test_latest(self) -> None: + seed = 8888 + arm_names = ["0_0", "1_0", "2_0", "3_0"] max_epochs = [25, 50, 75, 100] metric_names = ["a", "b"] @@ -259,6 +270,68 @@ def test_subsample(self) -> None: ) large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos) + shuffled_large_map_df = large_map_data.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).sample(frac=1, random_state=seed) + shuffled_large_map_data = MapData( + df=shuffled_large_map_df, map_key_infos=self.map_key_infos + ) + + for rows_per_group in [1, 40]: + large_map_data_latest = large_map_data.latest(rows_per_group=rows_per_group) + + if rows_per_group == 1: + self.assertTrue( + large_map_data_latest.map_df.groupby("metric_name") + .epoch.transform(lambda col: set(col) == set(max_epochs)) + .all() + ) + + # when rows_per_group is larger than the number of rows + # actually observed in a group + actual_rows_per_group = large_map_data_latest.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).size() + expected_rows_per_group = np.minimum( + large_map_data_latest.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).epoch.max(), + rows_per_group, + ) + self.assertTrue(actual_rows_per_group.equals(expected_rows_per_group)) + + # behavior should be consistent even if map_keys are not in ascending order + shuffled_large_map_data_latest = shuffled_large_map_data.latest( + rows_per_group=rows_per_group + ) + self.assertTrue( + shuffled_large_map_data_latest.map_df.equals( + large_map_data_latest.map_df + ) + ) + + def test_subsample(self) -> None: + arm_names = ["0_0", "1_0", "2_0", "3_0"] + max_epochs = [25, 50, 75, 100] + metric_names = ["a", "b"] + large_map_df = pd.DataFrame( + [ + { + "arm_name": arm_name, + "epoch": epoch + 1, + "mean": epoch * 0.1, + "sem": 0.1, + "trial_index": trial_index, + "metric_name": metric_name, + } + for metric_name in metric_names + for trial_index, (arm_name, max_epoch) in enumerate( + zip(arm_names, max_epochs) + ) + for epoch in range(max_epoch) + ] + ) + large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos) large_map_df_sparse_metric = pd.DataFrame( [ { diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 2c304353502..849cc69ab2f 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -24,7 +24,6 @@ ObservationData, ObservationFeatures, observations_from_data, - observations_from_map_data, recombine_observations, separate_observations, ) @@ -475,7 +474,7 @@ def test_ObservationsFromMapData(self) -> None: MapKeyInfo(key="timestamp", default_value=0.0), ], ) - observations = observations_from_map_data(experiment, data) + observations = observations_from_data(experiment, data) self.assertEqual(len(observations), 3) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 54346c39136..c2374c34f8d 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -295,8 +295,10 @@ def _prepare_observations( return observations_from_data( experiment=experiment, data=data, + latest_rows_per_group=1, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=False, ) def _transform_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 9fa0f119147..8acbd31746d 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -20,7 +20,7 @@ Observation, ObservationData, ObservationFeatures, - observations_from_map_data, + observations_from_data, separate_observations, ) from ax.core.optimization_config import OptimizationConfig @@ -252,19 +252,16 @@ def _array_to_observation_features( def _prepare_observations( self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: - """The difference b/t this method and ModelBridge._prepare_observations(...) - is that this one uses `observations_from_map_data`. - """ if experiment is None or data is None: return [] - return observations_from_map_data( + return observations_from_data( experiment=experiment, - map_data=data, # pyre-ignore[6]: Checked in __init__. - map_keys_as_parameters=True, + data=data, limit_rows_per_metric=self._map_data_limit_rows_per_metric, limit_rows_per_group=self._map_data_limit_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=True, ) def _compute_in_design( diff --git a/ax/modelbridge/transforms/map_key_to_float.py b/ax/modelbridge/transforms/map_key_to_float.py new file mode 100644 index 00000000000..1ec645aff51 --- /dev/null +++ b/ax/modelbridge/transforms/map_key_to_float.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Optional, TYPE_CHECKING + +from ax.core.map_metric import MapMetric +from ax.core.observation import Observation, ObservationFeatures +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.metadata_to_range import MetadataToFloat +from ax.models.types import TConfig +from pyre_extensions import assert_is_instance + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +class MapKeyToFloat(MetadataToFloat): + DEFAULT_LOG_SCALE: bool = True + DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + config = config or {} + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.setdefault("parameters", {}), dict + ) + # TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?) + self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) + super().__init__( + search_space=search_space, + observations=observations, + modelbridge=modelbridge, + config=config, + ) + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + if not obsf.parameters: + for p in self._parameter_list: + # TODO[tiao]: can we use be p.target_value? + # (not its original intended use but could be advantageous) + obsf.parameters[p.name] = p.upper + return + super()._transform_observation_feature(obsf) diff --git a/ax/modelbridge/transforms/metadata_to_float.py b/ax/modelbridge/transforms/metadata_to_float.py new file mode 100644 index 00000000000..d74af4604fe --- /dev/null +++ b/ax/modelbridge/transforms/metadata_to_float.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from logging import Logger +from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING + +from ax.core import ParameterType + +from ax.core.observation import Observation, ObservationFeatures +from ax.core.parameter import RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.base import Transform +from ax.models.types import TConfig +from ax.utils.common.logger import get_logger +from pyre_extensions import assert_is_instance, none_throws + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +logger: Logger = get_logger(__name__) + + +class MetadataToFloat(Transform): + """ + This transform converts metadata from observation features into range (float) + parameters for a search space. + + It allows the user to specify the `config` with `parameters` as the key, where + each entry maps a metadata key to a dictionary of keyword arguments for the + corresponding RangeParameter constructor. + + Transform is done in-place. + """ + + DEFAULT_LOG_SCALE: bool = False + DEFAULT_LOGIT_SCALE: bool = False + DEFAULT_IS_FIDELITY: bool = False + ENFORCE_BOUNDS: bool = False + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + if observations is None or not observations: + raise DataRequiredError( + "`MetadataToRange` transform requires non-empty data." + ) + config = config or {} + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.get("parameters", {}), dict + ) + + self._parameter_list: list[RangeParameter] = [] + for name in self.parameters: + lb = ub = None # de facto bounds + for obs in observations: + obsf_metadata = none_throws(obs.features.metadata) + + val = float(assert_is_instance(obsf_metadata[name], SupportsFloat)) + + lb = min(val, lb) if lb is not None else val + ub = max(val, ub) if ub is not None else val + + lower: float = self.parameters[name].get("lower", lb) + upper: float = self.parameters[name].get("upper", ub) + + log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) + logit_scale = self.parameters[name].get( + "logit_scale", self.DEFAULT_LOGIT_SCALE + ) + digits = self.parameters[name].get("digits") + is_fidelity = self.parameters[name].get( + "is_fidelity", self.DEFAULT_IS_FIDELITY + ) + + target_value = self.parameters[name].get("target_value") + + parameter = RangeParameter( + name=name, + parameter_type=ParameterType.FLOAT, + lower=lower, + upper=upper, + log_scale=log_scale, + logit_scale=logit_scale, + digits=digits, + is_fidelity=is_fidelity, + target_value=target_value, + ) + self._parameter_list.append(parameter) + + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: + for parameter in self._parameter_list: + search_space.add_parameter(parameter.clone()) + return search_space + + def transform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + self._transform_observation_feature(obsf) + return observation_features + + def untransform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + obsf.metadata = obsf.metadata or {} + _transfer( + src=obsf.parameters, + dst=obsf.metadata, + keys=self.parameters.keys(), + ) + return observation_features + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + _transfer( + src=none_throws(obsf.metadata), + dst=obsf.parameters, + keys=self.parameters.keys(), + ) + + +def _transfer( + src: dict[str, Any], + dst: dict[str, Any], + keys: Iterable[str], +) -> None: + """Transfer items in-place from one dictionary to another.""" + for key in keys: + dst[key] = src.pop(key) diff --git a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py new file mode 100644 index 00000000000..c1ddb8c1d34 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.map_key_to_float import MapKeyToFloat +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MapKeyToFloatTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + MapKeyToFloat.DEFAULT_MAP_KEY: steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + # does not require explicitly specifying `config` + self.t = MapKeyToFloat( + observations=self.observations, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertTrue(p.log_scale) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + # test that one is able to override default config + with self.subTest(msg="override default config"): + t = MapKeyToFloat( + observations=self.observations, + config={ + "parameters": {MapKeyToFloat.DEFAULT_MAP_KEY: {"log_scale": False}} + }, + ) + self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) + + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertFalse(p.log_scale) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters), + {"height", "width", MapKeyToFloat.DEFAULT_MAP_KEY}, + ) + + p = assert_is_instance( + ss2.parameters[MapKeyToFloat.DEFAULT_MAP_KEY], RangeParameter + ) + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertTrue(p.log_scale) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + MapKeyToFloat.DEFAULT_MAP_KEY: steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) + + def test_TransformObservationFeaturesWithEmptyParameters(self) -> None: + obsf = ObservationFeatures(parameters={}) + self.t.transform_observation_features([obsf]) + + p = self.t._parameter_list[0] + self.assertEqual( + obsf, + ObservationFeatures(parameters={MapKeyToFloat.DEFAULT_MAP_KEY: p.upper}), + ) diff --git a/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py b/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py new file mode 100644 index 00000000000..7c49f1df099 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.metadata_to_float import MetadataToFloat +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MetadataToFloatTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": 3.0 * steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + self.t = MetadataToFloat( + observations=self.observations, + config={ + "parameters": {"bar": {"log_scale": True}}, + }, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + # check that the parameter options are specified in a sensible manner + # by default if the user does not specify them explicitly + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): + MetadataToFloat(search_space=None, observations=None) + with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): + MetadataToFloat(search_space=None, observations=[]) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + t = MetadataToFloat( + observations=observations, + config={ + "parameters": {"bar": {}}, + }, + ) + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.INT) + self.assertEqual(p.lower, 1) + self.assertEqual(p.upper, 5) + self.assertFalse(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters.keys()), + {"height", "width", "bar"}, + ) + + p = assert_is_instance(ss2.parameters["bar"], RangeParameter) + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + "bar": 3.0 * steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 98a0c124cc7..c35831f380a 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -310,6 +310,24 @@ Transforms :undoc-members: :show-inheritance: + +`ax.modelbridge.transforms.metadata\_to\_float` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.metadata_to_float + :members: + :undoc-members: + :show-inheritance: + + +`ax.modelbridge.transforms.map\_key\_to\_float` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.map_key_to_float + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.rounding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~