Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve CodeGen #3371

Merged
merged 9 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/code_generation/codegen/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def do_train(args):
block_size)
dev_set = process_ds(dev_set, tokenizer, args.overwrite_cache, block_size)

batchify_fn = DataCollatorWithPadding(tokenizer)
batchify_fn = DataCollatorWithPadding(tokenizer, return_attention_mask=True)

train_batch_sampler = DistributedBatchSampler(
train_set, batch_size=args.train_batch_size, shuffle=True)
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class DataCollatorWithPadding:
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pd"
return_attention_mask: Optional[bool] = None

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch = self.tokenizer.pad(
Expand All @@ -200,7 +201,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
return_attention_mask=self.return_attention_mask)
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
Expand Down
13 changes: 13 additions & 0 deletions paddlenlp/transformers/codegen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
use_cache=False,
cache=None,
):
Expand Down Expand Up @@ -468,7 +469,13 @@ def forward(
else:
past_length = cache[0][0].shape[-2]

batch_size, seq_len = input_shape
# Attention mask.
if batch_size == 1 and past_length != 0:
Copy link
Contributor

@FrostML FrostML Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不对。把 batch size 判断放到 attention mask 为 None 里面,这样即使动转静,因一定会提供 mask,那就不会存在潜在问题

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE and thx

attention_mask = paddle.ones(
[batch_size, 1, 1, seq_len + past_length],
dtype=paddle.get_default_dtype())

if attention_mask is None:
assert input_ids is not None, "input_ids should be " \
"specified when generating attention_mask"
Expand All @@ -483,6 +490,10 @@ def forward(
attention_mask.stop_gradient = True

inputs_embeds = self.wte(input_ids)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
inputs_embeds = inputs_embeds + token_type_embeds

hidden_states = self.drop(inputs_embeds)
output_shape = input_shape[:] + [hidden_states.shape[-1]]

Expand Down Expand Up @@ -579,6 +590,7 @@ def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs):
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
use_cache=False,
cache=None):
r"""
Expand Down Expand Up @@ -613,6 +625,7 @@ def forward(self,

transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
use_cache=use_cache,
cache=cache)

Expand Down