-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPT ut #3133
Add GPT ut #3133
Changes from 4 commits
18ea818
b05b9bb
de554f4
9d863a9
c3c9070
be1b452
6159327
8e340ab
d17a819
3587548
8735549
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -707,10 +707,17 @@ def __init__(self, | |
super(GPTModel, self).__init__() | ||
|
||
self.pad_token_id = pad_token_id | ||
self.eos_token_id = eos_token_id | ||
self.bos_token_id = bos_token_id | ||
self.eol_token_id = eol_token_id | ||
self.initializer_range = initializer_range | ||
self.topo = topo | ||
self.hidden_size = hidden_size | ||
self.vocab_size = vocab_size | ||
self.bias = paddle.tril( | ||
paddle.ones( | ||
[1, 1, max_position_embeddings, max_position_embeddings], | ||
dtype="int64")) | ||
|
||
self.embeddings = GPTEmbeddings(vocab_size, hidden_size, | ||
hidden_dropout_prob, | ||
|
@@ -744,6 +751,12 @@ def __init__(self, | |
self.apply(self.init_weights) | ||
self.checkpoints = [] | ||
|
||
def get_input_embeddings(self): | ||
return self.embeddings.word_embeddings | ||
|
||
def set_input_embeddings(self, value): | ||
self.embeddings.word_embeddings = value | ||
|
||
def forward(self, | ||
input_ids, | ||
position_ids=None, | ||
|
@@ -816,16 +829,22 @@ def forward(self, | |
position_ids=position_ids) | ||
|
||
# TODO, use registered buffer | ||
causal_mask = paddle.tensor.triu(paddle.ones( | ||
(paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4, | ||
diagonal=1) | ||
length = paddle.shape(input_ids)[-1] | ||
if cache is not None: | ||
cache_length = paddle.shape(cache[0].k)[2] | ||
length = length + cache_length | ||
else: | ||
cache_length = 0 | ||
casual_mask = self.bias[:, :, cache_length:length, :length] | ||
|
||
if attention_mask is not None: | ||
if attention_mask.dtype != paddle.int64: | ||
attention_mask = paddle.cast(attention_mask, dtype=paddle.int64) | ||
if len(attention_mask.shape) == 2: | ||
attention_mask = attention_mask[:, None, None, :] | ||
attention_mask = attention_mask + causal_mask | ||
attention_mask = (1.0 - (attention_mask & casual_mask)) * -1e9 | ||
else: | ||
attention_mask = causal_mask | ||
attention_mask = (1.0 - casual_mask) * -1e9 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 保持和原来一致还用 -1e4 吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. API文档中attention_mask的支持情况也一并调整了吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
# The tensor returned by triu not in static graph. | ||
attention_mask.stop_gradient = True | ||
|
||
|
@@ -1111,7 +1130,6 @@ def forward(self, | |
`cache_kvs` is the cache output of gpt model if `use_cache` is True. | ||
|
||
""" | ||
|
||
outputs = self.gpt(input_ids, | ||
position_ids=position_ids, | ||
attention_mask=attention_mask, | ||
|
@@ -1166,12 +1184,8 @@ def prepare_inputs_for_generation(self, | |
# only last token for inputs_ids if cache is defined in kwargs | ||
position_ids = kwargs.get("position_ids", None) | ||
attention_mask = kwargs.get("attention_mask", None) | ||
if attention_mask is not None: | ||
if len(attention_mask.shape) == 4: | ||
attention_mask = attention_mask[:, -1, -1, :] | ||
if "int" in paddle.common_ops_import.convert_dtype( | ||
attention_mask.dtype): | ||
attention_mask = (1.0 - attention_mask) * -1e4 | ||
if attention_mask is not None and len(attention_mask.shape) == 4: | ||
attention_mask = attention_mask[:, -1:, -1:, :] | ||
if cache is not None: | ||
input_ids = input_ids[:, -1].unsqueeze(-1) | ||
if position_ids is not None: | ||
|
@@ -1184,6 +1198,19 @@ def prepare_inputs_for_generation(self, | |
"cache": cache | ||
} | ||
|
||
@staticmethod | ||
def prepare_attention_mask_for_generation(input_ids, pad_token_id, | ||
eos_token_id): | ||
is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any( | ||
input_ids == pad_token_id).numpy().item() | ||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( | ||
(eos_token_id is not None) and (pad_token_id != eos_token_id)) | ||
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: | ||
attention_mask = (input_ids != pad_token_id).astype("int64") | ||
else: | ||
attention_mask = paddle.ones_like(input_ids, dtype="int64") | ||
return paddle.unsqueeze(attention_mask, axis=[1, 2]) | ||
|
||
def __getattr__(self, name): | ||
try: | ||
return super().__getattr__(name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是否对静态图运行有影响
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 program guard 里面应该是 ok 的