Skip to content

Commit

Permalink
Minor patch to MVNXPB (#1933)
Browse files Browse the repository at this point in the history
Summary:
### Overview
This PR introduces a minor fix to the MVNXPB algorithm, so that bounds are properly whitened when computing initial plug-in estimators (which can impact the pivot order).

Pull Request resolved: #1933

Test Plan: I have run the existing unit tests locally.

Reviewed By: esantorella

Differential Revision: D47368553

Pulled By: Balandat

fbshipit-source-id: ac2dfe336a5d3ecf1dad14f11752b8fe8ce69d13
  • Loading branch information
j-wilson authored and facebook-github-bot committed Jul 12, 2023
1 parent 36c8c02 commit a1b38fc
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions botorch/utils/probability/mvnxpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,20 @@ def solve(self, num_steps: Optional[int] = None, eps: float = 1e-10) -> Tensor:
if pivot is not None and torch.any(pivot > i):
self.pivot_(pivot=pivot)

# Initialize `i`-th plug-in value as univariate conditional expectation
# Compute whitened bounds conditional on preceding plug-ins
Lii = L[..., i, i].clone()
if should_update_chol:
Lii = Lii.clip(min=0).sqrt()
Lii = Lii.clip(min=0).sqrt() # conditional stddev
inv_Lii = Lii.reciprocal()
if i == 0:
lb, ub = bounds[..., i, :].clone().unbind(dim=-1)
else:
db = (L[..., i, :i].clone() * y[..., :i].clone()).sum(-1, keepdim=True)
lb, ub = (bounds[..., i, :].clone() - db).unbind(dim=-1)
bounds_i = bounds[..., i, :].clone()
if i != 0:
bounds_i = bounds_i - torch.sum(
L[..., i, :i].clone() * y[..., :i].clone(), dim=-1, keepdim=True
)
lb, ub = (inv_Lii.unsqueeze(-1) * bounds_i).unbind(dim=-1)

Phi_i = Phi(inv_Lii * ub) - Phi(inv_Lii * lb)
# Initialize `i`-th plug-in value as univariate conditional expectation
Phi_i = Phi(ub) - Phi(lb)
small = Phi_i <= i * eps
y[..., i] = case_dispatcher( # used to select next pivot
out=(phi(lb) - phi(ub)) / Phi_i,
Expand Down Expand Up @@ -224,7 +226,7 @@ def solve(self, num_steps: Optional[int] = None, eps: float = 1e-10) -> Tensor:
# Replace 1D expectations with 2D ones `L[blk, blk]^{-1} y[..., blk]`
mask = blk_prob > zero
y[..., h] = torch.where(mask, zh, zero)
y[..., i] = torch.where(mask, (std_i * zi - Lih * zh) / Lii, zero)
y[..., i] = torch.where(mask, inv_Lii * (std_i * zi - Lih * zh), zero)

# Update running approximation to log probability
self.log_prob = self.log_prob + safe_log(blk_prob)
Expand Down

0 comments on commit a1b38fc

Please sign in to comment.