Skip to content

Commit

Permalink
Merge pull request #572 from aai-institute/fix/571-missing-move-preco…
Browse files Browse the repository at this point in the history
…nditioner-cg

Overwrite `to` method of `CgInfluence`, add `to` method to precondito…
  • Loading branch information
schroedk authored May 3, 2024
2 parents efa56e3 + ae41cbf commit 0fc0553
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
12 changes: 7 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

## Unreleased

### Fixed

- Fixed missing move of tensors to model device in `EkfacInfluence`
implementation [PR #570](https://github.com/aai-institute/pyDVL/pull/570)

### Added

- Add a device fixture for `pytest`, which depending on the availability and
user input (`pytest --with-cuda`) resolves to cuda device
[PR #574](https://github.com/aai-institute/pyDVL/pull/574)

### Fixed

- Fixed missing move of tensors to model device in `EkfacInfluence`
implementation [PR #570](https://github.com/aai-institute/pyDVL/pull/570)
- Missing move to device of `preconditioner` in `CgInfluence` implementation
[PR #572](https://github.com/aai-institute/pyDVL/pull/572)

## 0.9.1 - Bug fixes, logging improvement

### Fixed
Expand Down
9 changes: 8 additions & 1 deletion src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,9 @@ def mat_mat(x: torch.Tensor):
R = (rhs - mat_mat(X)).T
Z = R if self.pre_conditioner is None else self.pre_conditioner.solve(R)
P, _, _ = torch.linalg.svd(Z, full_matrices=False)
active_indices = torch.as_tensor(list(range(X.shape[-1])), dtype=torch.long)
active_indices = torch.as_tensor(
list(range(X.shape[-1])), dtype=torch.long, device=self.model_device
)

maxiter = self.maxiter if self.maxiter is not None else len(rhs) * 10
y_norm = torch.linalg.norm(rhs, dim=1)
Expand Down Expand Up @@ -758,6 +760,11 @@ def mat_mat(x: torch.Tensor):

return X.T

def to(self, device: torch.device):
if self.pre_conditioner is not None:
self.pre_conditioner = self.pre_conditioner.to(device)
return super().to(device)


class LissaInfluence(TorchInfluenceFunctionModel):
r"""
Expand Down
17 changes: 17 additions & 0 deletions src/pydvl/influence/torch/pre_conditioner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Optional

Expand Down Expand Up @@ -70,6 +72,11 @@ def solve(self, rhs: torch.Tensor):
def _solve(self, rhs: torch.Tensor):
pass

@abstractmethod
def to(self, device: torch.device) -> PreConditioner:
"""Implement this to move the (potentially fitted) preconditioner to a
specific device"""


class JacobiPreConditioner(PreConditioner):
r"""
Expand Down Expand Up @@ -141,6 +148,11 @@ def _solve(self, rhs: torch.Tensor):

return rhs * inv_diag.unsqueeze(-1)

def to(self, device: torch.device) -> JacobiPreConditioner:
if self._diag is not None:
self._diag = self._diag.to(device)
return self


class NystroemPreConditioner(PreConditioner):
r"""
Expand Down Expand Up @@ -233,3 +245,8 @@ def _solve(self, rhs: torch.Tensor):
result = result.squeeze()

return result

def to(self, device: torch.device) -> NystroemPreConditioner:
if self._low_rank_approx is not None:
self._low_rank_approx = self._low_rank_approx.to(device)
return self

0 comments on commit 0fc0553

Please sign in to comment.