-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring of GPT.forward
when it comes to input_pos
and KV cache usage
#1898
Comments
I'd also implement batched generation based on this, which seems incomplete at the moment. |
Hi @mseeger , great project and thanks for looking into it! So two things, and this is from a batch perspective mainly:
I would be very grateful if you could keep @ali-alshaar7 in the loop. Again it's awesome to have you interested in KVCache! |
Do you mean padding prompts of different length? My first attempt would be to prefill up to the minimum length over prompts, and then go token by token, sampling when a prompt has been processed, and taking from prompt otherwise. In a later stage, one could prefill more (with padding the shorter prompts), and then remove the KV vectors corresponding to padding again. But this is more tricky to implement. |
Your comment on compiling a graph for inference. I think this would still work even with dynamic KV caching like H2O, because it only uses operators like And the default KV cache, if you don't specify any, will be the dense one which stores everything and allocates full memory up front. So the default behavior should not change. |
The current
GPT.forward
inmodel.py
essentially serves two use cases:input_pos=None
, KV cache not used. Implicitly,input_pos = arange(idx.shape[-1])
, and causal masking is used. Could also be used for prefill with prompt in inference.input_pos
is notNone
, KV cache is used. There seem two cases here, eitherinput_pos = arange(idx.shape[-1])
(used for prefill), orinput_pos.shape[-1] == 1
(generation of single next token, possibly batched).I am interested in implementing KV cache strategies, such as H2O. In inference, we really only have prefill, and then single-token generation. Inference always works like this:
Most KV cache strategies only support this protocol.
My proposal would be to refactor
GPT.forward
to support two cases only:scaled_dot_product_attention
is calledidx.shape[-1] == 1
.input_pos
is not really needed, it would rather beinput_pos_maxp1
. The KV cache tracks the position of the next token, and it would complain if asked to do anything elseThis supports everything you have right now, plus it supports advanced KV caches like H2O.
I am happy to do this in a branch in my fork and show you how it would look like.
The text was updated successfully, but these errors were encountered: