Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter precompute_grad to CgInfluence init, adapt documentation #498

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Fixed

- Bug in `LissaInfluence`, when not using CPU device [PR #495](https://github.com/aai-institute/pyDVL/pull/495)
- Memory issue with `CgInfluence` and `ArnoldiInfluence`[PR #498](https://github.com/aai-institute/pyDVL/pull/498)

## 0.8.1 - 🆕 🏗 New method and noteboo, Games with exact shapley values, bug fixes and cleanup

Expand Down
13 changes: 10 additions & 3 deletions src/pydvl/influence/torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def create_hvp_function(
is the model's input and the second element is the target output.
precompute_grad: If True, the full data gradient is precomputed and kept
in memory, which can speed up the hessian vector product computation.
Set this to False, if you can't afford to keep an additional
parameter-sized vector in memory.
Set this to False, if you can't afford to keep the full computation graph
in memory.
use_average: If True, the returned function uses batch-wise computation via
[batch_loss_function][pydvl.influence.torch.functional.batch_loss_function]
and averages the results.
Expand Down Expand Up @@ -772,6 +772,7 @@ def model_hessian_low_rank(
tol: float = 1e-6,
max_iter: Optional[int] = None,
eigen_computation_on_gpu: bool = False,
precompute_grad: bool = False,
) -> LowRankProductRepresentation:
r"""
Calculates a low-rank approximation of the Hessian matrix of the model's
Expand Down Expand Up @@ -807,14 +808,20 @@ def model_hessian_low_rank(
small rank_estimate to fit your device's memory.
If False, the eigen pair approximation is executed on the CPU by
scipy wrapper to ARPACK.
precompute_grad: If True, the full data gradient is precomputed and kept
in memory, which can speed up the hessian vector product computation.
Set this to False, if you can't afford to keep the full computation graph
in memory.

Returns:
[LowRankProductRepresentation]
[pydvl.influence.torch.functional.LowRankProductRepresentation]
instance that contains the top (up until rank_estimate) eigenvalues
and corresponding eigenvectors of the Hessian.
"""
raw_hvp = create_hvp_function(model, loss, training_data, use_average=True)
raw_hvp = create_hvp_function(
model, loss, training_data, use_average=True, precompute_grad=precompute_grad
)
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
device = next(model.parameters()).device
return lanzcos_low_rank_hessian_approx(
Expand Down
20 changes: 19 additions & 1 deletion src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ class CgInfluence(TorchInfluenceFunctionModel):
atol: Absolute tolerance of result.
maxiter: Maximum number of iterations. If None, defaults to 10*len(b).
progress: If True, display progress bars.
precompute_grad: If True, the full data gradient is precomputed and kept
in memory, which can speed up the hessian vector product computation.
Set this to False, if you can't afford to keep the full computation graph
in memory.

"""

Expand All @@ -452,8 +456,10 @@ def __init__(
atol: float = 1e-7,
maxiter: Optional[int] = None,
progress: bool = False,
precompute_grad: bool = False,
):
super().__init__(model, loss)
self.precompute_grad = precompute_grad
self.progress = progress
self.maxiter = maxiter
self.atol = atol
Expand Down Expand Up @@ -525,7 +531,12 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
if len(self.train_dataloader) == 0:
raise ValueError("Training dataloader must not be empty.")

hvp = create_hvp_function(self.model, self.loss, self.train_dataloader)
hvp = create_hvp_function(
self.model,
self.loss,
self.train_dataloader,
precompute_grad=self.precompute_grad,
)

def reg_hvp(v: torch.Tensor):
return hvp(v) + self.hessian_regularization * v.type(rhs.dtype)
Expand Down Expand Up @@ -749,6 +760,10 @@ class ArnoldiInfluence(TorchInfluenceFunctionModel):
is appropriate for device memory.
If False, the eigen pair approximation is executed on the CPU by the scipy
wrapper to ARPACK.
precompute_grad: If True, the full data gradient is precomputed and kept
in memory, which can speed up the hessian vector product computation.
Set this to False, if you can't afford to keep the full computation graph
in memory.
"""
low_rank_representation: LowRankProductRepresentation

Expand All @@ -762,6 +777,7 @@ def __init__(
tol: float = 1e-6,
max_iter: Optional[int] = None,
eigen_computation_on_gpu: bool = False,
precompute_grad: bool = False,
):

super().__init__(model, loss)
Expand All @@ -771,6 +787,7 @@ def __init__(
self.max_iter = max_iter
self.krylov_dimension = krylov_dimension
self.eigen_computation_on_gpu = eigen_computation_on_gpu
self.precompute_grad = precompute_grad

@property
def is_fitted(self):
Expand Down Expand Up @@ -804,6 +821,7 @@ def fit(self, data: DataLoader) -> ArnoldiInfluence:
tol=self.tol,
max_iter=self.max_iter,
eigen_computation_on_gpu=self.eigen_computation_on_gpu,
precompute_grad=self.precompute_grad,
)
self.low_rank_representation = low_rank_representation.to(self.model_device)
return self
Expand Down
Loading