-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduction of CUDA Graphs to LLama.cpp #6766
Conversation
Labelled as DRAFT since this will need some more testing across different models, CUDA versions, etc before it is merged. See #6763. |
EDIT: the job just failed: https://github.com/ggerganov/llama.cpp/actions/runs/8753504588/job/24040839682?pr=6766 |
The speedup does not seem to be consistent: for a batch size of 1 it's faster but for batch sizes >> 1 it's slower. Still, the potential speedup is higher than I thought so it seems I was wrong when I previously said using CUDA graphs would not be worthwhile. Error for perplexity
|
Thanks for these tests. I haven't yet optimized/tested for batch size greater than one - it might be a good idea for me to only enable CUDA graphs for size 1 initially. I'll also look at the failures.
It's not obvious - even without CUDA graphs, llama.cpp already does a good job of pre-launching all kernels in the GGML graph, so CPU-side launch overheads are not the issue. But CUDA graphs also optimises GPU-side launch overheads, to reduce the "gaps" between kernels, and that is the benefit we are seeing here. |
Tried to add ROCm HIP compatibility but it error with:
Here my patch: diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index 481065b2..f1553f7c 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -117,6 +117,23 @@
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+#define CUDA_KERNEL_NODE_PARAMS_v2 hipKernelNodeParams
+#define CUresult hipError_t
+#define cuGetErrorString hipDrvGetErrorString
+#define cuGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cuGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaGraph_t hipGraph_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaStreamCaptureModeGlobal hipStreamCaptureModeGlobal
+#define cudaStreamEndCapture hipStreamEndCapture
#else
#include <cuda_runtime.h>
#include <cuda.h>
@@ -208,14 +225,12 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
-#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
-#endif
#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@@ -389,6 +404,16 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}
+
+struct cudaGraphExecUpdateResultInfo {
+ cudaGraphNode_t errorFromNode;
+ cudaGraphNode_t errorNode;
+ cudaGraphExecUpdateResult result;
+};
+
+static __host__ __forceinline__ cudaError_t cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, cudaGraphExecUpdateResultInfo* resultInfo ) {
+ return hipGraphExecUpdate(hGraphExec, hGraph, &resultInfo->errorNode, &resultInfo->result);
+}
#endif // defined(GGML_USE_HIPBLAS)
// TODO: move to ggml-common.h |
It does not seems to work at P40. and I cannot get it compile on ROCM
|
|
Good spot! I had this earlier when developing and somehow it got lost when refactoring/tidying but still worked. Now fixed. |
@JohannesGaessler I think the llama-bench and perplexity issues should now be fixed with the latest commit - can you confirm from your end? Perplexity is slower with CUDA graphs ATM because is has batch size > 1 - as above I think I should only enable CUDA graphs for batch size 1 initially. |
The P40 issue may be due to the CUDA version in use, we need CUDA >= 12.0 for the functionality here. I have now added a macro such that this new code won't be compiled with earlier CUDA, and the original code will instead be used. Similarly it now won't be compiled with HIP/ROCm (which can be added in a follow up if there is adequate support and performance benefits on that platform). |
I am using 12.4cuda for my P40, but this might have something to do with hardware,right?
|
@agray3 ok,it seems work now. I encounter a error during runtime
|
@sorasoras thanks for testing. Can you let me know the exact command for which you are seeing a failure, so I can try and reproduce? I don't have access to P40 but I have done a test on P100 that works OK. |
I am running batch inference for translations.
my guess is it has some issue with continuous batching |
With your changes, it now works with ROCm HIP (with patch below), but it is slower, making it likely not worth enabling it on that platform. I'm using a RX 6700 XT.
The patch can be found here: https://gist.github.com/ardfork/a223a10d20961707e7b5f3ee0b76c7d5, didn't want to bloat your PR comments with a wall of text that will be useless for most people reading it. |
Some single GPU test results:
There is a performance regression for P40. I was not able to run multi GPU tests because
Error with 3x P40:
|
It work decently on RDNA3 through. I forget to record the value but it's about 2% faster for token generation. |
OK thanks. I've now disabled CUDA graphs for multi-GPU and batch size > 1 which should prevent these crashes and regressions (where I can investigate these cases later). I can also disable for Pascal, I'l have a look at that tomorrow (also assessing Volta etc). |
I've reproduced the llama-bench regression on Pascal (CC 6) and Volta (CC 7), so I've now added code to disable CUDA graphs for CC<8. I've also added an env var: |
Here's on my 7900xtx
Over those 4 models, your PR yields a significant, but very small average TG speedup of 0.2%. |
I just noticed that instead of using Github's built-in draft feature you added "DRAFT:" to the title. Please let me know when you think the PR is ready for review, I'll take a look then. |
Unfortunately the graph is not exposed in the app until cudaStreamEndCapture(). You could possibly do something like this manually by introducing a wrapper to all kernel launches, plus some mechanism in that wrapper to keep track of things. |
If I understand |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on static code analysis I only have minor comments like i++
vs. ++i
and the things I otherwise mentioned. Definitely nothing that I would consider worth blocking a merge for; Let me also check the performance and correctness.
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); | ||
for (int j = 0; j < GGML_MAX_SRC; j++) { | ||
if (node->src[j] != nullptr) { | ||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); | |
GGML_ASSERT(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are correctness checks for ggml-backend that are only meant to be enabled in debug builds (that code was already before). In normal circumstances this should never happen, so there is no need to check it on release builds.
} | ||
|
||
#ifndef NDEBUG | ||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); | |
GGML_ASSERT(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); |
ggml-cuda.cu
Outdated
#if 0 | ||
if (disable_cuda_graphs_due_to_failed_capture) { | ||
use_cuda_graph = false; | ||
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; | ||
#ifndef NDEBUG | ||
fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__); | ||
#endif | ||
} else { | ||
graph_evaluated_or_captured = true; // CUDA graph has been captured | ||
} | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this code block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code used to check for graph capture failures, but I removed that since the implementation had issues. I think that there shouldn't be any capture failures if the ggml graph is checked correctly for incompatible ops early on. Also it uses relaxed capture mode now, which allows ops such as allocations to succeed even if they cannot be captured into the graph.
Yes, you are right, that should work - I've not tried it but it should append. But as you say I think it would have quite substantial overhead because there are a lot of kernels in each graph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU | Model | Model Size [GiB] | Test | t/s master | t/s f42312e | Speedup |
---|---|---|---|---|---|---|
RTX 4090 | llama 8B F16 | 14.96 | tg128 | 56.68 | 57.85 | 1.02 |
RTX 4090 | llama 8B Q4_0 | 4.33 | tg128 | 149.28 | 159.67 | 1.07 |
According to my testing the results are bit-for-bit identical, the evaluation is simply faster. It's unfortunate that it right now only works for batch size 1 since that means it won't be compatible with >1 server slots or speculative decoding such as with #6828 (I'm not sure whether the constant switching between batch sizes causes issues but I consider that my own problem to fix).
I don't have Ampere hardware handy, but the PR builds and works on RTX 2060 as usual |
I have been testing the VRAM usage of the CUDA graphs, and for me it is ~10MB for 7B, ~12MB for 13B, and ~18MB for 30B. So I think it is low enough that it is unlikely to cause any regressions, and can be left enabled by default. |
I'm guessing 30x0 is good too (also Ampere)... |
Did anyone check server performance before and after this PR? I am seeing no difference in terms of user request throughput on 1x RTX 4090 even with a single server slot: #6828 (comment) |
If it is being disabled for some reason, you can find out with a debug build. |
According to the log with debugging enabled this seems to be the reason:
I think the problem is that the graph is initially |
The logic to determine when to disable graphs is not great, but necessary to avoid degrading performance when cuda graphs cannot be used. For example with pipeline parallelism tensor addresses change with every eval so it must be disabled. But maybe prompt evals shouldn't increase the counter. |
@JohannesGaessler can you let me know how to reproduce this? I can then try to relax these checks appropriately. Thanks. |
The easiest way is to just run |
Thanks. Yeah I confirm it works if I comment out the line |
From playing around with the code, it seems that, in |
I think you can use the already existing code that checks for batch size > 1 to do this. I changed your implementation to check for add operation instead of a softmax, but the idea is the same. Did you find it unreliable? llama.cpp knows when it is a prompt (or batch) evaluation when the call to |
As discussed in PR ggml-org#6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.
Thanks, I found the issue - the counter is being unnecessarily incremented even for tokens where graphs are disabled. See the simple fix at #7302 |
As discussed in PR #6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.
As discussed in PR ggml-org#6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.
See Issue #6763