Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Samling arguments for generation #1225

Merged
merged 15 commits into from
Sep 15, 2023
6 changes: 4 additions & 2 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
f"Inputs parsed to {type(pipeline_inputs)}"
)
# batch size of the inputs may be `> self._batch_size` at this point
engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs)
engine_inputs = self.process_inputs(pipeline_inputs)
if isinstance(engine_inputs, tuple):
engine_inputs, context = engine_inputs
else:
Expand Down Expand Up @@ -494,7 +494,9 @@ def split_engine_inputs(
return split_engine_inputs(items, batch_size)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict = {}
self,
engine_inputs: List[numpy.ndarray],
context: Dict = {},
) -> List[numpy.ndarray]:
"""
:param engine_inputs: list of numpy inputs to Pipeline engine forward
Expand Down
27 changes: 3 additions & 24 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional

import numpy
from transformers import AutoTokenizer
Expand All @@ -24,7 +24,6 @@
generate_session_id,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX


Expand Down Expand Up @@ -177,7 +176,7 @@ def __call__(
self,
inp: List[numpy.ndarray],
val_inp: bool = True,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
) -> numpy.ndarray:
"""
The main entry point for running the engine.

Expand All @@ -193,7 +192,6 @@ def __call__(
inp = self.add_kv_cache_to_input(inp)

out = self.run(inp, val_inp)

if self.kv_cache:
logits, *kv_cache_state = out
self.update_kv_cache(
Expand All @@ -202,10 +200,7 @@ def __call__(
else:
logits = out[0]

# select batch idx 0, batch is always 1
token = self.generate_token(logits=logits[0, -1, :])

return token, logits
return logits
horheynm marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
return f"{self.__class__.__name__}: {self.engine}"
Expand All @@ -228,22 +223,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
cache.set_capacity(self.cache_length)
self.kv_cache = cache

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.

:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""
if self.deterministic:
return numpy.argmax(logits)

logits /= self.sampling_temperature

probs = numpy_softmax(logits)

return numpy.random.choice(len(probs), p=probs)

def reset_kv_cache(self):
"""
Resets the kv cache state.
Expand Down
83 changes: 65 additions & 18 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
create_causal_mask,
pad_to_fixed_length,
)
from deepsparse.transformers.utils.token_generator import TokenGenerator
from deepsparse.utils.onnx import default_cached_outputs


Expand Down Expand Up @@ -115,6 +116,29 @@ class Config:
" tokens is generated). Set to `None` to ignore this parameter."
" Default is `None`.",
)
top_p: Optional[float] = Field(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
default=0.0,
description="Used for filtering generated tokens. Keep the"
" tokens where its cumulative probability is >= top_p"
" Default set to 0.0",
)
top_k: Optional[int] = Field(
default=0,
description="Used for filtering generated tokens. Keep"
" top_k generated tokens. Default set to 0",
)
presence_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Any existing"
" token results in the subtraction of its corresponding logit value."
" Default set to 0.0",
)
frquency_peanlty: Optional[float] = Field(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
default=0.0,
description="Penalty applied for generating new token. Existing"
" token frequencies summed to subtraction the logit of its"
" corresponding logit value. Default set to 0.0.",
)


class TextGenerationOutput(BaseModel):
Expand Down Expand Up @@ -290,7 +314,6 @@ def initialize_engines(
if (
self.cache_support_enabled and self.enable_multitoken_prefill
) or not self.cache_support_enabled:

multitoken_engine = NLDecoderEngine(
onnx_file_path=self.onnx_file_path,
engine_type=self.engine_type,
Expand Down Expand Up @@ -414,7 +437,12 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
top_p=inputs.top_p,
top_k=inputs.top_k,
presence_penalty=inputs.presence_penalty,
frequency_penalty=inputs.presence_penalty,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
)

return engine_input, postprocessing_kwargs

def process_engine_outputs(
Expand All @@ -435,7 +463,9 @@ def process_engine_outputs(
return TextGenerationOutput(sequences=sequences, logits=logits)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict
self,
engine_inputs: List[numpy.ndarray],
context: Dict,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""
Run the forward pass on the engine.
Expand All @@ -450,20 +480,37 @@ def engine_forward(
# as such, a new context needs to be created since we are no longer in the
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")

if not self.cache_support_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits
prompt_logits = self.multitoken_engine(engine_inputs)
token_generator = TokenGenerator(
logits_shape=prompt_logits[-1].shape[-1],
deterministic=self.deterministic,
**context,
)
for prompt_logit in prompt_logits:
token_generator.generate(prompt_logit)
return numpy.array([self.tokens]), prompt_logits

else:
# run the prompt through
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)
prompt_logits = self.prompt_inference(engine_inputs)

tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
horheynm marked this conversation as resolved.
Show resolved Hide resolved
token_generator = TokenGenerator(
logits_shape=prompt_logits[-1].shape[-1],
tokens=tokens,
deterministic=self.deterministic,
**context,
)
token_generator.generate(prompt_logits[-1][0, -1, :])

if streamer is not None:
streamer.put(numpy.array(tokens))
streamer.put(numpy.array(token_generator.tokens))

# create the generated output
max_tokens = (
Expand All @@ -474,7 +521,7 @@ def engine_forward(

# last prompt token is the first generated token
# add it to generated tokens, and the logits
generated_tokens = [tokens[-1]]
generated_tokens = [token_generator.tokens[-1]]
generated_logits = (
prompt_logits
if context.get("include_prompt_logits")
Expand All @@ -486,8 +533,10 @@ def engine_forward(
with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
logits = self.autoregressive_inference(
tokens=token_generator.tokens
)
token = token_generator.generate(logits=logits[0, -1, :])
generated_tokens.append(token)
generated_logits.append(logits)

Expand Down Expand Up @@ -522,7 +571,8 @@ def engine_forward(
)

def prompt_inference(
self, engine_inputs: List[numpy.ndarray]
self,
engine_inputs: List[numpy.ndarray],
) -> Tuple[List[int], List[numpy.ndarray]]:
"""
An inference run that processes the prompt through the
Expand All @@ -539,7 +589,6 @@ def prompt_inference(
tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()

prompt_logits = []
new_token = None
num_tokens_processed = 0

if (
Expand All @@ -548,7 +597,7 @@ def prompt_inference(
):
self.multitoken_engine.reset_kv_cache()
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed += self.prompt_processing_sequence_length
bfineran marked this conversation as resolved.
Show resolved Hide resolved
prompt_logits.append(new_logits)

Expand All @@ -565,13 +614,11 @@ def prompt_inference(
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(run_tokens)
new_logits = self.autoregressive_inference(run_tokens)

prompt_logits.append(new_logits)

tokens.append(new_token)

return tokens, prompt_logits
return prompt_logits

def autoregressive_inference(
self,
Expand Down Expand Up @@ -608,9 +655,9 @@ def autoregressive_inference(
engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache
]

generated_token, generated_logits = self.engine(engine_inputs)
generated_logits = self.engine(engine_inputs)

return generated_token, generated_logits
return generated_logits

def engine_inputs_for_prefill(
self, tokens: List[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def _get_tag(self, entity_name: str) -> Tuple[str, str]:
return bi, tag

def _group_entities(self, entities: List[dict]) -> List[dict]:

horheynm marked this conversation as resolved.
Show resolved Hide resolved
entity_groups = []
entity_group_disagg = []

Expand Down
Loading