Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE]Horovod support for training transformer (PART 2) #1301

Merged
merged 44 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c564420
set default shuffle=True for boundedbudgetsampler
Jul 28, 2020
7f27b85
fix
Jul 28, 2020
20d2fe1
fix log condition
Jul 28, 2020
54a1abf
use horovod to train transformer
Jul 29, 2020
5b39c75
Merge pull request #1 from dmlc/numpy
hutao965 Jul 30, 2020
5685601
fix
Jul 30, 2020
5815f12
add mirror wmt dataset
Jul 30, 2020
4001821
fix
Jul 30, 2020
e280a65
Merge pull request #2 from dmlc/numpy
hutao965 Aug 2, 2020
5ba3789
rename wmt.txt to wmt.json and remove part of urls
Aug 2, 2020
c99503f
fix
Aug 2, 2020
cf8bcd3
tuning params
Aug 2, 2020
3c1c5c0
Merge branch 'numpy' of https://github.com/hymzoque/gluon-nlp into numpy
Aug 2, 2020
cc760d4
use get_repo_url()
Aug 2, 2020
93243de
update average checkpoint cli
Aug 4, 2020
73b942e
paste result of transformer large
Aug 4, 2020
3a969d5
fix
Aug 4, 2020
48d1fb9
fix logging in train_transformer
Aug 6, 2020
983d1ab
fix
Aug 6, 2020
9f7c087
fix
Aug 6, 2020
7dae5ad
fix
Aug 6, 2020
af18aca
add transformer base config
Aug 7, 2020
eadf4db
fix
Aug 7, 2020
192da91
Merge branch 'numpy' of https://github.com/dmlc/gluon-nlp into numpy
Aug 8, 2020
cfe7705
change to wmt14/full
Aug 8, 2020
e200440
print more sacrebleu info
Aug 9, 2020
a84a2d0
fix
Aug 10, 2020
681dc00
add test for num_parts and update behavior of boundedbudgetsampler wi…
Aug 16, 2020
179b8db
fix
Aug 17, 2020
eccceb7
fix
Aug 17, 2020
adc4b34
fix
Aug 17, 2020
b03faf2
fix logging when using horovd
Aug 17, 2020
bc9ad31
udpate doc of train transformer
Aug 17, 2020
28fad6d
Merge branch 'master' of https://github.com/dmlc/gluon-nlp into numpy
Aug 17, 2020
3377e6b
add test case for fail downloading
Aug 19, 2020
8fb28d7
add a ShardedIterator
Aug 19, 2020
c347dec
fix
Aug 19, 2020
cdbe846
fix
Aug 19, 2020
050478f
fix
Aug 19, 2020
6f94fc7
change mpirun to horovodrun
Aug 19, 2020
36cccca
make the horovod command complete
Aug 19, 2020
ddb174f
use print(sampler) to cover the codes of __repr__ func
Aug 19, 2020
a436c09
empty commit
Aug 20, 2020
db8afdb
add test case test_sharded_iterator_even_size
Aug 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions scripts/datasets/machine_translation/wmt2014_ende.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ nlp_data prepare_wmt \
# We use sacrebleu to fetch the dev set (newstest2013) and test set (newstest2014)
sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/dev.raw.${SRC}
sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/dev.raw.${TGT}
sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC}
sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}
sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC}
sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}


# Clean and tokenize the training + dev corpus
Expand All @@ -23,7 +23,7 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--src-corpus train.raw.${SRC} \
--tgt-corpus train.raw.${TGT} \
--min-num-words 1 \
--max-num-words 100 \
--max-num-words 250 \
--max-ratio 1.5 \
--src-save-path train.tok.${SRC} \
--tgt-save-path train.tok.${TGT}
Expand All @@ -33,7 +33,8 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--src-corpus dev.raw.${SRC} \
--tgt-corpus dev.raw.${TGT} \
--min-num-words 1 \
--max-num-words 100 \
--max-num-words 250 \
--max-ratio 1.5 \
--src-save-path dev.tok.${SRC} \
--tgt-save-path dev.tok.${TGT}

Expand Down
36 changes: 23 additions & 13 deletions scripts/machine_translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,38 @@ python3 train_transformer.py \
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_base \
--lr 0.002 \
--batch_size 2700 \
--num_averages 5 \
--warmup_steps 4000 \
--sampler BoundedBudgetSampler \
--max_num_tokens 2700 \
--max_update 15000 \
--save_interval_update 500 \
--warmup_steps 6000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
```

Use the average_checkpoint cli to average the last 10 checkpoints
Or training via horovod
```
mpirun -np 4 -H localhost:4 python3 train_transformer.py \
--comm_backend horovod \
...
```

Use the average_checkpoint cli to average the last 10 epoch checkpoints

```bash
gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \
--begin 21 \
--end 30 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
--begin 30 \
--end 39 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch_avg_30_39.params
```


Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python3 evaluate_transformer.py \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
python evaluate_transformer.py \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/epoch_avg_30_39.params \
--src_lang en \
--tgt_lang de \
--cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
Expand Down Expand Up @@ -97,7 +105,7 @@ python3 train_transformer.py \
--gpus 0,1,2,3
```

Use the average_checkpoint cli to average the last 10 checkpoints
Use the average_checkpoint cli to average the last 10 update checkpoints

```bash
gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \
Expand Down Expand Up @@ -131,16 +139,18 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU):

- transformer_base

(test bleu / valid bleu)
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | - | - | - | - |
| yttm | | 26.50/26.29 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |

- transformer_wmt_en_de_big

(test bleu / valid bleu)
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | 27.99 | - | - | - |
| yttm | | 27.93/26.82 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |
13 changes: 10 additions & 3 deletions scripts/machine_translation/evaluate_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,17 @@ def evaluate(args):
of.write('\n'.join(pred_sentences))
of.write('\n')

sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} Avg NLL={}, Perplexity={}'
sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} '
'({:2.1f} {:2.1f} {:2.1f} {:2.1f}) '
'(BP={:.3f}, ratio={:.3f}, syslen={}, reflen={}), '
'Avg NLL={}, Perplexity={}'
.format(end_eval_time - start_eval_time, len(all_tgt_lines),
sacrebleu_out.score, avg_nll_loss, np.exp(avg_nll_loss)))
sacrebleu_out.score,
*sacrebleu_out.precisions,
sacrebleu_out.bp, sacrebleu_out.sys_len / sacrebleu_out.ref_len,
sacrebleu_out.sys_len, sacrebleu_out.ref_len,
avg_nll_loss, np.exp(avg_nll_loss)))
# inference only
else:
with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
Expand Down
60 changes: 34 additions & 26 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ def parse_args():
args = parser.parse_args()
if args.max_update > 0:
args.epochs = -1
logging_config(args.save_dir, console=True)
logging.info(args)
_, args.num_parts, args.rank, args.local_rank, _, args.ctx_l = init_comm(
args.comm_backend, args.gpus)
if args.local_rank == 0:
logging_config(args.save_dir, console=True)
logging.info(args)
return args

def validation(model, data_loader, ctx_l):
Expand Down Expand Up @@ -231,14 +234,16 @@ def load_dataset_with_cache(src_corpus_path: str,
tgt_corpus_path: str,
src_tokenizer: BaseTokenizerWithVocab,
tgt_tokenizer: BaseTokenizerWithVocab,
overwrite_cache: bool):
overwrite_cache: bool,
local_rank: int):
# TODO online h5py multi processing encode (Tao)
src_md5sum = md5sum(src_corpus_path)
tgt_md5sum = md5sum(tgt_corpus_path)
cache_filepath = os.path.join(CACHE_PATH,
'{}_{}.cache.npz'.format(src_md5sum[:6], tgt_md5sum[:6]))
if os.path.exists(cache_filepath) and not overwrite_cache:
logging.info('Load cache from {}'.format(cache_filepath))
if local_rank == 0:
logging.info('Load cache from {}'.format(cache_filepath))
npz_data = np.load(cache_filepath, allow_pickle=True)
src_data, tgt_data = npz_data['src_data'][:], npz_data['tgt_data'][:]
else:
Expand Down Expand Up @@ -288,8 +293,6 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path):


def train(args):
store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
args.comm_backend, args.gpus)
src_tokenizer = create_tokenizer(args.src_tokenizer,
args.src_subword_model_path,
args.src_vocab_path)
Expand All @@ -302,12 +305,14 @@ def train(args):
args.train_tgt_corpus,
src_tokenizer,
tgt_tokenizer,
args.overwrite_cache)
args.overwrite_cache,
args.local_rank)
dev_src_data, dev_tgt_data = load_dataset_with_cache(args.dev_src_corpus,
args.dev_tgt_corpus,
src_tokenizer,
tgt_tokenizer,
args.overwrite_cache)
args.overwrite_cache,
args.local_rank)
data_train = gluon.data.SimpleDataset(
[(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))])
Expand All @@ -328,9 +333,9 @@ def train(args):
cfg.freeze()
model = TransformerModel.from_cfg(cfg)
model.initialize(mx.init.Xavier(magnitude=args.magnitude),
ctx=ctx_l)
ctx=args.ctx_l)
model.hybridize()
if local_rank == 0:
if args.local_rank == 0:
logging.info(model)
with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
cfg_f.write(cfg.dump())
Expand Down Expand Up @@ -364,8 +369,8 @@ def train(args):
max_num_tokens=args.max_num_tokens,
max_num_sentences=args.max_num_sentences,
seed=args.seed,
num_parts=num_parts,
part_index=rank)
num_parts=args.num_parts,
part_index=args.rank)
elif args.sampler == 'FixedBucketSampler':
if args.comm_backend == 'horovod':
raise NotImplementedError('FixedBucketSampler does not support horovod at present')
Expand All @@ -390,8 +395,7 @@ def train(args):
else:
raise NotImplementedError

if local_rank == 0:
logging.info(train_batch_sampler)
logging.info(train_batch_sampler)

batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack())
train_data_loader = gluon.data.DataLoader(data_train,
Expand Down Expand Up @@ -422,13 +426,13 @@ def train(args):
while (args.epochs < 0 or epoch_id < args.epochs): # when args.epochs < 0, the model will keep training
n_epoch_train_iters = 0
processed_batch_num = 0
train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
train_multi_data_loader = grouper(train_data_loader, len(args.ctx_l))
is_last_batch = False
sample_data_l = next(train_multi_data_loader)
while not is_last_batch:
processed_batch_num += len(sample_data_l)
loss_l = []
for sample_data, ctx in zip(sample_data_l, ctx_l):
for sample_data, ctx in zip(sample_data_l, args.ctx_l):
if sample_data is None:
continue
src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data
Expand Down Expand Up @@ -457,7 +461,7 @@ def train(args):
sample_data_l = next(train_multi_data_loader)
except StopIteration:
is_last_batch = True
if local_rank == 0 and num_params is None:
if args.local_rank == 0 and num_params is None:
num_params, num_fixed_params = count_parameters(model.collect_params())
logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}'
.format(num_params, num_fixed_params))
Expand All @@ -475,35 +479,39 @@ def train(args):
if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \
(args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update):
model_averager.step()
if local_rank == 0 and \
if args.local_rank == 0 and \
(n_epoch_train_iters % args.log_interval == 0 or is_last_batch):
log_end_time = time.time()
log_wc = log_wc.asnumpy()
wps = log_wc / (log_end_time - log_start_time)
log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
'throughput={:.2f}K wps, wc={:.2f}K, LR={}'
.format(epoch_id, processed_batch_num * num_parts, len(train_data_loader),
log_avg_loss, np.exp(log_avg_loss),
.format(epoch_id, min(processed_batch_num * args.num_parts, len(train_data_loader)),
len(train_data_loader), log_avg_loss, np.exp(log_avg_loss),
wps / 1000, log_wc / 1000, trainer.learning_rate))
log_start_time = time.time()
log_avg_loss = 0
log_loss_denom = 0
log_wc = 0
if local_rank == 0 and \
if args.local_rank == 0 and \
(args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
n_update = n_train_iters // args.save_interval_update
model.save_parameters(os.path.join(args.save_dir,
'update{:d}.params'.format(n_train_iters // args.save_interval_update)),
'update{:d}.params'.format(n_update)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, args.ctx_l)
logging.info('[Update {}] validation loss/ppl={:.4f}/{:.4f}'
.format(n_update, avg_valid_loss, np.exp(avg_valid_loss)))
if args.max_update > 0 and n_train_iters >= args.max_update:
break
if local_rank == 0 and args.epochs > 0:
if args.local_rank == 0:
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))
avg_valid_loss = validation(model, val_data_loader, args.ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

if args.max_update > 0 and n_train_iters >= args.max_update:
break
Expand Down
55 changes: 41 additions & 14 deletions src/gluonnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import math
import random
import warnings
import random
import numpy as np
import abc
from typing import Union, Sequence, Optional, List
Expand Down Expand Up @@ -289,12 +288,16 @@ class BoundedBudgetSampler(BaseSampler):
Number of partitions which the data is split into (default: 1)
part_index
The index of the part to read from
even_size
If the number of batches is not even across all partitions, sample a few extra batches
for the ones with fewer batches.
"""
def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]],
max_num_tokens: int = -1, max_num_sentences: int = -1,
required_batch_size_multiple: int = 1,
shuffle: bool = True, seed: Optional[int] = None,
num_parts: int = 1, part_index: int = 0):
num_parts: int = 1, part_index: int = 0,
even_size: bool = False):
assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.'
assert max_num_tokens > 0 or max_num_sentences > 0, \
'One of max_num_tokens and max_num_sentences must be larger than 0'
Expand All @@ -310,6 +313,7 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]],
self._rng = np.random.RandomState(seed)
self._num_parts = num_parts
self._part_index = part_index
self._even_size = even_size
# sort
self._indices = self._indices[np.argsort(self._lengths, kind='mergesort')]
batch = []
Expand All @@ -335,29 +339,52 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]],
)
batch.append(index)
if len(batch) > 0:
self._batches.append(np.array(batch))

self._batches.append(np.array(batch))

# split batches to parts
length = len(self._batches)
if not even_size:
part_len = length // num_parts
remaining = length % num_parts
self._start = part_len * part_index + min(part_index, remaining)
self._end = self._start + part_len + (part_index < remaining)
self._part_len = self._end - self._start
else:
part_len = int(length + num_parts - 1) // num_parts
self._start = part_len * part_index
self._end = self._start + part_len
self._start = self._start if self._start < length else length
self._end = self._end if self._end < length else length
self._part_len = part_len
self._part_batches = self._batches[self._start:self._end]
if even_size and len(self._part_batches) < self._part_len:
candidates = random.sample(self._batches, k=self._part_len-len(self._part_batches))
self._part_batches.extend(candidates)
self._part_sample_num = sum([len(b) for b in self._part_batches])

def __iter__(self):
if self._shuffle:
self._rng.shuffle(self._batches)
part_batches = []
for i in range(len(self._batches)):
if i % self._num_parts == self._part_index:
part_batches.append(self._batches[i])
for batch in part_batches:
self._rng.shuffle(self._part_batches)
for batch in self._part_batches:
yield batch

def __len__(self):
return len(self._batches)

def __repr__(self):
ret = '{name}(\n' \
' sample_num={sample_num},\n' \
' batch_num={batch_num}\n'\
' batch_num={batch_num},\n' \
' part_sample_num={part_sample_num},\n' \
' part_batch_num={part_batch_num},\n' \
' part_num={part_num},\n' \
' part_index={part_index}\n' \
')'\
.format(name=self.__class__.__name__,
sample_num=len(self._lengths),
batch_num=len(self._batches))
batch_num=len(self._batches),
part_sample_num=self._part_sample_num,
part_batch_num=len(self._part_batches),
part_num=self._num_parts,
part_index=self._part_index)
return ret


Expand Down
Loading