Skip to content

Commit

Permalink
Add codegen unittests (#3348)
Browse files Browse the repository at this point in the history
* add codegen unittests

* fix codegen

* update
  • Loading branch information
FrostML authored Sep 29, 2022
1 parent fe2543d commit 131750a
Show file tree
Hide file tree
Showing 6 changed files with 868 additions and 14 deletions.
45 changes: 34 additions & 11 deletions paddlenlp/transformers/codegen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,28 @@
from ..nezha.modeling import ACT2FN
from .. import PretrainedModel, register_base_model

CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"Salesforce/codegen-350M-nl",
"Salesforce/codegen-350M-multi",
"Salesforce/codegen-350M-mono",
"Salesforce/codegen-2B-nl",
"Salesforce/codegen-2B-multi",
"Salesforce/codegen-2B-mono",
"Salesforce/codegen-6B-nl",
"Salesforce/codegen-6B-multi",
"Salesforce/codegen-6B-mono",
"Salesforce/codegen-16B-nl",
"Salesforce/codegen-16B-multi",
"Salesforce/codegen-16B-mono",
]


def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000**(paddle.arange(0, dim, 2) / dim))
sinusoid_inp = (paddle.einsum("i , j -> i j",
sinusoid_inp = (paddle.einsum("i,j->ij",
paddle.arange(seq_len, dtype="float32"),
inv_freq))
return paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)
Expand Down Expand Up @@ -59,13 +74,10 @@ def __init__(self, embed_dim, rotary_dim, num_attention_heads,
max_positions, attn_pdrop, resid_pdrop):
super().__init__()

self.register_buffer(
"causal_mask",
paddle.tril(
paddle.ones((max_positions, max_positions),
dtype=paddle.get_default_dtype())).reshape(
(1, 1, max_positions, max_positions)),
)
self.causal_mask = paddle.tril(
paddle.ones((max_positions, max_positions),
dtype=paddle.get_default_dtype())).reshape(
(1, 1, max_positions, max_positions))

self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
Expand Down Expand Up @@ -475,7 +487,7 @@ def forward(
"specified when generating attention_mask"
if batch_size == 1 and past_length != 0:
batch_size, seq_len = input_shape
attention_mask = paddle.ones(
attention_mask = paddle.zeros(
[batch_size, 1, 1, seq_len + past_length],
dtype=paddle.get_default_dtype())
else:
Expand All @@ -487,7 +499,13 @@ def forward(
attention_mask = paddle.unsqueeze(
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True
attention_mask.stop_gradient = True
# TODO: CodeGen Attention Mask is TOO confusion.
# When it's 2D, it must be int and it's denoted by 1/0.
# When using model.generate() without providing attention mask
# or using 4D attention mask,
# the attention mask's dtype must be float and it's denoted by 0/-inf.
# Moreover, cannot support 3D attention mask.

inputs_embeds = self.wte(input_ids)
if token_type_ids is not None:
Expand Down Expand Up @@ -521,7 +539,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
r"""
CodeGen Model with a `language modeling` head on top.
Args:
bart (:class:`CodeGenModel`):
transformer (:class:`CodeGenModel`):
An instance of CodeGenModel.
"""
_keys_to_ignore_on_load_missing = [
Expand Down Expand Up @@ -572,8 +590,12 @@ def prepare_faster_entry(self, kwargs):

def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
token_type_ids = kwargs.get("token_type_ids", None)

if cache:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
Expand All @@ -585,6 +607,7 @@ def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs):
"cache": cache,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

def forward(self,
Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/transformers/codegen/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

__all__ = ['CodeGenTokenizer']

VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}


class CodeGenTokenizer(GPTTokenizer):

Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,16 +833,16 @@ def forward(self,
length = length + cache_length
else:
cache_length = 0
casual_mask = self.bias[:, :, cache_length:length, :length]
causal_mask = self.bias[:, :, cache_length:length, :length]

if attention_mask is not None:
if attention_mask.dtype != paddle.int64:
attention_mask = paddle.cast(attention_mask, dtype=paddle.int64)
if len(attention_mask.shape) == 2:
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - (attention_mask & casual_mask)) * -1e4
attention_mask = (1.0 - (attention_mask & causal_mask)) * -1e4
else:
attention_mask = (1.0 - casual_mask) * -1e4
attention_mask = (1.0 - causal_mask) * -1e4
# The tensor returned by triu not in static graph.
attention_mask.stop_gradient = True

Expand Down
13 changes: 13 additions & 0 deletions tests/transformers/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 131750a

Please sign in to comment.