Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Add GeneratedText and update TextGenerationOutput #1240

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 100 additions & 28 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import os
import warnings
from enum import Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -52,6 +54,12 @@
__all__ = ["TextGenerationPipeline"]


class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
TIME = "time"


class TextGenerationInput(BaseModel):
class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -146,15 +154,34 @@ class Config:
)


class GeneratedText(BaseModel):
text: str = Field(
description="The generated sequence for a given prompt. If "
"streaming is enabled, this will be the next generated token."
)
score: Optional[Any] = Field(
description="The score for the generated token or sequence. "
"The scores have the shape [sequence_length, vocab_size]"
)
finished: bool = Field(description="Whether generation has stopped.")
finished_reason: str = Field(
description="The reason for generation to stop. "
"Defined by FinishReason. One of stop, length, or time."
)
dsikka marked this conversation as resolved.
Show resolved Hide resolved


# TODO: Pydantic aliases allow assignment but not reference. Still need to update.
class TextGenerationOutput(BaseModel):
sequences: Union[str, List[str], List[List[str]]] = Field(
description="The generated text sequences.",
created: datetime.datetime = Field(description="Time of inference creation.")
prompts: Union[str, List[str]] = Field(
description="Prompts used for the sequence generation. For multiple input "
"prompts, a list of prompts is returned"
)
logits: Optional[Any] = Field( # numpy array, set to Any for FastAPI compatibility
default=None,
description="The logits for the generated text sequence."
"The logits have dimensions "
"[batch_size, sequence_length, vocab_size]",
generations: Union[List[GeneratedText], List[List[GeneratedText]]] = Field(
description="For a single prompt, a single list of GeneratedText is returned. "
"If multiple prompts are given, a list of GeneratedText is returned for each "
"prompt provided. If streamng is enabled, the next generated token is returned."
"Otherwise, the full generated sequence is returned."
)
session_id: Optional[str] = Field(
default=None, description="A string identifier for the kv cache session."
Expand Down Expand Up @@ -401,6 +428,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
# If the num_generated_predictions > 1, repeat the prompt
# num_generated_predictions times. Also, update the engine so that deterministic
# is set to False.
original_inputs = inputs.sequences
if inputs.num_generated_predictions > 1:
if isinstance(inputs.sequences, str):
inputs.sequences = [inputs.sequences]
Expand Down Expand Up @@ -457,6 +485,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
self.multitoken_engine.session_id = inputs.session_id

context = dict(
prompts=original_inputs,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
streamer=inputs.streamer,
Expand All @@ -473,39 +502,71 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
return engine_input, context

def process_engine_outputs(
self, engine_outputs: List[numpy.ndarray], **kwargs
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
) -> TextGenerationOutput:
"""
Convert the engine outputs to the output schema for the pipeline.

:param engine_outputs: the outputs from the engine
:return: the output schema for the pipeline
"""
generated_tokens, generated_logits = engine_outputs
generated_tokens, generated_logits, finished_reason = engine_outputs
finished_reason = [f[0] for f in finished_reason]

sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
num_preds = kwargs.get("num_generated_predictions", 1)
# If the num_generated_predictions > 1, group the generated sequences and return
# the sequences as a list of lists where each list consists of the generated
prompts = kwargs.get("prompts")

def _create_generated_text_output(
sequence: str,
finish_reason: FinishReason,
logits: Optional[numpy.array] = None,
):
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)

logits = generated_logits if kwargs.get("return_logits") else None

if logits is not None:
generations = list(
self.executor.map(
_create_generated_text_output,
sequences,
finished_reason,
logits,
)
)
else:
generations = list(
self.executor.map(
_create_generated_text_output, sequences, finished_reason
)
)

# If the num_generated_predictions > 1, group the generations and return
# them as a list of lists where each list consists of the generated
# predictions for a given prompt, and all the lists are in the order matching
# the order that the prompts were given as inputs.
if num_preds > 1:
grouped_seq = [
sequences[n : n + num_preds]
for n in range(0, len(sequences), num_preds)
grouped_generations = [
generations[n : n + num_preds]
for n in range(0, len(generations), num_preds)
]
sequences = grouped_seq
generations = grouped_generations

logits = generated_logits if kwargs.get("return_logits") else None

return TextGenerationOutput(sequences=sequences, logits=logits)
return TextGenerationOutput(
created=datetime.datetime.now(), prompts=prompts, generations=generations
)

def engine_forward(
self,
engine_inputs: List[numpy.ndarray],
context: Dict,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
self, engine_inputs: List[numpy.ndarray], context: Dict
) -> Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason]]:
"""
Run the forward pass on the engine.

Expand All @@ -522,6 +583,7 @@ def engine_forward(

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")
finished_reason = []

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand Down Expand Up @@ -583,27 +645,35 @@ def engine_forward(
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
):
finished_reason.append(FinishReason.STOP)
break

if self._stop_token_generated(token, stop_tokens=stop):
_LOGGER.debug(
"Stop token %s generated. Stopping generation."
% self.tokenizer.decode(token)
)
finished_reason.append(FinishReason.STOP)
break

# TODO: Add any generic callback reason?
if callback is not None and callback(token) is False:
_LOGGER.debug(
"callback %s returned False, stopping generation."
% callback.__qualname__
)
break

if len(generated_tokens) == max_tokens:
finished_reason.append(FinishReason.LENGTH)

if streamer is not None:
streamer.end()

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
return (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

def prompt_inference(
Expand Down Expand Up @@ -793,8 +863,10 @@ def is_cache_support_enabled(self) -> bool:
return any(default_cached_outputs(self.onnx_file_path))

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
) -> List[numpy.ndarray]:
self,
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
orig_batch_size: int,
) -> List[Union[numpy.ndarray, FinishReason]]:
"""
Takes a list of outputs (batches) from the engine
and joins them into a single output. Asserts that
Expand All @@ -805,7 +877,7 @@ def join_engine_outputs(
:param orig_batch_size: The original batch size
:return: A list of joined outputs
"""
tokens, logits = zip(*batch_outputs)
tokens, logits, finish_reason = zip(*batch_outputs)
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
Expand Down Expand Up @@ -837,7 +909,7 @@ def join_engine_outputs(
tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)

return [tokens, logits]
return [tokens, logits, finish_reason]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
Expand Down
38 changes: 24 additions & 14 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,12 @@ def test_run_same_prompt_multiple_times(self, setup):
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)
assert output_1.sequences[0] == output_2.sequences[0]
assert numpy.allclose(output_1.logits, output_2.logits, atol=_PRECISION)
assert output_1.generations[0].text == output_2.generations[0].text
assert numpy.allclose(
output_1.generations[0].score,
output_2.generations[0].score,
atol=_PRECISION,
)

def test_run_multiple_prompts_in_parallel(self, setup):
# Test the scenario, where multiple prompts are run in parallel
Expand All @@ -393,9 +397,14 @@ def test_run_multiple_prompts_in_parallel(self, setup):
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)
logits_0 = output.generations[0].score
sequence_0 = output.generations[0].text

logits_1 = output.generations[1].score
sequence_1 = output.generations[1].text

assert numpy.allclose(output.logits[0], output.logits[1], atol=_PRECISION)
assert output.sequences[0] == output.sequences[1]
assert numpy.allclose(logits_0, logits_1, atol=_PRECISION)
assert sequence_0 == sequence_1

def test_num_generated_predictions(self, setup):
# Test the scenario, where multiple predictions are generated
Expand All @@ -405,14 +414,16 @@ def test_num_generated_predictions(self, setup):
output_sequences = pipeline(
sequences=[self.prompt], num_generated_predictions=2
)
assert len(output_sequences.sequences[0]) == 2
assert len(output_sequences.generations) == 1
assert len(output_sequences.generations[0]) == 2

output_sequences = pipeline(
sequences=[self.prompt, self.prompt], num_generated_predictions=2
)
assert len(output_sequences.sequences) == 2
for sequences in output_sequences.sequences:
assert len(sequences) == 2
assert len(output_sequences.generations) == 2

for generation in output_sequences.generations:
assert len(generation) == 2

def _test_output(
self,
Expand All @@ -434,6 +445,7 @@ def _test_output(

# concatenate target prompt_logits and generated_logits and check
target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1)
score = output.generations[0].score

if max_logits_difference_threshold:
# if comparing the output from the model where
Expand All @@ -442,18 +454,16 @@ def _test_output(
# to be less than the threshold
# (the threshold is established by running the
# ONNX model in ONNXRuntime)
assert (
abs(output.logits - target_logits).max()
< max_logits_difference_threshold
)
assert abs(score - target_logits[0]).max() < max_logits_difference_threshold
else:
# otherwise, we expect the logits to be exactly the same
# as the target logits; the generated sequence should
# also be the same as the target sequence, and finally
# (if applicable) the kv cache should be the same as the
# target kv cache
assert numpy.allclose(output.logits, target_logits, atol=_PRECISION)
assert self.prompt + output.sequences[0] == generated_text

assert numpy.allclose(score, target_logits[0], atol=_PRECISION)
assert self.prompt + output.generations[0].text == generated_text

if run_cache_validation:
self._test_kv_cache_state(
Expand Down