Skip to content

Commit

Permalink
Merge branch 'fix-agem' of https://github.com/JonasFrey96/avalanche i…
Browse files Browse the repository at this point in the history
…nto agem
  • Loading branch information
AntonioCarta committed Jul 22, 2021
2 parents 981edc5 + f6e3919 commit 4e3cf68
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
78 changes: 42 additions & 36 deletions avalanche/training/plugins/agem.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,24 @@ def after_backward(self, strategy, **kwargs):
"""
Project gradient based on reference gradients
"""

if self.memory_x is not None:
for (n1, p1), (n2, refg) in zip(strategy.model.named_parameters(),
self.reference_gradients):

assert n1 == n2, "Different model parameters in AGEM projection"
assert (p1.grad is not None and refg is not None) or \
(p1.grad is None and refg is None)

if refg is None:
continue

dotg = torch.dot(p1.grad.view(-1), refg.view(-1))
dotref = torch.dot(refg.view(-1), refg.view(-1))
if dotg < 0:
p1.grad -= (dotg / dotref) * refg
current_gradients = [p.grad.view(-1)
for n, p in strategy.model.named_parameters() if p.requires_grad]
current_gradients = torch.cat(current_gradients)

assert current_gradients.shape == self.reference_gradients.shape , "Different model parameters in AGEM projection"

dotg = torch.dot( current_gradients, self.reference_gradients)
if dotg < 0:
alpha2 = dotg / torch.dot(self.reference_gradients, self.reference_gradients)
grad_proj = current_gradients - self.reference_gradients * alpha2

count = 0
for n, p in strategy.model.named_parameters():
if p.requires_grad:
n_param = p.numel()
p.grad.copy_( grad_proj[count:count+n_param].view_as(p) )
count += n_param

def after_training_exp(self, strategy, **kwargs):
"""
Expand Down Expand Up @@ -97,26 +99,30 @@ def update_memory(self, dataloader):
"""
Update replay memory with patterns from current experience.
"""

tot = 0
for mbatch in dataloader:
x, y, _ = mbatch
if tot + x.size(0) <= self.patterns_per_experience:
if self.memory_x is None:
self.memory_x = x.clone()
self.memory_y = y.clone()
else:
self.memory_x = torch.cat((self.memory_x, x), dim=0)
self.memory_y = torch.cat((self.memory_y, y), dim=0)
else:
diff = self.patterns_per_experience - tot
if self.memory_x is None:
self.memory_x = x[:diff].clone()
self.memory_y = y[:diff].clone()
done = False
for batches in dataloader:
for _, (x, y) in batches.items():
if tot + x.size(0) <= self.patterns_per_experience:
if self.memory_x is None:
self.memory_x = x.clone()
self.memory_y = y.clone()
else:
self.memory_x = torch.cat((self.memory_x, x), dim=0)
self.memory_y = torch.cat((self.memory_y, y), dim=0)
tot += x.size(0)
else:
self.memory_x = torch.cat((self.memory_x,
x[:diff]), dim=0)
self.memory_y = torch.cat((self.memory_y,
y[:diff]), dim=0)
break
tot += x.size(0)
diff = self.patterns_per_experience - tot
if self.memory_x is None:
self.memory_x = x[:diff].clone()
self.memory_y = y[:diff].clone()
else:
self.memory_x = torch.cat((self.memory_x,
x[:diff]), dim=0)
self.memory_y = torch.cat((self.memory_y,
y[:diff]), dim=0)
tot += diff
done = True

if done: break
if done: break
4 changes: 2 additions & 2 deletions examples/gem_agem_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
AGEM-PMNIST (5 experiences):
Patterns per experience = sample size: 256. 256 hidden size, 1 training epoch.
Average Accuracy over all experiences at the end of training on the last
experience: 51.4%
experience: 83.5%
AGEM-SMNIST:
Patterns per experience = sample size: 256, 512, 1024. Performance on previous
Expand All @@ -41,7 +41,7 @@
Hidden size 256.
Results for 1024 patterns per experience and sample size, 1 training epoch.
Average Accuracy over all experiences at the end of training on the last
experience: 23.5%
experience: 67.0%
"""

Expand Down

0 comments on commit 4e3cf68

Please sign in to comment.