From 3b7a4161c3cd3a84d21179b7946cc51c1553dfb4 Mon Sep 17 00:00:00 2001 From: Wei Tang Date: Fri, 6 Apr 2018 17:16:59 +0800 Subject: [PATCH] Update train.py --- train.py | 52 +--------------------------------------------------- 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/train.py b/train.py index 9ac070c..90f15f5 100644 --- a/train.py +++ b/train.py @@ -54,55 +54,5 @@ def main(): fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names) -def main2(): - args = parse_args() - hp = Hyperparams() - - if args.gpu: - contexts = [mx.context.gpu(i) for i in range(args.gpu)] - else: - contexts = [mx.context.cpu(i) for i in range(args.cpu)] - - - init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] - init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] - init_states = init_c + init_h - data_names = ['data'] + [x[0] for x in init_states] - - data_train = ImageIterLstm( - args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train") - data_val = ImageIterLstm( - args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val") - - head = '%(asctime)-15s %(message)s' - logging.basicConfig(level=logging.DEBUG, format=head) - - symbol = crnn_lstm(hp) - module = mx.mod.Module( - symbol, - data_names=data_names, - label_names=['label'], - context=contexts) - - module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label) - - metrics = CtcMetrics(hp.seq_length) - - module.fit(train_data=data_train, - eval_data=data_val, - # use metrics.accuracy or metrics.accuracy_lcs - eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), - optimizer='AdaDelta', - optimizer_params={'learning_rate': hp.learning_rate, - # 'momentum': hp.momentum, - 'wd': 0.00001, - }, - initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), - num_epoch=hp.num_epoch, - batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50), - epoch_end_callback=mx.callback.do_checkpoint(args.prefix), - ) - - if __name__ == '__main__': - main() \ No newline at end of file + main()