Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
diaomin authored Apr 6, 2018
1 parent b6ad4a3 commit 3b7a416
Showing 1 changed file with 1 addition and 51 deletions.
52 changes: 1 addition & 51 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,55 +54,5 @@ def main():
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)


def main2():
args = parse_args()
hp = Hyperparams()

if args.gpu:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu(i) for i in range(args.cpu)]


init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]

data_train = ImageIterLstm(
args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
data_val = ImageIterLstm(
args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

symbol = crnn_lstm(hp)
module = mx.mod.Module(
symbol,
data_names=data_names,
label_names=['label'],
context=contexts)

module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label)

metrics = CtcMetrics(hp.seq_length)

module.fit(train_data=data_train,
eval_data=data_val,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer='AdaDelta',
optimizer_params={'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': 0.00001,
},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch=hp.num_epoch,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)


if __name__ == '__main__':
main()
main()

0 comments on commit 3b7a416

Please sign in to comment.