diff --git a/CHANGELOG.md b/CHANGELOG.md index 14aca878a..2b9e9e454 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ [PR #371](https://github.com/appliedAI-Initiative/pyDVL/pull/371) - Major changes to IF interface and functionality [PR #278](https://github.com/appliedAI-Initiative/pyDVL/pull/278) + [PR #394](https://github.com/appliedAI-Initiative/pyDVL/pull/394) - **New Method**: Implements solving the hessian equation via spectral low-rank approximation [PR #365](https://github.com/appliedAI-Initiative/pyDVL/pull/365) - **Breaking Changes**: diff --git a/notebooks/support/torch.py b/notebooks/support/torch.py index b9779c5b0..68faad12f 100644 --- a/notebooks/support/torch.py +++ b/notebooks/support/torch.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from torchvision.models import ResNet18_Weights, resnet18 -from pydvl.influence.frameworks import as_tensor +from pydvl.influence.torch import as_tensor from pydvl.utils import maybe_progress from .types import Losses diff --git a/src/pydvl/influence/__init__.py b/src/pydvl/influence/__init__.py index a0d8f48da..41d5dc993 100644 --- a/src/pydvl/influence/__init__.py +++ b/src/pydvl/influence/__init__.py @@ -7,6 +7,5 @@ probably change. """ -from .frameworks import TorchTwiceDifferentiable, TwiceDifferentiable -from .general import compute_influence_factors, compute_influences +from .general import InfluenceType, compute_influence_factors, compute_influences from .inversion import InversionMethod diff --git a/src/pydvl/influence/frameworks/__init__.py b/src/pydvl/influence/frameworks/__init__.py deleted file mode 100644 index 298dd4f41..000000000 --- a/src/pydvl/influence/frameworks/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -# FIXME the following code was part of an attempt to accommodate different -# frameworks. In its current form it is ugly and thus it will likely be changed -# in the future. - -import logging - -from .twice_differentiable import InverseHvpResult, TwiceDifferentiable - -__all__ = ["TwiceDifferentiable"] -logger = logging.getLogger("frameworks") - -try: - import torch - - from .torch_differentiable import TorchTwiceDifferentiable - - __all__.append("TorchTwiceDifferentiable") - - from .torch_differentiable import ( - as_tensor, - cat, - einsum, - mvp, - solve_batch_cg, - solve_linear, - solve_lissa, - stack, - transpose_tensor, - zero_tensor, - ) - - TensorType = torch.Tensor - DataLoaderType = torch.utils.data.DataLoader - ModelType = torch.nn.Module - - __all__.extend( - [ - "TensorType", - "ModelType", - "InverseHvpResult", - "solve_linear", - "solve_batch_cg", - "solve_lissa", - "as_tensor", - "stack", - "cat", - "zero_tensor", - "transpose_tensor", - "einsum", - "mvp", - ] - ) - -except ImportError: - logger.info( - "No compatible framework found. Influence function computation disabled." - ) diff --git a/src/pydvl/influence/frameworks/twice_differentiable.py b/src/pydvl/influence/frameworks/twice_differentiable.py deleted file mode 100644 index 26e9e430d..000000000 --- a/src/pydvl/influence/frameworks/twice_differentiable.py +++ /dev/null @@ -1,90 +0,0 @@ -from abc import ABC -from dataclasses import dataclass -from typing import Any, Callable, Dict, Generic, List, Sequence, Tuple, TypeVar - -TensorType = TypeVar("TensorType", bound=Sequence) -ModelType = TypeVar("ModelType") -DeviceType = TypeVar("DeviceType") - - -@dataclass(frozen=True) -class InverseHvpResult(Generic[TensorType]): - x: TensorType - info: Dict[str, Any] - - def __iter__(self): - return iter((self.x, self.info)) - - -class TwiceDifferentiable(ABC, Generic[TensorType, ModelType, DeviceType]): - """ - Wraps a differentiable model and loss and provides methods to compute the - second derivative of the loss wrt. the model parameters. - """ - - def __init__( - self, - model: ModelType, - loss: Callable[[TensorType, TensorType], TensorType], - device: DeviceType, - ): - self.device = device - pass - - @property - def num_params(self) -> int: - """Returns the number of parameters of the model""" - pass - - @property - def parameters(self) -> List[TensorType]: - """Returns all the model parameters that require differentiation""" - pass - - def split_grad( - self, x: TensorType, y: TensorType, *, progress: bool = False - ) -> TensorType: - """ - Calculate the gradient of the model wrt each input x and labels y. - The output is therefore of size [Nxp], with N the amount of points (the - length of x and y) and P the number of parameters. - - :param x: An array representing the features $x_i$. - :param y: An array representing the predicted target values $y_i$. - :param progress: ``True`` to display progress. - :return: An array representing the gradients wrt. the parameters of the - model. - """ - pass - - def grad( - self, x: TensorType, y: TensorType, *, x_requires_grad: bool = False - ) -> Tuple[TensorType, TensorType]: - """ - Calculates gradient of model parameters wrt. the model parameters. - - :param x: A matrix representing the features $x_i$. - :param y: A matrix representing the target values $y_i$. - :param x_requires_grad: If True, the input $x$ is marked as requiring - gradients. This is important for further differentiation on input - parameters. - :return: A tuple where: the first element is an array with the - gradients of the model, and the second element is the input to the - model as a grad parameters. This can be used for further - differentiation. - """ - pass - - def hessian( - self, x: TensorType, y: TensorType, *, progress: bool = False - ) -> TensorType: - """Calculates the full Hessian of $L(f(x),y)$ with respect to the model - parameters given data ($x$ and $y$). - - :param x: An array representing the features $x_i$. - :param y: An array representing the target values $y_i$. - :param progress: ``True`` to display progress. - :return: The hessian of the model, i.e. the second derivative wrt. the - model parameters. - """ - pass diff --git a/src/pydvl/influence/general.py b/src/pydvl/influence/general.py index 304910f5c..b581514ee 100644 --- a/src/pydvl/influence/general.py +++ b/src/pydvl/influence/general.py @@ -2,24 +2,24 @@ This module contains parallelized influence calculation functions for general models, as introduced in :footcite:t:`koh_understanding_2017`. """ +import logging from copy import deepcopy from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Optional, Type from ..utils import maybe_progress -from .frameworks import ( +from .inversion import InverseHvpResult, InversionMethod, solve_hvp +from .twice_differentiable import ( DataLoaderType, TensorType, + TensorUtilities, TwiceDifferentiable, - einsum, - mvp, - transpose_tensor, - zero_tensor, ) -from .inversion import InverseHvpResult, InversionMethod, solve_hvp __all__ = ["compute_influences", "InfluenceType", "compute_influence_factors"] +logger = logging.getLogger(__name__) + class InfluenceType(str, Enum): """ @@ -56,32 +56,52 @@ def compute_influence_factors( :param model: A model wrapped in the TwiceDifferentiable interface. :param training_data: A DataLoader containing the training data. :param test_data: A DataLoader containing the test data. - :param inversion_func: function to use to invert the product of hvp (hessian - vector product) and the gradient of the loss (s_test in the paper). + :param inversion_method: name of method for computing inverse hessian vector + products. :param hessian_perturbation: regularization of the hessian :param progress: If True, display progress bars. :returns: An array of size (N, D) containing the influence factors for each dimension (D) and test sample (N). """ - test_grads = zero_tensor( - shape=(len(test_data.dataset), model.num_params), - dtype=test_data.dataset[0][0].dtype, - device=model.device, + tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable( + model ) - for batch_idx, (x_test, y_test) in enumerate( - maybe_progress(test_data, progress, desc="Batch Test Gradients") - ): - idx = batch_idx * test_data.batch_size - test_grads[idx : idx + test_data.batch_size] = model.split_grad( - x_test, y_test, progress=False + + stack = tensor_util.stack + unsqueeze = tensor_util.unsqueeze + cat_gen = tensor_util.cat_gen + cat = tensor_util.cat + + def test_grads() -> Generator[TensorType, None, None]: + for x_test, y_test in maybe_progress( + test_data, progress, desc="Batch Test Gradients" + ): + yield stack( + [ + model.grad(inpt, target) + for inpt, target in zip(unsqueeze(x_test, 1), y_test) + ] + ) # type:ignore + + try: + # if provided input_data implements __len__, pre-allocate the result tensor to reduce memory consumption + resulting_shape = (len(test_data), model.num_params) # type:ignore + rhs = cat_gen( + test_grads(), resulting_shape, model # type:ignore + ) # type:ignore + except Exception as e: + logger.warning( + f"Failed to pre-allocate result tensor: {e}\n" + f"Evaluate all resulting tensor and concatenate" ) + rhs = cat(list(test_grads())) + return solve_hvp( inversion_method, model, training_data, - test_grads, + rhs, hessian_perturbation=hessian_perturbation, - progress=progress, **kwargs, ) @@ -110,19 +130,39 @@ def compute_influences_up( :returns: An array of size [NxM], where N is number of influence factors, M number of input points. """ - grads = zero_tensor( - shape=(len(input_data.dataset), model.num_params), - dtype=input_data.dataset[0][0].dtype, - device=model.device, + + tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable( + model ) - for batch_idx, (x, y) in enumerate( - maybe_progress(input_data, progress, desc="Batch Split Input Gradients") - ): - idx = batch_idx * input_data.batch_size - grads[idx : idx + input_data.batch_size] = model.split_grad( - x, y, progress=False + + stack = tensor_util.stack + unsqueeze = tensor_util.unsqueeze + cat_gen = tensor_util.cat_gen + cat = tensor_util.cat + einsum = tensor_util.einsum + + def train_grads() -> Generator[TensorType, None, None]: + for x, y in maybe_progress( + input_data, progress, desc="Batch Split Input Gradients" + ): + yield stack( + [model.grad(inpt, target) for inpt, target in zip(unsqueeze(x, 1), y)] + ) # type:ignore + + try: + # if provided input_data implements __len__, pre-allocate the result tensor to reduce memory consumption + resulting_shape = (len(input_data), model.num_params) # type:ignore + train_grad_tensor = cat_gen( + train_grads(), resulting_shape, model # type:ignore + ) # type:ignore + except Exception as e: + logger.warning( + f"Failed to pre-allocate result tensor: {e}\n" + f"Evaluate all resulting tensor and concatenate" ) - return einsum("ta,va->tv", influence_factors, grads) + train_grad_tensor = cat([x for x in train_grads()]) # type:ignore + + return einsum("ta,va->tv", influence_factors, train_grad_tensor) # type:ignore def compute_influences_pert( @@ -149,34 +189,38 @@ def compute_influences_pert( :returns: An array of size [NxMxP], where N is the number of influence factors, M the number of input data, and P the number of features. """ - input_x = input_data.dataset[0][0] - all_pert_influences = zero_tensor( - shape=(len(input_data.dataset), len(influence_factors), *input_x.shape), - dtype=input_x.dtype, - device=model.device, + + tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable( + model ) - for batch_idx, (x, y) in enumerate( - maybe_progress( - input_data, - progress, - desc="Batch Influence Perturbation", - ) + stack = tensor_util.stack + tu_slice = tensor_util.slice + reshape = tensor_util.reshape + get_element = tensor_util.get_element + shape = tensor_util.shape + + all_pert_influences = [] + for x, y in maybe_progress( + input_data, + progress, + desc="Batch Influence Perturbation", ): for i in range(len(x)): - grad_xy, tensor_x = model.grad(x[i : i + 1], y[i], x_requires_grad=True) - perturbation_influences = mvp( + tensor_x = tu_slice(x, i, i + 1) + grad_xy = model.grad(tensor_x, get_element(y, i), create_graph=True) + perturbation_influences = model.mvp( grad_xy, influence_factors, backprop_on=tensor_x, ) - all_pert_influences[ - batch_idx * input_data.batch_size + i - ] = perturbation_influences.reshape((-1, *x[i].shape)) + all_pert_influences.append( + reshape(perturbation_influences, (-1, *shape(get_element(x, i)))) + ) - return transpose_tensor(all_pert_influences, 0, 1) + return stack(all_pert_influences, axis=1) # type:ignore -influence_type_registry = { +influence_type_registry: Dict[InfluenceType, Callable[..., TensorType]] = { InfluenceType.Up: compute_influences_up, InfluenceType.Perturbation: compute_influences_pert, } @@ -193,7 +237,7 @@ def compute_influences( hessian_regularization: float = 0.0, progress: bool = False, **kwargs: Any, -) -> TensorType: +) -> TensorType: # type: ignore # ToDO fix typing r""" Calculates the influence of the input_data point j on the test points i. First it calculates the influence factors of all test points with respect @@ -234,9 +278,8 @@ def compute_influences( progress=progress, **kwargs, ) - compute_influence_type = influence_type_registry[influence_type] - return compute_influence_type( + return influence_type_registry[influence_type]( differentiable_model, input_data, influence_factors, diff --git a/src/pydvl/influence/inversion.py b/src/pydvl/influence/inversion.py index b87e415c9..27cf7ff60 100644 --- a/src/pydvl/influence/inversion.py +++ b/src/pydvl/influence/inversion.py @@ -1,24 +1,26 @@ +"""Contains methods to invert the hessian vector product. """ -Contains methods to invert the hessian vector product. -""" +import functools +import inspect import logging +import warnings from enum import Enum -from typing import Any +from typing import Any, Callable, Dict, Tuple, Type + +__all__ = [ + "solve_hvp", + "InversionMethod", + "InversionRegistry", + "InverseHvpResult", +] -from .frameworks import ( +from .twice_differentiable import ( DataLoaderType, InverseHvpResult, TensorType, TwiceDifferentiable, - solve_batch_cg, - solve_linear, - solve_lissa, ) -__all__ = ["solve_hvp", "InversionMethod"] - -from .frameworks.torch_differentiable import solve_arnoldi - logger = logging.getLogger(__name__) @@ -40,7 +42,6 @@ def solve_hvp( b: TensorType, *, hessian_perturbation: float = 0.0, - progress: bool = False, **kwargs: Any, ) -> InverseHvpResult: """ @@ -53,51 +54,144 @@ def solve_hvp( :param inversion_method: :param model: A model wrapped in the TwiceDifferentiable interface. - :param x: An array containing the features of the input data points. - :param y: labels for x + :param training_data: :param b: Array as the right hand side of the equation $Ax = b$ - :param kwargs: kwargs to pass to the inversion method :param hessian_perturbation: regularization of the hessian - :param progress: If True, display progress bars. + :param kwargs: kwargs to pass to the inversion method - :return: An object that containes an array that solves the inverse problem, + :return: An object that contains an array that solves the inverse problem, i.e. it returns $x$ such that $Ax = b$, and a dictionary containing information about the inversion process. """ - if inversion_method == InversionMethod.Direct: - return solve_linear( - model, - training_data, - b, - **kwargs, - hessian_perturbation=hessian_perturbation, - progress=progress, - ) - elif inversion_method == InversionMethod.Cg: - return solve_batch_cg( - model, - training_data, - b, - **kwargs, - hessian_perturbation=hessian_perturbation, - progress=progress, - ) - elif inversion_method == InversionMethod.Lissa: - return solve_lissa( - model, - training_data, - b, - **kwargs, - hessian_perturbation=hessian_perturbation, - progress=progress, - ) - elif inversion_method == InversionMethod.Arnoldi: - return solve_arnoldi( - model, # type: ignore # TODO the interface TwiceDifferentiable is not used properly anyhow - training_data, - b, - **kwargs, - hessian_perturbation=hessian_perturbation, + return InversionRegistry.call( + inversion_method, + model, + training_data, + b, + hessian_perturbation=hessian_perturbation, + **kwargs, + ) + + +class InversionRegistry: + """ + A registry to hold inversion methods for different models. + """ + + registry: Dict[Tuple[Type[TwiceDifferentiable], InversionMethod], Callable] = {} + + @classmethod + def register( + cls, + model_type: Type[TwiceDifferentiable], + inversion_method: InversionMethod, + overwrite: bool = False, + ): + """ + Register a function for a specific model type and inversion method. + + The function to be registered must conform to the following signature: + `(model: TwiceDifferentiable, training_data: DataLoaderType, b: TensorType, + hessian_perturbation: float = 0.0, ...)`. + + :param model_type: The type of the model the function should be registered for. + :param inversion_method: The inversion method the function should be + registered for. + :param overwrite: If ``True``, allows overwriting of an existing registered + function for the same model type and inversion method. If ``False``, + logs a warning when attempting to register a function for an already + registered model type and inversion method. + + :raises TypeError: If the provided model_type or inversion_method are of the wrong type. + :raises ValueError: If the function to be registered does not match the required signature. + :return: A decorator for registering a function. + """ + + if not isinstance(model_type, type): + raise TypeError( + f"'model_type' is of type {type(model_type)} but should be a Type[TwiceDifferentiable]" + ) + + if not isinstance(inversion_method, InversionMethod): + raise TypeError( + f"'inversion_method' must be an 'InversionMethod' " + f"but has type {type(inversion_method)} instead." + ) + + key = (model_type, inversion_method) + + def decorator(func): + if not overwrite and key in cls.registry: + warnings.warn( + f"There is already a function registered for model type {model_type} " + f"and inversion method {inversion_method}. " + f"To overwrite the existing function {cls.registry.get(key)} with {func}, set overwrite to True." + ) + sig = inspect.signature(func) + params = list(sig.parameters.values()) + + expected_args = [ + ("model", model_type), + ("training_data", DataLoaderType.__bound__), + ("b", model_type.tensor_type()), + ("hessian_perturbation", float), + ] + + for (name, typ), param in zip(expected_args, params): + if not ( + isinstance(param.annotation, typ) + or issubclass(param.annotation, typ) + ): + raise ValueError( + f'Parameter "{name}" must be of type "{typ.__name__}"' + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + cls.registry[key] = wrapper + return wrapper + + return decorator + + @classmethod + def get( + cls, model_type: Type[TwiceDifferentiable], inversion_method: InversionMethod + ) -> Callable[ + [TwiceDifferentiable, DataLoaderType, TensorType, float], InverseHvpResult + ]: + key = (model_type, inversion_method) + method = cls.registry.get(key, None) + if method is None: + raise ValueError(f"No function registered for {key}") + return method + + @classmethod + def call( + cls, + inversion_method: InversionMethod, + model: TwiceDifferentiable, + training_data: DataLoaderType, + b: TensorType, + hessian_perturbation, + **kwargs, + ) -> InverseHvpResult: + """ + Call a registered function with the provided parameters. + + :param inversion_method: The inversion method to use. + :param model: A model wrapped in the TwiceDifferentiable interface. + :param training_data: The training data to use. + :param b: Array as the right hand side of the equation $Ax = b$. + :param hessian_perturbation: Regularization of the hessian. + :param kwargs: Additional keyword arguments to pass to the inversion method. + + :return: An object that contains an array that solves the inverse problem, + i.e. it returns $x$ such that $Ax = b$, and a dictionary containing + information about the inversion process. + """ + + return cls.get(type(model), inversion_method)( + model, training_data, b, hessian_perturbation, **kwargs ) - else: - raise ValueError(f"Unknown inversion method: {inversion_method}") diff --git a/src/pydvl/influence/torch/__init__.py b/src/pydvl/influence/torch/__init__.py new file mode 100644 index 000000000..1f431d57b --- /dev/null +++ b/src/pydvl/influence/torch/__init__.py @@ -0,0 +1,9 @@ +from .torch_differentiable import ( + TorchTwiceDifferentiable, + as_tensor, + model_hessian_low_rank, + solve_arnoldi, + solve_batch_cg, + solve_linear, + solve_lissa, +) diff --git a/src/pydvl/influence/frameworks/functional.py b/src/pydvl/influence/torch/functional.py similarity index 83% rename from src/pydvl/influence/frameworks/functional.py rename to src/pydvl/influence/torch/functional.py index 4951ac224..f1b042032 100644 --- a/src/pydvl/influence/frameworks/functional.py +++ b/src/pydvl/influence/torch/functional.py @@ -1,11 +1,15 @@ -from functools import partial from typing import Callable, Dict, Generator, Iterable import torch from torch.func import functional_call, grad, jvp, vjp from torch.utils.data import DataLoader -from .util import TorchTensorContainerType, align_structure, to_model_device +from .util import ( + TorchTensorContainerType, + align_structure, + flatten_tensors_to_vector, + to_model_device, +) __all__ = [ "get_hvp_function", @@ -13,7 +17,7 @@ def hvp( - func: Callable[[TorchTensorContainerType], TorchTensorContainerType], + func: Callable[[TorchTensorContainerType], torch.Tensor], params: TorchTensorContainerType, vec: TorchTensorContainerType, reverse_only: bool = True, @@ -24,7 +28,7 @@ def hvp( forward- and reverse-mode autodiff. - :param func: The function for which the HVP is computed. + :param func: The scalar-valued function for which the HVP is computed. :param params: The parameters at which the HVP is computed. :param vec: The vector with which the Hessian is multiplied. :param reverse_only: Whether to use only reverse-mode autodiff @@ -56,9 +60,7 @@ def batch_hvp_gen( loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], data_loader: DataLoader, reverse_only: bool = True, -) -> Generator[ - Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], None, None -]: +) -> Generator[Callable[[torch.Tensor], torch.Tensor], None, None]: """ Generates a sequence of batch Hessian-vector product (HVP) computations for the provided model, loss function, and data loader. @@ -75,14 +77,26 @@ def batch_hvp_gen( for inputs, targets in iter(data_loader): batch_loss = batch_loss_function(model, loss, inputs, targets) - yield partial(hvp, batch_loss, dict(model.named_parameters()), reverse_only=reverse_only) # type: ignore + model_params = dict(model.named_parameters()) + + def batch_hvp(vec: torch.Tensor): + return flatten_tensors_to_vector( + hvp( + batch_loss, + model_params, + align_structure(model_params, vec), + reverse_only=reverse_only, + ).values() + ) + + yield batch_hvp def empirical_loss_function( model: torch.nn.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], data_loader: DataLoader, -) -> Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]: +) -> Callable[[Dict[str, torch.Tensor]], torch.Tensor]: """ Creates a function to compute the empirical loss of a given model on a given dataset. If we denote the model parameters with $\theta$, the resulting function approximates @@ -151,7 +165,7 @@ def get_hvp_function( use_hessian_avg: bool = True, reverse_only: bool = True, track_gradients: bool = False, -) -> Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]: +) -> Callable[[torch.Tensor], torch.Tensor]: """ Returns a function that calculates the approximate Hessian-vector product for a given vector. If you want to compute the exact hessian, i.e. pulling all data into memory and compute a full gradient computation, use @@ -180,27 +194,25 @@ def get_hvp_function( k: p if track_gradients else p.detach() for k, p in model.named_parameters() } - def hvp_function(vec: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def hvp_function(vec: torch.Tensor) -> torch.Tensor: v = align_structure(params, vec) empirical_loss = empirical_loss_function(model, loss, data_loader) - return hvp(empirical_loss, params, v, reverse_only=reverse_only) + return flatten_tensors_to_vector( + hvp(empirical_loss, params, v, reverse_only=reverse_only).values() + ) - def avg_hvp_function(vec: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def avg_hvp_function(vec: torch.Tensor) -> torch.Tensor: v = align_structure(params, vec) - batch_hessians: Iterable[Dict[str, torch.Tensor]] = map( + batch_hessians_vector_products: Iterable[torch.Tensor] = map( lambda x: x(v), batch_hvp_gen(model, loss, data_loader, reverse_only) ) - result_dict = { - key: to_model_device(torch.zeros_like(p), model) - for key, p in params.items() - } num_batches = len(data_loader) + avg_hessian = to_model_device(torch.zeros_like(vec), model) - for batch_dict in batch_hessians: - for key, value in batch_dict.items(): - result_dict[key] += value + for batch_hvp in batch_hessians_vector_products: + avg_hessian += batch_hvp - return {key: value / num_batches for key, value in result_dict.items()} + return avg_hessian / float(num_batches) return avg_hvp_function if use_hessian_avg else hvp_function diff --git a/src/pydvl/influence/frameworks/torch_differentiable.py b/src/pydvl/influence/torch/torch_differentiable.py similarity index 72% rename from src/pydvl/influence/frameworks/torch_differentiable.py rename to src/pydvl/influence/torch/torch_differentiable.py index 41688f465..1515f3f96 100644 --- a/src/pydvl/influence/frameworks/torch_differentiable.py +++ b/src/pydvl/influence/torch/torch_differentiable.py @@ -7,9 +7,8 @@ import logging from dataclasses import dataclass from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Callable, Generator, List, Optional, Sequence, Tuple -import numpy as np import torch import torch.nn as nn from numpy.typing import NDArray @@ -19,42 +18,435 @@ from torch.utils.data import DataLoader from ...utils import maybe_progress +from ..inversion import InversionMethod, InversionRegistry +from ..twice_differentiable import ( + InverseHvpResult, + TensorUtilities, + TwiceDifferentiable, +) from .functional import get_hvp_function -from .twice_differentiable import InverseHvpResult, TwiceDifferentiable -from .util import align_structure, flatten_tensors_to_vector +from .util import align_structure, as_tensor, flatten_tensors_to_vector __all__ = [ "TorchTwiceDifferentiable", "solve_linear", "solve_batch_cg", "solve_lissa", - "as_tensor", - "stack", - "cat", - "zero_tensor", - "transpose_tensor", - "einsum", - "mvp", + "solve_arnoldi", "lanzcos_low_rank_hessian_approx", + "as_tensor", + "model_hessian_low_rank", ] logger = logging.getLogger(__name__) -def flatten_all(grad: torch.Tensor) -> torch.Tensor: +class TorchTwiceDifferentiable(TwiceDifferentiable[torch.Tensor]): + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + ): + r""" + :param model: A (differentiable) function. + :param loss: A differentiable scalar loss $L(\hat{y}, y)$, + mapping a prediction and a target to a real value. + """ + if model.training: + logger.warning( + "Passed model not in evaluation mode. This can create several issues in influence " + "computation, e.g. due to batch normalization. Please call model.eval() before " + "computing influences." + ) + self.loss = loss + self.model = model + first_param = next(model.parameters()) + self.device = first_param.device + self.dtype = first_param.dtype + + @classmethod + def tensor_type(cls): + return torch.Tensor + + @property + def parameters(self) -> List[torch.Tensor]: + """Returns all the model parameters that require differentiating""" + return [param for param in self.model.parameters() if param.requires_grad] + + @property + def num_params(self) -> int: + """ + Get number of parameters of model f. + :returns: Number of parameters as integer. + """ + return sum([p.numel() for p in self.parameters]) + + def grad( + self, x: torch.Tensor, y: torch.Tensor, create_graph: bool = False + ) -> torch.Tensor: + """ + Calculates gradient of model parameters wrt the model parameters. + + :param x: A matrix [NxD] representing the features $x_i$. + :param y: A matrix [NxK] representing the target values $y_i$. + :param create_graph: If True, the resulting gradient tensor, can be used for further differentiation + :returns: An array [P] with the gradients of the model. + """ + x = x.to(self.device) + y = y.to(self.device) + + if create_graph and not x.requires_grad: + x = x.requires_grad_(True) + + loss_value = self.loss(torch.squeeze(self.model(x)), torch.squeeze(y)) + grad_f = torch.autograd.grad( + loss_value, self.parameters, create_graph=create_graph + ) + return flatten_tensors_to_vector(grad_f) + + def hessian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calculates the explicit hessian of model parameters given data ($x$ and $y$). + :param x: A matrix [NxD] representing the features $x_i$. + :param y: A matrix [NxK] representing the target values $y_i$. + :returns: A tensor representing the hessian of the loss wrt. the model parameters. + """ + + def model_func(param): + outputs = torch.func.functional_call( + self.model, + align_structure( + {k: p for k, p in self.model.named_parameters() if p.requires_grad}, + param, + ), + (x.to(self.device),), + strict=True, + ) + return self.loss(outputs, y.to(self.device)) + + params = flatten_tensors_to_vector( + p.detach() for p in self.model.parameters() if p.requires_grad + ) + return torch.func.hessian(model_func)(params) + + @staticmethod + def mvp( + grad_xy: torch.Tensor, + v: torch.Tensor, + backprop_on: torch.Tensor, + *, + progress: bool = False, + ) -> torch.Tensor: + """ + Calculates second order derivative of the model along directions v. + This second order derivative can be selected through the backprop_on argument. + + :param grad_xy: an array [P] holding the gradients of the model + parameters wrt input $x$ and labels $y$, where P is the number of + parameters of the model. It is typically obtained through + self.grad. + :param v: An array ([DxP] or even one dimensional [D]) which + multiplies the matrix, where D is the number of directions. + :param progress: True, iff progress shall be printed. + :param backprop_on: tensor used in the second backpropagation (the first + one is along $x$ and $y$ as defined via grad_xy). + :returns: A matrix representing the implicit matrix vector product + of the model along the given directions. Output shape is [DxP] if + backprop_on is None, otherwise [DxM], with M the number of elements + of backprop_on. + """ + device = grad_xy.device + v = as_tensor(v, warn=False).to(device) + if v.ndim == 1: + v = v.unsqueeze(0) + + z = (grad_xy * Variable(v)).sum(dim=1) + + mvp = [] + for i in maybe_progress(range(len(z)), progress, desc="MVP"): + mvp.append( + flatten_tensors_to_vector( + autograd.grad(z[i], backprop_on, retain_graph=True) + ) + ) + return torch.stack([grad.contiguous().view(-1) for grad in mvp]).detach() + + +@dataclass +class LowRankProductRepresentation: + """ + Representation of a low rank product of the form $H = V D V^T$, where D is a diagonal matrix and + V is orthogonal + :param eigen_vals: diagonal of D + :param projections: the matrix V + """ + + eigen_vals: torch.Tensor + projections: torch.Tensor + + @property + def device(self) -> torch.device: + return ( + self.eigen_vals.device + if hasattr(self.eigen_vals, "device") + else torch.device("cpu") + ) + + def to(self, device: torch.device): + """ + Move the representing tensors to a device + """ + return LowRankProductRepresentation( + self.eigen_vals.to(device), self.projections.to(device) + ) + + def __post_init__(self): + if self.eigen_vals.device != self.projections.device: + raise ValueError("eigen_vals and projections must be on the same device.") + + +def lanzcos_low_rank_hessian_approx( + hessian_vp: Callable[[torch.Tensor], torch.Tensor], + matrix_shape: Tuple[int, int], + hessian_perturbation: float = 0.0, + rank_estimate: int = 10, + krylov_dimension: Optional[int] = None, + tol: float = 1e-6, + max_iter: Optional[int] = None, + device: Optional[torch.device] = None, + eigen_computation_on_gpu: bool = False, + torch_dtype: torch.dtype = None, +) -> LowRankProductRepresentation: + """ + Calculates a low-rank approximation of the Hessian matrix of a scalar-valued + function using the implicitly restarted Lanczos algorithm. + + :param hessian_vp: A function that takes a vector and returns the product of + the Hessian of the loss function. + :param matrix_shape: The shape of the matrix, represented by hessian vector + product. + :param hessian_perturbation: Optional regularization parameter added to the + Hessian-vector product for numerical stability. + :param rank_estimate: The number of eigenvalues and corresponding eigenvectors + to compute. Represents the desired rank of the Hessian approximation. + :param krylov_dimension: The number of Krylov vectors to use for the Lanczos + method. If not provided, it defaults to + $min(model.num_parameters, max(2*rank_estimate + 1, 20))$. + :param tol: The stopping criteria for the Lanczos algorithm, which stops when + the difference in the approximated eigenvalue is less than ``tol``. + Defaults to 1e-6. + :param max_iter: The maximum number of iterations for the Lanczos method. If + not provided, it defaults to ``10 * model.num_parameters``. + :param device: The device to use for executing the hessian vector product. + :param eigen_computation_on_gpu: If ``True``, tries to execute the eigen pair + approximation on the provided device via `cupy `_ + implementation. Make sure that either your model is small enough, or you + use a small rank_estimate to fit your device's memory. If ``False``, the + eigen pair approximation is executed on the CPU with scipy's wrapper to + ARPACK. + :param torch_dtype: if not provided, current torch default dtype is used for + conversion to torch. + + :return: An object that contains the top- ``rank_estimate`` eigenvalues and + corresponding eigenvectors of the Hessian. + """ + + torch_dtype = torch.get_default_dtype() if torch_dtype is None else torch_dtype + + if eigen_computation_on_gpu: + try: + import cupy as cp + from cupyx.scipy.sparse.linalg import LinearOperator, eigsh + from torch.utils.dlpack import from_dlpack, to_dlpack + except ImportError as e: + raise ImportError( + f"Try to install missing dependencies or set eigen_computation_on_gpu to False: {e}" + ) + + if device is None: + raise ValueError( + "Without setting an explicit device, cupy is not supported" + ) + + def to_torch_conversion_function(x): + return from_dlpack(x.toDlpack()).to(torch_dtype) + + def mv(x): + x = to_torch_conversion_function(x) + y = hessian_vp(x) + hessian_perturbation * x + return cp.from_dlpack(to_dlpack(y)) + + else: + from scipy.sparse.linalg import LinearOperator, eigsh + + def mv(x): + x_torch = torch.as_tensor(x, device=device, dtype=torch_dtype) + y: NDArray = ( + (hessian_vp(x_torch) + hessian_perturbation * x_torch) + .detach() + .cpu() + .numpy() + ) + return y + + to_torch_conversion_function = partial(torch.as_tensor, dtype=torch_dtype) + + try: + eigen_vals, eigen_vecs = eigsh( + LinearOperator(matrix_shape, matvec=mv), + k=rank_estimate, + maxiter=max_iter, + tol=tol, + ncv=krylov_dimension, + return_eigenvectors=True, + ) + + except ArpackNoConvergence as e: + logger.warning( + f"ARPACK did not converge for parameters {max_iter=}, {tol=}, {krylov_dimension=}, " + f"{rank_estimate=}. \n Returning the best approximation found so far. Use those with care or " + f"modify parameters.\n Original error: {e}" + ) + + eigen_vals, eigen_vecs = e.eigenvalues, e.eigenvectors + + eigen_vals = to_torch_conversion_function(eigen_vals) + eigen_vecs = to_torch_conversion_function(eigen_vecs) + + return LowRankProductRepresentation(eigen_vals, eigen_vecs) + + +def model_hessian_low_rank( + model: TorchTwiceDifferentiable, + training_data: DataLoader, + hessian_perturbation: float = 0.0, + rank_estimate: int = 10, + krylov_dimension: Optional[int] = None, + tol: float = 1e-6, + max_iter: Optional[int] = None, + eigen_computation_on_gpu: bool = False, +) -> LowRankProductRepresentation: """ - Simple function to flatten a pyTorch gradient for use in subsequent calculation + Calculates a low-rank approximation of the Hessian matrix of the model's loss function using the implicitly + restarted Lanczos algorithm. + + :param model: A PyTorch model instance that is twice differentiable, wrapped into :class:`TorchTwiceDifferential`. + The Hessian will be calculated with respect to this model's parameters. + :param training_data: A DataLoader instance that provides the model's training data. + Used in calculating the Hessian-vector products. + :param hessian_perturbation: Optional regularization parameter added to the Hessian-vector product + for numerical stability. + :param rank_estimate: The number of eigenvalues and corresponding eigenvectors to compute. + Represents the desired rank of the Hessian approximation. + :param krylov_dimension: The number of Krylov vectors to use for the Lanczos method. + If not provided, it defaults to $min(model.num_parameters, max(2*rank_estimate + 1, 20))$. + :param tol: The stopping criteria for the Lanczos algorithm, which stops when the difference + in the approximated eigenvalue is less than `tol`. Defaults to 1e-6. + :param max_iter: The maximum number of iterations for the Lanczos method. If not provided, it defaults to + $10*model.num_parameters$ + :param eigen_computation_on_gpu: If True, tries to execute the eigen pair approximation on the provided + device via cupy implementation. + Make sure, that either your model is small enough or you use a + small rank_estimate to fit your device's memory. + If False, the eigen pair approximation is executed on the CPU by scipy wrapper to + ARPACK. + :return: A `LowRankProductRepresentation` instance that contains the top (up until rank_estimate) eigenvalues + and corresponding eigenvectors of the Hessian. """ - return torch.cat([el.reshape(-1) for el in grad]) + raw_hvp = get_hvp_function( + model.model, model.loss, training_data, use_hessian_avg=True + ) + + return lanzcos_low_rank_hessian_approx( + hessian_vp=raw_hvp, + matrix_shape=(model.num_params, model.num_params), + hessian_perturbation=hessian_perturbation, + rank_estimate=rank_estimate, + krylov_dimension=krylov_dimension, + tol=tol, + max_iter=max_iter, + device=model.device if hasattr(model, "device") else None, + eigen_computation_on_gpu=eigen_computation_on_gpu, + ) + + +class TorchTensorUtilities(TensorUtilities[torch.Tensor, TorchTwiceDifferentiable]): + twice_differentiable_type = TorchTwiceDifferentiable + + @staticmethod + def einsum(equation: str, *operands) -> torch.Tensor: + """Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation + based on the Einstein summation convention. + """ + return torch.einsum(equation, *operands) + + @staticmethod + def cat(a: Sequence[torch.Tensor], **kwargs) -> torch.Tensor: + """Concatenates a sequence of tensors into a single torch tensor""" + return torch.cat(a, **kwargs) + + @staticmethod + def stack(a: Sequence[torch.Tensor], **kwargs) -> torch.Tensor: + """Stacks a sequence of tensors into a single torch tensor""" + return torch.stack(a, **kwargs) + + @staticmethod + def unsqueeze(x: torch.Tensor, dim: int) -> torch.Tensor: + """ + Add a singleton dimension at a specified position in a tensor. + + :param x: A PyTorch tensor. + :param dim: The position at which to add the singleton dimension. Zero-based indexing. + :return: A new tensor with an additional singleton dimension. + """ + return x.unsqueeze(dim) + + @staticmethod + def get_element(x: torch.Tensor, idx: int) -> torch.Tensor: + return x[idx] + + @staticmethod + def slice(x: torch.Tensor, start: int, stop: int, axis: int = 0) -> torch.Tensor: + slicer = [slice(None) for _ in x.shape] + slicer[axis] = slice(start, stop) + return x[tuple(slicer)] + + @staticmethod + def shape(x: torch.Tensor) -> Tuple[int, ...]: + return x.shape # type:ignore + + @staticmethod + def reshape(x: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: + return x.reshape(shape) + + @staticmethod + def cat_gen( + a: Generator[torch.Tensor, None, None], + resulting_shape: Tuple[int, ...], + model: TorchTwiceDifferentiable, + axis: int = 0, + ) -> torch.Tensor: + result = torch.empty(resulting_shape, dtype=model.dtype, device=model.device) + + start_idx = 0 + for x in a: + stop_idx = start_idx + x.shape[axis] + + slicer = [slice(None) for _ in resulting_shape] + slicer[axis] = slice(start_idx, stop_idx) + + result[tuple(slicer)] = x + start_idx = stop_idx + return result + + +@InversionRegistry.register(TorchTwiceDifferentiable, InversionMethod.Direct) def solve_linear( - model: TwiceDifferentiable, + model: TorchTwiceDifferentiable, training_data: DataLoader, b: torch.Tensor, - *, hessian_perturbation: float = 0.0, - progress: bool = False, ) -> InverseHvpResult: """Given a model and training data, it finds x s.t. $Hx = b$, with $H$ being the model hessian. @@ -63,7 +455,6 @@ def solve_linear( :param training_data: A DataLoader containing the training data. :param b: a vector or matrix, the right hand side of the equation $Hx = b$. :param hessian_perturbation: regularization of the hessian - :param progress: If True, display progress bars. :return: An array that solves the inverse problem, i.e. it returns $x$ such that $Hx = b$, and a dictionary containing @@ -74,22 +465,21 @@ def solve_linear( for x, y in training_data: all_x.append(x) all_y.append(y) - all_x = cat(all_x) - all_y = cat(all_y) - hessian = model.hessian(all_x, all_y, progress=progress) - matrix = hessian + hessian_perturbation * identity_tensor( + hessian = model.hessian(torch.cat(all_x), torch.cat(all_y)) + matrix = hessian + hessian_perturbation * torch.eye( model.num_params, device=model.device ) info = {"hessian": hessian} return InverseHvpResult(x=torch.linalg.solve(matrix, b.T).T, info=info) +@InversionRegistry.register(TorchTwiceDifferentiable, InversionMethod.Cg) def solve_batch_cg( - model: TwiceDifferentiable, + model: TorchTwiceDifferentiable, training_data: DataLoader, b: torch.Tensor, - *, hessian_perturbation: float = 0.0, + *, x0: Optional[torch.Tensor] = None, rtol: float = 1e-7, atol: float = 1e-7, @@ -119,11 +509,11 @@ def solve_batch_cg( total_grad_xy = 0 total_points = 0 for x, y in maybe_progress(training_data, progress, desc="Batch Train Gradients"): - grad_xy, _ = model.grad(x, y) + grad_xy = model.grad(x, y, create_graph=True) total_grad_xy += grad_xy * len(x) total_points += len(x) backprop_on = model.parameters - reg_hvp = lambda v: mvp( + reg_hvp = lambda v: model.mvp( total_grad_xy / total_points, v, backprop_on ) + hessian_perturbation * v.type(torch.float64) batch_cg = torch.zeros_like(b) @@ -188,12 +578,13 @@ def solve_cg( return InverseHvpResult(x=x, info=info) +@InversionRegistry.register(TorchTwiceDifferentiable, InversionMethod.Lissa) def solve_lissa( - model: TwiceDifferentiable, + model: TorchTwiceDifferentiable, training_data: DataLoader, b: torch.Tensor, - *, hessian_perturbation: float = 0.0, + *, maxiter: int = 1000, dampen: float = 0.0, scale: float = 10.0, @@ -249,8 +640,10 @@ def lissa_step( for _ in maybe_progress(range(maxiter), progress, desc="Lissa"): x, y = next(iter(shuffled_training_data)) - grad_xy, _ = model.grad(x, y) - reg_hvp = lambda v: mvp(grad_xy, v, model.parameters) + hessian_perturbation * v + grad_xy = model.grad(x, y, create_graph=True) + reg_hvp = ( + lambda v: model.mvp(grad_xy, v, model.parameters) + hessian_perturbation * v + ) residual = lissa_step(h_estimate, reg_hvp) - h_estimate h_estimate += residual if torch.isnan(h_estimate).any(): @@ -270,244 +663,13 @@ def lissa_step( return InverseHvpResult(x=h_estimate / scale, info=info) -def as_tensor(a: Any, warn=True, **kwargs) -> torch.Tensor: - """Converts an array into a torch tensor - - :param a: array to convert to tensor - :param warn: if True, warns that a will be converted - """ - if warn and not isinstance(a, torch.Tensor): - logger.warning("Converting tensor to type torch.Tensor.") - return torch.as_tensor(a, **kwargs) - - -def stack(a: Sequence[torch.Tensor], **kwargs) -> torch.Tensor: - """Stacks a sequence of tensors into a single torch tensor""" - return torch.stack(a, **kwargs) - - -def cat(a: Sequence[torch.Tensor], **kwargs) -> torch.Tensor: - """Concatenates a sequence of tensors into a single torch tensor""" - return torch.cat(a, **kwargs) - - -def zero_tensor( - shape: Sequence[int], dtype: Union[np.dtype, torch.dtype], **kwargs -) -> torch.Tensor: - """Returns a tensor of shape :attr:`shape` filled with zeros.""" - if isinstance(dtype, np.dtype): - dtype = getattr(torch, dtype.name) - return torch.zeros(shape, dtype=dtype, **kwargs) - - -def transpose_tensor(a: torch.Tensor, dim0: int, dim1: int) -> torch.Tensor: - return torch.transpose(a, dim0, dim1) - - -def einsum(equation, *operands) -> torch.Tensor: - """Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation - based on the Einstein summation convention. - """ - return torch.einsum(equation, *operands) - - -def identity_tensor(dim: int, **kwargs) -> torch.Tensor: - return torch.eye(dim, dim, **kwargs) - - -def mvp( - grad_xy: torch.Tensor, - v: torch.Tensor, - backprop_on: torch.Tensor, - *, - progress: bool = False, -) -> torch.Tensor: - """ - Calculates second order derivative of the model along directions v. - This second order derivative can be selected through the backprop_on argument. - - :param grad_xy: an array [P] holding the gradients of the model - parameters wrt input $x$ and labels $y$, where P is the number of - parameters of the model. It is typically obtained through - self.grad. - :param v: An array ([DxP] or even one dimensional [D]) which - multiplies the matrix, where D is the number of directions. - :param progress: True, iff progress shall be printed. - :param backprop_on: tensor used in the second backpropagation (the first - one is along $x$ and $y$ as defined via grad_xy). - :returns: A matrix representing the implicit matrix vector product - of the model along the given directions. Output shape is [DxP] if - backprop_on is None, otherwise [DxM], with M the number of elements - of backprop_on. - """ - device = grad_xy.device - v = as_tensor(v, warn=False).to(device) - if v.ndim == 1: - v = v.unsqueeze(0) - - z = (grad_xy * Variable(v)).sum(dim=1) - - mvp = [] - for i in maybe_progress(range(len(z)), progress, desc="MVP"): - mvp.append(flatten_all(autograd.grad(z[i], backprop_on, retain_graph=True))) - mvp = torch.stack([grad.contiguous().view(-1) for grad in mvp]) - return mvp.detach() # type: ignore - - -class TorchTwiceDifferentiable( - TwiceDifferentiable[torch.Tensor, nn.Module, torch.device] -): - def __init__( - self, - model: nn.Module, - loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - *, - device: torch.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ), - ): - r""" - :param model: A (differentiable) function. - :param loss: :param loss: A differentiable scalar loss $L(\hat{y}, y)$, - mapping a prediction and a target to a real value. - :param device: device to use for computations. Defaults to cuda if available. - """ - if model.training: - logger.warning( - "Passed model not in evaluation mode. This can create several issues in influence " - "computation, e.g. due to batch normalization. Please call model.eval() before " - "computing influences." - ) - self.model = model.to(device) - self.loss = loss - self.device = device - - @property - def parameters(self) -> List[torch.Tensor]: - """Returns all the model parameters that require differentiating""" - return [ - param for param in self.model.parameters() if param.requires_grad == True - ] - - @property - def num_params(self) -> int: - """ - Get number of parameters of model f. - :returns: Number of parameters as integer. - """ - return sum([np.prod(p.size()) for p in self.parameters]) - - def split_grad( - self, x: torch.Tensor, y: torch.Tensor, *, progress: bool = False - ) -> torch.Tensor: - """ - Calculates gradient of model parameters wrt each $x[i]$ and $y[i]$ and then - returns a array of size [N, P] with N number of points (length of x and y) and P - number of parameters of the model. - - :param x: An array [NxD] representing the features $x_i$. - :param y: An array [NxK] representing the predicted target values $y_i$. - :param progress: True, iff progress shall be printed. - :returns: An array [NxP] representing the gradients with respect to - all parameters of the model. - """ - x = as_tensor(x, warn=False).to(self.device).unsqueeze(1) - y = as_tensor(y, warn=False).to(self.device) - - grads = [] - for i in maybe_progress(range(len(x)), progress, desc="Split Gradient"): - grads.append( - flatten_all( - autograd.grad( - self.loss(torch.squeeze(self.model(x[i])), torch.squeeze(y[i])), - self.parameters, - ) - ).detach() - ) - - return torch.stack(grads, axis=0) - - def grad( - self, x: torch.Tensor, y: torch.Tensor, *, x_requires_grad: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculates gradient of model parameters wrt the model parameters. - - :param x: A matrix [NxD] representing the features $x_i$. - :param y: A matrix [NxK] representing the target values $y_i$. - :param x_requires_grad: If True, the input $x$ is marked as requiring - gradients. This is important for further differentiation on input - parameters. - :returns: A tuple where the first element is an array [P] with the - gradients of the model and second element is the input to the model - as a grad parameters. This can be used for further differentiation. - """ - x = as_tensor(x, warn=False).to(self.device).requires_grad_(x_requires_grad) - y = as_tensor(y, warn=False).to(self.device) - - loss_value = self.loss(torch.squeeze(self.model(x)), torch.squeeze(y)) - grad_f = torch.autograd.grad(loss_value, self.parameters, create_graph=True) - return flatten_all(grad_f), x - - def hessian( - self, x: torch.Tensor, y: torch.Tensor, *, progress: bool = False - ) -> torch.Tensor: - """Calculates the explicit hessian of model parameters given data ($x$ and $y$). - :param x: A matrix [NxD] representing the features $x_i$. - :param y: A matrix [NxK] representing the target values $y_i$. - :param progress: ``True`` to display progress. - :returns: the hessian of the model, i.e. the second derivative wrt. the model parameters. - """ - x = x.to(self.device) - y = y.to(self.device) - grad_xy, _ = self.grad(x, y) - return mvp( - grad_xy, - torch.eye(self.num_params, self.num_params, device=self.device), - self.parameters, - progress=progress, - ) - - -@dataclass -class LowRankProductRepresentation: - """ - Representation of a low rank product of the form $H = V D V^T$, where D is a diagonal matrix and - V is orthogonal - :param eigen_vals: diagonal of D - :param projections: the matrix V - """ - - eigen_vals: torch.Tensor - projections: torch.Tensor - - @property - def device(self) -> torch.device: - return ( - self.eigen_vals.device - if hasattr(self.eigen_vals, "device") - else torch.device("cpu") - ) - - def to(self, device: torch.device): - """ - Move the representing tensors to a device - """ - return LowRankProductRepresentation( - self.eigen_vals.to(device), self.projections.to(device) - ) - - def __post_init__(self): - if self.eigen_vals.device != self.projections.device: - raise ValueError("eigen_vals and projections must be on the same device.") - - +@InversionRegistry.register(TorchTwiceDifferentiable, InversionMethod.Arnoldi) def solve_arnoldi( model: TorchTwiceDifferentiable, training_data: DataLoader, b: torch.Tensor, - *, hessian_perturbation: float = 0.0, + *, rank_estimate: int = 10, krylov_dimension: Optional[int] = None, low_rank_representation: Optional[LowRankProductRepresentation] = None, @@ -515,7 +677,6 @@ def solve_arnoldi( max_iter: Optional[int] = None, eigen_computation_on_gpu: bool = False, ) -> InverseHvpResult: - """ Solves the linear system Hx = b, where H is the Hessian of the model's loss function and b is the given right-hand side vector. The Hessian is approximated using a low-rank representation. @@ -552,7 +713,6 @@ def solve_arnoldi( b_device = b.device if hasattr(b, "device") else torch.device("cpu") if low_rank_representation is None: - if b_device.type == "cuda" and not eigen_computation_on_gpu: raise ValueError( "Using 'eigen_computation_on_gpu=False' while 'b' is on a 'cuda' device is not supported. " @@ -598,169 +758,3 @@ def solve_arnoldi( "eigenvectors": low_rank_representation.projections, }, ) - - -def lanzcos_low_rank_hessian_approx( - hessian_vp: Callable[[torch.Tensor], torch.Tensor], - matrix_shape: Tuple[int, int], - hessian_perturbation: float = 0.0, - rank_estimate: int = 10, - krylov_dimension: Optional[int] = None, - tol: float = 1e-6, - max_iter: Optional[int] = None, - device: Optional[torch.device] = None, - eigen_computation_on_gpu: bool = False, - torch_dtype: torch.dtype = None, -) -> LowRankProductRepresentation: - """ - Calculates a low-rank approximation of the Hessian matrix of a scalar-valued function using the implicitly - restarted Lanczos algorithm. - - - :param hessian_vp: A function that takes a vector and returns the product of the Hessian of the loss function - :param matrix_shape: The shape of the matrix, represented by hessian vector product. - :param hessian_perturbation: Optional regularization parameter added to the Hessian-vector product - for numerical stability. - :param rank_estimate: The number of eigenvalues and corresponding eigenvectors to compute. - Represents the desired rank of the Hessian approximation. - :param krylov_dimension: The number of Krylov vectors to use for the Lanczos method. - If not provided, it defaults to $min(model.num_parameters, max(2*rank_estimate + 1, 20))$. - :param tol: The stopping criteria for the Lanczos algorithm, which stops when the difference - in the approximated eigenvalue is less than `tol`. Defaults to 1e-6. - :param max_iter: The maximum number of iterations for the Lanczos method. If not provided, it defaults to - $10*model.num_parameters$ - :param device: The device to use for executing the hessian vector product. - :param eigen_computation_on_gpu: If True, tries to execute the eigen pair approximation on the provided - device via cupy implementation. - Make sure, that either your model is small enough or you use a - small rank_estimate to fit your device's memory. - If False, the eigen pair approximation is executed on the CPU by scipy wrapper to - ARPACK. - :param torch_dtype: if not provided, current torch default dtype is used for conversion to torch - :return: A `LowRankProductRepresentation` instance that contains the top (up until rank_estimate) eigenvalues - and corresponding eigenvectors of the Hessian. - """ - - torch_dtype = torch.get_default_dtype() if torch_dtype is None else torch_dtype - - if eigen_computation_on_gpu: - try: - import cupy as cp - from cupyx.scipy.sparse.linalg import LinearOperator, eigsh - from torch.utils.dlpack import from_dlpack, to_dlpack - except ImportError as e: - raise ImportError( - f"Try to install missing dependencies or set eigen_computation_on_gpu to False: {e}" - ) - - if device is None: - raise ValueError( - "Without setting an explicit device, cupy is not supported" - ) - - def to_torch_conversion_function(x): - return from_dlpack(x.toDlpack()).to(torch_dtype) - - def mv(x): - x = to_torch_conversion_function(x) - y = hessian_vp(x) + hessian_perturbation * x - return cp.from_dlpack(to_dlpack(y)) - - else: - from scipy.sparse.linalg import LinearOperator, eigsh - - def mv(x): - x_torch = torch.as_tensor(x, device=device, dtype=torch_dtype) - y: NDArray = ( - (hessian_vp(x_torch) + hessian_perturbation * x_torch) - .detach() - .cpu() - .numpy() - ) - return y - - to_torch_conversion_function = partial(torch.as_tensor, dtype=torch_dtype) - - try: - - eigen_vals, eigen_vecs = eigsh( - LinearOperator(matrix_shape, matvec=mv), - k=rank_estimate, - maxiter=max_iter, - tol=tol, - ncv=krylov_dimension, - return_eigenvectors=True, - ) - - except ArpackNoConvergence as e: - logger.warning( - f"ARPACK did not converge for parameters {max_iter=}, {tol=}, {krylov_dimension=}, " - f"{rank_estimate=}. \n Returning the best approximation found so far. Use those with care or " - f"modify parameters.\n Original error: {e}" - ) - - eigen_vals, eigen_vecs = e.eigenvalues, e.eigenvectors - - eigen_vals = to_torch_conversion_function(eigen_vals) - eigen_vecs = to_torch_conversion_function(eigen_vecs) - - return LowRankProductRepresentation(eigen_vals, eigen_vecs) - - -def model_hessian_low_rank( - model: TorchTwiceDifferentiable, - training_data: DataLoader, - hessian_perturbation: float = 0.0, - rank_estimate: int = 10, - krylov_dimension: Optional[int] = None, - tol: float = 1e-6, - max_iter: Optional[int] = None, - eigen_computation_on_gpu: bool = False, -) -> LowRankProductRepresentation: - """ - Calculates a low-rank approximation of the Hessian matrix of the model's loss function using the implicitly - restarted Lanczos algorithm. - - :param model: A PyTorch model instance that is twice differentiable, wrapped into :class:`TorchTwiceDifferential`. - The Hessian will be calculated with respect to this model's parameters. - :param training_data: A DataLoader instance that provides the model's training data. - Used in calculating the Hessian-vector products. - :param hessian_perturbation: Optional regularization parameter added to the Hessian-vector product - for numerical stability. - :param rank_estimate: The number of eigenvalues and corresponding eigenvectors to compute. - Represents the desired rank of the Hessian approximation. - :param krylov_dimension: The number of Krylov vectors to use for the Lanczos method. - If not provided, it defaults to $min(model.num_parameters, max(2*rank_estimate + 1, 20))$. - :param tol: The stopping criteria for the Lanczos algorithm, which stops when the difference - in the approximated eigenvalue is less than `tol`. Defaults to 1e-6. - :param max_iter: The maximum number of iterations for the Lanczos method. If not provided, it defaults to - $10*model.num_parameters$ - :param eigen_computation_on_gpu: If True, tries to execute the eigen pair approximation on the provided - device via cupy implementation. - Make sure, that either your model is small enough or you use a - small rank_estimate to fit your device's memory. - If False, the eigen pair approximation is executed on the CPU by scipy wrapper to - ARPACK. - :return: A `LowRankProductRepresentation` instance that contains the top (up until rank_estimate) eigenvalues - and corresponding eigenvectors of the Hessian. - """ - raw_hvp = get_hvp_function( - model.model, model.loss, training_data, use_hessian_avg=True - ) - params = dict(model.model.named_parameters()) - - def hessian_vector_product(x: torch.Tensor) -> torch.Tensor: - output = raw_hvp(align_structure(params, x)) - return flatten_tensors_to_vector(output.values()) - - return lanzcos_low_rank_hessian_approx( - hessian_vp=hessian_vector_product, - matrix_shape=(model.num_params, model.num_params), - hessian_perturbation=hessian_perturbation, - rank_estimate=rank_estimate, - krylov_dimension=krylov_dimension, - tol=tol, - max_iter=max_iter, - device=model.device if hasattr(model, "device") else None, - eigen_computation_on_gpu=eigen_computation_on_gpu, - ) diff --git a/src/pydvl/influence/frameworks/util.py b/src/pydvl/influence/torch/util.py similarity index 92% rename from src/pydvl/influence/frameworks/util.py rename to src/pydvl/influence/torch/util.py index 540df53bc..b758458a3 100644 --- a/src/pydvl/influence/frameworks/util.py +++ b/src/pydvl/influence/torch/util.py @@ -1,6 +1,6 @@ import logging import math -from typing import Dict, Iterable, Tuple, TypeVar +from typing import Any, Dict, Iterable, Tuple, TypeVar import torch @@ -142,3 +142,14 @@ def align_structure( raise ValueError(f"'target' is of type {type(target)} which is not supported.") return tangent_dict + + +def as_tensor(a: Any, warn=True, **kwargs) -> torch.Tensor: + """Converts an array into a torch tensor + + :param a: array to convert to tensor + :param warn: if True, warns that a will be converted + """ + if warn and not isinstance(a, torch.Tensor): + logger.warning("Converting tensor to type torch.Tensor.") + return torch.as_tensor(a, **kwargs) diff --git a/src/pydvl/influence/twice_differentiable.py b/src/pydvl/influence/twice_differentiable.py new file mode 100644 index 000000000..37e67e9c6 --- /dev/null +++ b/src/pydvl/influence/twice_differentiable.py @@ -0,0 +1,210 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import ( + Any, + Dict, + Generator, + Generic, + Iterable, + List, + Sequence, + Tuple, + Type, + TypeVar, +) + +TensorType = TypeVar("TensorType", bound=Sequence) +ModelType = TypeVar("ModelType", bound="TwiceDifferentiable") +DataLoaderType = TypeVar("DataLoaderType", bound=Iterable) + + +@dataclass(frozen=True) +class InverseHvpResult(Generic[TensorType]): + x: TensorType + info: Dict[str, Any] + + def __iter__(self): + return iter((self.x, self.info)) + + +class TwiceDifferentiable(ABC, Generic[TensorType]): + """ + Wraps a differentiable model and loss and provides methods to compute gradients and + second derivative of the loss wrt. the model parameters + """ + + @classmethod + @abstractmethod + def tensor_type(cls): + pass + + @property + @abstractmethod + def num_params(self) -> int: + """Returns the number of parameters of the model""" + pass + + @property + @abstractmethod + def parameters(self) -> List[TensorType]: + """Returns all the model parameters that require differentiation""" + pass + + def grad( + self, x: TensorType, y: TensorType, create_graph: bool = False + ) -> TensorType: + """ + Calculates gradient of model parameters wrt. the model parameters. + + :param x: A matrix representing the features $x_i$. + :param y: A matrix representing the target values $y_i$. + gradients. This is important for further differentiation on input + parameters. + :param create_graph: + :return: A tuple where: the first element is an array with the + gradients of the model, and the second element is the input to the + model as a grad parameters. This can be used for further + differentiation. + """ + pass + + def hessian(self, x: TensorType, y: TensorType) -> TensorType: + """Calculates the full Hessian of $L(f(x),y)$ with respect to the model + parameters given data ($x$ and $y$). + + :param x: An array representing the features $x_i$. + :param y: An array representing the target values $y_i$. + :return: The hessian of the model, i.e. the second derivative wrt. the + model parameters. + """ + pass + + @staticmethod + @abstractmethod + def mvp( + grad_xy: TensorType, + v: TensorType, + backprop_on: TensorType, + *, + progress: bool = False, + ) -> TensorType: + """ + Calculates second order derivative of the model along directions v. + This second order derivative can be selected through the backprop_on argument. + + :param grad_xy: an array [P] holding the gradients of the model + parameters wrt input $x$ and labels $y$, where P is the number of + parameters of the model. It is typically obtained through + self.grad. + :param v: An array ([DxP] or even one dimensional [D]) which + multiplies the matrix, where D is the number of directions. + :param progress: True, iff progress shall be printed. + :param backprop_on: tensor used in the second backpropagation (the first + one is along $x$ and $y$ as defined via grad_xy). + :returns: A matrix representing the implicit matrix vector product + of the model along the given directions. Output shape is [DxP] if + backprop_on is None, otherwise [DxM], with M the number of elements + of backprop_on. + """ + + +class TensorUtilities(Generic[TensorType, ModelType], ABC): + twice_differentiable_type: Type[TwiceDifferentiable] + registry: Dict[Type[TwiceDifferentiable], Type["TensorUtilities"]] = {} + + def __init_subclass__(cls, **kwargs): + """ + Automatically registers non-abstract subclasses in the registry. + + Checks if `twice_differentiable_type` is defined in the subclass and + is of correct type. Raises `TypeError` if either attribute is missing or incorrect. + + :param kwargs: Additional keyword arguments. + :raise TypeError: If the subclass does not define `twice_differentiable_type`, + or if it is not of correct type. + """ + if not hasattr(cls, "twice_differentiable_type") or not isinstance( + cls.twice_differentiable_type, type + ): + raise TypeError( + f"'twice_differentiable_type' must be a Type[TwiceDifferentiable]" + ) + + cls.registry[cls.twice_differentiable_type] = cls + + super().__init_subclass__(**kwargs) + + @staticmethod + @abstractmethod + def einsum(equation, *operands) -> TensorType: + """Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation + based on the Einstein summation convention. + """ + + @staticmethod + @abstractmethod + def cat(a: Sequence[TensorType], **kwargs) -> TensorType: + """Concatenates a sequence of tensors into a single torch tensor""" + + @staticmethod + @abstractmethod + def stack(a: Sequence[TensorType], **kwargs) -> TensorType: + """Stacks a sequence of tensors into a single torch tensor""" + + @staticmethod + @abstractmethod + def unsqueeze(x: TensorType, dim: int) -> TensorType: + """Add a singleton dimension at a specified position in a tensor""" + + @staticmethod + @abstractmethod + def get_element(x: TensorType, idx: int) -> TensorType: + """Get the tensor element x[i] from the first non-singular dimension""" + + @staticmethod + @abstractmethod + def slice(x: TensorType, start: int, stop: int, axis: int = 0) -> TensorType: + """Slice a tensor in the provided axis""" + + @staticmethod + @abstractmethod + def shape(x: TensorType) -> Tuple[int, ...]: + """Slice a tensor in the provided axis""" + + @staticmethod + @abstractmethod + def reshape(x: TensorType, shape: Tuple[int, ...]) -> TensorType: + """Reshape a tensor to the provided shape""" + + @staticmethod + @abstractmethod + def cat_gen( + a: Generator[TensorType, None, None], + resulting_shape: Tuple[int, ...], + model: ModelType, + ) -> TensorType: + """Concatenate tensors from a generator. Resulting tensor is of shape resulting_shape + and compatible to model + """ + + @classmethod + def from_twice_differentiable( + cls, + twice_diff: TwiceDifferentiable, + ) -> Type["TensorUtilities"]: + """ + Factory method to create an instance of `TensorUtilities` from an instance of `TwiceDifferentiable`. + + :param twice_diff: An instance of `TwiceDifferentiable` + for which a corresponding `TensorUtilities` object is required. + :return: An instance of `TensorUtilities` corresponding to the provided `TwiceDifferentiable` object. + :raises KeyError: If there's no registered `TensorUtilities` for the provided `TwiceDifferentiable` type. + """ + tu = cls.registry.get(type(twice_diff), None) + + if tu is None: + raise KeyError( + f"No registered TensorUtilities for the type {type(twice_diff).__name__}" + ) + + return tu diff --git a/tests/influence/test_influences.py b/tests/influence/test_influences.py index 5e42f9547..ca953f43e 100644 --- a/tests/influence/test_influences.py +++ b/tests/influence/test_influences.py @@ -12,9 +12,8 @@ from torch.optim import LBFGS from torch.utils.data import DataLoader, TensorDataset -from pydvl.influence import TorchTwiceDifferentiable, compute_influences -from pydvl.influence.frameworks.torch_differentiable import model_hessian_low_rank -from pydvl.influence.general import InfluenceType, InversionMethod +from pydvl.influence import InfluenceType, InversionMethod, compute_influences +from pydvl.influence.torch import TorchTwiceDifferentiable, model_hessian_low_rank from .conftest import ( add_noise_to_linear_model, @@ -138,7 +137,7 @@ def test_influence_linear_model( ) influence_values = compute_influences( - TorchTwiceDifferentiable(linear_layer, loss, device=torch.device("cpu")), + TorchTwiceDifferentiable(linear_layer, loss), training_data=train_data_loader, test_data=test_data_loader, input_data=input_data, @@ -336,7 +335,7 @@ def test_influences_nn( model = module_factory() model.eval() - model = TorchTwiceDifferentiable(model, loss, device=torch.device("cpu")) + model = TorchTwiceDifferentiable(model, loss) direct_influence = compute_influences( model, @@ -437,7 +436,7 @@ def test_influences_arnoldi( ) nn_architecture = nn_architecture.eval() - model = TorchTwiceDifferentiable(nn_architecture, loss, device=torch.device("cpu")) + model = TorchTwiceDifferentiable(nn_architecture, loss) direct_influence = compute_influences( model, diff --git a/tests/influence/test_torch_differentiable.py b/tests/influence/test_torch_differentiable.py index a949eda40..621288d6f 100644 --- a/tests/influence/test_torch_differentiable.py +++ b/tests/influence/test_torch_differentiable.py @@ -23,9 +23,8 @@ from torch import nn from torch.utils.data import DataLoader -from pydvl.influence.frameworks.torch_differentiable import ( +from pydvl.influence.torch import ( TorchTwiceDifferentiable, - mvp, solve_batch_cg, solve_linear, solve_lissa, @@ -48,7 +47,7 @@ def linear_mvp_model(A, b): model.weight.data = torch.as_tensor(A) model.bias.data = torch.as_tensor(b) loss = F.mse_loss - return TorchTwiceDifferentiable(model=model, loss=loss, device=torch.device("cpu")) + return TorchTwiceDifferentiable(model=model, loss=loss) @pytest.mark.torch @@ -74,7 +73,13 @@ def test_linear_grad( mvp_model = linear_mvp_model(A, b) train_grads_analytical = linear_derivative_analytical((A, b), train_x, train_y) - train_grads_autograd = mvp_model.split_grad(train_x, train_y) + train_x = torch.as_tensor(train_x).unsqueeze(1) + train_y = torch.as_tensor(train_y) + + train_grads_autograd = torch.stack( + [mvp_model.grad(inpt, target) for inpt, target in zip(train_x, train_y)] + ) + assert np.allclose(train_grads_analytical, train_grads_autograd, rtol=1e-5) @@ -100,10 +105,12 @@ def test_linear_hessian( mvp_model = linear_mvp_model(A, b) test_hessian_analytical = linear_hessian_analytical((A, b), train_x) - grad_xy, _ = mvp_model.grad(train_x, train_y) - estimated_hessian = mvp( + grad_xy = mvp_model.grad( + torch.as_tensor(train_x), torch.as_tensor(train_y), create_graph=True + ) + estimated_hessian = mvp_model.mvp( grad_xy, - np.eye((input_dimension + 1) * output_dimension), + torch.as_tensor(np.eye((input_dimension + 1) * output_dimension)), mvp_model.parameters, ) assert np.allclose(test_hessian_analytical, estimated_hessian, rtol=1e-5) @@ -138,9 +145,11 @@ def test_linear_mixed_derivative( ) model_mvp = [] for i in range(len(train_x)): - grad_xy, tensor_x = mvp_model.grad(train_x[i], train_y[i], x_requires_grad=True) + tensor_x = torch.as_tensor(train_x[i]).requires_grad_(True) + tensor_y = torch.as_tensor(train_y[i]) + grad_xy = mvp_model.grad(tensor_x, tensor_y, create_graph=True) model_mvp.append( - mvp( + mvp_model.mvp( grad_xy, np.eye((input_dimension + 1) * output_dimension), backprop_on=tensor_x, diff --git a/tests/influence/test_util.py b/tests/influence/test_util.py index 7828c470a..b381426a9 100644 --- a/tests/influence/test_util.py +++ b/tests/influence/test_util.py @@ -9,15 +9,9 @@ from torch.nn.functional import mse_loss from torch.utils.data import DataLoader, TensorDataset -from pydvl.influence.frameworks.functional import ( - batch_loss_function, - get_hvp_function, - hvp, -) -from pydvl.influence.frameworks.torch_differentiable import ( - lanzcos_low_rank_hessian_approx, -) -from pydvl.influence.frameworks.util import ( +from pydvl.influence.torch.functional import batch_loss_function, get_hvp_function, hvp +from pydvl.influence.torch.torch_differentiable import lanzcos_low_rank_hessian_approx +from pydvl.influence.torch.util import ( TorchTensorContainerType, align_structure, flatten_tensors_to_vector, @@ -100,11 +94,13 @@ def model_data(request): x = torch.rand(train_size, dimension[-1]) y = torch.rand(train_size, dimension[0]) torch_model = linear_torch_model_from_numpy(A, b) - vec = { - name: torch.rand(*p.shape) - for name, p in torch_model.named_parameters() - if p.requires_grad - } + vec = flatten_tensors_to_vector( + tuple( + torch.rand(*p.shape) + for name, p in torch_model.named_parameters() + if p.requires_grad + ) + ) H_analytical = linear_hessian_analytical((A, b), x.numpy()) H_analytical = torch.as_tensor(H_analytical) return torch_model, x, y, vec, H_analytical.to(torch.float32) @@ -118,7 +114,6 @@ def model_data(request): ) def test_hvp(model_data, tol: float): torch_model, x, y, vec, H_analytical = model_data - vec_flat = flatten_tensors_to_vector(vec.values()) params = dict(torch_model.named_parameters()) @@ -127,7 +122,7 @@ def test_hvp(model_data, tol: float): Hvp_autograd = hvp(f, params, align_structure(params, vec)) flat_Hvp_autograd = flatten_tensors_to_vector(Hvp_autograd.values()) - assert torch.allclose(flat_Hvp_autograd, H_analytical @ vec_flat, rtol=tol) + assert torch.allclose(flat_Hvp_autograd, H_analytical @ vec, rtol=tol) @pytest.mark.torch @@ -146,10 +141,8 @@ def test_get_hvp_function(model_data, tol: float, use_avg: bool, batch_size: int Hvp_autograd = get_hvp_function( torch_model, mse_loss, data_loader, use_hessian_avg=use_avg )(vec) - vec_flat = flatten_tensors_to_vector(vec.values()) - flat_Hvp_autograd = flatten_tensors_to_vector(Hvp_autograd.values()) - assert torch.allclose(flat_Hvp_autograd, H_analytical @ vec_flat, rtol=tol) + assert torch.allclose(Hvp_autograd, H_analytical @ vec, rtol=tol) @pytest.mark.torch @@ -163,8 +156,6 @@ def test_lanzcos_low_rank_hessian_approx( ): _, _, _, vec, H_analytical = model_data - vec_flat = flatten_tensors_to_vector(vec.values()) - reg_H_analytical = H_analytical + regularization * torch.eye(H_analytical.shape[0]) low_rank_approx = lanzcos_low_rank_hessian_approx( lambda z: reg_H_analytical @ z, @@ -173,9 +164,9 @@ def test_lanzcos_low_rank_hessian_approx( ) approx_result = low_rank_approx.projections @ ( torch.diag_embed(low_rank_approx.eigen_vals) - @ (low_rank_approx.projections.t() @ vec_flat.t()) + @ (low_rank_approx.projections.t() @ vec.t()) ) - assert torch.allclose(approx_result, reg_H_analytical @ vec_flat, rtol=1e-1) + assert torch.allclose(approx_result, reg_H_analytical @ vec, rtol=1e-1) @pytest.mark.torch