This repository has been archived by the owner on Aug 23, 2023. It is now read-only.
forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·141 lines (106 loc) · 4.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python
""" Main training workflow """
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import division
import argparse
import os
import sys
import random
import torch
from torch import cuda
import onmt.opts as opts
from onmt.inputters.inputter import build_dataset_iter, lazily_load_dataset,\
_load_fields, _collect_report_features
from onmt.model_builder import build_model
from onmt.utils.optimizers import build_optim
from onmt.trainer import build_trainer
from onmt.models import build_model_saver
def _check_save_model_path():
save_model_path = os.path.abspath(opt.save_model)
model_dirname = os.path.dirname(save_model_path)
if not os.path.exists(model_dirname):
os.makedirs(model_dirname)
def _tally_parameters(model):
n_params = sum([p.nelement() for p in model.parameters()])
print('* number of parameters: %d' % n_params)
enc = 0
dec = 0
for name, param in model.named_parameters():
if 'encoder' in name:
enc += param.nelement()
elif 'decoder' or 'generator' in name:
dec += param.nelement()
print('encoder: ' + str(enc))
print('decoder: ' + str(dec))
def training_opt_postprocessing(opt):
if opt.word_vec_size != -1:
opt.src_word_vec_size = opt.word_vec_size
opt.tgt_word_vec_size = opt.word_vec_size
if opt.layers != -1:
opt.enc_layers = opt.layers
opt.dec_layers = opt.layers
opt.brnn = (opt.encoder_type == "brnn")
if opt.seed > 0:
random.seed(opt.seed)
torch.manual_seed(opt.seed)
if opt.rnn_type == "SRU" and not opt.gpuid:
raise AssertionError("Using SRU requires -gpuid set.")
if torch.cuda.is_available() and not opt.gpuid:
print("WARNING: You have a CUDA device, should run with -gpuid 0")
if opt.gpuid:
cuda.set_device(opt.gpuid[0])
if opt.seed > 0:
torch.cuda.manual_seed(opt.seed)
if len(opt.gpuid) > 1:
sys.stderr.write("Sorry, multigpu isn't supported yet, coming soon!\n")
sys.exit(1)
return opt
def main(opt):
opt = training_opt_postprocessing(opt)
# Load checkpoint if we resume from a previous training.
if opt.train_from:
print('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
model_opt = checkpoint['opt']
# I don't like reassigning attributes of opt: it's not clear.
opt.start_epoch = checkpoint['epoch'] + 1
else:
checkpoint = None
model_opt = opt
# Peek the fisrt dataset to determine the data_type.
# (All datasets have the same data_type).
first_dataset = next(lazily_load_dataset("train", opt))
data_type = first_dataset.data_type
# Load fields generated from preprocess phase.
fields = _load_fields(first_dataset, data_type, opt, checkpoint)
# Report src/tgt features.
_collect_report_features(fields)
# Build model.
model = build_model(model_opt, opt, fields, checkpoint)
_tally_parameters(model)
_check_save_model_path()
# Build optimizer.
optim = build_optim(model, opt, checkpoint)
# Build model saver
model_saver = build_model_saver(model_opt, opt, model, fields, optim)
trainer = build_trainer(
opt, model, fields, optim, data_type, model_saver=model_saver)
def train_iter_fct(): return build_dataset_iter(
lazily_load_dataset("train", opt), fields, opt)
def valid_iter_fct(): return build_dataset_iter(
lazily_load_dataset("valid", opt), fields, opt)
# Do training.
trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs)
if opt.tensorboard:
trainer.report_manager.tensorboard_writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='train.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
opts.add_md_help_argument(parser)
opts.model_opts(parser)
opts.train_opts(parser)
opt = parser.parse_args()
main(opt)