diff --git a/scripts/bert/compare_tf_gluon_model.py b/scripts/bert/conversion_tools/compare_tf_gluon_model.py similarity index 73% rename from scripts/bert/compare_tf_gluon_model.py rename to scripts/bert/conversion_tools/compare_tf_gluon_model.py index e35c216d52..11c3609254 100644 --- a/scripts/bert/compare_tf_gluon_model.py +++ b/scripts/bert/conversion_tools/compare_tf_gluon_model.py @@ -27,28 +27,37 @@ import mxnet as mx import gluonnlp as nlp +sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))) + parser = argparse.ArgumentParser(description='Comparison script for BERT model in Tensorflow' 'and that in Gluon. This script works with ' - 'google/bert@f39e881b') + 'google/bert@f39e881b', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--input_file', type=str, default='input.txt', - help='sample input file for testing. Default is input.txt') + help='sample input file for testing') parser.add_argument('--tf_bert_repo_dir', type=str, default='~/bert/', help='path to the original Tensorflow bert repository. ' - 'The repo should be at f39e881b. ' - 'Default is ~/bert/') + 'The repo should be at f39e881b.') parser.add_argument('--tf_model_dir', type=str, default='~/uncased_L-12_H-768_A-12/', - help='path to the original Tensorflow bert checkpoint directory. ' - 'Default is ~/uncased_L-12_H-768_A-12/') + help='path to the original Tensorflow bert checkpoint directory.') +parser.add_argument('--tf_model_prefix', type=str, + default='bert_model.ckpt', + help='name of bert checkpoint file.') +parser.add_argument('--tf_config_name', type=str, + default='bert_config.json', + help='Name of Bert config file') parser.add_argument('--cased', action='store_true', help='if not set, inputs are converted to lower case') parser.add_argument('--gluon_dataset', type=str, default='book_corpus_wiki_en_uncased', - help='gluon dataset name. Default is book_corpus_wiki_en_uncased') + help='gluon dataset name') parser.add_argument('--gluon_model', type=str, default='bert_12_768_12', - help='gluon model name. Default is bert_12_768_12') + help='gluon model name') parser.add_argument('--gluon_parameter_file', type=str, default=None, help='gluon parameter file name.') +parser.add_argument('--gluon_vocab_file', type=str, default=None, + help='gluon vocab file corresponding to --gluon_parameter_file.') args = parser.parse_args() @@ -56,8 +65,8 @@ tf_bert_repo_dir = os.path.expanduser(args.tf_bert_repo_dir) tf_model_dir = os.path.expanduser(args.tf_model_dir) vocab_file = os.path.join(tf_model_dir, 'vocab.txt') -bert_config_file = os.path.join(tf_model_dir, 'bert_config.json') -init_checkpoint = os.path.join(tf_model_dir, 'bert_model.ckpt') +bert_config_file = os.path.join(tf_model_dir, args.tf_config_name) +init_checkpoint = os.path.join(tf_model_dir, args.tf_model_prefix) do_lower_case = not args.cased max_length = 128 @@ -129,13 +138,18 @@ # Gluon MODEL # ############################################################################### -bert, vocabulary = nlp.model.get_model(args.gluon_model, - dataset_name=args.gluon_dataset, - pretrained=not args.gluon_parameter_file, - use_pooler=False, - use_decoder=False, - use_classifier=False) if args.gluon_parameter_file: + assert args.gluon_vocab_file, \ + 'Must specify --gluon_vocab_file when specifying --gluon_parameter_file' + with open(args.gluon_vocab_file, 'r') as f: + vocabulary = nlp.Vocab.from_json(f.read()) + bert, vocabulary = nlp.model.get_model(args.gluon_model, + dataset_name=None, + vocab=vocabulary, + pretrained=not args.gluon_parameter_file, + use_pooler=False, + use_decoder=False, + use_classifier=False) try: bert.cast('float16') bert.load_parameters(args.gluon_parameter_file, ignore_extra=True) @@ -143,6 +157,15 @@ except AssertionError: bert.cast('float32') bert.load_parameters(args.gluon_parameter_file, ignore_extra=True) +else: + assert not args.gluon_vocab_file, \ + 'Cannot specify --gluon_vocab_file without specifying --gluon_parameter_file' + bert, vocabulary = nlp.model.get_model(args.gluon_model, + dataset_name=args.gluon_dataset, + pretrained=not args.gluon_parameter_file, + use_pooler=False, + use_decoder=False, + use_classifier=False) print(bert) tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=do_lower_case) diff --git a/scripts/bert/conversion_tools/convert_pytorch_model.py b/scripts/bert/conversion_tools/convert_pytorch_model.py new file mode 100644 index 0000000000..978e377e98 --- /dev/null +++ b/scripts/bert/conversion_tools/convert_pytorch_model.py @@ -0,0 +1,180 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation +""" Script for converting PyTorch Model to Gluon. """ + +import argparse +import json +import logging +import os +import sys + +import mxnet as mx +import gluonnlp as nlp +import torch +from gluonnlp.model import BERTEncoder, BERTModel +from gluonnlp.model.bert import bert_hparams + +sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))) +from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab + +parser = argparse.ArgumentParser(description='Conversion script for PyTorch BERT model', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--model', type=str, default='bert_12_768_12', + choices=['bert_12_768_12', 'bert_24_1024_16'], help='BERT model name') +parser.add_argument('--pytorch_checkpoint_dir', type=str, + help='Path to Tensorflow checkpoint folder.') +parser.add_argument('--vocab_file', type=str, help='Full path to the vocab.txt') +parser.add_argument('--gluon_pytorch_name_mapping', type=str, + default='gluon_to_pytorch_naming.json', + help='Output of infer_pytorch_gluon_parameter_name_mapping.py') +parser.add_argument('--out_dir', type=str, default=os.path.join('~', 'output'), + help='Path to output folder. The folder must exist.') +parser.add_argument('--debug', action='store_true', help='debugging mode') +args = parser.parse_args() +logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO) +logging.info(args) + +# convert vocabulary +vocab = tf_vocab_to_gluon_vocab(load_text_vocab(args.vocab_file)) + +# vocab serialization +tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) +with open(tmp_file_path, 'w') as f: + f.write(vocab.to_json()) +hash_full, hash_short = get_hash(tmp_file_path) +gluon_vocab_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.vocab')) +with open(gluon_vocab_path, 'w') as f: + f.write(vocab.to_json()) + logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full) + +# Load PyTorch Model +pytorch_parameters = torch.load(os.path.join(args.pytorch_checkpoint_dir, 'pytorch_model.bin'), + map_location=lambda storage, loc: storage) +pytorch_parameters = {k: v.numpy() for k, v in pytorch_parameters.items()} + +# Make sure vocab fits to model +assert pytorch_parameters['bert.embeddings.word_embeddings.weight'].shape[0] == len( + vocab.idx_to_token) + +# Load Mapping +with open(args.gluon_pytorch_name_mapping, 'r') as f: + mapping = json.load(f) + +# BERT config +tf_config_names_to_gluon_config_names = { + 'attention_probs_dropout_prob': 'embed_dropout', + 'hidden_act': None, + 'hidden_dropout_prob': 'dropout', + 'hidden_size': 'units', + 'initializer_range': None, + 'intermediate_size': 'hidden_size', + 'max_position_embeddings': 'max_length', + 'num_attention_heads': 'num_heads', + 'num_hidden_layers': 'num_layers', + 'type_vocab_size': 'token_type_vocab_size', + 'vocab_size': None +} +predefined_args = bert_hparams[args.model] +with open(os.path.join(args.pytorch_checkpoint_dir, 'bert_config.json'), 'r') as f: + tf_config = json.load(f) + assert len(tf_config) == len(tf_config_names_to_gluon_config_names) + for tf_name, gluon_name in tf_config_names_to_gluon_config_names.items(): + if tf_name is None or gluon_name is None: + continue + assert tf_config[tf_name] == predefined_args[gluon_name] + +# BERT encoder +encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], + num_layers=predefined_args['num_layers'], units=predefined_args['units'], + hidden_size=predefined_args['hidden_size'], + max_length=predefined_args['max_length'], + num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'], + dropout=predefined_args['dropout'], + use_residual=predefined_args['use_residual']) + +# Infer enabled BERTModel components +use_pooler = any('pooler' in n for n in pytorch_parameters) +use_decoder = any('cls.predictions.transform.dense.weight' in n for n in pytorch_parameters) +use_classifier = any('cls.seq_relationship.weight' in n for n in pytorch_parameters) + +if not use_classifier and 'classifier.weight' in pytorch_parameters and \ + pytorch_parameters['classifier.weight'].shape[0] == 2: + logging.info('Assuming classifier weights in provided Pytorch model are ' + 'from next sentence prediction task.') + use_classifier = True + +logging.info('Inferred that the pytorch model provides the following parameters:') +logging.info('- use_pooler = {}'.format(use_pooler)) +logging.info('- use_decoder = {}'.format(use_decoder)) +logging.info('- use_classifier = {}'.format(use_classifier)) + +# BERT model +bert = BERTModel(encoder, len(vocab), + token_type_vocab_size=predefined_args['token_type_vocab_size'], + units=predefined_args['units'], embed_size=predefined_args['embed_size'], + embed_dropout=predefined_args['embed_dropout'], + word_embed=predefined_args['word_embed'], use_pooler=use_pooler, + use_decoder=use_decoder, use_classifier=use_classifier) + +bert.initialize(init=mx.init.Normal(0.02)) + +ones = mx.nd.ones((2, 8)) +out = bert(ones, ones, mx.nd.array([5, 6]), mx.nd.array([[1], [2]])) +params = bert._collect_params_with_prefix() +assert len(params) == len(pytorch_parameters), "Gluon model does not match PyTorch model. " \ + "Please fix the BERTModel hyperparameters" + +# set parameter data +loaded_params = {} +for name in params: + if name not in mapping: + raise RuntimeError('Invalid json mapping file. ' + 'The parameter {} is not described in the mapping file.'.format(name)) + pytorch_name = mapping[name] + if pytorch_name not in pytorch_parameters.keys(): + # Handle inconsistent naming in PyTorch + # The Expected names here are based on the PyTorch version of SciBert. + # The Inconsistencies were found in ClinicalBert + if 'LayerNorm' in pytorch_name: + pytorch_name = pytorch_name.replace('weight', 'gamma') + pytorch_name = pytorch_name.replace('bias', 'beta') + assert pytorch_name in pytorch_parameters.keys() + + if 'cls.seq_relationship' in pytorch_name: + pytorch_name = pytorch_name.replace('cls.seq_relationship', 'classifier') + + arr = mx.nd.array(pytorch_parameters[pytorch_name]) + + assert arr.shape == params[name].shape + params[name].set_data(arr) + loaded_params[name] = True + +if len(params) != len(loaded_params): + raise RuntimeError('The Gluon BERTModel comprises {} parameter arrays, ' + 'but {} have been extracted from the pytorch model. '.format( + len(params), len(loaded_params))) + +# param serialization +bert.save_parameters(tmp_file_path) +hash_full, hash_short = get_hash(tmp_file_path) +gluon_param_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.params')) +logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full) +bert.save_parameters(gluon_param_path) +mx.nd.waitall() diff --git a/scripts/bert/convert_tf_model.py b/scripts/bert/conversion_tools/convert_tf_model.py similarity index 62% rename from scripts/bert/convert_tf_model.py rename to scripts/bert/conversion_tools/convert_tf_model.py index 5cb0a5e7ae..6dd0806486 100644 --- a/scripts/bert/convert_tf_model.py +++ b/scripts/bert/conversion_tools/convert_tf_model.py @@ -19,26 +19,44 @@ # pylint:disable=redefined-outer-name,logging-format-interpolation """ Script for converting TF Model to Gluon. """ -import os -import logging import argparse +import json +import logging +import os +import sys + import mxnet as mx + from gluonnlp.model import BERTEncoder, BERTModel from gluonnlp.model.bert import bert_hparams -from utils import convert_vocab, get_hash, read_tf_checkpoint - -parser = argparse.ArgumentParser(description='Conversion script for Tensorflow BERT model') -parser.add_argument('--model', type=str, default='bert_12_768_12', - help='BERT model name. options are bert_12_768_12 and bert_24_1024_16.' - 'Default is bert_12_768_12') -parser.add_argument('--tf_checkpoint_dir', type=str, - default=os.path.join('~', 'cased_L-12_H-768_A-12/'), - help='Path to Tensorflow checkpoint folder. ' - 'Default is /home/ubuntu/cased_L-12_H-768_A-12/') -parser.add_argument('--out_dir', type=str, + +sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))) + +from utils import (get_hash, load_text_vocab, read_tf_checkpoint, + tf_vocab_to_gluon_vocab) + + +parser = argparse.ArgumentParser( + description='Conversion script for Tensorflow BERT model', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--model', + type=str, + default='bert_12_768_12', + choices=['bert_12_768_12', 'bert_24_1024_16'], + help='BERT model name') +parser.add_argument('--tf_checkpoint_dir', + type=str, + help='Path to Tensorflow checkpoint folder.') +parser.add_argument('--tf_model_prefix', type=str, + default='bert_model.ckpt', + help='name of bert checkpoint file.') +parser.add_argument('--tf_config_name', type=str, + default='bert_config.json', + help='Name of Bert config file') +parser.add_argument('--out_dir', + type=str, default=os.path.join('~', 'output'), - help='Path to output folder. The folder must exist. ' - 'Default is /home/ubuntu/output/') + help='Path to output folder. The folder must exist.') parser.add_argument('--debug', action='store_true', help='debugging mode') args = parser.parse_args() logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO) @@ -46,7 +64,7 @@ # convert vocabulary vocab_path = os.path.join(args.tf_checkpoint_dir, 'vocab.txt') -vocab, reserved_token_idx_map = convert_vocab(vocab_path) +vocab = tf_vocab_to_gluon_vocab(load_text_vocab(vocab_path)) # vocab serialization tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) @@ -60,10 +78,19 @@ # load tf model tf_checkpoint_file = os.path.expanduser( - os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt')) + os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix)) logging.info('loading Tensorflow checkpoint %s ...', tf_checkpoint_file) tf_tensors = read_tf_checkpoint(tf_checkpoint_file) tf_names = sorted(tf_tensors.keys()) + +tf_names = filter(lambda name: not name.endswith('adam_m'), tf_names) +tf_names = filter(lambda name: not name.endswith('adam_v'), tf_names) +tf_names = filter(lambda name: name != 'global_step', tf_names) +tf_names = list(tf_names) +if len(tf_tensors) != len(tf_names): + logging.info('Tensorflow model was saved with Optimizer parameters. ' + 'Ignoring them.') + for name in tf_names: logging.debug('%s: %s', name, tf_tensors[name].shape) @@ -110,22 +137,28 @@ logging.info('warning: %s has symmetric shape %s', target_name, target.shape) logging.debug('%s: %s', target_name, target.shape) -# post processings for parameters: -# - handle tied decoder weight -# - update word embedding for reserved tokens -mx_tensors['decoder.3.weight'] = mx_tensors['word_embed.0.weight'] -embedding = mx_tensors['word_embed.0.weight'] -for source_idx, dst_idx in reserved_token_idx_map: - source = embedding[source_idx].copy() - dst = embedding[dst_idx].copy() - embedding[source_idx][:] = dst - embedding[dst_idx][:] = source -logging.info('total number of tf parameters = %d', len(tf_tensors)) -logging.info('total number of mx parameters = %d (including decoder param for weight tying)', - len(mx_tensors)) - -# XXX assume no changes in BERT configs +# BERT config +tf_config_names_to_gluon_config_names = { + 'attention_probs_dropout_prob': 'embed_dropout', + 'hidden_act': None, + 'hidden_dropout_prob': 'dropout', + 'hidden_size': 'units', + 'initializer_range': None, + 'intermediate_size': 'hidden_size', + 'max_position_embeddings': 'max_length', + 'num_attention_heads': 'num_heads', + 'num_hidden_layers': 'num_layers', + 'type_vocab_size': 'token_type_vocab_size', + 'vocab_size': None +} predefined_args = bert_hparams[args.model] +with open(os.path.join(args.tf_checkpoint_dir, args.tf_config_name), 'r') as f: + tf_config = json.load(f) + assert len(tf_config) == len(tf_config_names_to_gluon_config_names) + for tf_name, gluon_name in tf_config_names_to_gluon_config_names.items(): + if tf_name is None or gluon_name is None: + continue + assert tf_config[tf_name] == predefined_args[gluon_name] # BERT encoder encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], @@ -138,6 +171,26 @@ dropout=predefined_args['dropout'], use_residual=predefined_args['use_residual']) +# Infer enabled BERTModel components +use_pooler = any('pooler' in n for n in mx_tensors) +use_decoder = any('decoder.0' in n for n in mx_tensors) +use_classifier = any('classifier.weight' in n for n in mx_tensors) + +logging.info('Inferred that the tensorflow model provides the following parameters:') +logging.info('- use_pooler = {}'.format(use_pooler)) +logging.info('- use_decoder = {}'.format(use_decoder)) +logging.info('- use_classifier = {}'.format(use_classifier)) + +# post processings for parameters: +# - handle tied decoder weight +logging.info('total number of tf parameters = %d', len(tf_names)) +if use_decoder: + mx_tensors['decoder.3.weight'] = mx_tensors['word_embed.0.weight'] + logging.info('total number of mx parameters = %d' + '(including decoder param for weight tying)', len(mx_tensors)) +else: + logging.info('total number of mx parameters = %d', len(mx_tensors)) + # BERT model bert = BERTModel(encoder, len(vocab), token_type_vocab_size=predefined_args['token_type_vocab_size'], @@ -145,14 +198,19 @@ embed_size=predefined_args['embed_size'], embed_dropout=predefined_args['embed_dropout'], word_embed=predefined_args['word_embed'], - use_pooler=True, use_decoder=True, - use_classifier=True) + use_pooler=use_pooler, use_decoder=use_decoder, + use_classifier=use_classifier) bert.initialize(init=mx.init.Normal(0.02)) ones = mx.nd.ones((2, 8)) out = bert(ones, ones, mx.nd.array([5, 6]), mx.nd.array([[1], [2]])) params = bert._collect_params_with_prefix() +if len(params) != len(mx_tensors): + raise RuntimeError('The Gluon BERTModel comprises {} parameter arrays, ' + 'but {} have been extracted from the tf model. ' + 'Most likely the BERTModel hyperparameters do not match ' + 'the hyperparameters of the tf model.'.format(len(params), len(mx_tensors))) # set parameter data loaded_params = {} diff --git a/scripts/bert/conversion_tools/infer_pytorch_gluon_parameter_name_mapping.py b/scripts/bert/conversion_tools/infer_pytorch_gluon_parameter_name_mapping.py new file mode 100644 index 0000000000..7797805735 --- /dev/null +++ b/scripts/bert/conversion_tools/infer_pytorch_gluon_parameter_name_mapping.py @@ -0,0 +1,94 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation +"""PyTorch BERT parameter naming to Gluon BERT parameter naming. + +Given a Gluon BERT model (eg. obtained with the convert_tf_gluon.py script) and +a pytorch_model.bin containing the same parameters, this script infers the +naming convention of PyTorch. + +""" + +import argparse +import json +import logging +import os +import sys + +import gluonnlp as nlp +import torch + +sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))) +from utils import load_text_vocab, tf_vocab_to_gluon_vocab + +parser = argparse.ArgumentParser(description='Pytorch BERT Naming Convention', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--model', type=str, default='bert_12_768_12', + choices=['bert_12_768_12', 'bert_24_1024_16'], help='BERT model name') +parser.add_argument('--dataset_name', type=str, default='scibert_scivocab_uncased', + help='Dataset name') +parser.add_argument('--pytorch_checkpoint_dir', type=str, + help='Path to Tensorflow checkpoint folder.') +parser.add_argument('--debug', action='store_true', help='debugging mode') +parser.add_argument('--out', default='gluon_to_pytorch_naming.json', + help='Output file to store gluon to pytorch name mapping.') +args = parser.parse_args() +logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO) +logging.info(args) + +# Load Gluon Model +bert, vocab = nlp.model.get_model(args.model, dataset_name=args.dataset_name, pretrained=True) +parameters = bert._collect_params_with_prefix() +parameters = {k: v.data().asnumpy() for k, v in parameters.items()} + +# Load PyTorch Model +pytorch_parameters = torch.load(os.path.join(args.pytorch_checkpoint_dir, 'pytorch_model.bin'), + map_location=lambda storage, loc: storage) +pytorch_vocab = tf_vocab_to_gluon_vocab( + load_text_vocab(os.path.join(args.pytorch_checkpoint_dir, 'vocab.txt'))) +pytorch_parameters = {k: v.numpy() for k, v in pytorch_parameters.items()} + +# Assert that vocabularies are equal +assert pytorch_vocab.idx_to_token == vocab.idx_to_token + +mapping = dict() + +for name, param in parameters.items(): + found_match = False + for pytorch_name, pytorch_param in pytorch_parameters.items(): + if param.shape == pytorch_param.shape: + if (param == pytorch_param).all(): + if found_match: + print('Found multiple matches for {}. ' + 'Ignoring new match {}'.format(name, pytorch_name)) + else: + found_match = True + mapping.update({name: pytorch_name}) + + # We don't break here, in case there are mulitple matches + + if not found_match: + raise RuntimeError('Pytorch and Gluon model do not match. ' + 'Cannot infer mapping of names.') + +assert len(mapping) == len(parameters) + +with open(args.out, 'w') as f: + json.dump(mapping, f, indent=" ") + print('Wrote mapping to {}'.format(args.out)) diff --git a/scripts/bert/input.txt b/scripts/bert/conversion_tools/input.txt similarity index 100% rename from scripts/bert/input.txt rename to scripts/bert/conversion_tools/input.txt diff --git a/scripts/bert/index.rst b/scripts/bert/index.rst index fa56754fd8..ffb3947bda 100644 --- a/scripts/bert/index.rst +++ b/scripts/bert/index.rst @@ -346,7 +346,7 @@ Command line interface .. code-block:: shell - python bert_embedding/bert.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research." + python bert/embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research." Text: GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research. Tokens embedding: [array([-0.11881411, -0.59530115, 0.627092 , ..., 0.00648153, -0.03886228, 0.03406909], dtype=float32), array([-0.7995638 , -0.6540758 , -0.00521846, ..., -0.42272145, diff --git a/scripts/bert/utils.py b/scripts/bert/utils.py index 0549fffcb9..fe1d9abfa1 100644 --- a/scripts/bert/utils.py +++ b/scripts/bert/utils.py @@ -22,65 +22,19 @@ import collections import hashlib import io -import json import mxnet as mx -import gluonnlp +import gluonnlp as nlp -__all__ = ['convert_vocab'] +__all__ = ['tf_vocab_to_gluon_vocab', 'load_text_vocab'] -def convert_vocab(vocab_file): - """GluonNLP specific code to convert the original vocabulary to nlp.vocab.BERTVocab.""" - original_vocab = load_vocab(vocab_file) - token_to_idx = dict(original_vocab) - num_tokens = len(token_to_idx) - idx_to_token = [None] * len(original_vocab) - for word in original_vocab: - idx = int(original_vocab[word]) - idx_to_token[idx] = word - - def swap(token, target_idx, token_to_idx, idx_to_token, swap_idx): - original_idx = token_to_idx[token] - original_token = idx_to_token[target_idx] - token_to_idx[token] = target_idx - token_to_idx[original_token] = original_idx - idx_to_token[target_idx] = token - idx_to_token[original_idx] = original_token - swap_idx.append((original_idx, target_idx)) - - reserved_tokens = [gluonnlp.vocab.bert.PADDING_TOKEN, gluonnlp.vocab.bert.CLS_TOKEN, - gluonnlp.vocab.bert.SEP_TOKEN, gluonnlp.vocab.bert.MASK_TOKEN] - - unknown_token = gluonnlp.vocab.bert.UNKNOWN_TOKEN - padding_token = gluonnlp.vocab.bert.PADDING_TOKEN - swap_idx = [] - assert unknown_token in token_to_idx - assert padding_token in token_to_idx - swap(unknown_token, 0, token_to_idx, idx_to_token, swap_idx) - for i, token in enumerate(reserved_tokens): - swap(token, i + 1, token_to_idx, idx_to_token, swap_idx) - - # sanity checks - assert len(token_to_idx) == num_tokens - assert len(idx_to_token) == num_tokens - assert None not in idx_to_token - assert len(set(idx_to_token)) == num_tokens - - bert_vocab_dict = {} - bert_vocab_dict['idx_to_token'] = idx_to_token - bert_vocab_dict['token_to_idx'] = token_to_idx - bert_vocab_dict['reserved_tokens'] = reserved_tokens - bert_vocab_dict['unknown_token'] = unknown_token - bert_vocab_dict['padding_token'] = padding_token - bert_vocab_dict['bos_token'] = None - bert_vocab_dict['eos_token'] = None - bert_vocab_dict['mask_token'] = gluonnlp.vocab.bert.MASK_TOKEN - bert_vocab_dict['sep_token'] = gluonnlp.vocab.bert.SEP_TOKEN - bert_vocab_dict['cls_token'] = gluonnlp.vocab.bert.CLS_TOKEN - json_str = json.dumps(bert_vocab_dict) - converted_vocab = gluonnlp.vocab.BERTVocab.from_json(json_str) - return converted_vocab, swap_idx +def tf_vocab_to_gluon_vocab(tf_vocab): + special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[MASK]', '[CLS]'] + assert all(t in tf_vocab for t in special_tokens) + counter = nlp.data.count_tokens(tf_vocab.keys()) + vocab = nlp.vocab.BERTVocab(counter, token_to_idx=tf_vocab) + return vocab def get_hash(filename): @@ -122,7 +76,7 @@ def profile(curr_step, start_step, end_step, profile_name='profile.json', if early_exit: exit() -def load_vocab(vocab_file): +def load_text_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() index = 0 diff --git a/src/gluonnlp/data/utils.py b/src/gluonnlp/data/utils.py index a435e5adcc..24369be9aa 100644 --- a/src/gluonnlp/data/utils.py +++ b/src/gluonnlp/data/utils.py @@ -227,7 +227,16 @@ def _slice_pad_length(num_items, length, overlap=0): 'wiki_cn_cased': 'ddebd8f3867bca5a61023f73326fb125cf12b4f5', 'wiki_cn': 'ddebd8f3867bca5a61023f73326fb125cf12b4f5', 'wiki_multilingual_uncased': '2b2514cc539047b9179e9d98a4e68c36db05c97a', - 'wiki_multilingual': '2b2514cc539047b9179e9d98a4e68c36db05c97a'} + 'wiki_multilingual': '2b2514cc539047b9179e9d98a4e68c36db05c97a', + 'scibert_scivocab_uncased': '2d2566bfc416790ab2646ab0ada36ba628628d60', + 'scibert_scivocab_cased': '2c714475b521ab8542cb65e46259f6bfeed8041b', + 'scibert_basevocab_uncased': '80ef760a6bdafec68c99b691c94ebbb918c90d02', + 'scibert_basevocab_cased': 'a4ff6fe1f85ba95f3010742b9abc3a818976bb2c', + 'biobert_v1.0_pmc_cased': 'a4ff6fe1f85ba95f3010742b9abc3a818976bb2c', + 'biobert_v1.0_pubmed_cased': 'a4ff6fe1f85ba95f3010742b9abc3a818976bb2c', + 'biobert_v1.0_pubmed_pmc_cased': 'a4ff6fe1f85ba95f3010742b9abc3a818976bb2c', + 'biobert_v1.1_pubmed_cased': 'a4ff6fe1f85ba95f3010742b9abc3a818976bb2c', + 'clinicalbert_uncased': '80ef760a6bdafec68c99b691c94ebbb918c90d02'} _url_format = '{repo_url}gluon/dataset/vocab/{file_name}.zip' diff --git a/src/gluonnlp/model/bert.py b/src/gluonnlp/model/bert.py index 58c568322e..742a8d190b 100644 --- a/src/gluonnlp/model/bert.py +++ b/src/gluonnlp/model/bert.py @@ -489,7 +489,17 @@ def _decode(self, sequence, masked_positions): ('885ebb9adc249a170c5576e90e88cfd1bbd98da6', 'bert_12_768_12_wiki_cn'), ('885ebb9adc249a170c5576e90e88cfd1bbd98da6', 'bert_12_768_12_wiki_cn_cased'), ('4e685a966f8bf07d533bd6b0e06c04136f23f620', 'bert_24_1024_16_book_corpus_wiki_en_cased'), - ('24551e1446180e045019a87fc4ffbf714d99c0b5', 'bert_24_1024_16_book_corpus_wiki_en_uncased') + ('24551e1446180e045019a87fc4ffbf714d99c0b5', 'bert_24_1024_16_book_corpus_wiki_en_uncased'), + ('6c82d963fc8fa79c35dd6cb3e1725d1e5b6aa7d7', 'bert_12_768_12_scibert_scivocab_uncased'), + ('adf9c81e72ac286a37b9002da8df9e50a753d98b', 'bert_12_768_12_scibert_scivocab_cased'), + ('75acea8e8386890120533d6c0032b0b3fcb2d536', 'bert_12_768_12_scibert_basevocab_uncased'), + ('8e86e5de55d6dae99123312cd8cdd8183a75e057', 'bert_12_768_12_scibert_basevocab_cased'), + ('a07780385add682f609772e81ec64aca77c9fb05', 'bert_12_768_12_biobert_v1.0_pmc_cased'), + ('280ad1cc487db90489f86189e045e915b35e7489', 'bert_12_768_12_biobert_v1.0_pubmed_cased'), + ('8a8c75441f028a6b928b11466f3d30f4360dfff5', + 'bert_12_768_12_biobert_v1.0_pubmed_pmc_cased'), + ('55f15c5d23829f6ee87622b68711b15fef50e55b', 'bert_12_768_12_biobert_v1.1_pubmed_cased'), + ('60281c98ba3572dfdaac75131fa96e2136d70d5c', 'bert_12_768_12_clinicalbert_uncased'), ]}) bert_12_768_12_hparams = { @@ -531,8 +541,8 @@ def _decode(self, sequence, masked_positions): def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), - root=os.path.join(get_home_dir(), 'models'), use_pooler=True, - use_decoder=True, use_classifier=True, **kwargs): + root=os.path.join(get_home_dir(), 'models'), use_pooler=True, use_decoder=True, + use_classifier=True, pretrained_allow_missing=False, **kwargs): """Generic BERT BASE model. The number of layers (L) is 12, number of units (H) is 768, and the @@ -541,10 +551,20 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), Parameters ---------- dataset_name : str or None, default None - Options include 'book_corpus_wiki_en_cased', 'book_corpus_wiki_en_uncased', - 'wiki_cn_cased', 'wiki_multilingual_uncased' and 'wiki_multilingual_cased'. + If not None, the dataset name is used to load a vocabulary for the + dataset. If the `pretrained` argument is set to True, the dataset name + is further used to select the pretrained parameters to load. + The supported datasets are 'book_corpus_wiki_en_cased', + 'book_corpus_wiki_en_uncased', 'wiki_cn_cased', + 'wiki_multilingual_uncased', 'wiki_multilingual_cased', + 'scibert_scivocab_uncased', 'scibert_scivocab_cased', + 'scibert_basevocab_uncased','scibert_basevocab_cased', + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed', + 'clinicalbert' vocab : gluonnlp.vocab.BERTVocab or None, default None - Vocabulary for the dataset. Must be provided if dataset is not specified. + Vocabulary for the dataset. Must be provided if dataset_name is not + specified. Ignored if dataset_name is specified. pretrained : bool, default True Whether to load the pretrained weights for model. ctx : Context, default CPU @@ -558,22 +578,42 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), for for segment level classification task. use_decoder : bool, default True Whether to include the decoder for masked language model prediction. + Note that + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed', + 'clinicalbert' + do not include these parameters. use_classifier : bool, default True Whether to include the classifier for next sentence classification. + Note that + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed' + do not include these parameters. + pretrained_allow_missing : bool, default False + Whether to ignore if any parameters for the BERTModel are missing in + the pretrained weights for model. + Some BERTModels for example do not provide decoder or classifier + weights. In that case it is still possible to construct a BERTModel + with use_decoder=True and/or use_classifier=True, but the respective + parameters will be missing from the pretrained file. + If pretrained_allow_missing=True, this will be ignored and the + parameters will be left uninitialized. Otherwise AssertionError is + raised. Returns ------- BERTModel, gluonnlp.vocab.BERTVocab """ - return get_bert_model(model_name='bert_12_768_12', vocab=vocab, - dataset_name=dataset_name, pretrained=pretrained, ctx=ctx, - use_pooler=use_pooler, use_decoder=use_decoder, - use_classifier=use_classifier, root=root, **kwargs) + return get_bert_model(model_name='bert_12_768_12', vocab=vocab, dataset_name=dataset_name, + pretrained=pretrained, ctx=ctx, use_pooler=use_pooler, + use_decoder=use_decoder, use_classifier=use_classifier, root=root, + pretrained_allow_missing=pretrained_allow_missing, **kwargs) -def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), - use_pooler=True, use_decoder=True, use_classifier=True, - root=os.path.join(get_home_dir(), 'models'), **kwargs): +def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), use_pooler=True, + use_decoder=True, use_classifier=True, + root=os.path.join(get_home_dir(), 'models'), + pretrained_allow_missing=False, **kwargs): """Generic BERT LARGE model. The number of layers (L) is 24, number of units (H) is 1024, and the @@ -582,9 +622,13 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu() Parameters ---------- dataset_name : str or None, default None + If not None, the dataset name is used to load a vocabulary for the + dataset. If the `pretrained` argument is set to True, the dataset name + is further used to select the pretrained parameters to load. Options include 'book_corpus_wiki_en_uncased' and 'book_corpus_wiki_en_cased'. vocab : gluonnlp.vocab.BERTVocab or None, default None - Vocabulary for the dataset. Must be provided if dataset is not specified. + Vocabulary for the dataset. Must be provided if dataset_name is not + specified. Ignored if dataset_name is specified. pretrained : bool, default True Whether to load the pretrained weights for model. ctx : Context, default CPU @@ -600,23 +644,31 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu() Whether to include the decoder for masked language model prediction. use_classifier : bool, default True Whether to include the classifier for next sentence classification. + pretrained_allow_missing : bool, default False + Whether to ignore if any parameters for the BERTModel are missing in + the pretrained weights for model. + Some BERTModels for example do not provide decoder or classifier + weights. In that case it is still possible to construct a BERTModel + with use_decoder=True and/or use_classifier=True, but the respective + parameters will be missing from the pretrained file. + If pretrained_allow_missing=True, this will be ignored and the + parameters will be left uninitialized. Otherwise AssertionError is + raised. Returns ------- BERTModel, gluonnlp.vocab.BERTVocab """ - return get_bert_model(model_name='bert_24_1024_16', vocab=vocab, - dataset_name=dataset_name, pretrained=pretrained, - ctx=ctx, use_pooler=use_pooler, - use_decoder=use_decoder, use_classifier=use_classifier, - root=root, **kwargs) - - -def get_bert_model(model_name=None, dataset_name=None, vocab=None, - pretrained=True, ctx=mx.cpu(), - use_pooler=True, use_decoder=True, use_classifier=True, - output_attention=False, output_all_encodings=False, - root=os.path.join(get_home_dir(), 'models'), **kwargs): + return get_bert_model(model_name='bert_24_1024_16', vocab=vocab, dataset_name=dataset_name, + pretrained=pretrained, ctx=ctx, use_pooler=use_pooler, + use_decoder=use_decoder, use_classifier=use_classifier, root=root, + pretrained_allow_missing=pretrained_allow_missing, **kwargs) + + +def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), + use_pooler=True, use_decoder=True, use_classifier=True, output_attention=False, + output_all_encodings=False, root=os.path.join(get_home_dir(), 'models'), + pretrained_allow_missing=False, **kwargs): """Any BERT pretrained model. Parameters @@ -624,12 +676,23 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, model_name : str or None, default None Options include 'bert_24_1024_16' and 'bert_12_768_12'. dataset_name : str or None, default None - Options include 'book_corpus_wiki_en_cased', 'book_corpus_wiki_en_uncased' - for both bert_24_1024_16 and bert_12_768_12. - 'wiki_cn_cased', 'wiki_multilingual_uncased' and 'wiki_multilingual_cased' - for bert_12_768_12 only. + If not None, the dataset name is used to load a vocabulary for the + dataset. If the `pretrained` argument is set to True, the dataset name + is further used to select the pretrained parameters to load. + The supported datasets for model_name of either bert_24_1024_16 and + bert_12_768_12 are 'book_corpus_wiki_en_cased', + 'book_corpus_wiki_en_uncased'. + For model_name bert_12_768_12 'wiki_cn_cased', + 'wiki_multilingual_uncased', 'wiki_multilingual_cased', + 'scibert_scivocab_uncased', 'scibert_scivocab_cased', + 'scibert_basevocab_uncased','scibert_basevocab_cased', + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed', + 'clinicalbert' + are additionally supported. vocab : gluonnlp.vocab.BERTVocab or None, default None - Vocabulary for the dataset. Must be provided if dataset is not specified. + Vocabulary for the dataset. Must be provided if dataset_name is not + specified. Ignored if dataset_name is specified. pretrained : bool, default True Whether to load the pretrained weights for model. ctx : Context, default CPU @@ -643,12 +706,31 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, for for segment level classification task. use_decoder : bool, default True Whether to include the decoder for masked language model prediction. + Note that + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed', + 'clinicalbert' + do not include these parameters. use_classifier : bool, default True Whether to include the classifier for next sentence classification. + Note that + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed' + do not include these parameters. output_attention : bool, default False Whether to include attention weights of each encoding cell to the output. output_all_encodings : bool, default False Whether to output encodings of all encoder cells. + pretrained_allow_missing : bool, default False + Whether to ignore if any parameters for the BERTModel are missing in + the pretrained weights for model. + Some BERTModels for example do not provide decoder or classifier + weights. In that case it is still possible to construct a BERTModel + with use_decoder=True and/or use_classifier=True, but the respective + parameters will be missing from the pretrained file. + If pretrained_allow_missing=True, this will be ignored and the + parameters will be left uninitialized. Otherwise AssertionError is + raised. Returns ------- @@ -689,6 +771,6 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, use_classifier=use_classifier) if pretrained: ignore_extra = not (use_pooler and use_decoder and use_classifier) - _load_pretrained_params(net, model_name, dataset_name, root, ctx, - ignore_extra=ignore_extra) + _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra, + allow_missing=pretrained_allow_missing) return net, bert_vocab diff --git a/src/gluonnlp/model/utils.py b/src/gluonnlp/model/utils.py index b6b4359cf6..3b4b0c7739 100644 --- a/src/gluonnlp/model/utils.py +++ b/src/gluonnlp/model/utils.py @@ -274,7 +274,8 @@ def _load_vocab(dataset_name, vocab, root, cls=None): return vocab -def _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=False): +def _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=False, + allow_missing=False): path = '_'.join([model_name, dataset_name]) model_file = model_store.get_model_file(path, root=root) - net.load_parameters(model_file, ctx=ctx, ignore_extra=ignore_extra) + net.load_parameters(model_file, ctx=ctx, ignore_extra=ignore_extra, allow_missing=allow_missing) diff --git a/tests/unittest/test_models.py b/tests/unittest/test_models.py index 0b290cf96c..3f9e842d92 100644 --- a/tests/unittest/test_models.py +++ b/tests/unittest/test_models.py @@ -98,29 +98,80 @@ def test_transformer_models(): @pytest.mark.serial @pytest.mark.remote_required -def test_pretrained_bert_models(): +@pytest.mark.parametrize('disable_missing_parameters', [False, True]) +def test_pretrained_bert_models(disable_missing_parameters): models = ['bert_12_768_12', 'bert_24_1024_16'] pretrained = { - 'bert_12_768_12': - ['book_corpus_wiki_en_cased', 'book_corpus_wiki_en_uncased', - 'wiki_multilingual_uncased', 'wiki_multilingual_cased', 'wiki_cn_cased'], - 'bert_24_1024_16': ['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased']} + 'bert_12_768_12': [ + 'book_corpus_wiki_en_cased', 'book_corpus_wiki_en_uncased', 'wiki_multilingual_uncased', + 'wiki_multilingual_cased', 'wiki_cn_cased', 'scibert_scivocab_uncased', + 'scibert_scivocab_cased', 'scibert_basevocab_uncased', 'scibert_basevocab_cased', + 'biobert_v1.0_pmc_cased', 'biobert_v1.0_pubmed_cased', 'biobert_v1.0_pubmed_pmc_cased', + 'biobert_v1.1_pubmed_cased', 'clinicalbert_uncased' + ], + 'bert_24_1024_16': ['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased'] + } vocab_size = {'book_corpus_wiki_en_cased': 28996, 'book_corpus_wiki_en_uncased': 30522, 'wiki_multilingual_cased': 119547, 'wiki_cn_cased': 21128, - 'wiki_multilingual_uncased': 105879} + 'wiki_multilingual_uncased': 105879, + 'scibert_scivocab_uncased': 31090, + 'scibert_scivocab_cased': 31116, + 'scibert_basevocab_uncased': 30522, + 'scibert_basevocab_cased': 28996, + 'biobert_v1.0_pubmed_cased': 28996, + 'biobert_v1.0_pmc_cased': 28996, + 'biobert_v1.0_pubmed_pmc_cased': 28996, + 'biobert_v1.1_pubmed_cased': 28996, + 'clinicalbert_uncased': 30522} special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[CLS]', '[MASK]'] ones = mx.nd.ones((2, 10)) valid_length = mx.nd.ones((2,)) positions = mx.nd.zeros((2, 3)) for model_name in models: - eprint('testing forward for %s' % model_name) pretrained_datasets = pretrained.get(model_name) for dataset in pretrained_datasets: - model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, - pretrained=True, - root='tests/data/model/') + has_missing_params = any(n in dataset for n in ('biobert', 'clinicalbert')) + if not has_missing_params and disable_missing_parameters: + # No parameters to disable for models pretrained on this dataset + continue + + eprint('testing forward for %s on %s' % (model_name, dataset)) + + if not has_missing_params: + model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, + pretrained=True, + root='tests/data/model/') + else: + with pytest.raises(AssertionError): + model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, + pretrained=True, + root='tests/data/model/') + + if not disable_missing_parameters: + model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, + pretrained=True, + root='tests/data/model/', + pretrained_allow_missing=True) + elif 'biobert' in dataset: + # Biobert specific test case + model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, + pretrained=True, + root='tests/data/model/', + pretrained_allow_missing=True, + use_decoder=False, + use_classifier=False) + elif 'clinicalbert' in dataset: + # Clinicalbert specific test case + model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, + pretrained=True, + root='tests/data/model/', + pretrained_allow_missing=True, + use_decoder=False) + else: + assert False, "Testcase needs to be adapted." + assert len(vocab) == vocab_size[dataset] for token in special_tokens: assert token in vocab, "Token %s not found in the vocab" % token @@ -129,8 +180,14 @@ def test_pretrained_bert_models(): assert vocab.unknown_token == '[UNK]' assert vocab.bos_token is None assert vocab.eos_token is None - output = model(ones, ones, valid_length, positions) - output[0].wait_to_read() + + if has_missing_params and not disable_missing_parameters: + with pytest.raises(RuntimeError): + output = model(ones, ones, valid_length, positions) + output[0].wait_to_read() + else: + output = model(ones, ones, valid_length, positions) + output[0].wait_to_read() del model mx.nd.waitall() @@ -482,6 +539,7 @@ def forward(self, inpt): for name, param in shared_net.collect_params().items(): assert not mx.test_utils.almost_equal(grads[name].asnumpy(), param.grad().asnumpy()) + def test_gelu(): x = mx.random.uniform(shape=(3, 4, 5)) net = nlp.model.GELU()