Skip to content

Commit

Permalink
Add Multi Information Source Augmented GP (#2152)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

This pull request introduces the implementation of the **Augmented Gaussian Process** (AGP) and a related acquisition function based on UCB, namely **Augmented UCB** (AUCB), for multi information source problems.

> Candelieri, A., Archetti, F. Sparsifying to optimize over multiple information sources: an augmented Gaussian process based algorithm. Struct Multidisc Optim 64, 239–255 (2021). [https://doi.org/10.1007/s00158-021-02882-7](https://doi.org/10.1007/s00158-021-02882-7)

### AGP and AUCB in a nutshell
The key idea of the AGP is to fit a GP model for each information source and *augment* the observations on the high fidelity source with those from *cheaper* sources which can be considered as *reliable*. The GP model fitted on this *augmented* set of observations is the AGP.

The AUCB is a modification of the standard UCB - computed on the AGP - suitably proposed to also deal with the source-specific query cost.

### Example
This is what the AGP and AUCB look like on the Forrester problem considering 2 sources.
<p align="center">
<img src="https://github.com/pytorch/botorch/assets/59694427/72c43b56-e08b-4b47-aea9-00925345890c" height="500">
</p>

### Some implementation details
The Augmented GP implementation is based on the `SingleTaskGP`.
Each source is implemented as an independent `SingleTaskGP` and all the _reliable_ observations are used to fit the SingleTaskGP representing the AGP, namely `SingleTaskAugmentedGP`. A key difference with the Multi-Fidelity approaches in BoTorch, is that the dimension representing the source is not directly modelled by the AGP.
In addition, a Fixed-Noise version of the AGP has been implemented based on the `FixedNoiseGP`.

The Augmented UCB implementation is based on the `UpperConfidenceBound`, but it is penalized by the cost of the source and the discrepancy from the AGP.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes, I have read the Contributing Guidelines on pull requests.

Pull Request resolved: #2152

Test Plan:
A bunch of unit tests for both, the AGP and AUCB, has been implemented. Those are inspired by the `SingleTaskGP` and `UpperConfidenceBound` tests, respectively.

In addition, a notebook tutorial shows how to use the AGP model along with the AUCB acquisition function and a comparison with the Multi-Fidelity MES on the Augmented Branin test function.

Finally, we plan to release an arXiv soon, with an extensive comparison of Multi-Fidelity and Multi Information Source approaches in BoTorch. Here are some preliminary results on the Augmented Hartmann test function, considering three sources (with fidelities $[0.50, 0.75, 1.00]$). The AGP has been compared with the discrete Multi-Fidelity version of the Knowled Gradient (KG), Max-value Entropy Search (MES) and GIBBON. The figure on the left shows the best seen with respect to the query cost, while the figure on the right shows the best seen with respect to the wall-clock time.
<p align="center">
<img src="https://github.com/pytorch/botorch/assets/59694427/77cfcebc-8123-4802-8045-5d9b359775ec" height="300">
</p>

## Related PRs
The docs has been updated in this PR.

Reviewed By: Balandat

Differential Revision: D52256404

Pulled By: sdaulton

fbshipit-source-id: 863437b488dcee6b37306dcd6c1ee6b63ca9c55f
  • Loading branch information
andreaponti5 authored and facebook-github-bot committed Feb 20, 2024
1 parent 5161bb5 commit defe657
Show file tree
Hide file tree
Showing 6 changed files with 1,692 additions and 1 deletion.
150 changes: 150 additions & 0 deletions botorch_community/acquisition/augmented_multisource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Multi-Source Upper Confidence Bound.
References:
.. [Ca2021ms]
Candelieri, A., & Archetti, F. (2021).
Sparsifying to optimize over multiple information sources:
an augmented Gaussian process based algorithm.
Structural and Multidisciplinary Optimization, 64, 239-255.
Contributor: andreaponti5
"""

from __future__ import annotations

from typing import Dict, Optional, Tuple, Union

import torch

from botorch.acquisition import UpperConfidenceBound
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import UnsupportedError
from botorch.models.model import Model
from botorch.utils.transforms import t_batch_mode_transform
from gpytorch.models import ExactGP
from torch import Tensor


class AugmentedUpperConfidenceBound(UpperConfidenceBound):
r"""Single-outcome Multi-Source Upper Confidence Bound (UCB).
A modified version of the UCB for Multi Information Source, that consider
the most optimistic improvement with respect to the best value observed so far.
The improvement is then penalized depending on source’s cost, and
the discrepancy between the GP associated to the source and the AGP.
`AUCB(x, s, y^+) = ((mu(x) + sqrt(beta) * sigma(x)) - y^+)
/ (c(s) (1 + abs(mu(x) - mu_s(x))))`,
where `mu` and `sigma` are the posterior mean and standard deviation of the AGP,
`mu_s` is the posterior mean of the GP modelling the s-th source and
c(s) is the cost of the source s.
"""

def __init__(
self,
model: Model,
cost: Dict,
best_f: Union[float, Tensor],
beta: Union[float, Tensor],
posterior_transform: Optional[PosteriorTransform] = None,
maximize: bool = True,
) -> None:
r"""Single-outcome Augmented Upper Confidence Bound.
Args:
model: A fitted single-outcome Augmented GP model.
beta: Either a scalar or a one-dim tensor with `b` elements (batch mode)
representing the trade-off parameter between mean and covariance
cost: A dictionary containing the cost of querying each source.
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
the best function value observed so far (assumed noiseless).
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
"""
if not hasattr(model, "models"):
raise UnsupportedError("Model have to be multi-source.")
super().__init__(
model=model,
beta=beta,
maximize=maximize,
posterior_transform=posterior_transform,
)
self.cost = cost
self.best_f = best_f

@t_batch_mode_transform(expected_q=1)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the Upper Confidence Bound on the candidate set X.
Args:
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
Returns:
A `(b1 x ... bk)`-dim tensor of Augmented Upper Confidence Bound values at
the given design points `X`.
"""
alpha = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
agp_mean, agp_sigma = self._mean_and_sigma(X[..., :-1])
cb = (self.best_f if self.maximize else -self.best_f) + (
(agp_mean if self.maximize else -agp_mean) + self.beta.sqrt() * agp_sigma
)
source_idxs = {
s.item(): torch.where(torch.round(X[..., -1], decimals=0) == s)[0]
for s in torch.round(X[..., -1], decimals=0).unique().int()
}
for s in source_idxs:
mean, sigma = self._mean_and_sigma(
X[source_idxs[s], :, :-1], self.model.models[s]
)
alpha[source_idxs[s]] = (
cb[source_idxs[s]]
/ self.cost[s]
* (1 + torch.abs(agp_mean[source_idxs[s]] - mean))
)
return alpha

def _mean_and_sigma(
self,
X: Tensor,
model: ExactGP = None,
compute_sigma: bool = True,
min_var: float = 1e-12,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""Computes the first and second moments of the model posterior.
Args:
X: `batch_shape x q x d`-dim Tensor of model inputs.
model: the model to use. If None, self is used.
compute_sigma: Boolean indicating whether to compute the second
moment (default: True).
min_var: The minimum value the variance is clamped too. Should be positive.
Returns:
A tuple of tensors containing the first and second moments of the model
posterior. Removes the last two dimensions if they have size one. Only
returns a single tensor of means if compute_sigma is True.
"""
self.to(device=X.device)
if model is None:
posterior = self.model.posterior(
X=X, posterior_transform=self.posterior_transform
)
else:
posterior = model.posterior(
X=X, posterior_transform=self.posterior_transform
)
mean = posterior.mean.squeeze(-2).squeeze(-1) # removing redundant dimensions
if not compute_sigma:
return mean, None
sigma = posterior.variance.clamp_min(min_var).sqrt().view(mean.shape)
return mean, sigma
Loading

0 comments on commit defe657

Please sign in to comment.