Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for missing labels #2288

Merged
merged 31 commits into from
Sep 7, 2023
Merged

Conversation

Turakar
Copy link
Contributor

@Turakar Turakar commented Mar 2, 2023

Summary

Fix #1790. Fix #1881.

While GPyTorch does have limited support for missing labels, e.g. in GaussianLikelihoodWithMissingObs, it does not have general support for this. For example, neither single task nor multitask training / prediction allows for NaN target values. NaN values are especially useful for training multitask models with partially missing observations (cf. #1790).

To fix this, this PR adds the following new functionality:

  • Introduces a new setting observation_nan_policy with values ignore, mask and fill. See documentation and implementation for details.
  • This setting is implemented in ExactMarginalLogLikelihood for exact training, in DefaultPredictionStrategy for exact prediction, and in _GaussianLikelihoodBase for variational training.
  • Where possible, lazy evaluation is used. However, this does not work with fill during exact prediction, because I need to zero-out some of the elements in the kernel matrices.
  • MultitaskMultivariateNormal can now be indexed (required for indexing the observed values).
  • MultivariateNormal indexing can deal with a superfluous ... now.
  • GaussianLikelihoodWithMissingObs is now API-equivalent to GaussianLikelihood with observation_nan_policy('fill').

Alternative to missing data support

Alternatively, one may pass the task index as an additional input to the model. However, depending on the choice of the kernel matrix, it may become complicated to construct. It is conceptually simpler to construct the kernel matrix for all tasks and samples at once (e.g. this allows for BlockDiagOperator and alike) and then use the NaN values later for filtering before calculating the marginal log likelihood.

Alternative to using a setting for this

Either subclassing like GaussianLikelihoodWithMissingObs already does or passing a keyword argument everywhere. In my opinion, a setting is way more useful, especially considering that DefaultPredictionStrategy is deeply nested and hard to reach otherwise.

Open points

  • Performance questions: Which NaN observation policy is faster in which cases? At the moment, both strategies are already justified, because they have unique differences beyond their performance.
  • How should we proceed with GaussianLikelihoodWithMissingObs()? At the moment, I just removed it, but we probably want some sort of deprecation?

Examples
The following snippets demonstrate the abilities of the proposed changes:

Single Task
# https://github.com/cornellius-gp/gpytorch/issues/1881

import math
import torch
import gpytorch
import matplotlib.pyplot as plt


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


def main():
    # train_x = torch.linspace(0, 1, 101)
    # train_y = torch.sin(train_x * (2 * math.pi) * 2) + torch.randn(train_x.size()) * math.sqrt(0.04)

    train_x = torch.linspace(0, 1, 41)
    train_y = torch.sin(2 * torch.pi * train_x).squeeze()
    train_y += torch.normal(0, 0.01, train_y.shape)

    # nan out a few train_y
    train_y[::4] = torch.nan

    # initialize likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = ExactGPModel(train_x, train_y, likelihood)

    training_iter = 50

    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.item(),
            model.likelihood.noise.item()
        ))
        optimizer.step()

    # Get into evaluation (predictive posterior) mode
    model.eval()
    likelihood.eval()

    # Test points are regularly spaced along [0,1]
    # Make predictions by feeding model through likelihood
    with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.max_eager_kernel_size(0):
        test_x = torch.linspace(0, 1, 51)
        observed_pred = likelihood(model(test_x))

        # Initialize plot
        f, ax = plt.subplots(1, 1, figsize=(4, 3))

        # Get upper and lower confidence bounds
        lower, upper = observed_pred.confidence_region()
        # Plot training data as black stars
        ax.plot(train_x.numpy(), train_y.numpy(), 'k*')
        # Plot predictive means as blue line
        ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
        ax.set_ylim([-3, 3])
        ax.legend(['Observed Data', 'Mean', 'Confidence'])

        plt.show()


if __name__ == '__main__':
    with gpytorch.settings.observation_nan_policy("mask"):
        main()
Multitask
# https://github.com/cornellius-gp/gpytorch/issues/1881

import math
import torch
import gpytorch
import matplotlib.pyplot as plt


class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=2
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=2, rank=1
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)


def main():
    train_x = torch.linspace(0, 1, 100)

    train_y = torch.stack([
        torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
        torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    ], -1)

    # nan out a few train_y
    train_y[-30:, 1] = float('nan')

    # initialize likelihood and model
    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=2)
    model = MultitaskGPModel(train_x, train_y, likelihood)

    training_iter = 50

    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (
            i + 1, training_iter, loss.item()
        ))
        optimizer.step()

    # Get into evaluation (predictive posterior) mode
    model.eval()
    likelihood.eval()

    # Initialize plots
    f, (y1_ax, y2_ax) = plt.subplots(1, 2, figsize=(8, 3))

    # Make predictions
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        test_x = torch.linspace(0, 1, 51)
        predictions = likelihood(model(test_x))
        mean = predictions.mean
        lower, upper = predictions.confidence_region()

    # This contains predictions for both tasks, flattened out
    # The first half of the predictions is for the first task
    # The second half is for the second task

    # Plot training data as black stars
    y1_ax.plot(train_x.detach().numpy(), train_y[:, 0].detach().numpy(), 'k*')
    # Predictive mean as blue line
    y1_ax.plot(test_x.numpy(), mean[:, 0].numpy(), 'b')
    # Shade in confidence
    y1_ax.fill_between(test_x.numpy(), lower[:, 0].numpy(), upper[:, 0].numpy(), alpha=0.5)
    y1_ax.set_ylim([-3, 3])
    y1_ax.legend(['Observed Data', 'Mean', 'Confidence'])
    y1_ax.set_title('Observed Values (Likelihood)')

    # Plot training data as black stars
    y2_ax.plot(train_x.detach().numpy(), train_y[:, 1].detach().numpy(), 'k*')
    # Predictive mean as blue line
    y2_ax.plot(test_x.numpy(), mean[:, 1].numpy(), 'b')
    # Shade in confidence
    y2_ax.fill_between(test_x.numpy(), lower[:, 1].numpy(), upper[:, 1].numpy(), alpha=0.5)
    y2_ax.set_ylim([-3, 3])
    y2_ax.legend(['Observed Data', 'Mean', 'Confidence'])
    y2_ax.set_title('Observed Values (Likelihood)')

    plt.show()

    print(predictions[-1].covariance_matrix)
    print(predictions[:, -1].covariance_matrix.shape)


if __name__ == '__main__':
    with gpytorch.settings.observation_nan_policy("mask"):
        main()
Variational Multitask
# https://github.com/cornellius-gp/gpytorch/issues/1881

import math
import torch
from tqdm.auto import tqdm

import gpytorch
import matplotlib.pyplot as plt

from gpytorch.likelihoods import MultitaskGaussianLikelihood


class MultitaskGPModel(gpytorch.models.VariationalGP):
    def __init__(self, num_latents, num_tasks):
        # Let's use a different set of inducing points for each latent function
        inducing_points = torch.rand(3, 16, 1)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_latents])
        )

        # We have to wrap the VariationalStrategy in a LMCVariationalStrategy
        # so that the output will be a MultitaskMultivariateNormal rather than a batch output
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
            num_latents=num_latents,
            latent_dim=-1
        )
        super().__init__(variational_strategy)
        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
            batch_shape=torch.Size([num_latents])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


def main():
    num_latents = 3

    train_x = torch.linspace(0, 1, 100)

    train_y = torch.stack([
        torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
        torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
        torch.sin(train_x * (2 * math.pi)) + 2 * torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
        -torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    ], -1)
    num_tasks = train_y.shape[-1]

    # nan out a few train_y
    train_y[-30:, 1] = float('nan')

    # initialize likelihood and model
    likelihood = MultitaskGaussianLikelihood(num_tasks=num_tasks)
    model = MultitaskGPModel(num_latents, num_tasks)

    model.train()
    likelihood.train()

    optimizer = torch.optim.Adam([
        {'params': model.parameters()},
        {'params': likelihood.parameters()},
    ], lr=0.1)

    # Our loss object. We're using the VariationalELBO, which essentially just computes the ELBO
    # mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))
    mll = gpytorch.mlls.PredictiveLogLikelihood(likelihood, model, num_data=train_y.size(0))

    epochs_iter = tqdm(range(50), desc="Epoch")
    for _ in epochs_iter:
        # Within each iteration, we will go over each minibatch of data
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        epochs_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()

    # Get into evaluation (predictive posterior) mode
    model.eval()
    likelihood.eval()

    # Initialize plots
    fig, axs = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 3))

    # Make predictions
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        test_x = torch.linspace(0, 1, 51)
        predictions = likelihood(model(test_x))
        mean = predictions.mean
        lower, upper = predictions.confidence_region()

    for task, ax in enumerate(axs):
        # Plot training data as black stars
        ax.plot(train_x.detach().numpy(), train_y[:, task].detach().numpy(), 'k*')
        # Predictive mean as blue line
        ax.plot(test_x.numpy(), mean[:, task].numpy(), 'b')
        # Shade in confidence
        ax.fill_between(test_x.numpy(), lower[:, task].numpy(), upper[:, task].numpy(), alpha=0.5)
        ax.set_ylim([-3, 3])
        ax.legend(['Observed Data', 'Mean', 'Confidence'])
        ax.set_title(f'Task {task + 1}')

    fig.tight_layout()

    plt.show()


if __name__ == '__main__':
    with gpytorch.settings.observation_nan_policy("fill"):
        main()

@jacobrgardner
Copy link
Member

jacobrgardner commented Mar 5, 2023

Maybe we can just include the NaN handling code inside the existing ExactMarginalLogLikelihood? This comes with a - probably very small - performance reduction but makes things 'just work'.

If we wanted, we could probably incorporate the missing obs in to the existing ExactMarginalLogLikelihood via a treat_nans_as_missing flag that defaults to False (I.e., the assumption is there is no missing data) and gates the isnan check. I think having users intentionally specify that nans mean missing is a good thing, because if they aren't intending to have missing labels, nans are probably more commonly some kind of data cleaning issue.

Does the lazy evaluation in GPyTorch save us from most of the performance cost associated with the NaN-values and can we improve on this aspect?

Probably yes, it should already be saving you. All we'd need to check is that the output of the likelihood call has a LazyEvaluatedKernelTensor covariance matrix before and after you index it. In the most typical GPyTorch training settings, you'd expect code like:

output_dist = model(train_x)
mll = exact_mll(output_dist, train_y)

And you'd expect output_dist to have a LazyEvaluatedKernelTensor covariance. Indexing that operator should index out data points with missing values before the forward method of the kernel is even actually called, so those points won't count towards the space and time of dealing with K_XX.

@Turakar
Copy link
Contributor Author

Turakar commented Mar 6, 2023

Does someone of you have an idea what the problem with the docstring is?

@Turakar
Copy link
Contributor Author

Turakar commented Mar 8, 2023

Maybe we can just include the NaN handling code inside the existing ExactMarginalLogLikelihood? This comes with a - probably very small - performance reduction but makes things 'just work'.

If we wanted, we could probably incorporate the missing obs in to the existing ExactMarginalLogLikelihood via a treat_nans_as_missing flag that defaults to False (I.e., the assumption is there is no missing data) and gates the isnan check. I think having users intentionally specify that nans mean missing is a good thing, because if they aren't intending to have missing labels, nans are probably more commonly some kind of data cleaning issue.

Ok, I like the idea of adding it as an option flag.

Does the lazy evaluation in GPyTorch save us from most of the performance cost associated with the NaN-values and can we improve on this aspect?

Probably yes, it should already be saving you. All we'd need to check is that the output of the likelihood call has a LazyEvaluatedKernelTensor covariance matrix before and after you index it. In the most typical GPyTorch training settings, you'd expect code like:

output_dist = model(train_x)
mll = exact_mll(output_dist, train_y)

And you'd expect output_dist to have a LazyEvaluatedKernelTensor covariance. Indexing that operator should index out data points with missing values before the forward method of the kernel is even actually called, so those points won't count towards the space and time of dealing with K_XX.

This is a LazyEvaluatedKernelTensor, so I think it's fine.

@Turakar Turakar changed the title [WIP] Better support for missing labels Better support for missing labels Mar 8, 2023
@Turakar Turakar marked this pull request as draft March 8, 2023 15:36
Turakar added 3 commits March 10, 2023 17:22
- Enable via gpytorch.settings
- Two modes: 'mask' and 'fill'
- Makes GaussianLikelihoodWithMissingObs obsolete
- Supports approximate GPs
@Turakar
Copy link
Contributor Author

Turakar commented Mar 10, 2023

I reworked large parts of this PR. It should be ready for review now.

@Turakar Turakar marked this pull request as ready for review March 10, 2023 16:56
Copy link
Member

@gpleiss gpleiss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Turakar for the awesome (and very thorough) PR! I'm excited to get this merged in.

See below for some comments. Mostly nit-picking, but one or two questions about performance.

docs/source/likelihoods.rst Show resolved Hide resolved
gpytorch/settings.py Show resolved Hide resolved
gpytorch/likelihoods/__init__.py Outdated Show resolved Hide resolved
test/examples/test_missing_data.py Show resolved Hide resolved
@Turakar
Copy link
Contributor Author

Turakar commented Apr 21, 2023

I created some initial benchmarks.

This is the benchmark code:

Benchmark code
import copy
import gc
import math
import time

import gpytorch.settings
import torch
from gpytorch import ExactMarginalLogLikelihood

from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from torch import Tensor
from tqdm import tqdm

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


class Model(ExactGP):
    def __init__(self, train_x: Tensor, train_y: Tensor):
        super().__init__(train_x, train_y, GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(RBFKernel())

    def forward(self, x: Tensor) -> MultivariateNormal:
        return MultivariateNormal(
            self.mean_module(x),
            self.covar_module(x),
        )


def make_dataset(train_num: int, train_missing: float, val_num: int, val_missing: float) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    train_x = torch.linspace(0, 1, train_num, device=device)
    train_y = torch.sin(2 * torch.pi * train_x)
    val_x = torch.linspace(0, 1, val_num, device=device)
    val_y = torch.sin(2 * torch.pi * val_x)

    # Randomly mask out some data
    if train_missing > 0:
        train_mask = torch.bernoulli(torch.full_like(train_y, train_missing)).to(torch.bool)
        train_y[train_mask] = torch.nan
    if val_missing > 0:
        val_mask = torch.bernoulli(torch.full_like(val_y, val_missing)).to(torch.bool)
        val_y[val_mask] = torch.nan

    train_x = train_x.unsqueeze(-1)
    val_x = val_x.unsqueeze(-1)

    return train_x, train_y, val_x, val_y


def prepare_model(train_x: Tensor, train_y: Tensor, steps: int) -> Model:
    model = Model(train_x, train_y).to(device)
    model.train()

    # Pre-train model s.t. we have realistic values and convergence times
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for _ in range(steps):
        optimizer.zero_grad(set_to_none=True)
        loss = -mll(model(*model.train_inputs), model.train_targets)
        loss.backward()
        optimizer.step()

    model.cpu()

    return model


def measure(model: Model, train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float]:
    # Create a copy of the model with the new training data
    model = copy.deepcopy(model)
    model.to(device)
    model.set_train_data(train_x, train_y, strict=False)

    # Simulate training step
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
    prior_time = time.time()
    loss = mll(model(*model.train_inputs), model.train_targets)
    loss.backward()
    train_time = time.time() - prior_time

    # Simulate prediction
    with torch.no_grad():
        prior_time = time.time()
        model.eval()
        prediction: MultivariateNormal = model.likelihood(model(val_x))
        mean = prediction.mean
        covar = prediction.covariance_matrix
        val_time = time.time() - prior_time

    return train_time, val_time


def measure_multiple(models: list[Model], train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float, float, float]:
    train_times = []
    val_times = []
    for model in tqdm(models, desc="Collecting", leave=False):
        train_time, val_time = measure(model, train_x, train_y, val_x)
        train_times.append(train_time)
        val_times.append(val_time)
        gc.collect()
    train_times = torch.tensor(train_times)
    train_mean = torch.mean(train_times).item()
    train_sem = torch.std(train_times).item() / math.sqrt(len(models))
    val_times = torch.tensor(val_times)
    val_mean = torch.mean(val_times).item()
    val_sem = torch.std(val_times).item() / math.sqrt(len(models))
    return train_mean, train_sem, val_mean, val_sem


def main():
    with gpytorch.settings.max_cholesky_size(0):
        n = 8000
        iterations = 50
        train_steps = 50
        nan_fractions = [x.item() for x in torch.linspace(0, 0.5, 12)]

        # Prepare some models
        # We will use the same models for each NaN fraction and just change their training datasets.
        sample_x, sample_y, _, _ = make_dataset(n, 0, n // 10, 0)
        models = []
        for _ in tqdm(range(iterations), desc="Preparing models"):
            models.append(prepare_model(sample_x, sample_y, train_steps))

        # Collect measurements
        measurements = []
        for nan_fraction in tqdm(nan_fractions, desc="NaN fractions"):
            with gpytorch.settings.observation_nan_policy("mask" if nan_fraction > 0 else "ignore"):
                nan_n = int(n / (1 - nan_fraction))  # Scale n s.t. we have an equal amount of observed data
                train_x, train_y, val_x, val_y = make_dataset(nan_n, nan_fraction, nan_n // 10, nan_fraction)
                measurements.append(list(measure_multiple(models, train_x, train_y, val_x)))
        measurements = torch.tensor(measurements)

        # Create a plot showing the mean and std. error of the mean for both training and prediction
        fig = make_subplots(rows=1, cols=2, column_titles=["Training step", "Prediction step"])
        fig.add_trace(go.Scatter(
            x=nan_fractions,
            y=measurements[:, 0],
            error_y=dict(
                type="data",
                array=measurements[:, 1],
            ),
            mode="lines",
        ), row=1, col=1)
        fig.add_trace(go.Scatter(
            x=nan_fractions,
            y=measurements[:, 2],
            error_y=dict(
                type="data",
                array=measurements[:, 3],
            ),
            mode="lines",
        ), row=1, col=2)
        fig.update_layout(
            title="NaN masking performance for simple RBF model",
            showlegend=False,
        )
        fig.update_xaxes(title="NaN fraction")
        fig.update_yaxes(title="Time per step (s)")
        fig.show(renderer="browser")
        fig.write_html("missing_data_performance.html")
        fig.write_image("missing_data_performance.svg")


if __name__ == '__main__':
    main()

And this is the result:

missing_data_performance

It would also be interesting to do this for a multitask model which uses Kronecker structure, as this is the more likely use case. But preliminary, I think it is safe to say that masking creates a certain overhead during training, but is independent of the masked fraction, while there is some steady increase with increasing masked points during prediction.

@Turakar
Copy link
Contributor Author

Turakar commented Apr 21, 2023

And here are the benchmarks for the Kronecker case. It seems like this operator cannot make use of indexing to accelerate training or inference.

Benchmark code
import copy
import gc
import math
import time

import gpytorch.settings
import torch
from gpytorch import ExactMarginalLogLikelihood

from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal
from gpytorch.kernels import ScaleKernel, RBFKernel, LCMKernel
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.means import ConstantMean, MultitaskMean
from gpytorch.models import ExactGP
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from torch import Tensor
from tqdm import tqdm

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


class Model(ExactGP):
    def __init__(self, train_x: Tensor, train_y: Tensor, num_latents: int):
        num_tasks = train_y.shape[-1]
        super().__init__(train_x, train_y, MultitaskGaussianLikelihood(num_tasks))
        self.mean_module = MultitaskMean(ConstantMean(), num_tasks)
        self.covar_module = LCMKernel(
            [ScaleKernel(RBFKernel()) for _ in range(num_latents)],
            num_tasks,
        )

    def forward(self, x: Tensor) -> MultitaskMultivariateNormal:
        return MultitaskMultivariateNormal(
            self.mean_module(x),
            self.covar_module(x),
        )


def make_dataset(train_num: int, train_missing: float, val_num: int, val_missing: float) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    def target_function(x: Tensor) -> Tensor:
        return torch.stack([
            torch.sin(2 * torch.pi * x),
            torch.sin(2 * torch.pi * x) * 0.25,
            torch.sin(3 * torch.pi * x) + torch.sin(2 * torch.pi * x),
        ], dim=1)


    train_x = torch.linspace(0, 1, train_num, device=device)
    train_y = target_function(train_x)
    val_x = torch.linspace(0, 1, val_num, device=device)
    val_y = target_function(val_x)

    # Randomly mask out some data
    if train_missing > 0:
        train_mask = torch.bernoulli(torch.full_like(train_y, train_missing)).to(torch.bool)
        train_y[train_mask] = torch.nan
    if val_missing > 0:
        val_mask = torch.bernoulli(torch.full_like(val_y, val_missing)).to(torch.bool)
        val_y[val_mask] = torch.nan

    train_x = train_x.unsqueeze(-1)
    val_x = val_x.unsqueeze(-1)

    return train_x, train_y, val_x, val_y


def prepare_model(train_x: Tensor, train_y: Tensor, num_latents: int, steps: int) -> Model:
    model = Model(train_x, train_y, num_latents).to(device)
    model.train()

    # Pre-train model s.t. we have realistic values and convergence times
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for _ in range(steps):
        optimizer.zero_grad(set_to_none=True)
        loss = -mll(model(*model.train_inputs), model.train_targets)
        loss.backward()
        optimizer.step()

    model.cpu()

    return model


def measure(model: Model, train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float]:
    # Create a copy of the model with the new training data
    model = copy.deepcopy(model)
    model.to(device)
    model.set_train_data(train_x, train_y, strict=False)

    # Simulate training step
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
    prior_time = time.time()
    loss = mll(model(*model.train_inputs), model.train_targets)
    loss.backward()
    train_time = time.time() - prior_time

    # Simulate prediction
    with torch.no_grad():
        prior_time = time.time()
        model.eval()
        prediction: MultivariateNormal = model.likelihood(model(val_x))
        mean = prediction.mean
        covar = prediction.covariance_matrix
        val_time = time.time() - prior_time

    return train_time, val_time


def measure_multiple(models: list[Model], train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float, float, float]:
    train_times = []
    val_times = []
    for model in tqdm(models, desc="Collecting", leave=False):
        train_time, val_time = measure(model, train_x, train_y, val_x)
        train_times.append(train_time)
        val_times.append(val_time)
        gc.collect()
    train_times = torch.tensor(train_times)
    train_mean = torch.mean(train_times).item()
    train_sem = torch.std(train_times).item() / math.sqrt(len(models))
    val_times = torch.tensor(val_times)
    val_mean = torch.mean(val_times).item()
    val_sem = torch.std(val_times).item() / math.sqrt(len(models))
    return train_mean, train_sem, val_mean, val_sem


def main():
    with gpytorch.settings.max_cholesky_size(0):
        n = 3000
        iterations = 50
        train_steps = 50
        num_latents = 2
        nan_fractions = [x.item() for x in torch.linspace(0, 0.5, 12)]

        # Prepare some models
        # We will use the same models for each NaN fraction and just change their training datasets.
        sample_x, sample_y, _, _ = make_dataset(n, 0, n // 10, 0)
        models = []
        for _ in tqdm(range(iterations), desc="Preparing models"):
            models.append(prepare_model(sample_x, sample_y, num_latents, train_steps))

        # Collect measurements
        measurements = []
        for nan_fraction in tqdm(nan_fractions, desc="NaN fractions"):
            with gpytorch.settings.observation_nan_policy("mask" if nan_fraction > 0 else "ignore"):
                nan_n = int(n / (1 - nan_fraction))  # Scale n s.t. we have an equal amount of observed data
                train_x, train_y, val_x, val_y = make_dataset(nan_n, nan_fraction, nan_n // 10, nan_fraction)
                measurements.append(list(measure_multiple(models, train_x, train_y, val_x)))
        measurements = torch.tensor(measurements)

        # Create a plot showing the mean and std. error of the mean for both training and prediction
        fig = make_subplots(rows=1, cols=2, column_titles=["Training step", "Prediction step"])
        fig.add_trace(go.Scatter(
            x=nan_fractions,
            y=measurements[:, 0],
            error_y=dict(
                type="data",
                array=measurements[:, 1],
            ),
            mode="lines",
        ), row=1, col=1)
        fig.add_trace(go.Scatter(
            x=nan_fractions,
            y=measurements[:, 2],
            error_y=dict(
                type="data",
                array=measurements[:, 3],
            ),
            mode="lines",
        ), row=1, col=2)
        fig.update_layout(
            title="NaN masking performance for simple RBF model",
            showlegend=False,
        )
        fig.update_xaxes(title="NaN fraction")
        fig.update_yaxes(title="Time per step (s)")
        fig.show(renderer="browser")
        fig.write_html("missing_data_performance.html")
        fig.write_image("missing_data_performance.svg")


if __name__ == '__main__':
    main()

And the result:

missing_data_performance

@Turakar
Copy link
Contributor Author

Turakar commented Apr 21, 2023

I did another test: Instead of indexing the linear operator in the kronecker case, which seems to bring no improvement, I instead attached a new MaskedLinearOperator on the outside. This seems to lower the overhead during training, but does not have an effect on prediction.

Masked Linear Operator
from typing import Optional, Union

import torch
from linear_operator import LinearOperator
from torch import Tensor


class MaskedLinearOperator(LinearOperator):
    def __init__(
        self, base: LinearOperator, row_mask: Tensor, col_mask: Tensor
    ):
        super().__init__(base, row_mask, col_mask)
        self.base = base
        self.row_mask = row_mask
        self.col_mask = col_mask
        self.row_eq_col_mask = (
            row_mask is not None and col_mask is not None and torch.equal(row_mask, col_mask)
        )

    def _matmul(self, rhs: Tensor) -> Tensor:
        if self.col_mask is not None:
            rhs_expanded = torch.zeros(
                *rhs.shape[:-2],
                self.base.size(-1),
                rhs.shape[-1],
                device=rhs.device,
                dtype=rhs.dtype,
            )
            rhs_expanded[..., self.col_mask, :] = rhs
            rhs = rhs_expanded

        res = self.base.matmul(rhs)

        if self.row_mask is not None:
            res = res[..., self.row_mask, :]

        return res

    def _size(self) -> torch.Size:
        base_size = list(self.base.size())
        if self.row_mask is not None:
            base_size[-2] = torch.count_nonzero(self.row_mask)
        if self.col_mask is not None:
            base_size[-1] = torch.count_nonzero(self.col_mask)
        return torch.Size(tuple(base_size))

    def _transpose_nonbatch(self) -> LinearOperator:
        return MaskedLinearOperator(self.base.mT, self.col_mask, self.row_mask)

    def _getitem(
        self,
        row_index: Union[slice, torch.LongTensor],
        col_index: Union[slice, torch.LongTensor],
        *batch_indices: tuple[Union[int, slice, torch.LongTensor], ...],
    ) -> LinearOperator:
        raise NotImplementedError(
            "Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)
        )

    def _get_indices(
        self,
        row_index: torch.LongTensor,
        col_index: torch.LongTensor,
        *batch_indices: tuple[torch.LongTensor, ...],
    ) -> torch.Tensor:
        def map_indices(index: torch.LongTensor, mask: Optional[Tensor], base_size: int) -> torch.LongTensor:
            if mask is None:
                return index
            map = torch.arange(base_size, device=self.base.device)[mask]
            return map[index]

        if len(batch_indices) == 0:
            row_index = map_indices(row_index, self.row_mask, self.base.size(-2))
            col_index = map_indices(col_index, self.col_mask, self.base.size(-1))
            return self.base._get_indices(row_index, col_index)

        raise NotImplementedError(
            "Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)
        )

    def _diagonal(self) -> Tensor:
        if not self.row_eq_col_mask:
            raise NotImplementedError()
        diag = self.base.diagonal()
        return diag[self.row_mask]

    def to_dense(self) -> torch.Tensor:
        full_dense = self.base.to_dense()
        return full_dense[..., self.row_mask, :][..., :, self.col_mask]

    def _cholesky_solve(self, rhs, upper: bool = False) -> LinearOperator:
        raise NotImplementedError()

    def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator:
        raise NotImplementedError()

    def _isclose(
        self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False
    ) -> Tensor:
        raise NotImplementedError()

    def _prod_batch(self, dim: int) -> LinearOperator:
        raise NotImplementedError()

    def _sum_batch(self, dim: int) -> LinearOperator:
        raise NotImplementedError()

missing_data_performance

@gpleiss
Copy link
Member

gpleiss commented May 26, 2023

@Turakar would you be able to add a PR for MaskedLinearOperator to the linear_operator repo, and then we can merge this PR?

Copy link
Member

@gpleiss gpleiss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay @Turakar - I'm back online now!
Just one small change: moving MaskedLinearOperator and adding a short unit test for it. Then I'll merge!

gpytorch/utils/masked_linear_operator.py Outdated Show resolved Hide resolved
@Turakar
Copy link
Contributor Author

Turakar commented Aug 8, 2023

Hm, I think the failing unit tests might be caused by an incompatbility with linear-operators 0.5.1.

The test fails because LazyEvaluatedKernelTensor only supports _matmul() with checkpointing, but checkpointing is deprecated.
@Turakar
Copy link
Contributor Author

Turakar commented Aug 8, 2023

Considering #2342, I decided to just disable the failing test. It is specific to a missing support of _matmul() in LazyEvaluatedKernelTensor for the non-checkpoint case (where checkpointing is deprecated).

@Turakar Turakar requested a review from gpleiss August 8, 2023 12:55
gpleiss added a commit that referenced this pull request Aug 25, 2023
RTD is removing the "use system packages" feature on 29 Aug 2023.
This PR ensures that our docs will sill build.

Moreover, the linear_operator requirement needs to be updated for #2288.
gpleiss added a commit that referenced this pull request Aug 25, 2023
RTD is removing the "use system packages" feature on 29 Aug 2023.
This PR ensures that our docs will sill build.

Moreover, the linear_operator requirement needs to be updated for #2288.
gpleiss added a commit that referenced this pull request Aug 25, 2023
RTD is removing the "use system packages" feature on 29 Aug 2023.
This PR ensures that our docs will sill build.

Moreover, the linear_operator requirement needs to be updated for #2288.
gpleiss added a commit that referenced this pull request Aug 25, 2023
RTD is removing the "use system packages" feature on 29 Aug 2023.
This PR ensures that our docs will sill build.

Moreover, the linear_operator requirement needs to be updated for #2288.
@Turakar
Copy link
Contributor Author

Turakar commented Sep 5, 2023

I fixed the merge conflicts.

@gpleiss gpleiss merged commit 981edd8 into cornellius-gp:master Sep 7, 2023
@gpleiss
Copy link
Member

gpleiss commented Sep 7, 2023

Finally merged! Thanks for the patience @Turakar !

@Turakar Turakar deleted the missing-data branch September 7, 2023 15:04
@Turakar
Copy link
Contributor Author

Turakar commented Sep 7, 2023

I am happy it's merged 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants