Skip to content

Commit

Permalink
[Serving][Refactor] Major codebase refactor
Browse files Browse the repository at this point in the history
This PR is a major refactor of the serving framework.
It contains the following aspects:

* Changing Model and Sampler from `runtime.Module` to Object
in TVM. Exposing the public interface through public member
function (rather than through `GetFunction` which returns PackedFunc).
Separating the definition and implementation of Model/Sampler
classes. Now a base Model/Sampler class definition is in their
respective header files, and the implementations are in .cc files.
* Removing the TokenizerModule class, and directly using the
Tokenizer class in `tokenizer_cpp`, to reduce indirection.
* Introducing unique string `id` to Request. This id is passed
in from frontend as a Request constructor parameter, and is the
unique identifier of a request.
* Reducing the uses of `ShapeTuple` after the Model/Sampler/Tokenizer
interface changes.
* Moving some previous member functions in Engine (such as "getting
data length", "getting data embedding") to the corresponding
data structure side.
* Introducing struct `EngineStats` to contain all the runtime
statistics of the engine.
* Classifying and reordering the member functions in Engine.
  • Loading branch information
MasterJH5574 committed Nov 17, 2023
1 parent d609545 commit c0d233d
Show file tree
Hide file tree
Showing 19 changed files with 1,132 additions and 1,262 deletions.
20 changes: 18 additions & 2 deletions cpp/serve/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <tvm/runtime/registry.h>

#include "model.h"

namespace mlc {
namespace llm {
namespace serve {
Expand All @@ -24,6 +26,16 @@ TextData::TextData(String text) {
data_ = std::move(n);
}

int TextDataNode::GetLength() const {
LOG(FATAL) << "\"GetLength\" for TextData is not supported. "
"Please tokenize the text and construct a TokenData object.";
}

NDArray TextDataNode::GetEmbedding(Model model) const {
LOG(FATAL) << "\"GetEmbedding\" for TextData is not supported. "
"Please tokenize the text and construct a TokenData object.";
}

TVM_REGISTER_GLOBAL("mlc.serve.TextData").set_body_typed([](String text) {
return TextData(std::move(text));
});
Expand All @@ -36,18 +48,22 @@ TVM_REGISTER_GLOBAL("mlc.serve.TextDataGetTextString").set_body_typed([](TextDat

TVM_REGISTER_OBJECT_TYPE(TokenDataNode);

TokenData::TokenData(ShapeTuple token_ids) {
TokenData::TokenData(IntTuple token_ids) {
ObjectPtr<TokenDataNode> n = make_object<TokenDataNode>();
n->token_ids = std::move(token_ids);
data_ = std::move(n);
}

TokenData::TokenData(std::vector<int32_t> token_ids) {
ObjectPtr<TokenDataNode> n = make_object<TokenDataNode>();
n->token_ids = ShapeTuple(token_ids.begin(), token_ids.end());
n->token_ids = IntTuple(token_ids.begin(), token_ids.end());
data_ = std::move(n);
}

int TokenDataNode::GetLength() const { return token_ids.size(); }

NDArray TokenDataNode::GetEmbedding(Model model) const { return model->TokenEmbed(token_ids); }

TVM_REGISTER_GLOBAL("mlc.serve.TokenData").set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<int32_t> token_ids;
token_ids.reserve(args.size());
Expand Down
19 changes: 17 additions & 2 deletions cpp/serve/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>

namespace mlc {
Expand All @@ -15,11 +16,19 @@ namespace serve {

using namespace tvm::runtime;

class Model;

/****************** DataNode ******************/

/*! \brief The base class of multi-modality data (text, tokens, embedding, etc). */
class DataNode : public Object {
public:
/*! \brief Get the length (equivalent number of tokens) of the data. */
virtual int GetLength() const = 0;

/*! \brief Compute the embedding of this data with regard to the input model. */
virtual NDArray GetEmbedding(Model model) const = 0;

static constexpr const char* _type_key = "mlc.serve.Data";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
Expand All @@ -39,6 +48,9 @@ class TextDataNode : public DataNode {
/*! \brief The text string. */
String text;

int GetLength() const final;
NDArray GetEmbedding(Model model) const final;

static constexpr const char* _type_key = "mlc.serve.TextData";
TVM_DECLARE_BASE_OBJECT_INFO(TextDataNode, DataNode);
};
Expand All @@ -56,15 +68,18 @@ class TextData : public Data {
class TokenDataNode : public DataNode {
public:
/*! \brief The token ids. */
ShapeTuple token_ids;
IntTuple token_ids;

int GetLength() const final;
NDArray GetEmbedding(Model model) const final;

static constexpr const char* _type_key = "mlc.serve.TokenData";
TVM_DECLARE_BASE_OBJECT_INFO(TokenDataNode, DataNode);
};

class TokenData : public Data {
public:
explicit TokenData(ShapeTuple token_ids);
explicit TokenData(IntTuple token_ids);

explicit TokenData(std::vector<int32_t> token_ids);

Expand Down
Loading

0 comments on commit c0d233d

Please sign in to comment.