-
Notifications
You must be signed in to change notification settings - Fork 56
/
create_solver.py
executable file
·60 lines (51 loc) · 1.66 KB
/
create_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python
"""
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
Author: Varun Jampani
"""
import tempfile
from caffe.proto import caffe_pb2 as PB
from config import *
from init_caffe import *
def create_solver(solver_param, file_name=""):
if file_name:
f = open(file_name, 'w')
else:
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
f.write(str(solver_param))
f.close()
solver = caffe.get_solver(f.name)
return solver
def create_solver_proto(train_net,
test_net,
lr,
prefix,
test_iter=100,
test_interval=1000,
max_iter=1e5,
iter_size=1,
snapshot=1000,
display=1,
debug_info=False):
solver = PB.SolverParameter()
solver.train_net = train_net
solver.test_net.extend([test_net])
solver.test_iter.extend([test_iter])
solver.test_interval = test_interval
solver.display = display
solver.max_iter = max_iter
solver.iter_size = iter_size
solver.snapshot = snapshot
solver.snapshot_prefix = prefix
solver.random_seed = RAND_SEED
solver.average_loss = 20
solver.solver_mode = PB.SolverParameter.GPU
solver.solver_type = PB.SolverParameter.ADAM
solver.base_lr = lr
solver.lr_policy = "fixed"
solver.power = 0.9
solver.momentum = 0.9
solver.momentum2 = 0.999
solver.debug_info = debug_info
return solver