Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Internal KV Cache Support + Initial Testing Framework #1163

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 13 additions & 76 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
from typing import Any, Dict, List, Optional, Tuple

import numpy
import onnx
from transformers import AutoTokenizer

from deepsparse.engine import Context
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.helpers import (
generate_session_id,
overwrite_onnx_model_inputs,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils.onnx import save_onnx


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,7 +71,11 @@ def __init__(
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
batch_size=engine_args.get("batch_size", 1),
sequence_length=sequence_length,
Expand All @@ -80,6 +84,7 @@ def __init__(
kv_cache_enabled = False
if sum(output_indices_to_be_cached):
kv_cache_enabled = True
self.kv_cache_data_type = kv_cache_data_type
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
# inform the engine, that are using the kv cache
engine_args["cache_output_bools"] = output_indices_to_be_cached
Expand Down Expand Up @@ -192,74 +197,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
cache_to_copy.set_capacity(target_cache_capacity)
self.kv_cache = cache_to_copy

def overwrite_onnx_model_inputs(
self,
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.

:param onnx_file_path: The path to the onnx model file that will be
overwritten with the new input shapes
:param batch_size: The batch size to use for the input
:param sequence_length: The sequence length to use for the input
:param input_ids_length: The length of input_ids
:return: The path to the onnx model file that has been overwritten
with the new input shapes, as well as the indices of the inputs
that should be cached
"""
model = onnx.load(onnx_file_path, load_external_data=False)
initializer_input_names = set(node.name for node in model.graph.initializer)
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
for external_input in external_inputs:
# overwrite the batch size for all the inputs
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size

if external_input.name in ["input_ids", "positions"]:
external_input.type.tensor_type.shape.dim[
1
].dim_value = input_ids_length
elif external_input.name == "attention_mask":
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
elif external_input.name.startswith(_CACHE_INPUT_NAME):
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
elif external_input.name.startswith("causal_mask"):
external_input.type.tensor_type.shape.dim[
2
].dim_value = input_ids_length
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
else:
raise ValueError(
f"Unexpected external input name: {external_input.name}"
)

_LOGGER.info(
"Overwriting in-place the input shapes "
f"of the transformer model at {onnx_file_path}"
)
save_onnx(model, onnx_file_path)

output_indices_to_be_cached = [
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]
if any(output_indices_to_be_cached):
kv_cache_elem_type = next(
inp
for inp in model.graph.input
if inp.name.startswith(_CACHE_INPUT_NAME)
).type.tensor_type.elem_type
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.
Expand All @@ -283,7 +220,7 @@ def reset_kv_cache(self):
kv_cache_state = self._initialize_kv_cache_state(
self.sequence_length - self.input_ids_length
)
self.kv_cache.setup_session(
self.kv_cache.setup(
session_id=self._session_id,
state=kv_cache_state,
num_processed_tokens=0,
Expand Down Expand Up @@ -328,7 +265,7 @@ def update_kv_cache(
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
}

self.kv_cache.update_session(
self.kv_cache.update(
state=kv_cache_state,
input_ids_len=input_ids_len,
)
Expand Down Expand Up @@ -364,6 +301,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
# (True if tokenizer has a prefix for a BOS token)
if tokenizer is None:
return False
if hasattr(tokenizer, "bos_token"):
if hasattr(tokenizer, "add_bos_token"):
return True
return False
34 changes: 16 additions & 18 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,8 @@ def prompt_inference(
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
new_token, new_logits = self.autoregressive_inference(run_tokens)

prompt_logits.append(new_logits)

tokens.append(new_token)
Expand All @@ -501,16 +500,12 @@ def prompt_inference(
def autoregressive_inference(
self,
tokens: List[int],
shift_positions_by_one: bool = False,
) -> Tuple[int, numpy.ndarray]:
"""
An inference run that processes the last token to generate
a new token and new logits.

:param tokens: The current context (prompt + generated tokens so far)
:param shift_positions_by_one: Whether to shift the positions
by one. Used if we are processing the prompt from the scratch
(i.e. not using the multitoken engine)
:return: The new, generated token and the logits for the new token
(with dimensions ['batch_size', 'num_tokens', 'vocab_size'])
"""
Expand All @@ -522,8 +517,7 @@ def autoregressive_inference(
num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len
attention_mask[:, -num_tokens_processed:] = 1
positions = numpy.array([[len(tokens)]], dtype=numpy.int64)
if shift_positions_by_one:
positions -= 1
positions -= 1
input_ids = numpy.array([[new_token]])
causal_mask = create_causal_mask(input_ids, attention_mask)

Expand Down Expand Up @@ -580,28 +574,28 @@ def engine_inputs_for_prefill(
num_batches = len(tokens) // self.prompt_processing_sequence_length

token_batches = [
tokens[i : i + self.prompt_processing_sequence_length]
for i in range(num_batches)
tokens[
i
* self.prompt_processing_sequence_length : (i + 1)
* self.prompt_processing_sequence_length
]
for i in range(0, num_batches)
]

for idx, token_batch in enumerate(token_batches):
engine_inputs = []

num_cached_entries = self.multitoken_engine.num_non_blank_cache_entries
for name in self.multitoken_engine.onnx_input_names_no_cache:
if name == "input_ids":
engine_input = numpy.array([token_batch])

elif name == "attention_mask":
num_cached_entries = (
self.multitoken_engine.num_non_blank_cache_entries
)

# create an empty attention mask
engine_input = numpy.zeros(
(1, self.sequence_length), dtype=numpy.int64
)
# fill it out with 1s (from the right), so that the number
# of unmaksed entries is equal to the sum of:
# of unmasked entries is equal to the sum of:
engine_input[
:,
-(
Expand All @@ -621,7 +615,11 @@ def engine_inputs_for_prefill(
engine_input = numpy.array([[idx]], dtype=numpy.int64)
else:
engine_input = (
numpy.arange(self.prompt_processing_sequence_length)
numpy.arange(
num_cached_entries,
num_cached_entries
+ self.prompt_processing_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
)
Expand Down
10 changes: 5 additions & 5 deletions src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, use_deepsparse_cache: bool = False):
self._state = None
self._kv_cache = None

def setup_session(
def setup(
self,
session_id: str,
state: Dict[str, Any],
Expand Down Expand Up @@ -80,7 +80,7 @@ def setup_session(
if self._use_deepsparse_cache:
raise NotImplementedError("DeepSparse cache is not supported yet.")

def update_session(
def update(
self,
state: Dict[str, Any],
input_ids_len: int,
Expand Down Expand Up @@ -233,7 +233,7 @@ def _add_entries(
return state

@property
def session_id(self):
def id(self):
if self._session_id is None:
raise ValueError("Attempted to access session_id before setting up session")
return self._session_id
Expand All @@ -259,8 +259,8 @@ def capacity(self) -> int:
self._sequence_len_axis
]

@session_id.setter
def session_id(self, session_id: str):
@id.setter
def id(self, session_id: str):
self._session_id = session_id

@property
Expand Down
77 changes: 76 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,93 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import uuid
from typing import List, Union
from typing import List, Tuple, Union

import numpy
import onnx

from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils import save_onnx


__all__ = [
"generate_session_id",
"pad_to_fixed_length",
"create_causal_mask",
"overwrite_onnx_model_inputs",
]

_LOGGER = logging.getLogger(__name__)


def overwrite_onnx_model_inputs(
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.

:param onnx_file_path: The path to the onnx model file that will be
overwritten with the new input shapes
:param batch_size: The batch size to use for the input
:param sequence_length: The sequence length to use for the input
:param input_ids_length: The length of input_ids
:return: A tuple that contains:
- the path to the onnx model file that has been overwritten
with the new input shapes
- boolean list, where elements are set to True if the
corresponding model output should be cached or False
if not.
- the data type of the kv cache. If the model does not
use kv cache, then the data type is None
"""
model = onnx.load(onnx_file_path, load_external_data=False)
initializer_input_names = set(node.name for node in model.graph.initializer)
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
for external_input in external_inputs:
# overwrite the batch size for all the inputs
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size

if external_input.name in ["input_ids", "positions"]:
external_input.type.tensor_type.shape.dim[1].dim_value = input_ids_length
elif external_input.name == "attention_mask":
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
elif external_input.name.startswith("past_key_values"):
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
elif external_input.name.startswith("causal_mask"):
external_input.type.tensor_type.shape.dim[2].dim_value = input_ids_length
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
else:
raise ValueError(f"Unexpected external input name: {external_input.name}")

_LOGGER.info(
"Overwriting in-place the input shapes "
f"of the transformer model at {onnx_file_path}"
)
save_onnx(model, onnx_file_path)

output_indices_to_be_cached = [
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]

kv_cache_data_type = None
if any(output_indices_to_be_cached):
kv_cache_elem_type = next(
inp for inp in model.graph.input if inp.name.startswith("past_key_values")
).type.tensor_type.elem_type
kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type


def generate_session_id() -> str:
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/deepsparse/transformers/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading