diff --git a/avalanche/training/plugins/agem.py b/avalanche/training/plugins/agem.py index ef4681631..607d0cb4c 100644 --- a/avalanche/training/plugins/agem.py +++ b/avalanche/training/plugins/agem.py @@ -53,22 +53,24 @@ def after_backward(self, strategy, **kwargs): """ Project gradient based on reference gradients """ - if self.memory_x is not None: - for (n1, p1), (n2, refg) in zip(strategy.model.named_parameters(), - self.reference_gradients): - - assert n1 == n2, "Different model parameters in AGEM projection" - assert (p1.grad is not None and refg is not None) or \ - (p1.grad is None and refg is None) - - if refg is None: - continue - - dotg = torch.dot(p1.grad.view(-1), refg.view(-1)) - dotref = torch.dot(refg.view(-1), refg.view(-1)) - if dotg < 0: - p1.grad -= (dotg / dotref) * refg + current_gradients = [p.grad.view(-1) + for n, p in strategy.model.named_parameters() if p.requires_grad] + current_gradients = torch.cat(current_gradients) + + assert current_gradients.shape == self.reference_gradients.shape , "Different model parameters in AGEM projection" + + dotg = torch.dot( current_gradients, self.reference_gradients) + if dotg < 0: + alpha2 = dotg / torch.dot(self.reference_gradients, self.reference_gradients) + grad_proj = current_gradients - self.reference_gradients * alpha2 + + count = 0 + for n, p in strategy.model.named_parameters(): + if p.requires_grad: + n_param = p.numel() + p.grad.copy_( grad_proj[count:count+n_param].view_as(p) ) + count += n_param def after_training_exp(self, strategy, **kwargs): """ @@ -97,26 +99,30 @@ def update_memory(self, dataloader): """ Update replay memory with patterns from current experience. """ - tot = 0 - for mbatch in dataloader: - x, y, _ = mbatch - if tot + x.size(0) <= self.patterns_per_experience: - if self.memory_x is None: - self.memory_x = x.clone() - self.memory_y = y.clone() - else: - self.memory_x = torch.cat((self.memory_x, x), dim=0) - self.memory_y = torch.cat((self.memory_y, y), dim=0) - else: - diff = self.patterns_per_experience - tot - if self.memory_x is None: - self.memory_x = x[:diff].clone() - self.memory_y = y[:diff].clone() + done = False + for batches in dataloader: + for _, (x, y) in batches.items(): + if tot + x.size(0) <= self.patterns_per_experience: + if self.memory_x is None: + self.memory_x = x.clone() + self.memory_y = y.clone() + else: + self.memory_x = torch.cat((self.memory_x, x), dim=0) + self.memory_y = torch.cat((self.memory_y, y), dim=0) + tot += x.size(0) else: - self.memory_x = torch.cat((self.memory_x, - x[:diff]), dim=0) - self.memory_y = torch.cat((self.memory_y, - y[:diff]), dim=0) - break - tot += x.size(0) + diff = self.patterns_per_experience - tot + if self.memory_x is None: + self.memory_x = x[:diff].clone() + self.memory_y = y[:diff].clone() + else: + self.memory_x = torch.cat((self.memory_x, + x[:diff]), dim=0) + self.memory_y = torch.cat((self.memory_y, + y[:diff]), dim=0) + tot += diff + done = True + + if done: break + if done: break diff --git a/examples/gem_agem_mnist.py b/examples/gem_agem_mnist.py index ba4bcbed8..0ca7c591a 100644 --- a/examples/gem_agem_mnist.py +++ b/examples/gem_agem_mnist.py @@ -32,7 +32,7 @@ AGEM-PMNIST (5 experiences): Patterns per experience = sample size: 256. 256 hidden size, 1 training epoch. Average Accuracy over all experiences at the end of training on the last -experience: 51.4% +experience: 83.5% AGEM-SMNIST: Patterns per experience = sample size: 256, 512, 1024. Performance on previous @@ -41,7 +41,7 @@ Hidden size 256. Results for 1024 patterns per experience and sample size, 1 training epoch. Average Accuracy over all experiences at the end of training on the last -experience: 23.5% +experience: 67.0% """