Skip to content

Commit

Permalink
Integrate FT into generation api (PaddlePaddle#1154)
Browse files Browse the repository at this point in the history
* Add the first draft for integrating FT into generation api.

* Add try-catch in FT using of generation api.

* Refine FasterTransformer integration into generation api.

* Update some checks in FT integration.

Co-authored-by: smallv0221 <[email protected]>
  • Loading branch information
guoshengCS and smallv0221 authored Oct 27, 2021
1 parent 27e0c34 commit 20acd16
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 12 deletions.
11 changes: 6 additions & 5 deletions paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,12 @@ def do_predict(args):
input_ids, mem_seq_lens = prepare_input(tokenizer, sentences, pad_id)

# Define model
faster_bart = FasterBART(
model=model,
decoding_strategy=args.decoding_strategy,
decoding_lib=args.decoding_lib,
use_fp16_decoding=args.use_fp16_decoding)
faster_bart = model
# faster_bart = FasterBART(
# model=model,
# decoding_strategy=args.decoding_strategy,
# decoding_lib=args.decoding_lib,
# use_fp16_decoding=args.use_fp16_decoding)

# Set evaluate mode
faster_bart.eval()
Expand Down
28 changes: 28 additions & 0 deletions paddlenlp/transformers/bart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,34 @@ def get_encoder(self):
def get_decoder(self):
return self.bart.get_decoder()

def prepare_faster_entry(self, kwargs):
from paddlenlp.ops import FasterBART
decoding_strategy = kwargs.get('decode_strategy')
model_kwargs = kwargs['model_kwargs']
use_fp16_decoding = model_kwargs.get('use_fp16_decoding', False)
# TODO(guosheng): Currently, beam_search_v2 in FasterTransformer uses
# t2t beam search which has some difference with beam search in generation
# api on finish queue addition and early-stop criterion, and it seems
# lead to poor performance on bart cnn-sum model, thus we disable it temporarily.
if decoding_strategy == 'beam_search':
return False
# Some checks on kwargs. For example, FasterBART needs `mem_seq_lens` as
# one input while BART not, thus check whether `mem_seq_lens` in kwargs.
if model_kwargs.get('mem_seq_lens', None) is None:
return False
# Assume no args change among multi-turns run to convert parameters only
# once. Additionaly, use some converted args as default values instead of
# converting args to allow overriding.
# TODO(guosheng): maybe use weakref for the model in faster model
self._faster_entry = partial(
FasterBART(
self,
decoding_strategy=decoding_strategy,
use_fp16_decoding=use_fp16_decoding).generate,
alpha=kwargs.get('length_penalty'),
rel_len=False)
return self._faster_entry

def forward(self,
input_ids,
attention_mask=None,
Expand Down
77 changes: 70 additions & 7 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.nn.functional as F
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layers.utils import map_structure
from paddlenlp.utils.log import logger

__all__ = ["GenerationMixin"]

Expand Down Expand Up @@ -170,7 +171,7 @@ def process(self,
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (
next_token.numpy().item() == eos_token_id):
# If beam_token does not belong to top num_beams tokens,
# If beam_token does not belong to top num_beams tokens,
# it should not be added
is_beam_token_worse_than_top_num_beams = (
beam_token_rank >= self.group_size)
Expand Down Expand Up @@ -357,10 +358,10 @@ def expand_inputs_for_generation(input_ids,
def update_model_kwargs_for_generation(outputs,
model_kwargs,
is_encoder_decoder=False):
# Update the model inputs during generation.
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
# and they contain pad value, the result vectors updated by this method
# may be different from expected. In this case, you need to rewrite the
# Update the model inputs during generation.
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
# and they contain pad value, the result vectors updated by this method
# may be different from expected. In this case, you need to rewrite the
# method.

# update cache
Expand Down Expand Up @@ -433,11 +434,37 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}

def adjust_logits_during_generation(self, logits):
# Implement in subclasses for custom behavior to adjust the logits in
# Implement in subclasses for custom behavior to adjust the logits in
# the generate method.

return logits

def prepare_faster_entry(self, kwargs):
pass

def _convert_to_faster(self, kwargs):
# try general convert
pass

def _build_faster(self, kwargs):
self._faster_entry = False

# common check for FasterTransformer
if kwargs['min_length'] != 0:
# not support for min_length yet in the faster version
return
if kwargs['repetition_penalty'] != 0:
# not support for repetition_penalty yet in the faster version
return
if kwargs['temperature'] != 1:
# not support for temperature yet in the faster version
return

# 1. custom convert
if not self.prepare_faster_entry(kwargs):
# 2. try general convert
self._convert_to_faster(kwargs)

@paddle.no_grad()
def generate(self,
input_ids=None,
Expand Down Expand Up @@ -610,6 +637,42 @@ def generate(self,
print(response)
# ['是的', '嗯嗯']
"""
# Switch to FasterTransformer automatically if supporting.
if getattr(self, '_faster_entry', None) is not False:
# TODO(guosheng): need better way to avoid recursive building
if not self.__class__.__module__.endswith('faster_transformer'):
args = locals()
args.pop('self')
args.pop("__class__", None)
try:
if not hasattr(self, '_faster_entry'):
self._build_faster(args)
if self._faster_entry:
model_kwargs = args.pop('model_kwargs')
# transpose to batch major to be consistent with original results
output_ids = self._faster_entry(**args, **model_kwargs)
if len(output_ids.shape) == 2: # sampling
output_ids = paddle.transpose(output_ids, [1, 0])
else: # beam search
output_ids = paddle.transpose(output_ids, [1, 2, 0])
output_ids = output_ids[:, :
num_return_sequences].reshape(
[
-1,
output_ids.shape[-1]
])
# append dummy scores to be consistent with original results
scores = None
return output_ids, scores
else:
# TODO(guosheng): Maybe we can report the unsupported
# reasons to help users enable FasterTransformer when not
# supporting.
pass
except Exception:
logger.warning(
"FasterTransformer is not available, "
"and the original version would be used instead.")

# params check
bos_token_id = bos_token_id if bos_token_id is not None else getattr(
Expand Down Expand Up @@ -778,7 +841,7 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
sorted_indices = paddle.argsort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)

# Remove tokens with cumulative probs above the top_p, But keep at
# Remove tokens with cumulative probs above the top_p, But keep at
# least min_tokens_to_keep tokens
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
Expand Down

0 comments on commit 20acd16

Please sign in to comment.