-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSS_train.py
82 lines (69 loc) · 3.55 KB
/
SS_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from absl import app
from absl import logging
import os
import sys
import signal
import datetime
from absl import flags
from utils.utils import program_duration
import train_utils
import SS_utils
FLAGS = flags.FLAGS
flags.DEFINE_float("lambda_u", help="lambda_u : weight of self-supervised loss", default=1.)
flags.DEFINE_enum("st_type", enum_values=["stssc", "combined"],
help="self-training configuration for self-supervision: only first iteration (stssc)"
" or all iterations (combined)", default="stssc")
def main(argv):
dt1 = datetime.datetime.now()
del argv # not used
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
dso, data_config = train_utils.set_dataset(FLAGS.dataset, FLAGS.lt, FLAGS.semi)
combined_model, supervised_model = SS_utils.get_model(FLAGS.network, data_config, FLAGS.weights, FLAGS.lt,
FLAGS.opt, FLAGS.lr, lambda_u=FLAGS.lambda_u)
# set up logging details
log_dir, log_name = train_utils.get_log_name(FLAGS, data_config, prefix="ss-")
# append type of self-training in log file name
if FLAGS.self_training:
log_name += FLAGS.st_type
os.makedirs(log_dir, exist_ok=True)
logging.get_absl_handler().use_absl_log_file(log_name, log_dir)
logging.get_absl_handler().setFormatter(None)
print("training logs are saved at: ", log_dir)
logging.info(FLAGS.flag_values_dict())
ac = train_utils.log_accuracy(supervised_model, dso, FLAGS.lt, FLAGS.semi, labelling=FLAGS.lbl)
logging.info("init Test accuracy : {:.2f} %".format(ac))
def ctrl_c_accuracy():
ac_ = train_utils.log_accuracy(supervised_model, dso, FLAGS.lt, FLAGS.semi, labelling=FLAGS.lbl)
logging.info("ctrl_c_accuracy Test accuracy : {:.2f} %".format(ac_))
print(program_duration(dt1, 'Killed after Time'))
def exit_gracefully(signum, frame):
original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, original_sigint)
try:
if input("\nReally quit? (y/n)> ").lower().startswith('y'):
ctrl_c_accuracy()
sys.exit(1)
except KeyboardInterrupt:
print("Ok ok, quitting")
sys.exit(1)
signal.signal(signal.SIGINT, exit_gracefully)
SS_utils.start_combined_training(combined_model, dso, FLAGS.epochs, FLAGS.semi, FLAGS.batch_size)
ac = train_utils.log_accuracy(supervised_model, dso, FLAGS.lt, FLAGS.semi, labelling=FLAGS.lbl)
logging.info("after training Test accuracy : {:.2f} %".format(ac))
if FLAGS.self_training:
# reduce lr by factor of 0.1
from tensorflow.keras import backend as keras_backend
keras_backend.set_value(combined_model.optimizer.learning_rate, FLAGS.lr/10.)
if 'stssc' in FLAGS.st_type:
train_utils.start_self_learning(supervised_model, dso, data_config, FLAGS.lt, FLAGS.confidence_measure,
FLAGS.meta_iterations, FLAGS.epochs_per_m_iteration, FLAGS.batch_size,
logger=logging)
else:
SS_utils.start_combined_self_learning(combined_model, supervised_model, dso, data_config, FLAGS.lt, logging,
FLAGS.meta_iterations, FLAGS.epochs_per_m_iteration, FLAGS.batch_size)
print(program_duration(dt1, 'Total Time taken'))
if __name__ == '__main__':
from flags import setup_flags
setup_flags()
FLAGS.alsologtostderr = True # also show logging info to std output
app.run(main)