Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Celve committed Apr 23, 2024
1 parent 64d2a09 commit a3a1063
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig")
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size, int max_history_size,
KVStateKind kv_state_kind, int speculative_mode, int spec_draft_length) {
int kv_state_kind, int speculative_mode, int spec_draft_length) {
return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models),
std::move(additional_model_lib_paths), device, kv_cache_page_size,
max_num_sequence, max_total_sequence_length, max_single_sequence_length,
prefill_chunk_size, max_history_size, kv_state_kind,
prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind),
SpeculativeMode(speculative_mode), spec_draft_length);
});

Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class EngineImpl : public Engine {
/*trace_enabled=*/trace_recorder.defined());
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
engine_config->max_total_sequence_length,
engine_config->prefill_chunk_size);
engine_config->prefill_chunk_size, engine_config->max_history_size,
engine_config->kv_state_kind);
CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length)
<< "The window size of the model, " << model->GetMaxWindowSize()
<< ", is smaller than the pre-defined max single sequence length, "
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class NewRequestPrefillActionObj : public EngineActionObj {
ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence);

// For RNN State, it can prefill as long as it can be instantiated.
if (kv_cache_config_->kind == KVStateKind::kRNNState) {
if (engine_config_->kv_state_kind == KVStateKind::kRNNState) {
return true;
}

Expand Down
10 changes: 10 additions & 0 deletions python/mlc_llm/serve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ class EngineConfig(tvm.runtime.Object):
prefill_chunk_size : int
The maximum total sequence length in a prefill.
max_history_size: int
The maximum history size for RNN state to rool back.
kv_state_kind: KVStateKind
The kind of cache.
speculative_mode : SpeculativeMode
The speculative mode.
Expand All @@ -203,6 +209,8 @@ def __init__( # pylint: disable=too-many-arguments
max_total_sequence_length: int,
max_single_sequence_length: int,
prefill_chunk_size: int,
max_history_size: int,
kv_state_kind: KVStateKind,
speculative_mode: SpeculativeMode,
spec_draft_length: int,
) -> None:
Expand All @@ -218,6 +226,8 @@ def __init__( # pylint: disable=too-many-arguments
max_total_sequence_length,
max_single_sequence_length,
prefill_chunk_size,
max_history_size,
kv_state_kind,
speculative_mode,
spec_draft_length,
)
76 changes: 58 additions & 18 deletions python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,13 @@ def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,t
device: tvm.runtime.Device,
model_config_dicts: List[Dict[str, Any]],
model_config_paths: List[str],
) -> Tuple[int, int, int, int, KVStateKind]:
) -> Tuple[int, int, int, KVStateKind, int]:
"""Initialize the KV cache config with user input and GPU memory usage estimation.
The returned four integers are:
- max_batch_size
- max_total_sequence_length
- prefill_chunk_size
- kv_state_kind
- model_max_single_sequence_length
"""
(
Expand Down Expand Up @@ -476,7 +477,6 @@ def infer_args_under_mode(
max_batch_size,
max_total_sequence_length,
prefill_chunk_size,
0, # max_history_size, placeholder for RNN state
KVStateKind.ATTENTION,
), [
total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token,
Expand Down Expand Up @@ -551,7 +551,15 @@ def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,
device: tvm.runtime.Device,
model_config_dicts: List[Dict[str, Any]],
model_config_paths: List[str],
) -> Tuple[Tuple[int, int, int, int, KVStateKind], int]:
) -> Tuple[int, int, int, KVStateKind, int]:
"""Initialize the RNN state config with user input and GPU memory usage estimation.
The returned four integers are:
- max_batch_size
- max_total_sequence_length
- prefill_chunk_size
- kv_state_kind
- max_history_size
"""
logging_msg = ""
if prefill_chunk_size is None:
if "prefill_chunk_size" in model_config_dicts:
Expand Down Expand Up @@ -596,8 +604,7 @@ def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,
max_history_size = model_max_history_size
else:
max_history_size = min(max_history_size, model_max_history_size)
model_max_single_sequence_length = 2147483647
max_total_sequence_length = 2147483647
max_total_sequence_length = 32768
prefill_chunk_size = 0
kind = KVStateKind.RNNSTATE

Expand All @@ -612,14 +619,11 @@ def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,
)

return (
(
max_batch_size,
max_total_sequence_length,
prefill_chunk_size,
max_history_size,
kind,
),
model_max_single_sequence_length,
max_batch_size,
max_total_sequence_length,
prefill_chunk_size,
kind,
max_history_size,
)


Expand All @@ -634,9 +638,24 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local
device: tvm.runtime.Device,
model_config_dicts: List[Dict[str, Any]],
model_config_paths: List[str],
) -> Tuple[int, int, int, int, KVStateKind]:
) -> Tuple[int, int, int, int, int, KVStateKind]:
"""Initialize the cache config with user input and GPU memory usage estimation.
The returned four integers are:
- max_batch_size
- max_total_sequence_length
- prefill_chunk_size
- max_single_sequence_length
- max_history_size
- kv_state_kind
"""
if all("rwkv" not in model.model for model in models):
return _infer_kv_cache_config_for_kv_cache(
(
max_batch_size,
max_total_sequence_length,
prefill_chunk_size,
kv_state_kind,
max_single_sequence_length,
) = _infer_kv_cache_config_for_kv_cache(
mode,
max_batch_size,
max_total_sequence_length,
Expand All @@ -647,8 +666,15 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local
model_config_dicts,
model_config_paths,
)
if all("rwkv" in model.model for model in models):
return _infer_kv_cache_config_for_rnn_state(
max_history_size = 0 # KV cache doesn't need this
elif all("rwkv" in model.model for model in models):
(
max_batch_size,
max_total_sequence_length,
prefill_chunk_size,
kv_state_kind,
max_history_size,
) = _infer_kv_cache_config_for_rnn_state(
mode,
max_batch_size,
max_total_sequence_length,
Expand All @@ -660,7 +686,17 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local
model_config_dicts,
model_config_paths,
)
raise ValueError("The models should be either all KV cache models or all RNN state models.")
max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this
else:
raise ValueError("The models should be either all KV cache models or all RNN state models.")
return (
max_batch_size,
max_total_sequence_length,
prefill_chunk_size,
max_single_sequence_length,
max_history_size,
kv_state_kind,
)


@dataclass
Expand Down Expand Up @@ -959,6 +995,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
max_total_sequence_length,
prefill_chunk_size,
max_single_sequence_length,
max_history_size,
kv_state_kind,
) = _infer_kv_cache_config(
mode,
max_batch_size,
Expand Down Expand Up @@ -1006,6 +1044,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
max_total_sequence_length=max_total_sequence_length,
max_single_sequence_length=max_single_sequence_length,
prefill_chunk_size=prefill_chunk_size,
max_history_size=max_history_size,
kv_state_kind=kv_state_kind,
speculative_mode=speculative_mode,
spec_draft_length=spec_draft_length,
)
Expand Down

0 comments on commit a3a1063

Please sign in to comment.