-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathlookahead.py
64 lines (52 loc) · 2.52 KB
/
lookahead.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#! -*- coding: utf-8 -*-
from keras import backend as K
class Lookahead(object):
"""Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/).
"""
def __init__(self, k=5, alpha=0.5):
self.k = k
self.alpha = alpha
self.count = 0
def inject(self, model):
"""Inject the Lookahead algorithm for the given model.
The following code is modified from keras's _make_train_function method.
See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497
"""
if not hasattr(model, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
model._check_trainable_weights_consistency()
if model.train_function is None:
inputs = (model._feed_inputs +
model._feed_targets +
model._feed_sample_weights)
if model._uses_dynamic_learning_phase():
inputs += [K.learning_phase()]
fast_params = model._collected_trainable_weights
with K.name_scope('training'):
with K.name_scope(model.optimizer.__class__.__name__):
training_updates = model.optimizer.get_updates(
params=fast_params,
loss=model.total_loss)
slow_params = [K.variable(p) for p in fast_params]
fast_updates = (model.updates +
training_updates +
model.metrics_updates)
slow_updates, copy_updates = [], []
for p, q in zip(fast_params, slow_params):
slow_updates.append(K.update(q, q + self.alpha * (p - q)))
copy_updates.append(K.update(p, q))
# Gets loss and metrics. Updates weights at each call.
fast_train_function = K.function(
inputs,
[model.total_loss] + model.metrics_tensors,
updates=fast_updates,
name='fast_train_function',
**model._function_kwargs)
def F(inputs):
self.count += 1
R = fast_train_function(inputs)
if self.count % self.k == 0:
K.batch_get_value(slow_updates)
K.batch_get_value(copy_updates)
return R
model.train_function = F