Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Command-R Model #6033

Merged
merged 1 commit into from
Mar 15, 2024
Merged

Add Command-R Model #6033

merged 1 commit into from
Mar 15, 2024

Conversation

acanis
Copy link
Contributor

@acanis acanis commented Mar 13, 2024

Information about the Command-R 35B model (128k context) can be found at:
https://huggingface.co/CohereForAI/c4ai-command-r-v01

Based on the llama2 model with a few changes:

  1. New hyper parameter to scale output logits (logit_scale)
  2. Uses LayerNorm instead of RMSNorm
  3. Transfomer layers have a single shared LayerNorm that feeds into both the
    self-attention and FFN layers in parallel. There is no post-attention LayerNorm.
  4. No support for Rotary Position Embeddings (RoPE) scaling
  5. No biases used

Find GGUF files here:
https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF

To convert model to GGUF format yourself:

  1. Download Command-R Hugging Face safetensors:
    git lfs install
    git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01

  2. Run:
    python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01

@acanis
Copy link
Contributor Author

acanis commented Mar 13, 2024

I can successfully convert the model to GGUF format with these changes.
Inference is still a work in progress. I will work on the build_command_r() function tomorrow.

@sweetcard
Copy link

You are so cool 👍

@choyakawa
Copy link

It seems that cohere use a different way of rotate_half in their codebase, not like other llama-based ones. Would this be okay with llama.cpp?

@acanis
Copy link
Contributor Author

acanis commented Mar 13, 2024

I noticed some potentially useful comments about the model here: Lightning-AI/litgpt#1089

@choyakawa
Copy link

choyakawa commented Mar 13, 2024

I noticed some potentially useful comments about the model here: Lightning-AI/litgpt#1089

I think they've missed the rotate_half part - which is different while my fine-tuning test, not sure if it is significant to inference.

def rotate_half(x):
    # Split and rotate
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
    return rot_x

@ggerganov
Copy link
Owner

ggerganov commented Mar 13, 2024

If I'm reading the code correctly in modeling, this is the same RoPE as usual:

llama.cpp/ggml.c

Lines 12696 to 12702 in f30ea47

const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}

So I don't think any changes are needed. Just use LLAMA_ROPE_TYPE_NORM

@acanis
Copy link
Contributor Author

acanis commented Mar 13, 2024

Any ideas on where to add logit_scale?

Comparing the llama and the cohere models: llama_cohere_diff.txt

@@ -1212,6 +1161,7 @@
             logits = torch.cat(logits, dim=-1)
         else:
             logits = self.lm_head(hidden_states)
+        logits = logits * self.logit_scale
         logits = logits.float()

@ggerganov
Copy link
Owner

ggerganov commented Mar 13, 2024

Probably this would work:

diff --git a/llama.cpp b/llama.cpp
index 38e7036a..a911cdff 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -5857,6 +5857,12 @@ struct llm_build_context {
 
         // lm_head
         cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        if (model.logits_scale != 1.0f) {
+            cur = ggml_scale(ctx0, cur, model.logits_scale);
+        }
+
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);

@acanis
Copy link
Contributor Author

acanis commented Mar 13, 2024

Another change is in the transformer layer.
Compared to llama2:

graph TD;
    Input-->InputLayerNorm;
    InputLayerNorm-->Attention;
    Attention-->Plus;
    Input-->Plus;
    Plus-->PostAttnLayerNorm;
    PostAttnLayerNorm-->MLP;
    MLP-->Plus2;
    Plus-->Plus2
    Plus2-->Output;
Loading

In cohere transformer layers there is no post-attention layernorm:

graph TD;
    Input-->InputLayerNorm;
    Input-->Plus;
    InputLayerNorm-->Attention;
    Attention-->Plus;
    InputLayerNorm-->MLP;
    MLP-->Plus
    Plus-->Output;
Loading

In the python code comparing cohere to llama model:

     def forward(
         self,
@@ -737,7 +691,7 @@
         hidden_states = self.input_layernorm(hidden_states)

         # Self Attention
-        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+        hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
             hidden_states=hidden_states,
             attention_mask=attention_mask,
             position_ids=position_ids,
@@ -747,16 +701,12 @@
             cache_position=cache_position,
             **kwargs,
         )
-        hidden_states = residual + hidden_states

         # Fully Connected
-        residual = hidden_states
-        hidden_states = self.post_attention_layernorm(hidden_states)

+        hidden_states_mlp = self.mlp(hidden_states)
-        hidden_states = self.mlp(hidden_states)

         # Add everything together
-        hidden_states = residual + hidden_states
+        hidden_states = residual + hidden_states_attention + hidden_states_mlp

         outputs = (hidden_states,)

@acanis
Copy link
Contributor Author

acanis commented Mar 13, 2024

Okay I made the changes mentioned above. Thanks for the help Georgi.
I still need to look through the rest of the model comparison to see if there are other changes and test vs the original model.

@Noeda
Copy link
Contributor

Noeda commented Mar 14, 2024

Howdy. I was working on this same thing to get this model working on llama.cpp; but my results were incoherent when I got it to run for the first time, whereas when I tested the code in this branch in this PR, I got coherent text, so this is definitely more correct than my hack.

But I also noticed Q8_0 quant does not work at all. quantize produces a 10 megabyte file that does not work. Q6_K quant worked fine, produced a proper file, and produces coherent text. I am unsure if this is in this new model code or a bug in llama.cpp overall. I had this problem with my own branch, and the same thing happens with this PR. I have not read the code of quantize as of writing this comment so not sure why Q8 specifically would be broken rather than all quants.

Other than that, I can confirm that the code so far, as written by @acanis produces coherent text that seems indistinguishable from the HF version (although I didn't test exactly on logit probability level at this time) so I think in terms of getting the model correctly computed mathematically it cannot be very wrong. I can come back to this tomorrow or later this week if no one else chimes in to confirm correctness.

Edit: quick edit: I noticed @acanis pushed new code about 1 hour before I wrote this comment and right after I wrote this comment. I was testing the commits before those happened.

@acanis
Copy link
Contributor Author

acanis commented Mar 14, 2024

Thanks for testing Noeda. Any help is appreciated.
The changes I just made were to properly set the f_norm_eps hyperparameter which was incorrectly set to 0 before.

I just tested llama.cpp on the F16 model (without quantization) using the suggested prompt:
./main -m /root/.cache/huggingface/hub/models--CohereForAI--c4ai-command-r-v01/snapshots/9fe64d67d13873f218cb05083b6fc2faab2d034a/ggml-model-f16.gguf -p "<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" -n 500

Prompt output is:
Hello! I'm doing well, thank you for asking. I hope the same goes for you too! It's a pleasure to be of assistance and help you in any way that I can. How can I assist you today? Do you have any questions or inquiries that you'd like me to look into? [end of text]

This matches pretty closely to the output from the 8-bit reference quantized python model:
Hello! I'm doing well, thank you for asking! I'm excited to assist you with whatever questions or tasks you have. How can I help you today?<|END_OF_TURN_TOKEN|>

Full log (CPU) is below:

Log start                                                                                                                                                              
main: build = 2410 (6a03064)                                                                                                                                           
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu                                                                                         
main: seed  = 1710399628                                                                                                                                               
llama_model_loader: loaded meta data with 22 key-value pairs and 322 tensors from /root/.cache/huggingface/hub/models--CohereForAI--c4ai-command-r-v01/snapshots/9fe64d67d13873f218cb05083b6fc2faab2d034a/ggml-model-f16.gguf (version GGUF V3 (latest))  
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = command-r
llama_model_loader: - kv   1:                               general.name str              = 9fe64d67d13873f218cb05083b6fc2faab2d034a
llama_model_loader: - kv   2:                      command-r.block_count u32              = 40
llama_model_loader: - kv   3:                   command-r.context_length u32              = 8192
llama_model_loader: - kv   4:                 command-r.embedding_length u32              = 8192
llama_model_loader: - kv   5:              command-r.feed_forward_length u32              = 22528
llama_model_loader: - kv   6:             command-r.attention.head_count u32              = 64
llama_model_loader: - kv   7:          command-r.attention.head_count_kv u32              = 64
llama_model_loader: - kv   8:                   command-r.rope.freq_base f32              = 8000000.000000
llama_model_loader: - kv   9:     command-r.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 1
llama_model_loader: - kv  11:                      command-r.logit_scale f32              = 0.062500
llama_model_loader: - kv  12:                command-r.rope.scaling.type str              = none
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  14:                      tokenizer.ggml.tokens arr[str,256000]  = ["<PAD>", "<UNK>", "<CLS>", "<SEP>", ...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, ...
llama_model_loader: - kv  16:                      tokenizer.ggml.merges arr[str,253333]  = ["Ġ Ġ", "Ġ t", "e r", "i n", "Ġ a...
llama_model_loader: - kv  17:                tokenizer.ggml.bos_token_id u32              = 5
llama_model_loader: - kv  18:                tokenizer.ggml.eos_token_id u32              = 255001
llama_model_loader: - kv  19:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  20:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  21:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - type  f32:   41 tensors
llama_model_loader: - type  f16:  281 tensors
llm_load_vocab: special tokens definition check successful ( 1008/256000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = command-r
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 253333
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 64
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 8192
llm_load_print_meta: n_embd_v_gqa     = 8192
llm_load_print_meta: f_norm_eps       = 1.0e-05
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale   = 6.2e-02
llm_load_print_meta: n_ff             = 22528
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attm      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = none
llm_load_print_meta: freq_base_train  = 8000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 35B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 34.98 B
llm_load_print_meta: model size       = 65.16 GiB (16.00 BPW) 
llm_load_print_meta: general.name     = 9fe64d67d13873f218cb05083b6fc2faab2d034a
llm_load_print_meta: BOS token        = 5 '<BOS_TOKEN>'
llm_load_print_meta: EOS token        = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: PAD token        = 0 '<PAD>'
llm_load_print_meta: LF token         = 136 'Ä'
llm_load_tensors: ggml ctx size =    0.12 MiB
llm_load_tensors:        CPU buffer size = 66721.28 MiB
...........................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 8000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =   640.00 MiB
llama_new_context_with_model: KV self size  =  640.00 MiB, K (f16):  320.00 MiB, V (f16):  320.00 MiB
llama_new_context_with_model:        CPU input buffer size   =    18.01 MiB
llama_new_context_with_model:        CPU compute buffer size =   516.00 MiB
llama_new_context_with_model: graph splits (measure): 1

system_info: n_threads = 127 / 255 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 512, n_predict = 500, n_keep = 1


<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hello! I'm doing well, thank you for asking. I hope the same goes for you too! It's a pleasure to be of assistance and help you in any way that I can. How can I assist you today? Do you have any questions or inquiries that you'd like me to look into? [end of text]

llama_print_timings:        load time =    9549.74 ms
llama_print_timings:      sample time =     243.81 ms /    63 runs   (    3.87 ms per token,   258.40 tokens per second)
llama_print_timings: prompt eval time =    8395.35 ms /    13 tokens (  645.80 ms per token,     1.55 tokens per second)
llama_print_timings:        eval time =  314820.10 ms /    62 runs   ( 5077.74 ms per token,     0.20 tokens per second)
llama_print_timings:       total time =  323720.89 ms /    75 tokens
Log end

Another log this time running with cuBLAS (on the same machine with an A100 GPU) with low temperature:

[email protected]:/cohere/llama.cpp/build/bin$ ./main --temp 0.001 -m /root/.cache/huggingface/hub/models--CohereForAI--c4ai-command-r-v01/snapshots/9fe64d67d13873f218cb05083b6fc2faab2d034a/ggml-model-f16.gguf -p "<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" -n 500   
log start
main: build = 2410 (6a03064)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1710401999
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA A100-PCIE-40GB, compute capability 8.0, VMM: yes
llama_model_loader: loaded meta data with 22 key-value pairs and 322 tensors from /root/.cache/huggingface/hub/models--CohereForAI--c4ai-command-r-v01/snapshots/9fe64d67d13873f218cb05083b6fc2faab2d034a/ggml-model-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = command-r
llama_model_loader: - kv   1:                               general.name str              = 9fe64d67d13873f218cb05083b6fc2faab2d034a
llama_model_loader: - kv   2:                      command-r.block_count u32              = 40
llama_model_loader: - kv   3:                   command-r.context_length u32              = 8192
llama_model_loader: - kv   4:                 command-r.embedding_length u32              = 8192
llama_model_loader: - kv   5:              command-r.feed_forward_length u32              = 22528
llama_model_loader: - kv   6:             command-r.attention.head_count u32              = 64
llama_model_loader: - kv   7:          command-r.attention.head_count_kv u32              = 64
llama_model_loader: - kv   8:                   command-r.rope.freq_base f32              = 8000000.000000
llama_model_loader: - kv   9:     command-r.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 1
llama_model_loader: - kv  11:                      command-r.logit_scale f32              = 0.062500
llama_model_loader: - kv  12:                command-r.rope.scaling.type str              = none
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  14:                      tokenizer.ggml.tokens arr[str,256000]  = ["<PAD>", "<UNK>", "<CLS>", "<SEP>", ...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, ...
llama_model_loader: - kv  16:                      tokenizer.ggml.merges arr[str,253333]  = ["Ġ Ġ", "Ġ t", "e r", "i n", "Ġ a...
llama_model_loader: - kv  17:                tokenizer.ggml.bos_token_id u32              = 5
llama_model_loader: - kv  18:                tokenizer.ggml.eos_token_id u32              = 255001
llama_model_loader: - kv  19:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  20:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  21:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - type  f32:   41 tensors
llama_model_loader: - type  f16:  281 tensors
llm_load_vocab: special tokens definition check successful ( 1008/256000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = command-r
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 253333
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 64
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 8192
llm_load_print_meta: n_embd_v_gqa     = 8192
llm_load_print_meta: f_norm_eps       = 1.0e-05
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale   = 6.2e-02
llm_load_print_meta: n_ff             = 22528
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attm      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = none
llm_load_print_meta: freq_base_train  = 8000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 35B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 34.98 B
llm_load_print_meta: model size       = 65.16 GiB (16.00 BPW) 
llm_load_print_meta: general.name     = 9fe64d67d13873f218cb05083b6fc2faab2d034a
llm_load_print_meta: BOS token        = 5 '<BOS_TOKEN>'
llm_load_print_meta: EOS token        = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: PAD token        = 0 '<PAD>'
llm_load_print_meta: LF token         = 136 'Ä'
llm_load_tensors: ggml ctx size =    0.12 MiB
llm_load_tensors: offloading 0 repeating layers to GPU
llm_load_tensors: offloaded 0/41 layers to GPU
llm_load_tensors:        CPU buffer size = 66721.28 MiB
...........................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 8000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =   640.00 MiB
llama_new_context_with_model: KV self size  =  640.00 MiB, K (f16):  320.00 MiB, V (f16):  320.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =    18.01 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   516.00 MiB
llama_new_context_with_model: graph splits (measure): 1

system_info: n_threads = 127 / 255 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.001
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 512, n_predict = 500, n_keep = 1


<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I'm doing well, thank you for asking! I'm here and ready to assist you with whatever questions or topics you have. How can I help you today? Whether it's answering general knowledge questions, providing information on a specific subject, helping with problem solving, or just having a friendly chat, I'm here for you! [end of text]

llama_print_timings:        load time =   11663.40 ms
llama_print_timings:      sample time =     263.37 ms /    67 runs   (    3.93 ms per token,   254.40 tokens per second)
llama_print_timings: prompt eval time =    8989.45 ms /    13 tokens (  691.50 ms per token,     1.45 tokens per second)
llama_print_timings:        eval time =  285983.33 ms /    66 runs   ( 4333.08 ms per token,     0.23 tokens per second)
llama_print_timings:       total time =  295509.04 ms /    79 tokens

I also noticed that the quantize failed for Q8_0 with this error (didn't have time to investigate):

llama_model_loader: - type  f16:  281 tensors
llama_model_quantize_internal: meta size = 10884960 bytes
[   1/ 322]                    token_embd.weight - [ 8192, 256000,     1,     1], type =    f16, converting to q8_0 .. size =  4000.00 MiB -> 17592186042445.00 MiB
[   2/ 322]                  blk.0.attn_k.weight - [ 8192,  8192,     1,     1], type =    f16, converting to q8_0 .. size =   128.00 MiB ->    68.00 MiB
llama_model_quantize: failed to quantize: basic_ios::clear: iostream error
main: failed to quantize model from '/root/.cache/huggingface/hub/models--CohereForAI--c4ai-command-r-v01/snapshots/9fe64d67d13873f218cb05083b6fc2faab2d034a/ggml-model-f16.gguf'

@Noeda
Copy link
Contributor

Noeda commented Mar 14, 2024

@acanis Thanks for your efforts! I'll pull your code tomorrow and do some testing, and compare with the HF implementation and see if there's any inaccuracies. If there's coherent text, then I don't expect major divergences. And maybe also I will try to check why am I having trouble with Q8.

llama.cpp Outdated Show resolved Hide resolved
@Blaizzy
Copy link

Blaizzy commented Mar 14, 2024

Any ideas on where to add logit_scale?

Comparing the llama and the cohere models: llama_cohere_diff.txt

@@ -1212,6 +1161,7 @@
             logits = torch.cat(logits, dim=-1)
         else:
             logits = self.lm_head(hidden_states)
+        logits = logits * self.logit_scale
         logits = logits.float()

This is correct ✅
Keep in mind that lm_head weights are tied with the embedding layer.

@sweetcard
Copy link

The commits from chatllm.cpp may help.

foldl/chatllm.cpp@7f43063
foldl/chatllm.cpp@51755b7

@Nold360
Copy link
Contributor

Nold360 commented Mar 14, 2024

lgtm, successfully create a quant & run it :)

@sweetcard
Copy link

lgtm, successfully create a quant & run it :)

./quantize for Q8_0 doesn't work.

@sweetcard
Copy link

sweetcard commented Mar 14, 2024

@acanis How to support 128K context for this model by llama.cpp ? Any idea?

@Nold360
Copy link
Contributor

Nold360 commented Mar 14, 2024

lgtm, successfully create a quant & run it :)

./quantize for Q8_0 doesn't work.

oh, you're right.. that's weird

@acanis
Copy link
Contributor Author

acanis commented Mar 14, 2024

@sweetcard I was planning to investigate that next after fixing the quantization. The current context is 8192 tokens.
I saw some discussion here:
https://huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/12

@sweetcard
Copy link

@sweetcard I was planning to investigate that next after fixing the quantization. The current context is 8192 tokens. I saw some discussion here: https://huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/12

Thank you for your cool work👍

@ggerganov
Copy link
Owner

@sweetcard I was planning to investigate that next after fixing the quantization. The current context is 8192 tokens. I saw some discussion here: huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/12

If the model training context is 128k, make sure to set the context length in the GGUF meta data to 128k. The loading log should report 128k:

llama_model_loader: - kv   3:                   command-r.context_length u32              = 8192

How to support 128K context for this model by llama.cpp ? Any idea?

With llama.cpp you can configure the size of the context via -c. Just set it to -c 128000 and it should work (if you have enough RAM)

@sweetcard
Copy link

@sweetcard I was planning to investigate that next after fixing the quantization. The current context is 8192 tokens. I saw some discussion here: huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/12

If the model training context is 128k, make sure to set the context length in the GGUF meta data to 128k. The loading log should report 128k:

llama_model_loader: - kv   3:                   command-r.context_length u32              = 8192

How to support 128K context for this model by llama.cpp ? Any idea?

With llama.cpp you can configure the size of the context via -c. Just set it to -c 128000 and it should work (if you have enough RAM)

The discussion from the authors is about the context size: huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/12.

I’m not very clear about the real context size: 8k or 128k?

@ggerganov
Copy link
Owner

@saurabhdash The Metal backend (this is what @sweetcard appears to be using) already computes the rope in F32 precision.

For the moment I wouldn't pay much attention to the reported "degradation" because there are no steps provided by @sweetcard to reproduce, so many other things could have gone wrong in their tests. Moreover, from the screenshot earlier with the log it looks to be an OOM or integer overflow problem in the Metal implementation, so it's likely not related to this PR specifically.

@acanis
Copy link
Contributor Author

acanis commented Mar 15, 2024

Running the F16 unquantized command-r model on hellaswag for 400 random samples I get: 86.25%

Full log: https://gist.github.com/acanis/719c7474ff4439f59cddc3826a0ac34f

To reproduce run the following commands:

git clone https://github.com/acanis/llama.cpp.git
cd llama.cpp
mkdir build
cd build
cmake .. -DLLAMA_CUBLAS=ON
cmake --build . --config Release -- -j16
cd ..

wget https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/c4ai-command-r-v01-f16.gguf-split-a
wget https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/c4ai-command-r-v01-f16.gguf-split-b
cat c4ai-command-r-v01-f16.gguf-split-* > c4ai-command-r-v01-f16.gguf

# optional: to confirm the checksum of c4ai-command-r-v01-f16.gguf
wget https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/md5sum
md5sum -c md5sum 

./scripts/get-hellaswag.sh
./build/bin/perplexity --hellaswag -f hellaswag_val_full.txt -m c4ai-command-r-v01-f16.gguf -ngl 99 -ub 256 -b 2048

Edit: updated perplexity flags to offload to the GPU, thanks @slaren

@slaren
Copy link
Collaborator

slaren commented Mar 15, 2024

@acanis Looking at your log, I noticed that you didn't offload any layers to the GPU. I am just wondering, is there any reason for that? You should be able to get much better performance adding -ngl 99 -ub 256 -b 2048 to the command line.

@acanis
Copy link
Contributor Author

acanis commented Mar 15, 2024

@acanis Looking at your log, I noticed that you didn't offload any layers to the GPU. I am just wondering, is there any reason for that? You should be able to get much better performance adding -ngl 99 -ub 256 -b 2048 to the command line.

Thanks so much @slaren! My mistake, I was realizing that -ngl makes this much faster! :) That explains the low GPU utilization.

@acanis
Copy link
Contributor Author

acanis commented Mar 15, 2024

If there is a way to run hellaswag/gsm8k using llama.cpp, that would be a sure shot way to verify if the implementation is correct. Hellaswag should be ~84%.

@saurabhdash @ggerganov
I just ran the F16 unquantized command-r model on the full hellaswag dataset giving 84.4% as expected.
Full log: https://gist.github.com/acanis/a4a1775f9a45fb05d377822280fe2a8c

To reproduce use the commands above but add --hellaswag-tasks 10042:
./build/bin/perplexity --hellaswag -f hellaswag_val_full.txt -m c4ai-command-r-v01-f16.gguf --hellaswag-tasks 10042 -ngl 99 -ub 256 -b 2048

@Noeda
Copy link
Contributor

Noeda commented Mar 15, 2024

I don't have anything else from my side. I am fairly confident all is working. The only divergence I'm aware of is that tokenization is not entirely equal, and that slightly alters the results:

Just meddled with the HF implementation again, if we let HF run with its own tokenization, totally vanilla:

HF vanilla logits
|  6315 | Black    | 10.5 |
|  9732 | Western  | 10.2890625 |
|  7397 | Vir      | 10.2421875 |
|  6842 | region   | 9.953125 |
| 14664 | Eastern  | 9.859375 |
|  4376 | known    | 9.640625 |
|  4903 | major    | 9.625 |
|  5079 | city     | 9.609375 |
|  4509 | City     | 9.4140625 |
|  7155 | Av       | 9.3515625 |
HF logits if used with llama.cpp tokenization
|  6315 | Black    | 10.6796875 |
|  7397 | Vir      | 10.46875 |
|  9732 | Western  | 10.3046875 |
|  6842 | region   | 10.015625 |
| 14664 | Eastern  | 9.828125 |
| 12999 | Southern | 9.5 |
|  5079 | city     | 9.4609375 |
|  4903 | major    | 9.40625 |
|  7155 | Av       | 9.3203125 |
| 71010 | jungle   | 9.2734375 |

It's not a big difference, but it does alter top logits ordering a little bit.

The only difference I see with my eyeballs in the sample text tokenizations is that tokens 206 ('\n') and 2126 ('\n\n') are done differently: if there's two newlines in the text: llama.cpp outputs a 2126 token but HF tokenizes with two 206s instead.

It feels like llama.cpp should be more correct because why else would the token exist in the dictionary, if you are not going to use it. Command-R has unusually large dictionary anyway. But I guess that would depend on the training.

@saurabhdash Do you which one is better with your model to inference with? Token 2126 (double newline) vs two 206s? Or maybe in other words:

If you had string, hello\n\nworld, should we tokenize it to: [34313, 206, 206, 17080] or [34313, 2126, 17080]?

hello

world

I just ran the F16 unquantized command-r model on the full hellaswag dataset giving 84.4% as expected.
Full log: https://gist.github.com/acanis/a4a1775f9a45fb05d377822280fe2a8c

I've been over time collecting results from a shorter 400-test run on random models that I see hyped on Reddit and Command-R scores around the same as most other big open source models, here's some other models:

model hellaswag 400
cohere_f16.gguf 82.75
mixtral-8x7b-instruct-v0.1.Q8_0 84.25
DolphinHermes-120b.Q6_K.gguf 84.25
gemma-7b-it.Q8_0-v2.gguf 71.50
miqu-Q6_K.gguf 85

(don't use to rank models)

A typical score for a big model is 80-something. I use it as a type of smoke test; I think at least once there was a highly upvoted hyped model on subreddit that actually was just completely broken when tested. If the hellaswag-400 test gives 40 or 50 or something as result, you know it's probably broken.

Empirically Command-R kicks ass.

@Noeda
Copy link
Contributor

Noeda commented Mar 15, 2024

One afterthought just as I wrote that: I only saw double newlines as divergence but maybe if you had a long context with emojis or lots non-English characters (e.g. long Chinese context), it could a lot more divergent. I still feel llama.cpp probably has it correct, and HF implementation is the one that has the bug. @saurabhdash when you see this, see if you can find an answer to my question I had on my previous comment. 😄

@saurabhdash
Copy link

I don't have anything else from my side. I am fairly confident all is working. The only divergence I'm aware of is that tokenization is not entirely equal, and that slightly alters the results:

Just meddled with the HF implementation again, if we let HF run with its own tokenization, totally vanilla:

HF vanilla logits

|  6315 | Black    | 10.5 |
|  9732 | Western  | 10.2890625 |
|  7397 | Vir      | 10.2421875 |
|  6842 | region   | 9.953125 |
| 14664 | Eastern  | 9.859375 |
|  4376 | known    | 9.640625 |
|  4903 | major    | 9.625 |
|  5079 | city     | 9.609375 |
|  4509 | City     | 9.4140625 |
|  7155 | Av       | 9.3515625 |

HF logits if used with llama.cpp tokenization

|  6315 | Black    | 10.6796875 |
|  7397 | Vir      | 10.46875 |
|  9732 | Western  | 10.3046875 |
|  6842 | region   | 10.015625 |
| 14664 | Eastern  | 9.828125 |
| 12999 | Southern | 9.5 |
|  5079 | city     | 9.4609375 |
|  4903 | major    | 9.40625 |
|  7155 | Av       | 9.3203125 |
| 71010 | jungle   | 9.2734375 |

It's not a big difference, but it does alter top logits ordering a little bit.

The only difference I see with my eyeballs in the sample text tokenizations is that tokens 206 ('\n') and 2126 ('\n\n') are done differently: if there's two newlines in the text: llama.cpp outputs a 2126 token but HF tokenizes with two 206s instead.

It feels like llama.cpp should be more correct because why else would the token exist in the dictionary, if you are not going to use it. Command-R has unusually large dictionary anyway. But I guess that would depend on the training.

@saurabhdash Do you which one is better with your model to inference with? Token 2126 (double newline) vs two 206s? Or maybe in other words:

If you had string, hello\n\nworld, should we tokenize it to: [34313, 206, 206, 17080] or [34313, 2126, 17080]?

hello

world

I just ran the F16 unquantized command-r model on the full hellaswag dataset giving 84.4% as expected.
Full log: https://gist.github.com/acanis/a4a1775f9a45fb05d377822280fe2a8c

I've been over time collecting results from a shorter 400-test run on random models that I see hyped on Reddit and Command-R scores around the same as most other big open source models, here's some other models:

model hellaswag 400
cohere_f16.gguf 82.75
mixtral-8x7b-instruct-v0.1.Q8_0 84.25
DolphinHermes-120b.Q6_K.gguf 84.25
gemma-7b-it.Q8_0-v2.gguf 71.50
miqu-Q6_K.gguf 85
(don't use to rank models)

A typical score for a big model is 80-something. I use it as a type of smoke test; I think at least once there was a highly upvoted hyped model on subreddit that actually was just completely broken when tested. If the hellaswag-400 test gives 40 or 50 or something as result, you know it's probably broken.

Empirically Command-R kicks ass.

I checked: it is encoded as [34313, 206, 206, 17080]. PS: If there is a way to use the chat template, you'll see better results.

@saurabhdash
Copy link

Running the F16 unquantized command-r model on hellaswag for 400 random samples I get: 86.25%

Full log: https://gist.github.com/acanis/719c7474ff4439f59cddc3826a0ac34f

To reproduce run the following commands:

git clone https://github.com/acanis/llama.cpp.git
cd llama.cpp
mkdir build
cd build
cmake .. -DLLAMA_CUBLAS=ON
cmake --build . --config Release -- -j16
cd ..

wget https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/c4ai-command-r-v01-f16.gguf-split-a
wget https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/c4ai-command-r-v01-f16.gguf-split-b
cat c4ai-command-r-v01-f16.gguf-split-* > c4ai-command-r-v01-f16.gguf

# optional: to confirm the checksum of c4ai-command-r-v01-f16.gguf
wget https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/md5sum
md5sum -c md5sum 

./scripts/get-hellaswag.sh
./build/bin/perplexity --hellaswag -f hellaswag_val_full.txt -m c4ai-command-r-v01-f16.gguf -ngl 99 -ub 256 -b 2048

Edit: updated perplexity flags to offload to the GPU, thanks @slaren

Then the implementation looks correct :)

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you all for the implementation and the detailed analysis - much appreciated!

@ggerganov ggerganov merged commit 12247f4 into ggerganov:master Mar 15, 2024
58 of 62 checks passed
@Noeda
Copy link
Contributor

Noeda commented Mar 15, 2024

@acanis If I could, I'd tip you for quickly getting this to work :) great speed. I'm impressed.

@saurabhdash I still have questions about tokenization, but later today or next week I'll open that discussion on HF instead. If anything needs to be done for llama.cpp I'll make that a separate thing.

Thanks everyone! 👏 👏

@acanis
Copy link
Contributor Author

acanis commented Mar 15, 2024

Thanks so much for the help everyone! I'm glad we were able to get this support merged with correct functionality.

@Noeda great work on the detailed analysis between llama.cpp and the reference, and for fixing the quantization bug, that must have been hard to track down! I'm also curious about the tokenization, I will checkout your HF discussion
@saurabhdash thanks for the quick turnaround to update the config.json file and for the feedback on the model architecture
@ggerganov thanks for the help with the logit scaling and other changes. Great work on the llama.cpp code, it's easy to understand and modify!
@sweetcard thanks for all your help testing the model and quantization
@slaren thanks for the -ngl GPU tip! I wonder if we should turn that on by default?

This was a great opportunity for me to learn the llama.cpp codebase.

@slaren
Copy link
Collaborator

slaren commented Mar 15, 2024

@slaren thanks for the -ngl GPU tip! I wonder if we should turn that on by default?

We should definitely do better with the defaults. The problem is that not everybody has enough VRAM to offload the entire model, and currently we don't have a way to detect how many layers can be offloaded automatically.

@chrismrutherford
Copy link

chrismrutherford commented Mar 15, 2024

Thanks so much for the patch and merging it so quickly after the release of Command-R. Not sure if this is the correct place to comment on a possible issue. I seem to have an problem converting the original model.

I have tried clean builds of latest master and acanis repo and get the same error. I had a quick look at the convert-hf-to-gguf source and tried hacking around a bit, but encountered other issues. I'm reasonably sure this is repeatable, but also struggling to understand why non one else is reporting it. Presumably I can work around the problem by checking out the converted models, but I thought id mention it in case it's an actual issue..

python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01
Loading model: c4ai-command-r-v01
gguf: This GGUF file is for Little Endian only
Traceback (most recent call last):
File "llama.cpp/convert-hf-to-gguf.py", line 2073, in
main()
File "llama.cpp/convert-hf-to-gguf.py", line 2054, in main
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
File "lama.cpp/convert-hf-to-gguf.py", line 1977, in init
self.hparams["max_position_embeddings"] = self.hparams["model_max_length"]
KeyError: 'model_max_length'

@Noeda
Copy link
Contributor

Noeda commented Mar 16, 2024

For those who are interested: I'm following up my tokenization question over here: https://huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/27

@acanis
Copy link
Contributor Author

acanis commented Mar 16, 2024

self.hparams["max_position_embeddings"] = self.hparams["model_max_length"]
KeyError: 'model_max_length'

@chrismrutherford you need to update your HF repo to get the latest version of the config.json:
https://huggingface.co/CohereForAI/c4ai-command-r-v01/commit/960d0bad58d35c695ff69dda1d2b90978f40b196

The PR to add model_max_length with the true 128k context length was only merged yesterday.

@sweetcard
Copy link

sweetcard commented Mar 16, 2024

Thank you very much for your excellent work.
We can enjoy this new model.

hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
Information about the Command-R 35B model (128k context) can be found at:
	https://huggingface.co/CohereForAI/c4ai-command-r-v01

Based on the llama2 model with a few changes:

1) New hyper parameter to scale output logits (logit_scale)
2) Uses LayerNorm instead of RMSNorm
3) Transfomer layers have a single shared LayerNorm that feeds into both the
   self-attention and FFN layers in parallel. There is no post-attention LayerNorm.
4) No support for Rotary Position Embeddings (RoPE) scaling
5) No biases used

Find GGUF files here:
	https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF

To convert model to GGUF format yourself:

1) Download Command-R Hugging Face safetensors:
	git lfs install
	git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01

2) Run:
	python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
Information about the Command-R 35B model (128k context) can be found at:
	https://huggingface.co/CohereForAI/c4ai-command-r-v01

Based on the llama2 model with a few changes:

1) New hyper parameter to scale output logits (logit_scale)
2) Uses LayerNorm instead of RMSNorm
3) Transfomer layers have a single shared LayerNorm that feeds into both the
   self-attention and FFN layers in parallel. There is no post-attention LayerNorm.
4) No support for Rotary Position Embeddings (RoPE) scaling
5) No biases used

Find GGUF files here:
	https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF

To convert model to GGUF format yourself:

1) Download Command-R Hugging Face safetensors:
	git lfs install
	git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01

2) Run:
	python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01
@EwoutH
Copy link
Contributor

EwoutH commented Apr 4, 2024

They now released a larger, 104B parameter model: C4AI Command R+

@Noeda
Copy link
Contributor

Noeda commented Apr 4, 2024

The new model has a use_qk_norm=True flag compared to old one. Otherwise seems almost identical. It wants a development version of transformers, and looking at the top commit they added code for the use_qk_norm which adds CohereNormLayers for query and key tensors in the attention block (huggingface/transformers@517a3e6). And nothing else that I can see. At least based on my 10 minutes checking when I got email from @EwoutH from commenting here and got excited.

Some of the parameters are different. logit_scale is different but the converter .py picks that up already. max_model_length is not included 😢

@saurabhdash I think are the committer of that top commit to transformers (different GitHub user though), did I read correctly what has changed?

I feel like this is coming my catchphrase...but I can help add the new layer code for llama.cpp...if no one else gets to it before I have time to sit down properly and do it. I would likely have time this evening or tomorrow. Doesn't look like much has changed.

@saurabhdash
Copy link

saurabhdash commented Apr 4, 2024

The new model has a use_qk_norm=True flag compared to old one. Otherwise seems almost identical. It wants a development version of transformers, and looking at the top commit they added code for the use_qk_norm which adds CohereNormLayers for query and key tensors in the attention block (huggingface/transformers@517a3e6). And nothing else that I can see. At least based on my 10 minutes checking when I got email from @EwoutH from commenting here and got excited.

Some of the parameters are different. logit_scale is different but the converter .py picks that up already. max_model_length is not included 😢

@saurabhdash I think are the committer of that top commit to transformers (different GitHub user though), did I read correctly what has changed?

I feel like this is coming my catchphrase...but I can help add the new layer code for llama.cpp...if no one else gets to it before I have time to sit down properly and do it. I would likely have time this evening or tomorrow. Doesn't look like much has changed.

Lemme know if you need any help. It's just qk_norm and GQA instead of MHA.

@acanis
Copy link
Contributor Author

acanis commented Apr 4, 2024

I love how active this project is :)
There is already a PR adding the new support: #6491

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.