Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : custom attention mask + parallel decoding + no context swaps #3228

Merged
merged 57 commits into from
Sep 28, 2023

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Sep 17, 2023

close: #2060 #2813
ref: #3137

Merge ETA: ~ 27 - 30 Sep

Here we will attempt to add batched (a.k.a. parallel, multi-sequence) decoding to llama.cpp.
For a short summary of the plan, see #3137 (comment)

Features:

  • parallel decoding with common prompt
  • parallel decoding with separate prompts
  • parallel decoding with continuous batching
  • tree-based parallel decoding
  • avoided the slow context swaps when the context becomes full

API changes:

  • add batch API:
    • add struct llama_batch
    • add llama_batch_init
    • add llama_batch_free
    • add temporary helper llama_batch_get_one for easy API migration
  • deprecate llama_get_kv_cache_token_count
  • add KV cache management API:
    • llama_kv_cache_tokens_rm
    • llama_kv_cache_seq_rm
    • llama_kv_cache_seq_cp
    • llama_kv_cache_seq_keep
    • llama_kv_cache_seq_shift
  • update decoding API:
    • deprecate llama_eval and llama_eval_embd
    • add llama_decode
    • add llama_get_logits_ith

Demo

No context swaps

Previously, we had to re-evaluate the context when it becomes full and this could take a lot of time, especially on the CPU.
Now, this is avoided by correctly updating the KV cache on-the-fly:

# run on the CPU using a small context of 256 tokens
./main -m ./models/llama-7b/ggml-model-q4_0.gguf -p "I believe the meaning of life is" --ignore-eos -c 256 -n -1 -t 8 -ngl 0

Parallel decoding - basic example

examples/batched

The batched example has been extended with an argument to specify the number of sequences to generate using the given prompt. This is a good starting point for understanding the new llama_batch API introduced in this PR.

# the prompt is "Hello my name is" and the number of sequences that will be generated is 8
./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 8
Generated results
main: n_len = 32, n_ctx = 2048, n_parallel = 8, n_kv_req = 221

 Hello my name is

main: generating 8 sequences ...

main: stream 0 finished
main: stream 1 finished
main: stream 2 finished
main: stream 3 finished
main: stream 4 finished
main: stream 5 finished
main: stream 6 finished
main: stream 7 finished

sequence 0:

Hello my name is Sharon. I am looking for a caregiver job in Austin, Texas. I'm a really good candidate you'll

sequence 1:

Hello my name is Renee. I am a 20 year old female who has been working with children for several years. I have experience working

sequence 2:

Hello my name is Diana. I am 25 years old and I am from the Philippines. I have been working as a nanny for 

sequence 3:

Hello my name is Cathy and I'm from the UK. I'm 26 years old and I'm a student. I

sequence 4:

Hello my name is Chelsea. I'm 27 years old and live in Florida (United States). I'm looking for someone

sequence 5:

Hello my name is Katherine, I'm 27 years old and I'm from the UK. I'm a very friendly, out

sequence 6:

Hello my name is Tiffany. I am a 21 year old student at the University of South Carolina. I am a certified nurs

sequence 7:

Hello my name is Renee, I am a mother of two beautiful children. I am a stay at home mom and I have been a nanny

main: decoded 216 tokens in 1.52 s, speed: 141.83 t/s

llama_print_timings:        load time =   489.06 ms
llama_print_timings:      sample time =     4.48 ms /   224 runs   (    0.02 ms per token, 49988.84 tokens per second)
llama_print_timings: prompt eval time =  1962.52 ms /   226 tokens (    8.68 ms per token,   115.16 tokens per second)
llama_print_timings:        eval time =     0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =  2011.64 ms

ggml_metal_free: deallocating

Parallel decoding - server simulation

examples/parallel

Decoding 32 parallel streams for a total of 128 sequences with continuous batching in about ~26s. Token summary:

  • System: 305 (encoded once at the start)
  • Prompt: 2039 (summed over all inputs)
  • Decoded: 7986 (summed over all sequences)

Using 7B LLaMA v2 Q8_0 model on single RTX 4080 with 16GB VRAM.

./bin/parallel -m ../models/llama-7b-v2/ggml-model-q8_0.gguf -n 128 -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 32 -ns 128 -cb
parallel-cuda-1-lq.mp4

For comparison, generating the same amount of sequences using a single stream (i.e. -np 1) takes ~125s on the same hardware with the same model.

When using the parallel example, you must make sure that your KV cache will have enough size to fit all parallel requests. If it is not able to fit them, then you will start seeing cache misses that slow down the processing significantly and can eventually fail the decoding if there is no more space in the cache.

To set the KV cache size, use the -c, --context parameter. For example, for 32 parallel streams that are expected to generate a maximum of 128 tokens each (i.e. -n 128), you would need to set -c 4096 (i.e. 32*128). If continuous batching is enabled, you would need some extra KV space to deal with fragmentation of the cache. In the example above, we conveniently set the context size to 8192 to guarantee that there will be no issues.

Also, when computing the KV cache size you should also take into account the size of the system prompt (if any).

We can run a similar test on M2 Ultra. This time, since we have practically infinite VRAM, we can afford to run 128 parallel streams instead of just 32. And we can also use the F16 model instead of the quantum Q8_0. In this case, continuous batching is not needed, since we will be processing all of the 128 requests at once:

./bin/parallel -m ../models/llama-7b-v2/ggml-model-f16.gguf -n 128 -t 1 -ngl 100 -c 16384 -b 512 -s 1 -np 128 -ns 128
parallel-ultra-0-speed-up-x2.mp4

(video is speed-up to fit in Github 10MB limit)

Implementation details

  • KV cache as ring-buffer

    This is one of the major changes. Previously, the n_past number indicated what part of the KV cache to use in the computation. The assumption was that all cached data would be stored in the cache starting from the first cell forward and that this data belongs to one sequence only. We now want to store information from multiple sequences in the KV cache and the data from each sequence can be located in arbitrary cells. To achieve that, the KV cache now stores information about the sequences to which the cache data belongs and the corresponding position in these sequences:

    llama.cpp/llama.cpp

    Lines 1009 to 1018 in b377bf2

    struct llama_kv_cell {
    llama_pos pos = -1;
    llama_pos delta = 0;
    std::set<llama_seq_id> seq_id;
    bool has_seq_id(const llama_seq_id & id) const {
    return seq_id.find(id) != seq_id.end();
    }
    };

    With this change, we now have to attend to the entire cache when computing the self attention, because we don't know where the data for the sequences being processed is located. So the naive solution is when we build the graph to set n_kv = n_ctx. However, there is a simple heuristic that can be introduced in order to restore to most extend the behavior of n_past where we attend to just the first part of the cache:

    llama.cpp/llama.cpp

    Lines 2620 to 2624 in b377bf2

    const int32_t n_tokens = batch.n_tokens;
    const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
    const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;

    Here, kv_self.n contains the index of the last "occupied" cell in the cache. I.e. we know that all cells beyond kv_self.n are currently empty, so no need to attend to them. We naively compute it each time we build a graph:

    llama.cpp/llama.cpp

    Lines 4101 to 4107 in b377bf2

    // a heuristic, to avoid attending the full cache if it is not yet utilized
    // after enough generations, the benefit from this heuristic disappears
    // if we start defragmenting the cache, the benefit from this will be more important
    //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
    kv_self.n = std::max(32, llama_kv_cache_cell_max(kv_self));

    This heuristic is easy to implement and gives a measurable improvement, allowing to keep the performance the same as on master for single-sequence generations. It stops having any effect once the KV cache has been filled. In theory, we can improve this even further by computing n_kv_0 and n_kv_1 - the KV range that contains the cached tokens for all currently decoded sequences. We just need to find an elegant way to incorporate this information in the KV view offsets. (i.e. n_kv = n_kv_1 - n_kv_0, etc.).

Workplan

  • verify that RoPE is "additive" (tests/test-rope.cpp)
  • update ggml_rope to take a tensor with positions instead of just n_past
    • CPU
    • Metal
      • add rope_f16 kernel
    • CUDA
      • update rope_f32 kernel to work with vector of positions
      • add rope_f16 kernel (needed for K-cache shifts)
  • use ggml_add for applying a custom -INF mask to a tensor instead of the existing ggml_diag_mask_inf
    • CPU
    • Metal
      • kernel add support broadcast across dims 1,2,3
    • CUDA
  • update the graph to utilize the new ggml API for RoPE and mask
  • extend the KV cache to store position and sequence ID information
  • extend llama.h API for passing multi-sequence data
  • add llama_kv_shift() for "forgetting" old tokens, re-roping (if needed) the KV cache and compressing it
  • add example for generating N completions in parallel for a given prompt

TODOs

  • fix MPI will remain for a future PR
  • deprecate ggml_alibi(), replace with ggml_add() (similar to ggml_diag_mask_inf, Baichuan 13B)
  • fix KV cache fragmentation with continuous batching

Performance

I decided to disable the concurrency optimization since it prevents from modifying the graph topology. With the changes in this PR, we now need to sometimes add extra nodes for K-cache shift, so it is no longer compatible. The result is ~8% slower prompt processing but ~5% faster text generations with Metal.

M2 Ultra

model size th test master t/s PR t/s speedup
LLaMA 7B mostly F16 12.55 GiB 4 pp 512 1490.37 ± 1.29 1381.20 ± 1.64 0.927
LLaMA 7B mostly Q8_0 6.67 GiB 4 pp 512 1326.17 ± 0.49 1231.50 ± 0.99 0.929
LLaMA 7B mostly Q4_0 3.56 GiB 4 pp 512 1355.75 ± 0.30 1258.24 ± 0.77 0.928
LLaMA 7B mostly Q4_1 3.95 GiB 4 pp 512 1351.95 ± 0.58 1256.14 ± 0.97 0.929
LLaMA 7B mostly Q6_K 5.15 GiB 4 pp 512 1106.49 ± 0.45 1033.03 ± 0.82 0.934
LLaMA 7B mostly Q5_K - Medium 4.45 GiB 4 pp 512 1103.09 ± 0.65 1028.90 ± 1.40 0.933
LLaMA 7B mostly Q5_K - Small 4.33 GiB 4 pp 512 1102.25 ± 0.33 1027.88 ± 0.83 0.933
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 4 pp 512 1168.98 ± 0.60 1089.69 ± 1.08 0.932
LLaMA 7B mostly Q4_K - Small 3.59 GiB 4 pp 512 1178.79 ± 0.72 1097.91 ± 0.56 0.931
LLaMA 7B mostly Q3_K - Medium 3.07 GiB 4 pp 512 1142.54 ± 0.65 1064.15 ± 0.39 0.931
LLaMA 7B mostly Q3_K - Small 2.75 GiB 4 pp 512 1119.60 ± 0.62 1044.46 ± 0.53 0.933
LLaMA 7B mostly F16 12.55 GiB 4 tg 128 40.94 ± 0.04 41.39 ± 0.04 1.011
LLaMA 7B mostly Q8_0 6.67 GiB 4 tg 128 64.77 ± 0.05 67.84 ± 0.03 1.047
LLaMA 7B mostly Q4_0 3.56 GiB 4 tg 128 91.14 ± 0.10 96.55 ± 0.09 1.059
LLaMA 7B mostly Q4_1 3.95 GiB 4 tg 128 85.97 ± 0.10 89.57 ± 0.10 1.042
LLaMA 7B mostly Q6_K 5.15 GiB 4 tg 128 71.44 ± 0.04 74.77 ± 0.06 1.047
LLaMA 7B mostly Q5_K - Medium 4.45 GiB 4 tg 128 72.56 ± 0.05 75.44 ± 0.06 1.040
LLaMA 7B mostly Q5_K - Small 4.33 GiB 4 tg 128 74.00 ± 0.06 76.96 ± 0.10 1.040
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 4 tg 128 83.71 ± 0.16 87.92 ± 0.12 1.050
LLaMA 7B mostly Q4_K - Small 3.59 GiB 4 tg 128 87.05 ± 0.08 91.49 ± 0.08 1.051
LLaMA 7B mostly Q3_K - Medium 3.07 GiB 4 tg 128 83.47 ± 0.14 87.62 ± 0.10 1.050
LLaMA 7B mostly Q3_K - Small 2.75 GiB 4 tg 128 85.14 ± 0.14 88.23 ± 0.08 1.036

build: 897cacc (1272)

@ggerganov ggerganov force-pushed the custom-attention-mask branch from d4cd263 to 1fb033f Compare September 17, 2023 18:17
data[i] = n_past + i;
}
}

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xaedes I'm changing the API of ggml_rope to take an entire vector with positions instead of n_past. I have a small concern about this particular change in train-text-from-scratch and cannot test it atm. I'm not sure if the allocator won't make some intermediate results to overwrite the data of KQ_pos at some point.

In other places, we fix this using ggml_allocr_alloc():

llama.cpp/llama.cpp

Lines 2431 to 2439 in 1fb033f

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

But wasn't sure if it's applicable here.

Copy link
Collaborator

@xaedes xaedes Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.

not sure if the allocator won't make some intermediate results to overwrite the data

To avoid deallocation of certain tensors T until the end of computation, I added a temporary scale_inplace(T, 1.0f) operation at the end of the computation graph before giving it to the allocator. With this the allocator cannot deallocate T before the original end of the graph. Those temporary operations are removed from the graph after allocations are done, so that they are not actually executed.
For example here:

// make sure some tensors are not reallocated by inserting new temporary nodes depending on them

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, clever solution :) I added a scale op for KQ_pos to be safe.

During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.

Btw, when n_past == 0, the KQ_pos tensor would have values 0, 1, 2, 3, ... (i.e. n_past + i).

@cebtenzzre
Copy link
Collaborator

See my comment at #2060 (comment). We should consider implementing non-RoPEd KV cache in the future. I hope these changes wouldn't make that more difficult.

@ggerganov
Copy link
Owner Author

ggerganov commented Sep 17, 2023

@cebtenzzre Shouldn't be more difficult than before this change. The main issue is updating ggml_rope to work with the transposed V data and potentially there is a concern about a performance hit due to having to RoPE the entire cache instead of just the new tokens. I'll probably give it a try within the scope of this PR, but first want to quickly finalize this version and make sure this entire approach is viable

Edit: here is a PoC #3234

@ggerganov ggerganov added high priority Very important issue need feedback Testing and feedback with results are needed labels Sep 18, 2023
@ggerganov ggerganov force-pushed the custom-attention-mask branch from 6289ed6 to 0cbf3bf Compare September 18, 2023 15:10
@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Sep 18, 2023

The main issue is a concern about a performance hit due to having to RoPE the entire cache instead of just the new tokens.

I don't think this would be too bad. The main issue that would arise from RoPeing the KV cache on-the-fly would be that trigonometric functions are relatively expensive in terms of compute. For large batches you are compute bound but you still need to compute RoPE only once per eval so the compute per token should still be acceptable. Conversely, for small batches you are I/O bound so RoPEing the KV cache on-the-fly after loading it should not make a difference.

@ggerganov ggerganov force-pushed the custom-attention-mask branch from 976ff05 to 5bda9e2 Compare September 18, 2023 17:31
@ggerganov ggerganov force-pushed the custom-attention-mask branch from 5bda9e2 to 0161372 Compare September 18, 2023 17:37
@ggerganov
Copy link
Owner Author

ggerganov commented Sep 18, 2023

I'm still thinking about the non-RoPEd K cache, but I guess I'm more leaning towards the current solution. I updated the graph to perform a "shift" of the K-cache when RoPE is involved via an optional call to ggml_rope:

llama.cpp/llama.cpp

Lines 2750 to 2762 in 897cacc

// shift the entire K-cache if needed
if (do_rope_shift) {
ggml_build_forward_expand(gf,
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
}

Since ggml_rope now takes a vector of positions (actually, deltas of the position) instead of just n_past, we can apply custom shifts to the entire KV cache or portions of it via the new llama_kv_cache_shift_seq(). This integrates well with the unified KV cache idea also implemented here, where we store the cache from multiple independent sequences in the same buffer and utilize custom attention mask to attend to the correct data.

Overall, I think it works pretty well. There are some tricky things still happening and the devs have to be very careful. For example, the KV cache is now fully manually managed, so we have to be careful when to clear it and to clear it correctly.

I'll be looking the following days to polish the API and make things less error-prone. The branch should be ready for testing on CPU and Metal. It is lacking CUDA support, which I will try to implement, but would appreciate any help with it.

yusiwen pushed a commit to yusiwen/llama.cpp that referenced this pull request Oct 7, 2023
…gerganov#3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (ggerganov#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <[email protected]>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (ggerganov#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (ggerganov#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close ggerganov#3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <[email protected]>
crasm added a commit to crasm/ensemble that referenced this pull request Oct 17, 2023
@MrJackSpade
Copy link

Is the intent of the cache shifting to allow for arbitrary shifting of data in any direction?

I've recently run into an issue with receiving garbage data from the model, and it appears to be related to shifting cells with a + delta.

I know the immediate use case is to - shift to account for context overflow, however I was + shifting to insert data into the prompt, and that seems to be what's causing everything to break down.

I know its not a standard use case, but since I don't see anything in the code preventing use of a + delta, I assumed it was supported.

I've been trying to dig in and see if I can figure out exactly what's going wrong, but I'm getting in a bit over my head in the code trying to chase it down.

@GentlemanOfCulture
Copy link

I see the warning about using world info, but how exactly does context shifting work with world info? If a world entry is triggered and is added to the exact middle of the context, for example, then what portions of the prompt need re-evaluation? Is it just the world entry itself, or is it all context that follows the world entry?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority Very important issue need feedback Testing and feedback with results are needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

llama : try to avoid context swap