From 39cccfbb35eb86d8b04064ea0098539174f34fd3 Mon Sep 17 00:00:00 2001 From: wawltor Date: Sun, 7 Feb 2021 22:39:33 +0800 Subject: [PATCH] fix the ernie pretrain and typo --- .../bert/{prdict_glue.py => predict_glue.py} | 61 ++++++-------- examples/language_model/bert/run_pretrain.py | 22 ++--- paddlenlp/transformers/ernie/modeling.py | 84 +++++++++++++++---- 3 files changed, 107 insertions(+), 60 deletions(-) rename examples/language_model/bert/{prdict_glue.py => predict_glue.py} (74%) diff --git a/examples/language_model/bert/prdict_glue.py b/examples/language_model/bert/predict_glue.py similarity index 74% rename from examples/language_model/bert/prdict_glue.py rename to examples/language_model/bert/predict_glue.py index 946ca8d54a39..eacda8771bd9 100644 --- a/examples/language_model/bert/prdict_glue.py +++ b/examples/language_model/bert/predict_glue.py @@ -33,43 +33,36 @@ def parse_args(): type=str, required=True, help="The name of the task to perform predict, selected in the list: " + - ", ".join(TASK_CLASSES.keys()), - ) + ", ".join(TASK_CLASSES.keys()), ) parser.add_argument( "--model_type", default=None, type=str, required=True, help="Model type selected in the list: " + - ", ".join(MODEL_CLASSES.keys()), - ) + ", ".join(MODEL_CLASSES.keys()), ) parser.add_argument( "--model_path", default=None, type=str, required=True, - help="The path prefix of inference model to be used.", - ) + help="The path prefix of inference model to be used.", ) parser.add_argument( "--select_device", default="gpu", choices=["gpu", "cpu", "xpu"], - help="Device selected for inference.", - ) + help="Device selected for inference.", ) parser.add_argument( "--batch_size", default=32, type=int, - help="Batch size for predict.", - ) + help="Batch size for predict.", ) parser.add_argument( "--max_seq_length", default=128, type=int, - help= - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded.", - ) + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) args = parser.parse_args() return args @@ -108,8 +101,8 @@ def create_predictor(cls, args): def predict_batch(self, data): for input_field, input_handle in zip(data, self.input_handles): - input_handle.copy_from_cpu(input_field.numpy( - ) if isinstance(input_field, paddle.Tensor) else input_field) + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) self.predictor.run() output = [ output_handle.copy_to_cpu() for output_handle in self.output_handles @@ -117,14 +110,14 @@ def predict_batch(self, data): return output def predict(self, dataset, collate_fn, batch_size=1): - batch_sampler = paddle.io.BatchSampler(dataset, - batch_size=batch_size, - shuffle=False) - data_loader = paddle.io.DataLoader(dataset=dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=0, - return_list=True) + batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=batch_size, shuffle=False) + data_loader = paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + return_list=True) outputs = [] for data in data_loader: output = self.predict_batch(data) @@ -143,13 +136,14 @@ def main(): model_class, tokenizer_class = MODEL_CLASSES[args.model_type] dataset = dataset_class.get_datasets("test") - tokenizer = tokenizer_class.from_pretrained(os.path.dirname( - args.model_path)) - transform_fn = partial(convert_example, - tokenizer=tokenizer, - label_list=dataset.get_labels(), - max_seq_length=args.max_seq_length, - is_test=True) + tokenizer = tokenizer_class.from_pretrained( + os.path.dirname(args.model_path)) + transform_fn = partial( + convert_example, + tokenizer=tokenizer, + label_list=dataset.get_labels(), + max_seq_length=args.max_seq_length, + is_test=True) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment @@ -157,9 +151,8 @@ def main(): ): [data for i, data in enumerate(fn(samples)) if i != 2] dataset = dataset.apply(transform_fn) - predictor.predict(dataset, - batch_size=args.batch_size, - collate_fn=batchify_fn) + predictor.predict( + dataset, batch_size=args.batch_size, collate_fn=batchify_fn) if __name__ == "__main__": diff --git a/examples/language_model/bert/run_pretrain.py b/examples/language_model/bert/run_pretrain.py index 551e77b8b78c..fce75ed7b90e 100644 --- a/examples/language_model/bert/run_pretrain.py +++ b/examples/language_model/bert/run_pretrain.py @@ -41,8 +41,10 @@ logger = logging.getLogger(__name__) MODEL_CLASSES = { - "bert": (BertForPretraining, BertTokenizer), - "ernie": (ErnieForPretraining, ErnieTokenizer) + "bert": + (BertModel, BertForPretraining, BertPretrainingCriterion, BertTokenizer), + "ernie": + (ErnieModel, ErnieForPretraining, ErniePretrainingCriterion, ErnieTokenizer) } @@ -289,16 +291,16 @@ def do_train(args): worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank()) args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[ + args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - model = BertForPretraining( - BertModel(**model_class.pretrained_init_configuration[ + model = model_class( + base_class(**model_class.pretrained_init_configuration[ args.model_name_or_path])) - criterion = BertPretrainingCriterion( - getattr(model, BertForPretraining.base_model_prefix).config[ - "vocab_size"]) + criterion = criterion_class( + getattr(model, model_class.base_model_prefix).config["vocab_size"]) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) @@ -328,8 +330,8 @@ def do_train(args): for epoch in range(args.num_train_epochs): files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in - f + if os.path.isfile(os.path.join(args.input_dir, f)) and + "training" in f ] files.sort() num_files = len(files) diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py index fe032e3861ce..45f6533ecf73 100644 --- a/paddlenlp/transformers/ernie/modeling.py +++ b/paddlenlp/transformers/ernie/modeling.py @@ -37,6 +37,7 @@ def __init__(self, type_vocab_size=2, pad_token_id=0): super(ErnieEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding( vocab_size, hidden_size, padding_idx=pad_token_id) self.position_embeddings = nn.Embedding(max_position_embeddings, @@ -55,7 +56,6 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None): position_ids.stop_gradient = True if token_type_ids is None: token_type_ids = paddle.zeros_like(input_ids, dtype="int64") - input_embedings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -324,6 +324,56 @@ def forward(self, return logits +class ErnieLMPredictionHead(nn.Layer): + def __init__(self, + hidden_size, + vocab_size, + activation, + embedding_weights=None): + super(ErnieLMPredictionHead, self).__init__() + self.transform = nn.Linear(hidden_size, hidden_size) + self.activation = getattr(nn.functional, activation) + self.layer_norm = nn.LayerNorm(hidden_size) + self.decoder_weight = self.create_parameter( + shape=[hidden_size, vocab_size], + dtype=self.transform.weight.dtype, + is_bias=True) if embedding_weights is None else embedding_weights + self.decoder_bias = self.create_parameter( + shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True) + + def forward(self, hidden_states, masked_positions=None): + if masked_positions is not None: + hidden_states = paddle.reshape(hidden_states, + [-1, hidden_states.shape[-1]]) + hidden_states = paddle.tensor.gather(hidden_states, + masked_positions) + # gather masked tokens might be more quick + hidden_states = self.transform(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = paddle.tensor.matmul( + hidden_states, self.decoder_weight, + transpose_y=True) + self.decoder_bias + return hidden_states + + +class ErniePretrainingHeads(nn.Layer): + def __init__(self, + hidden_size, + vocab_size, + activation, + embedding_weights=None): + super(ErniePretrainingHeads, self).__init__() + self.predictions = ErnieLMPredictionHead(hidden_size, vocab_size, + activation, embedding_weights) + self.seq_relationship = nn.Linear(hidden_size, 2) + + def forward(self, sequence_output, pooled_output, masked_positions=None): + prediction_scores = self.predictions(sequence_output, masked_positions) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + class ErnieForPretraining(ErniePretrainedModel): def __init__(self, ernie): super(ErnieForPretraining, self).__init__() @@ -342,15 +392,16 @@ def forward(self, position_ids=None, attention_mask=None, masked_positions=None): - outputs = self.ernie( - input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output, masked_positions) - return prediction_scores, seq_relationship_score + with paddle.static.amp.fp16_guard(): + outputs = self.ernie( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output, masked_positions) + return prediction_scores, seq_relationship_score class ErniePretrainingCriterion(paddle.nn.Layer): @@ -361,9 +412,10 @@ def __init__(self, vocab_size): def forward(self, prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels, masked_lm_scale): - masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy( - prediction_scores, masked_lm_labels, ignore_index=-1) - masked_lm_loss = masked_lm_loss / masked_lm_scale - next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy( - seq_relationship_score, next_sentence_labels) - return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss) + with paddle.static.amp.fp16_guard(): + masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy( + prediction_scores, masked_lm_labels, ignore_index=-1) + masked_lm_loss = masked_lm_loss / masked_lm_scale + next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy( + seq_relationship_score, next_sentence_labels) + return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)