Skip to content

Commit

Permalink
[Serving] Support RWKV for serving (#2111)
Browse files Browse the repository at this point in the history
feat: support serving for rwkv
  • Loading branch information
Celve authored Apr 25, 2024
1 parent 85fffee commit 71c7b3c
Show file tree
Hide file tree
Showing 19 changed files with 543 additions and 112 deletions.
13 changes: 8 additions & 5 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> ad
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence,
int max_total_sequence_length, int max_single_sequence_length,
int prefill_chunk_size, SpeculativeMode speculative_mode,
int spec_draft_length) {
int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length) {
ObjectPtr<EngineConfigNode> n = make_object<EngineConfigNode>();
n->model = std::move(model);
n->model_lib_path = std::move(model_lib_path);
Expand All @@ -252,6 +252,8 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> ad
n->max_total_sequence_length = max_total_sequence_length;
n->max_single_sequence_length = max_single_sequence_length;
n->prefill_chunk_size = prefill_chunk_size;
n->max_history_size = max_history_size;
n->kv_state_kind = kv_state_kind;
n->spec_draft_length = spec_draft_length;
n->speculative_mode = speculative_mode;
data_ = std::move(n);
Expand All @@ -261,12 +263,13 @@ TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig")
.set_body_typed([](String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size, int speculative_mode,
int spec_draft_length) {
int max_single_sequence_length, int prefill_chunk_size, int max_history_size,
int kv_state_kind, int speculative_mode, int spec_draft_length) {
return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models),
std::move(additional_model_lib_paths), device, kv_cache_page_size,
max_num_sequence, max_total_sequence_length, max_single_sequence_length,
prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length);
prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind),
SpeculativeMode(speculative_mode), spec_draft_length);
});

} // namespace serve
Expand Down
11 changes: 11 additions & 0 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ enum class SpeculativeMode : int {
kEagle = 2,
};

/*! \brief The kind of cache. */
enum KVStateKind {
kAttention = 0,
kRNNState = 1,
};

/*! \brief The configuration of engine execution config. */
class EngineConfigNode : public Object {
public:
Expand Down Expand Up @@ -121,6 +127,10 @@ class EngineConfigNode : public Object {
int max_single_sequence_length;
/*! \brief The maximum total sequence length in a prefill. */
int prefill_chunk_size;
/*! \brief The maximum history size for RNN state. KV cache does not need this. */
int max_history_size;
/*! \brief The kind of cache. Whether it's KV cache or RNN state. */
KVStateKind kv_state_kind;

/*************** Speculative decoding ***************/

Expand All @@ -143,6 +153,7 @@ class EngineConfig : public ObjectRef {
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size,
int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode);
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class EngineImpl : public Engine {
/*trace_enabled=*/trace_recorder.defined());
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
engine_config->max_total_sequence_length,
engine_config->prefill_chunk_size);
engine_config->prefill_chunk_size, engine_config->max_history_size,
engine_config->kv_state_kind);
CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length)
<< "The window size of the model, " << model->GetMaxWindowSize()
<< ", is smaller than the pre-defined max single sequence length, "
Expand Down
5 changes: 5 additions & 0 deletions cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ class NewRequestPrefillActionObj : public EngineActionObj {
int num_running_rsentries) {
ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence);

// For RNN State, it can prefill as long as it can be instantiated.
if (engine_config_->kv_state_kind == KVStateKind::kRNNState) {
return true;
}

// No exceeding of the maximum allowed requests that can
// run simultaneously.
int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable
Expand Down
7 changes: 6 additions & 1 deletion cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,12 @@ void FunctionTable::_InitFunctions() {
this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor");
this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache");
if (!this->create_kv_cache_func_.defined()) {
this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache");
PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state");
if (f_create_rnn_state.defined()) {
this->create_kv_cache_func_ = f_create_rnn_state;
} else {
this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache");
}
ICHECK(this->create_kv_cache_func_.defined());
}
this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear");
Expand Down
53 changes: 41 additions & 12 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <fstream>

#include "config.h"
#include "logit_processor.h"

namespace mlc {
Expand Down Expand Up @@ -68,6 +69,12 @@ class ModelImpl : public ModelObj {
token_ids_storage_ = memory::Storage(
allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator);
this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host);
// Step 7. Set model type
if (model_config["model_type"].get<std::string>().find("rwkv") != std::string::npos) {
this->kind = KVStateKind::kRNNState;
} else {
this->kind = KVStateKind::kAttention;
}
}

/*********************** Model Computation ***********************/
Expand Down Expand Up @@ -739,16 +746,26 @@ class ModelImpl : public ModelObj {
/*********************** KV Cache Management ***********************/

void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length,
int prefill_chunk_size) final {
IntTuple max_num_sequence_tuple{max_num_sequence};
IntTuple max_total_sequence_length_tuple{max_total_sequence_length};
IntTuple prefill_chunk_size_tuple{prefill_chunk_size};
IntTuple page_size_tuple{page_size};
IntTuple support_sliding_window{sliding_window_size_ != -1};
kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple,
prefill_chunk_size_tuple, page_size_tuple,
support_sliding_window);
local_kv_cache_ = ft_.use_disco ? Downcast<DRef>(kv_cache_)->DebugGetFromRemote(0) : kv_cache_;
int prefill_chunk_size, int max_history_size,
KVStateKind kv_state_kind) final {
if (kv_state_kind == KVStateKind::kAttention) {
IntTuple max_num_sequence_tuple{max_num_sequence};
IntTuple max_total_sequence_length_tuple{max_total_sequence_length};
IntTuple prefill_chunk_size_tuple{prefill_chunk_size};
IntTuple page_size_tuple{page_size};
IntTuple support_sliding_window{sliding_window_size_ != -1};
kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple,
prefill_chunk_size_tuple, page_size_tuple,
support_sliding_window);
local_kv_cache_ =
ft_.use_disco ? Downcast<DRef>(kv_cache_)->DebugGetFromRemote(0) : kv_cache_;
} else {
IntTuple max_num_sequence_tuple{max_num_sequence};
IntTuple max_history_size_tuple = {std::max(max_history_size, 1)};
kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple);
local_kv_cache_ =
ft_.use_disco ? Downcast<DRef>(kv_cache_)->DebugGetFromRemote(0) : kv_cache_;
}
}

void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); }
Expand All @@ -775,11 +792,21 @@ class ModelImpl : public ModelObj {
/************** Raw Info Query **************/

int GetNumAvailablePages() const final {
return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_);
if (this->kind == KVStateKind::kRNNState) {
// RNNState does not introduce new page at runtime
return std::numeric_limits<int>::max();
} else {
return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_);
}
}

int GetCurrentTotalSequenceLength() const final {
return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_);
if (this->kind == KVStateKind::kRNNState) {
// RNNState does not have a total sequence length limit
return 0;
} else {
return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_);
}
}

/*********************** Utilities ***********************/
Expand Down Expand Up @@ -946,6 +973,8 @@ class ModelImpl : public ModelObj {
NDArray logit_pos_arr_{nullptr};
// A boolean indicating if tracing is enabled.
bool trace_enabled_;
// An enum indicating whether it's RNN-based.
KVStateKind kind;
};

TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset")
Expand Down
6 changes: 5 additions & 1 deletion cpp/serve/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,13 @@ class ModelObj : public Object {
* in the engine.
* \param prefill_chunk_size The maximum total number of tokens whose KV data
* are allowed to exist in the KV cache at any time.
* \param max_history_size The maximum history size for RNN state to roll back.
* The KV cache does not need this.
* \param kv_state_kind The kind of cache. It can be KV cache or RNN state.
*/
virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length,
int prefill_chunk_size) = 0;
int prefill_chunk_size, int max_history_size,
KVStateKind kv_state_kind) = 0;

/*! \brief Add a new sequence with the given sequence id to the KV cache. */
virtual void AddNewSequence(int64_t seq_id) = 0;
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def main(argv):
"--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"]
)
parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"])
parser.add_argument(
"--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"]
)
parser.add_argument(
"--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"]
)
Expand Down Expand Up @@ -100,6 +103,7 @@ def main(argv):
max_batch_size=parsed.max_batch_size,
max_total_sequence_length=parsed.max_total_seq_length,
prefill_chunk_size=parsed.prefill_chunk_size,
max_history_size=parsed.max_history_size,
gpu_memory_utilization=parsed.gpu_memory_utilization,
speculative_mode=SpeculativeMode[parsed.speculative_mode],
spec_draft_length=parsed.spec_draft_length,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/conversation_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def get_conv_template(name: str) -> Optional[Conversation]:
# RWKV World
ConvTemplateRegistry.register_conv_template(
Conversation(
name="rwkv-world",
name="rwkv_world",
system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}",
system_message=(
"Hi. I am your assistant and I will provide expert full response "
Expand Down
5 changes: 5 additions & 0 deletions python/mlc_llm/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@
The maximum number of tokens the model passes for prefill each time.
It should not exceed the prefill chunk size in model config.
If not specified, this defaults to the prefill chunk size in model config.
""".strip(),
"max_history_size_serve": """
The maximum history length for rolling back the RNN state.
If unspecified, the default value is 1.
KV cache does not need this.
""".strip(),
"enable_tracing_serve": """
Enable Chrome Tracing for the server.
Expand Down
2 changes: 2 additions & 0 deletions python/mlc_llm/interface/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def serve(
max_batch_size: Optional[int],
max_total_sequence_length: Optional[int],
prefill_chunk_size: Optional[int],
max_history_size: Optional[int],
gpu_memory_utilization: Optional[float],
speculative_mode: SpeculativeMode,
spec_draft_length: int,
Expand All @@ -44,6 +45,7 @@ def serve(
max_batch_size=max_batch_size,
max_total_sequence_length=max_total_sequence_length,
prefill_chunk_size=prefill_chunk_size,
max_history_size=max_history_size,
gpu_memory_utilization=gpu_memory_utilization,
speculative_mode=speculative_mode,
spec_draft_length=spec_draft_length,
Expand Down
Loading

0 comments on commit 71c7b3c

Please sign in to comment.