diff --git a/botorch/utils/probability/mvnxpb.py b/botorch/utils/probability/mvnxpb.py index ee479fa40f..4d6edb3539 100644 --- a/botorch/utils/probability/mvnxpb.py +++ b/botorch/utils/probability/mvnxpb.py @@ -176,18 +176,20 @@ def solve(self, num_steps: Optional[int] = None, eps: float = 1e-10) -> Tensor: if pivot is not None and torch.any(pivot > i): self.pivot_(pivot=pivot) - # Initialize `i`-th plug-in value as univariate conditional expectation + # Compute whitened bounds conditional on preceding plug-ins Lii = L[..., i, i].clone() if should_update_chol: - Lii = Lii.clip(min=0).sqrt() + Lii = Lii.clip(min=0).sqrt() # conditional stddev inv_Lii = Lii.reciprocal() - if i == 0: - lb, ub = bounds[..., i, :].clone().unbind(dim=-1) - else: - db = (L[..., i, :i].clone() * y[..., :i].clone()).sum(-1, keepdim=True) - lb, ub = (bounds[..., i, :].clone() - db).unbind(dim=-1) + bounds_i = bounds[..., i, :].clone() + if i != 0: + bounds_i = bounds_i - torch.sum( + L[..., i, :i].clone() * y[..., :i].clone(), dim=-1, keepdim=True + ) + lb, ub = (inv_Lii.unsqueeze(-1) * bounds_i).unbind(dim=-1) - Phi_i = Phi(inv_Lii * ub) - Phi(inv_Lii * lb) + # Initialize `i`-th plug-in value as univariate conditional expectation + Phi_i = Phi(ub) - Phi(lb) small = Phi_i <= i * eps y[..., i] = case_dispatcher( # used to select next pivot out=(phi(lb) - phi(ub)) / Phi_i, @@ -224,7 +226,7 @@ def solve(self, num_steps: Optional[int] = None, eps: float = 1e-10) -> Tensor: # Replace 1D expectations with 2D ones `L[blk, blk]^{-1} y[..., blk]` mask = blk_prob > zero y[..., h] = torch.where(mask, zh, zero) - y[..., i] = torch.where(mask, (std_i * zi - Lih * zh) / Lii, zero) + y[..., i] = torch.where(mask, inv_Lii * (std_i * zi - Lih * zh), zero) # Update running approximation to log probability self.log_prob = self.log_prob + safe_log(blk_prob)