Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the radix tree to the thirdparty and update the type of kv_state #1779

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ endfunction()

file_glob_recurse(FILES_NEED_FORMAT DIRECTORIES "src" "modules" "python" "test" "benchmark"
PATTERNS ".*\\.(cc|cpp|h|hpp|vineyard-mod)$"
EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(.*modules/llm-cache/radix-tree/radix\.(cc|h)$)"
EXCLUDE_PATTERNS "(.*\\.vineyard.h$)"
)

# the `memcpy.h` is borrowed from external project
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ SOFTWARE.

-------------------------------------------------------------------------------

The files modules/llm-cache/radix-tree/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax,
The files thirdparty/rax/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax,
which has the following license:

Copyright (c) 2017, Salvatore Sanfilippo <[email protected]>
Expand Down
2 changes: 2 additions & 0 deletions modules/llm-cache/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ file(GLOB VINEYARD_LLM_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}"
"ds/*.h"
"radix-tree/*.cc"
"radix-tree/*.h"
"${PROJECT_SOURCE_DIR}/thirdparty/rax/*.cc"
"${PROJECT_SOURCE_DIR}/thirdparty/rax/*.h"
)

add_library(vineyard_llm_cache ${VINEYARD_LLM_CACHE_SRCS})
Expand Down
36 changes: 28 additions & 8 deletions modules/llm-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <utility>

#include "client/client.h"
#include "common/util/base64.h"
#include "common/util/logging.h"
#include "common/util/status.h"
#include "llm-cache/ds/kv_state_cache.h"
#include "llm-cache/radix-tree/radix-tree.h"
#include "llm-cache/radix-tree/radix.h"

#include "rax/radix.h"

namespace vineyard {

Expand Down Expand Up @@ -197,12 +199,12 @@ void KVStateCacheBuilder::Update(Client& client,
<< " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr();
}

KV_STATE_WITH_LAYER KVStateCacheBuilder::Query(
Client& client, const std::vector<int>& tokenList, int token) {
int KVStateCacheBuilder::Query(Client& client,
const std::vector<int>& tokenList, int token,
KV_STATE_WITH_LAYER& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(token);

KV_STATE_WITH_LAYER kvState;
std::shared_ptr<NodeData> nodeData = this->rootTree->Query(tokenListCopy);

if (nodeData != nullptr) {
Expand All @@ -214,9 +216,9 @@ KV_STATE_WITH_LAYER KVStateCacheBuilder::Query(
(reinterpret_cast<TreeData*>(nodeData->treeData->data))
->kvStateCacheBlockBuilder);

kvStateCacheBlockBuilder->Query(client, offset, kvState);
return kvStateCacheBlockBuilder->Query(client, offset, kvState);
}
return kvState;
return -1;
}

void KVStateCacheBuilder::Delete(std::shared_ptr<NodeData> evictedNodeData) {
Expand Down Expand Up @@ -273,10 +275,28 @@ void KVStateCacheBuilder::Merge(Client& client,
for (auto it = insertTokenList.begin(); it != insertTokenList.end(); ++it) {
std::vector<int> tokenList =
std::vector<int>((*it).begin(), (*it).end() - 1);
KV_STATE_WITH_LAYER kvState =
globalCacheBuilder->Query(client, tokenList, (*it).back());
KV_STATE_WITH_LAYER kvState;
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
K_STATE key_state;
V_STATE value_state;
key_state.data = malloc(this->dimension * sizeof(double));
key_state.length = this->dimension * sizeof(double);
value_state.data = malloc(this->dimension * sizeof(double));
value_state.length = this->dimension * sizeof(double);

kvState.insert(
std::make_pair(currentLayer, std::make_pair(key_state, value_state)));
}
globalCacheBuilder->Query(client, tokenList, (*it).back(), kvState);
this->Update(client, tokenList, (*it).back(), kvState);
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
K_STATE key_state = kvState[currentLayer].first;
V_STATE value_state = kvState[currentLayer].second;
free(key_state.data);
free(value_state.data);
}
}

this->version = globalCacheBuilder->GetVersion();
return;
}
Expand Down
4 changes: 2 additions & 2 deletions modules/llm-cache/ds/kv_state_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder {
void Update(Client& client, const std::vector<int>& token_list,
int next_token, const KV_STATE_WITH_LAYER& kv_state);

KV_STATE_WITH_LAYER Query(Client& client, const std::vector<int>& token_list,
int token);
int Query(Client& client, const std::vector<int>& token_list, int token,
KV_STATE_WITH_LAYER& kv_state);

void Delete(std::shared_ptr<NodeData> evicted_node);

Expand Down
44 changes: 18 additions & 26 deletions modules/llm-cache/ds/kv_state_cache_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,18 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(
}

// current we do not consider the layer.
Status KVStateCacheBlockBuilder::Query(Client& client, int index,
KV_STATE_WITH_LAYER& kvState) {
int KVStateCacheBlockBuilder::Query(Client& client, int index,
KV_STATE_WITH_LAYER& kvState) {
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
std::vector<double> keyStateVector;
std::vector<double> valueStateVector;

for (int i = 0; i < this->dimension; ++i) {
keyStateVector.push_back((keyStateTensorBuilderList[currentLayer]
->data())[index * dimension + i]);
}

for (int i = 0; i < this->dimension; ++i) {
valueStateVector.push_back((valueStateTensorBuilderList[currentLayer]
->data())[index * dimension + i]);
}

kvState.insert(std::make_pair(
currentLayer, std::make_pair(keyStateVector, valueStateVector)));
memcpy((kvState.find(currentLayer)->second).first.data,
keyStateTensorBuilderList[currentLayer]->data() + index * dimension,
dimension * sizeof(double));
memcpy(
(kvState.find(currentLayer)->second).second.data,
valueStateTensorBuilderList[currentLayer]->data() + index * dimension,
dimension * sizeof(double));
}
return Status::OK();
return 0;
}

int KVStateCacheBlockBuilder::FindEmptySlot() {
Expand Down Expand Up @@ -176,18 +168,18 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState,
OffsetData* data) {
int index = this->FindEmptySlot();
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
std::vector<double> keyStateVector =
(kvState.find(currentLayer)->second).first;
std::vector<double> valueStateVector =
(kvState.find(currentLayer)->second).second;
VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension);
VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension);
K_STATE keyState = (kvState.find(currentLayer)->second).first;
V_STATE valueState = (kvState.find(currentLayer)->second).second;
VINEYARD_ASSERT(keyState.length ==
(size_t) this->dimension * sizeof(double));
VINEYARD_ASSERT(valueState.length ==
(size_t) this->dimension * sizeof(double));

double* keyData = keyStateTensorBuilderList[currentLayer]->data();
double* valueData = valueStateTensorBuilderList[currentLayer]->data();
memcpy(keyData + index * this->dimension, keyStateVector.data(),
memcpy(keyData + index * this->dimension, keyState.data,
this->dimension * sizeof(double));
memcpy(valueData + index * this->dimension, valueStateVector.data(),
memcpy(valueData + index * this->dimension, valueState.data,
this->dimension * sizeof(double));
}
data->offset = index;
Expand Down
23 changes: 14 additions & 9 deletions modules/llm-cache/ds/kv_state_cache_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,19 @@ limitations under the License.
#include "client/ds/i_object.h"
#include "llm-cache/radix-tree/radix-tree.h"

using KV_STATE_WITH_LAYER =
std::map<int, std::pair<std::vector<double>, std::vector<double>>>;
using LIST_KV_STATE_WITH_LAYER = std::vector<
std::map<int, std::pair<std::vector<double>, std::vector<double>>>>;
using KV_STATE =
std::vector<std::pair<std::vector<double>, std::vector<double>>>;
using LIST_KV_STATE =
std::vector<std::pair<std::vector<double>, std::vector<double>>>;
struct State {
void* data;
size_t length;
};

using K_STATE = State;
using V_STATE = State;

using KV_STATE_WITH_LAYER = std::map<int, std::pair<K_STATE, V_STATE>>;
using LIST_KV_STATE_WITH_LAYER =
std::vector<std::map<int, std::pair<K_STATE, V_STATE>>>;
using KV_STATE = std::vector<std::pair<K_STATE, V_STATE>>;
using LIST_KV_STATE = std::vector<std::pair<K_STATE, V_STATE>>;

// Set the bit to 1, which means the resource is not being used
#define FREE_BIT_RESOURCE(value, bit) ((value) |= (((uint64_t) 1) << (bit)))
Expand Down Expand Up @@ -155,7 +160,7 @@ class KVStateCacheBlockBuilder : public ObjectBuilder {
* @param kv_state The kv-state of the prompt returned by radix-tree. If the
* kv-state is not found, the data of kv-state is invalid.
*/
Status Query(Client& client, int index, KV_STATE_WITH_LAYER& kv_state);
int Query(Client& client, int index, KV_STATE_WITH_LAYER& kv_state);

bool IsFull();

Expand Down
28 changes: 14 additions & 14 deletions modules/llm-cache/ds/kv_state_cache_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ void KVStateCacheManager::UpdateInternal(const std::vector<int>& tokenList,
kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState);
}

KV_STATE_WITH_LAYER KVStateCacheManager::QueryInternal(
const std::vector<int>& tokenList, int token) {
return kvStateCacheBuilder->Query(client, tokenList, token);
int KVStateCacheManager::QueryInternal(const std::vector<int>& tokenList,
int token,
KV_STATE_WITH_LAYER& kvState) {
return kvStateCacheBuilder->Query(client, tokenList, token, kvState);
}

void KVStateCacheManager::Update(const std::vector<int>& tokenList,
Expand Down Expand Up @@ -113,36 +114,35 @@ void KVStateCacheManager::Update(const std::vector<int>& tokenList,
syncMutex.unlock();
}

KV_STATE_WITH_LAYER KVStateCacheManager::Query(
const std::vector<int>& tokenList, int token) {
KV_STATE_WITH_LAYER result;
int KVStateCacheManager::Query(const std::vector<int>& tokenList, int token,
KV_STATE_WITH_LAYER& kvState) {
int result = -1;

if (!syncMutex.try_lock()) {
return result;
}

result = QueryInternal(tokenList, token);
result = QueryInternal(tokenList, token, kvState);
syncMutex.unlock();

return result;
}

LIST_KV_STATE_WITH_LAYER KVStateCacheManager::Query(
const std::vector<int>& tokenList) {
LIST_KV_STATE_WITH_LAYER listKVState;
int KVStateCacheManager::Query(const std::vector<int>& tokenList,
LIST_KV_STATE_WITH_LAYER& listKVState) {
int result = -1;
if (!syncMutex.try_lock()) {
return listKVState;
return result;
}

std::vector<int> tokenListCopy;
for (size_t i = 0; i < tokenList.size(); i++) {
KV_STATE_WITH_LAYER kvState = QueryInternal(tokenListCopy, tokenList[i]);
listKVState.push_back(kvState);
result = QueryInternal(tokenListCopy, tokenList[i], listKVState[i]);
tokenListCopy.push_back(tokenList[i]);
}

syncMutex.unlock();
return listKVState;
return result;
}

KVStateCacheManager::~KVStateCacheManager() {
Expand Down
10 changes: 6 additions & 4 deletions modules/llm-cache/ds/kv_state_cache_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,20 @@ class KVStateCacheManager {
void Update(const std::vector<int>& tokenList,
const LIST_KV_STATE_WITH_LAYER& kvState);

KV_STATE_WITH_LAYER Query(const std::vector<int>& tokenList, int token);
int Query(const std::vector<int>& tokenList, int token,
KV_STATE_WITH_LAYER& kvState);

LIST_KV_STATE_WITH_LAYER Query(const std::vector<int>& tokenList);
int Query(const std::vector<int>& tokenList,
LIST_KV_STATE_WITH_LAYER& listKVState);

~KVStateCacheManager();

private:
void UpdateInternal(const std::vector<int>& tokenList, int nextToken,
const KV_STATE_WITH_LAYER& kvState);

KV_STATE_WITH_LAYER QueryInternal(const std::vector<int>& tokenList,
int token);
int QueryInternal(const std::vector<int>& tokenList, int token,
KV_STATE_WITH_LAYER& kvState);

void Delete(std::vector<int> token);

Expand Down
11 changes: 4 additions & 7 deletions modules/llm-cache/radix-tree/radix-tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,17 +359,14 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
std::string tokenListPart, timestampPart, dataPart, subTreeSizePart;

if (!std::getline(lineStream, tokenListPart, '|')) {
throw std::runtime_error(
"Invalid serialized string format in token list part.");
LOG(ERROR) << "Invalid serialized string format in token list part.";
}
if (isMainTree) {
if (!std::getline(lineStream, timestampPart, '|')) {
throw std::runtime_error(
"Invalid serialized string format in timestamp part.");
LOG(ERROR) << "Invalid serialized string format in timestamp part.";
}
if (!std::getline(lineStream, subTreeSizePart, '|')) {
throw std::runtime_error(
"Invalid serialized string format in sub tree size part.");
LOG(ERROR) << "Invalid serialized string format in sub tree size part.";
}
}
if (!std::getline(lineStream, dataPart)) {
Expand Down Expand Up @@ -471,7 +468,7 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
reinterpret_cast<void**>(&dataNode), NULL);

if (dataNode == NULL) {
throw std::runtime_error("Insert token list failed");
LOG(ERROR) << "Insert token list failed";
}
dataNode->timestamp = timestampList[i];
}
Expand Down
2 changes: 1 addition & 1 deletion modules/llm-cache/radix-tree/radix-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#ifndef MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_
#define MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_

#include "llm-cache/radix-tree/radix.h"
#include "rax/radix.h"

#include <iomanip>
#include <map>
Expand Down
Loading
Loading