-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtask.py
55 lines (39 loc) · 1.54 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
from trainer import Trainer
from data_loader import DataLoader
from directories import Directories
from network import Network
class Task():
def __init__(self):
self._dirs = Directories()
self._init_data_loader()
self._init_network()
def _update_dirs(self, log_folder):
current_dir = os.path.dirname(os.path.realpath(__file__))
network_dir = Network.directory() + log_folder
self._dirs.update_for_task(current_dir, network_dir)
def _init_network(self):
self.__network = Network(data_loader=self.__data_loader,
dirs=self._dirs)
def _init_data_loader(self):
self.__data_loader = DataLoader(dirs=self._dirs,
total_num_examples=22872)
def _init_trainer(self):
self.__trainer = Trainer(data_loader=self.__data_loader,
network=self.__network,
dirs=self._dirs,
ld_chkpt=True,
save_freq=10000,
log_freq=200,
vis_freq=5000)
return self.__trainer
def run(self, train):
self._init_trainer()
log_folder = '/logged_data'
self._update_dirs(log_folder)
self.__network.chkpt_dir = self._dirs.chkpt_dir
if train:
self.__trainer.run_trainer()
else:
self.__trainer.run_visualiser()
self.__trainer.sess.close()