Skip to content

Commit

Permalink
[Text-Generation] Set kv cache inputs to empty arrays (size 0) when r…
Browse files Browse the repository at this point in the history
…unning internally (#1195)

* fix kv cache

* refactor

* add validation pathway

* avx2 support

* initial commit

* initial commit

* initial implementation

* problems with multitoken prefill

* its working

* Create test_nl_decoder_engine.py

* almost there...

* finally all tests pass

* just need to change to stub

* fix bad merge

* added some tests

* ready for review

* [Text Generation][Tests] DecoderKVCache (#1154)

* [Text Generation][Tests] NLDecoderEngine (#1155)

* initial commit

* initial commit

* [Text Generation][Tests] Text Generation Pipeline (#1162)

* initial implementation

* problems with multitoken prefill

* almost there...

* finally all tests pass

* just need to change to stub

* fix bad merge

* Make tests work with stub (as much as possible), cleanup test names,  disable heavy tests, include patch for running without causal mask

* initial commit

* use patch from unittest library - remove additional dependency

* improved logic

* additional improvements

* Update src/deepsparse/transformers/pipelines/text_generation.py

* Update src/deepsparse/utils/onnx.py

Co-authored-by: Benjamin Fineran <[email protected]>

* Update src/deepsparse/utils/onnx.py

Co-authored-by: Benjamin Fineran <[email protected]>

* response to Ben's comments

* finish rebasing

* full support

* Update tests/deepsparse/transformers/pipelines/test_text_generation.py

* initial commit

* clarify todo comment

* update user messages + add assertion for safety

* [Text Generation]  KV Cache internal Deepsparse support (#1135)

* fix kv cache

* refactor

* add validation pathway

* avx2 support

* initial commit

* initial commit

* initial implementation

* problems with multitoken prefill

* its working

* almost there...

* finally all tests pass

* just need to change to stub

* fix bad merge

* added some tests

* ready for review

* full support

---------

Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: Damian <[email protected]>

* minor improvements before landing

* Fix the helper function that has been broken after a merge

* incomplete string in parametrize

* few nits before the merge

* pass dummy cache if internal cache management supported

* Apply suggestions from code review

* add missing property

* cleaner func

* PR ready

* initial commit

* code review comments

* set kv cache inputs to empty arrays (size 0) when running internally

* TEMP: removing inputs filtering by name

* remove obsolete argument

* trying to find a solution

* this is working

* improve documentation

* review comments

* inline comment instead of warning

---------

Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: Damian <[email protected]>
Co-authored-by: Luka Govedic <[email protected]>
  • Loading branch information
5 people authored Sep 13, 2023
1 parent c3f07e3 commit 1439359
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
)

kv_cache_enabled = False
if sum(output_indices_to_be_cached):
if any(output_indices_to_be_cached):
kv_cache_enabled = True
self.kv_cache_data_type = kv_cache_data_type
if internal_kv_cache and engine_type == DEEPSPARSE_ENGINE:
Expand Down Expand Up @@ -157,18 +157,23 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]
If the self.internal_cache_active=True, the internal
deepsparse kv cache management is enabled. In this case
the LIB.kv_cache class object will be passed to the engine
call as well.
call as well. In this scenario also the inputs will not be
validated, even if the val_inp=True. This is because we
want to pass the empty kv cache inputs (batch_size=0) to
the engine.
:param inputs: The inputs to run the engine with
:param val_inp: Whether the input is for validation or not
:return: The output of the engine
"""

if self.internal_cache_active:
# validate the inputs if needed
if val_inp:
self.engine._validate_inputs(inputs)
# run the engine with the LIB.kv_cache object
# conventionally, before dispatching
# inputs to the engine, we validate them
# if val_inp=True. However, in this case
# we want to pass the empty kv cache inputs
# (batch_size=0) to the engine. Therefore,
# we skip the validation
return self.engine._eng_net.execute_list_out(
inputs, self.kv_cache.engine_internal_cache
)
Expand Down Expand Up @@ -266,7 +271,7 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
Takes the input and adds the past kv cache state to it.
If the internal kv cache is enabled, the kv cache state
will always be reinitialized to zeros. This is just to make sure
will always be an empty array. This is just to make sure
that the input shapes of the kv cache arrays to the
model are correct, the actual values are
being tracked internally inside the engine.
Expand All @@ -280,7 +285,9 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
:return The input with the kv cache state added to it
"""
if self.internal_cache_active:
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
kv_cache_state = self._initialize_kv_cache_state(
self.cache_length, empty=True
)
else:
kv_cache_state = self.kv_cache.cached_inputs
if kv_cache_state is None:
Expand Down Expand Up @@ -326,9 +333,13 @@ def update_kv_cache(
input_ids_len=input_ids_len,
)

def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
def _initialize_kv_cache_state(
self, length: int, empty: bool = False
) -> Dict[str, numpy.ndarray]:
# initialize empty kv cache of size
# (batch_size, num_attention_heads, length, hidden_dims)
# if empty is True, we initialize empty kv_cache
# and set the batch_size to 0

cache_engine_input_index = next(
i
Expand All @@ -340,7 +351,12 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
]

empty_kv_cache_tensor = numpy.zeros(
(batch_size, num_attention_heads, length, hidden_dims),
(
batch_size if not empty else 0,
num_attention_heads,
length,
hidden_dims,
),
dtype=self.kv_cache_data_type,
)

Expand Down

0 comments on commit 1439359

Please sign in to comment.