Skip to content

Commit

Permalink
llama : aboud ggml_repeat during classification
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 23, 2024
1 parent 5f95dcc commit 5b6468f
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10096,9 +10096,6 @@ struct llm_build_context {
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);

// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
cur = ggml_repeat(ctx0, cur, inp);
} break;
default:
{
Expand Down Expand Up @@ -16831,7 +16828,6 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq;
Expand All @@ -16845,6 +16841,20 @@ static int llama_decode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rank score - a single float per sequence
auto & embd_seq_out = lctx.embd_seq;

for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
Expand Down

0 comments on commit 5b6468f

Please sign in to comment.