Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : output embeddings for all tokens when pooling = none #10861

Merged
merged 12 commits into from
Dec 18, 2024
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
// Embedding utils
//

void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
// TODO: repace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);

float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

Expand Down
2 changes: 1 addition & 1 deletion examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}

std::vector<float> emb_norm(emb_unorm.size());
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm);

#ifdef GRIT_DEBUG
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd);
common_embd_normalize(embd, out, n_embd, 2);
}
}

Expand Down
42 changes: 42 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \

### POST `/v1/embeddings`: OpenAI-compatible embeddings API

This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.

*Options:*

See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
Expand Down Expand Up @@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
}'
```

### POST `/embeddings`: non-OpenAI-compatible embeddings API

This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.

Note that the response format of this endpoint is different from `/v1/embeddings`.

*Options:*

Same as the `/v1/embeddings` endpoint.

*Examples:*

Same as the `/v1/embeddings` endpoint.

**Response format**

```json
[
{
"index": 0,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
},
...
{
"index": P,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
}
]
```

### GET `/slots`: Returns the current slots processing state

> [!WARNING]
Expand Down
74 changes: 56 additions & 18 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,18 +726,32 @@ struct server_task_result_cmpl_partial : server_task_result {

struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<float> embedding;
std::vector<std::vector<float>> embedding;

int32_t n_tokens;

// OAI-compat fields
bool oaicompat = false;

virtual int get_index() override {
return index;
}

virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
};
}

json to_json_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
Expand Down Expand Up @@ -2017,9 +2031,10 @@ struct server_context {

void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;

const int n_embd = llama_n_embd(model);

Expand All @@ -2038,12 +2053,18 @@ struct server_context {
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);

res->embedding = std::vector<float>(n_embd, 0.0f);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}

common_embd_normalize(embd, embd_res.data(), n_embd);
res->embedding = embd_res;
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
res->embedding.push_back({ embd, embd + n_embd });
}
}

SLT_DBG(slot, "%s", "sending embeddings\n");
Expand Down Expand Up @@ -2657,7 +2678,10 @@ struct server_context {

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;

common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand Down Expand Up @@ -3665,14 +3689,17 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const json body = json::parse(req.body);
bool oaicompat = false;

if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}

// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.contains("input")) {
oaicompat = true;
if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = false;
Expand All @@ -3697,10 +3724,15 @@ int main(int argc, char ** argv) {
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);

task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);

// OAI-compat
task.params.oaicompat = oaicompat;

tasks.push_back(task);
}

Expand Down Expand Up @@ -3728,12 +3760,18 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat
? format_embeddings_response_oaicompat(body, responses)
: responses.size() == 1 ? responses[0] : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
res_ok(res, root);
};

const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false);
};

const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true);
};

const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
Expand Down Expand Up @@ -3907,7 +3945,7 @@ int main(int argc, char ** argv) {
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);
Expand Down
65 changes: 50 additions & 15 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def create_server():

def test_embedding_single():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
Expand All @@ -29,8 +30,9 @@ def test_embedding_single():

def test_embedding_multiple():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
Expand All @@ -46,7 +48,7 @@ def test_embedding_multiple():


@pytest.mark.parametrize(
"content,is_multi_prompt",
"input,is_multi_prompt",
[
# single prompt
("string", False),
Expand All @@ -59,34 +61,65 @@ def test_embedding_multiple():
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(content, is_multi_prompt: bool):
def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"content": content})
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200
data = res.body['data']
if is_multi_prompt:
assert len(res.body) == len(content)
for d in res.body:
assert len(data) == len(input)
for d in data:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in res.body
assert len(res.body['embedding']) > 1
assert 'embedding' in data[0]
assert len(data[0]['embedding']) > 1


def test_embedding_pooling_none():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "hello hello hello",
})
assert res.status_code == 200
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special

# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON


def test_embedding_pooling_none_oai():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello",
})

# /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400


def test_embedding_openai_library_single():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1


def test_embedding_openai_library_multiple():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add the /v1 suffix for these calls, otherwise the client will hit the /embeddings endpoint which is not OAI compatible. Is this the correct fix?

Copy link
Collaborator

@ngxson ngxson Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct, I haven't notice that before. The original OAI code has /v1 in its default value: https://github.com/openai/openai-python/blob/e94d98e9bf97a5d2d02d79d58f2abdbab26ff2bd/src/openai/_client.py#L117

We should change all other recurrence too (in another PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw for embeddings endpoint we also have the switch between non-OAI and OAI using the presence of content and input fields in the response. Is it still relevant now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified this to be able to use both endpoints with both inputs. And I'm thinking to also make the change that content is alias to input - maybe we can do this directly in #10866

res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
Expand All @@ -100,17 +133,19 @@ def test_embedding_openai_library_multiple():

def test_embedding_error_prompt_too_long():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
assert "too large" in res.body["error"]["message"]


def test_same_prompt_give_same_result():
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
Expand Down Expand Up @@ -138,7 +173,7 @@ def test_same_prompt_give_same_result():
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"input": content})
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
Expand All @@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
def test_embedding_usage_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
Expand Down
Loading
Loading