diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 5cae0fc0f3..0b000053c6 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -23,51 +23,60 @@ from typing import Optional import torch -from botorch import settings +from botorch.acquisition import AcquisitionFunction from botorch_community.models.np_regression import NeuralProcessModel from torch import Tensor import torch #reference: https://arxiv.org/abs/2106.02770 -class LatentInformationGain: +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class LatentInformationGain(AcquisitionFunction): def __init__( self, + context_x: torch.Tensor, + context_y: torch.Tensor, model: NeuralProcessModel, num_samples: int = 10, - min_std: float = 0.1, - scaler: float = 0.9 + min_std: float = 0.01, + scaler: float = 0.5 ) -> None: """ Latent Information Gain (LIG) Acquisition Function, designed for the - NeuralProcessModel. + NeuralProcessModel. This is a subclass of AcquisitionFunction. Args: model: Trained NeuralProcessModel. + context_x: Context input points, as a Tensor. + context_y: Context target points, as a Tensor. num_samples (int): Number of samples for calculation, defaults to 10. min_std: Float representing the minimum possible standardized std, defaults to 0.1. scaler: Float scaling the std, defaults to 0.9. """ - self.model = model + super().__init__(model=model) + self.model = model.to(device) self.num_samples = num_samples self.min_std = min_std self.scaler = scaler + self.context_x = context_x.to(device) + self.context_y = context_y.to(device) - def acquisition(self, candidate_x, context_x, context_y): + def forward(self, candidate_x): """ Conduct the Latent Information Gain acquisition function for the inputs. Args: candidate_x: Candidate input points, as a Tensor. - context_x: Context input points, as a Tensor. - context_y: Context target points, as a Tensor. Returns: torch.Tensor: The LIG score of computed KLDs. """ + candidate_x = candidate_x.to(device) + # Encoding and Scaling the context data - z_mu_context, z_logvar_context = self.model.data_to_z_params(context_x, context_y) + z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y) kl = 0.0 for _ in range(self.num_samples): # Taking reparameterized samples @@ -77,8 +86,8 @@ def acquisition(self, candidate_x, context_x, context_y): y_pred = self.model.decoder(candidate_x, samples) # Combining context and candidate data - combined_x = torch.cat([context_x, candidate_x], dim=0) - combined_y = torch.cat([context_y, y_pred], dim=0) + combined_x = torch.cat([self.context_x, candidate_x], dim=0).to(device) + combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) # Computing posterior variables z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y)