Skip to content

Commit

Permalink
1/25 Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
eibarolle authored Jan 25, 2025
1 parent 657151f commit 4f35e0f
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions botorch_community/acquisition/latent_information_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,60 @@
from typing import Optional

import torch
from botorch import settings
from botorch.acquisition import AcquisitionFunction
from botorch_community.models.np_regression import NeuralProcessModel
from torch import Tensor

import torch
#reference: https://arxiv.org/abs/2106.02770

class LatentInformationGain:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LatentInformationGain(AcquisitionFunction):
def __init__(
self,
context_x: torch.Tensor,
context_y: torch.Tensor,
model: NeuralProcessModel,
num_samples: int = 10,
min_std: float = 0.1,
scaler: float = 0.9
min_std: float = 0.01,
scaler: float = 0.5
) -> None:
"""
Latent Information Gain (LIG) Acquisition Function, designed for the
NeuralProcessModel.
NeuralProcessModel. This is a subclass of AcquisitionFunction.
Args:
model: Trained NeuralProcessModel.
context_x: Context input points, as a Tensor.
context_y: Context target points, as a Tensor.
num_samples (int): Number of samples for calculation, defaults to 10.
min_std: Float representing the minimum possible standardized std, defaults to 0.1.
scaler: Float scaling the std, defaults to 0.9.
"""
self.model = model
super().__init__(model=model)
self.model = model.to(device)
self.num_samples = num_samples
self.min_std = min_std
self.scaler = scaler
self.context_x = context_x.to(device)
self.context_y = context_y.to(device)

def acquisition(self, candidate_x, context_x, context_y):
def forward(self, candidate_x):
"""
Conduct the Latent Information Gain acquisition function for the inputs.
Args:
candidate_x: Candidate input points, as a Tensor.
context_x: Context input points, as a Tensor.
context_y: Context target points, as a Tensor.
Returns:
torch.Tensor: The LIG score of computed KLDs.
"""

candidate_x = candidate_x.to(device)

# Encoding and Scaling the context data
z_mu_context, z_logvar_context = self.model.data_to_z_params(context_x, context_y)
z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y)
kl = 0.0
for _ in range(self.num_samples):
# Taking reparameterized samples
Expand All @@ -77,8 +86,8 @@ def acquisition(self, candidate_x, context_x, context_y):
y_pred = self.model.decoder(candidate_x, samples)

# Combining context and candidate data
combined_x = torch.cat([context_x, candidate_x], dim=0)
combined_y = torch.cat([context_y, y_pred], dim=0)
combined_x = torch.cat([self.context_x, candidate_x], dim=0).to(device)
combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device)

# Computing posterior variables
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y)
Expand Down

0 comments on commit 4f35e0f

Please sign in to comment.