Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of GPT.forward when it comes to input_pos and KV cache usage #1898

Open
mseeger opened this issue Jan 6, 2025 · 4 comments
Open
Labels
enhancement New feature or request

Comments

@mseeger
Copy link
Contributor

mseeger commented Jan 6, 2025

The current GPT.forward in model.py essentially serves two use cases:

  • Forward pass for training: input_pos=None, KV cache not used. Implicitly, input_pos = arange(idx.shape[-1]), and causal masking is used. Could also be used for prefill with prompt in inference.
  • Inference. input_pos is not None, KV cache is used. There seem two cases here, either input_pos = arange(idx.shape[-1]) (used for prefill), or input_pos.shape[-1] == 1 (generation of single next token, possibly batched).

I am interested in implementing KV cache strategies, such as H2O. In inference, we really only have prefill, and then single-token generation. Inference always works like this:

  • Prefill with sequence length T (minimum of prompt size and max cache size)
  • Generate token T
  • Generate token T+1
  • ...

Most KV cache strategies only support this protocol.

My proposal would be to refactor GPT.forward to support two cases only:

  • Forward pass for training: With an additional flag, this can be used for prefill, in that this would initialize the KV cache with the K and V vectors obtained as part of the forward, just because scaled_dot_product_attention is called
  • Generate single token: idx.shape[-1] == 1. input_pos is not really needed, it would rather be input_pos_maxp1. The KV cache tracks the position of the next token, and it would complain if asked to do anything else

This supports everything you have right now, plus it supports advanced KV caches like H2O.

I am happy to do this in a branch in my fork and show you how it would look like.

@mseeger mseeger added the enhancement New feature or request label Jan 6, 2025
@mseeger
Copy link
Contributor Author

mseeger commented Jan 6, 2025

I'd also implement batched generation based on this, which seems incomplete at the moment.

@t-vi
Copy link
Contributor

t-vi commented Jan 7, 2025

Hi @mseeger ,

great project and thanks for looking into it!

So two things, and this is from a batch perspective mainly:

  • Currently can batch with padding stuff if we prefill things of different lengths (or have a prefill/run combination). Ideally, we would not want to lose that ability.
  • Similarly, we may have external reasons to pad things and/or see the unused bits of the kvcache, so I would not slice the kvcache more than needed. Dynamic shapes can still be tricky with DL compilers, so it would make things easier if this padding would still work.
  • It may be interesting to keep track of the maximum input_pos somewhere, maybe @ali-alshaar7 has thoughts around that (I would have to look).

I would be very grateful if you could keep @ali-alshaar7 in the loop.

Again it's awesome to have you interested in KVCache!

@mseeger
Copy link
Contributor Author

mseeger commented Jan 8, 2025

Do you mean padding prompts of different length?

My first attempt would be to prefill up to the minimum length over prompts, and then go token by token, sampling when a prompt has been processed, and taking from prompt otherwise. In a later stage, one could prefill more (with padding the shorter prompts), and then remove the KV vectors corresponding to padding again. But this is more tricky to implement.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 8, 2025

Your comment on compiling a graph for inference. I think this would still work even with dynamic KV caching like H2O, because it only uses operators like argmin, topk, and scatter. The sizes and shapes of all arrays are determined up front and the same for every call. The only exception I can see is the stopping when encountering , but that is the same w/o KV caching.

And the default KV cache, if you don't specify any, will be the dense one which stores everything and allocates full memory up front. So the default behavior should not change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants