Skip to content

Commit

Permalink
[Feature Branch][LLM Testing] LLM Testing Suite (#1227)
Browse files Browse the repository at this point in the history
* Update README.md

* Update src/deepsparse/yolov8/README.md

* Update text_generation.py

* quality

* readability

* all tests passing

* added some full kv cache tests

* initial commit

* ready for review

* Delete tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md
  • Loading branch information
dbogunowicz authored Sep 7, 2023
1 parent d04d6ee commit 8985a9b
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 88 deletions.
17 changes: 15 additions & 2 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,18 @@ def initialize_engines(
self.cache_support_enabled and self.enable_multitoken_prefill
) or not self.cache_support_enabled:

# input_ids_length for the multitoken engine is either:
# - the prompt_processing_sequence_length if the cache support is enabled
# (the prompt is processed sequentially at predefined processing length)
# - the full sequence_length if the cache support is disabled
# (the prompt is processed in a single pass, prompts length is fixed at
# sequence_length)
input_ids_length = (
self.prompt_processing_sequence_length
if self.cache_support_enabled
else self.sequence_length
)

multitoken_engine = NLDecoderEngine(
onnx_file_path=self.onnx_file_path,
engine_type=self.engine_type,
Expand All @@ -299,7 +311,7 @@ def initialize_engines(
sampling_temperature=self.sampling_temperature,
deterministic=self.deterministic,
sequence_length=self.sequence_length,
input_ids_length=self.prompt_processing_sequence_length,
input_ids_length=input_ids_length,
tokenizer=self.tokenizer,
use_deepsparse_cache=self.use_deepsparse_cache,
)
Expand Down Expand Up @@ -552,10 +564,11 @@ def prompt_inference(
num_tokens_processed += self.prompt_processing_sequence_length
prompt_logits.append(new_logits)

self.engine.reset_kv_cache()
if num_tokens_processed:
# transfer the cache state from the multi-token engine to the main engine
self.engine.transfer_cache_state(cache=self.multitoken_engine.kv_cache)
else:
self.engine.reset_kv_cache()

# prompt size is small, run autoregressive inference to populate kv cache
run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed]
Expand Down

This file was deleted.

Loading

0 comments on commit 8985a9b

Please sign in to comment.