-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathoptim_weight_ema.py
25 lines (19 loc) · 1.03 KB
/
optim_weight_ema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
class EMAWeightOptimizer (object):
def __init__(self, target_net, source_net, ema_alpha):
self.target_net = target_net
self.source_net = source_net
self.ema_alpha = ema_alpha
self.target_params = [p for p in target_net.state_dict().values() if p.dtype == torch.float]
self.source_params = [p for p in source_net.state_dict().values() if p.dtype == torch.float]
for tgt_p, src_p in zip(self.target_params, self.source_params):
tgt_p[...] = src_p[...]
target_keys = set(target_net.state_dict().keys())
source_keys = set(source_net.state_dict().keys())
if target_keys != source_keys:
raise ValueError('Source and target networks do not have the same state dict keys; do they have different architectures?')
def step(self):
one_minus_alpha = 1.0 - self.ema_alpha
for tgt_p, src_p in zip(self.target_params, self.source_params):
tgt_p.mul_(self.ema_alpha)
tgt_p.add_(src_p * one_minus_alpha)