diff --git a/cpp/serve/data.cc b/cpp/serve/data.cc index 6d25b41953..138620dc5a 100644 --- a/cpp/serve/data.cc +++ b/cpp/serve/data.cc @@ -6,6 +6,8 @@ #include +#include "model.h" + namespace mlc { namespace llm { namespace serve { @@ -24,6 +26,16 @@ TextData::TextData(String text) { data_ = std::move(n); } +int TextDataNode::GetLength() const { + LOG(FATAL) << "\"GetLength\" for TextData is not supported. " + "Please tokenize the text and construct a TokenData object."; +} + +NDArray TextDataNode::GetEmbedding(Model model) const { + LOG(FATAL) << "\"GetEmbedding\" for TextData is not supported. " + "Please tokenize the text and construct a TokenData object."; +} + TVM_REGISTER_GLOBAL("mlc.serve.TextData").set_body_typed([](String text) { return TextData(std::move(text)); }); @@ -36,7 +48,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.TextDataGetTextString").set_body_typed([](TextDat TVM_REGISTER_OBJECT_TYPE(TokenDataNode); -TokenData::TokenData(ShapeTuple token_ids) { +TokenData::TokenData(IntTuple token_ids) { ObjectPtr n = make_object(); n->token_ids = std::move(token_ids); data_ = std::move(n); @@ -44,10 +56,14 @@ TokenData::TokenData(ShapeTuple token_ids) { TokenData::TokenData(std::vector token_ids) { ObjectPtr n = make_object(); - n->token_ids = ShapeTuple(token_ids.begin(), token_ids.end()); + n->token_ids = IntTuple(token_ids.begin(), token_ids.end()); data_ = std::move(n); } +int TokenDataNode::GetLength() const { return token_ids.size(); } + +NDArray TokenDataNode::GetEmbedding(Model model) const { return model->TokenEmbed(token_ids); } + TVM_REGISTER_GLOBAL("mlc.serve.TokenData").set_body([](TVMArgs args, TVMRetValue* rv) { std::vector token_ids; token_ids.reserve(args.size()); diff --git a/cpp/serve/data.h b/cpp/serve/data.h index 3e8668f38a..e097529df2 100644 --- a/cpp/serve/data.h +++ b/cpp/serve/data.h @@ -7,6 +7,7 @@ #include #include +#include #include namespace mlc { @@ -15,11 +16,19 @@ namespace serve { using namespace tvm::runtime; +class Model; + /****************** DataNode ******************/ /*! \brief The base class of multi-modality data (text, tokens, embedding, etc). */ class DataNode : public Object { public: + /*! \brief Get the length (equivalent number of tokens) of the data. */ + virtual int GetLength() const = 0; + + /*! \brief Compute the embedding of this data with regard to the input model. */ + virtual NDArray GetEmbedding(Model model) const = 0; + static constexpr const char* _type_key = "mlc.serve.Data"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; @@ -39,6 +48,9 @@ class TextDataNode : public DataNode { /*! \brief The text string. */ String text; + int GetLength() const final; + NDArray GetEmbedding(Model model) const final; + static constexpr const char* _type_key = "mlc.serve.TextData"; TVM_DECLARE_BASE_OBJECT_INFO(TextDataNode, DataNode); }; @@ -56,7 +68,10 @@ class TextData : public Data { class TokenDataNode : public DataNode { public: /*! \brief The token ids. */ - ShapeTuple token_ids; + IntTuple token_ids; + + int GetLength() const final; + NDArray GetEmbedding(Model model) const final; static constexpr const char* _type_key = "mlc.serve.TokenData"; TVM_DECLARE_BASE_OBJECT_INFO(TokenDataNode, DataNode); @@ -64,7 +79,7 @@ class TokenDataNode : public DataNode { class TokenData : public Data { public: - explicit TokenData(ShapeTuple token_ids); + explicit TokenData(IntTuple token_ids); explicit TokenData(std::vector token_ids); diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 62127e92ac..967e460800 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -4,19 +4,19 @@ * \brief The implementation for runtime module of serving engine module in MLC LLM. */ #define __STDC_FORMAT_MACROS -#define PICOJSON_USE_INT64 -#include +#include #include #include #include #include +#include "../tokenizers.h" +#include "engine_stats.h" #include "model.h" #include "request.h" #include "request_state.h" #include "sampler.h" -#include "tokenizer.h" namespace mlc { namespace llm { @@ -43,11 +43,22 @@ class EngineModule; * \note For now only one model run in the engine is supported. * Multiple model support such as speculative inference will * be followed soon in the future. + * + * The public interface of Engine has the following three categories: + * - engine management, + * - high-level request management, + * - engine "step" action. + * + * The internal implementation of Engine has the following categories: + * - internal request management, + * - actions and request schedule policy (such as prefill, decode, etc.) */ class Engine { friend class EngineModule; public: + /********************** Engine Management **********************/ + /*! * \brief (Re)initialize the engine with the given lists of * models and KV cache config. @@ -63,77 +74,67 @@ class Engine { ICHECK_GE(num_models, 1); ICHECK_EQ(model_paths.size(), num_models); ICHECK_EQ(devices.size(), num_models); - devices_ = std::move(devices); // Step 1. Create models and their PackedFuncs. ICHECK(models_.empty()); models_.reserve(num_models); - fmodel_batch_prefill_.clear(); - fmodel_decode_.clear(); - fmodel_token_embed_.clear(); - fmodel_add_new_sequence_.clear(); - fmodel_remove_sequence_.clear(); - fmodel_softmax_with_temperature_.clear(); - fmodel_get_num_available_pages_.clear(); for (int i = 0; i < num_models; ++i) { - Module model = CreateModelModule(reload_libs[i], model_paths[i], devices_[i]); - models_.push_back(model); - fmodel_batch_prefill_.push_back(model->GetFunction("batch_prefill")); - fmodel_decode_.push_back(model->GetFunction("decode")); - fmodel_token_embed_.push_back(model->GetFunction("token_embed")); - fmodel_add_new_sequence_.push_back(model->GetFunction("add_new_sequence")); - fmodel_remove_sequence_.push_back(model->GetFunction("remove_sequence")); - fmodel_softmax_with_temperature_.push_back(model->GetFunction("softmax_with_temperature")); - fmodel_get_num_available_pages_.push_back(model->GetFunction("get_num_available_pages")); + models_.push_back(Model::Create(reload_libs[i], model_paths[i], devices[i])); } // Step 2. Fetch max single sequence length from models. max_single_sequence_length_ = std::numeric_limits::max(); - for (Module model : models_) { - int max_window_size = model->GetFunction("get_max_window_size")(); + for (Model model : models_) { + int max_window_size = model->GetMaxWindowSize(); max_single_sequence_length_ = std::min(max_single_sequence_length_, max_window_size); } // Step 3. Process KV cache config json string. kv_cache_config_ = KVCacheConfig(kv_cache_config_json, max_single_sequence_length_); // Step 4. Create KV cache for each model. - for (Module model : models_) { - model->GetFunction("create_kv_cache")(kv_cache_config_); + for (Model model : models_) { + model->CreateKVCache(kv_cache_config_); } // Step 5. Create sampler and tokenizer. - // The sampler is created one per model on each device. // The tokenizer is created from the first model. // We assume all models have the same tokenizer, which is the basic // requirement of speculative encoding. - fsampler_require_gpu_softmax_.clear(); - fsampler_compute_probs_from_logits_inplace_.clear(); - fsampler_sample_token_from_probs_.clear(); - for (int i = 0; i < num_models; ++i) { - Module sampler = CreateSamplerModule(devices_[i]); - samplers_.push_back(sampler); - fsampler_require_gpu_softmax_.push_back(sampler->GetFunction("require_gpu_softmax")); - fsampler_compute_probs_from_logits_inplace_.push_back( - sampler->GetFunction("compute_probs_from_logits_inplace")); - fsampler_sample_token_from_probs_.push_back(sampler->GetFunction("sample_token_from_probs")); - } - tokenizer_ = CreateTokenizerModule(model_paths[0]); - ftokenizer_tokenize = tokenizer_->GetFunction("tokenize"); - ftokenizer_decode = tokenizer_->GetFunction("decode"); + sampler_ = Sampler::Create(/*sampler_kind=*/"cpu"); + tokenizer_ = TokenizerFromPath(model_paths[0]); ResetEngine(); } + /*! \brief Reset the engine, clean up all running data and statistics. */ + void ResetEngine() { + running_queue_.clear(); + waiting_queue_.clear(); + abort_queue_.clear(); + request_states_.clear(); + stats_.Reset(); + for (Model model : models_) { + model->Reset(); + } + } + + /***************** High-level Request Management *****************/ + /*! * \brief Add a new request to the engine. * \param request The request to add. */ void AddRequest(Request request) { + // Get a request copy where all text inputs are tokenized. + request = Request::FromUntokenized(request, tokenizer_); + ICHECK_NE(request->input_total_length, -1); + // Append to the waiting queue and create the request state. waiting_queue_.push_back(request); - request_states_.emplace( - request, RequestState(models_.size(), request->inputs, GetInputLength(request->inputs))); + request_states_.emplace(request->id, RequestState(request, models_.size())); } /*! \brief Abort the input request. */ void AbortRequest(Request request) { abort_queue_.push_back(request); } + /*********************** Engine Action ***********************/ + /*! * \brief The main function that the engine takes a step of action. * At each step, the engine may decide to @@ -157,80 +158,158 @@ class Engine { return; } - // - Action 2. Run speculation step for small models. - // NOTE: Right now we do not really support speculation. - // Here we just reserve room for extension. - bool speculate_processed = StepSpeculate(); - if (speculate_processed) { - return; - } - // - Action 3. Run speculation verification step. - bool verify_processed = StepVerify(); - if (verify_processed) { - UpdateFinishedRequest(); - return; - } - - // - Action 4. Run decode step. + // - Action 2. Run decode step. bool decode_processed = StepDecode(); if (decode_processed) { - UpdateFinishedRequest(); + ProcessFinishedRequest(); return; } - // - Action 5. Preempt the last running sequence. - if (!running_queue_.empty()) { - ICHECK_GT(static_cast(running_queue_.size()), 1); - StepPreempt(running_queue_.back()); + ICHECK(running_queue_.empty()) + << "Not taking any action in a step is not expected with running requests."; + } + + private: + /***************** Internal Request Management *****************/ + + /*! \brief Assign the given internal id for the given request. */ + void AssignIDForRequest(Request request, int req_id) { + // Set internal id in the request state. + RequestState state = request_states_.at(request->id); + for (RequestModelState mstate : state->mstates) { + mstate->request_id = req_id; + } + // Add a new sequence to each model. + for (int i = 0; i < static_cast(models_.size()); ++i) { + int seq_id_in_model = models_[i]->AddNewSequence(); + ICHECK_EQ(seq_id_in_model, req_id); } } - /*! \brief Reset the engine, clean up all running data and statistics. */ - void ResetEngine() { - running_queue_.clear(); - waiting_queue_.clear(); - abort_queue_.clear(); - request_states_.clear(); + /*! + * \brief Remove the given request from models and update request states. + * \param req_id The internal id of the request to remove. + */ + void RemoveRequestFromModel(int req_id) { + // Remove the request from all models (usually the KV cache). + for (Model model : models_) { + model->RemoveSequence(req_id); + } + // Update the internal request id of other requests. + for (auto& it : request_states_) { + RequestState state = it.second; + for (RequestModelState mstate : state->mstates) { + ICHECK_NE(mstate->request_id, req_id); + if (mstate->request_id > req_id) { + --mstate->request_id; + } + } + } + } - for (Module model : models_) { - model->GetFunction("reset")(); + /*! + * \brief Preempt the generation of the given request, moving + * it from running request set to the foremost of waiting + * request queue. + */ + void PreemptRequest(std::vector::iterator request_it) { + Request request = *request_it; + + // Remove from models. + // - Reset `request_id` of states. + // - Clear model speculation draft. + // - Update `inputs` for future prefill. + RequestState state = request_states_.at(request->id); + int req_id = state->mstates[0]->request_id; + stats_.current_total_seq_len -= + request->input_total_length + state->mstates[0]->committed_tokens.size() - 1; + for (RequestModelState mstate : state->mstates) { + mstate->request_id = -1; + mstate->draft_output_tokens.clear(); + mstate->draft_output_token_prob.clear(); + mstate->draft_output_prob_dist.clear(); + ICHECK(mstate->inputs.empty()); + ICHECK(!mstate->committed_tokens.empty()); + + Array inputs = request->inputs; + if (const auto* token_input = inputs.back().as()) { + // Merge the TokenData so that a single time TokenEmbed is needed. + std::vector token_ids{token_input->token_ids->data, + token_input->token_ids->data + token_input->token_ids.size()}; + token_ids.insert(token_ids.end(), mstate->committed_tokens.begin(), + mstate->committed_tokens.end()); + inputs.Set(inputs.size() - 1, TokenData(token_ids)); + } else { + inputs.push_back(TokenData(mstate->committed_tokens)); + } + mstate->inputs = std::move(inputs); } + RemoveRequestFromModel(req_id); - current_total_seq_len_ = 0; - request_total_prefill_time_ = 0.0f; - request_total_decode_time_ = 0.0f; - engine_total_prefill_time_ = 0.0f; - engine_total_decode_time_ = 0.0f; - total_prefill_length_ = 0; - total_decode_length_ = 0; - tokenize_cache_.clear(); + // Move from running queue to the front of waiting queue. + running_queue_.erase(request_it); + waiting_queue_.insert(waiting_queue_.begin(), request); } /*! - * \brief Return the engine runtime statistics in JSON string. - * We collect the following entries: - * - single token prefill latency (s/tok): avg latency of processing one token in prefill - * - single token decode latency (s/tok): avg latency of processing one token in decode - * - engine time for prefill (sec) - * - engine time for decode (sec) - * - total number of processed tokens in prefill. - * - total number of processed tokens in decode. - * \return The statistics in JSON string. + * \brief For each request, check if the request has finished + * its generation. And update the state and return the generation + * result for the finished requests. + * \note This function removes requests from the running request + * queue. */ - String StatisticsJSON() { - picojson::object config; - config["single_token_prefill_latency"] = - picojson::value(request_total_prefill_time_ / total_prefill_length_); - config["single_token_decode_latency"] = - picojson::value(request_total_decode_time_ / total_decode_length_); - config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time_); - config["engine_total_decode_time"] = picojson::value(engine_total_decode_time_); - config["total_prefill_tokens"] = picojson::value(total_prefill_length_); - config["total_decode_tokens"] = picojson::value(total_decode_length_); - return picojson::value(config).serialize(true); + void ProcessFinishedRequest() { + // - Collect finished requests. + // We don't remove on the fly to avoid concurrent modification. + std::vector request_to_remove; + for (Request request : running_queue_) { + if (request_states_.at(request->id)->GenerationFinished(max_single_sequence_length_)) { + request_to_remove.push_back(request); + } + } + + // - Remove the finished request. + for (Request request : request_to_remove) { + // Remove from running queue. + auto it = std::find(running_queue_.begin(), running_queue_.end(), request); + ICHECK(it != running_queue_.end()); + running_queue_.erase(it); + + // Update engine states. + RequestState state = request_states_.at(request->id); + int req_id = state->mstates[0]->request_id; + for (RequestModelState mstate : state->mstates) { + ICHECK_EQ(mstate->request_id, req_id); + mstate->request_id = -1; + } + RemoveRequestFromModel(req_id); + request_states_.erase(request->id); + + // Update engine statistics. + int num_input_tokens = request->input_total_length; + int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; + stats_.current_total_seq_len -= num_input_tokens + num_output_tokens; + auto trequest_finish = std::chrono::high_resolution_clock::now(); + stats_.request_total_prefill_time += + static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; + stats_.total_prefill_length += num_input_tokens; + stats_.request_total_decode_time += + static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; + stats_.total_decode_length += num_output_tokens; + + // NOTE: right now we only return the generated text. + // In the future we might optional return text or token ids. + String output = tokenizer_->Decode(state->mstates[0]->committed_tokens); + request->fcallback(request, TextData(output)); + } } - private: + /************** Engine Actions and Request Schedule Policy **************/ + + /********************* + * Action 1. Prefill * + *********************/ + /*! \brief Pick applicable requests and run prefill. */ bool StepPrefill() { auto [requests, states, sample_new_token] = GetRequestsToPrefill(); @@ -261,7 +340,7 @@ class Engine { mstates_for_sample.reserve(requests.size()); generation_cfg_for_sample.reserve(requests.size()); for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - Module model = models_[model_id]; + Model model = models_[model_id]; auto [request_list, mstates, prefill_lengths] = FilterPrefillRequests(requests, states, model_id); Array embeddings; @@ -275,7 +354,7 @@ class Engine { if (model_id == 0) { // Accumulate the sequence length. sum_prefill_lengths += prefill_length; - current_total_seq_len_ += prefill_length; + stats_.current_total_seq_len += prefill_length; mstates_for_sample.push_back(mstate); generation_cfg_for_sample.push_back(request->generation_cfg); } @@ -285,14 +364,13 @@ class Engine { ICHECK(!mstate->inputs.empty()); request_ids.push_back(mstate->request_id); for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - embeddings.push_back(GetEmbedding(mstate->inputs[i], fmodel_token_embed_[model_id])); + embeddings.push_back(mstate->inputs[i]->GetEmbedding(model)); } // Clean up `inputs` after prefill mstate->inputs.clear(); } - NDArray logits = fmodel_batch_prefill_[model_id]( - embeddings, ShapeTuple(request_ids.begin(), request_ids.end()), prefill_lengths); + NDArray logits = model->BatchPrefill(embeddings, request_ids, prefill_lengths); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); ICHECK_EQ(logits->shape[1], request_list.size()); @@ -312,8 +390,8 @@ class Engine { ICHECK_EQ(generation_cfg_for_sample.size(), num_requests); logits_for_sample = logits_for_sample.CreateView( {num_requests, 1, logits_for_sample->shape[2]}, logits_for_sample->dtype); - ShapeTuple next_tokens = SampleTokens(logits_for_sample, /*model_id=*/0, /*sampler_id=*/0, - mstates_for_sample, generation_cfg_for_sample); + std::vector next_tokens = sampler_->SampleTokens( + logits_for_sample, models_[0], mstates_for_sample, generation_cfg_for_sample); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. @@ -321,178 +399,17 @@ class Engine { for (int i = 0; i < num_requests; ++i) { mstates_for_sample[i]->committed_tokens.push_back(next_tokens[i]); if (mstates_for_sample[i]->committed_tokens.size() == 1) { - request_states_.at(requests[i])->tprefill_finish = tnow; + request_states_.at(requests[i]->id)->tprefill_finish = tnow; } } } auto tend = std::chrono::high_resolution_clock::now(); - engine_total_prefill_time_ += static_cast((tend - tstart).count()) / 1e9; - - return true; - } - - /*! \brief Pick applicable requests and run decode. */ - bool StepDecode() { - // - Do not run decode when there are multiple models. - if (models_.size() > 1) { - return false; - } - - PreemptUnfittableRequests(); - if (running_queue_.empty()) { - return false; - } - - auto tstart = std::chrono::high_resolution_clock::now(); - - // NOTE: Right now we only support decode all the running requests at a time. - int num_requests = running_queue_.size(); - // Check if the requests ids are in an ascending order. - for (int i = 1; i < num_requests; ++i) { - ICHECK_GT(request_states_.at(running_queue_[i])->mstates[0]->request_id, - request_states_.at(running_queue_[i - 1])->mstates[0]->request_id); - } - - current_total_seq_len_ += num_requests; - // Collect - // - the last committed token, - // - the request states, - // - the sampling parameters, - // of each request. - Array inputs; - Array mstates; - Array generation_cfg; - inputs.reserve(num_requests); - mstates.reserve(num_requests); - generation_cfg.reserve(num_requests); - for (Request request : running_queue_) { - RequestState& state = request_states_.at(request); - inputs.push_back(TokenData(ShapeTuple({state->mstates[0]->committed_tokens.back()}))); - mstates.push_back(state->mstates[0]); - generation_cfg.push_back(request->generation_cfg); - } - - // - Compute embeddings. - NDArray embeddings = GetTokenEmbeddings(inputs, fmodel_token_embed_[0], - /*return_flattened_view=*/false); - - // - Invoke model decode. - NDArray logits = fmodel_decode_[0](embeddings); - ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], embeddings->shape[0]); - ICHECK_EQ(logits->shape[1], 1); - - // - Sample tokens. - ShapeTuple next_tokens = - SampleTokens(logits, /*model_id=*/0, /*sampler_id=*/0, mstates, generation_cfg); - ICHECK_EQ(next_tokens.size(), num_requests); - - // - Update the committed tokens of states. - for (int i = 0; i < num_requests; ++i) { - mstates[i]->committed_tokens.push_back(next_tokens[i]); - } - - auto tend = std::chrono::high_resolution_clock::now(); - engine_total_decode_time_ += static_cast((tend - tstart).count()) / 1e9; - - return true; - } - - /*! \brief Pick applicable requests and run speculation. */ - bool StepSpeculate() { - // - No speculate when there is only one model. - if (models_.size() == 1) { - return false; - } + stats_.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - // NOTE: We do not support speculation right now. - // The following is the possible sketch implementation for speculation step. - // - // Array mstates = GetRequestStatesToSpeculate(); - // if (mstates.empty()) { - // return false; - // } - // ... - ICHECK(false) << "Cannot reach here at this moment."; return true; } - /*! \brief Pick applicable requests and run verification of speculation results. */ - bool StepVerify() { - // - No verification when there is only one model. - if (models_.size() == 1) { - return false; - } - - // NOTE: We do not support speculation and verification right now. - // The following is the possible sketch implementation for speculation step. - // - // Array requests = GetRequestsToVerify(); - // if (requests.empty()) { - // return false; - // } - // ... - ICHECK(false) << "Cannot reach here at this moment."; - return true; - } - - /*! \brief Abort the generation of the given request. */ - void StepAbort(Request request) { - auto it_running = std::find(running_queue_.begin(), running_queue_.end(), request); - auto it_waiting = std::find(waiting_queue_.begin(), waiting_queue_.end(), request); - ICHECK(it_running != running_queue_.end() || it_waiting != waiting_queue_.end()); - if (it_running != running_queue_.end()) { - // The request to abort is in running queue - int req_id = it_running - running_queue_.begin(); - running_queue_.erase(it_running); - RequestState state = request_states_.at(request); - current_total_seq_len_ -= - state->raw_input_length + state->mstates[0]->committed_tokens.size() - 1; - RemoveSequenceFromModels(req_id); - UpdateRequestIDAfterRemoval(req_id); - } else { - // The request to abort is in waiting queue - waiting_queue_.erase(it_waiting); - } - request_states_.erase(request); - } - - /*! - * \brief Preempt the generation of the given request, moving - * it from running request set to the foremost of waiting - * request queue. - */ - void StepPreempt(Request request) { - auto it = std::find(running_queue_.begin(), running_queue_.end(), request); - ICHECK(it != running_queue_.end()); - - // Remove from models. - int req_id = it - running_queue_.begin(); - // - Reset `request_id` of states. - // - Clear model speculation draft. - // - Update `inputs` for future prefill. - RequestState& state = request_states_.at(request); - current_total_seq_len_ -= - state->raw_input_length + state->mstates[0]->committed_tokens.size() - 1; - for (RequestModelState mstate : state->mstates) { - mstate->request_id = -1; - mstate->draft_output_tokens.clear(); - mstate->draft_output_token_prob.clear(); - mstate->draft_output_prob_dist.clear(); - ICHECK(mstate->inputs.empty()); - ICHECK(!mstate->committed_tokens.empty()); - mstate->inputs = request->inputs; - mstate->inputs.push_back(TokenData(mstate->committed_tokens)); - } - RemoveSequenceFromModels(req_id); - UpdateRequestIDAfterRemoval(req_id); - - // Move from running queue to the front of waiting queue. - running_queue_.erase(it); - waiting_queue_.insert(waiting_queue_.begin(), request); - } - /*! * \brief Find one or multiple requests to run prefill. * \return The requests to prefill. For each request, we @@ -506,13 +423,12 @@ class Engine { if (!waiting_queue_.empty()) { int total_input_length = 0; int total_required_pages = 0; - ICHECK(fmodel_get_num_available_pages_[0].defined()); - int num_available_pages = fmodel_get_num_available_pages_[0](); + int num_available_pages = models_[0]->GetNumAvailablePages(); for (int i = 0; i < static_cast(waiting_queue_.size()); ++i) { Request request = waiting_queue_[i]; - RequestState state = request_states_.at(request); - int input_length = GetInputLength(state->mstates[0]->inputs); + RequestState state = request_states_.at(request->id); + int input_length = state->mstates[0]->GetInputLength(); int num_require_pages = (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; total_input_length += input_length; @@ -534,7 +450,7 @@ class Engine { // Try to prefill for small models. for (Request request : running_queue_) { - RequestState state = request_states_.at(request); + RequestState state = request_states_.at(request->id); Array mstates = state->mstates; for (int i = 0; i < static_cast(mstates.size()); ++i) { if (!mstates[i]->inputs.empty()) { @@ -551,36 +467,6 @@ class Engine { return {prefill_requests, states, false}; } - /*! \brief Preempt the requests unfittable for decode. */ - void PreemptUnfittableRequests() { - if (running_queue_.empty()) { - return; - } - - int num_available_pages = fmodel_get_num_available_pages_[0](); - while (true) { - if (CanDecode(running_queue_.size())) { - break; - } - StepPreempt(running_queue_.back()); - } - } - - /*! \brief Assign the given id for the given request. */ - void AssignIDForRequest(Request request, int req_id) { - // Set id in the request state. - RequestState& state = request_states_.at(request); - for (RequestModelState mstate : state->mstates) { - mstate->request_id = req_id; - } - // Add a new sequence to each model. - for (int i = 0; i < static_cast(models_.size()); ++i) { - Module model = models_[i]; - int seq_id_in_model = fmodel_add_new_sequence_[i](); - ICHECK_EQ(seq_id_in_model, req_id); - } - } - /*! \brief Check if the input requests can be prefilled under conditions. */ bool CanPrefill(int num_prefill_req, int total_input_length, int num_required_pages, int num_available_pages) { @@ -595,25 +481,19 @@ class Engine { // NOTE: The conditions are heuristic and can be revised. // Cond 1: total input length <= max allowed single sequence length. - // Cond 2: remaining pages >= 10, where 10 is a watermark number can + // Cond 2: at least one decode can be performed after prefill. + // Cond 3: number of total tokens after 8 times of decode does not + // exceed the limit, where 8 is a watermark number can // be configured and adjusted in the future. - // Cond 3: at least one decode can be performed after prefill. - // Todo: move watermark to config. int new_batch_size = num_running_requests + num_prefill_req; return total_input_length <= max_single_sequence_length_ && num_required_pages + new_batch_size <= num_available_pages && - current_total_seq_len_ + total_input_length + 8 * new_batch_size <= + stats_.current_total_seq_len + total_input_length + 8 * new_batch_size <= kv_cache_config_->max_total_sequence_length; } - /*! \brief Check if the input requests can be decoded under conditions. */ - bool CanDecode(int num_requests) { - int num_available_pages = fmodel_get_num_available_pages_[0](); - return num_requests <= num_available_pages; - } - /*! \brief Filter the requests to prefill on the given model. */ - std::tuple, Array, ShapeTuple> FilterPrefillRequests( + std::tuple, Array, std::vector> FilterPrefillRequests( Array requests, Array states, int model_id) { ICHECK_EQ(requests.size(), states.size()); int num_requests = requests.size(); @@ -625,376 +505,141 @@ class Engine { prefill_length.reserve(num_requests); for (int i = 0; i < num_requests; ++i) { - int length = GetInputLength(states[i]->mstates[model_id]->inputs); + int length = states[i]->mstates[model_id]->GetInputLength(); if (length > 0) { filtered_requests.push_back(requests[i]); filtered_mstates.push_back(states[i]->mstates[model_id]); prefill_length.push_back(length); } } - return {filtered_requests, filtered_mstates, - ShapeTuple(prefill_length.begin(), prefill_length.end())}; + return {filtered_requests, filtered_mstates, prefill_length}; } - /*! \brief Get the total input length of the given inputs. */ - int GetInputLength(Array inputs) { - int length_sum = 0; - for (Data input : inputs) { - length_sum += GetInputLength(input); - } - return length_sum; - } + /******************** + * Action 2. Decode * + ********************/ - /*! \brief Get the equivalent length of the given input. */ - int GetInputLength(const Data& input) { - // Dispatch according to input type. - if (const auto* text_input = input.as()) { - return Tokenize(text_input->text).size(); - } else if (const auto* tokens_input = input.as()) { - return tokens_input->token_ids.size(); - } else { - ICHECK(false) << "Cannot reach here"; - throw; + /*! \brief Pick applicable requests and run decode. */ + bool StepDecode() { + // - Do not run decode when there are multiple models. + if (models_.size() > 1) { + return false; } - } - /*! - * \brief Tokenize the input text string using tokenizer. - * \note We use an engine-wise tokenize cache. - * The cache will be reset once its size reaches the full capacity. - */ - ShapeTuple Tokenize(String text) { - auto it = tokenize_cache_.find(text); - if (it != tokenize_cache_.end()) { - return it->second; + if (running_queue_.empty()) { + return false; } - ShapeTuple token_ids = ftokenizer_tokenize(text); - tokenize_cache_.emplace(text, token_ids); - // Clean up cache to avoid unlimited growth. - static constexpr int max_cache_size = 100000; - if (tokenize_cache_.size() == max_cache_size) { - tokenize_cache_.clear(); + // Preempt requests when decode cannot apply. + while (!CanDecode(running_queue_.size())) { + PreemptRequest(running_queue_.end() - 1); } - return token_ids; - } + auto tstart = std::chrono::high_resolution_clock::now(); - /*! - * \brief Compute the embedding of the given **single** input with - * regard to the given model. - */ - NDArray GetEmbedding(Data input, PackedFunc fmodel_token_embed) { - // Dispatch according to input type. - if (const auto* text_input = input.as()) { - ShapeTuple token_ids = Tokenize(text_input->text); - return fmodel_token_embed(Array{token_ids}); - } else if (const auto* tokens_input = input.as()) { - return fmodel_token_embed(Array{tokens_input->token_ids}); - } else { - ICHECK(false) << "Cannot reach here"; - throw; + // NOTE: Right now we only support decode all the running requests at a time. + int num_requests = running_queue_.size(); + // Check if the requests ids are in an ascending order. + for (int i = 1; i < num_requests; ++i) { + ICHECK_GT(request_states_.at(running_queue_[i]->id)->mstates[0]->request_id, + request_states_.at(running_queue_[i - 1]->id)->mstates[0]->request_id); } - } - /*! - * \brief Get token embeddings for all inputs in a **batched style**. - * It requires all inputs are either TextData or TokenData. - * This function is usually called for batch-wise actions such as decode - * (or batched prefill if supported in the future). - * \param inputs The inputs to compute embeddings. - * \param fmodel_token_embed The token embedding function of the model of interest. - * \param return_flattened_view A boolean indicating if flatten the - * embeddings across the batch dimension or not. For batch decode we - * do not flatten, and for batch prefill we usually flatten to handle - * raggedness. - * \return The computed embeddings with regard to the required view. - */ - NDArray GetTokenEmbeddings(Array inputs, PackedFunc fmodel_token_embed, - bool return_flattened_view) { - CHECK(!inputs.empty()); - int num_inputs = inputs.size(); - Array token_ids; - token_ids.reserve(num_inputs); - for (Data input : inputs) { - if (const auto* text_input = input.as()) { - token_ids.push_back(Tokenize(text_input->text)); - } else if (const auto* tokens_input = input.as()) { - token_ids.push_back(tokens_input->token_ids); - } else { - CHECK(false) << "Input type " << input->GetTypeKey() << " is not accepted"; - } + stats_.current_total_seq_len += num_requests; + // Collect + // - the last committed token, + // - the request states, + // - the sampling parameters, + // of each request. + std::vector input_tokens; + Array mstates; + Array generation_cfg; + input_tokens.reserve(num_requests); + mstates.reserve(num_requests); + generation_cfg.reserve(num_requests); + for (Request request : running_queue_) { + RequestState state = request_states_.at(request->id); + input_tokens.push_back(state->mstates[0]->committed_tokens.back()); + mstates.push_back(state->mstates[0]); + generation_cfg.push_back(request->generation_cfg); } - // - If it is expected to return in a flattened view, just return the embeddings. - NDArray embeddings = fmodel_token_embed(token_ids); - if (return_flattened_view) { - return embeddings; - } - // - Otherwise, it is required that each input has the same length. - // Because we cannot return embeddings with raggedness in an - // unflattened way. - int input_length = token_ids[0].size(); - for (ShapeTuple ids : token_ids) { - CHECK_EQ(ids.size(), input_length) - << "When it is required not to return flattened embeddings, " - "all inputs are supposed to have the same length"; - } + // - Compute embeddings. + NDArray embeddings = + models_[0]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); ICHECK_EQ(embeddings->ndim, 3); ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], input_length * num_inputs); - return embeddings.CreateView({num_inputs, input_length, embeddings->shape[2]}, - embeddings->dtype); - } + ICHECK_EQ(embeddings->shape[1], num_requests); + embeddings = embeddings.CreateView({num_requests, 1, embeddings->shape[2]}, embeddings->dtype); - /*! - * \brief Sample tokens from the input logits. - * \param logits_on_device The logits to sample tokens from. - * \param model_id The id of the LLM model module which contains the softmax - * function on device that might be helpful. - * \param sampler_id The id of the sampler module to run sampling. - * \param request_mstates The request states of each sequence in - * the batch with regard to the given model. - * \param generation_cfg The generation config of each request - * in the input batch. - * \return The sampled tokens, one for each request in the batch. - */ - ShapeTuple SampleTokens(NDArray logits_on_device, int model_id, int sampler_id, - Array request_mstates, - Array generation_cfg) { - ICHECK(logits_on_device.defined()); - ICHECK_EQ(logits_on_device->ndim, 3); - ICHECK_EQ(logits_on_device->shape[1], 1) - << "Multi-token sampling for one sequence is not supported yet."; - ICHECK_EQ(logits_on_device->shape[0], generation_cfg.size()); - ICHECK_EQ(request_mstates.size(), generation_cfg.size()); - - Module model = models_[model_id]; - Module sampler = samplers_[sampler_id]; - - int num_sequence = logits_on_device->shape[0]; - bool require_gpu_softmax = fsampler_require_gpu_softmax_[sampler_id](generation_cfg); - - // - Compute probabilities from logits. - NDArray logits_or_probs_on_cpu{nullptr}; - if (require_gpu_softmax) { - NDArray probs_on_device = - fmodel_softmax_with_temperature_[model_id](logits_on_device, generation_cfg); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(probs_on_device); - } else { - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device); - // The "compute_probs_from_logits_inplace" function updates - // `logits_or_probs_on_cpu` in place. - fsampler_compute_probs_from_logits_inplace_[sampler_id]( - logits_or_probs_on_cpu, std::move(request_mstates), generation_cfg); - } - // `CopyLogitsOrProbsToCPU` flattens the first two dimensions. - ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2); - - // - Sample tokens from probabilities. - // NOTE: Though we have the probability field in RequestModelState, - // we do not save the probabilities right now. - // We will handle this in the future when we work on speculation. - ShapeTuple new_tokens = - fsampler_sample_token_from_probs_[sampler_id](logits_or_probs_on_cpu, generation_cfg); - return new_tokens; - } + // - Invoke model decode. + NDArray logits = models_[0]->BatchDecode(embeddings); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], embeddings->shape[0]); + ICHECK_EQ(logits->shape[1], 1); - /*! - * \brief Copy logits or prob distributions from device to CPU. - * The input array is in layout (b, n, v). - * This function flattens the first dimension, returns an NDArray - * in shape (b * n, v). - */ - NDArray CopyLogitsOrProbsToCPU(NDArray arr_on_device) { - // arr_on_device: (b, n, v) - ICHECK_EQ(arr_on_device->ndim, 3); - ICHECK(!logits_or_probs_on_cpu_.defined() || (logits_or_probs_on_cpu_)->ndim == 2); - ICHECK(arr_on_device->device.device_type != kDLCPU); - if (logits_or_probs_on_cpu_.defined()) { - ICHECK_EQ(logits_or_probs_on_cpu_->shape[1], arr_on_device->shape[2]); - } + // - Sample tokens. + std::vector next_tokens = + sampler_->SampleTokens(logits, models_[0], mstates, generation_cfg); + ICHECK_EQ(next_tokens.size(), num_requests); - int64_t init_size = logits_or_probs_on_cpu_.defined() ? logits_or_probs_on_cpu_->shape[0] : 32; - int64_t num_tokens = arr_on_device->shape[0] * arr_on_device->shape[1]; - int64_t vocab_size = arr_on_device->shape[2]; - while (init_size < num_tokens) { - init_size *= 2; - } - if (!logits_or_probs_on_cpu_.defined() || init_size != logits_or_probs_on_cpu_->shape[0]) { - logits_or_probs_on_cpu_ = - NDArray::Empty({init_size, vocab_size}, arr_on_device->dtype, DLDevice{kDLCPU, 0}); + // - Update the committed tokens of states. + for (int i = 0; i < num_requests; ++i) { + mstates[i]->committed_tokens.push_back(next_tokens[i]); } - ICHECK_LE(num_tokens, logits_or_probs_on_cpu_->shape[0]); - NDArray view = - logits_or_probs_on_cpu_.CreateView({num_tokens, vocab_size}, arr_on_device->dtype); - view.CopyFrom(arr_on_device); - return view; - } - /*! \brief Remove the given request from all models (usually the KV cache). */ - void RemoveSequenceFromModels(int req_id) { - for (int i = 0; i < static_cast(models_.size()); ++i) { - fmodel_remove_sequence_[i](req_id); - } - } + auto tend = std::chrono::high_resolution_clock::now(); + stats_.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; - /*! - * \brief Update the request ids of all running requests after - * the removal of the given request. - */ - void UpdateRequestIDAfterRemoval(int removed_req_id) { - for (auto& it : request_states_) { - RequestState& state = it.second; - for (RequestModelState mstate : state->mstates) { - ICHECK_NE(mstate->request_id, removed_req_id); - if (mstate->request_id > removed_req_id) { - --mstate->request_id; - } - } - } + return true; } - /*! - * \brief For each request, check if the request has finished - * its generation. And update the state and return the generation - * result for the finished requests. - * \note This function removes requests from the running request - * queue. - */ - void UpdateFinishedRequest() { - // - Collect finished requests. - // We don't remove on the fly to avoid concurrent modification. - std::vector request_to_remove; - for (Request request : running_queue_) { - if (RequestIsFinished(request)) { - request_to_remove.push_back(request); - } - } - - // - Remove the finished request. - for (Request request : request_to_remove) { - auto it = std::find(running_queue_.begin(), running_queue_.end(), request); - ICHECK(it != running_queue_.end()); - int req_id = it - running_queue_.begin(); - running_queue_.erase(it); - - RequestState& state = request_states_.at(request); - int num_input_tokens = state->raw_input_length; - int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; - current_total_seq_len_ -= num_input_tokens + num_output_tokens; - for (RequestModelState mstate : state->mstates) { - ICHECK_EQ(mstate->request_id, req_id); - mstate->request_id = -1; - } - RemoveSequenceFromModels(req_id); - UpdateRequestIDAfterRemoval(req_id); - - auto trequest_finish = std::chrono::high_resolution_clock::now(); - request_total_prefill_time_ += - static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; - total_prefill_length_ += num_input_tokens; - request_total_decode_time_ += - static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; - total_decode_length_ += num_output_tokens; - - // NOTE: right now we only return the generated text. - // In the future we might optional return text or token ids. - String output = ftokenizer_decode(ShapeTuple(state->mstates[0]->committed_tokens.begin(), - state->mstates[0]->committed_tokens.end())); - state->output = output.operator std::string(); - request->fcallback(request, TextData(state->output)); - - // Remove the request from states. - request_states_.erase(request); - } + /*! \brief Check if the input requests can be decoded under conditions. */ + bool CanDecode(int num_requests) { + int num_available_pages = models_[0]->GetNumAvailablePages(); + return num_requests <= num_available_pages; } - /*! \brief Check if the given request is finished under conditions. */ - bool RequestIsFinished(Request request) { - RequestState& state = request_states_.at(request); - - // - Case 0. There is remaining draft output ==> Unfinished - // All draft outputs are supposed to be processed before finish. - for (RequestModelState mstate : state->mstates) { - if (!mstate->draft_output_tokens.empty()) { - return false; - } - } - - // - Decode committed tokens. - const std::vector& committed_tokens = state->mstates[0]->committed_tokens; - - // Case 1. Any of the stop strings appears in output ==> Finished - // Todo: handle stop_str by tokenizing. So that we don't detokenize during check + /******************* + * Action 3. Abort * + *******************/ - // Case 2. Any of the stop tokens appears in the committed tokens ===> Finished - if (std::any_of(request->generation_cfg->stop_tokens.begin(), - request->generation_cfg->stop_tokens.end(), [&committed_tokens](int32_t token) { - return token == committed_tokens.back(); - })) { - return true; - } - // Case 3. Generation reaches the specified max generation length ==> Finished - if (static_cast(committed_tokens.size()) >= request->generation_cfg->max_new_tokens) { - return true; - } - // Case 4. Total length of the request reaches the maximum single sequence length ==> Finished - if (state->raw_input_length + static_cast(committed_tokens.size()) >= - max_single_sequence_length_) { - return true; + /*! \brief Abort the generation of the given request. */ + void StepAbort(Request request) { + auto it_running = std::find(running_queue_.begin(), running_queue_.end(), request); + auto it_waiting = std::find(waiting_queue_.begin(), waiting_queue_.end(), request); + ICHECK(it_running != running_queue_.end() || it_waiting != waiting_queue_.end()); + if (it_running != running_queue_.end()) { + // The request to abort is in running queue + int req_id = it_running - running_queue_.begin(); + running_queue_.erase(it_running); + RequestState state = request_states_.at(request->id); + stats_.current_total_seq_len -= + request->input_total_length + state->mstates[0]->committed_tokens.size() - 1; + RemoveRequestFromModel(req_id); + } else { + // The request to abort is in waiting queue + waiting_queue_.erase(it_waiting); } - return false; + request_states_.erase(request->id); } + /***************** Engine Data Structures *****************/ + // Request queues std::vector running_queue_; std::vector waiting_queue_; std::vector abort_queue_; // Request states - std::unordered_map request_states_; - - // Models - Array models_; - Array samplers_; - Module tokenizer_; - // Device corresponding to each model - std::vector devices_; - - /*! \brief Shared array for logits and probability distributions on cpu. */ - NDArray logits_or_probs_on_cpu_{nullptr}; - - // PackedFuncs from model/tokenizer/sampler/env. - std::vector fmodel_batch_prefill_; - std::vector fmodel_decode_; - std::vector fmodel_token_embed_; - std::vector fmodel_add_new_sequence_; - std::vector fmodel_remove_sequence_; - std::vector fmodel_softmax_with_temperature_; - std::vector fmodel_get_num_available_pages_; - std::vector fsampler_require_gpu_softmax_; - std::vector fsampler_compute_probs_from_logits_inplace_; - std::vector fsampler_sample_token_from_probs_; - PackedFunc ftokenizer_tokenize; - PackedFunc ftokenizer_decode; + std::unordered_map request_states_; + + // Models, sampler and tokenizer. + Array models_; + Sampler sampler_; + std::unique_ptr tokenizer_; // Runtime statistics - int64_t current_total_seq_len_; - /*! \brief The sum of "prefill time of each request". */ - double request_total_prefill_time_ = 0.0f; - /*! \brief The sum of "decode time of each request". */ - double request_total_decode_time_ = 0.0f; - /*! \brief The total engine time on prefill. */ - double engine_total_prefill_time_ = 0.0f; - /*! \brief The total engine time on decode. */ - double engine_total_decode_time_ = 0.0f; - /*! \brief The total number of processed tokens in prefill. */ - int64_t total_prefill_length_ = 0; - /*! \brief The total number of processed tokens in decode. */ - int64_t total_decode_length_ = 0; - - // Tokenization cache - std::unordered_map tokenize_cache_; + EngineStats stats_; // Configurations KVCacheConfig kv_cache_config_; @@ -1051,12 +696,6 @@ class EngineModule : public ModuleNode { } engine_->Reload(reload_libs, model_paths, devices, args[num_models * 4]); }); - } else if (name == "unload") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 0); - engine_ = nullptr; - ClearGlobalMemoryManager(); - }); } else if (name == "add_request") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1); @@ -1075,7 +714,7 @@ class EngineModule : public ModuleNode { } else if (name == "stats") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 0); - *rv = GetEngine()->StatisticsJSON(); + *rv = GetEngine()->stats_.AsJSON(); }); } else if (name == "reset") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { @@ -1104,9 +743,7 @@ tvm::runtime::Module CreateEngineModule() { } // register as a system function that can be queried -TVM_REGISTER_GLOBAL("mlc.serve.create_engine").set_body_typed([]() { - return CreateEngineModule(); -}); +TVM_REGISTER_GLOBAL("mlc.serve.create_engine").set_body_typed(CreateEngineModule); } // namespace serve } // namespace llm diff --git a/cpp/serve/engine_stats.cc b/cpp/serve/engine_stats.cc new file mode 100644 index 0000000000..8339008fba --- /dev/null +++ b/cpp/serve/engine_stats.cc @@ -0,0 +1,40 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_stats.cc + */ +#define PICOJSON_USE_INT64 + +#include "engine_stats.h" + +#include + +namespace mlc { +namespace llm { +namespace serve { + +String EngineStats::AsJSON() const { + picojson::object config; + config["single_token_prefill_latency"] = + picojson::value(request_total_prefill_time / total_prefill_length); + config["single_token_decode_latency"] = + picojson::value(request_total_decode_time / total_decode_length); + config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time); + config["engine_total_decode_time"] = picojson::value(engine_total_decode_time); + config["total_prefill_tokens"] = picojson::value(total_prefill_length); + config["total_decode_tokens"] = picojson::value(total_decode_length); + return picojson::value(config).serialize(true); +} + +void EngineStats::Reset() { + current_total_seq_len = 0; + request_total_prefill_time = 0.0f; + request_total_decode_time = 0.0f; + engine_total_prefill_time = 0.0f; + engine_total_decode_time = 0.0f; + total_prefill_length = 0; + total_decode_length = 0; +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_stats.h b/cpp/serve/engine_stats.h new file mode 100644 index 0000000000..6aa3f7397f --- /dev/null +++ b/cpp/serve/engine_stats.h @@ -0,0 +1,53 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_stats.h + */ +#ifndef MLC_LLM_SERVE_ENGINE_STATS_H_ +#define MLC_LLM_SERVE_ENGINE_STATS_H_ + +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! \brief Runtime statistics of engine. */ +struct EngineStats { + /*! \brief The current total sequence length in the first model. */ + int64_t current_total_seq_len; + /*! \brief The sum of "prefill time of each request". */ + double request_total_prefill_time = 0.0f; + /*! \brief The sum of "decode time of each request". */ + double request_total_decode_time = 0.0f; + /*! \brief The total engine time on prefill. */ + double engine_total_prefill_time = 0.0f; + /*! \brief The total engine time on decode. */ + double engine_total_decode_time = 0.0f; + /*! \brief The total number of processed tokens in prefill. */ + int64_t total_prefill_length = 0; + /*! \brief The total number of processed tokens in decode. */ + int64_t total_decode_length = 0; + + /*! + * \brief Return the engine runtime statistics in JSON string. + * We collect the following entries: + * - single token prefill latency (s/tok): avg latency of processing one token in prefill + * - single token decode latency (s/tok): avg latency of processing one token in decode + * - engine time for prefill (sec) + * - engine time for decode (sec) + * - total number of processed tokens in prefill. + * - total number of processed tokens in decode. + * \return The statistics in JSON string. + */ + String AsJSON() const; + /*! \brief Reset all the statistics. */ + void Reset(); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_ENGINE_STATS_H_ diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 9c2a1e768b..fe77418c84 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -9,48 +9,117 @@ #include "model.h" #include -#include -#include #include #include #include -#include "config.h" -#include "function_table.h" - namespace mlc { namespace llm { namespace serve { +/*********************** Utils ***********************/ + /*! - * \brief The runtime module for LLM functions. - * It runs an LLM, and has an internal KV cache that maintains - * the history KV values of all processed tokens. - * - * It contains the following functions: - * - * Model related: - * - "token_embed": take token ids as input and return the embeddings, - * - "single_seq_prefill": take embedding of a single sequence - * as input, forward the embedding through LLM and return the logits, - * - "decode": take the embeddings of the last-committed token of an - * entire batch as input, forward through LLM and return the logits - * for all sequences in the batch, - * - "softmax_with_temperature": take logits and temperatures, return - * probabilities. - * - * KV cache related: - * - "create_kv_cache": create the KV cache for this module, - * - "add_new_sequence": add (declare) a new sequence in the KV cache, - * - "remove_sequence": remove a sequence from KV cache. - * - * ... and more other auxiliary functions. + * \brief Concatenate the input embeddings along the sequence dimension. + * Store the concatenation result into the input destination NDarray. + * Return concatenation result as an NDArray view of the destination array. + * \param embedding_arr The array of embeddings to concatenate. + * \param total_length The total length of the input embeddings along the sequence dim. + * \param device The device where the embeddings locate. + * \param initial_seq_len The initial sequence length to allocate for embeddings. + * \param dst The destination of the concatenation + * \return The concatenated embeddings. */ -class ModelModule : public ModuleNode { +NDArray ConcatEmbeddings(Array embedding_arr, int64_t total_length, DLDevice device, + int initial_seq_len, NDArray* dst) { + ICHECK(!embedding_arr.empty()); + ICHECK_NOTNULL(dst); + int hidden_size = -1; + DataType dtype; + for (NDArray inp_embeddings : embedding_arr) { + // inp_embedding: (1, n, h) + CHECK_EQ(inp_embeddings->ndim, 3); + CHECK_EQ(inp_embeddings->shape[0], 1); + CHECK_EQ(inp_embeddings->device.device_type, device.device_type); + CHECK_EQ(inp_embeddings->device.device_id, device.device_id); + if (hidden_size == -1) { + hidden_size = inp_embeddings->shape[2]; + dtype = inp_embeddings.DataType(); + } else { + CHECK_EQ(inp_embeddings->shape[2], hidden_size); + CHECK_EQ(inp_embeddings.DataType(), dtype); + } + } + + // - Resize the shared embedding array. + if (dst->defined()) { + ICHECK_EQ((*dst)->ndim, 3); + ICHECK_EQ((*dst)->shape[0], 1); + ICHECK_EQ((*dst)->shape[2], hidden_size); + } + int64_t init_size = dst->defined() ? (*dst)->shape[1] : initial_seq_len; + while (init_size < total_length) { + init_size *= 2; + } + if (!dst->defined() || init_size != (*dst)->shape[1]) { + *dst = NDArray::Empty({1, init_size, hidden_size}, dtype, device); + } + + // - Copy input embeddings. + int64_t start_pos = 0; + for (NDArray inp_embeddings : embedding_arr) { + int64_t length = inp_embeddings->shape[1]; + CHECK_LE(start_pos + length, total_length); + + DLTensor copy_dst = *(dst->operator->()); + copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes(); + copy_dst.shape = inp_embeddings->shape; + NDArray::CopyFromTo(inp_embeddings.operator->(), ©_dst); + + start_pos += length; + } + CHECK_EQ(start_pos, total_length); + return dst->CreateView({1, total_length, hidden_size}, dtype); +} + +/*! \brief Utility function that copies input array to the device. */ +template +NDArray CopyArrayToDevice(const std::vector& array, NDArray* dst, DLDataType dtype, + int default_init_size, Device device) { + ICHECK(!array.empty()); + ICHECK(dst != nullptr); + ICHECK(!dst->defined() || (*dst)->ndim == 1); + int64_t init_size = dst->defined() ? (*dst)->shape[0] : default_init_size; + while (init_size < static_cast(array.size())) { + init_size *= 2; + } + if (!dst->defined() || init_size != (*dst)->shape[0]) { + (*dst) = NDArray::Empty({init_size}, dtype, device); + } + ICHECK_LE(static_cast(array.size()), (*dst)->shape[0]); + NDArray view = dst->CreateView(ShapeTuple({static_cast(array.size())}), dtype); + view.CopyFromBytes(array.data(), array.size() * sizeof(T)); + return view; +} + +/*********************** Model Implementation ***********************/ + +class ModelImpl; + +TVM_REGISTER_OBJECT_TYPE(ModelObj); + +Model Model::Create(TVMArgValue reload_lib, String model_path, DLDevice device) { + return Model(make_object(reload_lib, model_path, device)); +} + +class ModelImpl : public ModelObj { public: - explicit ModelModule(TVMArgValue reload_lib, String model_path, DLDevice device) - : device_(device) { + /*! + * \brief Constructor of ModelImpl. + * \sa Model::Create + */ + explicit ModelImpl(TVMArgValue reload_lib, String model_path, DLDevice device) : device_(device) { // Step 1. Process model config json string. { std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); @@ -70,100 +139,18 @@ class ModelModule : public ModuleNode { this->Reset(); } - // overrides - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { - if (name == "token_embed") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - *rv = TokenEmbed(args[0]); - }); - } else if (name == "batch_prefill") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 3); - *rv = BatchPrefill(args[0], args[1], args[2]); - }); - } else if (name == "decode") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - *rv = Decode(args[0]); - }); - } else if (name == "softmax_with_temperature") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 2); - *rv = SoftmaxWithTemperature(args[0], args[1]); - }); - } else if (name == "create_kv_cache") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 1); - KVCacheConfig kv_cache_config = args[0]; - kv_cache_ = ft_.create_kv_cache_func_( - ShapeTuple({kv_cache_config->max_num_sequence, - kv_cache_config->max_total_sequence_length, kv_cache_config->page_size})); - }); - } else if (name == "add_new_sequence") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 0); - *rv = ft_.add_sequence_to_kv_cache_func_(kv_cache_); - }); - } else if (name == "remove_sequence") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 1); - ft_.remove_from_kv_cache_func_(kv_cache_, args[0]); - }); - } else if (name == "reset") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 0); - Reset(); - }); - } else if (name == "get_num_available_pages") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 0); - ICHECK(kv_cache_.defined()); - *rv = ft_.get_num_available_pages_kv_cache_func_(kv_cache_); - }); - } else if (name == "get_max_window_size") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 0); - CHECK_NE(max_window_size_, -1) << "The model has not been initialized"; - *rv = max_window_size_; - }); - } else { - return PackedFunc(nullptr); - } - } - - const char* type_key() const final { return "mlc.serve.Model"; } + /*********************** Model Computation ***********************/ - private: - /*! - * \brief Compute embeddings for the input token ids. - * \param batch_token_ids The batch of token ids to compute embedding for. - * \return The computed embeddings. - * \note This function will **flatten** the input batch token ids, - * and return the NDArray flattened on the batch/sequence dimension. - * The caller side can decide whether to reshape the returned - * NDArray into some other shape or not. - * This brings the convenience for batched prefill and speculation - * verification where input sequences / draft outputs might can - * have different lengths, and we forward the flattened embeddings - * to prefill/verification. - */ - NDArray TokenEmbed(Array batch_token_ids) { - // Flatten input tokens. - int total_length = 0; - std::vector flattened_token_ids; - for (ShapeTuple token_ids : batch_token_ids) { - flattened_token_ids.insert(flattened_token_ids.end(), token_ids->data, - token_ids->data + token_ids.size()); - total_length += token_ids.size(); - } + NDArray TokenEmbed(IntTuple token_ids) final { + int num_tokens = token_ids.size(); + std::vector vec_token_ids(token_ids->data, token_ids->data + num_tokens); // Copy input token ids to device. DLDataType dtype(DataType::Int(32)); NDArray token_ids_nd = - CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, max_window_size_); + CopyArrayToDevice(vec_token_ids, &input_token_ids_, dtype, max_window_size_, device_); ICHECK_EQ(token_ids_nd->ndim, 1); - ICHECK_EQ(token_ids_nd->shape[0], total_length); - token_ids_nd = token_ids_nd.CreateView({1, total_length}, dtype); + ICHECK_EQ(token_ids_nd->shape[0], num_tokens); + token_ids_nd = token_ids_nd.CreateView({1, num_tokens}, dtype); CHECK(ft_.embed_func_.defined()) << "`embed` function is not found in the model. Please make sure the model is compiled " @@ -174,18 +161,12 @@ class ModelModule : public ModuleNode { // embeddings: (1, total_length, hidden_size) ICHECK_EQ(embeddings->ndim, 3); ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], total_length); + ICHECK_EQ(embeddings->shape[1], num_tokens); return embeddings; } - /*! - * \brief Single-sequence prefill function. Embedding in, logits out. - * \param embeddings The embedding of the input to be prefilled. - * \param seq_id The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The logits for the next token. - */ - NDArray BatchPrefill(Array embedding_arr, ShapeTuple seq_ids, ShapeTuple lengths) { + NDArray BatchPrefill(Array embedding_arr, std::vector seq_ids, + std::vector lengths) final { CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); int num_sequences = seq_ids.size(); @@ -201,14 +182,16 @@ class ModelModule : public ModuleNode { } // embeddings: (1, n, h) - NDArray embeddings = ConcatEmbeddings(std::move(embedding_arr), total_length); + NDArray embeddings = ConcatEmbeddings(std::move(embedding_arr), total_length, device_, + max_window_size_, &embeddings_); ICHECK_EQ(embeddings->ndim, 3); ICHECK_EQ(embeddings->shape[0], 1); ICHECK_EQ(embeddings->shape[1], total_length); ICHECK_EQ(embeddings->device.device_type, device_.device_type); ICHECK_EQ(embeddings->device.device_id, device_.device_id); - NDArray logit_pos_nd = CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32); + NDArray logit_pos_nd = + CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32, device_); CHECK(ft_.prefill_func_.defined()) << "`prefill_with_embed` function is not found in the model. Please make sure the model is " @@ -237,15 +220,7 @@ class ModelModule : public ModuleNode { return logits; } - /*! - * \brief Batch decode function. Embedding in, logits out. - * \param embeddings The embedding of last generated token in the entire batch. - * \return The logits for the next token for each sequence in the batch. - * \note The function runs for **every** sequence in the batch. - * That is to say, it does not accept "running a decode step for a subset - * of the full batch". - */ - NDArray Decode(NDArray embeddings) { + NDArray BatchDecode(NDArray embeddings) final { // embeddings: (b, 1, h) CHECK_EQ(embeddings->ndim, 3); CHECK_EQ(embeddings->shape[1], 1); @@ -278,13 +253,7 @@ class ModelModule : public ModuleNode { return logits; } - /*! - * \brief Computing probabilities from logits with softmax and temperatures. - * \param logits The logits to compute from. - * \param generation_cfg The generation config which contains the temperatures. - * \return The computed probabilities distribution. - */ - NDArray SoftmaxWithTemperature(NDArray logits, Array generation_cfg) { + NDArray SoftmaxWithTemperature(NDArray logits, Array generation_cfg) final { // logits: (b, n, v) CHECK_EQ(logits->ndim, 3); CHECK_EQ(logits->shape[0], generation_cfg.size()); @@ -297,7 +266,8 @@ class ModelModule : public ModuleNode { for (GenerationConfig cfg : generation_cfg) { temperatures.push_back(cfg->temperature); } - NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32); + NDArray temperatures_nd = + CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32, device_); ICHECK_EQ(temperatures_nd->ndim, 1); ICHECK_EQ(temperatures_nd->shape[0], batch_size); @@ -309,77 +279,40 @@ class ModelModule : public ModuleNode { return probs; } - /*! \brief Copy input array to the device. */ - template - NDArray CopyArrayToDevice(const std::vector& array, NDArray* dst, DLDataType dtype, - int default_init_size) { - ICHECK(!array.empty()); - ICHECK(dst != nullptr); - ICHECK(!dst->defined() || (*dst)->ndim == 1); - int64_t init_size = dst->defined() ? (*dst)->shape[0] : default_init_size; - while (init_size < static_cast(array.size())) { - init_size *= 2; - } - if (!dst->defined() || init_size != (*dst)->shape[0]) { - (*dst) = NDArray::Empty({init_size}, dtype, device_); - } - ICHECK_LE(static_cast(array.size()), (*dst)->shape[0]); - NDArray view = dst->CreateView(ShapeTuple({static_cast(array.size())}), dtype); - view.CopyFromBytes(array.data(), array.size() * sizeof(T)); - return view; + /*********************** KV Cache Management ***********************/ + + void CreateKVCache(KVCacheConfig kv_cache_config) final { + kv_cache_ = ft_.create_kv_cache_func_( + IntTuple({kv_cache_config->max_num_sequence, kv_cache_config->max_total_sequence_length, + kv_cache_config->page_size})); } - /*! \brief Concatenate the input embeddings. */ - NDArray ConcatEmbeddings(Array embedding_arr, int64_t total_length) { - ICHECK(!embedding_arr.empty()); - int hidden_size = -1; - DataType dtype; - for (NDArray inp_embeddings : embedding_arr) { - // inp_embedding: (1, n, h) - CHECK_EQ(inp_embeddings->ndim, 3); - CHECK_EQ(inp_embeddings->shape[0], 1); - CHECK_EQ(inp_embeddings->device.device_type, device_.device_type); - CHECK_EQ(inp_embeddings->device.device_id, device_.device_id); - if (hidden_size == -1) { - hidden_size = inp_embeddings->shape[2]; - dtype = inp_embeddings.DataType(); - } else { - CHECK_EQ(inp_embeddings->shape[2], hidden_size); - CHECK_EQ(inp_embeddings.DataType(), dtype); - } - } + /*! \brief Add a new sequence to the KV cache. Return the in-cache id of the sequence. */ + int AddNewSequence() final { return ft_.add_sequence_to_kv_cache_func_(kv_cache_); } - // - Resize the shared embedding array. - if (embeddings_.defined()) { - ICHECK_EQ(embeddings_->ndim, 3); - ICHECK_EQ(embeddings_->shape[0], 1); - ICHECK_EQ(embeddings_->shape[2], hidden_size); - } - int64_t init_size = embeddings_.defined() ? embeddings_->shape[1] : max_window_size_; - while (init_size < total_length) { - init_size *= 2; - } - if (!embeddings_.defined() || init_size != embeddings_->shape[1]) { - embeddings_ = NDArray::Empty({1, init_size, hidden_size}, dtype, device_); - } + /*! \brief Remove the given sequence from the KV cache in the model. */ + void RemoveSequence(int seq_id) final { ft_.remove_from_kv_cache_func_(kv_cache_, seq_id); } - // - Copy input embeddings. - int64_t start_pos = 0; - for (NDArray inp_embeddings : embedding_arr) { - int64_t length = inp_embeddings->shape[1]; - CHECK_LE(start_pos + length, total_length); + /*! \brief Get the number of available pages in KV cache. */ + int GetNumAvailablePages() const final { + return ft_.get_num_available_pages_kv_cache_func_(kv_cache_); + } + + /*********************** Utilities ***********************/ - DLTensor copy_dst = *(embeddings_.operator->()); - copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes(); - copy_dst.shape = inp_embeddings->shape; - NDArray::CopyFromTo(inp_embeddings.operator->(), ©_dst); + int GetMaxWindowSize() const final { + CHECK_NE(max_window_size_, -1) << "The model has not been initialized"; + return max_window_size_; + } - start_pos += length; + void Reset() final { + // Reset the KV cache. + if (kv_cache_.defined()) { + ft_.reset_kv_cache_func_(kv_cache_); } - CHECK_EQ(start_pos, total_length); - return embeddings_.CreateView({1, total_length, hidden_size}, dtype); } + private: /*! \brief Load model configuration from JSON. */ void LoadModelConfigJSON(const std::string& config_str) { picojson::value config_json; @@ -410,20 +343,13 @@ class ModelModule : public ModuleNode { } } - /*! \brief reset the runtime states. */ - void Reset() { - // Reset the KV cache. - if (kv_cache_.defined()) { - ft_.reset_kv_cache_func_(kv_cache_); - } - } - //---------------------------- // Model configurations //---------------------------- std::string model_name_; int num_shards_ = -1; int max_window_size_ = -1; + //---------------------------- // TVM related states //---------------------------- @@ -431,9 +357,9 @@ class ModelModule : public ModuleNode { FunctionTable ft_; // Paged KV cache ObjectRef kv_cache_{nullptr}; - // runtime device + // Runtime device Device device_; - // model params + // Model parameters ObjectRef params_; // Shared NDArray NDArray input_token_ids_{nullptr}; @@ -442,11 +368,6 @@ class ModelModule : public ModuleNode { NDArray temperature_arr_{nullptr}; }; -tvm::runtime::Module CreateModelModule(TVMArgValue reload_lib, String model_path, DLDevice device) { - ObjectPtr n = make_object(reload_lib, std::move(model_path), device); - return Module(n); -} - } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 25cf92e9c6..0260de7014 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -8,9 +8,11 @@ #define MLC_LLM_SERVE_MODEL_H_ #include -#include +#include #include "../base.h" +#include "config.h" +#include "function_table.h" namespace mlc { namespace llm { @@ -20,15 +22,114 @@ using tvm::Device; using namespace tvm::runtime; /*! - * \brief Create the runtime module for LLM functions. - * \param reload_lib The model library. It might be a path to the binary - * file or an executable module that is pre-loaded. - * \param model_path The path to the model weight parameters. - * \param device The device to run the model on. - * \return The created runtime module. + * \brief The model module for LLM functions. + * It runs an LLM, and has an internal KV cache that maintains + * the history KV values of all processed tokens. + * + * It contains the following functions: + * + * Model related: + * - "token_embed": take token ids as input and return the embeddings, + * - "batch_prefill": take embedding of a single sequence + * as input, forward the embedding through LLM and return the logits, + * - "decode": take the embeddings of the last-committed token of an + * entire batch as input, forward through LLM and return the logits + * for all sequences in the batch, + * - "softmax_with_temperature": take logits and temperatures, return + * probabilities. + * + * KV cache related: + * - "create_kv_cache": create the KV cache for this module, + * - "add_new_sequence": add (declare) a new sequence in the KV cache, + * - "remove_sequence": remove a sequence from KV cache. + * + * ... and some other auxiliary functions. */ -MLC_LLM_DLL tvm::runtime::Module CreateModelModule(TVMArgValue reload_lib, String model_path, - DLDevice device); +class ModelObj : public Object { + public: + /*********************** Model Computation ***********************/ + + /*! + * \brief Compute embeddings for the input token ids. + * \param token_ids The token ids to compute embedding for. + * \return The computed embeddings. + */ + virtual NDArray TokenEmbed(IntTuple batch_token_ids) = 0; + + /*! + * \brief Batch prefill function. Embedding in, logits out. + * \param embeddings The embedding of the input to be prefilled. + * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The logits for the next token. + */ + virtual NDArray BatchPrefill(Array embedding_arr, std::vector seq_ids, + std::vector) = 0; + + /*! + * \brief Batch decode function. Embedding in, logits out. + * \param embeddings The embedding of last generated token in the entire batch. + * \return The logits for the next token for each sequence in the batch. + * \note The function runs for **every** sequence in the batch. + * That is to say, it does not accept "running a decode step for a subset + * of the full batch". + */ + virtual NDArray BatchDecode(NDArray embeddings) = 0; + + /*! + * \brief Computing probabilities from logits with softmax and temperatures. + * \param logits The logits to compute from. + * \param generation_cfg The generation config which contains the temperatures. + * \return The computed probabilities distribution. + */ + virtual NDArray SoftmaxWithTemperature(NDArray logits, + Array generation_cfg) = 0; + + /*********************** KV Cache Management ***********************/ + + /*! + * \brief Create the KV cache inside the model with regard to the input config. + * \param kv_cache_config The configuration of KV cache. + */ + virtual void CreateKVCache(KVCacheConfig kv_cache_config) = 0; + + /*! \brief Add a new sequence to the KV cache. Return the in-cache id of the sequence. */ + virtual int AddNewSequence() = 0; + + /*! \brief Remove the given sequence from the KV cache in the model. */ + virtual void RemoveSequence(int seq_id) = 0; + + /*! \brief Get the number of available pages in KV cache. */ + virtual int GetNumAvailablePages() const = 0; + + /*********************** Utilities ***********************/ + + /*! \brief Get the max window size of the model. */ + virtual int GetMaxWindowSize() const = 0; + + /*! \brief Reset the model KV cache and other statistics. */ + virtual void Reset() = 0; + + static constexpr const char* _type_key = "mlc.serve.Model"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(ModelObj, Object); +}; + +class Model : public ObjectRef { + public: + /*! + * \brief Create the runtime module for LLM functions. + * \param reload_lib The model library. It might be a path to the binary + * file or an executable module that is pre-loaded. + * \param model_path The path to the model weight parameters. + * \param device The device to run the model on. + * \return The created runtime module. + */ + TVM_DLL static Model Create(TVMArgValue reload_lib, String model_path, DLDevice device); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); +}; } // namespace serve } // namespace llm diff --git a/cpp/serve/request.cc b/cpp/serve/request.cc index d5b3f50cff..af5f54b489 100644 --- a/cpp/serve/request.cc +++ b/cpp/serve/request.cc @@ -18,18 +18,59 @@ namespace serve { TVM_REGISTER_OBJECT_TYPE(RequestNode); -Request::Request(Array inputs, GenerationConfig generation_cfg, PackedFunc fcallback) { +Request::Request(String id, Array inputs, GenerationConfig generation_cfg, + PackedFunc fcallback) { + CHECK(!inputs.empty()) << "No input data is given."; + // Compute the total input length, or fall back to "-1" which means + // unknown due to the existence of untokenized data. + int input_total_length = 0; + for (Data input : inputs) { + if (const auto* token_data = input.as()) { + input_total_length += token_data->token_ids.size(); + } else { + input_total_length = -1; + break; + } + } + ObjectPtr n = make_object(); + n->id = std::move(id); n->inputs = std::move(inputs); + n->input_total_length = input_total_length; n->generation_cfg = std::move(generation_cfg); n->fcallback = std::move(fcallback); data_ = std::move(n); } +Request Request::FromUntokenized(Request request, const std::unique_ptr& tokenizer) { + bool has_untokenized_input = false; + Array inputs; + inputs.reserve(request->inputs.size()); + // Tokenize all text inputs. + for (Data input : request->inputs) { + if (const auto* text_data = input.as()) { + has_untokenized_input = true; + std::vector token_ids = tokenizer->Encode(text_data->text); + inputs.push_back(TokenData(token_ids)); + } else { + inputs.push_back(input); + } + } + + // If there is no untokenized input, we don't need to create a new request. + if (!has_untokenized_input) { + ICHECK_NE(request->input_total_length, -1); + return request; + } else { + return Request(request->id, std::move(inputs), request->generation_cfg, request->fcallback); + } +} + TVM_REGISTER_GLOBAL("mlc.serve.Request") - .set_body_typed([](Array inputs, String generation_cfg_json, PackedFunc fcallback) { - return Request(std::move(inputs), GenerationConfig(std::move(generation_cfg_json)), - std::move(fcallback)); + .set_body_typed([](String id, Array inputs, String generation_cfg_json, + PackedFunc fcallback) { + return Request(std::move(id), std::move(inputs), + GenerationConfig(std::move(generation_cfg_json)), std::move(fcallback)); }); TVM_REGISTER_GLOBAL("mlc.serve.RequestGetInputs").set_body_typed([](Request request) { diff --git a/cpp/serve/request.h b/cpp/serve/request.h index 6e169a9afc..4bbda345cd 100644 --- a/cpp/serve/request.h +++ b/cpp/serve/request.h @@ -11,6 +11,7 @@ #include #include +#include "../tokenizers.h" #include "config.h" #include "data.h" @@ -31,11 +32,22 @@ using namespace tvm::runtime; */ class RequestNode : public Object { public: + /*! + * \brief The unique identifier of the request. + * Different requests should have different ids. + */ + String id; /*! * \brief The user inputs of a request. Input may have multi-modality. * \sa data.h */ Array inputs; + /*! + * \brief The equivalent total input sequence length of the request. + * "-1" means the total input length is unknown due to the existence + * of untokenized text data. + */ + int input_total_length = -1; /*! * \brief The sampling configuration which may contain temperature, * top_p, repetition_penalty, max_gen_len, etc. @@ -56,7 +68,17 @@ class RequestNode : public Object { class Request : public ObjectRef { public: - explicit Request(Array inputs, GenerationConfig generation_cfg, PackedFunc fcallback); + explicit Request(String id, Array inputs, GenerationConfig generation_cfg, + PackedFunc fcallback); + + /*! + * \brief Return a request object with all text data tokenized, + * and the request ID kept the same as the input one. + * \param request The request to be tokenized. + * \param tokenizer The tokenizer to tokenize the input data of the given request. + * \return The request object whose data are tokenized. + */ + static Request FromUntokenized(Request request, const std::unique_ptr& tokenizer); TVM_DEFINE_OBJECT_REF_METHODS(Request, ObjectRef, RequestNode); }; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 38b5a7bf66..e1d5000081 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -5,8 +5,6 @@ #include "request_state.h" -#include "data.h" - namespace mlc { namespace llm { namespace serve { @@ -23,21 +21,62 @@ RequestModelState::RequestModelState(int model_id, Array inputs) { data_ = std::move(n); } +int RequestModelStateNode::GetInputLength() const { + int total_length = 0; + for (Data input : inputs) { + total_length += input->GetLength(); + } + return total_length; +} + TVM_REGISTER_OBJECT_TYPE(RequestStateNode); -RequestState::RequestState(int num_models, Array inputs, int raw_input_length) { +RequestState::RequestState(Request request, int num_models) { ObjectPtr n = make_object(); Array mstates; mstates.reserve(num_models); for (int i = 0; i < num_models; ++i) { - mstates.push_back(RequestModelState(i, inputs)); + mstates.push_back(RequestModelState(i, request->inputs)); } + n->request = std::move(request); n->mstates = std::move(mstates); - n->raw_input_length = raw_input_length; n->tadd = std::chrono::high_resolution_clock::now(); data_ = std::move(n); } +bool RequestStateNode::GenerationFinished(int max_single_sequence_length) const { + // - Case 0. There is remaining draft output ==> Unfinished + // All draft outputs are supposed to be processed before finish. + for (RequestModelState mstate : mstates) { + if (!mstate->draft_output_tokens.empty()) { + return false; + } + } + + // - Decode committed tokens. + const std::vector& committed_tokens = mstates[0]->committed_tokens; + + // Case 1. Any of the stop strings appears in output ==> Finished + // Todo: handle stop_str by tokenizing. So that we don't detokenize during check + + // Case 2. Any of the stop tokens appears in the committed tokens ===> Finished + if (std::any_of( + request->generation_cfg->stop_tokens.begin(), request->generation_cfg->stop_tokens.end(), + [&committed_tokens](int32_t token) { return token == committed_tokens.back(); })) { + return true; + } + // Case 3. Generation reaches the specified max generation length ==> Finished + if (static_cast(committed_tokens.size()) >= request->generation_cfg->max_new_tokens) { + return true; + } + // Case 4. Total length of the request reaches the maximum single sequence length ==> Finished + if (request->input_total_length + static_cast(committed_tokens.size()) >= + max_single_sequence_length) { + return true; + } + return false; +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 16f39548c2..71a3966eaa 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -11,7 +11,7 @@ #include #include "config.h" -#include "data.h" +#include "request.h" namespace mlc { namespace llm { @@ -30,7 +30,7 @@ using namespace tvm::runtime; class RequestModelStateNode : public Object { public: /*! - * \brief The corresponding request id of this state. + * \brief The internal request id of this state. * It is the **physical index** of the request in the running request queue. * If the request is on hold (not in the running queue), the request id * should be -1. @@ -73,6 +73,9 @@ class RequestModelStateNode : public Object { */ std::vector draft_output_token_prob; + /*! \brief Return the total length of the input data. */ + int GetInputLength() const; + static constexpr const char* _type_key = "mlc.serve.RequestModelState"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; @@ -88,22 +91,25 @@ class RequestModelState : public ObjectRef { class RequestStateNode : public Object { public: + /*! \brief The request that this state corresponds to. */ + Request request; /*! * \brief The state with regard to each model. * \sa RequestModelState */ Array mstates; - /*! \brief The summed up input length of the request. */ - int raw_input_length = 0; - /*! \brief The decoded text string output. */ - std::string output = ""; - /*! \brief The time of adding the request to engine. */ std::chrono::_V2::system_clock::time_point tadd; /*! \brief The time of finishing prefill stage. */ std::chrono::_V2::system_clock::time_point tprefill_finish; + /*! + * \brief Check if the request generation is finished. + * \param max_single_sequence_length The maximum allowed single sequence length. + */ + bool GenerationFinished(int max_single_sequence_length) const; + static constexpr const char* _type_key = "mlc.serve.RequestState"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; @@ -112,7 +118,7 @@ class RequestStateNode : public Object { class RequestState : public ObjectRef { public: - explicit RequestState(int num_models, Array inputs, int raw_input_length); + explicit RequestState(Request request, int num_models); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode); }; diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index 7012b22869..4a911c32fc 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -7,269 +7,129 @@ #include "sampler.h" -#include -#include #include #include #include +#include #include #include "../random.h" -#include "request_state.h" namespace mlc { namespace llm { namespace serve { -int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample); +/***** Utility function for in-place logits/prob update on CPU *****/ /*! - * \brief Execute the given lambda function in parallel with - * threading backend in TVM. - * \tparam T The type of the lambda: "void (int i)". - * \param flambda The lambda to be executed in parallel. - * It should have the signature "void (int i)". - * \param begin The start index of this parallel loop (inclusive). - * \param end The end index of this parallel loop (exclusive). - * \example - * - * The for loop - * for (int i = 0; i < 10; i++) { - * a[i] = i; - * } - * should work the same as: - * parallel_for_with_threading_backend([&a](int i) { - * a[i] = i; - * }, 0, 10); + * \brief In-place apply repetition penalty to logits based on history tokens. + * \param logits The logits (a batch) to be in-place mutated. + * \param token_offset The offset of the token in the batch + * whose logits will be updated. + * \param state The request state that contains history tokens. + * \param repetition_penalty The value of repetition penalty. */ -template -inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end); - -/*! - * \brief The sampler runtime module. - * It contains functions to - * - compute probability distribution out from logits, - * - sample token from probability distribution. - */ -class SamplerModule : public ModuleNode { - public: - explicit SamplerModule(DLDevice device) : device_(device), rng_(RandomGenerator::GetInstance()) { - // Set customized "logits -> prob" function. - const PackedFunc* f_logits_to_probs = - Registry::Get("mlc.llm.compute_probs_from_logits_inplace"); - if (f_logits_to_probs != nullptr) { - flogits_to_probs_inplace_ = *f_logits_to_probs; - } - } - - // overrides - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { - if (name == "compute_probs_from_logits_inplace") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 3); - ComputeProbsFromLogitsInplace(args[0], args[1], args[2]); - }); - } else if (name == "sample_token_from_probs") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 2); - *rv = SampleTokenFromProbs(args[0], args[1]); - }); - } else if (name == "require_gpu_softmax") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - *rv = RequireGPUSoftmax(args[0]); - }); +void ApplyRepetitionPenaltyOnCPU(NDArray logits, int token_offset, RequestModelState state, + double repetition_penalty) { + // logits: (n, v) + CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + CHECK_EQ(logits->ndim, 2); + CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); + int vocab_size = logits->shape[1]; + + // Collect appeared tokens. + std::unordered_set appeared_token_ids; + appeared_token_ids.insert(state->committed_tokens.begin(), state->committed_tokens.end()); + appeared_token_ids.insert(state->draft_output_tokens.begin(), state->draft_output_tokens.end()); + + float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); + for (int32_t token_id : appeared_token_ids) { + ICHECK_GE(token_id, 0); + ICHECK_LT(token_id, vocab_size); + if (logits_raw_data[token_id] <= 0) { + logits_raw_data[token_id] *= repetition_penalty; } else { - return PackedFunc(nullptr); - } - } - - const char* type_key() const final { return "mlc.serve.Sampler"; } - - private: - /*! - * \brief Given the generation config of a batch, check if the - * probability distributions needs to be computed on device via softmax. - * \param generation_cfg The input generation config. - * \return A boolean flag indicating if the check result. - */ - bool RequireGPUSoftmax(Array generation_cfg) { - // - Return false if there is customized probability compute function. - if (flogits_to_probs_inplace_.defined()) { - return false; + logits_raw_data[token_id] /= repetition_penalty; } - // - Return false if any sampling param has repetition penalty other than 1.0. - // - Return false if any sampling param has zero temperature. - for (GenerationConfig cfg : generation_cfg) { - if (cfg->repetition_penalty != 1.0 || cfg->temperature < 1e-6) { - return false; - } - } - return true; - } - - /*! - * \brief Compute the probability distribution from on-cpu logits for - * a batch of tokens **in place**. - * \param logits The input logits on CPU. - * \param states The request states, which contains the history generated tokens. - * \param generation_cfg The generation config. - * \note The function returns nothing. It in-place updates the input logits array. - */ - void ComputeProbsFromLogitsInplace(NDArray logits, Array states, - Array generation_cfg) { - // logits: (n, v) - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, kDLCPU); - - // - Invoke environment compute function if exists. - if (flogits_to_probs_inplace_.defined()) { - flogits_to_probs_inplace_(logits, states, generation_cfg); - return; - } - - parallel_for_with_threading_backend( - [this, &logits, &states, &generation_cfg](int i) { - // - Apply repetition penalty (inplace). - if (generation_cfg[i]->repetition_penalty != 1.0) { - ApplyRepetitionPenaltyOnCPU(logits, i, states[i], - generation_cfg[i]->repetition_penalty); - } - // - Compute probability (inplace) from logits. - // Using softmax if temperature is non-zero. - // Or set probability of the max-logit position to 1. - if (generation_cfg[i]->temperature >= 1e-6) { - ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature); - } else { - SetProbWithArgmaxOnCPU(logits, i); - } - }, - 0, logits->shape[0]); } +} - /*! - * \brief Sample tokens from a batch of input probability distributions. - * \param probs The input batch of probability distributions. - * \param generation_cfg The generation config. - * \return The sampled tokens, one for each instance of the batch. - */ - ShapeTuple SampleTokenFromProbs(NDArray probs, Array generation_cfg) { - // probs: (n, v) - CHECK_EQ(probs->ndim, 2); - CHECK_EQ(probs->device.device_type, kDLCPU); - - int n = probs->shape[0]; - std::vector random_numbers; - std::vector sampled_tokens; - random_numbers.reserve(n); - sampled_tokens.resize(n); - for (int i = 0; i < n; ++i) { - random_numbers.push_back(rng_.GetRandomNumber()); - } - - parallel_for_with_threading_backend( - [&sampled_tokens, &probs, &generation_cfg, &random_numbers](int i) { - // Sample top p from probability. - sampled_tokens[i] = - SampleTopPFromProb(probs, i, generation_cfg[i]->top_p, random_numbers[i]); - }, - 0, n); - return ShapeTuple(sampled_tokens.begin(), sampled_tokens.end()); +/*! + * \brief In-place compute softmax with temperature on CPU. + * \param logits The logits (a batch) to compute softmax from. + * \param token_offset The offset of the token in the batch + * to compute softmax for. Only the logits of the specified + * token will be updated to probability after softmax. + * \param temperature The temperature to apply before softmax. + */ +void ApplySoftmaxWithTemperatureOnCPU(NDArray logits, int token_offset, double temperature) { + // logits: (n, v) + CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + CHECK_EQ(logits->ndim, 2); + CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); + int vocab_size = logits->shape[1]; + + float* __restrict logits_raw_data = + static_cast(__builtin_assume_aligned(logits->data, 4)) + (token_offset * vocab_size); + float m = std::numeric_limits::min(); + float inv_temp = 1.0f / temperature; + double d = 0.0f; + for (int i = 0; i < vocab_size; ++i) { + float x = logits_raw_data[i] * inv_temp; + float m_prev = m; + m = std::max(m, x); + d = d * std::exp(m_prev - m) + std::exp(x - m); } - - /*! \brief Apply repetition penalty to logits based on history tokens. */ - void ApplyRepetitionPenaltyOnCPU(NDArray logits, int token_offset, RequestModelState state, - double repetition_penalty) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - // Collect appeared tokens. - std::unordered_set appeared_token_ids; - appeared_token_ids.insert(state->committed_tokens.begin(), state->committed_tokens.end()); - appeared_token_ids.insert(state->draft_output_tokens.begin(), state->draft_output_tokens.end()); - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - for (int32_t token_id : appeared_token_ids) { - ICHECK_GE(token_id, 0); - ICHECK_LT(token_id, vocab_size); - if (logits_raw_data[token_id] <= 0) { - logits_raw_data[token_id] *= repetition_penalty; - } else { - logits_raw_data[token_id] /= repetition_penalty; - } - } + for (int i = 0; i < vocab_size; ++i) { + float x = logits_raw_data[i] * inv_temp; + logits_raw_data[i] = std::exp(x - m) / d; } +} - /*! \brief Compute softmax with temperature on CPU. */ - void ApplySoftmaxWithTemperatureOnCPU(NDArray logits, int token_offset, double temperature) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* __restrict logits_raw_data = - static_cast(__builtin_assume_aligned(logits->data, 4)) + - (token_offset * vocab_size); - float m = std::numeric_limits::min(); - float inv_temp = 1.0f / temperature; - double d = 0.0f; - for (int i = 0; i < vocab_size; ++i) { - float x = logits_raw_data[i] * inv_temp; - float m_prev = m; - m = std::max(m, x); - d = d * std::exp(m_prev - m) + std::exp(x - m); - } - for (int i = 0; i < vocab_size; ++i) { - float x = logits_raw_data[i] * inv_temp; - logits_raw_data[i] = std::exp(x - m) / d; +/*! + * \brief In-place set probability via argmax. + * This is used for zero-temperature sampling cases. + * \param logits The logits (a batch) to set probability. + * \param token_offset The offset of the token in the batch + * to set probability for. Only the logits of the specified + * token will be updated to probability. + */ +void SetProbWithArgmaxOnCPU(NDArray logits, int token_offset) { + // logits: (n, v) + CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + CHECK_EQ(logits->ndim, 2); + CHECK_EQ(logits->device.device_type, kDLCPU); + int vocab_size = logits->shape[1]; + + float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); + int argmax_pos = -1; + float max_logits = std::numeric_limits::lowest(); + for (int i = 0; i < vocab_size; ++i) { + if (logits_raw_data[i] > max_logits) { + max_logits = logits_raw_data[i]; + argmax_pos = i; } } - /*! - * \brief Inplace set probability via argmax. - * This is used for zero-temperature sampling cases - */ - void SetProbWithArgmaxOnCPU(NDArray logits, int token_offset) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - int argmax_pos = -1; - float max_logits = std::numeric_limits::min(); - for (int i = 0; i < vocab_size; ++i) { - if (logits_raw_data[i] > max_logits) { - max_logits = logits_raw_data[i]; - argmax_pos = i; - } - } - - ICHECK_NE(argmax_pos, -1); - for (int i = 0; i < vocab_size; ++i) { - logits_raw_data[i] = i == argmax_pos ? 1.0f : 0.0f; - } + ICHECK_NE(argmax_pos, -1); + for (int i = 0; i < vocab_size; ++i) { + logits_raw_data[i] = i == argmax_pos ? 1.0f : 0.0f; } - - /*! \brief The runtime device where the input logits is. */ - DLDevice device_; - /*! \brief The random generator. */ - RandomGenerator& rng_; - /*! \brief Customized function which computes prob distribution from logits */ - PackedFunc flogits_to_probs_inplace_; -}; - -tvm::runtime::Module CreateSamplerModule(DLDevice device) { - ObjectPtr n = make_object(device); - return Module(n); } +/*! + * \brief Sample a value from the input probability distribution with top-p. + * The input is a batch of distributions, and we use `unit_offset` to specify + * which distribution to sample from. + * \param prob The input batch of probability distributions. + * \param unit_offset The offset specifying which distribution to sample from. + * \param top_p The top-p value of sampling. + * \param uniform_sample The random number in [0, 1] for sampling. + * \return The sampled value. + * \note This function is an enhancement of SampleTopPFromProb in TVM Unity. + * We will upstream the enhancement after it gets stable. + */ int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample) { // prob: (*, v) // The prob array may have arbitrary ndim and shape. @@ -380,6 +240,197 @@ int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double unifo return sampled_index; } +/*! + * \brief Copy logits or prob distributions from device to CPU. + * The input array is in layout (b, n, v). + * This function flattens the first dimension, returns an NDArray + * in shape (b * n, v). + */ +NDArray CopyLogitsOrProbsToCPU(NDArray arr_on_device, NDArray* arr_on_cpu) { + // arr_on_device: (b, n, v) + ICHECK_EQ(arr_on_device->ndim, 3); + ICHECK(!arr_on_cpu->defined() || (*arr_on_cpu)->ndim == 2); + ICHECK(arr_on_device->device.device_type != kDLCPU); + if (arr_on_cpu->defined()) { + ICHECK_EQ((*arr_on_cpu)->shape[1], arr_on_device->shape[2]); + } + + int64_t init_size = arr_on_cpu->defined() ? (*arr_on_cpu)->shape[0] : 32; + int64_t num_tokens = arr_on_device->shape[0] * arr_on_device->shape[1]; + int64_t vocab_size = arr_on_device->shape[2]; + while (init_size < num_tokens) { + init_size *= 2; + } + if (!arr_on_cpu->defined() || init_size != (*arr_on_cpu)->shape[0]) { + (*arr_on_cpu) = + NDArray::Empty({init_size, vocab_size}, arr_on_device->dtype, DLDevice{kDLCPU, 0}); + } + ICHECK_LE(num_tokens, (*arr_on_cpu)->shape[0]); + NDArray view = arr_on_cpu->CreateView({num_tokens, vocab_size}, arr_on_device->dtype); + view.CopyFrom(arr_on_device); + return view; +} + +/********************* CPU Sampler *********************/ + +class CPUSampler : public SamplerObj { + public: + explicit CPUSampler() : rng_(RandomGenerator::GetInstance()) { + // Set customized "logits -> prob" function. + const PackedFunc* f_logits_to_probs = + Registry::Get("mlc.llm.compute_probs_from_logits_inplace"); + if (f_logits_to_probs != nullptr) { + flogits_to_probs_inplace_ = *f_logits_to_probs; + } + } + + std::vector SampleTokens(NDArray logits_on_device, Model model, + Array request_mstates, + Array generation_cfg) final { + ICHECK(logits_on_device.defined()); + ICHECK_EQ(logits_on_device->ndim, 3); + ICHECK_EQ(logits_on_device->shape[1], 1) + << "Multi-token sampling for one sequence is not supported yet."; + ICHECK_EQ(logits_on_device->shape[0], generation_cfg.size()); + ICHECK_EQ(request_mstates.size(), generation_cfg.size()); + + int num_sequence = logits_on_device->shape[0]; + bool require_gpu_softmax = RequireGPUSoftmax(generation_cfg); + + // - Compute probabilities from logits. + NDArray logits_or_probs_on_cpu{nullptr}; + if (require_gpu_softmax) { + NDArray probs_on_device = model->SoftmaxWithTemperature(logits_on_device, generation_cfg); + logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(probs_on_device, &logits_or_probs_on_cpu_); + } else { + logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device, &logits_or_probs_on_cpu_); + // The "ComputeProbsFromLogitsInplace" function updates + // `logits_or_probs_on_cpu` in place. + ComputeProbsFromLogitsInplace(logits_or_probs_on_cpu, std::move(request_mstates), + generation_cfg); + } + // `CopyLogitsOrProbsToCPU` flattens the first two dimensions. + ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2); + + // - Sample tokens from probabilities. + // NOTE: Though we have the probability field in RequestModelState, + // we do not save the probabilities right now. + // We will handle this in the future when we work on speculation. + return SampleTokenFromProbs(logits_or_probs_on_cpu, generation_cfg); + } + + private: + /*! + * \brief Given the generation config of a batch, check if the + * probability distributions needs to be computed on device via softmax. + * \param generation_cfg The input generation config. + * \return A boolean flag indicating if the check result. + */ + bool RequireGPUSoftmax(Array generation_cfg) { + // - Return false if there is customized probability compute function. + if (flogits_to_probs_inplace_.defined()) { + return false; + } + // - Return false if any sampling param has repetition penalty other than 1.0. + // - Return false if any sampling param has zero temperature. + for (GenerationConfig cfg : generation_cfg) { + if (cfg->repetition_penalty != 1.0 || cfg->temperature < 1e-6) { + return false; + } + } + return true; + } + + /*! + * \brief Compute the probability distribution from on-cpu logits for + * a batch of tokens **in place**. + * \param logits The input logits on CPU. + * \param states The request states, which contains the history generated tokens. + * \param generation_cfg The generation config. + * \note The function returns nothing. It in-place updates the input logits array. + */ + void ComputeProbsFromLogitsInplace(NDArray logits, Array states, + Array generation_cfg) { + // logits: (n, v) + CHECK_EQ(logits->ndim, 2); + CHECK_EQ(logits->device.device_type, kDLCPU); + + // - Invoke environment compute function if exists. + if (flogits_to_probs_inplace_.defined()) { + flogits_to_probs_inplace_(logits, states, generation_cfg); + return; + } + + tvm::runtime::parallel_for_with_threading_backend( + [this, &logits, &states, &generation_cfg](int i) { + // - Apply repetition penalty (inplace). + if (generation_cfg[i]->repetition_penalty != 1.0) { + ApplyRepetitionPenaltyOnCPU(logits, i, states[i], + generation_cfg[i]->repetition_penalty); + } + // - Compute probability (inplace) from logits. + // Using softmax if temperature is non-zero. + // Or set probability of the max-logit position to 1. + if (generation_cfg[i]->temperature >= 1e-6) { + ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature); + } else { + SetProbWithArgmaxOnCPU(logits, i); + } + }, + 0, logits->shape[0]); + } + + /*! + * \brief Sample tokens from a batch of input probability distributions. + * \param probs The input batch of probability distributions. + * \param generation_cfg The generation config. + * \return The sampled tokens, one for each instance of the batch. + */ + std::vector SampleTokenFromProbs(NDArray probs, Array generation_cfg) { + // probs: (n, v) + CHECK_EQ(probs->ndim, 2); + CHECK_EQ(probs->device.device_type, kDLCPU); + + int n = probs->shape[0]; + std::vector random_numbers; + std::vector sampled_tokens; + random_numbers.reserve(n); + sampled_tokens.resize(n); + for (int i = 0; i < n; ++i) { + random_numbers.push_back(rng_.GetRandomNumber()); + } + + tvm::runtime::parallel_for_with_threading_backend( + [&sampled_tokens, &probs, &generation_cfg, &random_numbers](int i) { + // Sample top p from probability. + sampled_tokens[i] = + SampleTopPFromProb(probs, i, generation_cfg[i]->top_p, random_numbers[i]); + }, + 0, n); + return sampled_tokens; + } + + /*! \brief The random generator. */ + RandomGenerator& rng_; + /*! \brief Customized function which computes prob distribution from logits */ + PackedFunc flogits_to_probs_inplace_; + /*! \brief Shared array for logits and probability distributions on cpu. */ + NDArray logits_or_probs_on_cpu_{nullptr}; +}; + +/*********************** Sampler ***********************/ + +TVM_REGISTER_OBJECT_TYPE(SamplerObj); + +Sampler Sampler::Create(std::string sampler_kind) { + if (sampler_kind == "cpu") { + return Sampler(make_object()); + } else { + LOG(FATAL) << "Unsupported sampler_kind \"" << sampler_kind << "\""; + throw; + } +} + namespace detail { // The detailed implementation of `parallel_for_with_threading_backend`. diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler.h index 758ada92ac..45158d7c37 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler.h @@ -11,6 +11,8 @@ #include #include "../base.h" +#include "model.h" +#include "request_state.h" namespace mlc { namespace llm { @@ -20,11 +22,45 @@ using tvm::Device; using namespace tvm::runtime; /*! - * \brief Create the runtime module for sampler functions. - * \param device The device to run the sampling-related functions on. - * \return The created runtime module. + * \brief The base class of runtime sampler. + * Its main function is `SampleTokens`, which takes a batch of + * logits and corresponding configuration, and sample one token + * for each instance of the batch. */ -MLC_LLM_DLL tvm::runtime::Module CreateSamplerModule(DLDevice device); +class SamplerObj : public Object { + public: + /*! + * \brief Sample tokens from the input batch of logits. + * \param logits_on_device The logits to sample tokens from. + * \param model The LLM model which contains the softmax + * function on device that might be used to compute probability distribution. + * \param request_mstates The request states of each sequence in + * the batch with regard to the given model. + * \param generation_cfg The generation config of each request + * in the input batch. + * \return The sampled tokens, one for each request in the batch. + */ + virtual std::vector SampleTokens(NDArray logits_on_device, Model model, + Array request_mstates, + Array generation_cfg) = 0; + + static constexpr const char* _type_key = "mlc.serve.Sampler"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(SamplerObj, Object); +}; + +class Sampler : public ObjectRef { + public: + /*! + * \brief Create the runtime sampler module. + * \param sampler_kind The sampler name denoting which sampler to create. + * \return The created runtime module. + */ + TVM_DLL static Sampler Create(std::string sampler_kind); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Sampler, ObjectRef, SamplerObj); +}; } // namespace serve } // namespace llm diff --git a/cpp/serve/tokenizer.cc b/cpp/serve/tokenizer.cc deleted file mode 100644 index a6c53d5ff6..0000000000 --- a/cpp/serve/tokenizer.cc +++ /dev/null @@ -1,82 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file serve/tokenizer.cc - * \brief The implementation of runtime module of tokenizer encode/decode functions. - */ -#define __STDC_FORMAT_MACROS - -#include "tokenizer.h" - -#include -#include -#include -#include -#include - -#include "../tokenizers.h" - -namespace mlc { -namespace llm { -namespace serve { - -/*! - * \brief The tokenizer runtime module. - * It contains the encode and decode functions of the tokenizer. - */ -class TokenizerModule : public ModuleNode { - public: - explicit TokenizerModule(String model_path) { tokenizer_ = TokenizerFromPath(model_path); } - - // overrides - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { - if (name == "tokenize") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - *rv = Tokenize(args[0]); - }); - } else if (name == "decode") { - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - *rv = Decode(args[0]); - }); - } else { - return PackedFunc(nullptr); - } - } - - const char* type_key() const final { return "mlc.serve.Tokenizer"; } - - private: - /*! - * \brief Encode the input text to token ids. - * \param text The input to be tokenized - * \return The tokenization result. - */ - ShapeTuple Tokenize(std::string text) { - CHECK(tokenizer_ != nullptr) << "Tokenizer is not initialized."; - std::vector token_ids = this->tokenizer_->Encode(text); - return ShapeTuple(token_ids.begin(), token_ids.end()); - } - - /*! - * \brief Decode the input token ids to text. - * \param token_ids The input token ids to decode. - * \return The decode result. - */ - std::string Decode(ShapeTuple token_ids) { - CHECK(tokenizer_ != nullptr) << "Tokenizer is not initialized."; - return this->tokenizer_->Decode(std::vector(token_ids.begin(), token_ids.end())); - } - - /*! \brief The tokenizer pointer. */ - std::unique_ptr tokenizer_; -}; - -tvm::runtime::Module CreateTokenizerModule(String model_path) { - ObjectPtr n = make_object(model_path); - return Module(n); -} - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/serve/tokenizer.h b/cpp/serve/tokenizer.h deleted file mode 100644 index 39bc1d9aca..0000000000 --- a/cpp/serve/tokenizer.h +++ /dev/null @@ -1,33 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file serve/tokenizer.h - * \brief The header for runtime module of tokenizer encode/decode functions. - */ - -#ifndef MLC_LLM_SERVE_TOKENIZER_H_ -#define MLC_LLM_SERVE_TOKENIZER_H_ - -#include -#include - -#include "../base.h" - -namespace mlc { -namespace llm { -namespace serve { - -using tvm::Device; -using namespace tvm::runtime; - -/*! - * \brief Create the runtime module for tokenizer encode/decode functions. - * \param model_path The path to the model weights which also contains the tokenizer. - * \return The created runtime module. - */ -MLC_LLM_DLL tvm::runtime::Module CreateTokenizerModule(String model_path); - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_SERVE_TOKENIZER_H_ diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index f9c0ec6c99..308d2c4969 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -101,7 +101,6 @@ def __init__( # - Set the engine functions self._reload_func = engine["reload"] - self._unload_func = engine["unload"] self._add_request_func = engine["add_request"] self._abort_func = engine["abort"] self._step_func = engine["step"] @@ -175,6 +174,7 @@ def fcallback(request: Request, output: data.Data): # pylint: disable=unused-ar ) self.add_request( Request( + request_id=str(req_id), inputs=input_data, generation_config=generation_cfg, fcallback=callback_getter(req_id), diff --git a/python/mlc_chat/serve/request.py b/python/mlc_chat/serve/request.py index 39332c22ee..1cb4b0d344 100644 --- a/python/mlc_chat/serve/request.py +++ b/python/mlc_chat/serve/request.py @@ -17,6 +17,10 @@ class Request(Object): Parameters ---------- + request_id : str + The unique identifier of the request. + Different requests should have different ids. + inputs : List[Data] The user inputs of a request. Input may have multi-modality. @@ -32,6 +36,7 @@ class Request(Object): def __init__( self, + request_id: str, inputs: Union[Data, List[Data]], generation_config: GenerationConfig, fcallback: Callable[["Request", Data], None], @@ -40,6 +45,7 @@ def __init__( inputs = [inputs] self.__init_handle_by_constructor__( _ffi_api.Request, # type: ignore # pylint: disable=no-member + request_id, inputs, generation_config.asjson(), fcallback, diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 19f572bb77..f4bc707a9c 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -11,7 +11,7 @@ def _parse_args(): args = argparse.ArgumentParser() args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q0f16") args.add_argument("--device", type=str, default="auto") - args.add_argument("--batch-size", type=int, default=128) + args.add_argument("--batch-size", type=int, default=80) args.add_argument("--page-size", type=int, default=16) args.add_argument("--max-total-seq-length", type=int, default=16000) args.add_argument("--seed", type=int, default=0) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index b47595202d..0739055704 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -38,6 +38,7 @@ def create_requests( max_new_tokens = np.random.randint(max_new_tokens_low, max_new_tokens_high) requests.append( Request( + request_id=str(req_id), inputs=data.TextData(prompt), generation_config=GenerationConfig( temperature=temperature,