Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch engine prefix caching #1393

Closed
wants to merge 9 commits into from
Closed

Conversation

grimoire
Copy link
Collaborator

@grimoire grimoire commented Apr 4, 2024

Enable by set shared_cache=True.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 7, 2024

Hi @grimoire @lvhan028 Why did you choose the radix tree implementation? Have you considered using the hash table implementation? What factors did you consider, such as scalability or performance? Thanks.

@grimoire
Copy link
Collaborator Author

grimoire commented Apr 7, 2024

Any detail about the hash table implementation?
Honestly, I do not like my radix tree implementation in this PR.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 7, 2024

Any detail about the hash table implementation? Honestly, I do not like my radix tree implementation in this PR.

@ispobock may follow up. Currently researched the implementations of vLLM, RTP-LLM, and SGLang

@ispobock
Copy link
Contributor

ispobock commented Apr 7, 2024

@grimoire We compared the prefix cache implementation for other projects:

  • vllm

    • Hash Table
    • compute hash key for each block: hash(prefix tokens, tokens in this block)
    • block level reuse, if seq1: xxxxyyyy, seq2: xxxxzzzz, seq3 xxxxyyyyzzzz, each block contains 4 tokens, then seq2 can reuse the first block of seq1, seq3 can reuse 2 blocks of seq1
    • now only support prefix cache ( xxxxxoooo ), but plan to support general cache (xxxoooxxxooo) in the future
    • maybe need to consider hash collision
    • Complexity:
      • AssumeN is the number of seq,L is the length of seq
      • Time (Find & Insert): O(N*(L^2)), because compute hash key needs O(L^2), mentioned here
      • Space: O(N*L)
  • rtp-llm

    • Hash Table
    • compute hash key for each seq: hash(tokens in sequence)
    • block level reuse, like vllm
    • Complexity:
      • Time
        • Find: O((N^2)*L), due to token level match
        • Insert: O(N*L)
      • Space: O(N*L)
  • sglang

    • Radix Tree
    • can only support prefix cache ( xxxxxoooo ), cannot support general cache (xxxoooxxxooo)
    • Complexity:
      • Time (Find & Insert): O(N*L)
      • Space: worst O(N*L), if no shared part

@grimoire
Copy link
Collaborator Author

grimoire commented Apr 7, 2024

When do we need general cache?

@grimoire
Copy link
Collaborator Author

grimoire commented Apr 7, 2024

@ispobock Do they support window attention? How do they evict blocks? Would it take a long time if we have a large amount of blocks?

s-lora would increase number of blocks(by use a small block size) and window attention would make the block eviction more complex. I failed to find a good solution.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 7, 2024

@ispobock Do they support window attention? How do they evict blocks? Would it take a long time if we have a large amount of blocks?

s-lora would increase number of blocks(by use a small block size) and window attention would make the block eviction more complex. I failed to find a good solution.

In mistralai-sf24/hackathon, sliding window has been removed https://x.com/mistralailabs/status/1771670765521281370

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 7, 2024

And I think this approach is acceptable for now.

if self.window_size > 1 and self.shared_cache:
logger.warning(
'Shared cache is not available for window attention.')
self.shared_cache = False

@ispobock
Copy link
Contributor

ispobock commented Apr 7, 2024

@grimoire

When do we need general cache?

For example seq1: xxxxyyyyzzzz, seq2: yyyyzzzz, 4 tokens per block, for general cache, seq2 may use the last 2 cached blocks of seq1.
It's mentioned in vllm's design, but I'm not sure the real usage and implementation.

How do they evict blocks? Would it take a long time if we have a large amount of blocks?

It seems all of them are using reference count + LRU for evict policy.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 7, 2024

And I think this approach is acceptable for now.

if self.window_size > 1 and self.shared_cache:
logger.warning(
'Shared cache is not available for window attention.')
self.shared_cache = False

ref https://github.com/vllm-project/vllm/pull/2762/files#r1495331586

@grimoire
Copy link
Collaborator Author

grimoire commented Apr 7, 2024

Sure, let's ignore the sliding window for now.

It seems that the hash map does not bring much benefits to prefix matching. Eviction by blocks takes more time than eviction by node(sort by visit time, update ref-count/visit-time, update sequence status...).

But adding new concept node into the schedule made the code error prone and hard to maintain.
Any advice?

@ispobock
Copy link
Contributor

ispobock commented Apr 7, 2024

vllm didn't take the radix tree implementation due to the hard maintenance:

Major benefits of this design over a KV block Trie

  • Sometimes, caching is not limited to prefix caching:
    • With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
    • With attention sinks, we need to cache the first few tokens and the latest tokens.
  • Maintaining hash table is simpler than maintaining a tree.
  • Extensible to more advanced caching policy (the one above is just an example).

In sglang, actually there is no block concept because the size of each page is equivalent to one token, which simplified the implementation.

@lzhangzz
Copy link
Collaborator

lzhangzz commented Apr 7, 2024

For example seq1: xxxxyyyyzzzz, seq2: yyyyzzzz, 4 tokens per block, for general cache, seq2 may use the last 2 cached blocks of seq1.

In this case

  1. The positional embedding used for yyyyzzzz is offsetted by 4 steps (instead of starting from 0)
  2. xxxx which is involved in the computation of xxxxyyyyzzzz is ignored.

The result will be different from computing yyyyzzzz directly. The outcome maybe similar but you have no guarantee on it.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 8, 2024

vllm didn't take the radix tree implementation due to the hard maintenance:

Major benefits of this design over a KV block Trie

  • Sometimes, caching is not limited to prefix caching:

    • With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
    • With attention sinks, we need to cache the first few tokens and the latest tokens.
  • Maintaining hash table is simpler than maintaining a tree.

  • Extensible to more advanced caching policy (the one above is just an example).

In sglang, actually there is no block concept because the size of each page is equivalent to one token, which simplified the implementation.

Hi @grimoire Do you have any suggestions?

@grimoire
Copy link
Collaborator Author

grimoire commented Apr 8, 2024

Maintaining hash table is simpler than maintaining a tree.

That's true, especially when block size is not 1. In this PR, node is a wrap of sequence with meta info. I want to share the same block manage code to ease the implementation, but it ... sucks.

I want to try the block-based strategy. Guess it would take a long time to design and prototype since I don't want to break any features that already exist.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 8, 2024

Hi @grimoire I would like to know, is the completion of this PR currently ready for normal use? Thanks.

@grimoire
Copy link
Collaborator Author

grimoire commented Apr 8, 2024

@zhyncs Yes, this is not a draft.

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 9, 2024

ref #1407 (comment)

@grimoire grimoire changed the title Torch engine prefix cacheing Torch engine prefix caching Apr 11, 2024
@grimoire grimoire marked this pull request as draft April 12, 2024 07:40
@zhyncs
Copy link
Collaborator

zhyncs commented Apr 18, 2024

@grimoire We compared the prefix cache implementation for other projects:

  • vllm

    • Hash Table

    • compute hash key for each block: hash(prefix tokens, tokens in this block)

    • block level reuse, if seq1: xxxxyyyy, seq2: xxxxzzzz, seq3 xxxxyyyyzzzz, each block contains 4 tokens, then seq2 can reuse the first block of seq1, seq3 can reuse 2 blocks of seq1

    • now only support prefix cache ( xxxxxoooo ), but plan to support general cache (xxxoooxxxooo) in the future

    • maybe need to consider hash collision

    • Complexity:

      • AssumeN is the number of seq,L is the length of seq
      • Time (Find & Insert): O(N*(L^2)), because compute hash key needs O(L^2), mentioned here
      • Space: O(N*L)
  • rtp-llm

    • Hash Table

    • compute hash key for each seq: hash(tokens in sequence)

    • block level reuse, like vllm

    • Complexity:

      • Time

        • Find: O((N^2)*L), due to token level match
        • Insert: O(N*L)
      • Space: O(N*L)

  • sglang

    • Radix Tree

    • can only support prefix cache ( xxxxxoooo ), cannot support general cache (xxxoooxxxooo)

    • Complexity:

      • Time (Find & Insert): O(N*L)
      • Space: worst O(N*L), if no shared part

After sgl-project/sglang#364, SGLang Radix Tree implementation RPS increased by nearly 10%

@lvhan028 lvhan028 closed this May 7, 2024
@merrymercy
Copy link

Very good discussion here. ref vllm-project/vllm#2614 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants