From c564420602499f1c56efc8024754e6a2e7744900 Mon Sep 17 00:00:00 2001 From: Hu Date: Tue, 28 Jul 2020 12:31:05 +0800 Subject: [PATCH 01/39] set default shuffle=True for boundedbudgetsampler --- src/gluonnlp/data/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index 5879ea453a..470f8bbf80 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -289,7 +289,7 @@ class BoundedBudgetSampler(BaseSampler): def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], max_num_tokens: int = -1, max_num_sentences: int = -1, required_batch_size_multiple: int = 1, - shuffle: bool = False, seed: Optional[int] = None): + shuffle: bool = True, seed: Optional[int] = None): assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.' assert max_num_tokens > 0 or max_num_sentences > 0, \ 'One of max_num_tokens and max_num_sentences must be larger than 0' From 7f27b8517200df0cb06406ee6e3c70862edd0cd3 Mon Sep 17 00:00:00 2001 From: Hu Date: Tue, 28 Jul 2020 12:32:04 +0800 Subject: [PATCH 02/39] fix --- scripts/datasets/machine_translation/wmt2014_ende.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/datasets/machine_translation/wmt2014_ende.sh b/scripts/datasets/machine_translation/wmt2014_ende.sh index 028b796ed2..f319db2163 100644 --- a/scripts/datasets/machine_translation/wmt2014_ende.sh +++ b/scripts/datasets/machine_translation/wmt2014_ende.sh @@ -34,7 +34,6 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \ --tgt-corpus dev.raw.${TGT} \ --min-num-words 1 \ --max-num-words 100 \ - --max-ratio 1.5 \ --src-save-path dev.tok.${SRC} \ --tgt-save-path dev.tok.${TGT} From 20d2fe15fa1a50ede837ad016915236bd1c67f3d Mon Sep 17 00:00:00 2001 From: Hu Date: Tue, 28 Jul 2020 12:34:09 +0800 Subject: [PATCH 03/39] fix log condition --- scripts/machine_translation/train_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index dfa9a74cfb..d3ca507c32 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -451,7 +451,7 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if n_epoch_train_iters % args.log_interval == 0: + if n_epoch_train_iters % args.log_interval == 0 or is_last_batch: log_end_time = time.time() log_wc = log_wc.asnumpy() wps = log_wc / (log_end_time - log_start_time) From 54a1abf328682e052037fd0ce7fb5e6608ebbd98 Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 29 Jul 2020 14:52:27 +0800 Subject: [PATCH 04/39] use horovod to train transformer --- .../machine_translation/train_transformer.py | 37 +++++++++++++++---- src/gluonnlp/data/sampler.py | 16 +++++++- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index d3ca507c32..56c7077c4a 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -44,7 +44,7 @@ from mxnet import gluon from gluonnlp.models.transformer import TransformerNMTModel from gluonnlp.utils.misc import logging_config, AverageSGDTracker, count_parameters,\ - md5sum, grouper + md5sum, grouper, init_comm from gluonnlp.data.sampler import ( ConstWidthBucket, LinearWidthBucket, @@ -58,6 +58,11 @@ from gluonnlp.data.tokenizers import BaseTokenizerWithVocab from gluonnlp.lr_scheduler import InverseSquareRootScheduler from gluonnlp.loss import LabelSmoothCrossEntropyLoss +try: + import horovod.mxnet as hvd +except ImportError: + pass + mx.npx.set_np() @@ -162,6 +167,9 @@ def parse_args(): parser.add_argument('--overwrite_cache', action='store_true') parser.add_argument('--fp16', action='store_true', help='Whether to use dtype float16') + parser.add_argument('--comm_backend', type=str, default='device', + choices=['horovod', 'dist_sync_device', 'device'], + help='Communication backend.') parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.') args = parser.parse_args() @@ -280,6 +288,8 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path): def train(args): + store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm( + args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) @@ -304,8 +314,6 @@ def train(args): data_val = gluon.data.SimpleDataset( [(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))]) - ctx_l = [mx.cpu()] if args.gpus is None or args.gpus == ''\ - else [mx.gpu(int(x)) for x in args.gpus.split(',')] # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerNMTModel.get_cfg().clone_merge(args.cfg) @@ -330,6 +338,10 @@ def train(args): from_logits=False) label_smooth_loss.hybridize() rescale_loss = 100.0 + + if args.comm_backend == 'horovod': + hvd.broadcast_parameters(model.collect_params(), root_rank=0) + # Construct the trainer # TODO(sxjscience) Support AMP if args.lr is None: @@ -338,16 +350,25 @@ def train(args): base_lr = args.lr lr_scheduler = InverseSquareRootScheduler(warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) - trainer = gluon.Trainer(model.collect_params(), 'adam', + trainer_settings = (model.collect_params(), 'adam', {'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.98, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler}) + if args.comm_backend == 'horovod': + trainer = hvd.DistributedTrainer(*trainer_settings) + else: + trainer = gluon.Trainer(*trainer_settings) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler(lengths=[(ele[2], ele[3]) for ele in data_train], max_tokens=args.max_tokens, max_sentences=args.max_sentences, - seed=args.seed) + seed=args.seed, + num_parts=num_parts, + part_index=rank) elif args.sampler == 'FixedBucketSampler': + if args.comm_backend == 'horovod': + raise NotImplementedError + if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': @@ -451,7 +472,8 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if n_epoch_train_iters % args.log_interval == 0 or is_last_batch: + if local_rank == 0 and \ + (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() wps = log_wc / (log_end_time - log_start_time) @@ -465,7 +487,8 @@ def train(args): log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if args.max_update > 0 and n_train_iters % args.save_interval_update == 0: + if local_rank == 0 and \ + (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join(args.save_dir, '{:d}.params'.format(n_train_iters // args.save_interval_update)), deduplicate=True) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index 470f8bbf80..6d817dcdef 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -285,14 +285,20 @@ class BoundedBudgetSampler(BaseSampler): Whether to shuffle the batches. seed The seed of the sampler + num_parts + Number of partitions which the data is split into (default: 1) + part_index + The index of the part to read from """ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], max_num_tokens: int = -1, max_num_sentences: int = -1, required_batch_size_multiple: int = 1, - shuffle: bool = True, seed: Optional[int] = None): + shuffle: bool = True, seed: Optional[int] = None, + num_parts: int = 1, part_index: int = 0): assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.' assert max_num_tokens > 0 or max_num_sentences > 0, \ 'One of max_num_tokens and max_num_sentences must be larger than 0' + assert part_index < num_parts, 'part_index should be less than num_parts' self._lengths = np.array(lengths) if self._lengths.ndim == 2: self._lengths = self._lengths.max(axis=1) @@ -302,6 +308,8 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], self._batches = [] self._shuffle = shuffle self._rng = np.random.RandomState(seed) + self._num_parts = num_parts + self._part_index = part_index # sort self._indices = self._indices[np.argsort(self._lengths, kind='mergesort')] batch = [] @@ -332,7 +340,11 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], def __iter__(self): if self._shuffle: self._rng.shuffle(self._batches) - for batch in self._batches: + part_batches = [] + for i in range(len(self._batches)): + if i % self._num_parts == self._part_index: + part_batches.append(self._batches[i]) + for batch in part_batches: yield batch def __len__(self): From 5685601fc9f7307aaa268353252ce48beefa3c31 Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 30 Jul 2020 14:43:37 +0800 Subject: [PATCH 05/39] fix --- scripts/machine_translation/train_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 56c7077c4a..4154000e32 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -136,9 +136,9 @@ def parse_args(): '"exp": the width of bucket increases exponentially') parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the throughput of the bucketing') - parser.add_argument('--max_tokens', type=int, default=-1, + parser.add_argument('--max_num_tokens', type=int, default=-1, help='max tokens num of each batch, applicable while using BoundedBudgetSampler') - parser.add_argument('--max_sentences', type=int, default=-1, + parser.add_argument('--max_num_sentences', type=int, default=-1, help='max sentences num of each batch, applicable while using BoundedBudgetSampler') parser.add_argument('--lr', type=float, default=0.002, help='The learning rate at the end of the warmup stage. ' @@ -360,8 +360,8 @@ def train(args): # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler(lengths=[(ele[2], ele[3]) for ele in data_train], - max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_num_tokens=args.max_num_tokens, + max_num_sentences=args.max_num_sentences, seed=args.seed, num_parts=num_parts, part_index=rank) From 5815f128012bdcc88cb24b02e146fd00a3fbea26 Mon Sep 17 00:00:00 2001 From: Hu Date: Fri, 31 Jul 2020 05:16:56 +0800 Subject: [PATCH 06/39] add mirror wmt dataset --- .../machine_translation/prepare_wmt.py | 15 ++++- scripts/datasets/url_checksums/mirror/wmt.txt | 55 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 scripts/datasets/url_checksums/mirror/wmt.txt diff --git a/scripts/datasets/machine_translation/prepare_wmt.py b/scripts/datasets/machine_translation/prepare_wmt.py index 1e910a70ca..65ade27569 100644 --- a/scripts/datasets/machine_translation/prepare_wmt.py +++ b/scripts/datasets/machine_translation/prepare_wmt.py @@ -7,6 +7,7 @@ import functools import tarfile import gzip +import json from xml.etree import ElementTree from gluonnlp.data.filtering import ProfanityFilter from gluonnlp.utils.misc import file_line_number, download, load_checksum_stats @@ -336,6 +337,8 @@ } } +with open('../../url_checksums/mirror/wmt.txt') as wmt_mirror_map_f: + _WMT_MIRROR_URL_MAP = json.load(wmt_mirror_map_f) def _clean_space(s: str): """Removes trailing and leading spaces and collapses multiple consecutive internal spaces to a single one. @@ -626,7 +629,11 @@ def fetch_mono_dataset(selection: Union[str, List[str], List[List[str]]], save_path_l = [path] + selection + [matched_lang, original_filename] else: save_path_l = [path] + selection + [original_filename] - download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash) + download_fname = download( + _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url, + path=os.path.join(*save_path_l), + sha1_hash=sha1_hash + ) download_fname_l.append(download_fname) if len(download_fname_l) > 1: data_path = concatenate_files(download_fname_l) @@ -792,7 +799,11 @@ def fetch_wmt_parallel_dataset(selection: Union[str, List[str], List[List[str]]] save_path_l = [path] + selection + [matched_pair, original_filename] else: save_path_l = [path] + selection + [original_filename] - download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash) + download_fname = download( + _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url, + path=os.path.join(*save_path_l), + sha1_hash=sha1_hash + ) download_fname_l.append(download_fname) if len(download_fname_l) > 1: data_path = concatenate_files(download_fname_l) diff --git a/scripts/datasets/url_checksums/mirror/wmt.txt b/scripts/datasets/url_checksums/mirror/wmt.txt new file mode 100644 index 0000000000..821c0312e9 --- /dev/null +++ b/scripts/datasets/url_checksums/mirror/wmt.txt @@ -0,0 +1,55 @@ +{ + "http://www.statmt.org/europarl/v7/cs-en.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/cs-en-28bad3e096923694fb776b6cd6ba1079546a9e58.tgz", + "http://www.statmt.org/europarl/v7/de-en.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/de-en-53bb5408d22977c89284bd755717e6bbb5b12bc5.tgz", + "http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-ep-v8-2f5c2c2c98b72921474a3f1837dc5b61dd44ba88.tgz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.cs-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.cs-en.tsv-e36a1bfe634379ec813b399b57a38093df2349ef.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz", + "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-cs.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-cs.bicleaner07.tmx-201fc692d4e730cc63e0b1274f98769eeab2faad.gz", + "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-de.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-de.bicleaner07.tmx-7930ac4d7aa1d17467edc04a45f3ed2ffe809a30.gz", + "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-fi.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-fi.bicleaner07.tmx-2485ce022a8027a4cec60eed0e35b989d2302e32.gz", + "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-lt.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-lt.bicleaner07.tmx-926dfcd0aba9cc46e6e1a099047a49ee01745d63.gz", + "https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/paracrawl-release1.en-ru.zipporah0-dedup-clean-6a4c43a2fac39153f2320984a0f13bf4266696d8.tgz", + "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz", + "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz", + "http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz", + "http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v11-f51a1f03908e790d23d10001e92e09ce9555a790.tgz", + "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v12-d98afc59e1d753485530b377ff65f1f891d3bced.tgz", + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v13-cbaa7834e58d36f228336e3caee6a9056029ff5d.tgz", + "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news-commentary-v14.de-en.tsv-c1fd94c7c9ff222968cbd45100bdd8dbeb5ab2aa.gz", + "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-zh.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news-commentary-v14.en-zh.tsv-4ca5c01deeba5425646d42f9598d081cd662908b.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.cs-en.tsv-6e094d218dfd8f987fa1a18ea7b4cb127cfb1763.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-pl.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.cs-pl.tsv-dc93d346d151bf73e4165d6db425b903fc21a5b0.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.de-en.tsv-e141c55c43a474e06c259c3fa401288b39cd4315.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.es-pt.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.es-pt.tsv-c3bd398d57471ee4ab33323393977b8d475a368c.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.fi-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.fi-en.tsv-5668b004567ca286d1aad9c2b45862a441d79667.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.gu-en.tsv-95b9f15b6a86bfed6dc9bc91597368fd334f436e.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.hi-ne.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.hi-ne.tsv-6d63908950c72bc8cc69ca470deccff11354afc2.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.kk-en.tsv-56ee1e450ef98fe92ea2116c3ce7acc7c7c42b39.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.lt-en.tsv-b8829928686727165eec6c591d2875d12d7c0cfe.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.ru-en.tsv-16d8d231fdf6347b4cc7834654adec80153ff7a4.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.zh-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.zh-en.tsv-5829097ff7dd61752f29fb306b04d790a1a1cfd7.gz", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-ru-98c4e01e16070567d27da0ab4fe401f309dd3678.tar.gz.00", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-ru-86c6013dc88f353d2d6e591928e7549060fcb949.tar.gz.01", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-ru-bf6b18a33c8cafa6889fd463fa8a2850d8877d35.tar.gz.02", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01", + "http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz", + "https://s3-eu-west-1.amazonaws.com/tilde-model/rapid2019.de-en.zip" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/rapid2019.de-en-aafe431338abb98fc20951b2d6011223a1b91311.zip", + "http://data.statmt.org/wmt19/translation-task/dev.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz", + "http://data.statmt.org/wmt19/translation-task/test.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz", + "https://gluonnlp-numpy-data.s3-accelerate.amazonaws.com/wmt/cwmt.tar.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/cwmt-88c2f4295169e9f0a9834bf8bff87e3fd4c04055.tar.gz", + "http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz", + "http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz", + "http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz", + "http://data.statmt.org/news-crawl/de/news.2010.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2010.de.shuffled.deduped-f29b761194e9606f086102cfac12813931575818.gz", + "http://data.statmt.org/news-crawl/de/news.2011.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2011.de.shuffled.deduped-613b16e7a1cb8559dd428525a4c3b42c8a4dc278.gz", + "http://data.statmt.org/news-crawl/de/news.2012.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2012.de.shuffled.deduped-1bc419364ea3fe2f9ba4236947c012d4198d9282.gz", + "http://data.statmt.org/news-crawl/de/news.2013.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2013.de.shuffled.deduped-3edd84a7f105907608371c81babc7a9078f40aac.gz", + "http://data.statmt.org/news-crawl/de/news.2014.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2014.de.shuffled.deduped-1466c67b330c08ab5ab7d48e666c1d3a0bb4e479.gz", + "http://data.statmt.org/news-crawl/de/news.2015.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2015.de.shuffled.deduped-2c6d5ec9f8fe51e9eb762be8ff7107c6116c00c4.gz", + "http://data.statmt.org/news-crawl/de/news.2016.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2016.de.shuffled.deduped-e7d235c5d28e36dcf6382f1aa12c6ff37d4529bb.gz", + "http://data.statmt.org/news-crawl/de/news.2017.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2017.de.shuffled.deduped-f70b4a67bc04c0fdc2ec955b737fa22681e8c038.gz", + "http://data.statmt.org/news-crawl/de/news.2018.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2018.de.shuffled.deduped-43f8237de1e219276c0682255def13aa2cb80e35.gz" +} \ No newline at end of file From 40018218f7bc4e4c04035f37038ec7d6c0b655c7 Mon Sep 17 00:00:00 2001 From: Hu Date: Fri, 31 Jul 2020 05:21:35 +0800 Subject: [PATCH 07/39] fix --- scripts/datasets/machine_translation/prepare_wmt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/datasets/machine_translation/prepare_wmt.py b/scripts/datasets/machine_translation/prepare_wmt.py index 65ade27569..8acc05d1e5 100644 --- a/scripts/datasets/machine_translation/prepare_wmt.py +++ b/scripts/datasets/machine_translation/prepare_wmt.py @@ -337,7 +337,7 @@ } } -with open('../../url_checksums/mirror/wmt.txt') as wmt_mirror_map_f: +with open(os.path.join(_CURR_DIR, '..', 'url_checksums', 'mirror', 'wmt.txt')) as wmt_mirror_map_f: _WMT_MIRROR_URL_MAP = json.load(wmt_mirror_map_f) def _clean_space(s: str): From 5ba3789c575d52827ec0eb34179dcff2727cf9e1 Mon Sep 17 00:00:00 2001 From: Hu Date: Sun, 2 Aug 2020 16:46:39 +0800 Subject: [PATCH 08/39] rename wmt.txt to wmt.json and remove part of urls --- .../datasets/url_checksums/mirror/{wmt.txt => wmt.json} | 7 ------- 1 file changed, 7 deletions(-) rename scripts/datasets/url_checksums/mirror/{wmt.txt => wmt.json} (85%) diff --git a/scripts/datasets/url_checksums/mirror/wmt.txt b/scripts/datasets/url_checksums/mirror/wmt.json similarity index 85% rename from scripts/datasets/url_checksums/mirror/wmt.txt rename to scripts/datasets/url_checksums/mirror/wmt.json index 821c0312e9..88f6e3b13f 100644 --- a/scripts/datasets/url_checksums/mirror/wmt.txt +++ b/scripts/datasets/url_checksums/mirror/wmt.json @@ -6,11 +6,6 @@ "http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz", "http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz", "http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz", - "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-cs.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-cs.bicleaner07.tmx-201fc692d4e730cc63e0b1274f98769eeab2faad.gz", - "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-de.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-de.bicleaner07.tmx-7930ac4d7aa1d17467edc04a45f3ed2ffe809a30.gz", - "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-fi.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-fi.bicleaner07.tmx-2485ce022a8027a4cec60eed0e35b989d2302e32.gz", - "https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-lt.bicleaner07.tmx.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/en-lt.bicleaner07.tmx-926dfcd0aba9cc46e6e1a099047a49ee01745d63.gz", - "https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/paracrawl-release1.en-ru.zipporah0-dedup-clean-6a4c43a2fac39153f2320984a0f13bf4266696d8.tgz", "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz", "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz", "http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz", @@ -36,10 +31,8 @@ "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00", "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01", "http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz", - "https://s3-eu-west-1.amazonaws.com/tilde-model/rapid2019.de-en.zip" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/rapid2019.de-en-aafe431338abb98fc20951b2d6011223a1b91311.zip", "http://data.statmt.org/wmt19/translation-task/dev.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz", "http://data.statmt.org/wmt19/translation-task/test.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz", - "https://gluonnlp-numpy-data.s3-accelerate.amazonaws.com/wmt/cwmt.tar.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/cwmt-88c2f4295169e9f0a9834bf8bff87e3fd4c04055.tar.gz", "http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz", "http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz", "http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz", From c99503fdf9078619f656fbe60bf28dbda3638749 Mon Sep 17 00:00:00 2001 From: Hu Date: Sun, 2 Aug 2020 16:58:06 +0800 Subject: [PATCH 09/39] fix --- scripts/machine_translation/train_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 4154000e32..12842db4fe 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -61,7 +61,7 @@ try: import horovod.mxnet as hvd except ImportError: - pass + hvd = None mx.npx.set_np() @@ -156,10 +156,10 @@ def parse_args(): 'This is useful to mimic large batch training with limited gpu memory') parser.add_argument('--magnitude', type=float, default=3.0, help='Magnitude of Xavier initialization') - parser.add_argument('--num_averages', type=int, default=5, + parser.add_argument('--num_averages', type=int, default=-1, help='Perform final testing based on the ' 'average of last num_averages checkpoints. ' - 'This is only used if average_checkpoint is True') + 'Use num_average will cause extra gpu memory usage.') parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='report interval') parser.add_argument('--save_dir', type=str, default='transformer_out', @@ -367,7 +367,7 @@ def train(args): part_index=rank) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': - raise NotImplementedError + raise NotImplementedError('FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() From cf8bcd3b9c4baf1f07d6c4eb96154b9c515b98fd Mon Sep 17 00:00:00 2001 From: Hu Date: Sun, 2 Aug 2020 16:58:48 +0800 Subject: [PATCH 10/39] tuning params --- .../wmt2014_back_translation.sh | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/scripts/machine_translation/wmt2014_back_translation.sh b/scripts/machine_translation/wmt2014_back_translation.sh index 29ad1dfd33..8597759c48 100644 --- a/scripts/machine_translation/wmt2014_back_translation.sh +++ b/scripts/machine_translation/wmt2014_back_translation.sh @@ -126,13 +126,16 @@ python train_transformer.py \ --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ --save_dir backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO} \ --cfg transformer_nmt_base \ - --lr 0.002 \ - --batch_size 2700 \ - --max_update 60000 \ + --lr 0.003 \ + --max_num_tokens 4096 \ + --sampler BoundedBudgetSampler \ + --comm_backend horovod \ + --max_update 30000 \ --save_interval_update 1000 \ - --warmup_steps 4000 \ + --warmup_steps 6000 \ --warmup_init_lr 0.0 \ - --seed 100 \ + --num_averages -1 \ + --seed 123 \ --gpus 0,1,2,3 # TODO nlp_average_checkpoint @@ -142,7 +145,7 @@ nlp_nmt average_checkpoint --prefix range() \ # Finally, we can evaluate the model python evaluate_transformer.py \ - --param_path backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO}/average.params \ + --param_path backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO}/avg_20_29.params \ --src_lang ${SRC} \ --tgt_lang ${TGT} \ --cfg transformer_nmt_base \ From cc760d43d8d3373a4d8d609b4ec6bc82975dcdb3 Mon Sep 17 00:00:00 2001 From: Hu Date: Sun, 2 Aug 2020 17:12:31 +0800 Subject: [PATCH 11/39] use get_repo_url() --- .../machine_translation/prepare_wmt.py | 19 ++-- .../datasets/url_checksums/mirror/wmt.json | 92 +++++++++---------- 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/scripts/datasets/machine_translation/prepare_wmt.py b/scripts/datasets/machine_translation/prepare_wmt.py index 8acc05d1e5..2ac5f77772 100644 --- a/scripts/datasets/machine_translation/prepare_wmt.py +++ b/scripts/datasets/machine_translation/prepare_wmt.py @@ -11,7 +11,7 @@ from xml.etree import ElementTree from gluonnlp.data.filtering import ProfanityFilter from gluonnlp.utils.misc import file_line_number, download, load_checksum_stats -from gluonnlp.base import get_data_home_dir +from gluonnlp.base import get_data_home_dir, get_repo_url from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY # The datasets are provided by WMT2014-WMT2019 and can be freely used for research purposes. @@ -337,9 +337,16 @@ } } -with open(os.path.join(_CURR_DIR, '..', 'url_checksums', 'mirror', 'wmt.txt')) as wmt_mirror_map_f: +with open(os.path.join(_CURR_DIR, '..', 'url_checksums', 'mirror', 'wmt.json')) as wmt_mirror_map_f: _WMT_MIRROR_URL_MAP = json.load(wmt_mirror_map_f) +def _download_with_mirror(url, path, sha1_hash): + return download( + get_repo_url() + _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url, + path=path, + sha1_hash=sha1_hash + ) + def _clean_space(s: str): """Removes trailing and leading spaces and collapses multiple consecutive internal spaces to a single one. This is borrowed from sacrebleu: https://github.com/mjpost/sacreBLEU/blob/069b0c88fceb29f3e24c3c19ba25342a3e7f96cb/sacrebleu.py#L1077 @@ -629,8 +636,8 @@ def fetch_mono_dataset(selection: Union[str, List[str], List[List[str]]], save_path_l = [path] + selection + [matched_lang, original_filename] else: save_path_l = [path] + selection + [original_filename] - download_fname = download( - _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url, + download_fname = _download_with_mirror( + url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash ) @@ -799,8 +806,8 @@ def fetch_wmt_parallel_dataset(selection: Union[str, List[str], List[List[str]]] save_path_l = [path] + selection + [matched_pair, original_filename] else: save_path_l = [path] + selection + [original_filename] - download_fname = download( - _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url, + download_fname = _download_with_mirror( + url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash ) diff --git a/scripts/datasets/url_checksums/mirror/wmt.json b/scripts/datasets/url_checksums/mirror/wmt.json index 88f6e3b13f..fa695f6bd9 100644 --- a/scripts/datasets/url_checksums/mirror/wmt.json +++ b/scripts/datasets/url_checksums/mirror/wmt.json @@ -1,48 +1,48 @@ { - "http://www.statmt.org/europarl/v7/cs-en.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/cs-en-28bad3e096923694fb776b6cd6ba1079546a9e58.tgz", - "http://www.statmt.org/europarl/v7/de-en.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/de-en-53bb5408d22977c89284bd755717e6bbb5b12bc5.tgz", - "http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-ep-v8-2f5c2c2c98b72921474a3f1837dc5b61dd44ba88.tgz", - "http://www.statmt.org/europarl/v9/training/europarl-v9.cs-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.cs-en.tsv-e36a1bfe634379ec813b399b57a38093df2349ef.gz", - "http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz", - "http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz", - "http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz", - "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz", - "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz", - "http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz", - "http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v11-f51a1f03908e790d23d10001e92e09ce9555a790.tgz", - "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v12-d98afc59e1d753485530b377ff65f1f891d3bced.tgz", - "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/training-parallel-nc-v13-cbaa7834e58d36f228336e3caee6a9056029ff5d.tgz", - "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news-commentary-v14.de-en.tsv-c1fd94c7c9ff222968cbd45100bdd8dbeb5ab2aa.gz", - "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-zh.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news-commentary-v14.en-zh.tsv-4ca5c01deeba5425646d42f9598d081cd662908b.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.cs-en.tsv-6e094d218dfd8f987fa1a18ea7b4cb127cfb1763.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-pl.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.cs-pl.tsv-dc93d346d151bf73e4165d6db425b903fc21a5b0.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.de-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.de-en.tsv-e141c55c43a474e06c259c3fa401288b39cd4315.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.es-pt.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.es-pt.tsv-c3bd398d57471ee4ab33323393977b8d475a368c.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.fi-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.fi-en.tsv-5668b004567ca286d1aad9c2b45862a441d79667.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.gu-en.tsv-95b9f15b6a86bfed6dc9bc91597368fd334f436e.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.hi-ne.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.hi-ne.tsv-6d63908950c72bc8cc69ca470deccff11354afc2.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.kk-en.tsv-56ee1e450ef98fe92ea2116c3ce7acc7c7c42b39.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.lt-en.tsv-b8829928686727165eec6c591d2875d12d7c0cfe.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.ru-en.tsv-16d8d231fdf6347b4cc7834654adec80153ff7a4.gz", - "http://data.statmt.org/wikititles/v1/wikititles-v1.zh-en.tsv.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/wikititles-v1.zh-en.tsv-5829097ff7dd61752f29fb306b04d790a1a1cfd7.gz", - "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-ru-98c4e01e16070567d27da0ab4fe401f309dd3678.tar.gz.00", - "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-ru-86c6013dc88f353d2d6e591928e7549060fcb949.tar.gz.01", - "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-ru-bf6b18a33c8cafa6889fd463fa8a2850d8877d35.tar.gz.02", - "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00", - "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01", - "http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz", - "http://data.statmt.org/wmt19/translation-task/dev.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz", - "http://data.statmt.org/wmt19/translation-task/test.tgz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz", - "http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz", - "http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz", - "http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz", - "http://data.statmt.org/news-crawl/de/news.2010.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2010.de.shuffled.deduped-f29b761194e9606f086102cfac12813931575818.gz", - "http://data.statmt.org/news-crawl/de/news.2011.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2011.de.shuffled.deduped-613b16e7a1cb8559dd428525a4c3b42c8a4dc278.gz", - "http://data.statmt.org/news-crawl/de/news.2012.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2012.de.shuffled.deduped-1bc419364ea3fe2f9ba4236947c012d4198d9282.gz", - "http://data.statmt.org/news-crawl/de/news.2013.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2013.de.shuffled.deduped-3edd84a7f105907608371c81babc7a9078f40aac.gz", - "http://data.statmt.org/news-crawl/de/news.2014.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2014.de.shuffled.deduped-1466c67b330c08ab5ab7d48e666c1d3a0bb4e479.gz", - "http://data.statmt.org/news-crawl/de/news.2015.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2015.de.shuffled.deduped-2c6d5ec9f8fe51e9eb762be8ff7107c6116c00c4.gz", - "http://data.statmt.org/news-crawl/de/news.2016.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2016.de.shuffled.deduped-e7d235c5d28e36dcf6382f1aa12c6ff37d4529bb.gz", - "http://data.statmt.org/news-crawl/de/news.2017.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2017.de.shuffled.deduped-f70b4a67bc04c0fdc2ec955b737fa22681e8c038.gz", - "http://data.statmt.org/news-crawl/de/news.2018.de.shuffled.deduped.gz" : "https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/datasets/third_party_mirror/news.2018.de.shuffled.deduped-43f8237de1e219276c0682255def13aa2cb80e35.gz" + "http://www.statmt.org/europarl/v7/cs-en.tgz" : "datasets/third_party_mirror/cs-en-28bad3e096923694fb776b6cd6ba1079546a9e58.tgz", + "http://www.statmt.org/europarl/v7/de-en.tgz" : "datasets/third_party_mirror/de-en-53bb5408d22977c89284bd755717e6bbb5b12bc5.tgz", + "http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz" : "datasets/third_party_mirror/training-parallel-ep-v8-2f5c2c2c98b72921474a3f1837dc5b61dd44ba88.tgz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.cs-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.cs-en.tsv-e36a1bfe634379ec813b399b57a38093df2349ef.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz", + "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz", + "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz", + "http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz", + "http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz" : "datasets/third_party_mirror/training-parallel-nc-v11-f51a1f03908e790d23d10001e92e09ce9555a790.tgz", + "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" : "datasets/third_party_mirror/training-parallel-nc-v12-d98afc59e1d753485530b377ff65f1f891d3bced.tgz", + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" : "datasets/third_party_mirror/training-parallel-nc-v13-cbaa7834e58d36f228336e3caee6a9056029ff5d.tgz", + "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.de-en.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.de-en.tsv-c1fd94c7c9ff222968cbd45100bdd8dbeb5ab2aa.gz", + "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-zh.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.en-zh.tsv-4ca5c01deeba5425646d42f9598d081cd662908b.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-en.tsv-6e094d218dfd8f987fa1a18ea7b4cb127cfb1763.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-pl.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-pl.tsv-dc93d346d151bf73e4165d6db425b903fc21a5b0.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.de-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.de-en.tsv-e141c55c43a474e06c259c3fa401288b39cd4315.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.es-pt.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.es-pt.tsv-c3bd398d57471ee4ab33323393977b8d475a368c.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.fi-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.fi-en.tsv-5668b004567ca286d1aad9c2b45862a441d79667.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.gu-en.tsv-95b9f15b6a86bfed6dc9bc91597368fd334f436e.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.hi-ne.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.hi-ne.tsv-6d63908950c72bc8cc69ca470deccff11354afc2.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.kk-en.tsv-56ee1e450ef98fe92ea2116c3ce7acc7c7c42b39.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.lt-en.tsv-b8829928686727165eec6c591d2875d12d7c0cfe.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.ru-en.tsv-16d8d231fdf6347b4cc7834654adec80153ff7a4.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.zh-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.zh-en.tsv-5829097ff7dd61752f29fb306b04d790a1a1cfd7.gz", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-ru-98c4e01e16070567d27da0ab4fe401f309dd3678.tar.gz.00", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-ru-86c6013dc88f353d2d6e591928e7549060fcb949.tar.gz.01", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" : "datasets/third_party_mirror/UNv1.0.en-ru-bf6b18a33c8cafa6889fd463fa8a2850d8877d35.tar.gz.02", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01", + "http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz", + "http://data.statmt.org/wmt19/translation-task/dev.tgz" : "datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz", + "http://data.statmt.org/wmt19/translation-task/test.tgz" : "datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz", + "http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz", + "http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz", + "http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz", + "http://data.statmt.org/news-crawl/de/news.2010.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2010.de.shuffled.deduped-f29b761194e9606f086102cfac12813931575818.gz", + "http://data.statmt.org/news-crawl/de/news.2011.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2011.de.shuffled.deduped-613b16e7a1cb8559dd428525a4c3b42c8a4dc278.gz", + "http://data.statmt.org/news-crawl/de/news.2012.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2012.de.shuffled.deduped-1bc419364ea3fe2f9ba4236947c012d4198d9282.gz", + "http://data.statmt.org/news-crawl/de/news.2013.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2013.de.shuffled.deduped-3edd84a7f105907608371c81babc7a9078f40aac.gz", + "http://data.statmt.org/news-crawl/de/news.2014.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2014.de.shuffled.deduped-1466c67b330c08ab5ab7d48e666c1d3a0bb4e479.gz", + "http://data.statmt.org/news-crawl/de/news.2015.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2015.de.shuffled.deduped-2c6d5ec9f8fe51e9eb762be8ff7107c6116c00c4.gz", + "http://data.statmt.org/news-crawl/de/news.2016.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2016.de.shuffled.deduped-e7d235c5d28e36dcf6382f1aa12c6ff37d4529bb.gz", + "http://data.statmt.org/news-crawl/de/news.2017.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2017.de.shuffled.deduped-f70b4a67bc04c0fdc2ec955b737fa22681e8c038.gz", + "http://data.statmt.org/news-crawl/de/news.2018.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2018.de.shuffled.deduped-43f8237de1e219276c0682255def13aa2cb80e35.gz" } \ No newline at end of file From 93243dea29be9cfaccb62036139bbedb0a1d508c Mon Sep 17 00:00:00 2001 From: Hu Date: Tue, 4 Aug 2020 08:44:05 +0800 Subject: [PATCH 12/39] update average checkpoint cli --- src/gluonnlp/cli/average_checkpoint.py | 53 ++++++++++++++------------ 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/src/gluonnlp/cli/average_checkpoint.py b/src/gluonnlp/cli/average_checkpoint.py index 8a4ce86b63..bec986f114 100644 --- a/src/gluonnlp/cli/average_checkpoint.py +++ b/src/gluonnlp/cli/average_checkpoint.py @@ -1,39 +1,44 @@ import argparse import mxnet as mx +import re mx.npx.set_np() def get_parser(): parser = argparse.ArgumentParser(description='Script to average the checkpoints') - parser.add_argument('--checkpoints', type=str, required=True, - help='path of checkpoints, use * to represent the numbers, ' - 'e.g. --checkpoints folder/epoch*.prams') - parser.add_argument('--range', type=str, nargs='+', required=True, - help='number of checkpoints, supports range and list format at present, ' - 'e.g. --range range(3) [4,7, 5] range(8,100,2)') + parser.add_argument('--checkpoints', type=str, required=True, nargs='+', + help='checkpoint file paths, supports two format: ' + '--checkpoints folder/epoch*.params or --checkpoints folder/*.param') + parser.add_argument('--begin', type=int, required=True, help='begin number of checkpoints') + parser.add_argument('--end', type=int, required=True, help='end number of checkpoints') parser.add_argument('--save-path', type=str, required=True, help='Path of the output file') return parser def main(args): - temp_range = [] - try: - for r in args.range: - if len(r) > 5 and r[:5] == 'range': - r = r[5:].strip()[1:-1].split(',') - r = tuple([int(n.strip()) for n in r]) - assert len(r) >= 1 and len(r) <= 3 - temp_range.extend(range(*r)) - elif r[0] == '[' and r[-1] == ']': - r = r[1:-1].split(',') - r = [int(n.strip()) for n in r] - temp_range.extend(r) - else: - raise NotImplementedError - except: - raise Exception('wrong range format') - args.range = temp_range - ckpt_paths = [args.checkpoints.replace('*', str(i)) for i in args.range] + assert args.begin >= 0 + assert args.end >= args.begin + args.range = list(range(args.begin, args.end + 1)) + + ckpt_epochs_regexp = re.compile(r'(.*\/)epoch(\d+)\.params') + ckpt_updates_regexp = re.compile(r'(.*\/)(\d+)\.params') + ckpt_path = args.checkpoints[0] + if ckpt_epochs_regexp.fullmatch(ckpt_path) is not None: + ckpt_regexp = ckpt_epochs_regexp + elif ckpt_updates_regexp.fullmatch(ckpt_path) is not None: + ckpt_regexp = ckpt_updates_regexp + else: + raise Exception('Wrong checkpoints path format') + + ckpt_paths = [] + for path in args.checkpoints: + m = ckpt_regexp.fullmatch(path) + assert m is not None, 'Wrong checkpoints path format' + num = int(m.group(2)) + if num >= args.begin and num <= args.end: + ckpt_paths.append(path) + + assert len(ckpt_paths) > 0 res = mx.npx.load(ckpt_paths[0]) keys = res.keys() for ckpt_path in ckpt_paths[1:]: From 73b942e3ae41ef11b775445f9dfe00db7f9a3f25 Mon Sep 17 00:00:00 2001 From: Hu Date: Tue, 4 Aug 2020 08:44:25 +0800 Subject: [PATCH 13/39] paste result of transformer large --- scripts/machine_translation/README.md | 44 ++++++++++++++++++--------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 8b5d0695f1..0c799403e5 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -15,30 +15,46 @@ Then, you can run the experiment, we use the ```bash SUBWORD_MODEL=yttm +SRC=en +TGT=de +datapath=../datasets/machine_translation python train_transformer.py \ - --train_src_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.en \ - --train_tgt_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.de \ - --dev_src_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.en \ - --dev_tgt_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.de \ - --src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ - --src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ - --tgt_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ - --tgt_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ - --save_dir transformer_wmt2014_ende_${SUBWORD_MODEL} \ - --cfg transformer_base \ - --lr 0.002 \ + --train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \ + --train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \ + --dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \ + --dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \ + --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --save_dir transformer_big_wmt2014_de_en_${SUBWORD_ALGO} \ + --cfg transformer_wmt_en_de_big \ + --lr 0.001 \ + --sampler BoundedBudgetSampler \ + --max_num_tokens 3584 \ + --max_update 15000 \ --warmup_steps 4000 \ --warmup_init_lr 0.0 \ --seed 123 \ --gpus 0,1,2,3 ``` +Use the average_checkpoint cli to average the last 10 checkpoints + +```bash +gluon_average_checkpoint --checkpoints transformer_big_wmt2014_de_en_${SUBWORD_ALGO}/epoch*.params \ + --begin 21 \ + --end 30 \ + --save-path transformer_big_wmt2014_de_en_${SUBWORD_ALGO}/avg_21_30.params +``` + + Use the following command to inference/evaluate the Transformer model: ```bash SUBWORD_MODEL=yttm python evaluate_transformer.py \ - --param_path transformer_wmt2014_ende_${SUBWORD_MODEL}/average.params \ + --param_path transformer_wmt2014_ende_${SUBWORD_MODEL}/average_21_30.params \ --src_lang en \ --tgt_lang de \ --cfg transformer_wmt2014_ende_${SUBWORD_MODEL}/config.yml \ @@ -55,10 +71,10 @@ python evaluate_transformer.py \ Test BLEU score with 3 seeds (evaluated via sacre BLEU): -- transformer_base +- transformer_wmt_en_de_big | Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | |---------------|------------|-------------|-------------|--------------|-------------| -| yttm | | 26.63 | 26.73 | | - | +| yttm | | 27.99 | - | - | - | | hf_bpe | | - | - | - | - | | spm | | - | - | - | - | From 3a969d5d6b3ffa02b89d6667e545bfbbbb59fc96 Mon Sep 17 00:00:00 2001 From: Hu Date: Tue, 4 Aug 2020 09:07:19 +0800 Subject: [PATCH 14/39] fix --- scripts/machine_translation/train_transformer.py | 2 +- src/gluonnlp/cli/average_checkpoint.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 20fdd2d53a..267d8a185b 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -489,7 +489,7 @@ def train(args): if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join(args.save_dir, - '{:d}.params'.format(n_train_iters // args.save_interval_update)), + 'update{:d}.params'.format(n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break diff --git a/src/gluonnlp/cli/average_checkpoint.py b/src/gluonnlp/cli/average_checkpoint.py index bec986f114..a660244bfe 100644 --- a/src/gluonnlp/cli/average_checkpoint.py +++ b/src/gluonnlp/cli/average_checkpoint.py @@ -7,8 +7,8 @@ def get_parser(): parser = argparse.ArgumentParser(description='Script to average the checkpoints') parser.add_argument('--checkpoints', type=str, required=True, nargs='+', - help='checkpoint file paths, supports two format: ' - '--checkpoints folder/epoch*.params or --checkpoints folder/*.param') + help='checkpoint file paths, supports two format, ' + '--checkpoints folder/epoch*.params or --checkpoints folder/update*.param') parser.add_argument('--begin', type=int, required=True, help='begin number of checkpoints') parser.add_argument('--end', type=int, required=True, help='end number of checkpoints') parser.add_argument('--save-path', type=str, required=True, @@ -20,8 +20,8 @@ def main(args): assert args.end >= args.begin args.range = list(range(args.begin, args.end + 1)) - ckpt_epochs_regexp = re.compile(r'(.*\/)epoch(\d+)\.params') - ckpt_updates_regexp = re.compile(r'(.*\/)(\d+)\.params') + ckpt_epochs_regexp = re.compile(r'(.*\/)?epoch(\d+)\.params') + ckpt_updates_regexp = re.compile(r'(.*\/)?update(\d+)\.params') ckpt_path = args.checkpoints[0] if ckpt_epochs_regexp.fullmatch(ckpt_path) is not None: ckpt_regexp = ckpt_epochs_regexp From 48d1fb9b7472f4233885e79c4482f3482028eaa5 Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 6 Aug 2020 15:20:29 +0800 Subject: [PATCH 15/39] fix logging in train_transformer --- scripts/machine_translation/train_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 267d8a185b..e8fe2bfbdd 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -471,7 +471,7 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if local_rank == 0 and \ + if local_rank == num_parts - 1 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() @@ -479,21 +479,21 @@ def train(args): log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}' - .format(epoch_id, processed_batch_num, len(train_data_loader), + .format(epoch_id, processed_batch_num * num_parts, len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if local_rank == 0 and \ + if local_rank == num_parts - 1 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join(args.save_dir, 'update{:d}.params'.format(n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break - if args.epochs > 0: + if local_rank == num_parts - 1 and args.epochs > 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) From 983d1ab7aa61a722879c74ea930a97cee65c7fef Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 6 Aug 2020 15:33:25 +0800 Subject: [PATCH 16/39] fix --- scripts/machine_translation/train_transformer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index e8fe2bfbdd..a527bd3e7c 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -330,7 +330,8 @@ def train(args): model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() - logging.info(model) + if local_rank == num_parts - 1: + logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=len(tgt_vocab), @@ -389,12 +390,15 @@ def train(args): else: raise NotImplementedError + if local_rank == num_parts - 1: + logging.info(train_batch_sampler) + batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) - logging.info(train_batch_sampler) + val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, @@ -453,7 +457,7 @@ def train(args): sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True - if num_params is None: + if local_rank == num_parts - 1 and num_params is None: num_params, num_fixed_params = count_parameters(model.collect_params()) logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}' .format(num_params, num_fixed_params)) From 9f7c087f653c6478b4a0635701a85ef7f68c0920 Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 6 Aug 2020 15:49:20 +0800 Subject: [PATCH 17/39] fix --- scripts/machine_translation/train_transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index a527bd3e7c..ff58dcdf72 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -330,7 +330,7 @@ def train(args): model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() - if local_rank == num_parts - 1: + if is_master_node: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) @@ -390,7 +390,7 @@ def train(args): else: raise NotImplementedError - if local_rank == num_parts - 1: + if is_master_node: logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) @@ -457,7 +457,7 @@ def train(args): sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True - if local_rank == num_parts - 1 and num_params is None: + if is_master_node and num_params is None: num_params, num_fixed_params = count_parameters(model.collect_params()) logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}' .format(num_params, num_fixed_params)) @@ -475,7 +475,7 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if local_rank == num_parts - 1 and \ + if is_master_node and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() @@ -490,14 +490,14 @@ def train(args): log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if local_rank == num_parts - 1 and \ + if is_master_node and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join(args.save_dir, 'update{:d}.params'.format(n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break - if local_rank == num_parts - 1 and args.epochs > 0: + if is_master_node and args.epochs > 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) From 7dae5ad94bf7687ee57294afd9173391496074bc Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 6 Aug 2020 15:59:02 +0800 Subject: [PATCH 18/39] fix --- scripts/machine_translation/train_transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index ff58dcdf72..2a3ef665ab 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -330,7 +330,7 @@ def train(args): model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() - if is_master_node: + if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) @@ -390,7 +390,7 @@ def train(args): else: raise NotImplementedError - if is_master_node: + if local_rank == 0: logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) @@ -457,7 +457,7 @@ def train(args): sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True - if is_master_node and num_params is None: + if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(model.collect_params()) logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}' .format(num_params, num_fixed_params)) @@ -475,7 +475,7 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if is_master_node and \ + if local_rank == 0 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() @@ -490,14 +490,14 @@ def train(args): log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if is_master_node and \ + if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join(args.save_dir, 'update{:d}.params'.format(n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break - if is_master_node and args.epochs > 0: + if local_rank == 0 and args.epochs > 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) From af18acaab0b96c095736037a520bb0566f79a56b Mon Sep 17 00:00:00 2001 From: Hu Date: Fri, 7 Aug 2020 10:16:44 +0800 Subject: [PATCH 19/39] add transformer base config --- scripts/machine_translation/README.md | 81 ++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 7 deletions(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 0c799403e5..061ba0658d 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -10,9 +10,10 @@ You may first run the following command in [datasets/machine_translation](../dat bash wmt2014_ende.sh yttm ``` -Then, you can run the experiment, we use the -"transformer_base" configuration. +Then, you can run the experiment. +For "transformer_base" configuration +# TODO ```bash SUBWORD_MODEL=yttm SRC=en @@ -27,7 +28,65 @@ python train_transformer.py \ --src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ - --save_dir transformer_big_wmt2014_de_en_${SUBWORD_ALGO} \ + --save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \ + --cfg transformer_base \ + --lr 0.002 \ + --batch_size 2700 \ + --num_averages 5 \ + --warmup_steps 4000 \ + --warmup_init_lr 0.0 \ + --seed 123 \ + --gpus 0,1,2,3 +``` + +Use the average_checkpoint cli to average the last 10 checkpoints + +```bash +gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \ + --begin 21 \ + --end 30 \ + --save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params +``` + + +Use the following command to inference/evaluate the Transformer model: + +```bash +SUBWORD_MODEL=yttm +python evaluate_transformer.py \ + --param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \ + --src_lang en \ + --tgt_lang de \ + --cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \ + --src_tokenizer ${SUBWORD_MODEL} \ + --tgt_tokenizer ${SUBWORD_MODEL} \ + --src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ + --tgt_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ + --src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ + --tgt_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ + --src_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.en \ + --tgt_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.de +``` + + + +For "transformer_wmt_en_de_big" configuration + +```bash +SUBWORD_MODEL=yttm +SRC=en +TGT=de +datapath=../datasets/machine_translation +python train_transformer.py \ + --train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \ + --train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \ + --dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \ + --dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \ + --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --save_dir transformer_big_wmt2014_en_de_${SUBWORD_ALGO} \ --cfg transformer_wmt_en_de_big \ --lr 0.001 \ --sampler BoundedBudgetSampler \ @@ -42,10 +101,10 @@ python train_transformer.py \ Use the average_checkpoint cli to average the last 10 checkpoints ```bash -gluon_average_checkpoint --checkpoints transformer_big_wmt2014_de_en_${SUBWORD_ALGO}/epoch*.params \ +gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \ --begin 21 \ --end 30 \ - --save-path transformer_big_wmt2014_de_en_${SUBWORD_ALGO}/avg_21_30.params + --save-path transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params ``` @@ -54,10 +113,10 @@ Use the following command to inference/evaluate the Transformer model: ```bash SUBWORD_MODEL=yttm python evaluate_transformer.py \ - --param_path transformer_wmt2014_ende_${SUBWORD_MODEL}/average_21_30.params \ + --param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \ --src_lang en \ --tgt_lang de \ - --cfg transformer_wmt2014_ende_${SUBWORD_MODEL}/config.yml \ + --cfg transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \ --src_tokenizer ${SUBWORD_MODEL} \ --tgt_tokenizer ${SUBWORD_MODEL} \ --src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ @@ -71,6 +130,14 @@ python evaluate_transformer.py \ Test BLEU score with 3 seeds (evaluated via sacre BLEU): +- transformer_base + +| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | +|---------------|------------|-------------|-------------|--------------|-------------| +| yttm | | - | - | - | - | +| hf_bpe | | - | - | - | - | +| spm | | - | - | - | - | + - transformer_wmt_en_de_big | Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | From eadf4dbf0ea61cdc495b762a3a9a063a563aa3de Mon Sep 17 00:00:00 2001 From: Hu Date: Fri, 7 Aug 2020 13:54:59 +0800 Subject: [PATCH 20/39] fix --- scripts/machine_translation/train_transformer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 2a3ef665ab..a3b2c99439 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -483,8 +483,8 @@ def train(args): log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}' - .format(epoch_id, processed_batch_num * num_parts, len(train_data_loader), - log_avg_loss, np.exp(log_avg_loss), + .format(epoch_id, min(processed_batch_num * num_parts, len(train_data_loader)), + len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 @@ -492,18 +492,22 @@ def train(args): log_wc = 0 if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): + n_update = n_train_iters // args.save_interval_update model.save_parameters(os.path.join(args.save_dir, - 'update{:d}.params'.format(n_train_iters // args.save_interval_update)), + 'update{:d}.params'.format(n_update)), deduplicate=True) + avg_valid_loss = validation(model, val_data_loader, ctx_l) + logging.info('[Update {}] validation loss/ppl={:.4f}/{:.4f}' + .format(n_update, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break if local_rank == 0 and args.epochs > 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) - avg_valid_loss = validation(model, val_data_loader, ctx_l) - logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}' - .format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) + avg_valid_loss = validation(model, val_data_loader, ctx_l) + logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}' + .format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break From cfe7705de886c9f66b0416d97eace24cf69fce10 Mon Sep 17 00:00:00 2001 From: Hu Date: Sat, 8 Aug 2020 21:47:12 +0800 Subject: [PATCH 21/39] change to wmt14/full --- scripts/datasets/machine_translation/wmt2014_ende.sh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/datasets/machine_translation/wmt2014_ende.sh b/scripts/datasets/machine_translation/wmt2014_ende.sh index f319db2163..5c5d5ecbce 100644 --- a/scripts/datasets/machine_translation/wmt2014_ende.sh +++ b/scripts/datasets/machine_translation/wmt2014_ende.sh @@ -12,8 +12,8 @@ nlp_data prepare_wmt \ # We use sacrebleu to fetch the dev set (newstest2013) and test set (newstest2014) sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/dev.raw.${SRC} sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/dev.raw.${TGT} -sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC} -sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT} +sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC} +sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT} # Clean and tokenize the training + dev corpus @@ -23,7 +23,7 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \ --src-corpus train.raw.${SRC} \ --tgt-corpus train.raw.${TGT} \ --min-num-words 1 \ - --max-num-words 100 \ + --max-num-words 250 \ --max-ratio 1.5 \ --src-save-path train.tok.${SRC} \ --tgt-save-path train.tok.${TGT} @@ -33,7 +33,8 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \ --src-corpus dev.raw.${SRC} \ --tgt-corpus dev.raw.${TGT} \ --min-num-words 1 \ - --max-num-words 100 \ + --max-num-words 250 \ + --max-ratio 1.5 \ --src-save-path dev.tok.${SRC} \ --tgt-save-path dev.tok.${TGT} From e20044047530f88c02493c30f81d9e0891ef1d63 Mon Sep 17 00:00:00 2001 From: Hu Date: Sun, 9 Aug 2020 13:19:26 +0800 Subject: [PATCH 22/39] print more sacrebleu info --- scripts/machine_translation/evaluate_transformer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scripts/machine_translation/evaluate_transformer.py b/scripts/machine_translation/evaluate_transformer.py index 9010b83384..2dddfdc06b 100644 --- a/scripts/machine_translation/evaluate_transformer.py +++ b/scripts/machine_translation/evaluate_transformer.py @@ -247,10 +247,17 @@ def evaluate(args): of.write('\n'.join(pred_sentences)) of.write('\n') - sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines]) - logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} Avg NLL={}, Perplexity={}' + sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines]) + logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} ' + '({:2.1f} {:2.1f} {:2.1f} {:2.1f}) ' + '(BP={:.3f}, ratio={:.3f}, syslen={}, reflen={}), ' + 'Avg NLL={}, Perplexity={}' .format(end_eval_time - start_eval_time, len(all_tgt_lines), - sacrebleu_out.score, avg_nll_loss, np.exp(avg_nll_loss))) + sacrebleu_out.score, + *sacrebleu_out.precisions, + sacrebleu_out.bp, sacrebleu_out.sys_len / sacrebleu_out.ref_len, + sacrebleu_out.sys_len, sacrebleu_out.ref_len, + avg_nll_loss, np.exp(avg_nll_loss))) # inference only else: with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of: From a84a2d09946a6ea4494d7e755d1a24510eec7fed Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 10 Aug 2020 16:26:47 +0800 Subject: [PATCH 23/39] fix --- scripts/machine_translation/train_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index a3b2c99439..5850a46738 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -501,7 +501,7 @@ def train(args): .format(n_update, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break - if local_rank == 0 and args.epochs > 0: + if local_rank == 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) From 681dc0038d604ec09720bbfe045d9c602dcebd0b Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 17 Aug 2020 00:29:21 +0800 Subject: [PATCH 24/39] add test for num_parts and update behavior of boundedbudgetsampler with even_size --- src/gluonnlp/data/sampler.py | 54 ++++++++++++++++++++++++++---------- tests/test_data_sampler.py | 17 ++++++++---- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index 6d817dcdef..e119eb814f 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -23,7 +23,6 @@ import math import random import warnings -import random import numpy as np import abc from typing import Union, Sequence, Optional, List @@ -289,12 +288,16 @@ class BoundedBudgetSampler(BaseSampler): Number of partitions which the data is split into (default: 1) part_index The index of the part to read from + even_size + If the number of samples is not even across all partitions, sample a few extra samples + for the ones with fewer samples. """ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], max_num_tokens: int = -1, max_num_sentences: int = -1, required_batch_size_multiple: int = 1, shuffle: bool = True, seed: Optional[int] = None, - num_parts: int = 1, part_index: int = 0): + num_parts: int = 1, part_index: int = 0, + even_size: bool = False): assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.' assert max_num_tokens > 0 or max_num_sentences > 0, \ 'One of max_num_tokens and max_num_sentences must be larger than 0' @@ -310,6 +313,7 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], self._rng = np.random.RandomState(seed) self._num_parts = num_parts self._part_index = part_index + self._even_size = even_size # sort self._indices = self._indices[np.argsort(self._lengths, kind='mergesort')] batch = [] @@ -335,29 +339,51 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], ) batch.append(index) if len(batch) > 0: - self._batches.append(np.array(batch)) - + self._batches.append(np.array(batch)) + + # split batches to parts + # split strategy is same as the SplitSampler + length = len(lengths) + if not even_size: + part_len = length // num_parts + remaining = length % num_parts + self._start = part_len * part_index + min(part_index, remaining) + self._end = self._start + part_len + (part_index < remaining) + self._part_len = self._end - self._start + else: + part_len = int(length + num_parts - 1) // num_parts + self._start = part_len * part_index + self._end = self._start + part_len + self._start = self._start if self._start < length else length + self._end = self._end if self._end < length else length + self._part_len = part_len + self._part_batches = self._batches[self._start:self._end] + if even_size and len(self._part_batches) < self._part_len: + candidates = random.sample(self._batches, k=self._part_len-len(self._part_batches)) + self._part_batches.extend(candidates) + self._part_sample_num = sum([len(b) for b in self._part_batches]) + def __iter__(self): if self._shuffle: - self._rng.shuffle(self._batches) - part_batches = [] - for i in range(len(self._batches)): - if i % self._num_parts == self._part_index: - part_batches.append(self._batches[i]) - for batch in part_batches: + self._rng.shuffle(self._part_batches) + for batch in self._part_batches: yield batch def __len__(self): - return len(self._batches) + return len(self._part_batches) def __repr__(self): ret = '{name}(\n' \ ' sample_num={sample_num},\n' \ - ' batch_num={batch_num}\n'\ + ' batch_num={batch_num}\n' \ + ' part_num={part_num}\n' \ + ' part_index={part_index}\n' \ ')'\ .format(name=self.__class__.__name__, - sample_num=len(self._lengths), - batch_num=len(self._batches)) + sample_num=self._part_sample_num, + batch_num=len(self._part_batches), + part_num=self._num_parts, + part_index=self._part_index) return ret diff --git a/tests/test_data_sampler.py b/tests/test_data_sampler.py index ea920740b5..256333217c 100644 --- a/tests/test_data_sampler.py +++ b/tests/test_data_sampler.py @@ -137,13 +137,18 @@ def test_split_sampler_even_size(num_samples, num_parts): @pytest.mark.parametrize('required_batch_size_multiple', [1, 5]) @pytest.mark.parametrize('shuffle', [True, False]) @pytest.mark.parametrize('seed', [100, None]) +@pytest.mark.parametrize('num_parts', [1, 4]) def test_bounded_budget_sampler(seq_lengths, max_num_tokens, max_num_sentences, - required_batch_size_multiple, shuffle, seed): - sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, - required_batch_size_multiple, shuffle, seed) - print(sampler) + required_batch_size_multiple, shuffle, seed, num_parts): + samplers = [] + for index in range(num_parts): + samplers.append(s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, + required_batch_size_multiple, shuffle, seed, + num_parts, index)) total_sampled_ids = [] - for batch_sample_ids in sampler: - total_sampled_ids.extend(batch_sample_ids) + for sampler in samplers: + print(sampler) + for batch_sample_ids in sampler: + total_sampled_ids.extend(batch_sample_ids) assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N assert sorted(total_sampled_ids) == list(range(len(total_sampled_ids))) From 179b8db5ee6b409209e5a768623eb52891a643f1 Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 17 Aug 2020 14:54:27 +0800 Subject: [PATCH 25/39] fix --- src/gluonnlp/data/sampler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index e119eb814f..ed8b8572e4 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -370,18 +370,22 @@ def __iter__(self): yield batch def __len__(self): - return len(self._part_batches) + return len(self._batches) def __repr__(self): ret = '{name}(\n' \ ' sample_num={sample_num},\n' \ - ' batch_num={batch_num}\n' \ - ' part_num={part_num}\n' \ + ' batch_num={batch_num},\n' \ + ' part_sample_num={part_sample_num},\n' \ + ' part_batch_num={part_batch_num},\n' \ + ' part_num={part_num},\n' \ ' part_index={part_index}\n' \ ')'\ .format(name=self.__class__.__name__, - sample_num=self._part_sample_num, - batch_num=len(self._part_batches), + sample_num=self._lengths.shape[0], + batch_num=len(self._batches), + part_sample_num=self._part_sample_num, + part_batch_num=len(self._part_batches), part_num=self._num_parts, part_index=self._part_index) return ret From eccceb7b54ba687b001c5b97371182f1d73e2216 Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 17 Aug 2020 16:03:31 +0800 Subject: [PATCH 26/39] fix --- src/gluonnlp/data/sampler.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index ed8b8572e4..ee9473a67d 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -289,8 +289,8 @@ class BoundedBudgetSampler(BaseSampler): part_index The index of the part to read from even_size - If the number of samples is not even across all partitions, sample a few extra samples - for the ones with fewer samples. + If the number of batches is not even across all partitions, sample a few extra batches + for the ones with fewer batches. """ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], max_num_tokens: int = -1, max_num_sentences: int = -1, @@ -342,8 +342,7 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], self._batches.append(np.array(batch)) # split batches to parts - # split strategy is same as the SplitSampler - length = len(lengths) + length = len(self._batches) if not even_size: part_len = length // num_parts remaining = length % num_parts @@ -370,20 +369,16 @@ def __iter__(self): yield batch def __len__(self): - return len(self._batches) + return len(self._part_batches) def __repr__(self): ret = '{name}(\n' \ - ' sample_num={sample_num},\n' \ - ' batch_num={batch_num},\n' \ ' part_sample_num={part_sample_num},\n' \ ' part_batch_num={part_batch_num},\n' \ ' part_num={part_num},\n' \ ' part_index={part_index}\n' \ ')'\ .format(name=self.__class__.__name__, - sample_num=self._lengths.shape[0], - batch_num=len(self._batches), part_sample_num=self._part_sample_num, part_batch_num=len(self._part_batches), part_num=self._num_parts, From adc4b348aa175dedc180f7a00d7b3f275a797af1 Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 17 Aug 2020 16:22:07 +0800 Subject: [PATCH 27/39] fix --- .../machine_translation/train_transformer.py | 48 ++++++++++--------- src/gluonnlp/data/sampler.py | 4 +- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 5850a46738..ba6ee02f99 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -176,7 +176,10 @@ def parse_args(): if args.max_update > 0: args.epochs = -1 logging_config(args.save_dir, console=True) - logging.info(args) + _, args.num_parts, args.rank, args.local_rank, _, args.ctx_l = init_comm( + args.comm_backend, args.gpus) + if args.local_rank == 0: + logging.info(args) return args def validation(model, data_loader, ctx_l): @@ -231,14 +234,16 @@ def load_dataset_with_cache(src_corpus_path: str, tgt_corpus_path: str, src_tokenizer: BaseTokenizerWithVocab, tgt_tokenizer: BaseTokenizerWithVocab, - overwrite_cache: bool): + overwrite_cache: bool, + local_rank: int): # TODO online h5py multi processing encode (Tao) src_md5sum = md5sum(src_corpus_path) tgt_md5sum = md5sum(tgt_corpus_path) cache_filepath = os.path.join(CACHE_PATH, '{}_{}.cache.npz'.format(src_md5sum[:6], tgt_md5sum[:6])) if os.path.exists(cache_filepath) and not overwrite_cache: - logging.info('Load cache from {}'.format(cache_filepath)) + if local_rank == 0: + logging.info('Load cache from {}'.format(cache_filepath)) npz_data = np.load(cache_filepath, allow_pickle=True) src_data, tgt_data = npz_data['src_data'][:], npz_data['tgt_data'][:] else: @@ -288,8 +293,6 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path): def train(args): - store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm( - args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) @@ -302,12 +305,14 @@ def train(args): args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, - args.overwrite_cache) + args.overwrite_cache, + args.local_rank) dev_src_data, dev_tgt_data = load_dataset_with_cache(args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, - args.overwrite_cache) + args.overwrite_cache, + args.local_rank) data_train = gluon.data.SimpleDataset( [(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))]) @@ -328,9 +333,9 @@ def train(args): cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), - ctx=ctx_l) + ctx=args.ctx_l) model.hybridize() - if local_rank == 0: + if args.local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) @@ -364,8 +369,8 @@ def train(args): max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, seed=args.seed, - num_parts=num_parts, - part_index=rank) + num_parts=args.num_parts, + part_index=args.rank) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError('FixedBucketSampler does not support horovod at present') @@ -390,8 +395,7 @@ def train(args): else: raise NotImplementedError - if local_rank == 0: - logging.info(train_batch_sampler) + logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader(data_train, @@ -422,13 +426,13 @@ def train(args): while (args.epochs < 0 or epoch_id < args.epochs): # when args.epochs < 0, the model will keep training n_epoch_train_iters = 0 processed_batch_num = 0 - train_multi_data_loader = grouper(train_data_loader, len(ctx_l)) + train_multi_data_loader = grouper(train_data_loader, len(args.ctx_l)) is_last_batch = False sample_data_l = next(train_multi_data_loader) while not is_last_batch: processed_batch_num += len(sample_data_l) loss_l = [] - for sample_data, ctx in zip(sample_data_l, ctx_l): + for sample_data, ctx in zip(sample_data_l, args.ctx_l): if sample_data is None: continue src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data @@ -457,7 +461,7 @@ def train(args): sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True - if local_rank == 0 and num_params is None: + if args.local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(model.collect_params()) logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}' .format(num_params, num_fixed_params)) @@ -475,7 +479,7 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if local_rank == 0 and \ + if args.local_rank == 0 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() @@ -483,29 +487,29 @@ def train(args): log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}' - .format(epoch_id, min(processed_batch_num * num_parts, len(train_data_loader)), + .format(epoch_id, min(processed_batch_num * args.num_parts, len(train_data_loader)), len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if local_rank == 0 and \ + if args.local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): n_update = n_train_iters // args.save_interval_update model.save_parameters(os.path.join(args.save_dir, 'update{:d}.params'.format(n_update)), deduplicate=True) - avg_valid_loss = validation(model, val_data_loader, ctx_l) + avg_valid_loss = validation(model, val_data_loader, args.ctx_l) logging.info('[Update {}] validation loss/ppl={:.4f}/{:.4f}' .format(n_update, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break - if local_rank == 0: + if args.local_rank == 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) - avg_valid_loss = validation(model, val_data_loader, ctx_l) + avg_valid_loss = validation(model, val_data_loader, args.ctx_l) logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}' .format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index ee9473a67d..e8d9673187 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -369,16 +369,18 @@ def __iter__(self): yield batch def __len__(self): - return len(self._part_batches) + return len(self._batches) def __repr__(self): ret = '{name}(\n' \ + ' batch_num={batch_num},\n' \ ' part_sample_num={part_sample_num},\n' \ ' part_batch_num={part_batch_num},\n' \ ' part_num={part_num},\n' \ ' part_index={part_index}\n' \ ')'\ .format(name=self.__class__.__name__, + batch_num=len(self._batches), part_sample_num=self._part_sample_num, part_batch_num=len(self._part_batches), part_num=self._num_parts, From b03faf2062280aa1ee9cc2eaf34665d9bd2464b5 Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 17 Aug 2020 16:33:27 +0800 Subject: [PATCH 28/39] fix logging when using horovd --- scripts/machine_translation/train_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index ba6ee02f99..87345b5145 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -175,10 +175,10 @@ def parse_args(): args = parser.parse_args() if args.max_update > 0: args.epochs = -1 - logging_config(args.save_dir, console=True) _, args.num_parts, args.rank, args.local_rank, _, args.ctx_l = init_comm( args.comm_backend, args.gpus) if args.local_rank == 0: + logging_config(args.save_dir, console=True) logging.info(args) return args From bc9ad31f3009203d0dffb5b77ffb93eb606a7080 Mon Sep 17 00:00:00 2001 From: Hu Date: Mon, 17 Aug 2020 16:38:34 +0800 Subject: [PATCH 29/39] udpate doc of train transformer --- scripts/machine_translation/README.md | 35 +++++++++++++++++---------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 061ba0658d..81da47e9cd 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -13,7 +13,6 @@ bash wmt2014_ende.sh yttm Then, you can run the experiment. For "transformer_base" configuration -# TODO ```bash SUBWORD_MODEL=yttm SRC=en @@ -31,30 +30,38 @@ python train_transformer.py \ --save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \ --cfg transformer_base \ --lr 0.002 \ - --batch_size 2700 \ - --num_averages 5 \ - --warmup_steps 4000 \ + --sampler BoundedBudgetSampler \ + --max_num_tokens 2700 \ + --max_update 15000 \ + --save_interval_update 500 \ + --warmup_steps 6000 \ --warmup_init_lr 0.0 \ --seed 123 \ --gpus 0,1,2,3 ``` -Use the average_checkpoint cli to average the last 10 checkpoints +Or training via horovod +``` +mpirun -np 4 -H localhost:4 python3 train_transformer.py \ + --comm_backend horovod \ + ... +``` + +Use the average_checkpoint cli to average the last 10 epoch checkpoints ```bash gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \ - --begin 21 \ - --end 30 \ - --save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params + --begin 30 \ + --end 39 \ + --save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch_avg_30_39.params ``` - Use the following command to inference/evaluate the Transformer model: ```bash SUBWORD_MODEL=yttm python evaluate_transformer.py \ - --param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \ + --param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/epoch_avg_30_39.params \ --src_lang en \ --tgt_lang de \ --cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \ @@ -98,7 +105,7 @@ python train_transformer.py \ --gpus 0,1,2,3 ``` -Use the average_checkpoint cli to average the last 10 checkpoints +Use the average_checkpoint cli to average the last 10 update checkpoints ```bash gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \ @@ -132,16 +139,18 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU): - transformer_base +(test bleu / valid bleu) | Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | |---------------|------------|-------------|-------------|--------------|-------------| -| yttm | | - | - | - | - | +| yttm | | 26.50/26.29 | - | - | - | | hf_bpe | | - | - | - | - | | spm | | - | - | - | - | - transformer_wmt_en_de_big +(test bleu / valid bleu) | Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | |---------------|------------|-------------|-------------|--------------|-------------| -| yttm | | 27.99 | - | - | - | +| yttm | | 27.93/26.82 | - | - | - | | hf_bpe | | - | - | - | - | | spm | | - | - | - | - | From 3377e6b3c374afe330fe3c636f5760c2ae1dca00 Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 10:21:24 +0800 Subject: [PATCH 30/39] add test case for fail downloading --- tests/test_utils_misc.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_utils_misc.py b/tests/test_utils_misc.py index d5bc92eacd..83593cb3f6 100644 --- a/tests/test_utils_misc.py +++ b/tests/test_utils_misc.py @@ -109,6 +109,15 @@ def test_download_https(overwrite): overwrite=overwrite) +@pytest.mark.remote_required +@pytest.mark.parametrize('overwrite', [False, True]) +def test_download_non_existing(overwrite): + with pytest.raises(RuntimeError, match='Failed downloading url'): + verify_download(url='https://commoncrawl.s3.amazonaws.com/crawl-data/CC-MAIN-2014-41/non_existing', + sha1_hash='foo', + overwrite=overwrite) + + def test_logging_config(): logger = logging.getLogger(__name__) with tempfile.TemporaryDirectory() as root: From 8fb28d72f1baf3a491dc2471c70e7674a92ecc2f Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 11:10:57 +0800 Subject: [PATCH 31/39] add a ShardedIterator --- .../machine_translation/train_transformer.py | 44 +++---- src/gluonnlp/data/sampler.py | 118 +++++++++++------- tests/test_data_sampler.py | 38 ++++-- 3 files changed, 121 insertions(+), 79 deletions(-) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 87345b5145..71595c5aa1 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -50,7 +50,8 @@ LinearWidthBucket, ExpWidthBucket, FixedBucketSampler, - BoundedBudgetSampler + BoundedBudgetSampler, + ShardedIterator ) import gluonnlp.data.batchify as bf from gluonnlp.data import Vocab @@ -175,11 +176,8 @@ def parse_args(): args = parser.parse_args() if args.max_update > 0: args.epochs = -1 - _, args.num_parts, args.rank, args.local_rank, _, args.ctx_l = init_comm( - args.comm_backend, args.gpus) - if args.local_rank == 0: - logging_config(args.save_dir, console=True) - logging.info(args) + logging_config(args.save_dir, console=True) + logging.info(args) return args def validation(model, data_loader, ctx_l): @@ -293,6 +291,8 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path): def train(args): + _, num_parts, rank, local_rank, _, ctx_l = init_comm( + args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) @@ -306,13 +306,13 @@ def train(args): src_tokenizer, tgt_tokenizer, args.overwrite_cache, - args.local_rank) + local_rank) dev_src_data, dev_tgt_data = load_dataset_with_cache(args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, - args.local_rank) + local_rank) data_train = gluon.data.SimpleDataset( [(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))]) @@ -333,9 +333,9 @@ def train(args): cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), - ctx=args.ctx_l) + ctx=ctx_l) model.hybridize() - if args.local_rank == 0: + if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) @@ -368,9 +368,9 @@ def train(args): train_batch_sampler = BoundedBudgetSampler(lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, - seed=args.seed, - num_parts=args.num_parts, - part_index=args.rank) + seed=args.seed) + if num_parts > 1: + train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError('FixedBucketSampler does not support horovod at present') @@ -426,13 +426,13 @@ def train(args): while (args.epochs < 0 or epoch_id < args.epochs): # when args.epochs < 0, the model will keep training n_epoch_train_iters = 0 processed_batch_num = 0 - train_multi_data_loader = grouper(train_data_loader, len(args.ctx_l)) + train_multi_data_loader = grouper(train_data_loader, len(ctx_l)) is_last_batch = False sample_data_l = next(train_multi_data_loader) while not is_last_batch: processed_batch_num += len(sample_data_l) loss_l = [] - for sample_data, ctx in zip(sample_data_l, args.ctx_l): + for sample_data, ctx in zip(sample_data_l, ctx_l): if sample_data is None: continue src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data @@ -461,7 +461,7 @@ def train(args): sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True - if args.local_rank == 0 and num_params is None: + if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(model.collect_params()) logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}' .format(num_params, num_fixed_params)) @@ -479,7 +479,7 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if args.local_rank == 0 and \ + if local_rank == 0 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() @@ -487,29 +487,29 @@ def train(args): log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}' - .format(epoch_id, min(processed_batch_num * args.num_parts, len(train_data_loader)), + .format(epoch_id, processed_batch_num * num_parts, len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if args.local_rank == 0 and \ + if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): n_update = n_train_iters // args.save_interval_update model.save_parameters(os.path.join(args.save_dir, 'update{:d}.params'.format(n_update)), deduplicate=True) - avg_valid_loss = validation(model, val_data_loader, args.ctx_l) + avg_valid_loss = validation(model, val_data_loader, ctx_l) logging.info('[Update {}] validation loss/ppl={:.4f}/{:.4f}' .format(n_update, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break - if args.local_rank == 0: + if local_rank == 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) - avg_valid_loss = validation(model, val_data_loader, args.ctx_l) + avg_valid_loss = validation(model, val_data_loader, ctx_l) logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}' .format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index e8d9673187..9ada91f030 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -284,24 +284,14 @@ class BoundedBudgetSampler(BaseSampler): Whether to shuffle the batches. seed The seed of the sampler - num_parts - Number of partitions which the data is split into (default: 1) - part_index - The index of the part to read from - even_size - If the number of batches is not even across all partitions, sample a few extra batches - for the ones with fewer batches. """ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], max_num_tokens: int = -1, max_num_sentences: int = -1, required_batch_size_multiple: int = 1, - shuffle: bool = True, seed: Optional[int] = None, - num_parts: int = 1, part_index: int = 0, - even_size: bool = False): + shuffle: bool = True, seed: Optional[int] = None): assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.' assert max_num_tokens > 0 or max_num_sentences > 0, \ 'One of max_num_tokens and max_num_sentences must be larger than 0' - assert part_index < num_parts, 'part_index should be less than num_parts' self._lengths = np.array(lengths) if self._lengths.ndim == 2: self._lengths = self._lengths.max(axis=1) @@ -311,9 +301,6 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], self._batches = [] self._shuffle = shuffle self._rng = np.random.RandomState(seed) - self._num_parts = num_parts - self._part_index = part_index - self._even_size = even_size # sort self._indices = self._indices[np.argsort(self._lengths, kind='mergesort')] batch = [] @@ -341,31 +328,10 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], if len(batch) > 0: self._batches.append(np.array(batch)) - # split batches to parts - length = len(self._batches) - if not even_size: - part_len = length // num_parts - remaining = length % num_parts - self._start = part_len * part_index + min(part_index, remaining) - self._end = self._start + part_len + (part_index < remaining) - self._part_len = self._end - self._start - else: - part_len = int(length + num_parts - 1) // num_parts - self._start = part_len * part_index - self._end = self._start + part_len - self._start = self._start if self._start < length else length - self._end = self._end if self._end < length else length - self._part_len = part_len - self._part_batches = self._batches[self._start:self._end] - if even_size and len(self._part_batches) < self._part_len: - candidates = random.sample(self._batches, k=self._part_len-len(self._part_batches)) - self._part_batches.extend(candidates) - self._part_sample_num = sum([len(b) for b in self._part_batches]) - def __iter__(self): if self._shuffle: - self._rng.shuffle(self._part_batches) - for batch in self._part_batches: + self._rng.shuffle(self._batches) + for batch in self._batches: yield batch def __len__(self): @@ -373,18 +339,12 @@ def __len__(self): def __repr__(self): ret = '{name}(\n' \ + ' sample_num={sample_num},\n' \ ' batch_num={batch_num},\n' \ - ' part_sample_num={part_sample_num},\n' \ - ' part_batch_num={part_batch_num},\n' \ - ' part_num={part_num},\n' \ - ' part_index={part_index}\n' \ ')'\ .format(name=self.__class__.__name__, - batch_num=len(self._batches), - part_sample_num=self._part_sample_num, - part_batch_num=len(self._part_batches), - part_num=self._num_parts, - part_index=self._part_index) + sample_num=len(self._lengths), + batch_num=len(self._batches)) return ret @@ -698,3 +658,69 @@ def __iter__(self): def __len__(self): return self._len * self._repeat + + +class ShardedIterator(BaseSampler): + r"""A sharded wrapper around an iterable (padded to length). + + Parameters + ---------- + sampler + num_parts + Number of partitions which the data is split into (default: 1) + part_index + The index of the part to read from + even_size + If the number of batches is not even across all partitions, sample a few extra batches + for the ones with fewer batches. + """ + def __init__(self, sampler: BaseSampler, + num_parts: int = 1, + part_index: int = 0, + even_size: bool = False): + assert part_index < num_parts, 'part_index should be less than num_parts' + self._sampler = sampler + self._num_parts = num_parts + self._part_index = part_index + self._even_size = even_size + + length = len(sampler) + if not even_size: + part_len = length // num_parts + remaining = length % num_parts + self._start = part_len * part_index + min(part_index, remaining) + self._end = self._start + part_len + (part_index < remaining) + self._part_len = self._end - self._start + else: + part_len = int(length + num_parts - 1) // num_parts + self._start = part_len * part_index + self._end = self._start + part_len + self._start = self._start if self._start < length else length + self._end = self._end if self._end < length else length + self._part_len = part_len + + def __iter__(self): + batches = list(self._sampler) + part_batches = batches[self._start:self._end] + if self._even_size and len(part_batches) < self._part_len: + candidates = random.sample(batches, k=self._part_len-len(part_batches)) + part_batches.extend(candidates) + for batch in part_batches: + yield batch + + def __len__(self): + return self._part_len + + def __repr__(self): + ret = '{name}(\n' \ + ' batch_num={batch_num},\n' \ + ' part_batch_num={part_batch_num},\n' \ + ' num_parts={num_parts},\n' \ + ' part_index={part_index},\n' \ + ')'\ + .format(name=self.__class__.__name__, + batch_num=len(self._sampler), + part_batch_num=self._part_len, + num_parts=self._num_parts, + part_index=self.part_index) + return ret diff --git a/tests/test_data_sampler.py b/tests/test_data_sampler.py index 256333217c..7d4d11c8a4 100644 --- a/tests/test_data_sampler.py +++ b/tests/test_data_sampler.py @@ -34,7 +34,6 @@ def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_s ratio=ratio, shuffle=shuffle, use_average_length=use_average_length, bucket_scheme=bucket_scheme) - print(sampler) total_sampled_ids = [] for batch_sample_ids in sampler: total_sampled_ids.extend(batch_sample_ids) @@ -137,18 +136,35 @@ def test_split_sampler_even_size(num_samples, num_parts): @pytest.mark.parametrize('required_batch_size_multiple', [1, 5]) @pytest.mark.parametrize('shuffle', [True, False]) @pytest.mark.parametrize('seed', [100, None]) -@pytest.mark.parametrize('num_parts', [1, 4]) def test_bounded_budget_sampler(seq_lengths, max_num_tokens, max_num_sentences, - required_batch_size_multiple, shuffle, seed, num_parts): - samplers = [] - for index in range(num_parts): - samplers.append(s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, - required_batch_size_multiple, shuffle, seed, - num_parts, index)) + required_batch_size_multiple, shuffle, seed): + sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, + required_batch_size_multiple, shuffle, seed) + total_sampled_ids = [] + for batch_sample_ids in sampler: + total_sampled_ids.extend(batch_sample_ids) + assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N + assert sorted(total_sampled_ids) == list(range(len(total_sampled_ids))) + + +@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)], + [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]]) +@pytest.mark.parametrize('max_num_tokens', [200, 500]) +@pytest.mark.parametrize('max_num_sentences', [-1, 5]) +@pytest.mark.parametrize('required_batch_size_multiple', [1, 5]) +@pytest.mark.parametrize('shuffle', [True, False]) +@pytest.mark.parametrize('num_parts', [1, 4]) +@pytest.mark.parametrize('even_size', [False]) +def test_sharded_iterator(seq_lengths, max_num_tokens, max_num_sentences, + required_batch_size_multiple, shuffle, + num_parts, even_size): total_sampled_ids = [] - for sampler in samplers: - print(sampler) - for batch_sample_ids in sampler: + for part_index in range(num_parts): + # here we use independent (but same) sampler to simulate multi process situation + sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, + required_batch_size_multiple, shuffle, seed=100) + sharded_iter = s.ShardedIterator(sampler, num_parts, part_index, even_size) + for batch_sample_ids in sharded_iter: total_sampled_ids.extend(batch_sample_ids) assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N assert sorted(total_sampled_ids) == list(range(len(total_sampled_ids))) From c347decd0874384979376d5acebccec3aea867bf Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 11:11:36 +0800 Subject: [PATCH 32/39] fix --- scripts/machine_translation/README.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index c778f69afc..8e377e9935 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -47,7 +47,7 @@ mpirun -np 4 -H localhost:4 python3 train_transformer.py \ ... ``` -Use the average_checkpoint cli to average the last 10 epoch checkpoints +Use the average_checkpoint cli to average the last 10 checkpoints ```bash gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \ @@ -59,8 +59,7 @@ gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ Use the following command to inference/evaluate the Transformer model: ```bash -SUBWORD_MODEL=yttm -python evaluate_transformer.py \ +python3 evaluate_transformer.py \ --param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/epoch_avg_30_39.params \ --src_lang en \ --tgt_lang de \ @@ -105,7 +104,7 @@ python3 train_transformer.py \ --gpus 0,1,2,3 ``` -Use the average_checkpoint cli to average the last 10 update checkpoints +Use the average_checkpoint cli to average the last 10 checkpoints ```bash gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \ @@ -118,7 +117,6 @@ gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_A Use the following command to inference/evaluate the Transformer model: ```bash -SUBWORD_MODEL=yttm python3 evaluate_transformer.py \ --param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \ --src_lang en \ From cdbe846d28f673fe3ea03cd9ab771bc45f572e66 Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 11:11:58 +0800 Subject: [PATCH 33/39] fix --- scripts/datasets/machine_translation/wmt2014_ende.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/datasets/machine_translation/wmt2014_ende.sh b/scripts/datasets/machine_translation/wmt2014_ende.sh index 5c5d5ecbce..6557715365 100644 --- a/scripts/datasets/machine_translation/wmt2014_ende.sh +++ b/scripts/datasets/machine_translation/wmt2014_ende.sh @@ -23,7 +23,7 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \ --src-corpus train.raw.${SRC} \ --tgt-corpus train.raw.${TGT} \ --min-num-words 1 \ - --max-num-words 250 \ + --max-num-words 100 \ --max-ratio 1.5 \ --src-save-path train.tok.${SRC} \ --tgt-save-path train.tok.${TGT} @@ -33,7 +33,7 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \ --src-corpus dev.raw.${SRC} \ --tgt-corpus dev.raw.${TGT} \ --min-num-words 1 \ - --max-num-words 250 \ + --max-num-words 100 \ --max-ratio 1.5 \ --src-save-path dev.tok.${SRC} \ --tgt-save-path dev.tok.${TGT} From 050478f19874b9ebfe96dabf0327d9e57a98dc19 Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 11:26:11 +0800 Subject: [PATCH 34/39] fix --- src/gluonnlp/data/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index 9ada91f030..aabfe7a688 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -709,7 +709,7 @@ def __iter__(self): yield batch def __len__(self): - return self._part_len + return len(self._sampler) def __repr__(self): ret = '{name}(\n' \ @@ -722,5 +722,5 @@ def __repr__(self): batch_num=len(self._sampler), part_batch_num=self._part_len, num_parts=self._num_parts, - part_index=self.part_index) + part_index=self._part_index) return ret From 6f94fc7dc81054b031079a58657d2015846bd0d0 Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 14:00:05 +0800 Subject: [PATCH 35/39] change mpirun to horovodrun --- scripts/machine_translation/README.md | 2 +- scripts/question_answering/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 8e377e9935..92b7b384e6 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -42,7 +42,7 @@ python3 train_transformer.py \ Or training via horovod ``` -mpirun -np 4 -H localhost:4 python3 train_transformer.py \ +horovodrun -np 4 -H localhost:4 python3 train_transformer.py \ --comm_backend horovod \ ... ``` diff --git a/scripts/question_answering/README.md b/scripts/question_answering/README.md index a39e1e77b3..3c655c2ee8 100644 --- a/scripts/question_answering/README.md +++ b/scripts/question_answering/README.md @@ -74,7 +74,7 @@ We could speed up multi-GPU training via horovod. Compared to KVStore, training RoBERTa Large model on SQuAD 2.0 with 3 epochs will save roughly 1/4 training resources (8.48 vs 11.32 hours). Results may vary depending on the training instances. ```bash -mpirun -np 4 -H localhost:4 python3 run_squad.py \ +horovodrun -np 4 -H localhost:4 python3 run_squad.py \ --comm_backend horovod \ ... ``` From 36cccca880cdf2f9a04ac3adf83c6e230a516d0c Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 17:22:01 +0800 Subject: [PATCH 36/39] make the horovod command complete --- scripts/machine_translation/README.md | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 92b7b384e6..402e6272eb 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -44,7 +44,25 @@ Or training via horovod ``` horovodrun -np 4 -H localhost:4 python3 train_transformer.py \ --comm_backend horovod \ - ... + --train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \ + --train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \ + --dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \ + --dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \ + --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \ + --cfg transformer_base \ + --lr 0.002 \ + --sampler BoundedBudgetSampler \ + --max_num_tokens 2700 \ + --max_update 15000 \ + --save_interval_update 500 \ + --warmup_steps 6000 \ + --warmup_init_lr 0.0 \ + --seed 123 \ + --gpus 0,1,2,3 ``` Use the average_checkpoint cli to average the last 10 checkpoints From ddb174f78af29187323b989e06edff55aa56d269 Mon Sep 17 00:00:00 2001 From: Hu Date: Wed, 19 Aug 2020 17:22:27 +0800 Subject: [PATCH 37/39] use print(sampler) to cover the codes of __repr__ func --- tests/test_data_sampler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_data_sampler.py b/tests/test_data_sampler.py index 7d4d11c8a4..9e73905a53 100644 --- a/tests/test_data_sampler.py +++ b/tests/test_data_sampler.py @@ -34,6 +34,8 @@ def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_s ratio=ratio, shuffle=shuffle, use_average_length=use_average_length, bucket_scheme=bucket_scheme) + # here we print sampler to cover the __repr__ func of the sampler + print(sampler) total_sampled_ids = [] for batch_sample_ids in sampler: total_sampled_ids.extend(batch_sample_ids) @@ -140,6 +142,7 @@ def test_bounded_budget_sampler(seq_lengths, max_num_tokens, max_num_sentences, required_batch_size_multiple, shuffle, seed): sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, required_batch_size_multiple, shuffle, seed) + print(sampler) total_sampled_ids = [] for batch_sample_ids in sampler: total_sampled_ids.extend(batch_sample_ids) @@ -160,10 +163,11 @@ def test_sharded_iterator(seq_lengths, max_num_tokens, max_num_sentences, num_parts, even_size): total_sampled_ids = [] for part_index in range(num_parts): - # here we use independent (but same) sampler to simulate multi process situation + # we use independent (but same) sampler to simulate multi process situation sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, required_batch_size_multiple, shuffle, seed=100) sharded_iter = s.ShardedIterator(sampler, num_parts, part_index, even_size) + print(sharded_iter) for batch_sample_ids in sharded_iter: total_sampled_ids.extend(batch_sample_ids) assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N From a436c09b0078a6568fbf355e8d504719b3817b19 Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 20 Aug 2020 14:50:34 +0800 Subject: [PATCH 38/39] empty commit --- scripts/machine_translation/train_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 71595c5aa1..2389d84ad3 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -180,6 +180,7 @@ def parse_args(): logging.info(args) return args + def validation(model, data_loader, ctx_l): """Validate the model on the dataset From db8afdb4a843d242e21f76ec9497687e4ce221e7 Mon Sep 17 00:00:00 2001 From: Hu Date: Thu, 20 Aug 2020 16:12:25 +0800 Subject: [PATCH 39/39] add test case test_sharded_iterator_even_size --- tests/test_data_sampler.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_data_sampler.py b/tests/test_data_sampler.py index 9e73905a53..9597004be3 100644 --- a/tests/test_data_sampler.py +++ b/tests/test_data_sampler.py @@ -172,3 +172,34 @@ def test_sharded_iterator(seq_lengths, max_num_tokens, max_num_sentences, total_sampled_ids.extend(batch_sample_ids) assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N assert sorted(total_sampled_ids) == list(range(len(total_sampled_ids))) + + +@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)], + [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]]) +@pytest.mark.parametrize('max_num_tokens', [200, 500]) +@pytest.mark.parametrize('max_num_sentences', [-1, 5]) +@pytest.mark.parametrize('required_batch_size_multiple', [1, 5]) +@pytest.mark.parametrize('shuffle', [True, False]) +@pytest.mark.parametrize('num_parts', [1, 4]) +@pytest.mark.parametrize('even_size', [True]) +def test_sharded_iterator_even_size(seq_lengths, max_num_tokens, max_num_sentences, + required_batch_size_multiple, shuffle, + num_parts, even_size): + total_sampled_ids = [] + first_batch_num = None + for part_index in range(num_parts): + batch_num = 0 + # we use independent (but same) sampler to simulate multi process situation + sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences, + required_batch_size_multiple, shuffle, seed=100) + sharded_iter = s.ShardedIterator(sampler, num_parts, part_index, even_size) + print(sharded_iter) + for batch_sample_ids in sharded_iter: + total_sampled_ids.extend(batch_sample_ids) + batch_num += 1 + # assert batch num of each parts equals + if first_batch_num is None: + first_batch_num = batch_num + else: + assert first_batch_num == batch_num + assert len(set(total_sampled_ids)) == N