diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index 0d1fb3971b..44f0cfc77f 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -49,6 +49,7 @@ _MODEL_DIR_CONFIG_NAME = "config.json" _MODEL_DIR_TOKENIZER_NAME = "tokenizer.json" _MODEL_DIR_TOKENIZER_CONFIG_NAME = "tokenizer_config.json" +_OPT_TOKENIZER_FILES = ["special_tokens_map.json", "vocab.json", "merges.txt"] def get_onnx_path(model_path: str) -> str: @@ -122,14 +123,29 @@ def get_hugging_face_configs(model_path: str) -> Tuple[str, str]: config_path = _get_file_parent( zoo_model.deployment.default.get_file(_MODEL_DIR_CONFIG_NAME).path ) - tokenizer_path = _get_file_parent( - zoo_model.deployment.default.get_file(_MODEL_DIR_TOKENIZER_NAME).path + tokenizer_file = zoo_model.deployment.default.get_file( + _MODEL_DIR_TOKENIZER_NAME ) - tokenizer_config_path = zoo_model.deployment.default.get_file( + + tokenizer_config_file = zoo_model.deployment.default.get_file( _MODEL_DIR_TOKENIZER_CONFIG_NAME ) - if tokenizer_config_path is not None: - tokenizer_config_path.path # trigger download of tokenizer_config + + if tokenizer_config_file is not None: + tokenizer_config_path = _get_file_parent( + tokenizer_config_file.path + ) # trigger download of tokenizer_config + + if tokenizer_file is not None: + tokenizer_path = _get_file_parent(tokenizer_file.path) + else: + # if tokenizer_file is not present, we assume it's the OPT model + # this means that we use tokenizer_config_path instead of tokenizer_path + # and need to download the additional tokenizer files + tokenizer_path = tokenizer_config_path + for file in _OPT_TOKENIZER_FILES: + zoo_model.deployment.default.get_file(file).path + else: raise ValueError( f"model_path {model_path} is not a valid directory or zoo stub" diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/helpers.py new file mode 100644 index 0000000000..0bb962a8e3 --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/helpers.py @@ -0,0 +1,84 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import numpy +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class TorchGroundTruthSource: + """ + An object that generates ground truth logits and + cache states from a prompt. This object can + generate tokens in an autoregressive manner, and thus + will output: + - prompt logits, + - generated logits, + - prompt cache state, + - generated sequence + """ + + def __init__(self, num_tokens_to_generate: int, model_name: str): + + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.tokenizer = self._create_tokenizer(model_name) + + self.num_tokens_to_generate = num_tokens_to_generate + self.model_name = model_name + + def tokenize(self, prompt: str): + return self.tokenizer(prompt, return_tensors="pt") + + def __call__( + self, prompt: str + ) -> Tuple[numpy.ndarray, numpy.ndarray, List[numpy.ndarray], str]: + # afaik it is not possible to get 'past_key_values' from + # the generate method, so we have to run the model twice + out = self.model.generate( + self.tokenize(prompt).input_ids, + max_new_tokens=self.num_tokens_to_generate, + output_scores=True, + return_dict_in_generate=True, + use_cache=True, + ) + generated_text = self.tokenizer.decode( + out.sequences[0], skip_special_tokens=True + ) + generated_logits = numpy.concatenate( + [[score.numpy() for score in out.scores]] + ).transpose( + 1, 0, 2 + ) # (1, num_tokens_to_generate, vocab_size) + + out = self.model(**self.tokenize(prompt)) + prompt_logits = out.logits.detach().numpy()[ + :, :-1, : + ] # (1, prompt_length, vocab_size) + prompt_cache = [ + entry.detach().numpy() + for key_value_tuple in out.past_key_values + for entry in key_value_tuple + ] # List[(1, num_heads, past_length, head_dim)] + + return generated_logits, prompt_logits, prompt_cache, generated_text + + @staticmethod + def _create_tokenizer(model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index cf2df33314..33e87328ff 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -12,47 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple -import numpy as np -import onnx -import onnxruntime -from transformers import AutoModelForCausalLM, AutoTokenizer +import numpy import pytest from deepsparse import Pipeline -from deepsparse.transformers.utils.helpers import create_causal_mask -from deepsparse.utils.onnx import ( - CACHE_INPUT_PREFIX, - overwrite_onnx_model_inputs_for_kv_cache_models, -) -from sparsezoo import Model - - -def _initialize_kv_cache_state(model, length=0): - # get one of the cache inputs - cache_input = next( - input - for input in model.graph.input - if input.name.startswith(CACHE_INPUT_PREFIX) - ) - # read the shape of the cache input - batch_size = cache_input.type.tensor_type.shape.dim[0].dim_value - num_attention_heads = cache_input.type.tensor_type.shape.dim[1].dim_value - hidden_dims = cache_input.type.tensor_type.shape.dim[3].dim_value - - # create a kv cache dictionary - kv_cache = { - input_.name: np.zeros( - (batch_size, num_attention_heads, length, hidden_dims), dtype=np.float32 - ) - for input_ in model.graph.input - if input_.name.startswith(CACHE_INPUT_PREFIX) - } - - return kv_cache - - -START = 0 # global variable for dummy_callback +from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache +from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource + + +_PRECISION = 1e-3 + +NATURAL_LANGUAGE_PROMPT = """ +Didn't know what time it was, the lights were low +I leaned back on my radio +Some cat was layin' down some rock 'n' roll +"Lotta soul," he said +Then the loud sound did seem to fade +Came back like a slow voice on a wave of phase +That weren't no DJ, that was hazy cosmic jive +""" + +CODE_LANGUAGE_PROMPT = """ +def Fibonacci(n): + # Check if input is 0 then it will + # print incorrect input + if n < 0: + print("Incorrect input") + # Check if n is 0 + # then it will return 0 + elif n == 0: + return 0 +""" @pytest.mark.parametrize( @@ -60,233 +52,431 @@ def _initialize_kv_cache_state(model, length=0): [True, False], ) @pytest.mark.parametrize( - "model_stub, model_name, uses_bos_token", + "model_stub, " + "model_name, " + "uses_bos_token, " + "prompt, " + "logits_max_diff_kv_cache_has_been_filled", [ - ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/" - "huggingface/opt_pretrain/base-none", - "facebook/opt-1.3b", - True, - ), ( "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" "huggingface/bigpython_bigquery_thepile/base-none", "salesforce/codegen-350m-mono", False, + CODE_LANGUAGE_PROMPT, + 13, + ), + ( + "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" + "opt_pretrain/base-none", + "facebook/opt-1.3b", + True, + NATURAL_LANGUAGE_PROMPT, + 3.9, ), ], scope="class", ) -@pytest.mark.skip( - reason="Those tests are too heavy to " "run as a normal part of the CI." -) +@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") class TestTextGenerationPipeline: + """ + This test suite is meant to test the main scenarios of + the text generation pipeline. + """ + + def get_pipeline(self, **kwargs): + if not kwargs: + # return the default pipeline + if self.default_pipeline: + return self.default_pipeline + else: + self.default_pipeline = Pipeline.create( + task="text_generation", + model_path=self.model_stub, + internal_kv_cache=self.internal_kv_cache, + prompt_sequence_length=self.prompt_sequence_length, + sequence_length=self.sequence_length, + force_max_tokens=True, + ) + return self.default_pipeline + # return a pipeline with the given kwargs + return Pipeline.create(**kwargs) + @pytest.fixture - def setup(self, model_stub, model_name, uses_bos_token, internal_kv_cache): + def setup( + self, + model_stub, + model_name, + uses_bos_token, + prompt, + logits_max_diff_kv_cache_has_been_filled, + internal_kv_cache, + ): + self.num_tokens_generate = 216 + self.model_stub = model_stub + self.prompt = prompt + # create torch ground source + torch_source = TorchGroundTruthSource( + num_tokens_to_generate=self.num_tokens_generate, model_name=model_name + ) + torch_ground_truth = torch_source(self.prompt) + + # prompt length is expressed in number of prompt tokens + prompt_length = torch_ground_truth[1].shape[1] + + # sequence_length that assures that the KV cache will not be filled up + self.sequence_length = 2 * prompt_length + self.num_tokens_generate + # sequence_length that assures that the KV cache will be filled up + self.sequence_length_short = self.num_tokens_generate - self.max_generated_tokens = 16 - self.model = Model(model_stub) + # prompt_sequence_length used for the multitoken prefill scenario + self.prompt_sequence_length = prompt_length // 2 + + # the maximum threshold for the difference between the logits + # when running a scenario where KV Cache buffer has been filled + self.logits_max_diff_kv_cache_has_been_filled = ( + logits_max_diff_kv_cache_has_been_filled + ) self.internal_kv_cache = internal_kv_cache - pipeline = Pipeline.create( - task="text_generation", - model_path=model_stub, - sequence_length=32, - prompt_sequence_length=4, - internal_kv_cache=self.internal_kv_cache, + self.default_pipeline = None + + assert self.prompt_sequence_length < prompt_length, ( + "The prompt processing sequence length " + "must be smaller than the prompt length" ) - short_prompt = "this" - long_prompt = "this is a sample prompt that we will use to test the pipeline" - - # make sure that the short prompt will be only - # processed by a single token engine - # (DISABLED FOR NOW UNTIL WE HAVE ZOO CAUSAL MASK SUPPORT) - # assert ( - # len(pipeline.tokenizer.tokenize(short_prompt)) + int(uses_bos_token) - # < pipeline.prompt_sequence_length - # ) - # make sure that the long prompt will be processed by - # single token and multiple token engines - # (DISABLED FOR NOW UNTIL WE HAVE ZOO CAUSAL MASK SUPPORT) - # assert ( - # len(pipeline.tokenizer.tokenize(long_prompt)) + int(uses_bos_token) - # > pipeline.prompt_sequence_length * 3 - # ) - - yield pipeline, model_name, uses_bos_token, short_prompt, long_prompt - - def test_freeze(self, setup): - # test whether we should be "freezing" the first token after + + yield model_name, uses_bos_token, torch_ground_truth + + def test_freeze_first_position(self, setup): + # Test whether we should be "freezing" the first token after # the kv cache is full - pipeline, _, uses_bos_token, _, _ = setup + _, uses_bos_token, _ = setup + pipeline = self.get_pipeline() assert pipeline.engine._freeze_first_position == uses_bos_token - def test_model_output_sequences(self, setup): - # test model output against sources of truth - pipeline, model_name, _, short_prompt, long_prompt = setup + def test_ort_single_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally - output_sequences = pipeline( - sequences=[short_prompt, long_prompt], max_tokens=self.max_generated_tokens + if self.internal_kv_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." + ) + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task="text_generation", + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=1, + force_max_tokens=True, + engine_type="onnxruntime", ) - - # test against huggingface model - output_hugging_face = self._get_output_huggingface( - sequences=[short_prompt, long_prompt], model_name=model_name + output = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) - assert short_prompt + output_sequences.sequences[0] == output_hugging_face[0] - assert long_prompt + output_sequences.sequences[1] == output_hugging_face[1] - - def test_num_generated_predictions(self, setup): - pipeline = setup[0] - short_prompt = setup[3] - - output_sequences = pipeline( - sequences=[short_prompt], num_generated_predictions=2 + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, ) - assert len(output_sequences.sequences[0]) == 2 + def test_ort_multi_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally - output_sequences = pipeline( - sequences=[short_prompt, short_prompt], num_generated_predictions=2 + if self.internal_kv_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." + ) + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task="text_generation", + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, + force_max_tokens=True, + engine_type="onnxruntime", + ) + output = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, ) - assert len(output_sequences.sequences) == 2 - for sequences in output_sequences.sequences: - assert len(sequences) == 2 - def test_model_output_cache(self, setup): - pipeline, model_name, _, short_prompt, long_prompt = setup + def test_ort_generation_after_kv_cache_has_been_filled(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is filled up (old entries are removed) + # 3. KV Cache managed externally + if self.internal_kv_cache: pytest.skip( - "Running pipeline with internal " - "deepsparse cache will not result " - "in meaningful cache entries." + "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - self._test_cache_state(short_prompt, pipeline, model_name) - self._test_cache_state(long_prompt, pipeline, model_name) - - def test_callback(self, setup): - pipeline, *_ = setup - - def dummy_callback(token): - global START - START += 1 - return START < 3 - - inputs = { - "sequences": "def fib(a, b, accumulator=0)", - "callback": dummy_callback, - "return_logits": True, - "max_tokens": self.max_generated_tokens, - } - - outs = pipeline(**inputs) - assert outs.logits.shape[1] == 3 - - def _test_cache_state(self, prompt, pipeline, model_name): - # make sure that the cache state after running a prompt - # is correct - - pipeline(sequences=prompt, max_tokens=self.max_generated_tokens) - cache_state_dict = pipeline.engine.kv_cache.cached_inputs - cache_state_list = [cache_state_dict[key] for key in cache_state_dict.keys()] - - # generate ground truth from ORT - target_cache_state = self._get_cache_state_ort_kv_cache( - model_onnx_path=self.model.deployment.get_file("model.onnx").path, - sequence=prompt, - model_name=model_name, + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task="text_generation", + model_path=self.model_stub, + sequence_length=self.sequence_length_short, + prompt_sequence_length=self.prompt_sequence_length, + force_max_tokens=True, + engine_type="onnxruntime", + ) + output = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) - # get the number of processed prompt tokens - num_prompt_tokens = len(pipeline.tokenizer.tokenize(prompt)) + int( - pipeline.engine._freeze_first_position + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + "for this scenario, the kv cache should be full: " + "the total number of processed tokens should be " + "greater than the sequence length" ) - for x, y in zip(cache_state_list, target_cache_state): - """ - x will be a cache array - [blank, blank, ..., prompt_cache_1, prompt_cache_2, ..., - gen_token_cache_1, gen_token_cache_2, ...] - we need to first remove blank entries and then keep the - remaining prompt_cache entries (remove gen_token_cache entries) - """ - first_non_blank_cache_entry = min( - i for i in range(x.shape[2]) if np.count_nonzero(x[:, :, i, :]) - ) - x = x[:, :, first_non_blank_cache_entry:, :] - x = x[:, :, :num_prompt_tokens, :] - - """ - y will be a cache array - [blank, blank, ..., prompt_cache_1, prompt_cache_2, ...] - we need to keep the prompt_cache entries only - """ - y = y[:, :, -num_prompt_tokens:, :] - - assert np.allclose(x, y, atol=1e-4) - - def _get_output_huggingface(self, sequences, model_name): - hf_outputs = [] - # setup tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_name) - tokenizer.padding_side = "left" - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # setup model - model = AutoModelForCausalLM.from_pretrained(model_name) - - # generate ground truth output - for prompt in sequences: - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - generated_ids = model.generate( - input_ids, max_new_tokens=self.max_generated_tokens - ) - hf_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - hf_outputs.append(hf_output) - return hf_outputs + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 + ) - @staticmethod - def _get_cache_state_ort_kv_cache(model_onnx_path, sequence, model_name): - # setup tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_name) - tokenizer.padding_side = "left" - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # setup model and session - # (run full sequence inference) - overwrite_onnx_model_inputs_for_kv_cache_models( - model_onnx_path, sequence_length=128, input_ids_length=128 + def test_deepsparse_single_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally or internally + + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task="text_generation", + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=1, + force_max_tokens=True, + internal_kv_cache=self.internal_kv_cache, + ) + output = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, ) - sess = onnxruntime.InferenceSession(model_onnx_path) - # get model inputs - onnx_model = onnx.load(model_onnx_path, load_external_data=False) - model_inputs = [x.name for x in onnx_model.graph.input] - kv_cache = _initialize_kv_cache_state(model=onnx_model) + def test_deepsparse_multi_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally or internally - inputs = tokenizer( - sequence, return_tensors="np", padding="max_length", max_length=128 + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task="text_generation", + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, + force_max_tokens=True, + internal_kv_cache=self.internal_kv_cache, + ) + output = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, ) - onnxruntime_inputs = dict( - attention_mask=inputs["attention_mask"], - input_ids=inputs["input_ids"], - **kv_cache, + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, ) - if "positions" in model_inputs: - attention_mask = inputs["attention_mask"] - positions = attention_mask.cumsum(1) * attention_mask - 1 - onnxruntime_inputs["positions"] = positions + def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is filled up (old entries are removed) + # 3. KV Cache managed externally or internally - if "causal_mask" in model_inputs: - causal_mask = create_causal_mask( - input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] - ) - onnxruntime_inputs["causal_mask"] = causal_mask + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task="text_generation", + model_path=self.model_stub, + sequence_length=self.sequence_length_short, + prompt_sequence_length=self.prompt_sequence_length, + force_max_tokens=True, + internal_kv_cache=self.internal_kv_cache, + ) + output = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, + ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + "for this scenario, the kv cache should be full: " + "the total number of processed tokens should be " + "greater than the sequence length" + ) + + self._test_output( + output=output, + cache_session=cache_session, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, + max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 + ) - # run inference and return the cache state - outputs = sess.run(None, onnxruntime_inputs) - logits, *kv_cache = outputs + def test_run_same_prompt_multiple_times(self, setup): + # Test the scenario, where the same prompt is run multiple times + # Every run should produce the same output + pipeline = self.get_pipeline() - return kv_cache + output_1 = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, + ) + output_2 = pipeline( + sequences=self.prompt, + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, + ) + assert output_1.sequences[0] == output_2.sequences[0] + assert numpy.allclose(output_1.logits, output_2.logits, atol=_PRECISION) + + def test_run_multiple_prompts_in_parallel(self, setup): + # Test the scenario, where multiple prompts are run in parallel + # Same two prompts should produce the same output + pipeline = self.get_pipeline() + + output = pipeline( + sequences=[self.prompt, self.prompt], + return_logits=True, + include_prompt_logits=True, + max_tokens=self.num_tokens_generate, + ) + + assert numpy.allclose(output.logits[0], output.logits[1], atol=_PRECISION) + assert output.sequences[0] == output.sequences[1] + + def test_num_generated_predictions(self, setup): + # Test the scenario, where multiple predictions are generated + # from the same prompt + pipeline = self.get_pipeline() + + output_sequences = pipeline( + sequences=[self.prompt], num_generated_predictions=2 + ) + assert len(output_sequences.sequences[0]) == 2 + + output_sequences = pipeline( + sequences=[self.prompt, self.prompt], num_generated_predictions=2 + ) + assert len(output_sequences.sequences) == 2 + for sequences in output_sequences.sequences: + assert len(sequences) == 2 + + def _test_output( + self, + output: "TextGenerationOutput", # noqa F821 + cache_session: DecoderKVCache, + torch_ground_truth: Tuple[numpy.ndarray, ...], + max_logits_difference_threshold: Optional[float] = None, + run_cache_validation: bool = True, + ): + # extract numpy arrays from cached_inputs + kv_cache_array = list(cache_session.cached_inputs.values()) + + ( + generated_logits, + prompt_logits, + prompt_kv_cache, + generated_text, + ) = torch_ground_truth + + # concatenate target prompt_logits and generated_logits and check + target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) + + if max_logits_difference_threshold: + # if comparing the output from the model where + # the kv cache has been filled, we expect the + # maximum absolute difference between the logits + # to be less than the threshold + # (the threshold is established by running the + # ONNX model in ONNXRuntime) + assert ( + abs(output.logits - target_logits).max() + < max_logits_difference_threshold + ) + else: + # otherwise, we expect the logits to be exactly the same + # as the target logits; the generated sequence should + # also be the same as the target sequence, and finally + # (if applicable) the kv cache should be the same as the + # target kv cache + assert numpy.allclose(output.logits, target_logits, atol=_PRECISION) + assert self.prompt + output.sequences[0] == generated_text + + if run_cache_validation: + self._test_kv_cache_state( + expected_cache=kv_cache_array, + target_cache=torch_ground_truth[2], + total_num_processed_tokens=cache_session.total_num_processed_tokens, + ) + + @staticmethod + def _test_kv_cache_state( + expected_cache: List[numpy.ndarray], + target_cache: List[numpy.ndarray], + total_num_processed_tokens: int, + ): + for x, y in zip(expected_cache, target_cache): + start_index = total_num_processed_tokens + end_index = total_num_processed_tokens - y.shape[2] + # x is (in general) composed of three arrays: + # - padding cache entries (from 0 to -start_index) + # - prompt cache entries (from -start_index to -end_index) + # - generated cache entries (from -end_index to -1) + # as target_cache only pertains to prompt cache entries, we need to + # compare only the prompt cache entries in x with y + assert numpy.allclose( + x[:, :, -start_index:-end_index, :], y, atol=_PRECISION + )