-
Notifications
You must be signed in to change notification settings - Fork 10.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : custom attention mask + parallel decoding + no context swaps #3228
Conversation
d4cd263
to
1fb033f
Compare
data[i] = n_past + i; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xaedes I'm changing the API of ggml_rope
to take an entire vector with positions instead of n_past
. I have a small concern about this particular change in train-text-from-scratch
and cannot test it atm. I'm not sure if the allocator won't make some intermediate results to overwrite the data of KQ_pos
at some point.
In other places, we fix this using ggml_allocr_alloc()
:
Lines 2431 to 2439 in 1fb033f
// KQ_pos - contains the positions | |
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); | |
ggml_allocr_alloc(lctx.alloc, KQ_pos); | |
if (!ggml_allocr_is_measure(lctx.alloc)) { | |
int * data = (int *) KQ_pos->data; | |
for (int i = 0; i < N; ++i) { | |
data[i] = n_past + i; | |
} | |
} |
But wasn't sure if it's applicable here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.
not sure if the allocator won't make some intermediate results to overwrite the data
To avoid deallocation of certain tensors T until the end of computation, I added a temporary scale_inplace(T, 1.0f) operation at the end of the computation graph before giving it to the allocator. With this the allocator cannot deallocate T before the original end of the graph. Those temporary operations are removed from the graph after allocations are done, so that they are not actually executed.
For example here:
llama.cpp/examples/finetune/finetune.cpp
Line 768 in 5ce74ee
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, clever solution :) I added a scale op for KQ_pos
to be safe.
During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.
Btw, when n_past == 0
, the KQ_pos
tensor would have values 0, 1, 2, 3, ...
(i.e. n_past + i
).
See my comment at #2060 (comment). We should consider implementing non-RoPEd KV cache in the future. I hope these changes wouldn't make that more difficult. |
@cebtenzzre Shouldn't be more difficult than before this change. The main issue is Edit: here is a PoC #3234 |
57cea73
to
fad5693
Compare
6289ed6
to
0cbf3bf
Compare
I don't think this would be too bad. The main issue that would arise from RoPeing the KV cache on-the-fly would be that trigonometric functions are relatively expensive in terms of compute. For large batches you are compute bound but you still need to compute RoPE only once per eval so the compute per token should still be acceptable. Conversely, for small batches you are I/O bound so RoPEing the KV cache on-the-fly after loading it should not make a difference. |
976ff05
to
5bda9e2
Compare
5bda9e2
to
0161372
Compare
I'm still thinking about the non-RoPEd K cache, but I guess I'm more leaning towards the current solution. I updated the graph to perform a "shift" of the K-cache when RoPE is involved via an optional call to Lines 2750 to 2762 in 897cacc
Since Overall, I think it works pretty well. There are some tricky things still happening and the devs have to be very careful. For example, the KV cache is now fully manually managed, so we have to be careful when to clear it and to clear it correctly. I'll be looking the following days to polish the API and make things less error-prone. The branch should be ready for testing on CPU and Metal. It is lacking CUDA support, which I will try to implement, but would appreciate any help with it. |
…gerganov#3228) * tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (ggerganov#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <[email protected]> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (ggerganov#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <[email protected]> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (ggerganov#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close ggerganov#3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <[email protected]>
Is the intent of the cache shifting to allow for arbitrary shifting of data in any direction? I've recently run into an issue with receiving garbage data from the model, and it appears to be related to shifting cells with a + delta. I know the immediate use case is to - shift to account for context overflow, however I was + shifting to insert data into the prompt, and that seems to be what's causing everything to break down. I know its not a standard use case, but since I don't see anything in the code preventing use of a + delta, I assumed it was supported. I've been trying to dig in and see if I can figure out exactly what's going wrong, but I'm getting in a bit over my head in the code trying to chase it down. |
I see the warning about using world info, but how exactly does context shifting work with world info? If a world entry is triggered and is added to the exact middle of the context, for example, then what portions of the prompt need re-evaluation? Is it just the world entry itself, or is it all context that follows the world entry? |
close: #2060 #2813
ref: #3137
Merge ETA: ~ 27 - 30 Sep
Here we will attempt to add batched (a.k.a. parallel, multi-sequence) decoding to
llama.cpp
.For a short summary of the plan, see #3137 (comment)
Features:
API changes:
struct llama_batch
llama_batch_init
llama_batch_free
llama_batch_get_one
for easy API migrationllama_get_kv_cache_token_count
llama_kv_cache_tokens_rm
llama_kv_cache_seq_rm
llama_kv_cache_seq_cp
llama_kv_cache_seq_keep
llama_kv_cache_seq_shift
llama_eval
andllama_eval_embd
llama_decode
llama_get_logits_ith
Demo
No context swaps
Previously, we had to re-evaluate the context when it becomes full and this could take a lot of time, especially on the CPU.
Now, this is avoided by correctly updating the KV cache on-the-fly:
Parallel decoding - basic example
examples/batched
The
batched
example has been extended with an argument to specify the number of sequences to generate using the given prompt. This is a good starting point for understanding the newllama_batch
API introduced in this PR.Generated results
Parallel decoding - server simulation
examples/parallel
Decoding 32 parallel streams for a total of 128 sequences with continuous batching in about ~26s. Token summary:
Using 7B LLaMA v2 Q8_0 model on single RTX 4080 with 16GB VRAM.
parallel-cuda-1-lq.mp4
For comparison, generating the same amount of sequences using a single stream (i.e.
-np 1
) takes ~125s on the same hardware with the same model.When using the
parallel
example, you must make sure that your KV cache will have enough size to fit all parallel requests. If it is not able to fit them, then you will start seeing cache misses that slow down the processing significantly and can eventually fail the decoding if there is no more space in the cache.To set the KV cache size, use the
-c, --context
parameter. For example, for 32 parallel streams that are expected to generate a maximum of 128 tokens each (i.e.-n 128
), you would need to set-c 4096
(i.e. 32*128). If continuous batching is enabled, you would need some extra KV space to deal with fragmentation of the cache. In the example above, we conveniently set the context size to 8192 to guarantee that there will be no issues.Also, when computing the KV cache size you should also take into account the size of the system prompt (if any).
We can run a similar test on M2 Ultra. This time, since we have practically infinite VRAM, we can afford to run 128 parallel streams instead of just 32. And we can also use the F16 model instead of the quantum Q8_0. In this case, continuous batching is not needed, since we will be processing all of the 128 requests at once:
parallel-ultra-0-speed-up-x2.mp4
(video is speed-up to fit in Github 10MB limit)
Implementation details
KV cache as ring-buffer
This is one of the major changes. Previously, the
n_past
number indicated what part of the KV cache to use in the computation. The assumption was that all cached data would be stored in the cache starting from the first cell forward and that this data belongs to one sequence only. We now want to store information from multiple sequences in the KV cache and the data from each sequence can be located in arbitrary cells. To achieve that, the KV cache now stores information about the sequences to which the cache data belongs and the corresponding position in these sequences:llama.cpp/llama.cpp
Lines 1009 to 1018 in b377bf2
With this change, we now have to attend to the entire cache when computing the self attention, because we don't know where the data for the sequences being processed is located. So the naive solution is when we build the graph to set
n_kv = n_ctx
. However, there is a simple heuristic that can be introduced in order to restore to most extend the behavior ofn_past
where we attend to just the first part of the cache:llama.cpp/llama.cpp
Lines 2620 to 2624 in b377bf2
Here,
kv_self.n
contains the index of the last "occupied" cell in the cache. I.e. we know that all cells beyondkv_self.n
are currently empty, so no need to attend to them. We naively compute it each time we build a graph:llama.cpp/llama.cpp
Lines 4101 to 4107 in b377bf2
This heuristic is easy to implement and gives a measurable improvement, allowing to keep the performance the same as on
master
for single-sequence generations. It stops having any effect once the KV cache has been filled. In theory, we can improve this even further by computingn_kv_0
andn_kv_1
- the KV range that contains the cached tokens for all currently decoded sequences. We just need to find an elegant way to incorporate this information in the KV view offsets. (i.e.n_kv = n_kv_1 - n_kv_0
, etc.).Workplan
tests/test-rope.cpp
)ggml_rope
to take a tensor with positions instead of justn_past
rope_f16
kernelrope_f32
kernel to work with vector of positionsrope_f16
kernel (needed for K-cache shifts)ggml_add
for applying a custom-INF
mask to a tensor instead of the existingggml_diag_mask_inf
add
support broadcast across dims 1,2,3ggml
API for RoPE and maskllama.h
API for passing multi-sequence datallama_kv_shift()
for "forgetting" old tokens, re-roping (if needed) the KV cache and compressing itTODOs
fix MPIwill remain for a future PRggml_alibi()
, replace withggml_add()
(similar toggml_diag_mask_inf
, Baichuan 13B)Performance
I decided to disable the concurrency optimization since it prevents from modifying the graph topology. With the changes in this PR, we now need to sometimes add extra nodes for K-cache shift, so it is no longer compatible. The result is ~8% slower prompt processing but ~5% faster text generations with Metal.
M2 Ultra
build: 897cacc (1272)