Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Turn off the (currently) inefficient external KV cache logic when internal KV cache management enabled #1175

Merged
merged 99 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
0bcf1ea
fix kv cache
SageMoore Jul 20, 2023
d9487bc
Merge branch 'main' into kv-cache-fixes
dbogunowicz Jul 24, 2023
1485478
refactor
dbogunowicz Jul 24, 2023
23c97c8
add validation pathway
dbogunowicz Jul 24, 2023
2deb33b
avx2 support
dbogunowicz Jul 25, 2023
4526499
add import
dbogunowicz Jul 25, 2023
1898a56
Merge remote-tracking branch 'origin/main' into kv-cache-fixes
dbogunowicz Jul 31, 2023
f41689a
initial commit
dbogunowicz Jul 31, 2023
4ed646a
initial commit
dbogunowicz Jul 31, 2023
7f34062
initial implementation
dbogunowicz Aug 1, 2023
817a1fa
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 1, 2023
95b0082
problems with multitoken prefill
dbogunowicz Aug 1, 2023
db566c9
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 2, 2023
ebf76fc
its working
dbogunowicz Aug 2, 2023
c8e54b8
Merge branch 'kv-cache-fixes' of https://github.com/neuralmagic/deeps…
dbogunowicz Aug 2, 2023
bc0e010
Create test_nl_decoder_engine.py
dbogunowicz Aug 3, 2023
36c6664
Merge branch 'feature/damian/fb_testing' into tests/damian/decoder_kv…
dbogunowicz Aug 3, 2023
2353cb2
almost there...
dbogunowicz Aug 3, 2023
124a922
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 3, 2023
aac5b85
finally all tests pass
dbogunowicz Aug 4, 2023
7ee2577
just need to change to stub
dbogunowicz Aug 4, 2023
ef38160
Merge remote-tracking branch 'origin/tests/damian/decoder_kv_cache' i…
dbogunowicz Aug 4, 2023
21b9456
Merge remote-tracking branch 'origin/tests/feature/nl_dec_engine' int…
dbogunowicz Aug 4, 2023
caef2f7
fix bad merge
dbogunowicz Aug 4, 2023
ffdc7fb
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 7, 2023
8d98f73
Merge remote-tracking branch 'origin/tests/damian/llms' into kv-cache…
dbogunowicz Aug 7, 2023
f6b6807
added some tests
dbogunowicz Aug 7, 2023
a80c46d
ready for review
dbogunowicz Aug 7, 2023
d752f53
Merge branch 'main' into feature/damian/fb_testing
dbogunowicz Aug 7, 2023
c055873
Merge branch 'feature/damian/fb_testing' into tests/damian/decoder_kv…
dbogunowicz Aug 7, 2023
d851ab4
Merge branch 'tests/damian/decoder_kv_cache' into tests/feature/nl_de…
dbogunowicz Aug 7, 2023
0a42d3f
Merge branch 'tests/feature/nl_dec_engine' into tests/damian/llms
dbogunowicz Aug 7, 2023
9e6ea03
Merge branch 'tests/damian/llms' into kv-cache-fixes
dbogunowicz Aug 7, 2023
bfceb29
[Text Generation][Tests] DecoderKVCache (#1154)
dbogunowicz Aug 8, 2023
b7a55f9
[Text Generation][Tests] NLDecoderEngine (#1155)
dbogunowicz Aug 8, 2023
fef5df0
Make tests work with stub (as much as possible), cleanup test names, …
dbogunowicz Aug 8, 2023
5c48ba5
initial commit
dbogunowicz Aug 8, 2023
09a4847
use patch from unittest library - remove additional dependency
dbogunowicz Aug 8, 2023
ce42fc5
improved logic
dbogunowicz Aug 8, 2023
6bd55cd
additional improvements
dbogunowicz Aug 8, 2023
ebf3c92
Update src/deepsparse/transformers/pipelines/text_generation.py
dbogunowicz Aug 8, 2023
23f764b
Merge remote-tracking branch 'origin/feature/damian/backward_comp_no_…
dbogunowicz Aug 8, 2023
486e041
Merge branch 'feature/damian/backward_comp_no_causal_mask_models' int…
dbogunowicz Aug 8, 2023
c7ffdd1
Merge branch 'main' into feature/damian/backward_comp_no_causal_mask_…
dbogunowicz Aug 8, 2023
4357e8f
Merge branch 'feature/damian/backward_comp_no_causal_mask_models' int…
dbogunowicz Aug 8, 2023
0cbb1a3
finish rebase
dbogunowicz Aug 8, 2023
bafbccd
Update src/deepsparse/utils/onnx.py
dbogunowicz Aug 9, 2023
349197e
Update src/deepsparse/utils/onnx.py
dbogunowicz Aug 9, 2023
2ce0daf
response to Ben's comments
dbogunowicz Aug 9, 2023
4524442
Merge branch 'feature/damian/backward_comp_no_causal_mask_models' of …
dbogunowicz Aug 9, 2023
b952107
finish rebasing
dbogunowicz Aug 9, 2023
2f9fe9d
Merge branch 'main' into feature/damian/backward_comp_no_causal_mask_…
dbogunowicz Aug 9, 2023
663e05f
Merge branch 'feature/damian/backward_comp_no_causal_mask_models' int…
dbogunowicz Aug 9, 2023
1a524c7
Merge remote-tracking branch 'origin/feature/damian/fb_testing' into …
dbogunowicz Aug 9, 2023
098c7cf
full support
dbogunowicz Aug 9, 2023
9fd5e56
Update tests/deepsparse/transformers/pipelines/test_text_generation.py
dbogunowicz Aug 9, 2023
b0a7c96
Merge branch 'feature/damian/fb_testing' into kv-cache-fixes
dbogunowicz Aug 9, 2023
4c12c38
initial commit
dbogunowicz Aug 9, 2023
caf4b42
clarify todo comment
bfineran Aug 9, 2023
1b657bc
update user messages + add assertion for safety
bfineran Aug 9, 2023
e1e10b3
[Text Generation] KV Cache internal Deepsparse support (#1135)
SageMoore Aug 9, 2023
14a5197
minor improvements before landing
dbogunowicz Aug 10, 2023
d41e8ea
Merge remote-tracking branch 'origin/main' into feature/damian/backwa…
dbogunowicz Aug 10, 2023
4881d30
Fix the helper function that has been broken after a merge
dbogunowicz Aug 10, 2023
5f8b1eb
Merge branch 'feature/damian/backward_comp_no_causal_mask_models' int…
dbogunowicz Aug 10, 2023
6ff9cf5
incomplete string in parametrize
dbogunowicz Aug 10, 2023
1741012
few nits before the merge
dbogunowicz Aug 10, 2023
4528295
Merge branch 'feature/damian/fb_testing' into feature/damian/optimize…
dbogunowicz Aug 10, 2023
617147c
pass dummy cache if internal cache management supported
dbogunowicz Aug 10, 2023
49b8a9c
Merge branch 'feature/damian/optimize_update_kv_cache' of https://git…
dbogunowicz Aug 10, 2023
8fad0e0
Apply suggestions from code review
dbogunowicz Aug 10, 2023
144e3e0
add missing property
dbogunowicz Aug 10, 2023
b06bac7
cleaner func
dbogunowicz Aug 10, 2023
cf6540c
Merge remote-tracking branch 'origin/main' into feature/damian/optimi…
dbogunowicz Aug 16, 2023
22b81a2
PR ready
dbogunowicz Aug 16, 2023
90f767f
add timing for KV cache update
Aug 10, 2023
51c1eaa
initial commit
dbogunowicz Aug 17, 2023
fbb8135
Merge branch 'main' into feature/damian/optimize_decoder
dbogunowicz Aug 18, 2023
1cb41cb
Merge remote-tracking branch 'origin/feature/damian/optimize_decoder'…
dbogunowicz Aug 18, 2023
7dbb44d
Merge branch 'main' into feature/damian/optimize_update_kv_cache
dbogunowicz Aug 18, 2023
6bf740f
Merge branch 'feature/damian/optimize_update_kv_cache' of https://git…
dbogunowicz Aug 22, 2023
d434edd
code review comments
dbogunowicz Aug 22, 2023
96eb68a
Nit: docstring typo
dbogunowicz Aug 23, 2023
69e0a76
nit: docstring style
dbogunowicz Aug 23, 2023
4bdf82b
Merge branch 'feature/damian/optimize_decoder' into feature/damian/op…
dbogunowicz Aug 23, 2023
66b3803
Merge branch 'kv-cache-update-improvement' into feature/damian/optimi…
dbogunowicz Aug 23, 2023
aa31d5e
Merge branch 'main' into feature/damian/optimize_decoder
dbogunowicz Aug 23, 2023
26f1e64
Merge branch 'feature/damian/optimize_decoder' into feature/damian/op…
dbogunowicz Aug 23, 2023
24f3a87
fix style
dbogunowicz Aug 23, 2023
ce150cc
Merge branch 'feature/damian/optimize_decoder' into feature/damian/op…
dbogunowicz Aug 23, 2023
d1b3f5e
Merge branch 'feature/damian/optimize_update_kv_cache' of https://git…
dbogunowicz Aug 23, 2023
921f2f5
Merge branch 'main' into feature/damian/optimize_update_kv_cache
dbogunowicz Aug 24, 2023
c90a06b
Merge branch 'main' into feature/damian/optimize_update_kv_cache
dbogunowicz Aug 24, 2023
4c34f1d
Merge branch 'feature/damian/optimize_update_kv_cache' of https://git…
dbogunowicz Aug 24, 2023
519bf1b
fix broken test
dbogunowicz Aug 24, 2023
7da933b
fixing bad rebase
dbogunowicz Aug 25, 2023
41ef43c
Merge remote-tracking branch 'origin/main' into HEAD
dbogunowicz Aug 25, 2023
a43b063
Merge branch 'main' into feature/damian/optimize_update_kv_cache
dbogunowicz Aug 28, 2023
b7f03fc
Merge branch 'main' into feature/damian/optimize_update_kv_cache
dbogunowicz Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
sampling_temperature: float = 1.0,
deterministic: bool = True,
engine_context: Optional[Context] = None,
use_deepsparse_cache=False,
use_deepsparse_cache: bool = False,
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None
Expand Down Expand Up @@ -93,10 +93,12 @@ def __init__(
engine_args=engine_args,
context=engine_context,
)

self.sequence_length = sequence_length
self.sampling_temperature = sampling_temperature
self.deterministic = deterministic
self.input_ids_length = input_ids_length
self.cache_length = sequence_length - input_ids_length
Satrat marked this conversation as resolved.
Show resolved Hide resolved
self.kv_cache_enabled = kv_cache_enabled
self.kv_cache = (
DecoderKVCache(use_deepsparse_cache) if kv_cache_enabled else None
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -134,35 +136,41 @@ def onnx_input_names_no_cache(self) -> List[str]:
@property
def num_non_blank_cache_entries(self) -> int:
"""
:return a number of non-blank entries in the
kv cache
:return A number of non-blank entries in the
kv cache
"""
return self.kv_cache.num_non_blank_entries

@property
def internal_cache_active(self) -> bool:
"""
:return: Whether the internal kv cache is active
"""
return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None

def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
"""
Run the engine with the given inputs.

If the internal deepsparse kv cache management is enable,
If the self.internal_cache_active=True, the internal
deepsparse kv cache management is enabled. In this case
the LIB.kv_cache class object will be passed to the engine
call as well.

:param inputs: The inputs to run the engine with
:param val_inp: Whether the input is for validation or not

:return: The output of the engine
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
"""

if self.kv_cache is not None:
if self.kv_cache._kv_cache is not None:
if val_inp:
self.engine._validate_inputs(inputs)
# model has kv cache support, as well as deepsparse
# internal management of the kv cache
return self.engine._eng_net.execute_list_out(
inputs, self.kv_cache._kv_cache
)

if self.internal_cache_active:
# validate the inputs if needed
if val_inp:
self.engine._validate_inputs(inputs)
# run the engine with the LIB.kv_cache object
return self.engine._eng_net.execute_list_out(
inputs, self.kv_cache.engine_internal_cache
)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
# run the engine without the LIB.kv_cache object
return self.engine.run(inputs, val_inp)

def __call__(
Expand All @@ -180,8 +188,8 @@ def __call__(
:return: The generated token and corresponding logits
"""
if self.kv_cache:
# if kv cache is enabled, we need to add the kv cache state
# to the input
# if model has kv cache enabled, we need
# to add the kv cache state to the input
inp = self.add_kv_cache_to_input(inp)

out = self.run(inp, val_inp)
Expand Down Expand Up @@ -217,8 +225,7 @@ def transfer_cache_state(self, cache: DecoderKVCache):
:param cache: The `DecoderKVCache` object to transfer to the engine
from
"""
target_cache_capacity = self.sequence_length - self.input_ids_length
cache.set_capacity(target_cache_capacity)
cache.set_capacity(self.cache_length)
self.kv_cache = cache

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
Expand All @@ -241,9 +248,7 @@ def reset_kv_cache(self):
"""
Resets the kv cache state.
"""
kv_cache_state = self._initialize_kv_cache_state(
self.sequence_length - self.input_ids_length
)
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
self.kv_cache.setup(
session_id=self._session_id,
state=kv_cache_state,
Expand All @@ -255,13 +260,27 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
"""
Takes the input and adds the past kv cache state to it.

If the internal kv cache is enabled, the kv cache state
will always be reinitialized to zeros. This is just to make sure
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
that the input shapes of the kv cache arrays to the
model are correct, the actual values are
being tracked internally inside the engine.

If the internal kv cache is disabled, we need to
fetch the kv cache state as numpy arrays
from the current session, or initialize it if required.


:param inp: The input to the model
:return The input with the kv cache state added to it
"""
kv_cache_state = self.kv_cache.cached_inputs
if kv_cache_state is None:
self.reset_kv_cache()
if self.internal_cache_active:
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
else:
kv_cache_state = self.kv_cache.cached_inputs
if kv_cache_state is None:
self.reset_kv_cache()
kv_cache_state = self.kv_cache.cached_inputs

for idx, input_name in enumerate(self.onnx_input_names_no_cache):
kv_cache_state[input_name] = inp[idx]
Expand All @@ -277,9 +296,17 @@ def update_kv_cache(
"""
Updates the state of the kv cache

If the internal kv cache is enabled, we refrain from
updating the kv cache state as it is being tracked internally
inside the engine. We only update the number of tokens processed.

:param kv_cache_state: The state of the kv cache storage
:param input_ids_len: The length of input_ids
"""
if self.internal_cache_active:
self.kv_cache.total_num_processed_tokens += input_ids_len
return

cache_onnx_names = [
name
for name in self.engine.input_names
Expand Down
11 changes: 7 additions & 4 deletions src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(self, use_deepsparse_cache: bool = False):
The goal this object is to handle the manipulation
of the key value cache.

:param use_deepsparse_cache: If set to True, the `kv_cache` object
from the deepsparse.LIB will be loaded as an attribute.
:param use_deepsparse_cache: If set to True, the
`kv_cache` object from the deepsparse.LIB will
be loaded as an engine_internal_cache attribute.
This object is used to handle the manipulation of the
key/value buffers on the DeepSparse engine side.
"""
Expand All @@ -45,7 +46,7 @@ def __init__(self, use_deepsparse_cache: bool = False):
self._session_id = None
self._freeze_first_position = None
self._state = None
self._kv_cache = None
self.engine_internal_cache = None

def setup(
self,
Expand Down Expand Up @@ -82,7 +83,9 @@ def setup(
if self._use_deepsparse_cache:
prev_num_tokens = self.total_num_processed_tokens
num_frozen_tokens = int(self._freeze_first_position)
self._kv_cache = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
self.engine_internal_cache = LIB.kv_cache(
prev_num_tokens, num_frozen_tokens
)

def update(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class DummyKVCacheDecoder:
"past_key_values_1": np.array([10, 11, 12]),
"past_key_values_2": np.array([13, 14, 15]),
}
engine_internal_cache = None


class DummyEngine:
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_add_kv_cache_to_input():
nl_decoder_engine = NLDecoderEngine(None, None)
nl_decoder_engine.engine = DummyEngine()
nl_decoder_engine.kv_cache = DummyKVCacheDecoder()
nl_decoder_engine.kv_cache_enabled = True
result = nl_decoder_engine.add_kv_cache_to_input(inp)

for (x, y) in zip(result, expected_result):
Expand Down