Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPT ut #3133

Merged
merged 11 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,17 @@ def __init__(self,
super(GPTModel, self).__init__()

self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.eol_token_id = eol_token_id
self.initializer_range = initializer_range
self.topo = topo
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.bias = paddle.tril(
paddle.ones(
[1, 1, max_position_embeddings, max_position_embeddings],
dtype="int64"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是否对静态图运行有影响

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 program guard 里面应该是 ok 的


self.embeddings = GPTEmbeddings(vocab_size, hidden_size,
hidden_dropout_prob,
Expand Down Expand Up @@ -744,6 +751,12 @@ def __init__(self,
self.apply(self.init_weights)
self.checkpoints = []

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def forward(self,
input_ids,
position_ids=None,
Expand Down Expand Up @@ -816,16 +829,22 @@ def forward(self,
position_ids=position_ids)

# TODO, use registered buffer
causal_mask = paddle.tensor.triu(paddle.ones(
(paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4,
diagonal=1)
length = paddle.shape(input_ids)[-1]
if cache is not None:
cache_length = paddle.shape(cache[0].k)[2]
length = length + cache_length
else:
cache_length = 0
casual_mask = self.bias[:, :, cache_length:length, :length]

if attention_mask is not None:
if attention_mask.dtype != paddle.int64:
attention_mask = paddle.cast(attention_mask, dtype=paddle.int64)
if len(attention_mask.shape) == 2:
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask + causal_mask
attention_mask = (1.0 - (attention_mask & casual_mask)) * -1e9
else:
attention_mask = causal_mask
attention_mask = (1.0 - casual_mask) * -1e9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保持和原来一致还用 -1e4 吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API文档中attention_mask的支持情况也一并调整了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# The tensor returned by triu not in static graph.
attention_mask.stop_gradient = True

Expand Down Expand Up @@ -1111,7 +1130,6 @@ def forward(self,
`cache_kvs` is the cache output of gpt model if `use_cache` is True.

"""

outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1166,12 +1184,8 @@ def prepare_inputs_for_generation(self,
# only last token for inputs_ids if cache is defined in kwargs
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if len(attention_mask.shape) == 4:
attention_mask = attention_mask[:, -1, -1, :]
if "int" in paddle.common_ops_import.convert_dtype(
attention_mask.dtype):
attention_mask = (1.0 - attention_mask) * -1e4
if attention_mask is not None and len(attention_mask.shape) == 4:
attention_mask = attention_mask[:, -1:, -1:, :]
if cache is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if position_ids is not None:
Expand All @@ -1184,6 +1198,19 @@ def prepare_inputs_for_generation(self,
"cache": cache
}

@staticmethod
def prepare_attention_mask_for_generation(input_ids, pad_token_id,
eos_token_id):
is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
input_ids == pad_token_id).numpy().item()
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id))
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
attention_mask = (input_ids != pad_token_id).astype("int64")
else:
attention_mask = paddle.ones_like(input_ids, dtype="int64")
return paddle.unsqueeze(attention_mask, axis=[1, 2])

def __getattr__(self, name):
try:
return super().__getattr__(name)
Expand Down
23 changes: 20 additions & 3 deletions paddlenlp/transformers/gpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __init__(
unk_token='<|endoftext|>',
eol_token='\u010a',
add_prefix_space=False,
add_bos_token=False,
**kwargs # The token of newline.
):

Expand All @@ -382,9 +383,11 @@ def __init__(
lstrip=False, rstrip=False) if isinstance(
unk_token, str) else unk_token
self.eol_token = eol_token
self._build_special_tokens_map_extended(bos_token=pad_token,
eos_token=eos_token,
unk_token=unk_token)
self._build_special_tokens_map_extended(
bos_token=pad_token
if getattr(self, "bos_token", None) is None else self.bos_token,
eos_token=eos_token,
unk_token=unk_token)

self._vocab_file = vocab_file
self._merges_file = merges_file
Expand All @@ -410,6 +413,7 @@ def __init__(
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
self.add_bos_token = add_bos_token

re = try_import("regex")
self.pat = re.compile(
Expand Down Expand Up @@ -555,3 +559,16 @@ def prepare_for_tokenization(self,
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []

output = bos_token_ids + token_ids_0

if token_ids_1 is None:
return output

return output + bos_token_ids + token_ids_1
4 changes: 2 additions & 2 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_input_embeddings(self):
else:
raise NotImplementedError(
f'model of {type(base_model)} has not implemented the `get_input_embeddings`'
' or `set_input_embedding` method')
' or `set_input_embeddings` method')

def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
Expand All @@ -163,7 +163,7 @@ def set_input_embeddings(self, value):
else:
raise NotImplementedError(
f'model of {type(base_model)} has not implemented the `get_input_embeddings`'
' or `set_input_embedding` method')
' or `set_input_embeddings` method')

def get_output_embeddings(self):
return None # Overwrite for models with output embeddings
Expand Down
Loading