Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add closed-form grads for ewc_on, si. Fix dualprompt #52

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,24 @@ def get_grads(self) -> torch.Tensor:
Returns:
gradients tensor
"""
return torch.cat(self.get_grads_list())
grads = []
for pp in list(self.parameters()):
grads.append(pp.grad.view(-1))
return torch.cat(grads)


def get_grads_list(self):
def set_grads(self, new_grads: torch.Tensor) -> None:
"""
Returns a list containing the gradients (a tensor for each layer).
Sets the parameters to a given value.

Returns:
gradients list
Args:
new_params: concatenated values to be set
"""
grads = []
assert new_grads.size() == self.get_params().size()
progress = 0
for pp in list(self.parameters()):
grads.append(pp.grad.view(-1))
return grads
cand_grads = new_grads[progress: progress +
torch.tensor(pp.size()).prod()].view(pp.size())
progress += torch.tensor(pp.size()).prod()
pp.grad = cand_grads

8 changes: 7 additions & 1 deletion backbone/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
import torch.nn.functional as F
import torch.utils.checkpoint

from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
from timm.layers import PatchEmbed, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed
from timm.models._builder import build_model_with_cfg
from timm.models._manipulate import named_apply
Expand All @@ -65,6 +65,12 @@
from backbone import MammothBackbone
from backbone.utils.lora_utils import LoRAAttention, LoRAMlp
from utils.conf import warn_once
from timm.layers import Mlp as TimmMlp

class Mlp(TimmMlp):
def forward(self, x, **kwargs):
return super().forward(x)


__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this

Expand Down
10 changes: 7 additions & 3 deletions models/ewc_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def penalty(self):
if self.checkpoint is None:
return torch.tensor(0.0).to(self.device)
else:
penalty = (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
penalty = self.args.e_lambda * (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
return penalty

def end_task(self, dataset):
Expand Down Expand Up @@ -64,12 +64,16 @@ def end_task(self, dataset):

self.checkpoint = self.net.get_params().data.clone()

def get_penalty_grads(self):
return self.args.e_lambda * 2 * self.fish * (self.net.get_params().data - self.checkpoint)

def observe(self, inputs, labels, not_aug_inputs, epoch=None):

self.opt.zero_grad()
outputs = self.net(inputs)
penalty = self.penalty()
loss = self.loss(outputs, labels) + self.args.e_lambda * penalty
if self.checkpoint is not None:
self.net.set_grads(self.get_penalty_grads())
loss = self.loss(outputs, labels)
assert not torch.isnan(loss)
loss.backward()
self.opt.step()
Expand Down
14 changes: 10 additions & 4 deletions models/si.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,21 @@ def end_task(self, dataset):
self.checkpoint = self.net.get_params().data.clone().to(self.device)
self.small_omega = 0

def get_penalty_grads(self):
return self.args.c * 2 * self.big_omega * (self.net.get_params().data - self.checkpoint)

def observe(self, inputs, labels, not_aug_inputs, epoch=None):
self.opt.zero_grad()
outputs = self.net(inputs)
penalty = self.penalty()
loss = self.loss(outputs, labels) + self.args.c * penalty
loss = self.loss(outputs, labels)
loss.backward()
cur_small_omega = self.net.get_grads().data
if self.big_omega is not None:
loss_grads = self.net.get_grads()
self.net.set_grads(loss_grads + self.get_penalty_grads())
cur_small_omega *= self.args.lr * self.net.get_grads().data
self.small_omega += cur_small_omega
nn.utils.clip_grad.clip_grad_value_(self.get_parameters(), 1)
self.opt.step()

self.small_omega += self.args.lr * self.net.get_grads().data ** 2

return loss.item()