From b7133a069a8de4d40926c438c64647c4ceb5425a Mon Sep 17 00:00:00 2001 From: Damian Date: Mon, 28 Aug 2023 08:54:18 +0000 Subject: [PATCH 1/9] initial commit --- src/deepsparse/transformers/metrics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index f93756c55a..e3ed19e8df 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -118,7 +118,10 @@ def add_batch(self, predictions: List[str]): labels = encoded_batch out = self._pipeline( - sequences=predictions, return_logits=True, fixed_sequences_length=True + sequences=predictions, + return_logits=True, + fixed_sequences_length=True, + include_prompt_logits=True, ) logits = out.logits From a47f97722a55b88663e70986260a81693deac9bf Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 29 Aug 2023 14:04:21 +0000 Subject: [PATCH 2/9] initial commit --- .../proposal_text_generation_tests.md | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md diff --git a/tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md b/tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md new file mode 100644 index 0000000000..ad015eb54c --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md @@ -0,0 +1,48 @@ +# test perplexity script +# test what's going on after we go past sequence_length +# sequences need to be appropriately long to notice the divergence over time + + +# metric to evaluate logits and cache -> min absolute difference +# maybe if env var set, we can also plot graphs + +OUR_MODELS = ["opt", "codegen", "llama"] +for model in OUR_MODELS: + + # establish sources of truth + torch_target_logits = ... + ort_target_logits = ... + torch_target_cache = ... + ort_target_cache = ... + + no_kv_cache_logits = model(kv_cache=False) + + for engine_type in ["onnxruntime", "deepsparse"] + ort_no_kv_cache_logits = ... + deepsparse_no_kv_cache_logits = ... + + if kv_cache: + for kv_cache_management in ["external", "internal"] + + ort_single_token_prefill_logits = ... + ort_multi_token_prefill_logits = ... + ort_single_token_prefill_cache = ... + ort_multi_token_prefill_cache = ... + + ds_single_token_prefill_logits_external = ... + ds_multi_token_prefill_logits_external = ... + ds_single_token_prefill_cache_external = ... + ds_multi_token_prefill_cache_external = ... + + ds_single_token_prefill_logits_internal = ... + ds_multi_token_prefill_logits_internal = ... + + + + + + + + + + From 04775517db3d6940256db234204e494690de9c6a Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 7 Sep 2023 11:07:44 +0200 Subject: [PATCH 3/9] [Feature Branch][LLM Testing] Create GroundTruthSource objects (#1219) * initial commit * finish creation of helper objects * Update tests/conftest.py * small refactor * [Feature Branch][LLM Testing] LLM Testing Suite (#1227) * Update README.md * Update src/deepsparse/yolov8/README.md * Update text_generation.py * quality * readability * all tests passing * added some full kv cache tests * initial commit * ready for review * Delete tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md --- .../transformers/pipelines/text_generation.py | 17 +- .../transformers/pipelines/helpers.py | 210 +++++++ .../proposal_text_generation_tests.md | 48 -- .../pipelines/test_text_generation.py | 556 +++++++++++------- 4 files changed, 570 insertions(+), 261 deletions(-) create mode 100644 tests/deepsparse/transformers/pipelines/helpers.py delete mode 100644 tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 92dfe04338..c0f917446d 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -291,6 +291,18 @@ def initialize_engines( self.cache_support_enabled and self.enable_multitoken_prefill ) or not self.cache_support_enabled: + # input_ids_length for the multitoken engine is either: + # - the prompt_sequence_length if the cache support is enabled + # (the prompt is processed sequentially at predefined processing length) + # - the full sequence_length if the cache support is disabled + # (the prompt is processed in a single pass, prompts length is fixed at + # sequence_length) + input_ids_length = ( + self.prompt_sequence_length + if self.cache_support_enabled + else self.sequence_length + ) + multitoken_engine = NLDecoderEngine( onnx_file_path=self.onnx_file_path, engine_type=self.engine_type, @@ -299,7 +311,7 @@ def initialize_engines( sampling_temperature=self.sampling_temperature, deterministic=self.deterministic, sequence_length=self.sequence_length, - input_ids_length=self.prompt_sequence_length, + input_ids_length=input_ids_length, tokenizer=self.tokenizer, internal_kv_cache=self.internal_kv_cache, ) @@ -549,10 +561,11 @@ def prompt_inference( num_tokens_processed += self.prompt_sequence_length prompt_logits.append(new_logits) - self.engine.reset_kv_cache() if num_tokens_processed: # transfer the cache state from the multi-token engine to the main engine self.engine.transfer_cache_state(cache=self.multitoken_engine.kv_cache) + else: + self.engine.reset_kv_cache() # prompt size is small, run autoregressive inference to populate kv cache run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed] diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/helpers.py new file mode 100644 index 0000000000..1edf7b5558 --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/helpers.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + +import numpy +import onnx +import onnxruntime +from transformers import AutoModelForCausalLM, AutoTokenizer + +from deepsparse.transformers.utils.helpers import ( + create_causal_mask, + overwrite_onnx_model_inputs_for_kv_cache_models, +) +from deepsparse.utils.onnx import CACHE_INPUT_PREFIX +from sparsezoo import Model + + +class GroundTruthSource(ABC): + def __init__(self, model_name: str): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.tokenizer = tokenizer + + @abstractmethod + def tokenize(self, prompt: str) -> Dict[str, Any]: + """ + :param prompt: The prompt to tokenize + :return: A dictionary of tokenized inputs + """ + raise NotImplementedError() + + @abstractmethod + def __call__(self, prompt: str) -> Any: + """ + :param prompt: The prompt to generate from + :return: Ground truth logits / cache state + """ + raise NotImplementedError() + + +class ORTGroundTruthSource(GroundTruthSource): + """ + An object that generates ground truth logits and + cache states from a prompt. This object cannot + generate tokens in an autoregressive manner, and thus + will only output prompt logits and prompt cache state + """ + + def __init__( + self, + model_stub: str, + model_name: str, + sequence_length: int = 256, + ): + super().__init__(model_name) + + self.model_onnx_path = Model(model_stub).deployment.get_file("model.onnx").path + overwrite_onnx_model_inputs_for_kv_cache_models( + self.model_onnx_path, + sequence_length=sequence_length, + input_ids_length=sequence_length, + ) + self.sequence_length = sequence_length + self.session = onnxruntime.InferenceSession(self.model_onnx_path) + self.model_inputs = [ + x.name + for x in onnx.load( + self.model_onnx_path, load_external_data=False + ).graph.input + ] + + def tokenize(self, prompt: str): + return self.tokenizer( + prompt, + return_tensors="np", + padding="max_length", + max_length=self.sequence_length, + ) + + def __call__(self, prompt: str) -> Tuple[numpy.ndarray, List[numpy.ndarray]]: + inputs = self.tokenize(prompt) + kv_cache = self._initialize_kv_cache_state() + + onnxruntime_inputs = dict( + attention_mask=inputs["attention_mask"], + input_ids=inputs["input_ids"], + **kv_cache, + ) + + if "positions" in self.model_inputs: + attention_mask = inputs["attention_mask"] + positions = attention_mask.cumsum(1) * attention_mask - 1 + onnxruntime_inputs["positions"] = positions + + if "causal_mask" in self.model_inputs: + causal_mask = create_causal_mask( + input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] + ) + onnxruntime_inputs["causal_mask"] = causal_mask + + # run inference and return the cache state + outputs = self.session.run(None, onnxruntime_inputs) + prompt_logits, *prompt_cache = outputs + + # remove logits that correspond to padding tokens + prompt_logits = numpy.compress( + onnxruntime_inputs["attention_mask"].flatten(), prompt_logits, axis=1 + ) # (1, prompt_length, vocab_size) + prompt_logits = prompt_logits[:, :-1, :] # (1, prompt_length, vocab_size) + + # remove cache that corresponds to padding tokens + prompt_cache = [ + numpy.compress( + onnxruntime_inputs["attention_mask"].flatten(), cache, axis=2 + ) + for cache in prompt_cache + ] # List[(1, num_heads, past_length, head_dim)] + + return prompt_logits, prompt_cache + + def _initialize_kv_cache_state(self, length: int = 0) -> Dict[str, numpy.ndarray]: + model = onnx.load(self.model_onnx_path, load_external_data=False) + + cache_input = next( + input + for input in model.graph.input + if input.name.startswith(CACHE_INPUT_PREFIX) + ) + # read the shape of the cache input + batch_size = cache_input.type.tensor_type.shape.dim[0].dim_value + num_attention_heads = cache_input.type.tensor_type.shape.dim[1].dim_value + hidden_dims = cache_input.type.tensor_type.shape.dim[3].dim_value + + # create a kv cache dictionary + kv_cache = { + input_.name: numpy.zeros( + (batch_size, num_attention_heads, length, hidden_dims), + dtype=numpy.float32, + ) + for input_ in model.graph.input + if input_.name.startswith(CACHE_INPUT_PREFIX) + } + return kv_cache + + +class TorchGroundTruthSource(GroundTruthSource): + """ + An object that generates ground truth logits and + cache states from a prompt. This object can + generate tokens in an autoregressive manner, and thus + will output prompt logits, generated logits, generated + sequence and prompt cache state + """ + + def __init__(self, num_tokens_to_generate: int, model_name: str): + super().__init__(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.num_tokens_to_generate = num_tokens_to_generate + + def tokenize(self, prompt: str): + return self.tokenizer(prompt, return_tensors="pt") + + def __call__( + self, prompt: str + ) -> Tuple[numpy.ndarray, numpy.ndarray, List[numpy.ndarray], str]: + # afaik it is not possible to get 'past_key_values' from + # the generate method, so we have to run the model twice + out = self.model.generate( + self.tokenize(prompt).input_ids, + max_new_tokens=self.num_tokens_to_generate, + output_scores=True, + return_dict_in_generate=True, + use_cache=True, + ) + generated_text = self.tokenizer.decode( + out.sequences[0], skip_special_tokens=True + ) + generated_logits = numpy.concatenate( + [[score.numpy() for score in out.scores]] + ).transpose( + 1, 0, 2 + ) # (1, num_tokens_to_generate, vocab_size) + + out = self.model(**self.tokenize(prompt)) + prompt_logits = out.logits.detach().numpy()[ + :, :-1, : + ] # (1, prompt_length, vocab_size) + prompt_cache = [ + entry.detach().numpy() + for key_value_tuple in out.past_key_values + for entry in key_value_tuple + ] # List[(1, num_heads, past_length, head_dim)] + + return generated_logits, prompt_logits, prompt_cache, generated_text diff --git a/tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md b/tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md deleted file mode 100644 index ad015eb54c..0000000000 --- a/tests/deepsparse/transformers/pipelines/proposal_text_generation_tests.md +++ /dev/null @@ -1,48 +0,0 @@ -# test perplexity script -# test what's going on after we go past sequence_length -# sequences need to be appropriately long to notice the divergence over time - - -# metric to evaluate logits and cache -> min absolute difference -# maybe if env var set, we can also plot graphs - -OUR_MODELS = ["opt", "codegen", "llama"] -for model in OUR_MODELS: - - # establish sources of truth - torch_target_logits = ... - ort_target_logits = ... - torch_target_cache = ... - ort_target_cache = ... - - no_kv_cache_logits = model(kv_cache=False) - - for engine_type in ["onnxruntime", "deepsparse"] - ort_no_kv_cache_logits = ... - deepsparse_no_kv_cache_logits = ... - - if kv_cache: - for kv_cache_management in ["external", "internal"] - - ort_single_token_prefill_logits = ... - ort_multi_token_prefill_logits = ... - ort_single_token_prefill_cache = ... - ort_multi_token_prefill_cache = ... - - ds_single_token_prefill_logits_external = ... - ds_multi_token_prefill_logits_external = ... - ds_single_token_prefill_cache_external = ... - ds_multi_token_prefill_cache_external = ... - - ds_single_token_prefill_logits_internal = ... - ds_multi_token_prefill_logits_internal = ... - - - - - - - - - - diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index a9492cb25c..eb1d428305 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -12,262 +12,396 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple -import numpy as np -import onnx -import onnxruntime -from transformers import AutoModelForCausalLM, AutoTokenizer +import numpy import pytest from deepsparse import Pipeline -from deepsparse.transformers.utils.helpers import create_causal_mask -from deepsparse.utils.onnx import ( - CACHE_INPUT_PREFIX, - overwrite_onnx_model_inputs_for_kv_cache_models, +from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache +from tests.deepsparse.transformers.pipelines.helpers import ( + ORTGroundTruthSource, + TorchGroundTruthSource, ) -from sparsezoo import Model - - -def _initialize_kv_cache_state(model, length=0): - # get one of the cache inputs - cache_input = next( - input - for input in model.graph.input - if input.name.startswith(CACHE_INPUT_PREFIX) - ) - # read the shape of the cache input - batch_size = cache_input.type.tensor_type.shape.dim[0].dim_value - num_attention_heads = cache_input.type.tensor_type.shape.dim[1].dim_value - hidden_dims = cache_input.type.tensor_type.shape.dim[3].dim_value - - # create a kv cache dictionary - kv_cache = { - input_.name: np.zeros( - (batch_size, num_attention_heads, length, hidden_dims), dtype=np.float32 - ) - for input_ in model.graph.input - if input_.name.startswith(CACHE_INPUT_PREFIX) - } - - return kv_cache - - -START = 0 # global variable for dummy_callback @pytest.mark.parametrize( - "internal_kv_cache", + "use_deepsparse_cache", [True, False], ) @pytest.mark.parametrize( - "model_stub, model_name, uses_bos_token", + "model_stub, model_name, uses_bos_token, logits_max_diff_kv_cache_has_been_filled", [ - ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/" - "huggingface/opt_pretrain/base-none", - "facebook/opt-1.3b", - True, - ), ( "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" "huggingface/bigpython_bigquery_thepile/base-none", "salesforce/codegen-350m-mono", False, + 15.5, ), + # TODO: Waiting for the model to be available + # ("zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/pruned50_quantW8A8-none", + # "facebook/opt-1.3b", + # True, + # None), ], scope="class", ) -@pytest.mark.skip( - reason="Those tests are too heavy to " "run as a normal part of the CI." -) class TestTextGenerationPipeline: + """ + This test suite is meant to test the main scenarios of + the text generation pipeline. + """ + @pytest.fixture - def setup(self, model_stub, model_name, uses_bos_token, internal_kv_cache): + def setup( + self, + model_stub, + model_name, + uses_bos_token, + logits_max_diff_kv_cache_has_been_filled, + use_deepsparse_cache, + ): + self.num_tokens_generate = 216 + self.prompt = """ + Didn't know what time it was, the lights were low + I leaned back on my radio + Some cat was layin' down some rock 'n' roll + "Lotta soul," he said + Then the loud sound did seem to fade + Came back like a slow voice on a wave of phase + That weren't no DJ, that was hazy cosmic jive + """ + # create torch ground source + torch_source = TorchGroundTruthSource( + num_tokens_to_generate=self.num_tokens_generate, model_name=model_name + ) + torch_ground_truth = torch_source(self.prompt) - self.max_generated_tokens = 16 - self.model = Model(model_stub) - self.internal_kv_cache = internal_kv_cache + # prompt length is expressed in number of prompt tokens + prompt_length = torch_ground_truth[1].shape[1] - pipeline = Pipeline.create( - task="text_generation", - model_path=model_stub, - sequence_length=32, - prompt_sequence_length=4, - max_generated_tokens=self.max_generated_tokens, - internal_kv_cache=self.internal_kv_cache, + # sequence_length that assures that the KV cache will not be filled up + self.sequence_length = 2 * prompt_length + self.num_tokens_generate + # sequence_length that assures that the KV cache will be filled up + self.sequence_length_short = self.num_tokens_generate + + # prompt_processing_sequence_length used for the multitoken prefill scenario + self.prompt_processing_sequence_length = 16 + + # the maximum trheshold for the difference between the logits + # when running a scenario where KV Cache buffer has been filled + self.logits_max_diff_kv_cache_has_been_filled = ( + logits_max_diff_kv_cache_has_been_filled ) - short_prompt = "this" - long_prompt = "this is a sample prompt that we will use to test the pipeline" - - # make sure that the short prompt will be only - # processed by a single token engine - # (DISABLED FOR NOW UNTIL WE HAVE ZOO CAUSAL MASK SUPPORT) - # assert ( - # len(pipeline.tokenizer.tokenize(short_prompt)) + int(uses_bos_token) - # < pipeline.prompt_sequence_length - # ) - # make sure that the long prompt will be processed by - # single token and multiple token engines - # (DISABLED FOR NOW UNTIL WE HAVE ZOO CAUSAL MASK SUPPORT) - # assert ( - # len(pipeline.tokenizer.tokenize(long_prompt)) + int(uses_bos_token) - # > pipeline.prompt_sequence_length * 3 - # ) - - yield pipeline, model_name, uses_bos_token, short_prompt, long_prompt - - def test_freeze(self, setup): - # test whether we should be "freezing" the first token after + self.use_deepsparse_cache = use_deepsparse_cache + + assert self.prompt_processing_sequence_length < prompt_length, ( + "The prompt processing sequence length " + "must be smaller than the prompt length" + ) + + yield model_stub, model_name, uses_bos_token, torch_ground_truth + + def test_freeze_first_position(self, setup): + # Test whether we should be "freezing" the first token after # the kv cache is full - pipeline, _, uses_bos_token, _, _ = setup + model_stub, _, uses_bos_token, _ = setup + pipeline = Pipeline.create(task="text_generation", model_path=model_stub) assert pipeline.engine._freeze_first_position == uses_bos_token - def test_model_output_sequences(self, setup): - # test model output against sources of truth - pipeline, model_name, _, short_prompt, long_prompt = setup - - output_sequences = pipeline(sequences=[short_prompt, long_prompt]) + def test_ort_model(self, setup): + # Assert that the ONNX model with KV Cache support runs + # directly in ONNXRuntime and delivers correct results + model_stub, model_name, _, torch_ground_truth = setup - # test against huggingface model - output_hugging_face = self._get_output_huggingface( - sequences=[short_prompt, long_prompt], model_name=model_name + ort_source = ORTGroundTruthSource( + model_name=model_name, + model_stub=model_stub, ) - assert short_prompt + output_sequences.sequences[0] == output_hugging_face[0] - assert long_prompt + output_sequences.sequences[1] == output_hugging_face[1] - - def test_model_output_cache(self, setup): - pipeline, model_name, _, short_prompt, long_prompt = setup - if self.internal_kv_cache: + ort_prompt_logits, ort_prompt_kv_cache = ort_source(self.prompt) + _, torch_prompt_logits, torch_prompt_cache, _ = torch_ground_truth + + # check that the prompt logits are the same + assert numpy.allclose(torch_prompt_logits, ort_prompt_logits, atol=1e-4) + # check that the prompt cache is the same + for torch_cache, ort_cache in zip(torch_prompt_cache, ort_prompt_kv_cache): + assert numpy.allclose(torch_cache, ort_cache, atol=1e-4) + + def test_ort_single_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally + + if self.use_deepsparse_cache: pytest.skip( - "Running pipeline with internal " - "deepsparse cache will not result " - "in meaningful cache entries." + "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - self._test_cache_state(short_prompt, pipeline, model_name) - self._test_cache_state(long_prompt, pipeline, model_name) - - def test_callback(self, setup): - pipeline, *_ = setup - - def dummy_callback(token): - global START - START += 1 - return START < 3 - - inputs = { - "sequences": "def fib(a, b, accumulator=0)", - "callback": dummy_callback, - "return_logits": True, - } - - outs = pipeline(**inputs) - assert outs.logits.shape[1] == 3 - - def _test_cache_state(self, prompt, pipeline, model_name): - # make sure that the cache state after running a prompt - # is correct - - pipeline(sequences=prompt) - cache_state_dict = pipeline.engine.kv_cache.cached_inputs - cache_state_list = [cache_state_dict[key] for key in cache_state_dict.keys()] - - # generate ground truth from ORT - target_cache_state = self._get_cache_state_ort_kv_cache( - model_onnx_path=self.model.deployment.get_file("model.onnx").path, - sequence=prompt, - model_name=model_name, + model_stub, _, _, torch_ground_truth = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + sequence_length=self.sequence_length, + prompt_processing_sequence_length=1, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + engine_type="onnxruntime", + ) + output = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True ) - # get the number of processed prompt tokens - num_prompt_tokens = len(pipeline.tokenizer.tokenize(prompt)) + int( - pipeline.engine._freeze_first_position + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, ) - for x, y in zip(cache_state_list, target_cache_state): - """ - x will be a cache array - [blank, blank, ..., prompt_cache_1, prompt_cache_2, ..., - gen_token_cache_1, gen_token_cache_2, ...] - we need to first remove blank entries and then keep the - remaining prompt_cache entries (remove gen_token_cache entries) - """ - first_non_blank_cache_entry = min( - i for i in range(x.shape[2]) if np.count_nonzero(x[:, :, i, :]) + def test_ort_multi_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally + + if self.use_deepsparse_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - x = x[:, :, first_non_blank_cache_entry:, :] - x = x[:, :, :num_prompt_tokens, :] - - """ - y will be a cache array - [blank, blank, ..., prompt_cache_1, prompt_cache_2, ...] - we need to keep the prompt_cache entries only - """ - y = y[:, :, -num_prompt_tokens:, :] - - assert np.allclose(x, y, atol=1e-4) - - def _get_output_huggingface(self, sequences, model_name): - hf_outputs = [] - # setup tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_name) - tokenizer.padding_side = "left" - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # setup model - model = AutoModelForCausalLM.from_pretrained(model_name) - - # generate ground truth output - for prompt in sequences: - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - generated_ids = model.generate( - input_ids, max_new_tokens=self.max_generated_tokens + model_stub, _, _, torch_ground_truth = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + sequence_length=self.sequence_length, + prompt_processing_sequence_length=self.prompt_processing_sequence_length, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + engine_type="onnxruntime", + ) + output = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + ) + + def test_ort_generation_after_kv_cache_has_been_filled(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is filled up (old entries are removed) + # 3. KV Cache managed externally + + if self.use_deepsparse_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - hf_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - hf_outputs.append(hf_output) - return hf_outputs + model_stub, _, _, torch_ground_truth = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + sequence_length=self.sequence_length_short, + prompt_processing_sequence_length=self.prompt_processing_sequence_length, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + engine_type="onnxruntime", + ) + output = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + "for this scenario, the kv cache should be full: " + "the total number of processed tokens should be " + "greater than the sequence length" + ) - @staticmethod - def _get_cache_state_ort_kv_cache(model_onnx_path, sequence, model_name): - # setup tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_name) - tokenizer.padding_side = "left" - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # setup model and session - # (run full sequence inference) - overwrite_onnx_model_inputs_for_kv_cache_models( - model_onnx_path, sequence_length=128, input_ids_length=128 + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) - sess = onnxruntime.InferenceSession(model_onnx_path) - # get model inputs - onnx_model = onnx.load(model_onnx_path, load_external_data=False) - model_inputs = [x.name for x in onnx_model.graph.input] - kv_cache = _initialize_kv_cache_state(model=onnx_model) + def test_deepsparse_single_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally or internally - inputs = tokenizer( - sequence, return_tensors="np", padding="max_length", max_length=128 + model_stub, _, _, torch_ground_truth = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + sequence_length=self.sequence_length, + prompt_processing_sequence_length=1, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + use_deepsparse_cache=self.use_deepsparse_cache, + ) + output = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True ) - onnxruntime_inputs = dict( - attention_mask=inputs["attention_mask"], - input_ids=inputs["input_ids"], - **kv_cache, + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.use_deepsparse_cache, ) - if "positions" in model_inputs: - attention_mask = inputs["attention_mask"] - positions = attention_mask.cumsum(1) * attention_mask - 1 - onnxruntime_inputs["positions"] = positions + def test_deepsparse_multi_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally or internally - if "causal_mask" in model_inputs: - causal_mask = create_causal_mask( - input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] - ) - onnxruntime_inputs["causal_mask"] = causal_mask + model_stub, _, _, torch_ground_truth = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + sequence_length=self.sequence_length, + prompt_processing_sequence_length=self.prompt_processing_sequence_length, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + use_deepsparse_cache=self.use_deepsparse_cache, + ) + output = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.use_deepsparse_cache, + ) - # run inference and return the cache state - outputs = sess.run(None, onnxruntime_inputs) - logits, *kv_cache = outputs + def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is filled up (old entries are removed) + # 3. KV Cache managed externally or internally - return kv_cache + model_stub, _, _, torch_ground_truth = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + sequence_length=self.sequence_length_short, + prompt_processing_sequence_length=self.prompt_processing_sequence_length, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + use_deepsparse_cache=self.use_deepsparse_cache, + ) + output = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + "for this scenario, the kv cache should be full: " + "the total number of processed tokens should be " + "greater than the sequence length" + ) + + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.use_deepsparse_cache, + max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 + ) + + def test_run_same_prompt_multiple_times(self, setup): + # Test the scenario, where the same prompt is run multiple times + # Every run should produce the same output + model_stub, *_ = setup + pipeline = Pipeline.create( + task="text_generation", + model_path=model_stub, + use_deepsparse_cache=self.use_deepsparse_cache, + ) + output_1 = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True + ) + output_2 = pipeline( + sequences=self.prompt, return_logits=True, include_prompt_logits=True + ) + assert output_1.sequences[0] == output_2.sequences[0] + assert numpy.allclose(output_1.logits, output_2.logits, atol=1e-4) + + def _test_output( + self, + output: "TextGenerationOutput", # noqa F821 + cache_session: DecoderKVCache, + torch_ground_truth: Tuple[numpy.ndarray, ...], + max_logits_difference_threshold: Optional[float] = None, + run_cache_validation: bool = True, + ): + # extract numpy arrays from cached_inputs + kv_cache_array = list(cache_session.cached_inputs.values()) + + ( + generated_logits, + prompt_logits, + prompt_kv_cache, + generated_text, + ) = torch_ground_truth + + # concatenate target prompt_logits and generated_logits and check + target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) + + if max_logits_difference_threshold: + # if comparing the output from the model where + # the kv cache has been filled, we expect the + # maximum absolute difference between the logits + # to be less than the threshold + # (the threshold is established by running the + # ONNX model in ONNXRuntime) + assert ( + abs(output.logits - target_logits).max() + < max_logits_difference_threshold + ) + else: + # otherwise, we expect the logits to be exactly the same + # as the target logits; the generated sequence should + # also be the same as the target sequence, and finally + # (if applicable) the kv cache should be the same as the + # target kv cache + assert numpy.allclose(output.logits, target_logits, atol=1e-4) + assert self.prompt + output.sequences[0] == generated_text + + if run_cache_validation: + self._test_kv_cache_state( + expected_cache=kv_cache_array, + target_cache=torch_ground_truth[2], + total_num_processed_tokens=cache_session.total_num_processed_tokens, + ) + + @staticmethod + def _test_kv_cache_state( + expected_cache: List[numpy.ndarray], + target_cache: List[numpy.ndarray], + total_num_processed_tokens: int, + ): + for x, y in zip(expected_cache, target_cache): + start_index = total_num_processed_tokens + end_index = total_num_processed_tokens - y.shape[2] + # x is (in general) composed of three arrays: + # - padding cache entries (from 0 to -start_index) + # - prompt cache entries (from -start_index to -end_index) + # - generated cache entries (from -end_index to -1) + # as target_cache only pertains to prompt cache entries, we need to + # compare only the prompt cache entries in x with y + assert numpy.allclose(x[:, :, -start_index:-end_index, :], y, atol=1e-4) From 92163b95c2d66258e8cfe8f50fa78d56fecea7a4 Mon Sep 17 00:00:00 2001 From: Damian Date: Fri, 8 Sep 2023 10:47:36 +0000 Subject: [PATCH 4/9] fix tests --- tests/deepsparse/transformers/pipelines/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/helpers.py index 1edf7b5558..2f37760ff5 100644 --- a/tests/deepsparse/transformers/pipelines/helpers.py +++ b/tests/deepsparse/transformers/pipelines/helpers.py @@ -20,11 +20,11 @@ import onnxruntime from transformers import AutoModelForCausalLM, AutoTokenizer -from deepsparse.transformers.utils.helpers import ( - create_causal_mask, +from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.utils.onnx import ( + CACHE_INPUT_PREFIX, overwrite_onnx_model_inputs_for_kv_cache_models, ) -from deepsparse.utils.onnx import CACHE_INPUT_PREFIX from sparsezoo import Model From bfe1b62ad825a0f90aa27398cb9f6179af233441 Mon Sep 17 00:00:00 2001 From: Damian Date: Mon, 11 Sep 2023 11:58:23 +0000 Subject: [PATCH 5/9] Dipika's comments plus adjusting the script to renamed variables --- .../pipelines/test_text_generation.py | 178 ++++++++++++------ 1 file changed, 116 insertions(+), 62 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index eb1d428305..d23c1b5d50 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -25,24 +25,53 @@ ) +NATURAL_LANGUAGE_PROMPT = """ +Didn't know what time it was, the lights were low +I leaned back on my radio +Some cat was layin' down some rock 'n' roll +"Lotta soul," he said +Then the loud sound did seem to fade +Came back like a slow voice on a wave of phase +That weren't no DJ, that was hazy cosmic jive +""" + +CODE_LANGUAGE_PROMPT = """ +def Fibonacci(n): + # Check if input is 0 then it will + # print incorrect input + if n < 0: + print("Incorrect input") + # Check if n is 0 + # then it will return 0 + elif n == 0: + return 0 +""" + + @pytest.mark.parametrize( - "use_deepsparse_cache", + "internal_kv_cache", [True, False], ) @pytest.mark.parametrize( - "model_stub, model_name, uses_bos_token, logits_max_diff_kv_cache_has_been_filled", + "model_stub, " + "model_name, " + "uses_bos_token, " + "prompt, " + "logits_max_diff_kv_cache_has_been_filled", [ ( "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" "huggingface/bigpython_bigquery_thepile/base-none", "salesforce/codegen-350m-mono", False, - 15.5, + CODE_LANGUAGE_PROMPT, + 13, ), # TODO: Waiting for the model to be available # ("zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/pruned50_quantW8A8-none", # "facebook/opt-1.3b", # True, + # NATURAL_LANGUAGE_PROMPT, # None), ], scope="class", @@ -53,25 +82,38 @@ class TestTextGenerationPipeline: the text generation pipeline. """ + def get_pipeline(self, **kwargs): + if not kwargs: + # return the default pipeline + if self.default_pipeline: + return self.default_pipeline + else: + self.default_pipeline = Pipeline.create( + task="text_generation", + model_path=self.model_stub, + internal_kv_cache=self.internal_kv_cache, + prompt_sequence_length=self.prompt_sequence_length, + sequence_length=self.sequence_length, + max_generated_tokens=self.num_tokens_generate, + force_max_tokens=True, + ) + return self.default_pipeline + # return a pipeline with the given kwargs + return Pipeline.create(**kwargs) + @pytest.fixture def setup( self, model_stub, model_name, uses_bos_token, + prompt, logits_max_diff_kv_cache_has_been_filled, - use_deepsparse_cache, + internal_kv_cache, ): self.num_tokens_generate = 216 - self.prompt = """ - Didn't know what time it was, the lights were low - I leaned back on my radio - Some cat was layin' down some rock 'n' roll - "Lotta soul," he said - Then the loud sound did seem to fade - Came back like a slow voice on a wave of phase - That weren't no DJ, that was hazy cosmic jive - """ + self.model_stub = model_stub + self.prompt = prompt # create torch ground source torch_source = TorchGroundTruthSource( num_tokens_to_generate=self.num_tokens_generate, model_name=model_name @@ -86,38 +128,40 @@ def setup( # sequence_length that assures that the KV cache will be filled up self.sequence_length_short = self.num_tokens_generate - # prompt_processing_sequence_length used for the multitoken prefill scenario - self.prompt_processing_sequence_length = 16 + # prompt_sequence_length used for the multitoken prefill scenario + self.prompt_sequence_length = prompt_length // 2 - # the maximum trheshold for the difference between the logits + # the maximum threshold for the difference between the logits # when running a scenario where KV Cache buffer has been filled self.logits_max_diff_kv_cache_has_been_filled = ( logits_max_diff_kv_cache_has_been_filled ) - self.use_deepsparse_cache = use_deepsparse_cache + self.internal_kv_cache = internal_kv_cache - assert self.prompt_processing_sequence_length < prompt_length, ( + self.default_pipeline = None + + assert self.prompt_sequence_length < prompt_length, ( "The prompt processing sequence length " "must be smaller than the prompt length" ) - yield model_stub, model_name, uses_bos_token, torch_ground_truth + yield model_name, uses_bos_token, torch_ground_truth def test_freeze_first_position(self, setup): # Test whether we should be "freezing" the first token after # the kv cache is full - model_stub, _, uses_bos_token, _ = setup - pipeline = Pipeline.create(task="text_generation", model_path=model_stub) + _, uses_bos_token, _ = setup + pipeline = self.get_pipeline() assert pipeline.engine._freeze_first_position == uses_bos_token def test_ort_model(self, setup): # Assert that the ONNX model with KV Cache support runs # directly in ONNXRuntime and delivers correct results - model_stub, model_name, _, torch_ground_truth = setup + model_name, _, torch_ground_truth = setup ort_source = ORTGroundTruthSource( model_name=model_name, - model_stub=model_stub, + model_stub=self.model_stub, ) ort_prompt_logits, ort_prompt_kv_cache = ort_source(self.prompt) _, torch_prompt_logits, torch_prompt_cache, _ = torch_ground_truth @@ -135,16 +179,16 @@ def test_ort_single_token_prefill(self, setup): # 2. The KV Cache is never filled up # 3. KV Cache managed externally - if self.use_deepsparse_cache: + if self.internal_kv_cache: pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - model_stub, _, _, torch_ground_truth = setup - pipeline = Pipeline.create( + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( task="text_generation", - model_path=model_stub, + model_path=self.model_stub, sequence_length=self.sequence_length, - prompt_processing_sequence_length=1, + prompt_sequence_length=1, max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, engine_type="onnxruntime", @@ -167,16 +211,16 @@ def test_ort_multi_token_prefill(self, setup): # 2. The KV Cache is never filled up # 3. KV Cache managed externally - if self.use_deepsparse_cache: + if self.internal_kv_cache: pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - model_stub, _, _, torch_ground_truth = setup - pipeline = Pipeline.create( + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( task="text_generation", - model_path=model_stub, + model_path=self.model_stub, sequence_length=self.sequence_length, - prompt_processing_sequence_length=self.prompt_processing_sequence_length, + prompt_sequence_length=self.prompt_sequence_length, max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, engine_type="onnxruntime", @@ -199,16 +243,16 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): # 2. The KV Cache is filled up (old entries are removed) # 3. KV Cache managed externally - if self.use_deepsparse_cache: + if self.internal_kv_cache: pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - model_stub, _, _, torch_ground_truth = setup - pipeline = Pipeline.create( + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( task="text_generation", - model_path=model_stub, + model_path=self.model_stub, sequence_length=self.sequence_length_short, - prompt_processing_sequence_length=self.prompt_processing_sequence_length, + prompt_sequence_length=self.prompt_sequence_length, max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, engine_type="onnxruntime", @@ -237,15 +281,15 @@ def test_deepsparse_single_token_prefill(self, setup): # 2. The KV Cache is never filled up # 3. KV Cache managed externally or internally - model_stub, _, _, torch_ground_truth = setup - pipeline = Pipeline.create( + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( task="text_generation", - model_path=model_stub, + model_path=self.model_stub, sequence_length=self.sequence_length, - prompt_processing_sequence_length=1, + prompt_sequence_length=1, max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, - use_deepsparse_cache=self.use_deepsparse_cache, + internal_kv_cache=self.internal_kv_cache, ) output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True @@ -256,7 +300,7 @@ def test_deepsparse_single_token_prefill(self, setup): output=output, cache_session=cache_session, torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.use_deepsparse_cache, + run_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_multi_token_prefill(self, setup): @@ -266,15 +310,15 @@ def test_deepsparse_multi_token_prefill(self, setup): # 2. The KV Cache is never filled up # 3. KV Cache managed externally or internally - model_stub, _, _, torch_ground_truth = setup - pipeline = Pipeline.create( + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( task="text_generation", - model_path=model_stub, + model_path=self.model_stub, sequence_length=self.sequence_length, - prompt_processing_sequence_length=self.prompt_processing_sequence_length, + prompt_sequence_length=self.prompt_sequence_length, max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, - use_deepsparse_cache=self.use_deepsparse_cache, + internal_kv_cache=self.internal_kv_cache, ) output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True @@ -285,7 +329,7 @@ def test_deepsparse_multi_token_prefill(self, setup): output=output, cache_session=cache_session, torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.use_deepsparse_cache, + run_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): @@ -295,15 +339,15 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): # 2. The KV Cache is filled up (old entries are removed) # 3. KV Cache managed externally or internally - model_stub, _, _, torch_ground_truth = setup - pipeline = Pipeline.create( + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( task="text_generation", - model_path=model_stub, + model_path=self.model_stub, sequence_length=self.sequence_length_short, - prompt_processing_sequence_length=self.prompt_processing_sequence_length, + prompt_sequence_length=self.prompt_sequence_length, max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, - use_deepsparse_cache=self.use_deepsparse_cache, + internal_kv_cache=self.internal_kv_cache, ) output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True @@ -319,19 +363,15 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): output=output, cache_session=cache_session, torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.use_deepsparse_cache, + run_cache_validation=not self.internal_kv_cache, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) def test_run_same_prompt_multiple_times(self, setup): # Test the scenario, where the same prompt is run multiple times # Every run should produce the same output - model_stub, *_ = setup - pipeline = Pipeline.create( - task="text_generation", - model_path=model_stub, - use_deepsparse_cache=self.use_deepsparse_cache, - ) + pipeline = self.get_pipeline() + output_1 = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True ) @@ -341,6 +381,20 @@ def test_run_same_prompt_multiple_times(self, setup): assert output_1.sequences[0] == output_2.sequences[0] assert numpy.allclose(output_1.logits, output_2.logits, atol=1e-4) + def test_run_multiple_prompts_in_parallel(self, setup): + # Test the scenario, where multiple prompts are run in parallel + # Same two prompts should produce the same output + pipeline = self.get_pipeline() + + output = pipeline( + sequences=[self.prompt, self.prompt], + return_logits=True, + include_prompt_logits=True, + ) + + assert numpy.allclose(output.logits[0], output.logits[1], atol=1e-4) + assert output.sequences[0] == output.sequences[1] + def _test_output( self, output: "TextGenerationOutput", # noqa F821 From f2693ff50995c1c44964c532e6ccf56552e36373 Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 12 Sep 2023 05:59:16 +0000 Subject: [PATCH 6/9] remove ORT ground truth --- .../transformers/pipelines/helpers.py | 166 +++--------------- .../pipelines/test_text_generation.py | 23 +-- 2 files changed, 21 insertions(+), 168 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/helpers.py index 2f37760ff5..0bb962a8e3 100644 --- a/tests/deepsparse/transformers/pipelines/helpers.py +++ b/tests/deepsparse/transformers/pipelines/helpers.py @@ -12,166 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import List, Tuple import numpy -import onnx -import onnxruntime from transformers import AutoModelForCausalLM, AutoTokenizer -from deepsparse.transformers.utils.helpers import create_causal_mask -from deepsparse.utils.onnx import ( - CACHE_INPUT_PREFIX, - overwrite_onnx_model_inputs_for_kv_cache_models, -) -from sparsezoo import Model - -class GroundTruthSource(ABC): - def __init__(self, model_name: str): - tokenizer = AutoTokenizer.from_pretrained(model_name) - tokenizer.padding_side = "left" - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.tokenizer = tokenizer - - @abstractmethod - def tokenize(self, prompt: str) -> Dict[str, Any]: - """ - :param prompt: The prompt to tokenize - :return: A dictionary of tokenized inputs - """ - raise NotImplementedError() - - @abstractmethod - def __call__(self, prompt: str) -> Any: - """ - :param prompt: The prompt to generate from - :return: Ground truth logits / cache state - """ - raise NotImplementedError() - - -class ORTGroundTruthSource(GroundTruthSource): - """ - An object that generates ground truth logits and - cache states from a prompt. This object cannot - generate tokens in an autoregressive manner, and thus - will only output prompt logits and prompt cache state - """ - - def __init__( - self, - model_stub: str, - model_name: str, - sequence_length: int = 256, - ): - super().__init__(model_name) - - self.model_onnx_path = Model(model_stub).deployment.get_file("model.onnx").path - overwrite_onnx_model_inputs_for_kv_cache_models( - self.model_onnx_path, - sequence_length=sequence_length, - input_ids_length=sequence_length, - ) - self.sequence_length = sequence_length - self.session = onnxruntime.InferenceSession(self.model_onnx_path) - self.model_inputs = [ - x.name - for x in onnx.load( - self.model_onnx_path, load_external_data=False - ).graph.input - ] - - def tokenize(self, prompt: str): - return self.tokenizer( - prompt, - return_tensors="np", - padding="max_length", - max_length=self.sequence_length, - ) - - def __call__(self, prompt: str) -> Tuple[numpy.ndarray, List[numpy.ndarray]]: - inputs = self.tokenize(prompt) - kv_cache = self._initialize_kv_cache_state() - - onnxruntime_inputs = dict( - attention_mask=inputs["attention_mask"], - input_ids=inputs["input_ids"], - **kv_cache, - ) - - if "positions" in self.model_inputs: - attention_mask = inputs["attention_mask"] - positions = attention_mask.cumsum(1) * attention_mask - 1 - onnxruntime_inputs["positions"] = positions - - if "causal_mask" in self.model_inputs: - causal_mask = create_causal_mask( - input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] - ) - onnxruntime_inputs["causal_mask"] = causal_mask - - # run inference and return the cache state - outputs = self.session.run(None, onnxruntime_inputs) - prompt_logits, *prompt_cache = outputs - - # remove logits that correspond to padding tokens - prompt_logits = numpy.compress( - onnxruntime_inputs["attention_mask"].flatten(), prompt_logits, axis=1 - ) # (1, prompt_length, vocab_size) - prompt_logits = prompt_logits[:, :-1, :] # (1, prompt_length, vocab_size) - - # remove cache that corresponds to padding tokens - prompt_cache = [ - numpy.compress( - onnxruntime_inputs["attention_mask"].flatten(), cache, axis=2 - ) - for cache in prompt_cache - ] # List[(1, num_heads, past_length, head_dim)] - - return prompt_logits, prompt_cache - - def _initialize_kv_cache_state(self, length: int = 0) -> Dict[str, numpy.ndarray]: - model = onnx.load(self.model_onnx_path, load_external_data=False) - - cache_input = next( - input - for input in model.graph.input - if input.name.startswith(CACHE_INPUT_PREFIX) - ) - # read the shape of the cache input - batch_size = cache_input.type.tensor_type.shape.dim[0].dim_value - num_attention_heads = cache_input.type.tensor_type.shape.dim[1].dim_value - hidden_dims = cache_input.type.tensor_type.shape.dim[3].dim_value - - # create a kv cache dictionary - kv_cache = { - input_.name: numpy.zeros( - (batch_size, num_attention_heads, length, hidden_dims), - dtype=numpy.float32, - ) - for input_ in model.graph.input - if input_.name.startswith(CACHE_INPUT_PREFIX) - } - return kv_cache - - -class TorchGroundTruthSource(GroundTruthSource): +class TorchGroundTruthSource: """ An object that generates ground truth logits and cache states from a prompt. This object can generate tokens in an autoregressive manner, and thus - will output prompt logits, generated logits, generated - sequence and prompt cache state + will output: + - prompt logits, + - generated logits, + - prompt cache state, + - generated sequence """ def __init__(self, num_tokens_to_generate: int, model_name: str): - super().__init__(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.tokenizer = self._create_tokenizer(model_name) + self.num_tokens_to_generate = num_tokens_to_generate + self.model_name = model_name def tokenize(self, prompt: str): return self.tokenizer(prompt, return_tensors="pt") @@ -208,3 +73,12 @@ def __call__( ] # List[(1, num_heads, past_length, head_dim)] return generated_logits, prompt_logits, prompt_cache, generated_text + + @staticmethod + def _create_tokenizer(model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index d23c1b5d50..4a6626e741 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -19,10 +19,7 @@ import pytest from deepsparse import Pipeline from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache -from tests.deepsparse.transformers.pipelines.helpers import ( - ORTGroundTruthSource, - TorchGroundTruthSource, -) +from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource NATURAL_LANGUAGE_PROMPT = """ @@ -154,24 +151,6 @@ def test_freeze_first_position(self, setup): pipeline = self.get_pipeline() assert pipeline.engine._freeze_first_position == uses_bos_token - def test_ort_model(self, setup): - # Assert that the ONNX model with KV Cache support runs - # directly in ONNXRuntime and delivers correct results - model_name, _, torch_ground_truth = setup - - ort_source = ORTGroundTruthSource( - model_name=model_name, - model_stub=self.model_stub, - ) - ort_prompt_logits, ort_prompt_kv_cache = ort_source(self.prompt) - _, torch_prompt_logits, torch_prompt_cache, _ = torch_ground_truth - - # check that the prompt logits are the same - assert numpy.allclose(torch_prompt_logits, ort_prompt_logits, atol=1e-4) - # check that the prompt cache is the same - for torch_cache, ort_cache in zip(torch_prompt_cache, ort_prompt_kv_cache): - assert numpy.allclose(torch_cache, ort_cache, atol=1e-4) - def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the # following scenario: From dd270a2f61fa43735d948af359ce3fba49b18e1b Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 12 Sep 2023 09:32:13 +0000 Subject: [PATCH 7/9] add OPT tests --- src/deepsparse/transformers/helpers.py | 26 +++++++++++++++---- .../pipelines/test_text_generation.py | 25 +++++++++++------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index 0d1fb3971b..44f0cfc77f 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -49,6 +49,7 @@ _MODEL_DIR_CONFIG_NAME = "config.json" _MODEL_DIR_TOKENIZER_NAME = "tokenizer.json" _MODEL_DIR_TOKENIZER_CONFIG_NAME = "tokenizer_config.json" +_OPT_TOKENIZER_FILES = ["special_tokens_map.json", "vocab.json", "merges.txt"] def get_onnx_path(model_path: str) -> str: @@ -122,14 +123,29 @@ def get_hugging_face_configs(model_path: str) -> Tuple[str, str]: config_path = _get_file_parent( zoo_model.deployment.default.get_file(_MODEL_DIR_CONFIG_NAME).path ) - tokenizer_path = _get_file_parent( - zoo_model.deployment.default.get_file(_MODEL_DIR_TOKENIZER_NAME).path + tokenizer_file = zoo_model.deployment.default.get_file( + _MODEL_DIR_TOKENIZER_NAME ) - tokenizer_config_path = zoo_model.deployment.default.get_file( + + tokenizer_config_file = zoo_model.deployment.default.get_file( _MODEL_DIR_TOKENIZER_CONFIG_NAME ) - if tokenizer_config_path is not None: - tokenizer_config_path.path # trigger download of tokenizer_config + + if tokenizer_config_file is not None: + tokenizer_config_path = _get_file_parent( + tokenizer_config_file.path + ) # trigger download of tokenizer_config + + if tokenizer_file is not None: + tokenizer_path = _get_file_parent(tokenizer_file.path) + else: + # if tokenizer_file is not present, we assume it's the OPT model + # this means that we use tokenizer_config_path instead of tokenizer_path + # and need to download the additional tokenizer files + tokenizer_path = tokenizer_config_path + for file in _OPT_TOKENIZER_FILES: + zoo_model.deployment.default.get_file(file).path + else: raise ValueError( f"model_path {model_path} is not a valid directory or zoo stub" diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 4a6626e741..cc92fb5b76 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -22,6 +22,8 @@ from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource +_PRECISION = 1e-3 + NATURAL_LANGUAGE_PROMPT = """ Didn't know what time it was, the lights were low I leaned back on my radio @@ -64,12 +66,13 @@ def Fibonacci(n): CODE_LANGUAGE_PROMPT, 13, ), - # TODO: Waiting for the model to be available - # ("zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/pruned50_quantW8A8-none", - # "facebook/opt-1.3b", - # True, - # NATURAL_LANGUAGE_PROMPT, - # None), + ( + "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/base-none", + "facebook/opt-1.3b", + True, + NATURAL_LANGUAGE_PROMPT, + 3.9, + ), ], scope="class", ) @@ -358,7 +361,7 @@ def test_run_same_prompt_multiple_times(self, setup): sequences=self.prompt, return_logits=True, include_prompt_logits=True ) assert output_1.sequences[0] == output_2.sequences[0] - assert numpy.allclose(output_1.logits, output_2.logits, atol=1e-4) + assert numpy.allclose(output_1.logits, output_2.logits, atol=_PRECISION) def test_run_multiple_prompts_in_parallel(self, setup): # Test the scenario, where multiple prompts are run in parallel @@ -371,7 +374,7 @@ def test_run_multiple_prompts_in_parallel(self, setup): include_prompt_logits=True, ) - assert numpy.allclose(output.logits[0], output.logits[1], atol=1e-4) + assert numpy.allclose(output.logits[0], output.logits[1], atol=_PRECISION) assert output.sequences[0] == output.sequences[1] def _test_output( @@ -412,7 +415,7 @@ def _test_output( # also be the same as the target sequence, and finally # (if applicable) the kv cache should be the same as the # target kv cache - assert numpy.allclose(output.logits, target_logits, atol=1e-4) + assert numpy.allclose(output.logits, target_logits, atol=_PRECISION) assert self.prompt + output.sequences[0] == generated_text if run_cache_validation: @@ -437,4 +440,6 @@ def _test_kv_cache_state( # - generated cache entries (from -end_index to -1) # as target_cache only pertains to prompt cache entries, we need to # compare only the prompt cache entries in x with y - assert numpy.allclose(x[:, :, -start_index:-end_index, :], y, atol=1e-4) + assert numpy.allclose( + x[:, :, -start_index:-end_index, :], y, atol=_PRECISION + ) From 646d6f510b5009fd86565d72df69e19de7eab3e9 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 13 Sep 2023 09:03:05 +0000 Subject: [PATCH 8/9] rebase and disable tests in GHA --- .../pipelines/test_text_generation.py | 71 ++++++++++++++----- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index cc92fb5b76..9911ff4d8a 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -67,7 +67,8 @@ def Fibonacci(n): 13, ), ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/base-none", + "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" + "opt_pretrain/base-none", "facebook/opt-1.3b", True, NATURAL_LANGUAGE_PROMPT, @@ -76,6 +77,9 @@ def Fibonacci(n): ], scope="class", ) +@pytest.mark.skip( + reason="Those tests are too heavy to " "run as a normal part of the CI." +) class TestTextGenerationPipeline: """ This test suite is meant to test the main scenarios of @@ -94,7 +98,6 @@ def get_pipeline(self, **kwargs): internal_kv_cache=self.internal_kv_cache, prompt_sequence_length=self.prompt_sequence_length, sequence_length=self.sequence_length, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, ) return self.default_pipeline @@ -171,12 +174,14 @@ def test_ort_single_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=1, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, engine_type="onnxruntime", ) output = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) cache_session = pipeline.engine.kv_cache assert cache_session.total_num_processed_tokens < self.sequence_length @@ -203,12 +208,14 @@ def test_ort_multi_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=self.prompt_sequence_length, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, engine_type="onnxruntime", ) output = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) cache_session = pipeline.engine.kv_cache assert cache_session.total_num_processed_tokens < self.sequence_length @@ -235,12 +242,14 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length_short, prompt_sequence_length=self.prompt_sequence_length, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, engine_type="onnxruntime", ) output = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) cache_session = pipeline.engine.kv_cache assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( @@ -269,12 +278,14 @@ def test_deepsparse_single_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=1, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) output = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) cache_session = pipeline.engine.kv_cache assert cache_session.total_num_processed_tokens < self.sequence_length @@ -298,12 +309,14 @@ def test_deepsparse_multi_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=self.prompt_sequence_length, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) output = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) cache_session = pipeline.engine.kv_cache assert cache_session.total_num_processed_tokens < self.sequence_length @@ -327,12 +340,14 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length_short, prompt_sequence_length=self.prompt_sequence_length, - max_generated_tokens=self.num_tokens_generate, force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) output = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) cache_session = pipeline.engine.kv_cache assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( @@ -355,10 +370,16 @@ def test_run_same_prompt_multiple_times(self, setup): pipeline = self.get_pipeline() output_1 = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) output_2 = pipeline( - sequences=self.prompt, return_logits=True, include_prompt_logits=True + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) assert output_1.sequences[0] == output_2.sequences[0] assert numpy.allclose(output_1.logits, output_2.logits, atol=_PRECISION) @@ -372,11 +393,29 @@ def test_run_multiple_prompts_in_parallel(self, setup): sequences=[self.prompt, self.prompt], return_logits=True, include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) assert numpy.allclose(output.logits[0], output.logits[1], atol=_PRECISION) assert output.sequences[0] == output.sequences[1] + def test_num_generated_predictions(self, setup): + # Test the scenario, where multiple predictions are generated + # from the same prompt + pipeline = self.get_pipeline() + + output_sequences = pipeline( + sequences=[self.prompt], num_generated_predictions=2 + ) + assert len(output_sequences.sequences[0]) == 2 + + output_sequences = pipeline( + sequences=[self.prompt, self.prompt], num_generated_predictions=2 + ) + assert len(output_sequences.sequences) == 2 + for sequences in output_sequences.sequences: + assert len(sequences) == 2 + def _test_output( self, output: "TextGenerationOutput", # noqa F821 From 17e63ee362399a223835b645bdf94ef3574d743b Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 13 Sep 2023 09:05:36 +0000 Subject: [PATCH 9/9] quality --- .../deepsparse/transformers/pipelines/test_text_generation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 9911ff4d8a..33e87328ff 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -77,9 +77,7 @@ def Fibonacci(n): ], scope="class", ) -@pytest.mark.skip( - reason="Those tests are too heavy to " "run as a normal part of the CI." -) +@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") class TestTextGenerationPipeline: """ This test suite is meant to test the main scenarios of