Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Nov 11, 2023
1 parent c12fe04 commit a928fdd
Show file tree
Hide file tree
Showing 16 changed files with 1,587 additions and 268 deletions.
421 changes: 265 additions & 156 deletions cpp/serve/engine.cc

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ void FunctionTable::_InitFunctions() {
get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device");
this->remove_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_remove");
this->popn_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn");
this->get_num_available_pages_kv_cache_func_ =
get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages");
support_backtracking_kv_ = true;
}

Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct FunctionTable {
PackedFunc sync_device_kv_cache_func_;
PackedFunc remove_from_kv_cache_func_;
PackedFunc popn_from_kv_cache_func_;
PackedFunc get_num_available_pages_kv_cache_func_;
};

} // namespace serve
Expand Down
154 changes: 96 additions & 58 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ class ModelModule : public ModuleNode {
CHECK_EQ(args.size(), 1);
*rv = TokenEmbed(args[0]);
});
} else if (name == "single_seq_prefill") {
} else if (name == "batch_prefill") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 2);
*rv = SingleSequencePrefill(args[0], args[1]);
CHECK_EQ(args.size(), 3);
*rv = BatchPrefill(args[0], args[1], args[2]);
});
} else if (name == "decode") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -115,15 +115,18 @@ class ModelModule : public ModuleNode {
ICHECK_EQ(args.size(), 0);
Reset();
});
} else if (name == "get_num_available_pages") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
ICHECK(kv_cache_.defined());
*rv = ft_.get_num_available_pages_kv_cache_func_(kv_cache_);
});
} else if (name == "get_max_window_size") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
CHECK_NE(max_window_size_, -1) << "The model has not been initialized";
*rv = max_window_size_;
});
} else if (name == "runtime_stats_text") {
// Todo: JSON style
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = GetStats(); });
} else {
return PackedFunc(nullptr);
}
Expand Down Expand Up @@ -156,7 +159,8 @@ class ModelModule : public ModuleNode {
}
// Copy input token ids to device.
DLDataType dtype(DataType::Int(32));
NDArray token_ids_nd = CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, 2048);
NDArray token_ids_nd =
CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, max_window_size_);
ICHECK_EQ(token_ids_nd->ndim, 1);
ICHECK_EQ(token_ids_nd->shape[0], total_length);
token_ids_nd = token_ids_nd.CreateView({1, total_length}, dtype);
Expand All @@ -165,12 +169,7 @@ class ModelModule : public ModuleNode {
<< "`embed` function is not found in the model. Please make sure the model is compiled "
"with flag `--sep-embed` and `--enable-batching`";

auto tstart = std::chrono::high_resolution_clock::now();
NDArray embeddings = ft_.embed_func_(ft_.CopyToWorker0(token_ids_nd), params_);
auto tend = std::chrono::high_resolution_clock::now();

this->embed_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->embed_total_tokens += total_length;

// embeddings: (1, total_length, hidden_size)
ICHECK_EQ(embeddings->ndim, 3);
Expand All @@ -185,12 +184,30 @@ class ModelModule : public ModuleNode {
* \param seq_id The id of the sequence in the KV cache.
* \return The logits for the next token.
*/
NDArray SingleSequencePrefill(NDArray embeddings, int seq_id) {
NDArray BatchPrefill(Array<NDArray> embedding_arr, ShapeTuple seq_ids, ShapeTuple lengths) {
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
std::vector<int> logit_pos;
logit_pos.reserve(num_sequences);
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
logit_pos.push_back(total_length);
if (i > 0) {
CHECK_GT(seq_ids[i], seq_ids[i - 1]) << "The input sequence ids must be non-decreasing.";
}
}

// embeddings: (1, n, h)
CHECK_EQ(embeddings->ndim, 3);
CHECK_EQ(embeddings->shape[0], 1);
CHECK_EQ(embeddings->device.device_type, device_.device_type);
CHECK_EQ(embeddings->device.device_id, device_.device_id);
NDArray embeddings = ConcatEmbeddings(std::move(embedding_arr), total_length);
ICHECK_EQ(embeddings->ndim, 3);
ICHECK_EQ(embeddings->shape[0], 1);
ICHECK_EQ(embeddings->shape[1], total_length);
ICHECK_EQ(embeddings->device.device_type, device_.device_type);
ICHECK_EQ(embeddings->device.device_id, device_.device_id);

NDArray logit_pos_nd = CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32);

CHECK(ft_.prefill_func_.defined())
<< "`prefill_with_embed` function is not found in the model. Please make sure the model is "
Expand All @@ -202,22 +219,20 @@ class ModelModule : public ModuleNode {

// Reserve in KV cache for the length of the input.
ft_.reset_append_length_kv_cache_func_(kv_cache_);
ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_id, /*length=*/embeddings->shape[1]);
for (int i = 0; i < num_sequences; ++i) {
ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_ids[i], lengths[i]);
}
ft_.sync_device_kv_cache_func_(kv_cache_);

auto tstart = std::chrono::high_resolution_clock::now();
// args: embeddings, kv_cache, params
Array<ObjectRef> ret = ft_.prefill_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_);
auto tend = std::chrono::high_resolution_clock::now();

this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += embeddings->shape[1];
// args: embeddings, logit_pos, kv_cache, params
Array<ObjectRef> ret =
ft_.prefill_func_(ft_.CopyToWorker0(embeddings), logit_pos_nd, kv_cache_, params_);

// logits: (1, 1, v)
// logits: (1, num_sequences, v)
NDArray logits = Downcast<NDArray>(ret[0]);
ICHECK_EQ(logits->ndim, 3);
ICHECK_EQ(logits->shape[0], 1);
ICHECK_EQ(logits->shape[1], 1);
ICHECK_EQ(logits->shape[1], num_sequences);
return logits;
}

Expand Down Expand Up @@ -251,13 +266,8 @@ class ModelModule : public ModuleNode {
}
ft_.sync_device_kv_cache_func_(kv_cache_);

auto tstart = std::chrono::high_resolution_clock::now();
// args: embeddings, kv_cache, params
Array<ObjectRef> ret = ft_.decode_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_);
auto tend = std::chrono::high_resolution_clock::now();

this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->decode_total_tokens += embeddings->shape[0];

// logits: (b, 1, v)
NDArray logits = Downcast<NDArray>(ret[0]);
Expand Down Expand Up @@ -286,7 +296,7 @@ class ModelModule : public ModuleNode {
for (GenerationConfig cfg : generation_cfg) {
temperatures.push_back(cfg->temperature);
}
NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 16);
NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32);
ICHECK_EQ(temperatures_nd->ndim, 1);
ICHECK_EQ(temperatures_nd->shape[0], batch_size);

Expand Down Expand Up @@ -318,6 +328,57 @@ class ModelModule : public ModuleNode {
return view;
}

/*! \brief Concatenate the input embeddings. */
NDArray ConcatEmbeddings(Array<NDArray> embedding_arr, int64_t total_length) {
ICHECK(!embedding_arr.empty());
int hidden_size = -1;
DataType dtype;
for (NDArray inp_embeddings : embedding_arr) {
// inp_embedding: (1, n, h)
CHECK_EQ(inp_embeddings->ndim, 3);
CHECK_EQ(inp_embeddings->shape[0], 1);
CHECK_EQ(inp_embeddings->device.device_type, device_.device_type);
CHECK_EQ(inp_embeddings->device.device_id, device_.device_id);
if (hidden_size == -1) {
hidden_size = inp_embeddings->shape[2];
dtype = inp_embeddings.DataType();
} else {
CHECK_EQ(inp_embeddings->shape[2], hidden_size);
CHECK_EQ(inp_embeddings.DataType(), dtype);
}
}

// - Resize the shared embedding array.
if (embeddings_.defined()) {
ICHECK_EQ(embeddings_->ndim, 3);
ICHECK_EQ(embeddings_->shape[0], 1);
ICHECK_EQ(embeddings_->shape[2], hidden_size);
}
int64_t init_size = embeddings_.defined() ? embeddings_->shape[1] : max_window_size_;
while (init_size < total_length) {
init_size *= 2;
}
if (!embeddings_.defined() || init_size != embeddings_->shape[1]) {
embeddings_ = NDArray::Empty({1, init_size, hidden_size}, dtype, device_);
}

// - Copy input embeddings.
int64_t start_pos = 0;
for (NDArray inp_embeddings : embedding_arr) {
int64_t length = inp_embeddings->shape[1];
CHECK_LE(start_pos + length, total_length);

DLTensor copy_dst = *(embeddings_.operator->());
copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes();
copy_dst.shape = inp_embeddings->shape;
NDArray::CopyFromTo(inp_embeddings.operator->(), &copy_dst);

start_pos += length;
}
CHECK_EQ(start_pos, total_length);
return embeddings_.CreateView({1, total_length, hidden_size}, dtype);
}

/*! \brief Load model configuration from JSON. */
void LoadModelConfigJSON(const std::string& config_str) {
picojson::value config_json;
Expand Down Expand Up @@ -350,37 +411,12 @@ class ModelModule : public ModuleNode {

/*! \brief reset the runtime states. */
void Reset() {
// Reset the statistics.
this->embed_total_tokens = 0;
this->prefill_total_tokens = 0;
this->decode_total_tokens = 0;
this->embed_total_time = 0;
this->prefill_total_time = 0;
this->decode_total_time = 0;
// Reset the KV cache.
if (kv_cache_.defined()) {
ft_.reset_kv_cache_func_(kv_cache_);
}
}

/*! \brief Return statistics in JSON format. */
String GetStats() {
picojson::object stats;
stats["prefill_speed"] = picojson::value(prefill_total_tokens / prefill_total_time);
stats["decode_speed"] = picojson::value(decode_total_tokens / decode_total_time);
stats["embed_speed"] = picojson::value(embed_total_tokens / embed_total_time);
return picojson::value(stats).serialize(true);
}

//----------------------------
// Statistics
//----------------------------
double embed_total_time = 0;
double decode_total_time = 0;
double prefill_total_time = 0;
int64_t embed_total_tokens = 0;
int64_t decode_total_tokens = 0;
int64_t prefill_total_tokens = 0;
//----------------------------
// Model configurations
//----------------------------
Expand All @@ -400,6 +436,8 @@ class ModelModule : public ModuleNode {
ObjectRef params_;
// Shared NDArray
NDArray input_token_ids_{nullptr};
NDArray embeddings_{nullptr};
NDArray logit_pos_arr_{nullptr};
NDArray temperature_arr_{nullptr};
};

Expand Down
15 changes: 15 additions & 0 deletions cpp/serve/request_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ RequestModelState::RequestModelState(int model_id, Array<Data> inputs) {
data_ = std::move(n);
}

TVM_REGISTER_OBJECT_TYPE(RequestStateNode);

RequestState::RequestState(int num_models, Array<Data> inputs, int raw_input_length) {
ObjectPtr<RequestStateNode> n = make_object<RequestStateNode>();
Array<RequestModelState> mstates;
mstates.reserve(num_models);
for (int i = 0; i < num_models; ++i) {
mstates.push_back(RequestModelState(i, inputs));
}
n->mstates = std::move(mstates);
n->raw_input_length = raw_input_length;
n->tadd = std::chrono::high_resolution_clock::now();
data_ = std::move(n);
}

} // namespace serve
} // namespace llm
} // namespace mlc
23 changes: 15 additions & 8 deletions cpp/serve/request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ class RequestModelState : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode);
};

struct RequestState {
class RequestStateNode : public Object {
public:
/*!
* \brief The state with regard to each model.
* \sa RequestModelState
*/
Array<RequestModelState> mstates;

/*! \brief The summed up input length of the request. */
int raw_input_length = 0;
/*! \brief The decoded text string output. */
std::string output = "";

Expand All @@ -101,13 +104,17 @@ struct RequestState {
/*! \brief The time of finishing prefill stage. */
std::chrono::_V2::system_clock::time_point tprefill_finish;

explicit RequestState(int num_models, Array<Data> inputs) {
mstates.reserve(num_models);
for (int i = 0; i < num_models; ++i) {
mstates.push_back(RequestModelState(i, inputs));
}
tadd = std::chrono::high_resolution_clock::now();
}
static constexpr const char* _type_key = "mlc.serve.RequestState";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateNode, Object);
};

class RequestState : public ObjectRef {
public:
explicit RequestState(int num_models, Array<Data> inputs, int raw_input_length);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode);
};

} // namespace serve
Expand Down
Loading

0 comments on commit a928fdd

Please sign in to comment.