-
Notifications
You must be signed in to change notification settings - Fork 10.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: mul_mat_vec_q tiling, refactor mul mat logic #5434
CUDA: mul_mat_vec_q tiling, refactor mul mat logic #5434
Conversation
I hate to say it, but this is one of the downsides of splitting a repo into as few files as possible - compilation is very serial and inefficient. cmake is probably worse right now because it seems to be building and archiving static libraries before compiling the examples and tests that depend on them. |
In this case the reason of the high compilation time is the number of template instantiations. I am not sure that anything short of putting each kernel in a different source file would help with that. I would still prefer to split the sources into more files, even if just to ease working with the code, though. I find it very hard to work with a 10k LOC file. |
This commit seems to fail badly on ROCm.
Edit: fixed in 2bb97fc. |
Some results
Small script to automate using compare-llama-bench.py#!/bin/bash
set -e
set -x
if [ $# -lt 2 ]; then
echo "usage: ./scripts/compare-commits.sh <commit1> <commit2> [additional llama-bench arguments]"
exit 1
fi
bench_args="${@:3}"
rm -f llama-bench.sqlite
git checkout $1
make clean && LLAMA_CUBLAS=1 make -j32 llama-bench
./llama-bench -o sql $bench_args | tee /dev/tty | sqlite3 llama-bench.sqlite
git checkout $2
make clean && LLAMA_CUBLAS=1 make -j32 llama-bench
./llama-bench -o sql $bench_args | tee /dev/tty | sqlite3 llama-bench.sqlite
./scripts/compare-llama-bench.py -b $1 -c $2 Example usage:
|
Results on V100, RTX 2060 and A100 bash scripts/compare-commits.sh master cuda-faster-mmvq-12 -p 2,3,4,5,6,7,8 -n 1 -r 100 -m $(echo models-mnt/open-llama/7B-v2/ggml-model-*.gguf | sed -e "s/ /,/g") V100Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
RTX 2060
A100 SXM 80GB
|
Edit: updated to reflect changes after 76a0128.
|
My results:
I think I'll revert part of the changes that cause a small regression for a batch size of 1. I did this to reduce register pressure but for that batch size it seems to on average not be beneficial. |
I reverted part of the changes. This should fix the regression for a batch size of 1. It may be possible to squeeze out 1-2% more performance by utilizing bit shifts for pointer arithmetic but then you'd have to compile 8 times more kernels. |
Nvm, it already scales almost perfectly for F16 when there is enough memory bandwidth: Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
build: 2bb97fc (2112) |
I think that the biggest issue right now is that the CUDA grids are just too small. On my RTX 3090, for q8_0, a 4096x4096 weight matrix, and a batch size of 1 ~10% of the kernel runtime is lost to the initial latency when launching the kernel and another ~5% are lost due to tail effects. That's why I think fusing the branching matrix multiplication kernels like the KQV ones which all use the same hidden state as input would be beneficial. It would also allow you to save time on the conversion of the hidden state (something like 1% of the total runtime). Other than that, you could potentially get better performance by fundamentally changing the data layout. If you were to separate the quantized data into blocks that only contain the quantized values or the scales you should be able to load it more efficiently. But with some prototypes where I tried this the performance did not improve. And also an identical data layout between backends makes it much easier to work with the code. Or write a completely new kernel since |
When it comes to scaling in particular the issue could also be related to the amount of compute. The scaling as the batch size increases on my P40s is much worse than on my RTX 3090 and I think the reason is that (relative to the memory bandwidth) modern cards just have way more compute. |
With the current version I get this performance:
Compared to the previous version the performance for a batch size of 1 should be the same as on master but the performance for larger batch sizes is slightly worse. I think prioritizing a batch size of 1 makes more sense since it's the most common use case. |
Yes, So should we look for ways to improve the build time or should we merge it like this? I wish we didn't have to special case so many sizes and architectures, but it looks like this is how CUDA (GPU?) programming goes. I don't want the code to start taking 10 minutes to compile, so what options are there to improve this? |
The main issue is that to get good performance you have to do loop unrolling. The compiler can then optimize out a lot of conditional statements and rearrange the instructions in a better way but this simply takes time, especially if you do this for multiple loop lengths. The biggest reduction in compile time would be achieved with just splitting the code into multiple files. Currently you cannot parallelize the compilation at all. But if you were to split the code based on e.g. the 14 different data formats you could compile the code with 14 parallel jobs which (given enough cores) should be much faster. Or add a compile option that reduces compile time at the cost of performance. |
As a data point, on my system with a Ryzen 5950X the compile time on master is 13.461 s, with this PR it's 18.945 s. Command used:
|
I guess it's more realistic if we moved the kernels in a separate header + source and build it multiple times for various template specializations based on |
To do that we would need to keep a list of specializations on the build scripts, litter the code with ifdefs for each combination of template parameters, and additionally create a different function for each combination of parameters so that the template can be linked externally. This would be an insane amount of complexity. |
Yeah, I already tried to prototype the idea and saw it does not make sense. (the main issue is actually the cases where the template argument is a function (e.g. dequantize_xxx) but it's not important) |
What we could do instead is separate the code by kernel. As of right now what takes up most of the time are the matrix multiplication kernels that deal with quantized data: |
Wouldn't it be possible to define the templates in one file, include that file in the quantization-specific files (where the actual kernels get compiled), and to then include the quantization-specific files in |
Moving each kernel to a different source file, together with its |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, let's consider reorganizing the CUDA backend into multiple files
Co-authored-by: slaren <[email protected]>
* CUDA: mul_mat_vec_q tiling, refactor mul mat logic Co-authored-by: slaren <[email protected]> --------- Co-authored-by: slaren <[email protected]>
* CUDA: mul_mat_vec_q tiling, refactor mul mat logic Co-authored-by: slaren <[email protected]> --------- Co-authored-by: slaren <[email protected]>
This PR does the following:
ggml_cuda_mul_mat
and simplify the logic for choosing between different matrix multiplication kernels. This also fixes P100s not being treated as having good FP16 performance.mul_mat_vec_q
. This increases arithmetic intensity and results in higher t/s for batch sizes > 1. This and a reduction in the number of warps for batch sizes 5-8 (to reduce register pressure) mostly fixes the performance regression sometimes seen when increasing batch sizes. I increased the maximum batch size formul_mat_vec_q
to 8. There are still some cases where the performance regresses slightly as you increase the batch size if you get unlucky with occupancy but these cases should be rare now.mul_mat_vec_q
to use as few registers as possible so that there are more for loop unrolling. This increases performance on average but it again is possible to get unlucky with occupancy so there are also some cases where performance is ~1% worse.mul_mat_vec_q
kernel with variablencols_y
. With a batch size of 8 the kernel seems to already be hitting its limits and makingncols_y
purely a template parameter makes it possible to reduce compilation time.I'm currently too tired to re-run all of the performance tests; I'll post results tomorrow.
Because of the loop unrolling for the new kernels the compilation time has increased. On master the command
reports 13.443 s, with this PR it's 18.691 s. It may make sense to add an option like
LLAMA_FAST_COMPILE
that reduces compilation time as much as possible.