Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: add DeepSeek-v3 support #10981

Closed
4 tasks done
RodriMora opened this issue Dec 26, 2024 · 64 comments · Fixed by #11049
Closed
4 tasks done

Feature Request: add DeepSeek-v3 support #10981

RodriMora opened this issue Dec 26, 2024 · 64 comments · Fixed by #11049
Labels
enhancement New feature or request

Comments

@RodriMora
Copy link

RodriMora commented Dec 26, 2024

Prerequisites

  • I am running the latest code. Mention the version if possible as well.
  • Version b4391
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new and useful enhancement to share.

Feature Description

Add support for DeepSeek-v3

https://huggingface.co/deepseek-ai/DeepSeek-V3

Currently not supported:

ERROR:hf-to-gguf:Model DeepseekV3ForCausalLM is not supported

Motivation

DeepSeek-v3 is a big MoE model of 685B params, would be great as offloading to RAM would be a must for most systems

Possible Implementation

There is no model card or technical report yet. I don't know how much different from v2 it is.

Edit: they have uploaded the model card and paper:
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/README.md

@RodriMora RodriMora added the enhancement New feature or request label Dec 26, 2024
@nisten
Copy link

nisten commented Dec 26, 2024

The sigmoid routing thing or whatever is a bit different but the rest of the arch is largerly the same as deepseek2.5, just larger.

There's no PR yet in hf transformers, it looks like they've built this atop of transformers 4.33 so that will be quite a merge to get properly i guess.
So it's not thaaat hard to implement given that the modelling code but it's still hard because again it's a "2nd gen" MoE with a routing pool.

@web-traveler
Copy link

In case it helps: transformers 4.46.3 is written here https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/requirements.txt

@arthurwolf
Copy link

What's missing to get this to work, and can one do anything to help?

@Nottlespike
Copy link

What's missing to get this to work, and can one do anything to help?

huggingface/transformers#35425

@cpumaxx
Copy link
Contributor

cpumaxx commented Dec 31, 2024

Can a dev help break down for us what would be required in convert_hf_to_gguf.py to at least get a gguf created with the new model bits? eg. mlp.gate.e_score_correction_bias, etc
Without a gguf I think implementation of multi-token prediction and the other novel parts of this model's inference architecture will be much harder to develop.
I'm happy to help if I can, but the papers and deepseek2 diffs aren't giving me as much of a clue as I was hoping.

@web-traveler
Copy link

web-traveler commented Dec 31, 2024

@cpumaxx
Copy link
Contributor

cpumaxx commented Dec 31, 2024

@fairydreaming : How much more work is needed before you can accept collaborators and testers on your branch? I see on localllama that you have at least a PoC running.

@fairydreaming
Copy link
Collaborator

fairydreaming commented Dec 31, 2024

@fairydreaming : How much more work is needed before you can accept collaborators and testers on your branch? I see on localllama that you have at least a PoC running.

I still have to add a new pre-tokenizer regex and test the tokenization. I'm not sure how many weird regex quirks I'll encounter along the way, but I estimate it will take a few days at most.

Edit: Also, I don't have MTP implemented, but it can be added later.

@Nottlespike
Copy link

@fairydreaming : How much more work is needed before you can accept collaborators and testers on your branch? I see on localllama that you have at least a PoC running.

I still have to add a new pre-tokenizer regex and test the tokenization. I'm not sure how many weird regex quirks I'll encounter along the way, but I estimate it will take a few days at most.

Edit: Also, I don't have MTP implemented, but it can be added later.

You can do this without offical HF transformers support without trust_remote_code=True?
This is my main concern and why I'm working with HF for an offical HF transformers implempentation.
huggingface/transformers#35425 (comment)
What is the branch and how can I help?

@fairydreaming
Copy link
Collaborator

My DeepSeek-V3 branch is here: https://github.com/fairydreaming/llama.cpp/tree/deepseek-v3

To convert the model to GGUF you need dequantized DeepSeek V3. You can download it from HF (there are several BF16 DeepSeek V3 models available, but I didn't test any of them) or run inference/fp8_cast_bf16.py script from the original model to convert it to bf16 (that's what I did). Note that it uses triton, so I think you need a GPU for this. In case you experience CUDA out of memory errors during conversion check this: https://huggingface.co/deepseek-ai/DeepSeek-V3/discussions/17

There are some minor tokenization differences compared to the original model, but I think it's usable.

@fairydreaming
Copy link
Collaborator

Some initial perplexity values over wiki.test.raw (not a full run) with Q4_K_S quantized model:

$ ./build/bin/llama-perplexity --numa distribute -t 32 -m /mnt/md0/models/deepseek-v3-Q4_K_S.gguf --no-context-shift -f ../perplexity/wikitext-2-raw/wiki.test.raw
build: 4407 (ad77e9b3) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu
llama_model_loader: loaded meta data with 43 key-value pairs and 1025 tensors from /mnt/md0/models/deepseek-v3-Q4_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Models Deepseek Ai DeepSeek V3 Bf16
llama_model_loader: - kv   3:                         general.size_label str              = 256x20B
llama_model_loader: - kv   4:                      deepseek2.block_count u32              = 61
llama_model_loader: - kv   5:                   deepseek2.context_length u32              = 163840
llama_model_loader: - kv   6:                 deepseek2.embedding_length u32              = 7168
llama_model_loader: - kv   7:              deepseek2.feed_forward_length u32              = 18432
llama_model_loader: - kv   8:             deepseek2.attention.head_count u32              = 128
llama_model_loader: - kv   9:          deepseek2.attention.head_count_kv u32              = 128
llama_model_loader: - kv  10:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  12:                deepseek2.expert_used_count u32              = 8
llama_model_loader: - kv  13:                          general.file_type u32              = 14
llama_model_loader: - kv  14:        deepseek2.leading_dense_block_count u32              = 3
llama_model_loader: - kv  15:                       deepseek2.vocab_size u32              = 129280
llama_model_loader: - kv  16:            deepseek2.attention.q_lora_rank u32              = 1536
llama_model_loader: - kv  17:           deepseek2.attention.kv_lora_rank u32              = 512
llama_model_loader: - kv  18:             deepseek2.attention.key_length u32              = 192
llama_model_loader: - kv  19:           deepseek2.attention.value_length u32              = 128
llama_model_loader: - kv  20:       deepseek2.expert_feed_forward_length u32              = 2048
llama_model_loader: - kv  21:                     deepseek2.expert_count u32              = 256
llama_model_loader: - kv  22:              deepseek2.expert_shared_count u32              = 1
llama_model_loader: - kv  23:             deepseek2.expert_weights_scale f32              = 2.500000
llama_model_loader: - kv  24:              deepseek2.expert_weights_norm bool             = true
llama_model_loader: - kv  25:               deepseek2.expert_gating_func u32              = 2
llama_model_loader: - kv  26:             deepseek2.rope.dimension_count u32              = 64
llama_model_loader: - kv  27:                deepseek2.rope.scaling.type str              = yarn
llama_model_loader: - kv  28:              deepseek2.rope.scaling.factor f32              = 40.000000
llama_model_loader: - kv  29: deepseek2.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  30: deepseek2.rope.scaling.yarn_log_multiplier f32              = 0.100000
llama_model_loader: - kv  31:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  32:                         tokenizer.ggml.pre str              = deepseek-v3
llama_model_loader: - kv  33:                      tokenizer.ggml.tokens arr[str,129280]  = ["<|begin▁of▁sentence|>", "<�...
llama_model_loader: - kv  34:                  tokenizer.ggml.token_type arr[i32,129280]  = [3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  35:                      tokenizer.ggml.merges arr[str,127741]  = ["Ġ t", "Ġ a", "i n", "Ġ Ġ", "h e...
llama_model_loader: - kv  36:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  37:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  38:            tokenizer.ggml.padding_token_id u32              = 1
llama_model_loader: - kv  39:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  40:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  41:                    tokenizer.chat_template str              = {% if not add_generation_prompt is de...
llama_model_loader: - kv  42:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q4_K:  652 tensors
llama_model_loader: - type q5_K:   11 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
llm_load_vocab: special tokens cache size = 818
llm_load_vocab: token to piece cache size = 0.8223 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = deepseek2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 129280
llm_load_print_meta: n_merges         = 127741
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 163840
llm_load_print_meta: n_embd           = 7168
llm_load_print_meta: n_layer          = 61
llm_load_print_meta: n_head           = 128
llm_load_print_meta: n_head_kv        = 128
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 192
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 24576
llm_load_print_meta: n_embd_v_gqa     = 16384
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 18432
llm_load_print_meta: n_expert         = 256
llm_load_print_meta: n_expert_used    = 8
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = yarn
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 0.025
llm_load_print_meta: n_ctx_orig_yarn  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 671B
llm_load_print_meta: model ftype      = Q4_K - Small
llm_load_print_meta: model params     = 671.03 B
llm_load_print_meta: model size       = 353.90 GiB (4.53 BPW) 
llm_load_print_meta: general.name     = Models Deepseek Ai DeepSeek V3 Bf16
llm_load_print_meta: BOS token        = 0 '<|begin▁of▁sentence|>'
llm_load_print_meta: EOS token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: EOT token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: PAD token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: LF token         = 131 'Ä'
llm_load_print_meta: FIM PRE token    = 128801 '<|fim▁begin|>'
llm_load_print_meta: FIM SUF token    = 128800 '<|fim▁hole|>'
llm_load_print_meta: FIM MID token    = 128802 '<|fim▁end|>'
llm_load_print_meta: EOG token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: max token length = 256
llm_load_print_meta: n_layer_dense_lead   = 3
llm_load_print_meta: n_lora_q             = 1536
llm_load_print_meta: n_lora_kv            = 512
llm_load_print_meta: n_ff_exp             = 2048
llm_load_print_meta: n_expert_shared      = 1
llm_load_print_meta: expert_weights_scale = 2.5
llm_load_print_meta: expert_weights_norm  = 1
llm_load_print_meta: expert_gating_func   = sigmoid
llm_load_print_meta: rope_yarn_log_mul    = 0.1000
llm_load_tensors:   CPU_Mapped model buffer size = 362392.97 MiB
....................................................................................................
llama_new_context_with_model: n_seq_max     = 4
llama_new_context_with_model: n_ctx         = 2048
llama_new_context_with_model: n_ctx_per_seq = 512
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 512
llama_new_context_with_model: flash_attn    = 0
llama_new_context_with_model: freq_base     = 10000.0
llama_new_context_with_model: freq_scale    = 0.025
llama_new_context_with_model: n_ctx_per_seq (512) < n_ctx_train (163840) -- the full capacity of the model will not be utilized
llama_kv_cache_init: kv_size = 2048, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 61
llama_kv_cache_init:        CPU KV buffer size =  9760.00 MiB
llama_new_context_with_model: KV self size  = 9760.00 MiB, K (f16): 5856.00 MiB, V (f16): 3904.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     1.97 MiB
llama_new_context_with_model:        CPU compute buffer size =   670.01 MiB
llama_new_context_with_model: graph nodes  = 5025
llama_new_context_with_model: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 2048
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

system_info: n_threads = 32 (n_threads_batch = 32) / 64 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 722.377 ms
perplexity: calculating perplexity over 569 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 115.76 seconds per pass - ETA 4 hours 34.45 minutes
[1]4.3360,[2]4.8332,[3]4.7754,[4]3.3915,[5]2.6790,[6]2.3034,[7]2.1953,[8]2.1407,[9]1.9801,[10]1.8571,[11]1.8597,[12]1.8759,[13]1.8085,[14]1.9133,[15]2.0826,[16]2.2035,[17]2.3526,[18]2.5832,[19]2.6840,[20]2.7049,[21]2.8202,[22]2.8124,[23]2.7522,[24]2.7122,[25]2.6636,[26]2.6258,[27]2.6398,[28]2.6945,[29]2.7111,[30]2.7600,[31]2.8559,[32]2.9283,[33]2.9444,[34]2.9556,[35]3.0040,[36]3.0353,[37]3.0683,[38]3.1502,[39]3.2140,[40]3.2328,[41]3.3069,[42]3.3719,[43]3.3812,[44]3.4136,[45]3.5169,[46]3.5905,[47]3.5758,[48]3.4899

@Nottlespike
Copy link

Some initial perplexity values over wiki.test.raw (not a full run) with Q4_K_S quantized model:

$ ./build/bin/llama-perplexity --numa distribute -t 32 -m /mnt/md0/models/deepseek-v3-Q4_K_S.gguf --no-context-shift -f ../perplexity/wikitext-2-raw/wiki.test.raw
build: 4407 (ad77e9b3) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu
llama_model_loader: loaded meta data with 43 key-value pairs and 1025 tensors from /mnt/md0/models/deepseek-v3-Q4_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Models Deepseek Ai DeepSeek V3 Bf16
llama_model_loader: - kv   3:                         general.size_label str              = 256x20B
llama_model_loader: - kv   4:                      deepseek2.block_count u32              = 61
llama_model_loader: - kv   5:                   deepseek2.context_length u32              = 163840
llama_model_loader: - kv   6:                 deepseek2.embedding_length u32              = 7168
llama_model_loader: - kv   7:              deepseek2.feed_forward_length u32              = 18432
llama_model_loader: - kv   8:             deepseek2.attention.head_count u32              = 128
llama_model_loader: - kv   9:          deepseek2.attention.head_count_kv u32              = 128
llama_model_loader: - kv  10:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  12:                deepseek2.expert_used_count u32              = 8
llama_model_loader: - kv  13:                          general.file_type u32              = 14
llama_model_loader: - kv  14:        deepseek2.leading_dense_block_count u32              = 3
llama_model_loader: - kv  15:                       deepseek2.vocab_size u32              = 129280
llama_model_loader: - kv  16:            deepseek2.attention.q_lora_rank u32              = 1536
llama_model_loader: - kv  17:           deepseek2.attention.kv_lora_rank u32              = 512
llama_model_loader: - kv  18:             deepseek2.attention.key_length u32              = 192
llama_model_loader: - kv  19:           deepseek2.attention.value_length u32              = 128
llama_model_loader: - kv  20:       deepseek2.expert_feed_forward_length u32              = 2048
llama_model_loader: - kv  21:                     deepseek2.expert_count u32              = 256
llama_model_loader: - kv  22:              deepseek2.expert_shared_count u32              = 1
llama_model_loader: - kv  23:             deepseek2.expert_weights_scale f32              = 2.500000
llama_model_loader: - kv  24:              deepseek2.expert_weights_norm bool             = true
llama_model_loader: - kv  25:               deepseek2.expert_gating_func u32              = 2
llama_model_loader: - kv  26:             deepseek2.rope.dimension_count u32              = 64
llama_model_loader: - kv  27:                deepseek2.rope.scaling.type str              = yarn
llama_model_loader: - kv  28:              deepseek2.rope.scaling.factor f32              = 40.000000
llama_model_loader: - kv  29: deepseek2.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  30: deepseek2.rope.scaling.yarn_log_multiplier f32              = 0.100000
llama_model_loader: - kv  31:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  32:                         tokenizer.ggml.pre str              = deepseek-v3
llama_model_loader: - kv  33:                      tokenizer.ggml.tokens arr[str,129280]  = ["<|begin▁of▁sentence|>", "<�...
llama_model_loader: - kv  34:                  tokenizer.ggml.token_type arr[i32,129280]  = [3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  35:                      tokenizer.ggml.merges arr[str,127741]  = ["Ġ t", "Ġ a", "i n", "Ġ Ġ", "h e...
llama_model_loader: - kv  36:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  37:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  38:            tokenizer.ggml.padding_token_id u32              = 1
llama_model_loader: - kv  39:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  40:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  41:                    tokenizer.chat_template str              = {% if not add_generation_prompt is de...
llama_model_loader: - kv  42:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q4_K:  652 tensors
llama_model_loader: - type q5_K:   11 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
llm_load_vocab: special tokens cache size = 818
llm_load_vocab: token to piece cache size = 0.8223 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = deepseek2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 129280
llm_load_print_meta: n_merges         = 127741
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 163840
llm_load_print_meta: n_embd           = 7168
llm_load_print_meta: n_layer          = 61
llm_load_print_meta: n_head           = 128
llm_load_print_meta: n_head_kv        = 128
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 192
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 24576
llm_load_print_meta: n_embd_v_gqa     = 16384
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 18432
llm_load_print_meta: n_expert         = 256
llm_load_print_meta: n_expert_used    = 8
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = yarn
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 0.025
llm_load_print_meta: n_ctx_orig_yarn  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 671B
llm_load_print_meta: model ftype      = Q4_K - Small
llm_load_print_meta: model params     = 671.03 B
llm_load_print_meta: model size       = 353.90 GiB (4.53 BPW) 
llm_load_print_meta: general.name     = Models Deepseek Ai DeepSeek V3 Bf16
llm_load_print_meta: BOS token        = 0 '<|begin▁of▁sentence|>'
llm_load_print_meta: EOS token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: EOT token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: PAD token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: LF token         = 131 'Ä'
llm_load_print_meta: FIM PRE token    = 128801 '<|fim▁begin|>'
llm_load_print_meta: FIM SUF token    = 128800 '<|fim▁hole|>'
llm_load_print_meta: FIM MID token    = 128802 '<|fim▁end|>'
llm_load_print_meta: EOG token        = 1 '<|end▁of▁sentence|>'
llm_load_print_meta: max token length = 256
llm_load_print_meta: n_layer_dense_lead   = 3
llm_load_print_meta: n_lora_q             = 1536
llm_load_print_meta: n_lora_kv            = 512
llm_load_print_meta: n_ff_exp             = 2048
llm_load_print_meta: n_expert_shared      = 1
llm_load_print_meta: expert_weights_scale = 2.5
llm_load_print_meta: expert_weights_norm  = 1
llm_load_print_meta: expert_gating_func   = sigmoid
llm_load_print_meta: rope_yarn_log_mul    = 0.1000
llm_load_tensors:   CPU_Mapped model buffer size = 362392.97 MiB
....................................................................................................
llama_new_context_with_model: n_seq_max     = 4
llama_new_context_with_model: n_ctx         = 2048
llama_new_context_with_model: n_ctx_per_seq = 512
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 512
llama_new_context_with_model: flash_attn    = 0
llama_new_context_with_model: freq_base     = 10000.0
llama_new_context_with_model: freq_scale    = 0.025
llama_new_context_with_model: n_ctx_per_seq (512) < n_ctx_train (163840) -- the full capacity of the model will not be utilized
llama_kv_cache_init: kv_size = 2048, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 61
llama_kv_cache_init:        CPU KV buffer size =  9760.00 MiB
llama_new_context_with_model: KV self size  = 9760.00 MiB, K (f16): 5856.00 MiB, V (f16): 3904.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     1.97 MiB
llama_new_context_with_model:        CPU compute buffer size =   670.01 MiB
llama_new_context_with_model: graph nodes  = 5025
llama_new_context_with_model: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 2048
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

system_info: n_threads = 32 (n_threads_batch = 32) / 64 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 722.377 ms
perplexity: calculating perplexity over 569 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 115.76 seconds per pass - ETA 4 hours 34.45 minutes
[1]4.3360,[2]4.8332,[3]4.7754,[4]3.3915,[5]2.6790,[6]2.3034,[7]2.1953,[8]2.1407,[9]1.9801,[10]1.8571,[11]1.8597,[12]1.8759,[13]1.8085,[14]1.9133,[15]2.0826,[16]2.2035,[17]2.3526,[18]2.5832,[19]2.6840,[20]2.7049,[21]2.8202,[22]2.8124,[23]2.7522,[24]2.7122,[25]2.6636,[26]2.6258,[27]2.6398,[28]2.6945,[29]2.7111,[30]2.7600,[31]2.8559,[32]2.9283,[33]2.9444,[34]2.9556,[35]3.0040,[36]3.0353,[37]3.0683,[38]3.1502,[39]3.2140,[40]3.2328,[41]3.3069,[42]3.3719,[43]3.3812,[44]3.4136,[45]3.5169,[46]3.5905,[47]3.5758,[48]3.4899

THANKS! Will begin running https://github.com/EleutherAI/lm-evaluation-harness on it ASAP!

@fairydreaming
Copy link
Collaborator

I ran farel-bench locally on the model, looks good! (first two are via OpenRouter, third is local)

Nr Model FaRel child parent grand-child sibling grand-parent great grand-child niece or nephew aunt or uncle great grand-parent
1 deepseek-v3-sys 96.89 100.00 100.00 98.00 98.00 100.00 98.00 88.00 90.00 100.00
2 deepseek-v3 96.44 100.00 100.00 100.00 96.00 100.00 100.00 82.00 92.00 98.00
3 deepseek-v3-Q4_K_S 96.22 100.00 100.00 100.00 98.00 96.00 96.00 86.00 94.00 96.00

@Nottlespike
Copy link

Nottlespike commented Jan 2, 2025

I ran farel-bench locally on the model, looks good! (first two are via OpenRouter, third is local)
Nr Model FaRel child parent grand-child sibling grand-parent great grand-child niece or nephew aunt or uncle great grand-parent
1 deepseek-v3-sys 96.89 100.00 100.00 98.00 98.00 100.00 98.00 88.00 90.00 100.00
2 deepseek-v3 96.44 100.00 100.00 100.00 96.00 100.00 100.00 82.00 92.00 98.00
3 deepseek-v3-Q4_K_S 96.22 100.00 100.00 100.00 98.00 96.00 96.00 86.00 94.00 96.00

What is your rig specs wise?

@fairydreaming
Copy link
Collaborator

I ran farel-bench locally on the model, looks good! (first two are via OpenRouter, third is local)
Nr Model FaRel child parent grand-child sibling grand-parent great grand-child niece or nephew aunt or uncle great grand-parent
1 deepseek-v3-sys 96.89 100.00 100.00 98.00 98.00 100.00 98.00 88.00 90.00 100.00
2 deepseek-v3 96.44 100.00 100.00 100.00 96.00 100.00 100.00 82.00 92.00 98.00
3 deepseek-v3-Q4_K_S 96.22 100.00 100.00 100.00 98.00 96.00 96.00 86.00 94.00 96.00

What is your rig?

@Nottlespike Epyc 9374F, 384GB RAM. It took almost 5 hours to run all 450 prompts.

@Nottlespike
Copy link

I ran farel-bench locally on the model, looks good! (first two are via OpenRouter, third is local)
Nr Model FaRel child parent grand-child sibling grand-parent great grand-child niece or nephew aunt or uncle great grand-parent
1 deepseek-v3-sys 96.89 100.00 100.00 98.00 98.00 100.00 98.00 88.00 90.00 100.00
2 deepseek-v3 96.44 100.00 100.00 100.00 96.00 100.00 100.00 82.00 92.00 98.00
3 deepseek-v3-Q4_K_S 96.22 100.00 100.00 100.00 98.00 96.00 96.00 86.00 94.00 96.00

What is your rig?

@Nottlespike Epyc 9374F, 384GB RAM. It took almost 5 hours to run all 450 prompts.

No GPU's? I got as 4x3090 Ti FE's linked together with the hacked P2P driver plus a ThreadRipper Pro 8 channels of 128GB DDR4 so I should be able to run it MUCH faster! I've seen your work before and REALLY appreciate your contributions! Any way we can get in contact? I know @bartowski1182 very well if they have a contact with you?

@fairydreaming
Copy link
Collaborator

@Nottlespike Epyc 9374F, 384GB RAM. It took almost 5 hours to run all 450 prompts.

No GPU's? I got as 4x3090 Ti FE's linked together with the hacked P2P driver plus a ThreadRipper Pro 8 channels of 128GB DDR4 so I should be able to run it MUCH faster! I've seen your work before and REALLY appreciate your contributions! Any way we can get in contact? I know @bartowski1182 very well if they have a contact with you?

@Nottlespike I have a single RTX 4090, but I didn't use it here. What is your exact CPU model?

Regarding the contact I'm active on Reddit (mostly on r/LocalLLaMA) with the same username.

@Nottlespike
Copy link

@Nottlespike Epyc 9374F, 384GB RAM. It took almost 5 hours to run all 450 prompts.

No GPU's? I got as 4x3090 Ti FE's linked together with the hacked P2P driver plus a ThreadRipper Pro 8 channels of 128GB DDR4 so I should be able to run it MUCH faster! I've seen your work before and REALLY appreciate your contributions! Any way we can get in contact? I know @bartowski1182 very well if they have a contact with you?

@Nottlespike I have a single RTX 4090, but I didn't use it here. What is your exact CPU model?

Regarding the contact I'm active on Reddit (mostly on r/LocalLLaMA) with the same username.

I have been informed I am "unpopular to hated" on r/LocalLLaMA...... given I am basically using a "server" with 4 of the best consumer GPU's on the market and I called the tinybox a grift at best and a scam at worst.

@fairydreaming fairydreaming linked a pull request Jan 2, 2025 that will close this issue
@Nottlespike
Copy link

@Nottlespike Epyc 9374F, 384GB RAM. It took almost 5 hours to run all 450 prompts.

No GPU's? I got as 4x3090 Ti FE's linked together with the hacked P2P driver plus a ThreadRipper Pro 8 channels of 128GB DDR4 so I should be able to run it MUCH faster! I've seen your work before and REALLY appreciate your contributions! Any way we can get in contact? I know @bartowski1182 very well if they have a contact with you?

@Nottlespike I have a single RTX 4090, but I didn't use it here. What is your exact CPU model?

Regarding the contact I'm active on Reddit (mostly on r/LocalLLaMA) with the same username.

@fairydreaming Am I reading your PR correctly and you DON'T NEED trust_remote_code=True HOW? Can you help us at HF out on a offical HF transformers implementation? Also MASSIVE kudos.... 131 lines of ELEGANTE CODE.... I'm in shock and awe

@fairydreaming
Copy link
Collaborator

@fairydreaming Am I reading your PR correctly and you DON'T NEED trust_remote_code=True HOW? Can you help us at HF out on a offical HF transformers implementation? Also MASSIVE kudos.... 131 lines of ELEGANTE CODE.... I'm in shock and awe

@Nottlespike AFAIK llama.cpp conversion scripts only use HF transformers AutoTokenizer class and DeepSeek V3 has no custom tokenizer class implementation, so I guess there is no need for trust_remote_code=True - it simply doesn't run any.

@Nottlespike
Copy link

@fairydreaming Am I reading your PR correctly and you DON'T NEED trust_remote_code=True HOW? Can you help us at HF out on a offical HF transformers implementation? Also MASSIVE kudos.... 131 lines of ELEGANTE CODE.... I'm in shock and awe

@Nottlespike AFAIK llama.cpp conversion scripts only use HF transformers AutoTokenizer class and DeepSeek V3 has no custom tokenizer class implementation, so I guess there is no need for trust_remote_code=True - it simply doesn't run any.

@fairydreaming This is elegant.... props. The previous HF transformers "implementation" forced trust_remote_code.

@etafund
Copy link

etafund commented Jan 2, 2025

EDIT: Ignore below, simple user error.

@fairydreaming, I'm running your convert_hf_to_gguf_update.py file to create a GGUF after dequantizing the model, but when I run the script, I get an error. Any advice on what I'm doing wrong?

python ./convert_hf_to_gguf_update.py /zfspool/user/models/DeepSeek-V3-BF16/ --outtype q8_0 --outfile /zfspool/user/models/

It always gives the same error, no matter what I run:
INFO:convert_hf_to_gguf_update:Usage: python convert_hf_to_gguf_update.py <huggingface_token>

Excited to replicate what you've done! Great work.

@bartowski1182
Copy link
Contributor

@etafund that's the script for updating the conversion script, use the one without _update

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 2, 2025

My DeepSeek-V3 branch is here: https://github.com/fairydreaming/llama.cpp/tree/deepseek-v3

To convert the model to GGUF you need dequantized DeepSeek V3. You can download it from HF (there are several BF16 DeepSeek V3 models available, but I didn't test any of them) or run inference/fp8_cast_bf16.py script from the original model to convert it to bf16 (that's what I did). Note that it uses triton, so I think you need a GPU for this. In case you experience CUDA out of memory errors during conversion check this: https://huggingface.co/deepseek-ai/DeepSeek-V3/discussions/17

There are some minor tokenization differences compared to the original model, but I think it's usable.

Thanks, @fairydreaming! Your updated conversion script is working perfectly going from BF16 to q8_0.

I'll update with inference results once the quanting finishes and I have a chance to run it through its paces.

@fraschm1998
Copy link

@Nottlespike I have a single RTX 4090, but I didn't use it here. What is your exact CPU model?

What are your speeds with the 4090?

@Nottlespike
Copy link

@Nottlespike I have a single RTX 4090, but I didn't use it here. What is your exact CPU model?

What are your speeds with the 4090?

Wrong person. I'm the one with this thing
Screenshot from 2024-12-28 09-32-56

@Nottlespike
Copy link

Nottlespike commented Jan 3, 2025

@Nottlespike Yeah, better wait for merged llama.cpp DeepSeek V3 support. I mean it's not a race, is it?

@fairydreaming https://manifold.markets/Kearm20/will-i-be-able-to-run-deepseekv3-10 It is now.

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 3, 2025

I have it running on my dual-socket genoa rig now. First result is 8.83t/s cpu-only inference.
@fairydreaming: is there anything I can do to assist in implementation? I'd need to pull and requant, but don't mind doing so if I can be of use. Thanks for getting this model working!

@etafund
Copy link

etafund commented Jan 3, 2025

I got it to work also. Similar path as @RodriMora.

  • Original fp8 > bf16 safetensors
  • bf16 safetensors > q8_0 gguf

Specs:
2x Epyc 9684X
1x RTX 6000 Ada

Hint for anyone trying to replicate: Make sure to run with --no-context-shift

CUDA build with -ngl 0. Real life test on q8_0 gguf:
llama_perf_sampler_print: sampling time = 36.48 ms / 591 runs (0.06 ms per token, 16198.44 tokens per second)
llama_perf_context_print: load time = 624846.58 ms
llama_perf_context_print: prompt eval time = 973.60 ms / 7 tokens (139.09 ms per token, 7.19 tokens per second)
llama_perf_context_print: eval time = 135586.41 ms / 583 runs (232.57 ms per token, 4.30 tokens per second)
llama_perf_context_print: total time = 136689.20 ms / 590 tokens

@etafund
Copy link

etafund commented Jan 3, 2025

I have it running on my dual-socket genoa rig now. First result is 8.83t/s cpu-only inference. fairydreaming: is there anything I can do to assist in implementation? I'd need to pull and requant, but don't mind doing so if I can be of use. Thanks for getting this model working!

@cpumaxx, running at q8_0 or q4?

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 3, 2025

I have it running on my dual-socket genoa rig now. First result is 8.83t/s cpu-only inference. fairydreaming: is there anything I can do to assist in implementation? I'd need to pull and requant, but don't mind doing so if I can be of use. Thanks for getting this model working!

@cpumaxx, running at q8_0 or q4?

I quanted to q8_0.
To get max throughput you'll want to drop all memory caches and use the --numa distribute flag
I've seen 9.22t/s throughput now on a short prompt from empty context.

@arthurwolf
Copy link

arthurwolf commented Jan 3, 2025 via email

@etafund
Copy link

etafund commented Jan 3, 2025

@arthurwolf, 680238 MiB is the model buffer size on my rig. So about 664.25 GB of RAM.

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 3, 2025

just curious, how much RAM does this use to run?

llama-server process is using 711GB on my rig

@etafund
Copy link

etafund commented Jan 3, 2025

@cpumaxx have tried a few different commands here and am stuck at 4-5 tokens/second. Mind posting the command you're running that gets you to 8-9 tokens/second on your dual Genoa setup? Thanks so much!

sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches' to clear the system page cache. Then...
numactl --interleave=0-1 ./llama-cli -m /zfspool/user/models/DeepSeek-V3-BF16-256x20B-Q8_0.gguf -p "What is DeepSeek?" -n 128 -ngl 0 --no-mmap --no-context-shift --numa distribute

@Nottlespike
Copy link

@cpumaxx have tried a few different commands here and am stuck at 4-5 tokens/second. Mind posting the command you're running? Thanks so much!

sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches' to clear the system page cache. Then... numactl --interleave=0-1 ./llama-cli -m /zfspool/user/models/DeepSeek-V3-BF16-256x20B-Q8_0.gguf -p "What is DeepSeek?" -n 128 -ngl 0 --no-mmap --no-context-shift --numa distribute

100% -ngl should not be 0. 7 should work with a A6000.

@etafund
Copy link

etafund commented Jan 3, 2025

@Nottlespike Agree with you - just trying to figure out the Epyc Genoa CPU issue first, then will layer in the GPU to improve performance.

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 3, 2025

@cpumaxx have tried a few different commands here and am stuck at 4-5 tokens/second. Mind posting the command you're running that gets you to 8-9 tokens/second on your dual Genoa setup? Thanks so much!

sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches' to clear the system page cache. Then... numactl --interleave=0-1 ./llama-cli -m /zfspool/user/models/DeepSeek-V3-BF16-256x20B-Q8_0.gguf -p "What is DeepSeek?" -n 128 -ngl 0 --no-mmap --no-context-shift --numa distribute

Remove the numactl command and --no-mmap flag. Everything else is the same as what I'm using.
You want it to use mmap so that you get maximum memory locality given the current lcpp architecture.
You won't see maximum speed until a few responses have faulted in all the memory caches, so If you're looking for performance numbers you probably either want llama-bench or an interactive session.

@Nottlespike
Copy link

Nottlespike commented Jan 3, 2025

Yo if y'all want to collaborate in real time I'm livestreaming this all over X at https://x.com/i/broadcasts/1MnxnDkgaMyGO

@Nottlespike
Copy link

But I'm also a maniac who is doing this from BF16 DeepSeek-V3-Bf16-256x20B-BF16.gguf

@fairydreaming
Copy link
Collaborator

@etafund Try to limit the number of threads to 32 or 48 (-t 32)

@etafund
Copy link

etafund commented Jan 3, 2025

Still testing at about 5.5 tokens/second on a dual Epyc Genoa system. If anyone has advice on how to get this closer to 8-9 tokens/second, let me know. Thanks @fairydreaming and @cpumaxx and @Nottlespike for all the help so far.

./llama-bench -ngl 0 --numa distribute -t 32 -m /zfspool/user/models/DeepSeek-V3-BF16-256x20B-Q8_0.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes

model size params backend ngl threads test t/s
deepseek2 671B Q8_0 664.29 GiB 671.03 B CUDA 0 32 pp512 14.11 ± 0.01
deepseek2 671B Q8_0 664.29 GiB 671.03 B CUDA 0 32 tg128 5.38 ± 0.07

@fairydreaming
Copy link
Collaborator

fairydreaming commented Jan 3, 2025

@etafund What performance do you have on a CPU-only llama.cpp build? (compiled without CUDA)

Edit: Also you can try increasing the number of threads to 48, 64 and check when the performance starts decreasing, I'm not sure what's the right value for your CPU.

@fairydreaming
Copy link
Collaborator

I have it running on my dual-socket genoa rig now. First result is 8.83t/s cpu-only inference. @fairydreaming: is there anything I can do to assist in implementation? I'd need to pull and requant, but don't mind doing so if I can be of use. Thanks for getting this model working!

@cpumaxx If you have time please repeat the whole conversion and model testing process with the current code to confirm that it still works without problems (and that old DeepSeek V2 and V2.5 still work in case you have them). I just finished dealing with the llama.cpp file explosion caused by #10902, time to get some rest.

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 3, 2025

I have it running on my dual-socket genoa rig now. First result is 8.83t/s cpu-only inference. @fairydreaming: is there anything I can do to assist in implementation? I'd need to pull and requant, but don't mind doing so if I can be of use. Thanks for getting this model working!

@cpumaxx If you have time please repeat the whole conversion and model testing process with the current code to confirm that it still works without problems (and that old DeepSeek V2 and V2.5 still work in case you have them). I just finished dealing with the llama.cpp file explosion caused by #10902, time to get some rest.

Sure. I should be done by end of day if the last conversion was any barometer.
I can sympathize with trying to keep a dev branch in sync with the lcpp's head...its a quickly moving target and a lot of work just trying to keep it merga-able without breaking things!

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 4, 2025

I have it running on my dual-socket genoa rig now. First result is 8.83t/s cpu-only inference. @fairydreaming: is there anything I can do to assist in implementation? I'd need to pull and requant, but don't mind doing so if I can be of use. Thanks for getting this model working!

@cpumaxx If you have time please repeat the whole conversion and model testing process with the current code to confirm that it still works without problems (and that old DeepSeek V2 and V2.5 still work in case you have them). I just finished dealing with the llama.cpp file explosion caused by #10902, time to get some rest.

I can now confirm that the re-quanted model works with the new code (and that the old model doesn't)

@fairydreaming
Copy link
Collaborator

fairydreaming commented Jan 4, 2025

@cpumaxx Can you post the error you get with the old model?
Edit: do you mean the previous DeepSeek V3 GGUF or V2/V2.5?

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 4, 2025

@cpumaxx Can you post the error you get with the old model? Edit: do you mean the previous DeepSeek V3 GGUF or V2/V2.5?

Sorry, I meant the old v3 quant doesn't work with the new code (I tried as a sanity check to make sure I was on the new code and that the new quant was really different)
I haven't regression tested on v2/2.5 yet. I should be able to this afternoon (I have some weekend work today)

@fairydreaming
Copy link
Collaborator

@cpumaxx Can you post the error you get with the old model? Edit: do you mean the previous DeepSeek V3 GGUF or V2/V2.5?

Sorry, I meant the old v3 quant doesn't work with the new code (I tried as a sanity check to make sure I was on the new code and that the new quant was really different) I haven't regression tested on v2/2.5 yet. I should be able to this afternoon (I have some weekend work today)

OK, that was expected since a tensor name changed.

@RodriMora
Copy link
Author

Same here, new quants working fine with the latest commit. Gives this error with the old quants:

llama_model_load: error loading model: done_getting_tensors: wrong number of tensors;

So everything working as expected.

I also run the MMLU-PRO computer science benchmark and got really good results:
Q4_K_M: 77.32
API: 78.05, 78.05, 77.80, 77.80 (test done by @WolframRavenwolf)

@fairydreaming
Copy link
Collaborator

Same here, new quants working fine with the latest commit. Gives this error with the old quants:

llama_model_load: error loading model: done_getting_tensors: wrong number of tensors;

So everything working as expected.

I also run the MMLU-PRO computer science benchmark and got really good results: Q4_K_M: 77.32 API: 78.05, 78.05, 77.80, 77.80 (test done by @WolframRavenwolf)

Great! BTW how did you run the benchmark, did you use llama-server? I also tried to run this bench today out of curiosity (via OpenAI-compatible endpoint of llama-server) but experienced llama-server token generation speed gradually getting slower and slower (at the beginning it was over 9 t/s, but around question 180 only around 5 t/s) so I started investigating why it does that. I don't know, maybe the prompt cache started to grow too big and caused parts of the model to be removed from RAM. Did you notice a similar behavior?

@RodriMora
Copy link
Author

Same here, new quants working fine with the latest commit. Gives this error with the old quants:
llama_model_load: error loading model: done_getting_tensors: wrong number of tensors;
So everything working as expected.
I also run the MMLU-PRO computer science benchmark and got really good results: Q4_K_M: 77.32 API: 78.05, 78.05, 77.80, 77.80 (test done by @WolframRavenwolf)

Great! BTW how did you run the benchmark, did you use llama-server? I also tried to run this bench today out of curiosity (via OpenAI-compatible endpoint of llama-server) but experienced llama-server token generation speed gradually getting slower and slower (at the beginning it was over 9 t/s, but around question 180 only around 5 t/s) so I started investigating why it does that. I don't know, maybe the prompt cache started to grow too big and caused parts of the model to be removed from RAM. Did you notice a similar behavior?

Same, running llama-server with the OAI api endpoint as the backend and ollama-mmlu-pro for the benchmark.

This is the report results I got at the end:

Finished testing computer science in 9 hours 15 minutes 45 seconds.
Total, 317/410, 77.32%
Random Guess Attempts, 0/410, 0.00%
Correct Random Guesses, division by zero error
Adjusted Score Without Random Guesses, 317/410, 77.32%
Finished the benchmark in 9 hours 15 minutes 49 seconds.
Total, 317/410, 77.32%
Token Usage:
Prompt tokens: min 1598, average 1747, max 2688, total 709250, tk/s 21.27
Completion tokens: min 36, average 167, max 844, total 67668, tk/s 2.03
Markdown Table:
overall computer science
77.32 77.32

So 2t/s average at the end. I didn't monitor it as I left it overnight and already closed the ssh session so I don't have the llama-server output. I'll test again and monitor RAM. But I think it may be just the longer questions? I get 5t/s on super low context (under 200 tokens). 4t/s at 500-1000 context. And 2t/s at 2000-3000 context from some quick test.

@fairydreaming
Copy link
Collaborator

fairydreaming commented Jan 4, 2025

@RodriMora I'm not sure, there is definitely large variance caused by different length of prompts/generated token sequences (and different sets of activated experts), but values close to 8 t/s are only at the beginning.
llama-server-eval-time

Edit: tomorrow I'm going to disable prompt caching and run it again, will see if it changes anything.
Edit2: Disabling prompt caching didn't change much, it seems that there is a bunch of short prompts at the beginning of the benchmark, that's why the generation is faster there. The token generation speed quickly goes down with increased prompt length.

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 4, 2025

I know this is closed/merged, but as a datapoint: deepseek 2.5 didn't show any regressions.

@fairydreaming
Copy link
Collaborator

I know this is closed/merged, but as a datapoint: deepseek 2.5 didn't show any regressions.

Great, thanks for checking!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants