Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[Neural Speed] Support continuous batching + beam search inference in…
Browse files Browse the repository at this point in the history
… LLAMA (#145)
  • Loading branch information
zhentaoyu authored Mar 4, 2024
1 parent 9bcb612 commit 7c2199f
Show file tree
Hide file tree
Showing 13 changed files with 781 additions and 346 deletions.
230 changes: 162 additions & 68 deletions developer_document.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
out_count = 0
input_list = None
pad_token_id = generate_kwargs.get("pad_token", None)
if generate_kwargs.get("continuous_batching", False):
if input_ids.shape[0] > 1 and generate_kwargs.get("continuous_batching", True):
input_list = self._cont_batching_input(input_ids, pad_token_id)
else:
input_list = input_ids.tolist()
Expand Down
40 changes: 23 additions & 17 deletions neural_speed/application/main_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ using Response = Query;
using ResponseCallback = std::function<void(std::vector<Response>, int)>;
} // namespace

static std::set<model_archs> cont_batching_model_archs = {MODEL_GPTJ, MODEL_LLAMA};
void init_gpt_params(gpt_params* params, const std::string& model_path, int max_new_tokens = -1, int n_batch = 512,
int ctx_size = 512, int seed = -1, int threads = 8, float repetition_penalty = 1.1f,
int num_beams = 1, bool do_sample = false, int top_k = 40, float top_p = 0.95,
float temperature = 0.8, int min_new_tokens = 0, float length_penalty = 1.0f,
bool early_stopping = false, int n_keep = 0, int n_discard = -1, bool shift_roped_k = false,
int batch_size = 1, model_vocab::id pad_token = -1, const std::string& memory_dtype = "auto",
const bool& continuous_batching = false, const int& max_request_num = MODEL_MAX_REQUEST_NUM,
bool continuous_batching = true, const int& max_request_num = MODEL_MAX_REQUEST_NUM,
const float& model_scratch_enlarge_scale = 1.0f) {
MODEL_ASSERT(params != nullptr);
#ifdef MODEL_NAME
Expand Down Expand Up @@ -114,10 +115,13 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_
params->memory_type = KV_MEM_TYPE_AUTO;
else
fprintf(stderr, "Unexpected memory dtype %s!", memory_dtype.c_str());
if (batch_size > 1 && (!continuous_batching || params->model_arch != model_archs::MODEL_GPTJ)) {
params->memory_type = KV_MEM_TYPE_F16; // TODO(Yi & YZT): MHA IN MULTI-BATCH For More Model Archs
}
// TODO(Yi & YZT): MHA IN MULTI-BATCH For More Model Archs
params->cont_batching = continuous_batching;
if (params->shift_roped_k) params->cont_batching = false;
if (cont_batching_model_archs.count(params->model_arch) == 0) params->cont_batching = false;
if (batch_size > 1 && !continuous_batching) {
params->memory_type = KV_MEM_TYPE_F16;
}
params->max_request_num = std::max(batch_size, max_request_num);
params->min_new_tokens = min_new_tokens;
params->length_penalty = length_penalty;
Expand All @@ -137,8 +141,8 @@ class ModelServer {
int n_batch, int ctx_size, int seed, int threads, float repetition_penalty, int num_beams, bool do_sample,
int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping,
int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token,
const std::string& memory_dtype, const bool& continuous_batching, const int& max_request_num,
const float& model_scratch_enlarge_scale, const std::string& policy, const bool& print_log,
const std::string& memory_dtype, bool continuous_batching, const int& max_request_num,
const float& model_scratch_enlarge_scale, const std::string& policy, bool print_log,
const std::function<void()>& init_cb)
: response(response),
waiting(),
Expand Down Expand Up @@ -258,12 +262,16 @@ class ModelServer {
int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p,
float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep,
int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token,
const std::string& memory_dtype, const bool& continuous_batching, const int& max_request_num,
const std::string& memory_dtype, bool continuous_batching, const int& max_request_num,
const float& model_scratch_enlarge_scale) {
init_gpt_params(&params, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty,
num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping,
n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching,
max_request_num, model_scratch_enlarge_scale);
if (cont_batching_model_archs.count(params.model_arch) == 0) {
fprintf(stderr, "\nERROR: ModelServer only supports gpt-j, llama!\n");
running = false;
}
}

~ModelServer() {
Expand Down Expand Up @@ -317,8 +325,7 @@ class Model {
float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard,
bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype,
const bool& continuous_batching, const int& max_request_num,
const float& model_scratch_enlarge_scale);
bool continuous_batching, const int& max_request_num, const float& model_scratch_enlarge_scale);
void reinit();
std::vector<std::vector<model_token>> generate(const std::vector<std::vector<model_token>>& input_ids);
// deprecated API
Expand Down Expand Up @@ -411,7 +418,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_
int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p,
float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep,
int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token,
const std::string& memory_dtype, const bool& continuous_batching, const int& max_request_num,
const std::string& memory_dtype, bool continuous_batching, const int& max_request_num,
const float& model_scratch_enlarge_scale) {
init_gpt_params(&params, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams,
do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep,
Expand Down Expand Up @@ -466,9 +473,9 @@ bool Model::check_input_and_count_padding(const std::vector<std::vector<model_to
} else { // multi-batch inputs (first token)
ctx->batch_size = input_ids.size();
MODEL_ASSERT(input_ids.size() <= ctx->max_request_num);
static std::set<model_archs> batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM};
static std::set<model_archs> batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM, MODEL_LLAMA};
if (batched_model_archs.count(params.model_arch) == 0) {
fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm support multi-batch generation!\n");
fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm, llama support multi-batch generation!\n");
return false;
}
if (ctx->vocab.pad_token_id == -1) {
Expand Down Expand Up @@ -738,7 +745,7 @@ std::vector<std::vector<model_token>> Model::post_beam_search(model_context* lct
const std::vector<model_input>& inputs,
const int& n_threads) {
// TODO(Zhentao): to implement
static std::set<model_archs> supported_archs = {MODEL_GPTJ, MODEL_GPTNEOX};
static std::set<model_archs> supported_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_LLAMA};
if (supported_archs.count(params.model_arch) != 0) {
return beam_search(lctx, n_predict, inputs, n_threads);
} else {
Expand Down Expand Up @@ -914,7 +921,7 @@ PYBIND11_MODULE(mixtral_cpp, m)
py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false,
py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false,
py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto",
py::arg("continuous_batching") = false, py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM,
py::arg("continuous_batching") = true, py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM,
py::arg("model_scratch_enlarge_scale") = 1.0f)
.def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids"))
.def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits",
Expand Down Expand Up @@ -946,9 +953,8 @@ PYBIND11_MODULE(mixtral_cpp, m)
.def_readwrite("token_ids", &Query::token_ids);
py::class_<ModelServer>(m, "ModelServer", py::module_local())
.def(py::init<const ResponseCallback&, const std::string&, bool, int, int, int, int, int, float, int, bool, int,
float, float, int, float, bool, int, int, bool, int, model_vocab::id, const std::string&,
const bool&, const int&, const float&, const std::string&, const bool&,
const std::function<void()>&>(),
float, float, int, float, bool, int, int, bool, int, model_vocab::id, const std::string&, bool,
const int&, const float&, const std::string&, bool, const std::function<void()>&>(),
py::arg("response"), py::arg("model_path"), py::arg("return_prompt") = false, py::arg("max_new_tokens") = -1,
py::arg("n_batch") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1, py::arg("threads") = 8,
py::arg("repetition_penalty") = 1.1f, py::arg("num_beams") = 1, py::arg("do_sample") = false,
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/core/ne.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#define NE_FILE_VERSION 1

#define NE_MAX_DIMS 4
#define NE_MAX_NODES 16384
#define NE_MAX_NODES 40960
#define NE_MAX_PARAMS 256
#define NE_MAX_CONTEXTS 64
#define NE_MAX_OPT 36
Expand Down
23 changes: 12 additions & 11 deletions neural_speed/models/gptj/gptj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
const int N = inputs->n_tokens;
const int n_past = inputs->n_past;
const int n_total = inputs->n_total;
// continuous batching
// continuous batching (no padding)
// if each sequence length l_i ! = l_k
// input shape will be [1, l_sum]
const bool concat_multi_seqs = (batch_size > 1 && lctx.cont_batching) ? true : false;
std::vector<int> n_tokens(batch_size);
std::vector<int> n_pasts(batch_size);
std::vector<int> n_totals(batch_size);
Expand All @@ -78,15 +79,15 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
n_pasts[i] = inputs[i].n_past;
n_totals[i] = inputs[i].n_total;
block_ids[i] = inputs[i].request_idx * beam_size + inputs[i].beam_idx;
if (!lctx.cont_batching) {
if (!concat_multi_seqs) {
n_padding.push_back(inputs[i].n_padding);
if (no_padding && inputs[i].n_padding != 0) no_padding = false;
}
}
const int seq_len_sum = std::accumulate(n_tokens.begin(), n_tokens.end(), 0);
if (!lctx.cont_batching) MODEL_ASSERT(seq_len_sum == N * batch_size);
const int infer_bs = lctx.cont_batching ? 1 : batch_size;
const int infer_seq_len = lctx.cont_batching ? seq_len_sum : N;
if (!concat_multi_seqs) MODEL_ASSERT(seq_len_sum == N * batch_size);
const int infer_bs = concat_multi_seqs ? 1 : batch_size;
const int infer_seq_len = concat_multi_seqs ? seq_len_sum : N;
const std::vector<std::vector<int>> infer_groups = split_inputs_into_groups(inputs, n_input);
const auto& model = lctx.model;
const auto& hparams = model.hparams;
Expand All @@ -100,7 +101,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
const int n_ctx = lctx.n_ctx; // max number fo tokens to keep in the kv-cache
const int n_keep = lctx.n_keep;
const bool shift_roped_k = lctx.shift_roped_k;
MODEL_ASSERT(("continuous batching mechanism doesn't support shift rope.\n", !(lctx.cont_batching && shift_roped_k)));
MODEL_ASSERT(("continuous batching mechanism doesn't support shift rope.\n", !(concat_multi_seqs && shift_roped_k)));
const bool is_ring_full = shift_roped_k && n_total > n_past;
const int n_cached = shift_roped_k ? std::min(n_total + N, n_ctx) : (n_past + N); // #tokens cached after kv-append
int n_head = hparams.n_head;
Expand All @@ -120,7 +121,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
}
#endif

MODEL_ASSERT(("continuous batching mechanism doesn't support TP.\n", !(lctx.cont_batching && enable_tp)));
MODEL_ASSERT(("continuous batching mechanism doesn't support TP.\n", !(concat_multi_seqs && enable_tp)));
auto& mem_per_token = lctx.mem_per_token;
auto& buf_compute = lctx.buf_compute;

Expand Down Expand Up @@ -208,7 +209,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
infer_bs);
Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
}
if (lctx.cont_batching) {
if (concat_multi_seqs) {
size_t off_sl = 0;
// per_request rope
for (int gi = 0; gi < infer_groups.size(); ++gi) {
Expand Down Expand Up @@ -414,9 +415,9 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
K = ne_permute(ctx0, K, 0, 2, 1, 3);
}
} else {
std::vector<int> attn_block_ids;
for (const auto& bsi : infer_groups[gi]) {
attn_block_ids.push_back(block_ids[bsi]);
std::vector<int> attn_block_ids(infer_groups[gi].size());
for (int j = 0; j < infer_groups[gi].size(); ++j) {
attn_block_ids[j] = block_ids[infer_groups[gi][j]];
}
K = model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head, attn_bs, attn_block_ids, il);
if (is_ring_full) {
Expand Down
Loading

0 comments on commit 7c2199f

Please sign in to comment.