Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring/influence #394

Merged
merged 18 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
[PR #371](https://github.com/appliedAI-Initiative/pyDVL/pull/371)
- Major changes to IF interface and functionality
[PR #278](https://github.com/appliedAI-Initiative/pyDVL/pull/278)
[PR #394](https://github.com/appliedAI-Initiative/pyDVL/pull/394)
- **New Method**: Implements solving the hessian equation via spectral low-rank approximation
[PR #365](https://github.com/appliedAI-Initiative/pyDVL/pull/365)
- **Breaking Changes**:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/support/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader
from torchvision.models import ResNet18_Weights, resnet18

from pydvl.influence.frameworks import as_tensor
from pydvl.influence.torch import as_tensor
from pydvl.utils import maybe_progress

from .types import Losses
Expand Down
3 changes: 1 addition & 2 deletions src/pydvl/influence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@
probably change.

"""
from .frameworks import TorchTwiceDifferentiable, TwiceDifferentiable
from .general import compute_influence_factors, compute_influences
from .general import InfluenceType, compute_influence_factors, compute_influences
from .inversion import InversionMethod
57 changes: 0 additions & 57 deletions src/pydvl/influence/frameworks/__init__.py

This file was deleted.

90 changes: 0 additions & 90 deletions src/pydvl/influence/frameworks/twice_differentiable.py

This file was deleted.

113 changes: 58 additions & 55 deletions src/pydvl/influence/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@
"""
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Type

from ..utils import maybe_progress
from .frameworks import (
DataLoaderType,
TensorType,
TwiceDifferentiable,
einsum,
mvp,
transpose_tensor,
zero_tensor,
)
from .inversion import InverseHvpResult, InversionMethod, solve_hvp
from .inversion import DataLoaderType, InverseHvpResult, InversionMethod, solve_hvp
from .twice_differentiable import TensorType, TensorUtilities, TwiceDifferentiable

__all__ = ["compute_influences", "InfluenceType", "compute_influence_factors"]

Expand Down Expand Up @@ -56,32 +48,37 @@ def compute_influence_factors(
:param model: A model wrapped in the TwiceDifferentiable interface.
:param training_data: A DataLoader containing the training data.
:param test_data: A DataLoader containing the test data.
:param inversion_func: function to use to invert the product of hvp (hessian
vector product) and the gradient of the loss (s_test in the paper).
:param inversion_method: name of method for computing inverse hessian vector
products.
:param hessian_perturbation: regularization of the hessian
:param progress: If True, display progress bars.
:returns: An array of size (N, D) containing the influence factors for each
dimension (D) and test sample (N).
"""
test_grads = zero_tensor(
shape=(len(test_data.dataset), model.num_params),
dtype=test_data.dataset[0][0].dtype,
device=model.device,
tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable(
model
)
for batch_idx, (x_test, y_test) in enumerate(
maybe_progress(test_data, progress, desc="Batch Test Gradients")
stack = tensor_util.stack
unsqueeze = tensor_util.unsqueeze
cat = tensor_util.cat

test_grads = []
schroedk marked this conversation as resolved.
Show resolved Hide resolved
for x_test, y_test in maybe_progress(
test_data, progress, desc="Batch Test Gradients"
):
idx = batch_idx * test_data.batch_size
test_grads[idx : idx + test_data.batch_size] = model.split_grad(
x_test, y_test, progress=False
test_grad = stack(
[
model.grad(inpt, target)
for inpt, target in zip(unsqueeze(x_test, 1), y_test)
]
)
test_grads.append(test_grad)
return solve_hvp(
inversion_method,
model,
training_data,
test_grads,
cat(test_grads),
hessian_perturbation=hessian_perturbation,
progress=progress,
**kwargs,
)

Expand Down Expand Up @@ -110,19 +107,25 @@ def compute_influences_up(
:returns: An array of size [NxM], where N is number of influence factors, M
number of input points.
"""
grads = zero_tensor(
shape=(len(input_data.dataset), model.num_params),
dtype=input_data.dataset[0][0].dtype,
device=model.device,

tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable(
model
)
for batch_idx, (x, y) in enumerate(
maybe_progress(input_data, progress, desc="Batch Split Input Gradients")
stack = tensor_util.stack
unsqueeze = tensor_util.unsqueeze
cat = tensor_util.cat
einsum = tensor_util.einsum

train_grads = []
for x, y in maybe_progress(
input_data, progress, desc="Batch Split Input Gradients"
):
idx = batch_idx * input_data.batch_size
grads[idx : idx + input_data.batch_size] = model.split_grad(
x, y, progress=False
train_grad = stack(
[model.grad(inpt, target) for inpt, target in zip(unsqueeze(x, 1), y)]
)
return einsum("ta,va->tv", influence_factors, grads)
train_grads.append(train_grad)

return einsum("ta,va->tv", influence_factors, cat(train_grads)) # type: ignore # ToDO fix typing


def compute_influences_pert(
Expand All @@ -149,34 +152,35 @@ def compute_influences_pert(
:returns: An array of size [NxMxP], where N is the number of influence factors, M
the number of input data, and P the number of features.
"""
input_x = input_data.dataset[0][0]
all_pert_influences = zero_tensor(
shape=(len(input_data.dataset), len(influence_factors), *input_x.shape),
dtype=input_x.dtype,
device=model.device,

tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable(
model
)
for batch_idx, (x, y) in enumerate(
maybe_progress(
input_data,
progress,
desc="Batch Influence Perturbation",
)
stack = tensor_util.stack

all_pert_influences = []
for x, y in maybe_progress(
input_data,
progress,
desc="Batch Influence Perturbation",
):
for i in range(len(x)):
grad_xy, tensor_x = model.grad(x[i : i + 1], y[i], x_requires_grad=True)
perturbation_influences = mvp(
tensor_x = x[i : i + 1]
tensor_x = tensor_x.requires_grad_(True)
grad_xy = model.grad(tensor_x, y[i], create_graph=True)
perturbation_influences = model.mvp(
grad_xy,
influence_factors,
backprop_on=tensor_x,
)
all_pert_influences[
batch_idx * input_data.batch_size + i
] = perturbation_influences.reshape((-1, *x[i].shape))
all_pert_influences.append(
perturbation_influences.reshape((-1, *x[i].shape))
)

return transpose_tensor(all_pert_influences, 0, 1)
return stack(all_pert_influences, axis=1) # type: ignore # ToDO fix typing


influence_type_registry = {
influence_type_registry: Dict[InfluenceType, Callable[..., TensorType]] = {
InfluenceType.Up: compute_influences_up,
InfluenceType.Perturbation: compute_influences_pert,
}
Expand All @@ -193,7 +197,7 @@ def compute_influences(
hessian_regularization: float = 0.0,
progress: bool = False,
**kwargs: Any,
) -> TensorType:
) -> TensorType: # type: ignore # ToDO fix typing
r"""
Calculates the influence of the input_data point j on the test points i.
First it calculates the influence factors of all test points with respect
Expand Down Expand Up @@ -234,9 +238,8 @@ def compute_influences(
progress=progress,
**kwargs,
)
compute_influence_type = influence_type_registry[influence_type]

return compute_influence_type(
return influence_type_registry[influence_type](
differentiable_model,
input_data,
influence_factors,
Expand Down
Loading