From 2b5d41ff47c68c603968e2c672c278a4dfa45da8 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Thu, 25 Apr 2024 08:12:35 -0400 Subject: [PATCH 1/3] Change OpenAI protocol default value to None in JSON FFI engine --- cpp/json_ffi/openai_api_protocol.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index bed225d3d0..1266e9cb93 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -90,8 +90,8 @@ class ChatCompletionRequest { public: std::vector messages; std::string model; - double frequency_penalty = 0.0; - double presence_penalty = 0.0; + std::optional frequency_penalty = std::nullopt; + std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; std::optional> logit_bias = std::nullopt; @@ -100,8 +100,8 @@ class ChatCompletionRequest { std::optional seed = std::nullopt; std::optional> stop = std::nullopt; bool stream = false; - double temperature = 1.0; - double top_p = 1.0; + std::optional temperature = std::nullopt; + std::optional top_p = std::nullopt; std::optional> tools = std::nullopt; std::optional tool_choice = std::nullopt; std::optional user = std::nullopt; From 4ccd937f68f97f81ce6f457a4bc9362ad51bab15 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Thu, 25 Apr 2024 18:32:16 -0400 Subject: [PATCH 2/3] [JSONFFIEngine] Support generation config in JSONFFIEngine. Default config values to NOT_GIVEN --- cpp/json_ffi/{conv_template.cc => config.cc} | 46 ++++++++++++++- cpp/json_ffi/{conv_template.h => config.h} | 56 ++++++++++++++++++- cpp/json_ffi/json_ffi_engine.cc | 10 ++-- cpp/json_ffi/json_ffi_engine.h | 3 +- cpp/json_ffi/openai_api_protocol.h | 2 +- cpp/metadata/json_parser.h | 16 ++++++ cpp/serve/config.cc | 24 +++++--- cpp/serve/config.h | 12 ++-- .../mlc_llm/protocol/openai_api_protocol.py | 2 +- python/mlc_llm/serve/engine_base.py | 22 ++++++++ python/mlc_llm/support/auto_config.py | 2 +- tests/python/json_ffi/_ffi_api.py | 6 ++ tests/python/json_ffi/test_json_ffi_engine.py | 52 +++++++++++++++-- 13 files changed, 224 insertions(+), 29 deletions(-) rename cpp/json_ffi/{conv_template.cc => config.cc} (85%) rename cpp/json_ffi/{conv_template.h => config.h} (67%) create mode 100644 tests/python/json_ffi/_ffi_api.py diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/config.cc similarity index 85% rename from cpp/json_ffi/conv_template.cc rename to cpp/json_ffi/config.cc index 02e0b3bdbd..8f5c0e1062 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/config.cc @@ -1,4 +1,6 @@ -#include "conv_template.h" +#include "config.h" + +#include #include "../metadata/json_parser.h" @@ -8,6 +10,29 @@ namespace json_ffi { using namespace mlc::llm; +/****************** Model-defined generation config ******************/ + +TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); + +ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, + double frequency_penalty, + double presence_penalty) { + ObjectPtr n = make_object(); + n->temperature = temperature; + n->top_p = top_p; + n->frequency_penalty = frequency_penalty; + n->presence_penalty = presence_penalty; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") + .set_body_typed([](double temperature, double top_p, double frequency_penalty, + double presence_penalty) { + return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); + }); + +/****************** Conversation template ******************/ + std::map PLACEHOLDERS = { {MessagePlaceholders::SYSTEM, "{system_message}"}, {MessagePlaceholders::USER, "{user_message}"}, @@ -308,6 +333,25 @@ std::optional Conversation::FromJSON(const std::string& json_str, } return Conversation::FromJSON(json_obj.value(), err); } + +/****************** JSON FFI engine config ******************/ + +TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); + +JSONFFIEngineConfig::JSONFFIEngineConfig( + String conv_template, Map model_generation_cfgs) { + ObjectPtr n = make_object(); + n->conv_template = conv_template; + n->model_generation_cfgs = model_generation_cfgs; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") + .set_body_typed([](String conv_template, + Map model_generation_cfgs) { + return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); + }); + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/config.h similarity index 67% rename from cpp/json_ffi/conv_template.h rename to cpp/json_ffi/config.h index d3a1d1de2f..78c3cb16a9 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/config.h @@ -1,5 +1,9 @@ -#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H -#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H +#ifndef MLC_LLM_JSON_FFI_CONFIG_H +#define MLC_LLM_JSON_FFI_CONFIG_H + +#include +#include +#include #include #include @@ -18,6 +22,32 @@ namespace mlc { namespace llm { namespace json_ffi { +/****************** Model-defined generation config ******************/ + +class ModelDefinedGenerationConfigNode : public Object { + public: + double temperature; + double top_p; + double frequency_penalty; + double presence_penalty; + + static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); +}; + +class ModelDefinedGenerationConfig : public ObjectRef { + public: + explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, + double presence_penalty); + + TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, + ModelDefinedGenerationConfigNode); +}; + +/****************** Conversation template ******************/ + enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; MessagePlaceholders messagePlaceholderFromString(const std::string& role); @@ -114,6 +144,28 @@ struct Conversation { static std::optional FromJSON(const std::string& json_str, std::string* err); }; +/****************** JSON FFI engine config ******************/ + +class JSONFFIEngineConfigNode : public Object { + public: + String conv_template; + Map model_generation_cfgs; + + static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); +}; + +class JSONFFIEngineConfig : public ObjectRef { + public: + explicit JSONFFIEngineConfig( + String conv_template, + Map model_generation_cfgs); + + TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); +}; + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 0e21735e2f..1a21c2962d 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -83,8 +83,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = - GenerationConfig::FromJSON(request_json_str, &err_, conv_template); + Optional generation_cfg = GenerationConfig::Create( + request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); if (!generation_cfg.defined()) { return false; } @@ -122,14 +122,16 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(std::string conv_template_str, EngineConfig engine_config, + void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, Optional request_stream_callback, Optional trace_recorder) { - std::optional conv_template = Conversation::FromJSON(conv_template_str, &err_); + std::optional conv_template = + Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); if (!conv_template.has_value()) { LOG(FATAL) << "Invalid conversation template JSON: " << err_; } this->conv_template_ = conv_template.value(); + this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; // Todo(mlc-team): decouple InitBackgroundEngine into two functions // by removing `engine_config` from arguments, after properly handling diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 2c7501c337..d57384abb5 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,7 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" -#include "conv_template.h" +#include "config.h" #include "openai_api_protocol.h" namespace mlc { @@ -49,6 +49,7 @@ class JSONFFIEngine { PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request Conversation conv_template_; + Map model_generation_cfgs; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 1266e9cb93..429050da3c 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,7 +13,7 @@ #include #include -#include "conv_template.h" +#include "config.h" #include "picojson.h" namespace mlc { diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index f6ff10e1ac..99a284fc42 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end()) { + return default_value; + } + + if (it->second.is()) { + return default_value; + } + + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index f36bc151a3..19f26ff624 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -161,15 +161,26 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } -Optional GenerationConfig::FromJSON(const std::string& json_str, std::string* err, - const Conversation& conv_template) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !json_obj.has_value()) { +Optional GenerationConfig::Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config) { + std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !optional_json_obj.has_value()) { return NullOpt; } + picojson::object& json_obj = optional_json_obj.value(); ObjectPtr n = make_object(); - // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + n->temperature = + json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); + n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); + n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", + model_defined_gen_config->frequency_penalty); + n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", + model_defined_gen_config->presence_penalty); + n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); + n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); + n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); // Copy stop str from conversation template to generation config for (auto& stop_str : conv_template.stop_str) { @@ -179,9 +190,6 @@ Optional GenerationConfig::FromJSON(const std::string& json_st n->stop_token_ids.push_back(stop_token_id); } - if (!err->empty()) { - return NullOpt; - } GenerationConfig gen_config; gen_config.data_ = std::move(n); return gen_config; diff --git a/cpp/serve/config.h b/cpp/serve/config.h index ef147b751b..6a3bdd8997 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -11,7 +11,7 @@ #include -#include "../json_ffi/conv_template.h" +#include "../json_ffi/config.h" namespace mlc { namespace llm { @@ -63,11 +63,13 @@ class GenerationConfig : public ObjectRef { explicit GenerationConfig(String config_json_str); /*! - * \brief Parse the generation config from the given JSON string. - * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + * \brief Create a generation config from a ChatCompletionRequest. + * If the request does not contain a generation config, the model-defined + * generation config will be used. */ - static Optional FromJSON(const std::string& json_str, std::string* err, - const Conversation& conv_template); + static Optional Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index d6ce4a4fcb..4a5168f971 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -223,7 +223,7 @@ class ChatCompletionRequest(BaseModel): @classmethod def check_penalty_range(cls, penalty_value: float) -> float: """Check if the penalty value is in range [-2, 2].""" - if penalty_value < -2 or penalty_value > 2: + if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 5d62dd5fb1..e6d52ab421 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -706,6 +706,28 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local ) +def _infer_generation_config( + model_config_dicts: List[Dict[str, Any]] +) -> List[Tuple[float, float, float, float]]: + """Infer the generation config from the model config dictionaries. + The returned four floats are: + - temperature + - top_p + - frequency_penalty + - presence_penalty + """ + generation_configs = [] + + for model_config in model_config_dicts: + temperature = model_config.get("temperature", 1.0) + top_p = model_config.get("top_p", 1.0) + frequency_penalty = model_config.get("frequency_penalty", 0.0) + presence_penalty = model_config.get("presence_penalty", 0.0) + generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) + + return generation_configs + + @dataclass class CallbackStreamOutput: """The output of MLCEngine._generate and AsyncMLCEngine._generate diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index f0247a6ef9..be0ee8af98 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -62,7 +62,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # search mlc-chat-config.json under path mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" if not mlc_chat_config_json_path.exists(): - raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + raise ValueError(f"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.") else: mlc_chat_config_json_path = mlc_chat_config_path diff --git a/tests/python/json_ffi/_ffi_api.py b/tests/python/json_ffi/_ffi_api.py new file mode 100644 index 0000000000..3df07d6a1f --- /dev/null +++ b/tests/python/json_ffi/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for mlc.json_ffi""" +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.json_ffi" prefix. +# e.g. TVM_REGISTER_GLOBAL("mlc.serve.TextData") +tvm._ffi._init_api("mlc.json_ffi", __name__) # pylint: disable=protected-access diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index c0c749c0a7..1b433c61b3 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -19,6 +19,8 @@ ) from mlc_llm.tokenizer import Tokenizer +from tests.python.json_ffi import _ffi_api + chat_completion_prompts = [ "What is the meaning of life?", "Introduce the history of Pittsburgh to me. Please elaborate in detail.", @@ -60,6 +62,32 @@ ] +@tvm._ffi.register_object( + "mlc.json_ffi.ModelDefinedGenerationConfig" +) # pylint: disable=protected-access +class ModelDefinedGenerationConfig(tvm.runtime.Object): + def __init__( # pylint: disable=too-many-arguments + self, temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ModelDefinedGenerationConfig, + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +@tvm._ffi.register_object("mlc.json_ffi.JSONFFIEngineConfig") # pylint: disable=protected-access +class JSONFFIEngineConfig(tvm.runtime.Object): + def __init__( # pylint: disable=too-many-arguments + self, conv_template: str, model_generation_cfgs: Dict[str, ModelDefinedGenerationConfig] + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.JSONFFIEngineConfig, conv_template, model_generation_cfgs + ) + + class EngineState: sync_queue: queue.Queue @@ -171,8 +199,22 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) + + self.json_ffi_engine_config = JSONFFIEngineConfig( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: ModelDefinedGenerationConfig( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + self._ffi["init_background_engine"]( - self.conv_template.model_dump_json(), + self.json_ffi_engine_config, self.engine_config, self.state.get_request_stream_callback(), None, @@ -204,8 +246,8 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: str, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -214,8 +256,8 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, From 7bf6b4d989594c421908a54deda32b07d3dd92bd Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Thu, 25 Apr 2024 20:56:04 -0400 Subject: [PATCH 3/3] Fix lint --- cpp/json_ffi/config.h | 5 ++--- tests/python/json_ffi/test_json_ffi_engine.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/json_ffi/config.h b/cpp/json_ffi/config.h index 78c3cb16a9..fe5e4e42e2 100644 --- a/cpp/json_ffi/config.h +++ b/cpp/json_ffi/config.h @@ -159,9 +159,8 @@ class JSONFFIEngineConfigNode : public Object { class JSONFFIEngineConfig : public ObjectRef { public: - explicit JSONFFIEngineConfig( - String conv_template, - Map model_generation_cfgs); + explicit JSONFFIEngineConfig(String conv_template, + Map model_generation_cfgs); TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); }; diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 1b433c61b3..f5235663be 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union import tvm +from tests.python.json_ffi import _ffi_api from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import engine_utils @@ -19,8 +20,6 @@ ) from mlc_llm.tokenizer import Tokenizer -from tests.python.json_ffi import _ffi_api - chat_completion_prompts = [ "What is the meaning of life?", "Introduce the history of Pittsburgh to me. Please elaborate in detail.",