Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llamafile : improve sgemm.cpp #6796

Merged
merged 4 commits into from
Apr 22, 2024
Merged

llamafile : improve sgemm.cpp #6796

merged 4 commits into from
Apr 22, 2024

Conversation

jart
Copy link
Contributor

@jart jart commented Apr 20, 2024

- Re-enable by default
- Fix issue described in ggerganov#6716
- Make code more abstract, elegant, and maintainable
- Faster handling of weirdly shaped `m` an `n` edge cases
Copy link
Contributor

github-actions bot commented Apr 20, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 465 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=10185.65ms p(95)=27005.29ms fails=, finish reason: stop=416 truncated=49
  • Prompt processing (pp): avg=105.0tk/s p(95)=462.15tk/s
  • Token generation (tg): avg=25.75tk/s p(95)=36.86tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=sgemm2 commit=d3d40bfd1e20d7c77029081c56188c3381a1a1b4

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 903.41, 903.41, 903.41, 903.41, 903.41, 655.86, 655.86, 655.86, 655.86, 655.86, 683.83, 683.83, 683.83, 683.83, 683.83, 703.98, 703.98, 703.98, 703.98, 703.98, 725.35, 725.35, 725.35, 725.35, 725.35, 754.87, 754.87, 754.87, 754.87, 754.87, 750.82, 750.82, 750.82, 750.82, 750.82, 760.52, 760.52, 760.52, 760.52, 760.52, 768.01, 768.01, 768.01, 768.01, 768.01, 764.59, 764.59, 764.59, 764.59, 764.59, 771.48, 771.48, 771.48, 771.48, 771.48, 766.39, 766.39, 766.39, 766.39, 766.39, 769.0, 769.0, 769.0, 769.0, 769.0, 785.47, 785.47, 785.47, 785.47, 785.47, 783.08, 783.08, 783.08, 783.08, 783.08, 750.46, 750.46, 750.46, 750.46, 750.46, 646.23, 646.23, 646.23, 646.23, 646.23, 646.67, 646.67, 646.67, 646.67, 646.67, 650.24, 650.24, 650.24, 650.24, 650.24, 650.22, 650.22, 650.22, 650.22, 650.22, 655.8, 655.8, 655.8, 655.8, 655.8, 663.75, 663.75, 663.75, 663.75, 663.75, 661.86, 661.86, 661.86, 661.86, 661.86, 662.94, 662.94, 662.94, 662.94, 662.94, 666.12, 666.12, 666.12, 666.12, 666.12, 666.5, 666.5, 666.5, 666.5, 666.5, 669.18, 669.18, 669.18, 669.18, 669.18, 670.52, 670.52, 670.52, 670.52, 670.52, 643.8, 643.8, 643.8, 643.8, 643.8, 647.81, 647.81, 647.81, 647.81, 647.81, 648.55, 648.55, 648.55, 648.55, 648.55, 647.84, 647.84, 647.84, 647.84, 647.84, 646.89, 646.89, 646.89, 646.89, 646.89, 648.43, 648.43, 648.43, 648.43, 648.43, 648.49, 648.49, 648.49, 648.49, 648.49, 651.7, 651.7, 651.7, 651.7, 651.7, 654.74, 654.74, 654.74, 654.74, 654.74, 654.25, 654.25, 654.25, 654.25, 654.25, 655.42, 655.42, 655.42, 655.42, 655.42, 658.0, 658.0, 658.0, 658.0, 658.0, 666.24, 666.24, 666.24, 666.24, 666.24, 670.04, 670.04, 670.04, 670.04, 670.04, 667.2, 667.2, 667.2, 667.2, 667.2, 666.66, 666.66, 666.66, 666.66, 666.66, 665.98, 665.98, 665.98, 665.98, 665.98, 666.02, 666.02, 666.02, 666.02, 666.02, 668.94, 668.94, 668.94, 668.94, 668.94, 671.72, 671.72, 671.72, 671.72, 671.72, 680.15, 680.15, 680.15, 680.15, 680.15, 673.62, 673.62, 673.62, 673.62, 673.62, 655.87, 655.87, 655.87, 655.87, 655.87, 654.24, 654.24, 654.24, 654.24, 654.24, 653.49, 653.49, 653.49, 653.49, 653.49, 652.47, 652.47, 652.47, 652.47, 652.47, 650.04, 650.04, 650.04, 650.04, 650.04, 652.53, 652.53, 652.53, 652.53, 652.53, 652.77, 652.77, 652.77, 652.77, 652.77, 655.36, 655.36, 655.36, 655.36, 655.36, 657.45, 657.45, 657.45, 657.45, 657.45, 658.24, 658.24, 658.24, 658.24, 658.24, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 36.3, 36.3, 36.3, 36.3, 36.3, 22.22, 22.22, 22.22, 22.22, 22.22, 22.9, 22.9, 22.9, 22.9, 22.9, 24.05, 24.05, 24.05, 24.05, 24.05, 24.57, 24.57, 24.57, 24.57, 24.57, 24.7, 24.7, 24.7, 24.7, 24.7, 26.32, 26.32, 26.32, 26.32, 26.32, 26.81, 26.81, 26.81, 26.81, 26.81, 27.0, 27.0, 27.0, 27.0, 27.0, 26.82, 26.82, 26.82, 26.82, 26.82, 26.41, 26.41, 26.41, 26.41, 26.41, 26.1, 26.1, 26.1, 26.1, 26.1, 25.1, 25.1, 25.1, 25.1, 25.1, 24.45, 24.45, 24.45, 24.45, 24.45, 24.38, 24.38, 24.38, 24.38, 24.38, 23.83, 23.83, 23.83, 23.83, 23.83, 23.23, 23.23, 23.23, 23.23, 23.23, 22.94, 22.94, 22.94, 22.94, 22.94, 22.93, 22.93, 22.93, 22.93, 22.93, 22.99, 22.99, 22.99, 22.99, 22.99, 22.74, 22.74, 22.74, 22.74, 22.74, 22.5, 22.5, 22.5, 22.5, 22.5, 22.35, 22.35, 22.35, 22.35, 22.35, 22.06, 22.06, 22.06, 22.06, 22.06, 21.89, 21.89, 21.89, 21.89, 21.89, 21.92, 21.92, 21.92, 21.92, 21.92, 22.05, 22.05, 22.05, 22.05, 22.05, 22.1, 22.1, 22.1, 22.1, 22.1, 22.27, 22.27, 22.27, 22.27, 22.27, 22.43, 22.43, 22.43, 22.43, 22.43, 22.54, 22.54, 22.54, 22.54, 22.54, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.38, 22.38, 22.38, 22.38, 22.38, 22.57, 22.57, 22.57, 22.57, 22.57, 22.63, 22.63, 22.63, 22.63, 22.63, 22.72, 22.72, 22.72, 22.72, 22.72, 22.92, 22.92, 22.92, 22.92, 22.92, 22.98, 22.98, 22.98, 22.98, 22.98, 22.99, 22.99, 22.99, 22.99, 22.99, 22.95, 22.95, 22.95, 22.95, 22.95, 22.87, 22.87, 22.87, 22.87, 22.87, 22.78, 22.78, 22.78, 22.78, 22.78, 22.72, 22.72, 22.72, 22.72, 22.72, 22.74, 22.74, 22.74, 22.74, 22.74, 22.75, 22.75, 22.75, 22.75, 22.75, 22.93, 22.93, 22.93, 22.93, 22.93, 22.97, 22.97, 22.97, 22.97, 22.97, 22.94, 22.94, 22.94, 22.94, 22.94, 22.6, 22.6, 22.6, 22.6, 22.6, 22.42, 22.42, 22.42, 22.42, 22.42, 22.34, 22.34, 22.34, 22.34, 22.34, 22.05, 22.05, 22.05, 22.05, 22.05, 21.71, 21.71, 21.71, 21.71, 21.71, 21.58, 21.58, 21.58, 21.58, 21.58, 21.61, 21.61, 21.61, 21.61, 21.61, 21.7, 21.7, 21.7, 21.7, 21.7, 21.71, 21.71, 21.71, 21.71, 21.71, 21.79, 21.79, 21.79, 21.79, 21.79, 21.84, 21.84, 21.84, 21.84, 21.84, 21.88, 21.88, 21.88, 21.88]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.01, 0.01, 0.01, 0.33, 0.33, 0.33, 0.33, 0.33, 0.23, 0.23, 0.23, 0.23, 0.23, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.25, 0.25, 0.25, 0.25, 0.25, 0.18, 0.18, 0.18, 0.18, 0.18, 0.35, 0.35, 0.35, 0.35, 0.35, 0.28, 0.28, 0.28, 0.28, 0.28, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.26, 0.26, 0.26, 0.26, 0.26, 0.25, 0.25, 0.25, 0.25, 0.25, 0.29, 0.29, 0.29, 0.29, 0.29, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.26, 0.26, 0.26, 0.26, 0.26, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.08, 0.08, 0.08, 0.08, 0.08, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.29, 0.29, 0.29, 0.29, 0.29, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.3, 0.3, 0.3, 0.3, 0.3, 0.42, 0.42, 0.42, 0.42, 0.42, 0.51, 0.51, 0.51, 0.51, 0.51, 0.4, 0.4, 0.4, 0.4, 0.4, 0.36, 0.36, 0.36, 0.36, 0.36, 0.39, 0.39, 0.39, 0.39, 0.39, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.1, 0.1, 0.1, 0.1, 0.1, 0.19, 0.19, 0.19, 0.19, 0.19, 0.27, 0.27, 0.27, 0.27]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0]
                    
Loading

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build fails on Mac:

$ make -j
I llama.cpp build info: 
I UNAME_S:   Darwin
I UNAME_P:   arm
I UNAME_M:   arm64
I CFLAGS:    -I. -Icommon -D_XOPEN_SOURCE=600 -D_DARWIN_C_SOURCE -DNDEBUG -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_LLAMAFILE -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY  -std=c11   -fPIC -O3 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int -Werror=implicit-function-declaration -pthread -Wunreachable-code-break -Wunreachable-code-return -Wdouble-promotion 
I CXXFLAGS:  -std=c++11 -fPIC -O3 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -pthread   -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi -I. -Icommon -D_XOPEN_SOURCE=600 -D_DARWIN_C_SOURCE -DNDEBUG -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_LLAMAFILE -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY 
I NVCCFLAGS: -std=c++11 -O3 
I LDFLAGS:   -framework Accelerate -framework Foundation -framework Metal -framework MetalKit 
I CC:        Apple clang version 15.0.0 (clang-1500.1.0.2.5)
I CXX:       Apple clang version 15.0.0 (clang-1500.1.0.2.5)

c++ -std=c++11 -fPIC -O3 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -pthread   -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi -I. -Icommon -D_XOPEN_SOURCE=600 -D_DARWIN_C_SOURCE -DNDEBUG -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_LLAMAFILE -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY  -c sgemm.cpp -o sgemm.o
sgemm.cpp:515:13: error: unknown type name 'D'
            D Cv[RN][RM] = {};
            ^
sgemm.cpp:516:41: error: use of undeclared identifier 'KN'
            for (int l = 0; l < k; l += KN)
                                        ^
2 errors generated.
make: *** [sgemm.o] Error 1

Also, apply the following patch to enable LLAMAFILE with CMake builds:

diff --git a/CMakeLists.txt b/CMakeLists.txt
index f134a153..58a1805b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -43,17 +43,11 @@ else()
     set(LLAMA_METAL_DEFAULT OFF)
 endif()
 
-# TODO: fix this for Android CI
-#       https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191
-#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
-#    set(LLAMA_LLAMAFILE_DEFAULT OFF)
-#else()
-#    set(LLAMA_LLAMAFILE_DEFAULT ON)
-#endif()
-
-# TODO: temporary disable until MoE is fixed
-#       https://github.com/ggerganov/llama.cpp/pull/6716
-set(LLAMA_LLAMAFILE_DEFAULT OFF)
+if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
+    set(LLAMA_LLAMAFILE_DEFAULT OFF)
+else()
+    set(LLAMA_LLAMAFILE_DEFAULT ON)
+endif()
 
 # general
 option(BUILD_SHARED_LIBS                "build shared libraries"                                OFF)

@jart
Copy link
Contributor Author

jart commented Apr 21, 2024

@ggerganov Fixed. PTAL

Please give Q8_0 a try (build it with make -j32 LLAMA_NO_ACCELERATE=1 LLAMA_NO_METAL=1 main) On my Mac Studio M2 Ultra, Mistral 7b Q8 prompt eval goes from 90 tok/sec to 140 tok/sec with this change. That's 16% faster than Apple Accelerate cblas_sgemm() which goes 120 tok/sec for me. The other quants I've tried (q40 and f16) seem to be equal to Accelerate in speed.

Do you know if there's a better way to fix the Android issue? IIUC it's due to an instruction not being available on 32-bit ARM. Is there a way we could solve that with an #ifdef instead? I don't own a 32-bit ARM system, so I have no way of doing this myself.

@jart
Copy link
Contributor Author

jart commented Apr 21, 2024

@ggerganov Could you advise me on how I might bring the benefits of llamafile_sgemm() to GGML_MUL_MAT_ID? I know very little about mixture of expert architecture. It's not obvious to me how I might go about decomposing that operation into 2d matrix multiplications.

@ggerganov
Copy link
Owner

On master with Accelerate I get:

make clean && LLAMA_NO_METAL=1 make -j && ./llama-bench -m models/mistral-7b-v0.2/ggml-model-fp16.gguf -m models/mistral-7b-v0.2/ggml-model-q8_0.gguf -m models/mistral-7b-v0.2/ggml-model-q4_0.gguf -ngl 0 -n 0
model size params backend threads test t/s
llama 8B F16 13.49 GiB 7.24 B BLAS 16 pp 512 152.87 ± 1.06
llama 8B Q8_0 7.17 GiB 7.24 B BLAS 16 pp 512 147.44 ± 5.19
llama 8B Q4_0 3.83 GiB 7.24 B BLAS 16 pp 512 149.98 ± 1.63

build: 8960fe8 (2713)

With this PR without Accelerate:

make clean && LLAMA_NO_ACCELERATE=1 LLAMA_NO_METAL=1 make -j && ./llama-bench -m models/mistral-7b-v0.2/ggml-model-fp16.gguf -m models/mistral-7b-v0.2/ggml-model-q8_0.gguf -m models/mistral-7b-v0.2/ggml-model-q4_0.gguf -ngl 0 -n 0
model size params backend threads test t/s
llama 7B F16 13.49 GiB 7.24 B CPU 16 pp 512 172.84 ± 0.39
llama 7B Q8_0 7.17 GiB 7.24 B CPU 16 pp 512 146.22 ± 0.44
llama 7B Q4_0 3.83 GiB 7.24 B CPU 16 pp 512 123.81 ± 0.43

build: 6b220dc (2704)

So for me, F16 is faster now, Q8_0 is the same and Q4_0 is slower.

Btw, I've looked some more, and I think the proper call in ggml.c should be like this:

diff --git a/ggml.c b/ggml.c
index e3356bdb..086db96a 100644
--- a/ggml.c
+++ b/ggml.c
@@ -10878,15 +10878,13 @@ UseGgmlGemm1:;
     const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
 #if GGML_USE_LLAMAFILE
-    if (src1_cont) {
+    if (src1->type != vec_dot_type) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
                                      nb01/ggml_type_size(src0->type),
-                                     (const char *)wdata + ggml_row_size(vec_dot_type,
-                                         nb12/ggml_type_size(src1->type)*i12 +
-                                         nb13/ggml_type_size(src1->type)*i13),
+                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
                                      row_size/ggml_type_size(vec_dot_type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),

Regarding Android: in ggml-quants.c we do stuff like this to provide 32-bit ARM compatibility:

llama.cpp/ggml-quants.c

Lines 291 to 307 in e931888

#if !defined(__aarch64__)
// 64-bit compatibility
// vaddvq_s16
// vpaddq_s16
// vpaddq_s32
// vaddvq_s32
// vaddvq_f32
// vmaxvq_f32
// vcvtnq_s32_f32
// vzip1_u8
// vzip2_u8
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +

We should not repeat the same implementation twice - we have to see if we can reuse it. I also don't have setup Android builds and it takes me some time to get the build running. So for now, let's focus on fixing the SGEMM and later we can think about improving Android support

Let me think some time about GGML_MUL_MAT_ID support

@jart
Copy link
Contributor Author

jart commented Apr 22, 2024

Review comments addressed. PTAL. Agree on Android 32-bit.

Let me think some time about GGML_MUL_MAT_ID support

I studied the code for hours and managed to figure it out. I've got a llamafile_mixmul() function working now that enables mixtral to go 2x faster on my machine for prompt processing.

@ggerganov ggerganov merged commit 192090b into ggerganov:master Apr 22, 2024
56 of 59 checks passed
@fairydreaming
Copy link
Collaborator

These changes cause failed assertions when running Cohere's Command R+ model:

main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.

Thread 22 "main" received signal SIGABRT, Aborted.
[Switching to Thread 0x7fcf41ffb640 (LWP 108365)]
__pthread_kill_implementation (no_tid=0, signo=6, threadid=140528142235200) at ./nptl/pthread_kill.c:44
44	./nptl/pthread_kill.c: No such file or directory.
(gdb) bt
#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=140528142235200) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=140528142235200) at ./nptl/pthread_kill.c:78
#2  __GI___pthread_kill (threadid=140528142235200, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3  0x00007ffff7a99476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007ffff7a7f7f3 in __GI_abort () at ./stdlib/abort.c:79
#5  0x00007ffff7a7f71b in __assert_fail_base (fmt=0x7ffff7c34130 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", 
    assertion=0x555555811556 "1ll * lda * m <= 0x7fffffff", file=0x55555581150a "sgemm.cpp", line=827, 
    function=<optimized out>) at ./assert/assert.c:92
#6  0x00007ffff7a90e96 in __GI___assert_fail (assertion=0x555555811556 "1ll * lda * m <= 0x7fffffff", 
    file=0x55555581150a "sgemm.cpp", line=827, 
    function=0x555555811498 "bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int)") at ./assert/assert.c:101
#7  0x000055555576c199 in llamafile_sgemm (m=256000, n=1, k=12288, A=0x7fcf9dd7b500, lda=12288, B=0x7fcf88918020, 
    ldb=12288, C=0x7fcf89518020, ldc=256000, ith=21, nth=32, task=0, Atype=1, Btype=0, Ctype=0) at sgemm.cpp:827
#8  0x0000555555588c78 in ggml_compute_forward_mul_mat (params=0x7fcf41ffae20, dst=0x55555bf087d0) at ggml.c:10831
#9  0x00005555555a1125 in ggml_compute_forward (params=0x7fcf41ffae20, tensor=0x55555bf087d0) at ggml.c:16254
#10 0x00005555555a75b0 in ggml_graph_compute_thread (data=0x7fffffffb660) at ggml.c:18398
#11 0x00007ffff7aebac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#12 0x00007ffff7b7d850 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

When I reverted commit 192090b the problem disappeared. Log files attached.

main-crash.log
main-working.log

@ggerganov
Copy link
Owner

Yes, this assert has to be avoided. The Command-R model has a very large output tensor and it's number of elements exceeds int. That's why in order to support it, we switched to int64_t in many places across the codebase: #6491

This issue should be prioritized, @jart PTAL

@ggerganov
Copy link
Owner

Here are instruction to trigger this assert:

# convert to GGUF
python3 convert-hf-to-gguf.py ~/Data/huggingface/c4ai-command-r-plus/ --outfile models/command-r-plus/ggml-model-f16.gguf --outtype f16

# quantize to Q8_0 + F16 token embeddings
make -j
./quantize --token-embedding-type f16 ./models/command-r-plus/ggml-model-f16.gguf ./models/command-r-plus/ggml-model-q8_0.gguf q8_0

# build in DEBUG and run
make clean
LLAMA_DEBUG=1 LLAMA_NO_METAL=1 LLAMA_NO_ACCELERATE=1 make -j
./main -m ./models/command-r-plus/ggml-model-q8_0.gguf

@fairydreaming
Copy link
Collaborator

I just tried a naive solution and replaced all ints in sgemm.cpp and sgemm.h with int64_t, and the resulting code works fine without any performance penalty (at least on my Epyc Genoa). Also there are no more crashes due to int overflow in pointer calculations when using Command R+.

By the way, @jart thank you for these changes, they improved the prompt eval time on my system by 65% on llama-3 70B Q8!

@jart
Copy link
Contributor Author

jart commented Apr 26, 2024

Thanks for the inbox bump. Making this my top priority now. Expect a PR shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants