Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JSONFFIEngine] Support generation config in JSONFFIEngine. Default config values to NOT_GIVEN #2225

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion cpp/json_ffi/conv_template.cc → cpp/json_ffi/config.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "conv_template.h"
#include "config.h"

#include <tvm/runtime/registry.h>

#include "../metadata/json_parser.h"

Expand All @@ -8,6 +10,29 @@ namespace json_ffi {

using namespace mlc::llm;

/****************** Model-defined generation config ******************/

TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode);

ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p,
double frequency_penalty,
double presence_penalty) {
ObjectPtr<ModelDefinedGenerationConfigNode> n = make_object<ModelDefinedGenerationConfigNode>();
n->temperature = temperature;
n->top_p = top_p;
n->frequency_penalty = frequency_penalty;
n->presence_penalty = presence_penalty;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig")
.set_body_typed([](double temperature, double top_p, double frequency_penalty,
double presence_penalty) {
return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty);
});

/****************** Conversation template ******************/

std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
{MessagePlaceholders::SYSTEM, "{system_message}"},
{MessagePlaceholders::USER, "{user_message}"},
Expand Down Expand Up @@ -308,6 +333,25 @@ std::optional<Conversation> Conversation::FromJSON(const std::string& json_str,
}
return Conversation::FromJSON(json_obj.value(), err);
}

/****************** JSON FFI engine config ******************/

TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode);

JSONFFIEngineConfig::JSONFFIEngineConfig(
String conv_template, Map<String, ModelDefinedGenerationConfig> model_generation_cfgs) {
ObjectPtr<JSONFFIEngineConfigNode> n = make_object<JSONFFIEngineConfigNode>();
n->conv_template = conv_template;
n->model_generation_cfgs = model_generation_cfgs;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig")
.set_body_typed([](String conv_template,
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs) {
return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs));
});

} // namespace json_ffi
} // namespace llm
} // namespace mlc
55 changes: 53 additions & 2 deletions cpp/json_ffi/conv_template.h → cpp/json_ffi/config.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H
#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H
#ifndef MLC_LLM_JSON_FFI_CONFIG_H
#define MLC_LLM_JSON_FFI_CONFIG_H

#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>

#include <iostream>
#include <map>
Expand All @@ -18,6 +22,32 @@ namespace mlc {
namespace llm {
namespace json_ffi {

/****************** Model-defined generation config ******************/

class ModelDefinedGenerationConfigNode : public Object {
public:
double temperature;
double top_p;
double frequency_penalty;
double presence_penalty;

static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object);
};

class ModelDefinedGenerationConfig : public ObjectRef {
public:
explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty,
double presence_penalty);

TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef,
ModelDefinedGenerationConfigNode);
};

/****************** Conversation template ******************/

enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION };

MessagePlaceholders messagePlaceholderFromString(const std::string& role);
Expand Down Expand Up @@ -114,6 +144,27 @@ struct Conversation {
static std::optional<Conversation> FromJSON(const std::string& json_str, std::string* err);
};

/****************** JSON FFI engine config ******************/

class JSONFFIEngineConfigNode : public Object {
public:
String conv_template;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;

static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object);
};

class JSONFFIEngineConfig : public ObjectRef {
public:
explicit JSONFFIEngineConfig(String conv_template,
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs);

TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode);
};

} // namespace json_ffi
} // namespace llm
} // namespace mlc
Expand Down
10 changes: 6 additions & 4 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
Array<Data> inputs = inputs_obj.value();

// generation_cfg
Optional<GenerationConfig> generation_cfg =
GenerationConfig::FromJSON(request_json_str, &err_, conv_template);
Optional<GenerationConfig> generation_cfg = GenerationConfig::Create(
request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]);
if (!generation_cfg.defined()) {
return false;
}
Expand Down Expand Up @@ -122,14 +122,16 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop);
TVM_MODULE_VTABLE_END();

void InitBackgroundEngine(std::string conv_template_str, EngineConfig engine_config,
void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config,
Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
std::optional<Conversation> conv_template = Conversation::FromJSON(conv_template_str, &err_);
std::optional<Conversation> conv_template =
Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_);
if (!conv_template.has_value()) {
LOG(FATAL) << "Invalid conversation template JSON: " << err_;
}
this->conv_template_ = conv_template.value();
this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs;

// Todo(mlc-team): decouple InitBackgroundEngine into two functions
// by removing `engine_config` from arguments, after properly handling
Expand Down
3 changes: 2 additions & 1 deletion cpp/json_ffi/json_ffi_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#include "../serve/threaded_engine.h"
#include "../streamer.h"
#include "conv_template.h"
#include "config.h"
#include "openai_api_protocol.h"

namespace mlc {
Expand Down Expand Up @@ -49,6 +49,7 @@ class JSONFFIEngine {
PackedFunc request_stream_callback_;
TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request
Conversation conv_template_;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;
};

} // namespace json_ffi
Expand Down
10 changes: 5 additions & 5 deletions cpp/json_ffi/openai_api_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <unordered_map>
#include <vector>

#include "conv_template.h"
#include "config.h"
#include "picojson.h"

namespace mlc {
Expand Down Expand Up @@ -90,8 +90,8 @@ class ChatCompletionRequest {
public:
std::vector<ChatCompletionMessage> messages;
std::string model;
double frequency_penalty = 0.0;
double presence_penalty = 0.0;
std::optional<double> frequency_penalty = std::nullopt;
std::optional<double> presence_penalty = std::nullopt;
bool logprobs = false;
int top_logprobs = 0;
std::optional<std::unordered_map<int, double>> logit_bias = std::nullopt;
Expand All @@ -100,8 +100,8 @@ class ChatCompletionRequest {
std::optional<int> seed = std::nullopt;
std::optional<std::vector<std::string>> stop = std::nullopt;
bool stream = false;
double temperature = 1.0;
double top_p = 1.0;
std::optional<double> temperature = std::nullopt;
std::optional<double> top_p = std::nullopt;
std::optional<std::vector<ChatTool>> tools = std::nullopt;
std::optional<std::string> tool_choice = std::nullopt;
std::optional<std::string> user = std::nullopt;
Expand Down
16 changes: 16 additions & 0 deletions cpp/metadata/json_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) {
return it->second.get<ValueType>();
}

template <typename ValueType>
inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key,
const ValueType& default_value) {
auto it = json.find(key);
if (it == json.end()) {
return default_value;
}

if (it->second.is<picojson::null>()) {
return default_value;
}

CHECK(it->second.is<ValueType>()) << "ValueError: key `" << key << "` has unexpected type";
return it->second.get<ValueType>();
}

template <typename ValueType>
inline ValueType Lookup(const picojson::array& json, int index) {
CHECK(index < json.size()) << "IndexError: json::array index out of range";
Expand Down
24 changes: 16 additions & 8 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,26 @@ GenerationConfig::GenerationConfig(String config_json_str) {
data_ = std::move(n);
}

Optional<GenerationConfig> GenerationConfig::FromJSON(const std::string& json_str, std::string* err,
const Conversation& conv_template) {
std::optional<picojson::object> json_obj = json::LoadJSONFromString(json_str, err);
if (!err->empty() || !json_obj.has_value()) {
Optional<GenerationConfig> GenerationConfig::Create(
const std::string& json_str, std::string* err, const Conversation& conv_template,
const ModelDefinedGenerationConfig& model_defined_gen_config) {
std::optional<picojson::object> optional_json_obj = json::LoadJSONFromString(json_str, err);
if (!err->empty() || !optional_json_obj.has_value()) {
return NullOpt;
}
picojson::object& json_obj = optional_json_obj.value();
ObjectPtr<GenerationConfigNode> n = make_object<GenerationConfigNode>();

// TODO(mlc-team): Pass the parameters from `json_obj` to `n`.
n->temperature =
json::LookupOrDefault<double>(json_obj, "temperature", model_defined_gen_config->temperature);
n->top_p = json::LookupOrDefault<double>(json_obj, "top_p", model_defined_gen_config->top_p);
n->frequency_penalty = json::LookupOrDefault<double>(json_obj, "frequency_penalty",
model_defined_gen_config->frequency_penalty);
n->presence_penalty = json::LookupOrDefault<double>(json_obj, "presence_penalty",
model_defined_gen_config->presence_penalty);
n->logprobs = json::LookupOrDefault<bool>(json_obj, "logprobs", false);
n->top_logprobs = static_cast<int>(json::LookupOrDefault<double>(json_obj, "top_logprobs", 0));
n->ignore_eos = json::LookupOrDefault<bool>(json_obj, "ignore_eos", false);

// Copy stop str from conversation template to generation config
for (auto& stop_str : conv_template.stop_str) {
Expand All @@ -179,9 +190,6 @@ Optional<GenerationConfig> GenerationConfig::FromJSON(const std::string& json_st
n->stop_token_ids.push_back(stop_token_id);
}

if (!err->empty()) {
return NullOpt;
}
GenerationConfig gen_config;
gen_config.data_ = std::move(n);
return gen_config;
Expand Down
12 changes: 7 additions & 5 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <optional>

#include "../json_ffi/conv_template.h"
#include "../json_ffi/config.h"

namespace mlc {
namespace llm {
Expand Down Expand Up @@ -63,11 +63,13 @@ class GenerationConfig : public ObjectRef {
explicit GenerationConfig(String config_json_str);

/*!
* \brief Parse the generation config from the given JSON string.
* When parsing fails, errors are dumped to the input error string, and NullOpt is returned.
* \brief Create a generation config from a ChatCompletionRequest.
* If the request does not contain a generation config, the model-defined
* generation config will be used.
*/
static Optional<GenerationConfig> FromJSON(const std::string& json_str, std::string* err,
const Conversation& conv_template);
static Optional<GenerationConfig> Create(
const std::string& json_str, std::string* err, const Conversation& conv_template,
const ModelDefinedGenerationConfig& model_defined_gen_config);

TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode);
};
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class ChatCompletionRequest(BaseModel):
@classmethod
def check_penalty_range(cls, penalty_value: float) -> float:
"""Check if the penalty value is in range [-2, 2]."""
if penalty_value < -2 or penalty_value > 2:
if penalty_value and (penalty_value < -2 or penalty_value > 2):
raise ValueError("Penalty value should be in range [-2, 2].")
return penalty_value

Expand Down
22 changes: 22 additions & 0 deletions python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,28 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local
)


def _infer_generation_config(
model_config_dicts: List[Dict[str, Any]]
) -> List[Tuple[float, float, float, float]]:
"""Infer the generation config from the model config dictionaries.
The returned four floats are:
- temperature
- top_p
- frequency_penalty
- presence_penalty
"""
generation_configs = []

for model_config in model_config_dicts:
temperature = model_config.get("temperature", 1.0)
top_p = model_config.get("top_p", 1.0)
frequency_penalty = model_config.get("frequency_penalty", 0.0)
presence_penalty = model_config.get("presence_penalty", 0.0)
generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty))

return generation_configs


@dataclass
class CallbackStreamOutput:
"""The output of MLCEngine._generate and AsyncMLCEngine._generate
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path:
# search mlc-chat-config.json under path
mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json"
if not mlc_chat_config_json_path.exists():
raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.")
raise ValueError(f"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.")
else:
mlc_chat_config_json_path = mlc_chat_config_path

Expand Down
6 changes: 6 additions & 0 deletions tests/python/json_ffi/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""FFI APIs for mlc.json_ffi"""
import tvm._ffi

# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.json_ffi" prefix.
# e.g. TVM_REGISTER_GLOBAL("mlc.serve.TextData")
tvm._ffi._init_api("mlc.json_ffi", __name__) # pylint: disable=protected-access
Loading
Loading