Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE]Horovod support for training transformer (PART 2) #1301

Merged
merged 44 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c564420
set default shuffle=True for boundedbudgetsampler
Jul 28, 2020
7f27b85
fix
Jul 28, 2020
20d2fe1
fix log condition
Jul 28, 2020
54a1abf
use horovod to train transformer
Jul 29, 2020
5b39c75
Merge pull request #1 from dmlc/numpy
hutao965 Jul 30, 2020
5685601
fix
Jul 30, 2020
5815f12
add mirror wmt dataset
Jul 30, 2020
4001821
fix
Jul 30, 2020
e280a65
Merge pull request #2 from dmlc/numpy
hutao965 Aug 2, 2020
5ba3789
rename wmt.txt to wmt.json and remove part of urls
Aug 2, 2020
c99503f
fix
Aug 2, 2020
cf8bcd3
tuning params
Aug 2, 2020
3c1c5c0
Merge branch 'numpy' of https://github.com/hymzoque/gluon-nlp into numpy
Aug 2, 2020
cc760d4
use get_repo_url()
Aug 2, 2020
93243de
update average checkpoint cli
Aug 4, 2020
73b942e
paste result of transformer large
Aug 4, 2020
3a969d5
fix
Aug 4, 2020
48d1fb9
fix logging in train_transformer
Aug 6, 2020
983d1ab
fix
Aug 6, 2020
9f7c087
fix
Aug 6, 2020
7dae5ad
fix
Aug 6, 2020
af18aca
add transformer base config
Aug 7, 2020
eadf4db
fix
Aug 7, 2020
192da91
Merge branch 'numpy' of https://github.com/dmlc/gluon-nlp into numpy
Aug 8, 2020
cfe7705
change to wmt14/full
Aug 8, 2020
e200440
print more sacrebleu info
Aug 9, 2020
a84a2d0
fix
Aug 10, 2020
681dc00
add test for num_parts and update behavior of boundedbudgetsampler wi…
Aug 16, 2020
179b8db
fix
Aug 17, 2020
eccceb7
fix
Aug 17, 2020
adc4b34
fix
Aug 17, 2020
b03faf2
fix logging when using horovd
Aug 17, 2020
bc9ad31
udpate doc of train transformer
Aug 17, 2020
28fad6d
Merge branch 'master' of https://github.com/dmlc/gluon-nlp into numpy
Aug 17, 2020
3377e6b
add test case for fail downloading
Aug 19, 2020
8fb28d7
add a ShardedIterator
Aug 19, 2020
c347dec
fix
Aug 19, 2020
cdbe846
fix
Aug 19, 2020
050478f
fix
Aug 19, 2020
6f94fc7
change mpirun to horovodrun
Aug 19, 2020
36cccca
make the horovod command complete
Aug 19, 2020
ddb174f
use print(sampler) to cover the codes of __repr__ func
Aug 19, 2020
a436c09
empty commit
Aug 20, 2020
db8afdb
add test case test_sharded_iterator_even_size
Aug 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions scripts/datasets/machine_translation/wmt2014_ende.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ nlp_data prepare_wmt \
# We use sacrebleu to fetch the dev set (newstest2013) and test set (newstest2014)
sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/dev.raw.${SRC}
sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/dev.raw.${TGT}
sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC}
sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}
sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC}
sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}


# Clean and tokenize the training + dev corpus
Expand All @@ -34,6 +34,7 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-corpus dev.raw.${TGT} \
--min-num-words 1 \
--max-num-words 100 \
--max-ratio 1.5 \
--src-save-path dev.tok.${SRC} \
--tgt-save-path dev.tok.${TGT}

Expand Down
50 changes: 38 additions & 12 deletions scripts/machine_translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,36 @@ python3 train_transformer.py \
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_base \
--lr 0.002 \
--batch_size 2700 \
--num_averages 5 \
--warmup_steps 4000 \
--sampler BoundedBudgetSampler \
--max_num_tokens 2700 \
--max_update 15000 \
--save_interval_update 500 \
--warmup_steps 6000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
```

Or training via horovod
```
horovodrun -np 4 -H localhost:4 python3 train_transformer.py \
--comm_backend horovod \
--train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \
--train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \
--dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \
--dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \
--src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_base \
--lr 0.002 \
--sampler BoundedBudgetSampler \
--max_num_tokens 2700 \
--max_update 15000 \
--save_interval_update 500 \
--warmup_steps 6000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
Expand All @@ -42,18 +69,16 @@ Use the average_checkpoint cli to average the last 10 checkpoints

```bash
gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \
--begin 21 \
--end 30 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
--begin 30 \
--end 39 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch_avg_30_39.params
```


Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python3 evaluate_transformer.py \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/epoch_avg_30_39.params \
--src_lang en \
--tgt_lang de \
--cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
Expand Down Expand Up @@ -110,7 +135,6 @@ gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_A
Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python3 evaluate_transformer.py \
--param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
--src_lang en \
Expand All @@ -131,16 +155,18 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU):

- transformer_base

(test bleu / valid bleu)
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | - | - | - | - |
| yttm | | 26.50/26.29 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |

- transformer_wmt_en_de_big

(test bleu / valid bleu)
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | 27.99 | - | - | - |
| yttm | | 27.93/26.82 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |
13 changes: 10 additions & 3 deletions scripts/machine_translation/evaluate_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,17 @@ def evaluate(args):
of.write('\n'.join(pred_sentences))
of.write('\n')

sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} Avg NLL={}, Perplexity={}'
sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} '
'({:2.1f} {:2.1f} {:2.1f} {:2.1f}) '
'(BP={:.3f}, ratio={:.3f}, syslen={}, reflen={}), '
'Avg NLL={}, Perplexity={}'
.format(end_eval_time - start_eval_time, len(all_tgt_lines),
sacrebleu_out.score, avg_nll_loss, np.exp(avg_nll_loss)))
sacrebleu_out.score,
*sacrebleu_out.precisions,
sacrebleu_out.bp, sacrebleu_out.sys_len / sacrebleu_out.ref_len,
sacrebleu_out.sys_len, sacrebleu_out.ref_len,
avg_nll_loss, np.exp(avg_nll_loss)))
# inference only
else:
with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
Expand Down
45 changes: 27 additions & 18 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
LinearWidthBucket,
ExpWidthBucket,
FixedBucketSampler,
BoundedBudgetSampler
BoundedBudgetSampler,
ShardedIterator
)
import gluonnlp.data.batchify as bf
from gluonnlp.data import Vocab
Expand Down Expand Up @@ -179,6 +180,7 @@ def parse_args():
logging.info(args)
return args


def validation(model, data_loader, ctx_l):
"""Validate the model on the dataset

Expand Down Expand Up @@ -231,14 +233,16 @@ def load_dataset_with_cache(src_corpus_path: str,
tgt_corpus_path: str,
src_tokenizer: BaseTokenizerWithVocab,
tgt_tokenizer: BaseTokenizerWithVocab,
overwrite_cache: bool):
overwrite_cache: bool,
local_rank: int):
# TODO online h5py multi processing encode (Tao)
src_md5sum = md5sum(src_corpus_path)
tgt_md5sum = md5sum(tgt_corpus_path)
cache_filepath = os.path.join(CACHE_PATH,
'{}_{}.cache.npz'.format(src_md5sum[:6], tgt_md5sum[:6]))
if os.path.exists(cache_filepath) and not overwrite_cache:
logging.info('Load cache from {}'.format(cache_filepath))
if local_rank == 0:
logging.info('Load cache from {}'.format(cache_filepath))
npz_data = np.load(cache_filepath, allow_pickle=True)
src_data, tgt_data = npz_data['src_data'][:], npz_data['tgt_data'][:]
else:
Expand Down Expand Up @@ -288,7 +292,7 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path):


def train(args):
store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
_, num_parts, rank, local_rank, _, ctx_l = init_comm(
args.comm_backend, args.gpus)
src_tokenizer = create_tokenizer(args.src_tokenizer,
args.src_subword_model_path,
Expand All @@ -302,12 +306,14 @@ def train(args):
args.train_tgt_corpus,
src_tokenizer,
tgt_tokenizer,
args.overwrite_cache)
args.overwrite_cache,
local_rank)
dev_src_data, dev_tgt_data = load_dataset_with_cache(args.dev_src_corpus,
args.dev_tgt_corpus,
src_tokenizer,
tgt_tokenizer,
args.overwrite_cache)
args.overwrite_cache,
local_rank)
data_train = gluon.data.SimpleDataset(
[(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))])
Expand Down Expand Up @@ -363,9 +369,9 @@ def train(args):
train_batch_sampler = BoundedBudgetSampler(lengths=[(ele[2], ele[3]) for ele in data_train],
max_num_tokens=args.max_num_tokens,
max_num_sentences=args.max_num_sentences,
seed=args.seed,
num_parts=num_parts,
part_index=rank)
seed=args.seed)
if num_parts > 1:
train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank)
elif args.sampler == 'FixedBucketSampler':
if args.comm_backend == 'horovod':
raise NotImplementedError('FixedBucketSampler does not support horovod at present')
Expand All @@ -390,8 +396,7 @@ def train(args):
else:
raise NotImplementedError

if local_rank == 0:
logging.info(train_batch_sampler)
logging.info(train_batch_sampler)

batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack())
train_data_loader = gluon.data.DataLoader(data_train,
Expand Down Expand Up @@ -483,27 +488,31 @@ def train(args):
log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
'throughput={:.2f}K wps, wc={:.2f}K, LR={}'
.format(epoch_id, processed_batch_num * num_parts, len(train_data_loader),
log_avg_loss, np.exp(log_avg_loss),
.format(epoch_id, processed_batch_num * num_parts,
len(train_data_loader), log_avg_loss, np.exp(log_avg_loss),
wps / 1000, log_wc / 1000, trainer.learning_rate))
log_start_time = time.time()
log_avg_loss = 0
log_loss_denom = 0
log_wc = 0
if local_rank == 0 and \
(args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
n_update = n_train_iters // args.save_interval_update
model.save_parameters(os.path.join(args.save_dir,
'update{:d}.params'.format(n_train_iters // args.save_interval_update)),
'update{:d}.params'.format(n_update)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Update {}] validation loss/ppl={:.4f}/{:.4f}'
.format(n_update, avg_valid_loss, np.exp(avg_valid_loss)))
if args.max_update > 0 and n_train_iters >= args.max_update:
break
if local_rank == 0 and args.epochs > 0:
if local_rank == 0:
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

if args.max_update > 0 and n_train_iters >= args.max_update:
break
Expand Down
2 changes: 1 addition & 1 deletion scripts/question_answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ We could speed up multi-GPU training via horovod.
Compared to KVStore, training RoBERTa Large model on SQuAD 2.0 with 3 epochs will save roughly 1/4 training resources (8.48 vs 11.32 hours). Results may vary depending on the training instances.

```bash
mpirun -np 4 -H localhost:4 python3 run_squad.py \
horovodrun -np 4 -H localhost:4 python3 run_squad.py \
--comm_backend horovod \
...
```
Expand Down
Loading