From ff8e91036426ccdfd94e6f005d5647b3f6c3c660 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 1 Feb 2024 01:17:43 +0100 Subject: [PATCH] llama.cpp: update server --- backend/cpp/llama/CMakeLists.txt | 2 +- backend/cpp/llama/Makefile | 1 + backend/cpp/llama/grpc-server.cpp | 913 ++++++++---------------------- backend/cpp/llama/utils.hpp | 510 +++++++++++++++++ 4 files changed, 745 insertions(+), 681 deletions(-) create mode 100644 backend/cpp/llama/utils.hpp diff --git a/backend/cpp/llama/CMakeLists.txt b/backend/cpp/llama/CMakeLists.txt index 7caa10cda483..8299705a36a0 100644 --- a/backend/cpp/llama/CMakeLists.txt +++ b/backend/cpp/llama/CMakeLists.txt @@ -70,7 +70,7 @@ add_library(hw_grpc_proto ${hw_proto_srcs} ${hw_proto_hdrs} ) -add_executable(${TARGET} grpc-server.cpp json.hpp ) +add_executable(${TARGET} grpc-server.cpp utils.hpp json.hpp) target_link_libraries(${TARGET} PRIVATE common llama myclip ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto absl::flags_parse gRPC::${_REFLECTION} diff --git a/backend/cpp/llama/Makefile b/backend/cpp/llama/Makefile index d66f4e73c9e4..b050b62003a5 100644 --- a/backend/cpp/llama/Makefile +++ b/backend/cpp/llama/Makefile @@ -40,6 +40,7 @@ llama.cpp/examples/grpc-server: cp -r $(abspath ./)/CMakeLists.txt llama.cpp/examples/grpc-server/ cp -r $(abspath ./)/grpc-server.cpp llama.cpp/examples/grpc-server/ cp -rfv $(abspath ./)/json.hpp llama.cpp/examples/grpc-server/ + cp -rfv $(abspath ./)/utils.hpp llama.cpp/examples/grpc-server/ echo "add_subdirectory(grpc-server)" >> llama.cpp/examples/CMakeLists.txt ## XXX: In some versions of CMake clip wasn't being built before llama. ## This is an hack for now, but it should be fixed in the future. diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 76a82a33dc86..1b82d197e640 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -19,6 +19,7 @@ #include "grammar-parser.h" #include "backend.pb.h" #include "backend.grpc.pb.h" +#include "utils.hpp" // include std::regex #include @@ -30,6 +31,7 @@ #include #include #include +#include using grpc::Server; using grpc::ServerBuilder; @@ -42,205 +44,19 @@ using backend::HealthMessage; ///// LLAMA.CPP server code below -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" - using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; - std::string api_key; + std::vector api_keys; std::string public_path = "examples/server/public"; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; }; -static bool server_verbose = false; - -#if SERVER_VERBOSE != 1 -#define LOG_VERBOSE(MSG, ...) -#else -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - if (server_verbose) \ - { \ - server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) - -json oaicompat_completion_params_parse(const json &body); -std::string format_chatml(std::vector messages); - - -// -// base64 utils (TODO: move to common in the future) -// - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -static inline bool is_base64(uint8_t c) -{ - return (isalnum(c) || (c == '+') || (c == '/')); -} - -static std::vector base64_decode(const std::string & encoded_string) -{ - int i = 0; - int j = 0; - int in_ = 0; - - int in_len = encoded_string.size(); - - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; - - std::vector ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) - { - for (i = 0; i <4; i++) - { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) - { - ret.push_back(char_array_3[i]); - } - i = 0; - } - } - - if (i) - { - for (j = i; j <4; j++) - { - char_array_4[j] = 0; - } - - for (j = 0; j <4; j++) - { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; (j < i - 1); j++) - { - ret.push_back(char_array_3[j]); - } - } - - return ret; -} - -// -// parallel -// - -enum task_type { - TASK_TYPE_COMPLETION, - TASK_TYPE_CANCEL, -}; - -struct task_server { - int id; - int target_id; - task_type type; - json data; - bool infill_mode = false; - bool embedding_mode = false; - int multitask_id = -1; -}; - -struct task_result { - int id; - int multitask_id = -1; - bool stop; - bool error; - json result_json; -}; - -struct task_multi { - int id; - std::set subtasks_remaining{}; - std::vector results{}; -}; - -// TODO: can become bool if we can't find use of more states -enum slot_state -{ - IDLE, - PROCESSING, -}; - -enum slot_command -{ - NONE, - LOAD_PROMPT, - RELEASE, -}; - -struct slot_params -{ - bool stream = true; - bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - - uint32_t seed = -1; // RNG seed - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_predict = -1; // new tokens to predict - - std::vector antiprompt; - - json input_prefix; - json input_suffix; -}; - -struct slot_image -{ - int32_t id; - - bool request_encode_image = false; - float * image_embedding = nullptr; - int32_t image_tokens = 0; - - clip_image_u8 * img_data; - - std::string prefix_prompt; // before of this image -}; - -// completion token output with probabilities -struct completion_token_output -{ - struct token_prob - { - llama_token tok; - float prob; - }; - - std::vector probs; - llama_token tok; - std::string text_to_send; -}; +bool server_verbose = false; static size_t common_part(const std::vector &a, const std::vector &b) { @@ -296,28 +112,6 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) return ret; } -static void server_log(const char *level, const char *function, int line, - const char *message, const nlohmann::ordered_json &extra) -{ - nlohmann::ordered_json log - { - {"timestamp", time(nullptr)}, - {"level", level}, - {"function", function}, - {"line", line}, - {"message", message}, - }; - - if (!extra.empty()) - { - log.merge_patch(extra); - } - - const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); - printf("%.*s\n", (int)str.size(), str.data()); - fflush(stdout); -} - // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { @@ -359,15 +153,6 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector -static T json_value(const json &body, const std::string &key, const T &default_value) -{ - // Fallback null to default value - return body.contains(key) && !body.at(key).is_null() - ? body.value(key, default_value) - : default_value; -} - struct llama_client_slot { int id; @@ -414,6 +199,12 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; + int32_t ga_i = 0; // group-attention state + int32_t ga_n = 1; // group-attention factor + int32_t ga_w = 512; // group-attention width + + int32_t n_past_se = 0; // self-extend + // multimodal std::vector images; @@ -442,6 +233,8 @@ struct llama_client_slot sent_count = 0; sent_token_probs_index = 0; infill = false; + ga_i = 0; + n_past_se = 0; generated_token_probs.clear(); @@ -462,7 +255,9 @@ struct llama_client_slot { return true; // limitless } + n_remaining = -1; + if (params.n_predict != -1) { n_remaining = params.n_predict - n_decoded; @@ -471,6 +266,7 @@ struct llama_client_slot { n_remaining = global_params.n_predict - n_decoded; } + return n_remaining > 0; // no budget } @@ -492,7 +288,7 @@ struct llama_client_slot } void release() { - if (state == IDLE || state == PROCESSING) + if (state == PROCESSING) { t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; command = RELEASE; @@ -540,7 +336,6 @@ struct llama_server_context bool all_slots_are_idle = false; bool add_bos_token = true; - int32_t id_gen; int32_t n_ctx; // total context for all clients / slots // system prompt @@ -555,13 +350,8 @@ struct llama_server_context // slots / clients std::vector slots; - std::vector queue_tasks; - std::vector queue_results; - std::vector queue_multitasks; - std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks - std::condition_variable condition_tasks; - std::mutex mutex_results; - std::condition_variable condition_results; + llama_server_queue queue_tasks; + llama_server_response queue_results; ~llama_server_context() { @@ -620,8 +410,6 @@ struct llama_server_context } void initialize() { - id_gen = 0; - // create slots all_slots_are_idle = true; @@ -634,9 +422,26 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; - slot.reset(); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT + LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w); + } + + slot.ga_i = 0; + slot.ga_n = ga_n; + slot.ga_w = ga_w; + + slot.reset(); + slots.push_back(slot); } @@ -892,7 +697,7 @@ struct llama_server_context while ((pos = prompt.find(pattern, pos)) != std::string::npos) { size_t end_prefix = pos; pos += pattern.length(); - size_t end_pos = prompt.find("]", pos); + size_t end_pos = prompt.find(']', pos); if (end_pos != std::string::npos) { std::string image_id = prompt.substr(pos, end_pos - pos); @@ -1184,39 +989,13 @@ struct llama_server_context void send_error(task_server& task, const std::string &error) { LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); - std::unique_lock lock(mutex_results); task_result res; res.id = task.id; res.multitask_id = task.multitask_id; res.stop = false; res.error = true; res.result_json = { { "content", error } }; - queue_results.push_back(res); - condition_results.notify_all(); - } - - void add_multi_task(int id, std::vector& sub_ids) - { - std::lock_guard lock(mutex_tasks); - task_multi multi; - multi.id = id; - std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - condition_tasks.notify_one(); - } - - void update_multi_task(int multitask_id, int subtask_id, task_result& result) - { - std::lock_guard lock(mutex_tasks); - for (auto& multitask : queue_multitasks) - { - if (multitask.id == multitask_id) - { - multitask.subtasks_remaining.erase(subtask_id); - multitask.results.push_back(result); - condition_tasks.notify_one(); - } - } + queue_results.send(res); } json get_model_props() @@ -1262,7 +1041,6 @@ struct llama_server_context void send_partial_response(llama_client_slot &slot, completion_token_output tkn) { - std::unique_lock lock(mutex_results); task_result res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1297,13 +1075,11 @@ struct llama_server_context res.result_json["model"] = slot.oaicompat_model; } - queue_results.push_back(res); - condition_results.notify_all(); + queue_results.send(res); } void send_final_response(llama_client_slot &slot) { - std::unique_lock lock(mutex_results); task_result res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1351,23 +1127,12 @@ struct llama_server_context res.result_json["oaicompat_token_ctr"] = slot.n_decoded; res.result_json["model"] = slot.oaicompat_model; } - queue_results.push_back(res); - condition_results.notify_all(); - - // done with results, unlock - lock.unlock(); - - // parent multitask, if any, needs to be updated - if (slot.multitask_id != -1) - { - update_multi_task(slot.multitask_id, slot.task_id, res); - } + queue_results.send(res); } void send_embedding(llama_client_slot &slot) { - std::unique_lock lock(mutex_results); task_result res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1394,15 +1159,13 @@ struct llama_server_context {"embedding", embedding }, }; } - queue_results.push_back(res); - condition_results.notify_all(); + queue_results.send(res); } - int request_completion(json data, bool infill, bool embedding, int multitask_id) + void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) { - std::unique_lock lock(mutex_tasks); task_server task; - task.id = id_gen++; + task.id = task_id; task.target_id = 0; task.data = std::move(data); task.infill_mode = infill; @@ -1413,47 +1176,11 @@ struct llama_server_context // when a completion task's prompt array is not a singleton, we split it into multiple requests if (task.data.count("prompt") && task.data.at("prompt").size() > 1) { - lock.unlock(); // entering new func scope - return split_multiprompt_task(task); + split_multiprompt_task(task_id, task); } // otherwise, it's a single-prompt task, we actually queue it - queue_tasks.push_back(task); - condition_tasks.notify_one(); - return task.id; - } - - task_result next_result(int task_id) - { - while (true) - { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - return !queue_results.empty(); - }); - - for (int i = 0; i < (int) queue_results.size(); i++) - { - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (queue_results[i].multitask_id == task_id) - { - update_multi_task(task_id, queue_results[i].id, queue_results[i]); - queue_results.erase(queue_results.begin() + i); - continue; - } - - if (queue_results[i].id == task_id) - { - assert(queue_results[i].multitask_id == -1); - task_result res = queue_results[i]; - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // never reached - //return task_result{-1, false, false, {}}; + queue_tasks.post(task); } // for multiple images processing @@ -1516,7 +1243,7 @@ struct llama_server_context std::vector append_tokens = tokenize(json_prompt, false); // has next image for (int i = 0; i < (int) append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); + llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; } } @@ -1526,146 +1253,119 @@ struct llama_server_context void request_cancel(int task_id) { - std::unique_lock lock(mutex_tasks); task_server task; - task.id = id_gen++; task.type = TASK_TYPE_CANCEL; task.target_id = task_id; - queue_tasks.push_back(task); - condition_tasks.notify_one(); + queue_tasks.post(task); } - int split_multiprompt_task(task_server& multiprompt_task) + void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) { int prompt_count = multiprompt_task.data.at("prompt").size(); assert(prompt_count > 1); - int multitask_id = id_gen++; + // generate all the ID for subtask std::vector subtask_ids(prompt_count); for (int i = 0; i < prompt_count; i++) + { + subtask_ids[i] = queue_tasks.get_new_id(); + } + + // queue up the multitask so we can track its subtask progression + queue_tasks.add_multitask(multitask_id, subtask_ids); + + // add subtasks + for (int i = 0; i < prompt_count; i++) { json subtask_data = multiprompt_task.data; subtask_data["prompt"] = subtask_data["prompt"][i]; // subtasks inherit everything else (infill mode, embedding mode, etc.) - subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); + request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); } - - // queue up the multitask so we can track its subtask progression - add_multi_task(multitask_id, subtask_ids); - return multitask_id; } - void process_tasks() + void process_single_task(task_server& task) { - std::unique_lock lock(mutex_tasks); - std::vector deferred_tasks; - while (!queue_tasks.empty()) + switch (task.type) { - task_server task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); - switch (task.type) - { - case TASK_TYPE_COMPLETION: { - llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) - { - // if no slot is available, we defer this task for processing later - deferred_tasks.push_back(task); + case TASK_TYPE_COMPLETION: { + llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); + if (slot == nullptr) + { + // if no slot is available, we defer this task for processing later + LOG_VERBOSE("no slot is available", {}); + queue_tasks.defer(task); + break; + } + + if (task.data.contains("system_prompt")) + { + if (!all_slots_are_idle) { + send_error(task, "system prompt can only be updated when all slots are idle"); break; } + process_system_prompt_data(task.data["system_prompt"]); - if (task.data.contains("system_prompt")) + // reset cache_tokens for all slots + for (llama_client_slot &slot : slots) { - if (!all_slots_are_idle) { - send_error(task, "system prompt can only be updated when all slots are idle"); - break; - } - process_system_prompt_data(task.data["system_prompt"]); - // reset cache_tokens for all slots - for (llama_client_slot &slot : slots) - { - slot.cache_tokens.clear(); - } + slot.cache_tokens.clear(); + slot.n_past = 0; + slot.n_past_se = 0; } + } - slot->reset(); + slot->reset(); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; + slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; + slot->task_id = task.id; + slot->multitask_id = task.multitask_id; - if (!launch_slot_with_data(slot, task.data)) + if (!launch_slot_with_data(slot, task.data)) + { + // send error result + send_error(task, "internal_error"); + break; + } + } break; + case TASK_TYPE_CANCEL: { // release slot linked with the task id + for (auto & slot : slots) + { + if (slot.task_id == task.target_id) { - // send error result - send_error(task, "internal_error"); + slot.release(); break; } - } break; - case TASK_TYPE_CANCEL: { // release slot linked with the task id - for (auto & slot : slots) - { - if (slot.task_id == task.target_id) - { - slot.release(); - break; - } - } - } break; - } + } + } break; + case TASK_TYPE_NEXT_RESPONSE: { + // do nothing + } break; } + } - // add all the deferred tasks back the the queue - for (task_server &task : deferred_tasks) - { - queue_tasks.push_back(task); - } + void on_finish_multitask(task_multi& multitask) + { + // all subtasks done == multitask is done + task_result result; + result.id = multitask.id; + result.stop = true; + result.error = false; - // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue - std::vector agg_results; - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) + // collect json results into one json result + std::vector result_jsons; + for (auto& subres : multitask.results) { - if (queue_iterator->subtasks_remaining.empty()) - { - // all subtasks done == multitask is done - task_result aggregate_result; - aggregate_result.id = queue_iterator->id; - aggregate_result.stop = true; - aggregate_result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (auto& subres : queue_iterator->results) - { - result_jsons.push_back(subres.result_json); - aggregate_result.error = aggregate_result.error && subres.error; - } - aggregate_result.result_json = json{ "results", result_jsons }; - - agg_results.push_back(aggregate_result); - condition_results.notify_all(); - - queue_iterator = queue_multitasks.erase(queue_iterator); - } - else - { - ++queue_iterator; - } + result_jsons.push_back(subres.result_json); + result.error = result.error && subres.error; } - // done with tasks, unlock - lock.unlock(); - - // copy aggregate results of complete multi-tasks to the results queue - std::lock_guard lock_results(mutex_results); - queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end()); + result.result_json = json{ { "results", result_jsons } }; + queue_results.send(result); } bool update_slots() { - // attend tasks - process_tasks(); - if (system_need_update) { LOG_TEE("updating system prompt\n"); @@ -1681,40 +1381,45 @@ struct llama_server_context LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n"); kv_cache_clear(); } - std::unique_lock lock(mutex_tasks); - condition_tasks.wait(lock, [&]{ - return !queue_tasks.empty(); - }); + return true; } + task_server task; + task.type = TASK_TYPE_NEXT_RESPONSE; + task.target_id = -1; + queue_tasks.post(task); + for (llama_client_slot &slot : slots) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) + if (slot.ga_n == 1) { - // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; - const int n_discard = n_left / 2; + if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) + { + // Shift context + const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1; + const int n_discard = n_left / 2; - LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); + llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard); - for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) - { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.n_past -= n_discard; + slot.n_past -= n_discard; - slot.truncated = true; + slot.truncated = true; - LOG_VERBOSE("context shift", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); + LOG_VERBOSE("context shift", { + { "n_ctx", n_ctx }, + { "n_keep", params.n_keep }, + { "n_left", n_left }, + }); + } } } @@ -1729,6 +1434,7 @@ struct llama_server_context slot.t_last_used = ggml_time_us(); LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size()); + queue_tasks.notify_slot_changed(); continue; } @@ -1740,8 +1446,11 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); + const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + // TODO: we always have to take into account the "system_tokens" + // this is not great and needs to be improved somehow + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; } @@ -1792,8 +1501,8 @@ struct llama_server_context prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(model)); prompt_tokens = prefix_tokens; } @@ -1838,6 +1547,8 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; + slot.n_past_se = 0; + slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else @@ -1851,6 +1562,25 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + if (slot.ga_n != 1) + { + int ga_i = 0; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < slot.n_past; ++k) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + slot_npast++; + } + slot.n_past_se = slot_npast; + slot.ga_i = ga_i; + } + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -1865,11 +1595,15 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if (slot.ga_i > 0) + { + slot.n_past_se--; + } } LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, }); @@ -1877,9 +1611,25 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; + + int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + + int32_t ga_i = slot.ga_i; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); + if (slot.ga_n != 1) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + } + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + slot_npast++; } if (has_images && !ingest_images(slot, n_batch)) @@ -1909,6 +1659,37 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + for (auto & slot : slots) + { + if (slot.ga_n != 1) + { + // context extension via Self-Extend + while (slot.n_past_se >= slot.ga_i + slot.ga_w) + { + const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; + const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); + const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; + + LOG_TEE("\n"); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); + LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); + + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); + + slot.n_past_se -= bd; + + slot.ga_i += slot.ga_w / slot.ga_n; + + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); + } + slot.n_past_se += n_tokens; + } + } + llama_batch batch_view = { n_tokens, @@ -1922,6 +1703,7 @@ struct llama_server_context }; const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { if (n_batch == 1 || ret < 0) @@ -1994,242 +1776,13 @@ struct llama_server_context } return true; } -}; - - -static std::string random_string() -{ - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - - std::random_device rd; - std::mt19937 generator(rd()); - - std::string result(32, ' '); - - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; - } - - return result; -} - -static std::string gen_chatcmplid() -{ - std::stringstream chatcmplid; - chatcmplid << "chatcmpl-" << random_string(); - return chatcmplid.str(); -} - -std::string format_chatml(std::vector messages) -{ - std::ostringstream chatml_msgs; - for (auto it = messages.begin(); it != messages.end(); ++it) { - chatml_msgs << "<|im_start|>" - << json_value(*it, "role", std::string("user")) << '\n'; - chatml_msgs << json_value(*it, "content", std::string("")) - << "<|im_end|>\n"; + void run_on_all_tasks_finished() { + update_slots(); } - - chatml_msgs << "<|im_start|>assistant" << '\n'; - - return chatml_msgs.str(); -} +}; /* llama.cpp completion api semantics */ -json oaicompat_completion_params_parse( - const json &body /* openai api json semantics */) -{ - json llama_params; - - llama_params["__oaicompat"] = true; - - // Map OpenAI parameters to llama.cpp parameters - // - // For parameters that are defined by the OpenAI documentation (e.g. - // temperature), we explicitly specify OpenAI's intended default; we - // need to do that because sometimes OpenAI disagrees with llama.cpp - // - // https://platform.openai.com/docs/api-reference/chat/create - llama_sampling_params default_sparams; - llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' - llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); - llama_params["temperature"] = json_value(body, "temperature", 0.0); - llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); - llama_params["top_p"] = json_value(body, "top_p", 1.0); - llama_params["n_predict"] = json_value(body, "max_tokens", -1); - llama_params["logit_bias"] = json_value(body, "logit_bias",json::object()); - llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); - llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); - llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); - llama_params["stream"] = json_value(body, "stream", false); - llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); - llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); - llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); - llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); - llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); - - if (body.count("grammar") != 0) { - llama_params["grammar"] = json_value(body, "grammar", json::object()); - } - - // Handle 'stop' field - if (body.contains("stop") && body["stop"].is_string()) { - llama_params["stop"] = json::array({body["stop"].get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - // Ensure there is ChatML-specific end sequence among stop words - llama_params["stop"].push_back("<|im_end|>"); - - return llama_params; -} - -static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false) -{ - json result = response.result_json; - - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = - json{{"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", - json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; - - if (server_verbose) { - res["__verbose"] = result; - } - - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const task_result &response) { - json result = response.result_json; - - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({response.result_json}); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - if (stopped_limit) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({json::object()}); - } - - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } - - json ret = json{{"choices", choices}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({ret}); -} - static json format_partial_response( llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs ) { diff --git a/backend/cpp/llama/utils.hpp b/backend/cpp/llama/utils.hpp new file mode 100644 index 000000000000..c5dafbf0f9ce --- /dev/null +++ b/backend/cpp/llama/utils.hpp @@ -0,0 +1,510 @@ +// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/utils.hpp + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "json.hpp" + +#include "../llava/clip.h" + +using json = nlohmann::json; + +extern bool server_verbose; + +#ifndef SERVER_VERBOSE +#define SERVER_VERBOSE 1 +#endif + +#if SERVER_VERBOSE != 1 +#define LOG_VERBOSE(MSG, ...) +#else +#define LOG_VERBOSE(MSG, ...) \ + do \ + { \ + if (server_verbose) \ + { \ + server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ + } \ + } while (0) +#endif + +#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) + +// +// parallel +// + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed +}; + +enum task_type { + TASK_TYPE_COMPLETION, + TASK_TYPE_CANCEL, + TASK_TYPE_NEXT_RESPONSE +}; + +struct task_server { + int id = -1; // to be filled by llama_server_queue + int target_id; + task_type type; + json data; + bool infill_mode = false; + bool embedding_mode = false; + int multitask_id = -1; +}; + +struct task_result { + int id; + int multitask_id = -1; + bool stop; + bool error; + json result_json; +}; + +struct task_multi { + int id; + std::set subtasks_remaining{}; + std::vector results{}; +}; + +// TODO: can become bool if we can't find use of more states +enum slot_state +{ + IDLE, + PROCESSING, +}; + +enum slot_command +{ + NONE, + LOAD_PROMPT, + RELEASE, +}; + +struct slot_params +{ + bool stream = true; + bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt + + uint32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_predict = -1; // new tokens to predict + + std::vector antiprompt; + + json input_prefix; + json input_suffix; +}; + +struct slot_image +{ + int32_t id; + + bool request_encode_image = false; + float * image_embedding = nullptr; + int32_t image_tokens = 0; + + clip_image_u8 * img_data; + + std::string prefix_prompt; // before of this image +}; + +// completion token output with probabilities +struct completion_token_output +{ + struct token_prob + { + llama_token tok; + float prob; + }; + + std::vector probs; + llama_token tok; + std::string text_to_send; +}; + +static inline void server_log(const char *level, const char *function, int line, + const char *message, const nlohmann::ordered_json &extra) +{ + nlohmann::ordered_json log + { + {"timestamp", time(nullptr)}, + {"level", level}, + {"function", function}, + {"line", line}, + {"message", message}, + }; + + if (!extra.empty()) + { + log.merge_patch(extra); + } + + const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); + printf("%.*s\n", (int)str.size(), str.data()); + fflush(stdout); +} + +// +// server utils +// + +template +static T json_value(const json &body, const std::string &key, const T &default_value) +{ + // Fallback null to default value + return body.contains(key) && !body.at(key).is_null() + ? body.value(key, default_value) + : default_value; +} + +inline std::string format_chatml(std::vector messages) +{ + std::ostringstream chatml_msgs; + + for (auto it = messages.begin(); it != messages.end(); ++it) { + chatml_msgs << "<|im_start|>" + << json_value(*it, "role", std::string("user")) << '\n'; + chatml_msgs << json_value(*it, "content", std::string("")) + << "<|im_end|>\n"; + } + + chatml_msgs << "<|im_start|>assistant" << '\n'; + + return chatml_msgs.str(); +} + +// +// work queue utils +// + +struct llama_server_queue { + int id = 0; + std::mutex mutex_tasks; + // queues + std::vector queue_tasks; + std::vector queue_tasks_deferred; + std::vector queue_multitasks; + std::condition_variable condition_tasks; + // callback functions + std::function callback_new_task; + std::function callback_finish_multitask; + std::function callback_all_task_finished; + + // Add a new task to the end of the queue + int post(task_server task) { + std::unique_lock lock(mutex_tasks); + if (task.id == -1) { + task.id = id++; + } + queue_tasks.push_back(std::move(task)); + condition_tasks.notify_one(); + return task.id; + } + + // Add a new task, but defer until one slot is available + void defer(task_server task) { + std::unique_lock lock(mutex_tasks); + queue_tasks_deferred.push_back(std::move(task)); + } + + // Get the next id for creating anew task + int get_new_id() { + std::unique_lock lock(mutex_tasks); + return id++; + } + + // Register function to process a new task + void on_new_task(std::function callback) { + callback_new_task = callback; + } + + // Register function to process a multitask + void on_finish_multitask(std::function callback) { + callback_finish_multitask = callback; + } + + // Register the function to be called when the batch of tasks is finished + void on_all_tasks_finished(std::function callback) { + callback_all_task_finished = callback; + } + + // Call when the state of one slot is changed + void notify_slot_changed() { + // move deferred tasks back to main loop + std::unique_lock lock(mutex_tasks); + for (auto & task : queue_tasks_deferred) { + queue_tasks.push_back(std::move(task)); + } + queue_tasks_deferred.clear(); + } + + // Start the main loop. This call is blocking + [[noreturn]] + void start_loop() { + while (true) { + // new task arrived + LOG_VERBOSE("have new task", {}); + { + while (true) + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + task_server task = queue_tasks.front(); + queue_tasks.erase(queue_tasks.begin()); + lock.unlock(); + LOG_VERBOSE("callback_new_task", {}); + callback_new_task(task); + } + LOG_VERBOSE("callback_all_task_finished", {}); + // process and update all the multitasks + auto queue_iterator = queue_multitasks.begin(); + while (queue_iterator != queue_multitasks.end()) + { + if (queue_iterator->subtasks_remaining.empty()) + { + // all subtasks done == multitask is done + task_multi current_multitask = *queue_iterator; + callback_finish_multitask(current_multitask); + // remove this multitask + queue_iterator = queue_multitasks.erase(queue_iterator); + } + else + { + ++queue_iterator; + } + } + // all tasks in the current loop is finished + callback_all_task_finished(); + } + LOG_VERBOSE("wait for new task", {}); + // wait for new task + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return !queue_tasks.empty(); + }); + } + } + } + } + + // + // functions to manage multitasks + // + + // add a multitask by specifying the id of all subtask (subtask is a task_server) + void add_multitask(int multitask_id, std::vector& sub_ids) + { + std::lock_guard lock(mutex_tasks); + task_multi multi; + multi.id = multitask_id; + std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); + queue_multitasks.push_back(multi); + } + + // updatethe remaining subtasks, while appending results to multitask + void update_multitask(int multitask_id, int subtask_id, task_result& result) + { + std::lock_guard lock(mutex_tasks); + for (auto& multitask : queue_multitasks) + { + if (multitask.id == multitask_id) + { + multitask.subtasks_remaining.erase(subtask_id); + multitask.results.push_back(result); + } + } + } +}; + +struct llama_server_response { + typedef std::function callback_multitask_t; + callback_multitask_t callback_update_multitask; + // for keeping track of all tasks waiting for the result + std::set waiting_task_ids; + // the main result queue + std::vector queue_results; + std::mutex mutex_results; + std::condition_variable condition_results; + + void add_waiting_task_id(int task_id) { + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(task_id); + } + + void remove_waiting_task_id(int task_id) { + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(task_id); + } + + // This function blocks the thread until there is a response for this task_id + task_result recv(int task_id) { + while (true) + { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + return !queue_results.empty(); + }); + LOG_VERBOSE("condition_results unblock", {}); + + for (int i = 0; i < (int) queue_results.size(); i++) + { + if (queue_results[i].id == task_id) + { + assert(queue_results[i].multitask_id == -1); + task_result res = queue_results[i]; + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // Register the function to update multitask + void on_multitask_update(callback_multitask_t callback) { + callback_update_multitask = callback; + } + + // Send a new result to a waiting task_id + void send(task_result result) { + std::unique_lock lock(mutex_results); + LOG_VERBOSE("send new result", {}); + for (auto& task_id : waiting_task_ids) { + // LOG_TEE("waiting task id %i \n", task_id); + // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result + if (result.multitask_id == task_id) + { + LOG_VERBOSE("callback_update_multitask", {}); + callback_update_multitask(task_id, result.id, result); + continue; + } + + if (result.id == task_id) + { + LOG_VERBOSE("queue_results.push_back", {}); + queue_results.push_back(result); + condition_results.notify_one(); + return; + } + } + } +}; + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(uint8_t c) +{ + return (isalnum(c) || (c == '+') || (c == '/')); +} + +static inline std::vector base64_decode(const std::string & encoded_string) +{ + int i = 0; + int j = 0; + int in_ = 0; + + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + std::vector ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) + { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) + { + for (i = 0; i <4; i++) + { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + { + ret.push_back(char_array_3[i]); + } + i = 0; + } + } + + if (i) + { + for (j = i; j <4; j++) + { + char_array_4[j] = 0; + } + + for (j = 0; j <4; j++) + { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) + { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} + +// +// random string / id +// + +static std::string random_string() +{ + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + std::random_device rd; + std::mt19937 generator(rd()); + + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; + } + + return result; +} + +static std::string gen_chatcmplid() +{ + std::stringstream chatcmplid; + chatcmplid << "chatcmpl-" << random_string(); + return chatcmplid.str(); +} \ No newline at end of file