Skip to content

Commit

Permalink
Inherit from LogImprovementMCAcquisitionFunction; Add _sample_forward…
Browse files Browse the repository at this point in the history
… method
  • Loading branch information
SaiAakash committed Jan 30, 2025
1 parent 5223198 commit f9b7fb9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
23 changes: 10 additions & 13 deletions botorch_community/acquisition/rei.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@
_scaled_improvement,
AnalyticAcquisitionFunction,
)
from botorch.acquisition.logei import _log_improvement, check_tau
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.logei import (
_log_improvement,
check_tau,
LogImprovementMCAcquisitionFunction,
)
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.safe_math import logmeanexp
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor

TAU_RELU = 1e-6
Expand Down Expand Up @@ -125,7 +128,7 @@ def forward(self, X: Tensor) -> Tensor:
return logrei


class qLogRegionalExpectedImprovement(MCAcquisitionFunction):
class qLogRegionalExpectedImprovement(LogImprovementMCAcquisitionFunction):
def __init__(
self,
model: Model,
Expand Down Expand Up @@ -191,10 +194,7 @@ def __init__(
device=self.X_dev.device, dtype=self.X_dev.dtype
)

@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
batch_shape = X.shape[0]
def _sample_forward(self, X: Tensor) -> Tensor:
q = X.shape[1]
d = X.shape[2]

Expand All @@ -208,9 +208,6 @@ def forward(self, X: Tensor) -> Tensor:
)
samples = self.get_posterior_samples(posterior)
obj = self.objective(samples, X=Xs)
obj = _log_improvement(obj, self.best_f, self.tau_relu, self.fat).reshape(
-1, self.n_region, batch_shape, q
)
q_log_rei = obj.max(dim=-1)[0].mean(dim=(0, 1))
obj = _log_improvement(obj, self.best_f, self.tau_relu, self.fat)

return q_log_rei
return obj
2 changes: 0 additions & 2 deletions test_community/acquisition/test_rei.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ def _test_q_log_regional_expected_improvement(self, dtype: torch.dtype) -> None:
self.assertIsNone(acqf.X_pending)
acqf.set_X_pending(X)
self.assertEqual(acqf.X_pending, X)
mm._posterior._samples = torch.zeros(1, 2, 1, **tkwargs)
res = acqf(X)
X2 = torch.zeros(1, 1, 1, **tkwargs, requires_grad=True)
with warnings.catch_warnings(record=True) as ws:
acqf.set_X_pending(X2)
Expand Down

0 comments on commit f9b7fb9

Please sign in to comment.