diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 2b6b6c9854..1fb09cca0f 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -232,7 +232,7 @@ def __call__(self, *args, **kwargs) -> BaseModel: f"Inputs parsed to {type(pipeline_inputs)}" ) # batch size of the inputs may be `> self._batch_size` at this point - engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs) + engine_inputs = self.process_inputs(pipeline_inputs) if isinstance(engine_inputs, tuple): engine_inputs, context = engine_inputs else: @@ -494,7 +494,9 @@ def split_engine_inputs( return split_engine_inputs(items, batch_size) def engine_forward( - self, engine_inputs: List[numpy.ndarray], context: Dict = {} + self, + engine_inputs: List[numpy.ndarray], + context: Dict = {}, ) -> List[numpy.ndarray]: """ :param engine_inputs: list of numpy inputs to Pipeline engine forward diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 980f0773c1..223d4f0a60 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import numpy from transformers import AutoTokenizer @@ -23,7 +23,6 @@ from deepsparse.transformers.utils.helpers import generate_session_id from deepsparse.transformers.utils.timings import TextGenerationTimings from deepsparse.utils import TimerManager -from deepsparse.utils.data import numpy_softmax from deepsparse.utils.onnx import ( CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX, @@ -184,7 +183,7 @@ def __call__( self, inp: List[numpy.ndarray], val_inp: bool = True, - ) -> Tuple[numpy.ndarray, numpy.ndarray]: + ) -> numpy.ndarray: """ The main entry point for running the engine. @@ -212,10 +211,7 @@ def __call__( else: logits = out[0] - # select batch idx 0, batch is always 1 - token = self.generate_token(logits=logits[0, -1, :]) - - return token, logits + return logits def __str__(self): return f"{self.__class__.__name__}: {self.engine}" @@ -238,22 +234,6 @@ def transfer_cache_state(self, cache: DecoderKVCache): cache.set_capacity(self.cache_length) self.kv_cache = cache - def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray: - """ - Samples a token from the logits using the sampling temperature. - - :param logits: the logits from the model with shape (vocab_size,) - :return: the sampled token - """ - if self.deterministic: - return numpy.argmax(logits) - - logits /= self.sampling_temperature - - probs = numpy_softmax(logits) - - return numpy.random.choice(len(probs), p=probs) - def reset_kv_cache(self): """ Resets the kv cache state. diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 15546ca6fa..f9ccb4b78a 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -43,6 +43,7 @@ repeat_inputs, ) from deepsparse.transformers.utils.timings import TextGenerationTimings +from deepsparse.transformers.utils.token_generator import TokenGenerator from deepsparse.utils.onnx import default_cached_outputs @@ -120,6 +121,29 @@ class Config: " tokens is generated). Set to `None` to ignore this parameter." " Default is `None`.", ) + top_p: Optional[float] = Field( + default=0.0, + description="Used for filtering generated tokens. Keep the" + " tokens where its cumulative probability is >= top_p" + " Default set to 0.0", + ) + top_k: Optional[int] = Field( + default=0, + description="Used for filtering generated tokens. Keep" + " top_k generated tokens. Default set to 0", + ) + presence_penalty: Optional[float] = Field( + default=0.0, + description="Penalty applied for generating new token. Any existing" + " token results in the subtraction of its corresponding logit value." + " Default set to 0.0", + ) + frequency_penalty: Optional[float] = Field( + default=0.0, + description="Penalty applied for generating new token. Existing" + " token frequencies summed to subtraction the logit of its" + " corresponding logit value. Default set to 0.0.", + ) class TextGenerationOutput(BaseModel): @@ -440,8 +464,13 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: include_prompt_logits=inputs.include_prompt_logits, callback=inputs.callback, stop=inputs.stop, + top_p=inputs.top_p, + top_k=inputs.top_k, + presence_penalty=inputs.presence_penalty, + frequency_penalty=inputs.frequency_penalty, max_tokens=inputs.max_tokens, ) + return engine_input, context def process_engine_outputs( @@ -474,7 +503,9 @@ def process_engine_outputs( return TextGenerationOutput(sequences=sequences, logits=logits) def engine_forward( - self, engine_inputs: List[numpy.ndarray], context: Dict + self, + engine_inputs: List[numpy.ndarray], + context: Dict, ) -> Tuple[numpy.ndarray, numpy.ndarray]: """ Run the forward pass on the engine. @@ -489,20 +520,37 @@ def engine_forward( # as such, a new context needs to be created since we are no longer in the # main thread. That is why `engine_` is prepended to each of the timer phase # names in this context + with self.timer_manager.new_timer_context(total_inference=False) as timer: streamer = context.get("streamer") if not self.cache_support_enabled: - tokens, prompt_logits = self.multitoken_engine(engine_inputs) - return numpy.array([tokens]), prompt_logits + prompt_logits = self.multitoken_engine(engine_inputs) + token_generator = TokenGenerator( + logits_shape=prompt_logits[-1].shape[-1], + deterministic=self.deterministic, + **context, + ) + for prompt_logit in prompt_logits: + token_generator.generate(prompt_logit) + return numpy.array([self.tokens]), prompt_logits else: # run the prompt through with timer.time(TextGenerationTimings.PROMPT_PREFILL): - tokens, prompt_logits = self.prompt_inference(engine_inputs) + prompt_logits = self.prompt_inference(engine_inputs) + + tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() + token_generator = TokenGenerator( + logits_shape=prompt_logits[-1].shape[-1], + tokens=tokens, + deterministic=self.deterministic, + **context, + ) + token_generator.generate(prompt_logits[-1][0, -1, :]) if streamer is not None: - streamer.put(numpy.array(tokens)) + streamer.put(numpy.array(token_generator.tokens)) # create the generated output max_tokens = context.get("max_tokens", 0) @@ -510,7 +558,7 @@ def engine_forward( # last prompt token is the first generated token # add it to generated tokens, and the logits - generated_tokens = [tokens[-1]] + generated_tokens = [token_generator.tokens[-1]] generated_logits = ( prompt_logits if context.get("include_prompt_logits") @@ -522,8 +570,10 @@ def engine_forward( with timer.time(TextGenerationTimings.TOKEN_GENERATION): while len(generated_tokens) < max_tokens: with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE): - token, logits = self.autoregressive_inference(tokens) - tokens.append(token) + logits = self.autoregressive_inference( + tokens=token_generator.tokens + ) + token = token_generator.generate(logits=logits[0, -1, :]) generated_tokens.append(token) generated_logits.append(logits) @@ -558,7 +608,8 @@ def engine_forward( ) def prompt_inference( - self, engine_inputs: List[numpy.ndarray] + self, + engine_inputs: List[numpy.ndarray], ) -> Tuple[List[int], List[numpy.ndarray]]: """ An inference run that processes the prompt through the @@ -575,13 +626,12 @@ def prompt_inference( tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() prompt_logits = [] - new_token = None num_tokens_processed = 0 if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill: self.multitoken_engine.reset_kv_cache() for engine_inputs in self.engine_inputs_for_prefill(tokens): - new_token, new_logits = self.multitoken_engine(engine_inputs) + new_logits = self.multitoken_engine(engine_inputs) num_tokens_processed += self.prompt_sequence_length prompt_logits.append(new_logits) @@ -599,13 +649,11 @@ def prompt_inference( with self.timer_manager.current.time( TextGenerationTimings.PROMPT_PREFILL_SINGLE ): - new_token, new_logits = self.autoregressive_inference(run_tokens) + new_logits = self.autoregressive_inference(run_tokens) prompt_logits.append(new_logits) - tokens.append(new_token) - - return tokens, prompt_logits + return prompt_logits def autoregressive_inference( self, @@ -642,9 +690,9 @@ def autoregressive_inference( engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache ] - generated_token, generated_logits = self.engine(engine_inputs) + generated_logits = self.engine(engine_inputs) - return generated_token, generated_logits + return generated_logits def engine_inputs_for_prefill( self, tokens: List[int] diff --git a/src/deepsparse/transformers/pipelines/token_classification.py b/src/deepsparse/transformers/pipelines/token_classification.py index c7f32cc301..66957fce97 100644 --- a/src/deepsparse/transformers/pipelines/token_classification.py +++ b/src/deepsparse/transformers/pipelines/token_classification.py @@ -522,7 +522,6 @@ def _get_tag(self, entity_name: str) -> Tuple[str, str]: return bi, tag def _group_entities(self, entities: List[dict]) -> List[dict]: - entity_groups = [] entity_group_disagg = [] diff --git a/src/deepsparse/transformers/utils/token_generator.py b/src/deepsparse/transformers/utils/token_generator.py new file mode 100644 index 0000000000..3fc61740c2 --- /dev/null +++ b/src/deepsparse/transformers/utils/token_generator.py @@ -0,0 +1,173 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy + +from deepsparse.utils.data import numpy_softmax + + +class TokenGenerator: + """ + Responsible for generating tokens, and contains functions that + token generation depends on including different sampling and + filtering methods + """ + + def __init__( + self, + logits_shape: int, + tokens: List[int] = [], + deterministic: bool = True, + sampling_temperature: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + **kwargs, + ): + """ + :param logits_shape: int representing the size/length of the logit + to be used. Note that generated token will have the upper bound of + this value + :param tokens: Any previously generated tokens. Used to keep frequncy counts + to be used for penalty calculations + :param deterministic: set to True will always return the same output with the + same inputs + :param sampling_temperature: used to add randomness to the generated token + :param top_k: select top_k logit values + :param top_p: select the cumulative sum of the logits values outside of top_p + :param frequency_penalty: subtract its value and its token frequency count + to thelogit + :param presence_penalty: subtract any corresponding logit with existing tokens + """ + self.token_frequencies = numpy.zeros(logits_shape) + + self.deterministic = deterministic + self.sampling_temperature = sampling_temperature + self.top_k = top_k + self.top_p = top_p + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.tokens = tokens + + self._initialize_token_frequencies() + + def generate(self, logits: numpy.ndarray) -> numpy.ndarray: + """ + Samples a token from the logits. If non-deterministic, logits that tokens + get generated from will be a function of sampling_temperature, top_k, top_p, + frequency_penalty and presence_penalty. + + :param logits: the logits from the model with shape (vocab_size,) + :return: the sampled token + """ + if self.top_k: + logits = self.apply_top_k(logits) + if self.top_p: + logits = self.apply_top_p(logits) + + if self.deterministic: + token = numpy.argmax(logits) + self.tokens.append(token) + return token + + if self.sampling_temperature != 1.0: + logits /= self.sampling_temperature + + if self.frequency_penalty != 0.0: + logits = self.apply_frequency_penalty(logits) + if self.presence_penalty != 0.0: + logits = self.apply_presence_penalty(logits) + + probs = numpy_softmax(logits) + token = numpy.random.choice(len(probs), p=probs) + + self.tokens.append(token) + self._update_frequencies(token) + + return token + + def apply_frequency_penalty(self, logits: numpy.ndarray) -> numpy.ndarray: + """Apply frequency_penalty based on the token frequency count""" + logits -= self.frequency_penalty * self.token_frequencies + return logits + + def apply_presence_penalty(self, logits: numpy.ndarray) -> numpy.ndarray: + """ + Apply prensence_penaly to any logits where there exists + a token + """ + logits -= self.presence_penalty * (self.token_frequencies > 0) + return logits + + # from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf31 + def apply_top_k( + self, logits: numpy.ndarray, filter_value=-float("Inf") + ) -> numpy.ndarray: + """ + Keep top_k logits based on its value. All other values + will be overwritten to filter_value + + :param filter_value: value to overwrite non-top_k values + """ + logits_shape = logits.shape + logits = logits.reshape(logits.shape[-1]) + top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :] + logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value + + return logits.reshape(logits_shape) + + # from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + def apply_top_p( + self, + logits: numpy.ndarray, + filter_value=-float("Inf"), + min_tokens_to_keep: int = 1, + ) -> numpy.ndarray: + """ + Keep any logits' cumulative sum <= top_p. non top_p logits will be + overwritten to filter_value + + :param filter_value: value to overwrite non-top_p values + :param min_tokens_to_keep: number of logit values to keep to avoid + zero valued logits + """ + logits_shape = logits.shape + logits = logits.reshape(logits.shape[-1]) + + sorted_indices = numpy.argsort(logits) + sorted_logits = logits[sorted_indices] + logit_cumulative_probs = numpy.cumsum(numpy_softmax(sorted_logits)) + + # Remove tokens with cumulative top_p above the threshold + # (token with 0 are kept) + sorted_indices_to_remove = logit_cumulative_probs > self.top_p + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + + return logits.reshape(logits_shape) + + def _update_frequencies(self, token: numpy.ndarray): + self.token_frequencies[token] += 1 + + def _initialize_token_frequencies(self): + unique_tokens, frequencies = numpy.unique(self.tokens, return_counts=True) + for token, frequnecies in zip(unique_tokens, frequencies): + self.token_frequencies[token] += frequnecies diff --git a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py index f4c8cc2f97..7d80aa6ada 100644 --- a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py +++ b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py @@ -17,7 +17,6 @@ import numpy as np from deepsparse.transformers.engines import NLDecoderEngine -from flaky import flaky class DummyKVCacheDecoder: @@ -32,20 +31,6 @@ class DummyEngine: input_names = ["input_1", "input_2", "past_key_values_1", "past_key_values_2"] -@flaky(max_runs=10, min_passes=1) -def test_generate_token(): - logits = np.array([1.0, 11, 0.9, 0.8]) - expected_token = 1 - - with patch.object(NLDecoderEngine, "__init__", lambda x, y, z: None): - engine = NLDecoderEngine(None, None) - engine.deterministic = False - engine.sampling_temperature = 1.0 - token = engine.generate_token(logits) - - assert expected_token == token - - def test_add_kv_cache_to_input(): # keep only the first two inputs # (corresponding to "input_1" and "input_2") diff --git a/tests/deepsparse/transformers/utils/test_token_generator.py b/tests/deepsparse/transformers/utils/test_token_generator.py new file mode 100644 index 0000000000..7f5e4b0751 --- /dev/null +++ b/tests/deepsparse/transformers/utils/test_token_generator.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import List, Tuple, Union + +import numpy + +import pytest +from deepsparse.transformers.utils.token_generator import TokenGenerator + + +@pytest.fixture(scope="function") +def logits_fixture() -> numpy.array: + def get(shape: Tuple = (1, 1, 51200), token_max_thresh: int = 30, low: int = -30): + return numpy.random.uniform(low, token_max_thresh, size=shape) + + return get + + +@pytest.fixture(scope="function") +def token_fixture() -> List[int]: + def get(shape: Union[int, Tuple] = 5, token_max_thresh: int = 51200): + return numpy.random.randint(0, token_max_thresh, size=shape).tolist() + + return get + + +class TestTokenGenerator: + def test_update_frequencies( + self, logits_fixture, token_fixture, token_max_thresh: int = 51200 + ): + logits, tokens = logits_fixture(), token_fixture( + token_max_thresh=token_max_thresh + ) + token_generator = TokenGenerator( + logits_shape=logits[-1].shape[-1], tokens=tokens.copy() + ) + + assert token_generator.tokens == tokens + + freq = defaultdict(int) + for token in token_generator.tokens: + freq[token] += 1 + + for key, value in freq.items(): + assert token_generator.token_frequencies[key] == value + + # test TokenGenerator._update_frequencies + new_token = token_fixture(shape=1)[0] + token_generator.tokens.append(new_token) + token_generator._update_frequencies(new_token) + + assert token_generator.tokens == tokens + [new_token] + freq[new_token] += 1 + for key, value in freq.items(): + assert token_generator.token_frequencies[key] == value + + def test_apply_frequency_penalty( + self, + logits_fixture, + token_fixture, + ): + logits, tokens = logits_fixture(), token_fixture() + frequency_penalty = 1.0 + token_generator = TokenGenerator( + logits_shape=logits[-1].shape[-1], + tokens=(tokens + tokens), + frequency_penalty=frequency_penalty, + ) + + test_logits = token_generator.token_frequencies + # numpy arrays by default are pass by ref + new_logits = token_generator.apply_frequency_penalty(test_logits.copy()) + assert new_logits.shape == test_logits.shape + assert numpy.sum(new_logits) == 0 + + def test_apply_presence_penalty( + self, + logits_fixture, + token_fixture, + ): + logits, tokens = logits_fixture(), token_fixture() + presence_penalty = 1.0 + token_generator = TokenGenerator( + logits_shape=logits[-1].shape[-1], + tokens=(tokens + tokens), + presence_penalty=presence_penalty, + ) + test_logits = token_generator.token_frequencies + # numpy arrays by default are pass by ref + new_logits = token_generator.apply_presence_penalty(test_logits.copy()) + assert new_logits.shape == test_logits.shape + assert numpy.sum(new_logits) == 0.5 * numpy.sum(test_logits) + + def test_apply_topk( + self, + ): + # logits for opt usually have shape (1,1,51200) + logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11)) + + token_generator = TokenGenerator( + logits_shape=logits[-1].shape[-1], + top_k=3, + ) + + filter_value = -float("Inf") + new_logits = token_generator.apply_top_k( + logits.copy(), filter_value=filter_value + ) + + for _ in range(token_generator.top_k): + curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits) + assert curr_max > filter_value + new_logits = numpy.delete(new_logits, idx) + + assert numpy.all(new_logits == filter_value) + + def test_apply_top_p( + self, + ): + # logits for opt usually have shape (1,1,51200) + logits = 0.1 * numpy.ones(10).reshape((1, 1, 10)) + + token_generator = TokenGenerator( + logits_shape=logits[-1].shape[-1], + top_p=0.89, + ) + + filter_value = -float("Inf") + new_logits = token_generator.apply_top_p( + logits.copy(), filter_value=filter_value + ) + for _ in range(1): + curr_min, idx = numpy.min(new_logits), numpy.argmin(new_logits) + assert curr_min == filter_value + new_logits = numpy.delete(new_logits, idx) + + assert numpy.all(new_logits != filter_value) + + def test_generate_token( + self, + logits_fixture, + token_fixture, + ): + logits, tokens = logits_fixture(), token_fixture() + token_generator = TokenGenerator( + logits_shape=logits[-1].shape[-1], + tokens=(tokens + tokens), + deterministic=False, + ) + new_token = token_generator.generate(logits=logits[0, -1, :]) + assert new_token == token_generator.tokens[-1] + assert len(token_generator.tokens) == len(tokens + tokens) + 1