diff --git a/modules/llm-cache/ds/kv_state_cache.cc b/modules/llm-cache/ds/kv_state_cache.cc index 23ab37b19..fe8da7d4f 100644 --- a/modules/llm-cache/ds/kv_state_cache.cc +++ b/modules/llm-cache/ds/kv_state_cache.cc @@ -71,7 +71,8 @@ KVStateCache::~KVStateCache() {} KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int tensorBytes, int layer, - std::shared_ptr& rootTree) { + std::shared_ptr& rootTree) + : client(client) { this->tensorBytes = tensorBytes; this->version = 0; this->layer = layer; @@ -126,7 +127,7 @@ Status KVStateCacheBuilder::Make( } Status KVStateCacheBuilder::Split( - Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, std::vector> nodeDataList, KVStateCacheBlockBuilder*& childKVStateCacheBlockBuilder) { // Split the tree if the list of kvState is full. @@ -155,7 +156,7 @@ Status KVStateCacheBuilder::Split( } Status KVStateCacheBuilder::Update( - Client& client, const std::vector& tokenList, int nextToken, + const std::vector& tokenList, int nextToken, const std::map>& kvState) { std::vector tokenListCopy = tokenList; tokenListCopy.push_back(nextToken); @@ -189,7 +190,7 @@ Status KVStateCacheBuilder::Update( rootTree->Split(tokenListCopy, subTreeHeader); RETURN_ON_ASSERT(nodeDataList.size() != 0, "Split llm cache failed."); KVStateCacheBlockBuilder* newKVStateCacheBlockBuilder; - Status status = Split(client, kvStateCacheBlockBuilder, nodeDataList, + Status status = Split(kvStateCacheBlockBuilder, nodeDataList, newKVStateCacheBlockBuilder); RETURN_ON_ERROR(status); @@ -204,7 +205,7 @@ Status KVStateCacheBuilder::Update( VLOG(100) << "block split success"; // kv_state_cache_builder->UnLock(); - status = Update(client, tokenList, nextToken, kvState); + status = Update(tokenList, nextToken, kvState); RETURN_ON_ERROR(status); } else { // Update the kv-state cache. @@ -223,7 +224,7 @@ Status KVStateCacheBuilder::Update( } Status KVStateCacheBuilder::Query( - Client& client, const std::vector& tokenList, int token, + const std::vector& tokenList, int token, std::map>& kvState) { std::vector tokenListCopy = tokenList; tokenListCopy.push_back(token); @@ -240,7 +241,7 @@ Status KVStateCacheBuilder::Query( (reinterpret_cast(nodeData->treeData->data)) ->kvStateCacheBlockBuilder); - return kvStateCacheBlockBuilder->Query(client, offset, kvState); + return kvStateCacheBlockBuilder->Query(offset, kvState); } void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { @@ -263,8 +264,7 @@ void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { evictedNodeData->RecycleSource(); } -Status KVStateCacheBuilder::Merge(Client& client, - std::shared_ptr kvStateCache) { +Status KVStateCacheBuilder::Merge(std::shared_ptr kvStateCache) { if (kvStateCache == nullptr) { return Status::OK(); } @@ -312,8 +312,8 @@ Status KVStateCacheBuilder::Merge(Client& client, kvState.insert( std::make_pair(currentLayer, std::make_pair(key_state, value_state))); } - globalCacheBuilder->Query(client, tokenList, (*it).back(), kvState); - this->Update(client, tokenList, (*it).back(), kvState); + globalCacheBuilder->Query(tokenList, (*it).back(), kvState); + this->Update(tokenList, (*it).back(), kvState); for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { LLMKV key_state = kvState[currentLayer].first; LLMKV value_state = kvState[currentLayer].second; diff --git a/modules/llm-cache/ds/kv_state_cache.h b/modules/llm-cache/ds/kv_state_cache.h index ac0545a1f..f8ddd3435 100644 --- a/modules/llm-cache/ds/kv_state_cache.h +++ b/modules/llm-cache/ds/kv_state_cache.h @@ -77,6 +77,7 @@ class KVStateCache : public vineyard::Registered { }; class KVStateCacheBuilder : public vineyard::ObjectBuilder { + Client& client; std::shared_ptr rootTree; int tensorBytes; int layer; @@ -97,21 +98,19 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { std::shared_ptr& kvStateCacheBuilder, std::shared_ptr& cache); - Status Split(Client& client, - KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, + Status Split(KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, std::vector> nodeDataList, KVStateCacheBlockBuilder*& childKVStateCacheBlockBuilder); - Status Update(Client& client, const std::vector& token_list, - int next_token, + Status Update(const std::vector& token_list, int next_token, const std::map>& kv_state); - Status Query(Client& client, const std::vector& token_list, int token, + Status Query(const std::vector& token_list, int token, std::map>& kv_state); void Delete(std::shared_ptr evicted_node); - Status Merge(Client& client, std::shared_ptr kv_state_cache); + Status Merge(std::shared_ptr kv_state_cache); uint64_t GetVersion() { return this->version; } diff --git a/modules/llm-cache/ds/kv_state_cache_block.cc b/modules/llm-cache/ds/kv_state_cache_block.cc index 8620be704..bed4682b9 100644 --- a/modules/llm-cache/ds/kv_state_cache_block.cc +++ b/modules/llm-cache/ds/kv_state_cache_block.cc @@ -82,7 +82,8 @@ KVStateCacheBlock::~KVStateCacheBlock() { delete this->bitmap; } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, int tensorBytes, int layer, - int blockSize) { + int blockSize) + : client(client) { this->blockSize = blockSize; this->bitmapSize = (blockSize + 63) / 64; this->bitmap = new uint64_t[this->bitmapSize]; @@ -99,7 +100,8 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( - Client& client, std::shared_ptr kvStateCacheBlock) { + Client& client, std::shared_ptr kvStateCacheBlock) + : client(client) { this->bitmapSize = kvStateCacheBlock->bitmapSize; this->blockSize = kvStateCacheBlock->blockSize; VLOG(100) << "create builder from block object, bitmap size:" @@ -129,8 +131,7 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( } Status KVStateCacheBlockBuilder::Query( - Client& client, int index, - std::map>& kvState) { + int index, std::map>& kvState) { RETURN_ON_ASSERT((index >= 0 && index < this->blockSize), "Index out of range: " + std::to_string(index)); for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { diff --git a/modules/llm-cache/ds/kv_state_cache_block.h b/modules/llm-cache/ds/kv_state_cache_block.h index 6ad82a4ac..579a8fa44 100644 --- a/modules/llm-cache/ds/kv_state_cache_block.h +++ b/modules/llm-cache/ds/kv_state_cache_block.h @@ -111,6 +111,7 @@ class KVStateCacheBlock : public vineyard::Registered { class KVStateCacheBlockBuilder : public ObjectBuilder { private: + Client& client; std::vector>> keyStateTensorBuilderList; std::vector>> @@ -150,8 +151,7 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { * @param kv_state The kv-state of the prompt returned by radix-tree. If the * kv-state is not found, the data of kv-state is invalid. */ - Status Query(Client& client, int index, - std::map>& kv_state); + Status Query(int index, std::map>& kv_state); bool IsFull(); diff --git a/modules/llm-cache/ds/kv_state_cache_manager.cc b/modules/llm-cache/ds/kv_state_cache_manager.cc index 0c478ddd1..ab2983d19 100644 --- a/modules/llm-cache/ds/kv_state_cache_manager.cc +++ b/modules/llm-cache/ds/kv_state_cache_manager.cc @@ -86,13 +86,13 @@ Status KVStateCacheManager::Make(Client& client, Status KVStateCacheManager::UpdateInternal( const std::vector& tokenList, int nextToken, const std::map>& kvState) { - return kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); + return kvStateCacheBuilder->Update(tokenList, nextToken, kvState); } Status KVStateCacheManager::QueryInternal( const std::vector& tokenList, int token, std::map>& kvState) { - return kvStateCacheBuilder->Query(client, tokenList, token, kvState); + return kvStateCacheBuilder->Query(tokenList, token, kvState); } Status KVStateCacheManager::Update( @@ -222,7 +222,7 @@ Status KVStateCacheManager::Sync() { : std::to_string(globalKVStateCache->GetVersion())); if (globalKVStateCache != nullptr && kvStateCacheBuilder->GetVersion() < globalKVStateCache->GetVersion()) { - status = kvStateCacheBuilder->Merge(client, globalKVStateCache); + status = kvStateCacheBuilder->Merge(globalKVStateCache); RETURN_ON_ERROR(status); } kvStateCacheBuilder->UpdateVersion();