Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Internal KV Cache Support + Initial Testing Framework #1163

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/deepsparse/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
cached_outputs: List[bool] = None,
):
BaseEngine.construct(
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
Expand All @@ -316,6 +317,7 @@ def __init__(
self._num_streams,
self._scheduler.value,
None,
cached_outputs,
)
else:
self._eng_net = LIB.deepsparse_engine(
Expand All @@ -325,6 +327,7 @@ def __init__(
self._num_streams,
self._scheduler.value,
None,
cached_outputs,
)

def __call__(
Expand Down
126 changes: 44 additions & 82 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple

import numpy
import onnx
from transformers import AutoTokenizer

from deepsparse.engine import Context
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.helpers import (
generate_session_id,
overwrite_onnx_model_inputs,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils.onnx import save_onnx


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,7 +70,11 @@ def __init__(
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
batch_size=engine_args.get("batch_size", 1),
sequence_length=sequence_length,
Expand All @@ -80,9 +83,10 @@ def __init__(
kv_cache_enabled = False
if sum(output_indices_to_be_cached):
kv_cache_enabled = True
self.kv_cache_data_type = kv_cache_data_type
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
# inform the engine, that are using the kv cache
engine_args["cache_output_bools"] = output_indices_to_be_cached
engine_args["cached_outputs"] = output_indices_to_be_cached

self.engine = create_engine(
onnx_file_path=onnx_file_path,
Expand All @@ -100,6 +104,7 @@ def __init__(
)
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
self._session_id = generate_session_id()
self._engine_type = engine_type

@property
def session_id(self) -> str:
Expand Down Expand Up @@ -135,6 +140,32 @@ def num_non_blank_cache_entries(self) -> int:
"""
return self.kv_cache.num_non_blank_entries

def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
"""
Run the engine with the given inputs.

If the internal deepsparse kv cache management is enable,
the LIB.kv_cache class object will be passed to the engine
call as well.

:param inputs: The inputs to run the engine with
:param val_inp: Whether the input is for validation or not

:return: The output of the engine
"""

if self.kv_cache is not None:
if self.kv_cache._kv_cache is not None:
if val_inp:
self.engine._validate_inputs(inputs)
# model has kv cache support, as well as deepsparse
# internal management of the kv cache
return self.engine._eng_net.execute_list_out(
inputs, self.kv_cache._kv_cache
)

return self.engine.run(inputs, val_inp)

def __call__(
self,
inp: List[numpy.ndarray],
Expand All @@ -154,7 +185,7 @@ def __call__(
# to the input
inp = self.add_kv_cache_to_input(inp)

out = self.engine.run(inp, val_inp)
out = self.run(inp, val_inp)

if self.kv_cache:
logits, *kv_cache_state = out
Expand Down Expand Up @@ -187,78 +218,9 @@ def transfer_cache_state(self, cache: DecoderKVCache):
:param cache: The `DecoderKVCache` object to transfer to the engine
from
"""
cache_to_copy = copy.deepcopy(cache)
target_cache_capacity = self.sequence_length - self.input_ids_length
cache_to_copy.set_capacity(target_cache_capacity)
self.kv_cache = cache_to_copy

def overwrite_onnx_model_inputs(
self,
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.

:param onnx_file_path: The path to the onnx model file that will be
overwritten with the new input shapes
:param batch_size: The batch size to use for the input
:param sequence_length: The sequence length to use for the input
:param input_ids_length: The length of input_ids
:return: The path to the onnx model file that has been overwritten
with the new input shapes, as well as the indices of the inputs
that should be cached
"""
model = onnx.load(onnx_file_path, load_external_data=False)
initializer_input_names = set(node.name for node in model.graph.initializer)
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
for external_input in external_inputs:
# overwrite the batch size for all the inputs
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size

if external_input.name in ["input_ids", "positions"]:
external_input.type.tensor_type.shape.dim[
1
].dim_value = input_ids_length
elif external_input.name == "attention_mask":
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
elif external_input.name.startswith(_CACHE_INPUT_NAME):
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
elif external_input.name.startswith("causal_mask"):
external_input.type.tensor_type.shape.dim[
2
].dim_value = input_ids_length
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
else:
raise ValueError(
f"Unexpected external input name: {external_input.name}"
)

_LOGGER.info(
"Overwriting in-place the input shapes "
f"of the transformer model at {onnx_file_path}"
)
save_onnx(model, onnx_file_path)

output_indices_to_be_cached = [
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]
if any(output_indices_to_be_cached):
kv_cache_elem_type = next(
inp
for inp in model.graph.input
if inp.name.startswith(_CACHE_INPUT_NAME)
).type.tensor_type.elem_type
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached
cache.set_capacity(target_cache_capacity)
self.kv_cache = cache

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Expand All @@ -283,7 +245,7 @@ def reset_kv_cache(self):
kv_cache_state = self._initialize_kv_cache_state(
self.sequence_length - self.input_ids_length
)
self.kv_cache.setup_session(
self.kv_cache.setup(
session_id=self._session_id,
state=kv_cache_state,
num_processed_tokens=0,
Expand Down Expand Up @@ -328,7 +290,7 @@ def update_kv_cache(
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
}

self.kv_cache.update_session(
self.kv_cache.update(
state=kv_cache_state,
input_ids_len=input_ids_len,
)
Expand Down Expand Up @@ -364,6 +326,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
# (True if tokenizer has a prefix for a BOS token)
if tokenizer is None:
return False
if hasattr(tokenizer, "bos_token"):
if hasattr(tokenizer, "add_bos_token"):
return True
return False
47 changes: 19 additions & 28 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from transformers import TextStreamer

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
Expand Down Expand Up @@ -146,22 +145,16 @@ def __init__(
**kwargs,
):
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)
if not cpu_avx512_compatible() and kwargs_engine_type == DEEPSPARSE_ENGINE:
warnings.warn(
"AVX512 support not detected, disabling internal management "
"of KV cache which may affect performance. To enable full "
"performance, deploy on an AVX512-compatible system."
)
use_deepsparse_cache = False

if use_deepsparse_cache:
if kwargs_engine_type != DEEPSPARSE_ENGINE:
raise ValueError(
_LOGGER.warning(
"`use_deepsparse_cache` is set to True "
"but the chosen `engine_type` "
f"is {kwargs_engine_type}. "
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
f"The optimized kv cache management is disabled."
)
use_deepsparse_cache = False

super().__init__(
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
Expand Down Expand Up @@ -493,9 +486,8 @@ def prompt_inference(
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
new_token, new_logits = self.autoregressive_inference(run_tokens)

prompt_logits.append(new_logits)

tokens.append(new_token)
Expand All @@ -505,16 +497,12 @@ def prompt_inference(
def autoregressive_inference(
self,
tokens: List[int],
shift_positions_by_one: bool = False,
) -> Tuple[int, numpy.ndarray]:
"""
An inference run that processes the last token to generate
a new token and new logits.

:param tokens: The current context (prompt + generated tokens so far)
:param shift_positions_by_one: Whether to shift the positions
by one. Used if we are processing the prompt from the scratch
(i.e. not using the multitoken engine)
:return: The new, generated token and the logits for the new token
(with dimensions ['batch_size', 'num_tokens', 'vocab_size'])
"""
Expand All @@ -526,8 +514,7 @@ def autoregressive_inference(
num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len
attention_mask[:, -num_tokens_processed:] = 1
positions = numpy.array([[len(tokens)]], dtype=numpy.int64)
if shift_positions_by_one:
positions -= 1
positions -= 1
input_ids = numpy.array([[new_token]])
causal_mask = create_causal_mask(input_ids, attention_mask)

Expand Down Expand Up @@ -584,28 +571,28 @@ def engine_inputs_for_prefill(
num_batches = len(tokens) // self.prompt_processing_sequence_length

token_batches = [
tokens[i : i + self.prompt_processing_sequence_length]
for i in range(num_batches)
tokens[
i
* self.prompt_processing_sequence_length : (i + 1)
* self.prompt_processing_sequence_length
]
for i in range(0, num_batches)
]

for idx, token_batch in enumerate(token_batches):
engine_inputs = []

num_cached_entries = self.multitoken_engine.num_non_blank_cache_entries
for name in self.multitoken_engine.onnx_input_names_no_cache:
if name == "input_ids":
engine_input = numpy.array([token_batch])

elif name == "attention_mask":
num_cached_entries = (
self.multitoken_engine.num_non_blank_cache_entries
)

# create an empty attention mask
engine_input = numpy.zeros(
(1, self.sequence_length), dtype=numpy.int64
)
# fill it out with 1s (from the right), so that the number
# of unmaksed entries is equal to the sum of:
# of unmasked entries is equal to the sum of:
engine_input[
:,
-(
Expand All @@ -625,7 +612,11 @@ def engine_inputs_for_prefill(
engine_input = numpy.array([[idx]], dtype=numpy.int64)
else:
engine_input = (
numpy.arange(self.prompt_processing_sequence_length)
numpy.arange(
num_cached_entries,
num_cached_entries
+ self.prompt_processing_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
)
Expand Down
Loading