Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add potential to run Jina Embeddings architecture #6826

Merged
merged 39 commits into from
May 11, 2024

Conversation

JoanFM
Copy link
Contributor

@JoanFM JoanFM commented Apr 22, 2024

Hello @ggerganov ,

Thanks for having this awesome project. I have been trying to add support for Jina Embeddings (https://huggingface.co/jinaai/jina-embeddings-v2-base-en) in llama.cpp

This PR aims to be able to run in llama.cpp the Jina Embeddings architecture.

For this, the changes made are:

  • Define JinaBertModel into convert-hf-to-gguf.py to be able to extract the tensors into GGUF.
  • Set options to force ollama to load the model with proper vocab settings (Add EOS and BOS tokens)
  • Introduce the LLM_ARCH_JINA_BERT architecture and adapt the tensors used by the implementation.
  • Adapt the build_bert model to adapt to some small changes needed by the model (like not having positional embeddings)
  • (The most controversial thing) Adapt or fix the ALIBI computation of softmax to have the slope multiplied by the distance to the diagonal of the specific head attention.

@JoanFM JoanFM changed the title Feat jina embeddings (DRAFT) feat: add potential to run Jina Embeddings architecture Apr 22, 2024
@JoanFM JoanFM marked this pull request as ready for review April 22, 2024 16:03
@JoanFM JoanFM changed the title (DRAFT) feat: add potential to run Jina Embeddings architecture feat: add potential to run Jina Embeddings architecture Apr 22, 2024
@JoanFM
Copy link
Contributor Author

JoanFM commented Apr 22, 2024

Hey @ggerganov ,

I would like to get some comments specially on the Alibi implementation, which is what I found more confusing.

@JoanFM JoanFM force-pushed the feat-jina-embeddings branch from e946cb0 to d7d6a4e Compare April 23, 2024 07:49
@ggerganov
Copy link
Member

The way it is implemented now, ALiBi is not applied because KQ_pos is null. You need to apply the following patch:

diff --git a/llama.cpp b/llama.cpp
index 309f4eec..1230a4bc 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4135,7 +4135,7 @@ static void llm_load_hparams(
 
     model.ftype = ml.ftype;
 
-    if (hparams.f_max_alibi_bias > 0.0f && model.arch != LLM_ARCH_JINA_BERT) {
+    if (hparams.f_max_alibi_bias > 0.0f) {
         hparams.need_kq_pos = true;
     }
 
@@ -7984,11 +7984,8 @@ struct llm_build_context {
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
-        struct ggml_tensor * inp_pos = nullptr;
+        struct ggml_tensor * inp_pos = inp_pos = build_inp_pos();
 
-        if (model.arch != LLM_ARCH_JINA_BERT) {
-            inp_pos = build_inp_pos();
-        }
         struct ggml_tensor * inp_mean = build_inp_mean();
         struct ggml_tensor * inp_cls  = build_inp_cls();
 
@@ -8010,6 +8007,9 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
+
         // iterate layers
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * cur = inpL;
@@ -8065,7 +8065,7 @@ struct llm_build_context {
             struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
             cb(kq, "kq", il);
 
-            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
+            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, KQ_pos, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
             cb(kq, "kq_soft_max_ext", il);
 
             struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
@@ -11131,7 +11131,7 @@ static int llama_decode_internal(
         }
 
         // non-causal masks do not use the KV cache
-        if (hparams.causal_attn) {
+        if (hparams.causal_attn || model.arch == LLM_ARCH_JINA_BERT) {
             llama_kv_cache_update(&lctx);
 
             // if we have enough unused cells before the current head ->

But this still does not work because the KQ_pos is padded.

Let's revisit this PR after merging #5021 - I think the fix should be relatively simple, but will be easier to resolve conflicts after we merge #5021

@JoanFM
Copy link
Contributor Author

JoanFM commented Apr 30, 2024

Hey @ggerganov ,

I have a couple of comments from the suggestions you made:

  • The inp_pos is not needed for Jina Embedding architecture and not adding the architecture check gives an Assertion error when running it.

  • The KQ_pos is not the one we need, every time the vector [0, 1, 2 ... num_tokens]*slope is added to the data with AliBI. However we need to represent a matrix like this [[0, 1, 2 3, ...], [1, 0, 1, 2, ...], [2, 1, 0, 1, ...]]. Not sure how to create this behavior. Also the slope or the vector needs to be negative.

@JoanFM JoanFM force-pushed the feat-jina-embeddings branch from da96368 to d9b8dd6 Compare April 30, 2024 12:34
@JoanFM
Copy link
Contributor Author

JoanFM commented May 8, 2024

Hey @ggerganov ,

Is there anything missing?

@ggerganov ggerganov self-requested a review May 8, 2024 14:33
@JoanFM
Copy link
Contributor Author

JoanFM commented May 9, 2024

Hey @ggerganov , I fixed the last conflicts

@ggerganov
Copy link
Member

Yup, thanks. I'll be looking today a bit more - think the ALiBi stuff still needs some changes/improvements. Hope to be ready soon

@mofosyne mofosyne added Review Complexity : High Generally require indepth knowledge of LLMs or GPUs enhancement New feature or request labels May 9, 2024
@ggerganov ggerganov mentioned this pull request May 10, 2024
8 tasks
@ggerganov
Copy link
Member

@JoanFM Thanks to your ALiBi-related changes here, I realized that my understanding of how ALiBi works was wrong. I will do a refactoring of the functionality on the master branch: #7192

After the refactoring is ready, we will rebase this PR and merge it

@ggerganov ggerganov changed the base branch from master to gg/refactor-alibi-2 May 10, 2024 12:28
@ggerganov ggerganov changed the base branch from gg/refactor-alibi-2 to master May 11, 2024 07:32
@ggerganov ggerganov merged commit b83cc3f into ggml-org:master May 11, 2024
56 of 61 checks passed
JoanFM added a commit to JoanFM/llama.cpp that referenced this pull request May 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants