Skip to content

Commit

Permalink
Merge branch 'main' into feature/damian/simplify_serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored Nov 30, 2023
2 parents f405947 + ce60541 commit db74b56
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/deepsparse/transformers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Spinning up:
```bash
deepsparse.server \
task text-generation \
--model_path zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned50_quantized
--model_path zoo:opt-1.3b-opt_pretrain-pruned50_quantW8A8
```

Making a request:
Expand All @@ -172,12 +172,12 @@ import requests

url = "http://localhost:5543/v2/models/text_generation/infer" # Server's port default to 5543

obj = {"prompt": "Who is the president of the United States?"}
obj = {"prompt": "Large language models are"}

response = requests.post(url, json=obj)
print(response.json().text)
print(response.json()["generations"][0]["text"])

>> 'The president of the United States is the head of the executive branch of government...'
>> ' often used to model the language of a large number of users...'
```

### Sentiment Analysis
Expand Down
59 changes: 56 additions & 3 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import datetime
import logging
import os
Expand Down Expand Up @@ -580,10 +581,24 @@ def _stream_engine_outputs(
self, engine_outputs, prompts, generation_config, **kwargs
):
for output in engine_outputs:
generated_tokens, generated_logits, finished_reason = output
(
generated_tokens,
generated_logits,
finished_reason,
past_tokens_queue,
) = output
logits = generated_logits if generation_config.output_scores else None
from transformers import LlamaTokenizer, LlamaTokenizerFast

if isinstance(self.tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
# temporary fix for LLama2/Mistral/... models
generated_string = self._generate_streamed_text_from_past_tokens(
generated_tokens, past_tokens_queue
)
else:
generated_string = self.tokenizer.batch_decode(generated_tokens)[0]
generation = self._create_generated_text_output(
self.tokenizer.batch_decode(generated_tokens)[0],
generated_string,
finished_reason[0],
logits,
)
Expand All @@ -601,6 +616,33 @@ def _stream_engine_outputs(
**schema_kwargs,
)

def _generate_streamed_text_from_past_tokens(
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
) -> str:
"""
An auxiliary method that helps to properly generate the streamed text.
Some models like llama2 and mistral are using LlamaTokenizer which is
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
to output appropriate prefix spaces when decoding token by token.
One can make it work if the previously generated tokens are included.
This allows the tokenizer to figure out that the appropriate spaces
from last n consecutive tokens.
:param generated_tokens: the generated tokens from the engine
:param past_tokens_queue: the queue of last n tokens (n is the
original prompt length in tokens)
:return: the generated string
"""
string_from_n_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.append(generated_tokens[0])
string_from_n_plus_1_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.pop(0)
return string_from_n_plus_1_tokens[len(string_from_n_tokens) :]

def process_engine_outputs(
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
) -> TextGenerationOutput:
Expand Down Expand Up @@ -738,6 +780,9 @@ def engine_forward(
prompt_logits, session = self.prompt_inference(engine_inputs)

tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
# copy the tokens so that we can use them for streaming
past_tokens_queue = copy.copy(tokens)

token_generator = TokenGenerator(
logits_shape=prompt_logits[-1].shape[-1],
tokens=tokens,
Expand Down Expand Up @@ -776,6 +821,7 @@ def engine_forward(
numpy.array([generated_tokens[-1]]),
numpy.array([generated_logits[-1]]),
[None],
past_tokens_queue,
)

while len(generated_tokens) < max_tokens:
Expand Down Expand Up @@ -816,7 +862,12 @@ def engine_forward(
break

if streaming:
yield (numpy.array([token]), numpy.array([logits]), [None])
yield (
numpy.array([token]),
numpy.array([logits]),
[None],
past_tokens_queue,
)

# Run the autoregressive inference only to put the
# kv cache entry for the last generated token into the
Expand All @@ -831,12 +882,14 @@ def engine_forward(
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
[FinishReason.LENGTH],
past_tokens_queue,
)
else:
yield (
numpy.array([token]),
numpy.array([logits]),
[finished_reason[-1]],
past_tokens_queue,
)

if not streaming:
Expand Down
26 changes: 26 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,29 @@ def test_streaming_mode_returns_generator(pipeline, prompt):
isinstance(response, pipeline.output_schema) for response in response_generator
), "Pipeline should return a generator of output_schema \
objects in streaming mode"


def test_streaming_with_several_prompts(pipeline, prompt):
additional_prompt = "Never gonna run around and desert you"
prompts = [prompt, additional_prompt]

generations_first_prompt_only = list(pipeline(prompt=prompts[0], streaming=True))
generations_second_prompt_only = list(pipeline(prompt=prompts[1], streaming=True))

bag_of_words_first_prompt = [
g.generations[0].text for g in generations_first_prompt_only
]
bag_of_words_second_prompt = [
g.generations[0].text for g in generations_second_prompt_only
]

generations = pipeline(prompt=prompts, streaming=True)
bag_of_words_shared = []
for r in generations:
for gen in r.generations:
text = gen.text
bag_of_words_shared.append(text)

assert sorted(bag_of_words_first_prompt + bag_of_words_second_prompt) == sorted(
bag_of_words_shared
)

0 comments on commit db74b56

Please sign in to comment.