-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathengine.py
74 lines (64 loc) · 2.3 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class Engine(object):
def __init__(self):
self.hooks = {}
def hook(self, name, state):
if name in self.hooks:
self.hooks[name](state)
def train(self, network, iterator, maxepoch, optimizer, scheduler):
state = {
'network': network,
'iterator': iterator,
'maxepoch': maxepoch,
'optimizer': optimizer,
'scheduler': scheduler,
'epoch': 0,
't': 0,
'train': True,
}
self.hook('on_start', state)
while state['epoch'] < state['maxepoch']:
self.hook('on_start_epoch', state)
for sample in state['iterator']:
state['sample'] = sample
self.hook('on_sample', state)
def closure():
loss, output = state['network'](state['sample'])
state['output'] = output
state['loss'] = loss
loss.backward()
self.hook('on_forward', state)
# to free memory in save_for_backward
state['output'] = None
state['loss'] = None
return loss
state['optimizer'].zero_grad()
state['optimizer'].step(closure)
self.hook('on_update', state)
state['t'] += 1
state['epoch'] += 1
self.hook('on_end_epoch', state)
self.hook('on_end', state)
return state
def test(self, network, iterator):
state = {
'network': network,
'iterator': iterator,
't': 0,
'train': False,
}
self.hook('on_test_start', state)
for sample in state['iterator']:
state['sample'] = sample
self.hook('on_test_sample', state)
def closure():
loss, output = state['network'](state['sample'])
state['output'] = output
state['loss'] = loss
self.hook('on_test_forward', state)
# to free memory in save_for_backward
state['output'] = None
state['loss'] = None
closure()
state['t'] += 1
self.hook('on_test_end', state)
return state