Skip to content

Commit

Permalink
Refactor: API of KVStateCacheManager accept Client&.
Browse files Browse the repository at this point in the history
Make Client& client as a member of llm cache object.

Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Mar 7, 2024
1 parent 9a3cc54 commit f0c3844
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 26 deletions.
22 changes: 11 additions & 11 deletions modules/llm-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ KVStateCache::~KVStateCache() {}

KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int tensorBytes,
int layer,
std::shared_ptr<RadixTree>& rootTree) {
std::shared_ptr<RadixTree>& rootTree)
: client(client) {
this->tensorBytes = tensorBytes;
this->version = 0;
this->layer = layer;
Expand Down Expand Up @@ -126,7 +127,7 @@ Status KVStateCacheBuilder::Make(
}

Status KVStateCacheBuilder::Split(
Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder,
KVStateCacheBlockBuilder* kvStateCacheBlockBuilder,
std::vector<std::shared_ptr<NodeData>> nodeDataList,
KVStateCacheBlockBuilder*& childKVStateCacheBlockBuilder) {
// Split the tree if the list of kvState is full.
Expand Down Expand Up @@ -155,7 +156,7 @@ Status KVStateCacheBuilder::Split(
}

Status KVStateCacheBuilder::Update(
Client& client, const std::vector<int>& tokenList, int nextToken,
const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(nextToken);
Expand Down Expand Up @@ -189,7 +190,7 @@ Status KVStateCacheBuilder::Update(
rootTree->Split(tokenListCopy, subTreeHeader);
RETURN_ON_ASSERT(nodeDataList.size() != 0, "Split llm cache failed.");
KVStateCacheBlockBuilder* newKVStateCacheBlockBuilder;
Status status = Split(client, kvStateCacheBlockBuilder, nodeDataList,
Status status = Split(kvStateCacheBlockBuilder, nodeDataList,
newKVStateCacheBlockBuilder);
RETURN_ON_ERROR(status);

Expand All @@ -204,7 +205,7 @@ Status KVStateCacheBuilder::Update(
VLOG(100) << "block split success";

// kv_state_cache_builder->UnLock();
status = Update(client, tokenList, nextToken, kvState);
status = Update(tokenList, nextToken, kvState);
RETURN_ON_ERROR(status);
} else {
// Update the kv-state cache.
Expand All @@ -223,7 +224,7 @@ Status KVStateCacheBuilder::Update(
}

Status KVStateCacheBuilder::Query(
Client& client, const std::vector<int>& tokenList, int token,
const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(token);
Expand All @@ -240,7 +241,7 @@ Status KVStateCacheBuilder::Query(
(reinterpret_cast<TreeData*>(nodeData->treeData->data))
->kvStateCacheBlockBuilder);

return kvStateCacheBlockBuilder->Query(client, offset, kvState);
return kvStateCacheBlockBuilder->Query(offset, kvState);
}

void KVStateCacheBuilder::Delete(std::shared_ptr<NodeData> evictedNodeData) {
Expand All @@ -263,8 +264,7 @@ void KVStateCacheBuilder::Delete(std::shared_ptr<NodeData> evictedNodeData) {
evictedNodeData->RecycleSource();
}

Status KVStateCacheBuilder::Merge(Client& client,
std::shared_ptr<KVStateCache> kvStateCache) {
Status KVStateCacheBuilder::Merge(std::shared_ptr<KVStateCache> kvStateCache) {
if (kvStateCache == nullptr) {
return Status::OK();
}
Expand Down Expand Up @@ -312,8 +312,8 @@ Status KVStateCacheBuilder::Merge(Client& client,
kvState.insert(
std::make_pair(currentLayer, std::make_pair(key_state, value_state)));
}
globalCacheBuilder->Query(client, tokenList, (*it).back(), kvState);
this->Update(client, tokenList, (*it).back(), kvState);
globalCacheBuilder->Query(tokenList, (*it).back(), kvState);
this->Update(tokenList, (*it).back(), kvState);
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
LLMKV key_state = kvState[currentLayer].first;
LLMKV value_state = kvState[currentLayer].second;
Expand Down
11 changes: 5 additions & 6 deletions modules/llm-cache/ds/kv_state_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class KVStateCache : public vineyard::Registered<KVStateCache> {
};

class KVStateCacheBuilder : public vineyard::ObjectBuilder {
Client& client;
std::shared_ptr<RadixTree> rootTree;
int tensorBytes;
int layer;
Expand All @@ -97,21 +98,19 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder {
std::shared_ptr<KVStateCacheBuilder>& kvStateCacheBuilder,
std::shared_ptr<KVStateCache>& cache);

Status Split(Client& client,
KVStateCacheBlockBuilder* kvStateCacheBlockBuilder,
Status Split(KVStateCacheBlockBuilder* kvStateCacheBlockBuilder,
std::vector<std::shared_ptr<NodeData>> nodeDataList,
KVStateCacheBlockBuilder*& childKVStateCacheBlockBuilder);

Status Update(Client& client, const std::vector<int>& token_list,
int next_token,
Status Update(const std::vector<int>& token_list, int next_token,
const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

Status Query(Client& client, const std::vector<int>& token_list, int token,
Status Query(const std::vector<int>& token_list, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

void Delete(std::shared_ptr<NodeData> evicted_node);

Status Merge(Client& client, std::shared_ptr<KVStateCache> kv_state_cache);
Status Merge(std::shared_ptr<KVStateCache> kv_state_cache);

uint64_t GetVersion() { return this->version; }

Expand Down
9 changes: 5 additions & 4 deletions modules/llm-cache/ds/kv_state_cache_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ KVStateCacheBlock::~KVStateCacheBlock() { delete this->bitmap; }

KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client,
int tensorBytes, int layer,
int blockSize) {
int blockSize)
: client(client) {
this->blockSize = blockSize;
this->bitmapSize = (blockSize + 63) / 64;
this->bitmap = new uint64_t[this->bitmapSize];
Expand All @@ -99,7 +100,8 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client,
}

KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(
Client& client, std::shared_ptr<KVStateCacheBlock> kvStateCacheBlock) {
Client& client, std::shared_ptr<KVStateCacheBlock> kvStateCacheBlock)
: client(client) {
this->bitmapSize = kvStateCacheBlock->bitmapSize;
this->blockSize = kvStateCacheBlock->blockSize;
VLOG(100) << "create builder from block object, bitmap size:"
Expand Down Expand Up @@ -129,8 +131,7 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(
}

Status KVStateCacheBlockBuilder::Query(
Client& client, int index,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
int index, std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
RETURN_ON_ASSERT((index >= 0 && index < this->blockSize),
"Index out of range: " + std::to_string(index));
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
Expand Down
4 changes: 2 additions & 2 deletions modules/llm-cache/ds/kv_state_cache_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class KVStateCacheBlock : public vineyard::Registered<KVStateCacheBlock> {

class KVStateCacheBlockBuilder : public ObjectBuilder {
private:
Client& client;
std::vector<std::shared_ptr<TensorBuilder<uint8_t>>>
keyStateTensorBuilderList;
std::vector<std::shared_ptr<TensorBuilder<uint8_t>>>
Expand Down Expand Up @@ -150,8 +151,7 @@ class KVStateCacheBlockBuilder : public ObjectBuilder {
* @param kv_state The kv-state of the prompt returned by radix-tree. If the
* kv-state is not found, the data of kv-state is invalid.
*/
Status Query(Client& client, int index,
std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);
Status Query(int index, std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

bool IsFull();

Expand Down
6 changes: 3 additions & 3 deletions modules/llm-cache/ds/kv_state_cache_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ Status KVStateCacheManager::Make(Client& client,
Status KVStateCacheManager::UpdateInternal(
const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
return kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState);
return kvStateCacheBuilder->Update(tokenList, nextToken, kvState);
}

Status KVStateCacheManager::QueryInternal(
const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
return kvStateCacheBuilder->Query(client, tokenList, token, kvState);
return kvStateCacheBuilder->Query(tokenList, token, kvState);
}

Status KVStateCacheManager::Update(
Expand Down Expand Up @@ -222,7 +222,7 @@ Status KVStateCacheManager::Sync() {
: std::to_string(globalKVStateCache->GetVersion()));
if (globalKVStateCache != nullptr &&
kvStateCacheBuilder->GetVersion() < globalKVStateCache->GetVersion()) {
status = kvStateCacheBuilder->Merge(client, globalKVStateCache);
status = kvStateCacheBuilder->Merge(globalKVStateCache);
RETURN_ON_ERROR(status);
}
kvStateCacheBuilder->UpdateVersion();
Expand Down

0 comments on commit f0c3844

Please sign in to comment.