Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL: Add gated linear attention kernel #11175

Merged
merged 3 commits into from
Jan 15, 2025

Conversation

qnixsynapse
Copy link
Contributor

@qnixsynapse qnixsynapse commented Jan 10, 2025

Following #11001 , added gated linear attention kernel based on the logic of the CUDA kernel.

This is my very first initial attempt at translating CUDA kernels to SYCL. Please excuse me for mistakes.
test-backend-ops passing for now.

Could not able test the model(which is 32B) because of lack of memory. Maybe test in an Nvidia GPU should give the results.

@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language labels Jan 10, 2025
Copy link
Collaborator

@Alcpz Alcpz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Overall the code looks good, though I found an issue. I will try to launch the complete model in an A100 once the barrier is fixed.

ggml/src/ggml-sycl/gla.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@Alcpz Alcpz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to run a small perf test with the model:

model size params backend ngl sm test t/s
rwkv6qwen2 32B Q4_K - Medium 20.42 GiB 34.74 B SYCL 99 none pp512 276.94 ± 0.29
rwkv6qwen2 32B Q4_K - Medium 20.42 GiB 34.74 B SYCL 99 none tg128 21.67 ± 0.06

A run of llama-cli with the model seems fine. LGTM!

@Alcpz
Copy link
Collaborator

Alcpz commented Jan 14, 2025

@qnixsynapse Let's wait a day for others to review. If no one else comes by, I will merge it.

@qnixsynapse
Copy link
Contributor Author

@Alcpz sure.

BTW, I did some inspection because I felt like this seems a bit slow for an A100. It turns out that dequant matmul kernels aren't vectorized enough and does not make use of local_accessors. The entire thing seems converted using a tool called SYCLomatic and I do find this tool problematic tbh.

@Alcpz
Copy link
Collaborator

Alcpz commented Jan 14, 2025

Yes, you are right. The original code for this backend was primarily generated using SYCLomatic, which prioritizes functionalities over optimal design due to the nature of the tool. The reference build from which this was converted has had significant improvements since then.
It's good that contributors like you improve the backend and contribute by writing proper SYCL code instead.

Regarding local memory, there’s no universal guarantee that adding it to matrix multiplication kernels will enhance performance across all devices, given the hardware differences across vendors. Achieving consistent performance could potentially require splitting code paths or adjusting kernels for specific hardware, so just have that in mind in case you want to start working there.

@qnixsynapse
Copy link
Contributor Author

qnixsynapse commented Jan 14, 2025

Regarding local memory, there’s no universal guarantee that adding it to matrix multiplication kernels will enhance performance across all devices, given the hardware differences across vendors. Achieving consistent performance could potentially require splitting code paths or adjusting kernels for specific hardware, so just have that in mind in case you want to start working there.

Actually, my aim is to reduce the number of global memory accesses during an operation by caching data that is being accessed in an operation (like scales for dequantization) in local memory. For now, I will wait for someone from the Codeplay/Intel side or the person who wrote the original code to improve it. They have better knowledge of and access to the hardware than I do..

Copy link
Collaborator

@NeoZhangJianyu NeoZhangJianyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great job!
Since you make the unit test is passed. I think the quality should be OK. At least for function.

Thank you!

@NeoZhangJianyu NeoZhangJianyu merged commit f446c2c into ggerganov:master Jan 15, 2025
48 checks passed
@qnixsynapse qnixsynapse deleted the gla branch January 15, 2025 04:53
@Rbiessy
Copy link
Collaborator

Rbiessy commented Jan 15, 2025

Actually, my aim is to reduce the number of global memory accesses during an operation by caching data that is being accessed in an operation (like scales for dequantization) in local memory. For now, I will wait for someone from the Codeplay/Intel side or the person who wrote the original code to improve it. They have better knowledge of and access to the hardware than I do..

FYI we are not planning to optimize SYCL kernels for Nvidia devices in the short term. There may be longer term options which will allow us to compile and launch native CUDA kernels with SYCL interop mode.

@qnixsynapse
Copy link
Contributor Author

FYI we are not planning to optimize SYCL kernels for Nvidia devices in the short term.

@Rbiessy It's okay. My primary focus for this backend is Intel GPUs only, since I personally own one myself.

@NeoZhangJianyu
Copy link
Collaborator

Actually, my aim is to reduce the number of global memory accesses during an operation by caching data that is being accessed in an operation (like scales for dequantization) in local memory. For now, I will wait for someone from the Codeplay/Intel side or the person who wrote the original code to improve it. They have better knowledge of and access to the hardware than I do..

FYI we are not planning to optimize SYCL kernels for Nvidia devices in the short term. There may be longer term options which will allow us to compile and launch native CUDA kernels with SYCL interop mode.

From the viewpoint to migrate CUDA to SYCL, or support CUDA device by SYCL, it's OK to use SYCL interop mode to support native CUDA kernels.
But from the project viewpoint, both SYCL -> CUDA -> NV GPU and CUDA-> NV GPU are same to end user of llama.cpp. Maybe the first path is slower than second path for more wrap.

For CUDA user, they have more choice. That means only the best solution to be used popular finally. I don't think SYCL backend could attract more the CUDA user than CUDA backend.

The value of SYCL backend is to support Intel GPU with good performance, that's why some Intel GPU user move from other tools to llama.cpp SYCL backend. SYCL backend is the first choice of Intel GPU.

In this year, my target is still to optimize the SYCL backend with good function and performance on Intel GPU only.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants