diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 7379bad7ed..f36bc151a3 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -239,8 +239,8 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array ad Array additional_model_lib_paths, DLDevice device, int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, SpeculativeMode speculative_mode, - int spec_draft_length) { + int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind, + SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); @@ -252,6 +252,8 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array ad n->max_total_sequence_length = max_total_sequence_length; n->max_single_sequence_length = max_single_sequence_length; n->prefill_chunk_size = prefill_chunk_size; + n->max_history_size = max_history_size; + n->kv_state_kind = kv_state_kind; n->spec_draft_length = spec_draft_length; n->speculative_mode = speculative_mode; data_ = std::move(n); @@ -261,12 +263,13 @@ TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, Array additional_model_lib_paths, DLDevice device, int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, - int spec_draft_length) { + int max_single_sequence_length, int prefill_chunk_size, int max_history_size, + int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), std::move(additional_model_lib_paths), device, kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), + SpeculativeMode(speculative_mode), spec_draft_length); }); } // namespace serve diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 41ddb3c6e4..ef147b751b 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -84,6 +84,12 @@ enum class SpeculativeMode : int { kEagle = 2, }; +/*! \brief The kind of cache. */ +enum KVStateKind { + kAttention = 0, + kRNNState = 1, +}; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -121,6 +127,10 @@ class EngineConfigNode : public Object { int max_single_sequence_length; /*! \brief The maximum total sequence length in a prefill. */ int prefill_chunk_size; + /*! \brief The maximum history size for RNN state. KV cache does not need this. */ + int max_history_size; + /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ + KVStateKind kv_state_kind; /*************** Speculative decoding ***************/ @@ -143,6 +153,7 @@ class EngineConfig : public ObjectRef { Array additional_model_lib_paths, DLDevice device, int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 8568c6ce94..0348f7f40a 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -69,7 +69,8 @@ class EngineImpl : public Engine { /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size); + engine_config->prefill_chunk_size, engine_config->max_history_size, + engine_config->kv_state_kind); CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index c80c5e0ede..b4192a04f1 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -396,6 +396,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { int num_running_rsentries) { ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + // For RNN State, it can prefill as long as it can be instantiated. + if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + return true; + } + // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index b33d3709e8..b721eae7c3 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -244,7 +244,12 @@ void FunctionTable::_InitFunctions() { this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { - this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); + if (f_create_rnn_state.defined()) { + this->create_kv_cache_func_ = f_create_rnn_state; + } else { + this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + } ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 3583b5d84b..27a0043850 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -13,6 +13,7 @@ #include +#include "config.h" #include "logit_processor.h" namespace mlc { @@ -68,6 +69,12 @@ class ModelImpl : public ModelObj { token_ids_storage_ = memory::Storage( allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); + // Step 7. Set model type + if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + this->kind = KVStateKind::kRNNState; + } else { + this->kind = KVStateKind::kAttention; + } } /*********************** Model Computation ***********************/ @@ -739,16 +746,26 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) final { - IntTuple max_num_sequence_tuple{max_num_sequence}; - IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; - IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; - IntTuple page_size_tuple{page_size}; - IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, - prefill_chunk_size_tuple, page_size_tuple, - support_sliding_window); - local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) final { + if (kv_state_kind == KVStateKind::kAttention) { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; + IntTuple support_sliding_window{sliding_window_size_ != -1}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } } void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } @@ -775,11 +792,21 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ int GetNumAvailablePages() const final { - return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not introduce new page at runtime + return std::numeric_limits::max(); + } else { + return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + } } int GetCurrentTotalSequenceLength() const final { - return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not have a total sequence length limit + return 0; + } else { + return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + } } /*********************** Utilities ***********************/ @@ -946,6 +973,8 @@ class ModelImpl : public ModelObj { NDArray logit_pos_arr_{nullptr}; // A boolean indicating if tracing is enabled. bool trace_enabled_; + // An enum indicating whether it's RNN-based. + KVStateKind kind; }; TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") diff --git a/cpp/serve/model.h b/cpp/serve/model.h index da532f83e8..045daff874 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -234,9 +234,13 @@ class ModelObj : public Object { * in the engine. * \param prefill_chunk_size The maximum total number of tokens whose KV data * are allowed to exist in the KV cache at any time. + * \param max_history_size The maximum history size for RNN state to roll back. + * The KV cache does not need this. + * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) = 0; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 9f7c1c3580..6663a0c230 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -44,6 +44,9 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] + ) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) @@ -100,6 +103,7 @@ def main(argv): max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, speculative_mode=SpeculativeMode[parsed.speculative_mode], spec_draft_length=parsed.spec_draft_length, diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 917e229632..1c599fa875 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -365,7 +365,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: # RWKV World ConvTemplateRegistry.register_conv_template( Conversation( - name="rwkv-world", + name="rwkv_world", system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", system_message=( "Hi. I am your assistant and I will provide expert full response " diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 14e5cee321..86930fa5ea 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -152,6 +152,11 @@ The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "max_history_size_serve": """ +The maximum history length for rolling back the RNN state. +If unspecified, the default value is 1. +KV cache does not need this. """.strip(), "enable_tracing_serve": """ Enable Chrome Tracing for the server. diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index d0cbd4690b..40fa9fdda8 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -22,6 +22,7 @@ def serve( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -44,6 +45,7 @@ def serve( max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 49386720da..81c9e9aa7f 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -40,6 +40,7 @@ class RWKV5Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -129,23 +130,18 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - # x.shape = (batch, seq_len, hidden_size) - # state.shape = (batch, hidden_size) - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): # x.shape = (batch, seq_len, hidden_size) batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -350,10 +346,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -367,11 +367,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -386,7 +402,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -396,9 +411,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -406,7 +419,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -414,8 +452,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index 0e1887310d..a8faf48a6b 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -40,6 +40,7 @@ class RWKV6Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -126,20 +127,17 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -390,10 +388,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -407,11 +409,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -426,7 +444,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -436,9 +453,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -446,7 +461,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -454,8 +494,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 60e4eca8c5..40c53e336a 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -128,6 +128,13 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) +class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods + """Possible kinds of KV state.""" + + ATTENTION = 0 + RNNSTATE = 1 + + class SpeculativeMode(enum.IntEnum): """The speculative mode.""" @@ -177,6 +184,12 @@ class EngineConfig(tvm.runtime.Object): prefill_chunk_size : int The maximum total sequence length in a prefill. + max_history_size: int + The maximum history size for RNN state to rool back. + + kv_state_kind: KVStateKind + The kind of cache. + speculative_mode : SpeculativeMode The speculative mode. @@ -196,6 +209,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length: int, max_single_sequence_length: int, prefill_chunk_size: int, + max_history_size: int, + kv_state_kind: KVStateKind, speculative_mode: SpeculativeMode, spec_draft_length: int, ) -> None: @@ -211,6 +226,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, + kv_state_kind, speculative_mode, spec_draft_length, ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index d9721b4864..413c856db1 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -816,6 +816,9 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. + max_history_size : Optional[int] + The maximum history for RNN state. + gpu_memory_utilization : Optional[float] A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. @@ -846,6 +849,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -861,6 +865,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, @@ -1392,6 +1397,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -1407,6 +1413,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 7b2ede60b2..5d62dd5fb1 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -17,10 +17,16 @@ from tvm.runtime import Device from mlc_llm.chat_module import _get_chat_config, _get_lib_module_path, _get_model_path +from mlc_llm.cli.model_metadata import _compute_memory_usage, _extract_metadata from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import ( + EngineConfig, + GenerationConfig, + KVStateKind, + SpeculativeMode, +) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -121,7 +127,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments +def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments models: List[ModelInfo], device: tvm.runtime.Device, model_config_paths: List[str], @@ -240,6 +246,77 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma ) +def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, int]: + # Get single-card GPU size. + gpu_size_bytes = device.total_global_memory + if gpu_size_bytes is None: + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 + + rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 + param_bytes = 0.0 + model_workspace_bytes = 0.0 + logit_processor_workspace_bytes = 0.0 + for model, model_config_dict in zip(models, model_config_dicts): + model_config = model_config_dict["model_config"] + vocab_size = model_config_dict["vocab_size"] + head_size = model_config["head_size"] + num_heads = model_config["num_heads"] + num_layers = model_config["num_hidden_layers"] + hidden_size = model_config["hidden_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + + rnn_state_base_bytes += ( + max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 + ) + + metadata = _extract_metadata(Path(model.model_lib_path)) + metadata["memory_usage"] = {} + metadata["kv_cache_bytes"] = 0 + current_param_bytes, _, _ = _compute_memory_usage(metadata, model_config_dict) + param_bytes += current_param_bytes + + max_history_size = int( + ( + gpu_size_bytes * gpu_memory_utilization + - logit_processor_workspace_bytes + - model_workspace_bytes + - param_bytes + ) + / rnn_state_base_bytes + ) + if max_history_size < 1: + raise ValueError( + f"Memory required by models may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) + + return ( + param_bytes, + model_workspace_bytes + logit_processor_workspace_bytes, + rnn_state_base_bytes, + max_history_size, + ) + + def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: """Read the model config dictionaries, and return the maximum single sequence length the models can support, the maximum prefill chunk @@ -294,7 +371,7 @@ def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[i return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements +def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements mode: Literal["local", "interactive", "server"], max_batch_size: Optional[int], max_total_sequence_length: Optional[int], @@ -304,12 +381,13 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local device: tvm.runtime.Device, model_config_dicts: List[Dict[str, Any]], model_config_paths: List[str], -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, KVStateKind, int]: """Initialize the KV cache config with user input and GPU memory usage estimation. The returned four integers are: - max_batch_size - max_total_sequence_length - prefill_chunk_size + - kv_state_kind - model_max_single_sequence_length """ ( @@ -323,7 +401,7 @@ def infer_args_under_mode( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int], List[float]]: + ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: logging_msg = "" # - max_batch_size if max_batch_size is None: @@ -343,7 +421,7 @@ def infer_args_under_mode( kv_aux_workspace_bytes, temp_workspace_bytes, model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length( + ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( models, device, model_config_paths, @@ -400,7 +478,12 @@ def infer_args_under_mode( # - Construct the KV cache config # - Estimate total GPU memory usage on single GPU. - return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + KVStateKind.ATTENTION, + ), [ total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, model_params_bytes, kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, @@ -462,6 +545,167 @@ def infer_args_under_mode( return *kv_cache_config, model_max_single_sequence_length +def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, KVStateKind, int]: + """Initialize the RNN state config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - kv_state_kind + - max_history_size + """ + logging_msg = "" + prefill_chunk_size = 0 + + if prefill_chunk_size is None: + prefill_chunk_size = min( + config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 + for config in model_config_dicts + ) + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + if max_batch_size is None: + max_batch_size = 1 if mode == "interactive" else 4 + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + + if mode == "local": + logging_msg += ( + "We choose small max batch size and RNN state capacity to use less GPU memory." + ) + elif mode == "interactive": + logging_msg += "We fix max batch size to 1 for interactive single sequence use." + else: + logging_msg += ( + "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + ) + logger.info('Under mode "%s", %s', mode, logging_msg) + + ( + model_param_bytes, + model_temp_bytes, + model_rnn_state_base_bytes, + model_max_history_size, + ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + if max_history_size is None: + max_history_size = model_max_history_size + else: + max_history_size = min(max_history_size, model_max_history_size) + max_total_sequence_length = 32768 + prefill_chunk_size = 0 + kind = KVStateKind.RNNSTATE + + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", + green("Estimated total single GPU memory usage"), + (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, + model_param_bytes / 1024 / 1024, + max_history_size * model_rnn_state_base_bytes / 1024 / 1024, + model_temp_bytes / 1024 / 1024, + ) + + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kind, + max_history_size, + ) + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, int, int, KVStateKind]: + """Initialize the cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - max_single_sequence_length + - max_history_size + - kv_state_kind + """ + if all("rwkv" not in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_single_sequence_length, + ) = _infer_kv_cache_config_for_kv_cache( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_history_size = 0 # KV cache doesn't need this + elif all("rwkv" in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_history_size, + ) = _infer_kv_cache_config_for_rnn_state( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this + else: + raise ValueError("The models should be either all KV cache models or all RNN state models.") + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) + + @dataclass class CallbackStreamOutput: """The output of MLCEngine._generate and AsyncMLCEngine._generate @@ -728,6 +972,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -757,11 +1002,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -803,6 +1051,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 257338da3a..7469ddc241 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -98,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, @@ -128,11 +129,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -168,6 +172,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 9b594e9784..c0c749c0a7 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -89,6 +89,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, prefill_chunk_size: Optional[int] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -118,11 +119,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -162,6 +166,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index f965e8cc82..37d1833b14 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,6 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List +import pytest + from mlc_llm.serve import GenerationConfig, MLCEngine prompts = [ @@ -17,17 +19,39 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +test_models = [ + ( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ), + ( + "dist/rwkv-6-world-1b6-q0f16-MLC", + "dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so", + ), +] -def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + +def create_engine(model: str, model_lib_path: str): + if "rwkv" in model: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_batch_size=8, + max_history_size=1, + ) + else: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_engine_generate(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 10 max_tokens = 256 @@ -57,16 +81,10 @@ def test_engine_generate(): del engine -def test_chat_completion(): +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion(model: str, model_lib_path: str): # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -101,16 +119,9 @@ def test_chat_completion(): del engine -def test_chat_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -144,16 +155,9 @@ def test_chat_completion_non_stream(): del engine -def test_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -188,16 +192,9 @@ def test_completion(): del engine -def test_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -232,8 +229,9 @@ def test_completion_non_stream(): if __name__ == "__main__": - test_engine_generate() - test_chat_completion() - test_chat_completion_non_stream() - test_completion() - test_completion_non_stream() + for model, model_lib_path in test_models: + test_engine_generate(model, model_lib_path) + test_chat_completion(model, model_lib_path) + test_chat_completion_non_stream(model, model_lib_path) + test_completion(model, model_lib_path) + test_completion_non_stream(model, model_lib_path)