Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recycle resources of kv state cache object. #1748

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions modules/kv-state-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension,
std::shared_ptr<NodeData> rootTreeHeader = this->rootTree->GetRootNode();
rootTreeHeader->treeData->data = treeData;
rootTreeHeader->treeData->dataLength = sizeof(TreeData);
this->rootTree->SetSubtreeData(treeData, sizeof(TreeData));
this->rootTree->SetSubtreeData(treeData);
LOG(INFO) << "set builder:" << builder
<< " to tree:" << this->rootTree->GetRootTree()->head;
LOG(INFO) << "data:" << treeData
Expand All @@ -102,7 +102,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client,
std::set<void*> subTreeData = cache->rootTree->GetSubTreeDataSet();

for (auto iter = subTreeData.begin(); iter != subTreeData.end(); ++iter) {
TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data;
TreeData* treeData = (TreeData*) (*iter);
LOG(INFO) << "tree data:" << treeData;
VINEYARD_ASSERT(treeData->isPtr == false);
LOG(INFO) << "id:" << treeData->builderObjectID;
Expand Down Expand Up @@ -135,13 +135,12 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split(
kvStateCacheBlockBuilder->GetKeyStateBuilder();
const std::shared_ptr<TensorBuilder<double>> valueStateTensorBuilder =
kvStateCacheBlockBuilder->GetValueStateBuilder();
OffsetData* new_offset_data = new OffsetData();
OffsetData new_offset_data;
childKVStateCacheBlockBuilder->Update(
keyStateTensorBuilder->data() + index * this->dimension,
valueStateTensorBuilder->data() + index * this->dimension,
this->dimension, new_offset_data);
nodeDataList[i]->nodeData->data = new_offset_data;
nodeDataList[i]->nodeData->dataLength = sizeof(OffsetData);
this->dimension, &new_offset_data);
data->offset = new_offset_data.offset;
// Clear the bitmap.
kvStateCacheBlockBuilder->DeleteKVCache(index);
}
Expand Down Expand Up @@ -177,6 +176,13 @@ void KVStateCacheBuilder::Update(Client& client,
Delete(evictedNodeData);
}

// if (evictedNodeData->treeData != nullptr && evictedNodeData->nodeData !=
// nullptr) {
// if (evictedNodeData->nodeData->data != nullptr) {
// delete (TreeData*) evictedNodeData->nodeData->data;
// }
// }

// TBD
// Use lock to protect the kv_state_cache_builder
LOG(INFO) << "data:" << nodeData->treeData->data
Expand Down Expand Up @@ -204,7 +210,7 @@ void KVStateCacheBuilder::Update(Client& client,

subTreeHeader->treeData->data = newTreeData;
subTreeHeader->treeData->dataLength = sizeof(TreeData);
rootTree->SetSubtreeData(newTreeData, sizeof(TreeData));
rootTree->SetSubtreeData(newTreeData);
LOG(INFO) << "block split success";

// kv_state_cache_builder->UnLock();
Expand Down Expand Up @@ -257,6 +263,15 @@ void KVStateCacheBuilder::Delete(std::shared_ptr<NodeData> evictedNodeData) {
kvStateCacheBlockBuilder->DeleteKVCache(data->offset);
LOG(INFO) << "stage4";
delete data;
// TBD
// Refactor this code. The data should be deleted by the RadixTree
// delete (DataWrapper*) evictedNodeData->nodeData;
LOG(INFO) << "tree data:" << evictedNodeData->treeData->data;
if (evictedNodeData->cleanTreeData) {
LOG(INFO) << "erase";
this->rootTree->GetSubTreeDataSet().erase(evictedNodeData->treeData->data);
}
evictedNodeData->RecycleSource();
}

void KVStateCacheBuilder::Merge(Client& client,
Expand Down Expand Up @@ -320,11 +335,10 @@ std::shared_ptr<Object> KVStateCacheBuilder::_Seal(Client& client) {
// change the tree data from pointer to object id

int count = 0;
LOG(INFO) << "count:" << count;
std::set<void*> subTreeDataSet = rootTree->GetSubTreeDataSet();
for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end();
++iter) {
TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data;
TreeData* treeData = (TreeData*) (*iter);
VINEYARD_ASSERT(treeData != nullptr);
VINEYARD_ASSERT(treeData->isPtr == true);

Expand Down Expand Up @@ -357,12 +371,25 @@ std::shared_ptr<Object> KVStateCacheBuilder::_Seal(Client& client) {
}

KVStateCacheBuilder::~KVStateCacheBuilder() {
// TBD
// std::vector<std::shared_ptr<NodeData>> nodeDataList =
// RadixTree::TraverseTreeWithoutSubTree(this->rootTree);
// for (size_t i = 0; i < nodeDataList.size(); i++) {
// delete (OffsetData*) nodeDataList[i]->get_node()->get_data();
// }
LOG(INFO) << "KVStateCacheBuilder::~KVStateCacheBuilder";
// get all subtree data and node data
std::set<void*> subTreeDataSet = rootTree->GetSubTreeDataSet();
std::set<void*> nodeDataSet = rootTree->GetAllNodeData();
// 2. delete all subtree data and node data
for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end();
++iter) {
TreeData* treeData = (TreeData*) (*iter);
if (treeData->isPtr == true) {
delete (KVStateCacheBlockBuilder*) treeData->kvStateCacheBlockBuilder;
delete treeData;
}
}
for (auto iter = nodeDataSet.begin(); iter != nodeDataSet.end(); ++iter) {
OffsetData* data = (OffsetData*) (*iter);
if (data != nullptr) {
delete data;
}
}
}

} // namespace vineyard
135 changes: 97 additions & 38 deletions modules/kv-state-cache/radix-tree/radix-tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,23 @@ RadixTree::RadixTree(int cacheCapacity) {
data->data = nullptr;
data->dataLength = 0;
dataNode->custom_data = data;
LOG(INFO) << "root data wrapper:" << data;
dataNode->issubtree = true;
this->rootToken = rootToken;
}

RadixTree::~RadixTree() {
// TBD
// raxFreeWithCallback(this->tree, [](raxNode *n) {
// if (n->iskey && !n->isnull) {
// nodeData* nodedata = (nodeData*) raxGetData(n);
// delete nodedata;
// }
// if (n->issubtree && n->iscustomallocated && !n->iscustomnull) {
// customData* customdata = (customData*) raxGetCustomData(n);
// delete customdata;
// }
// });
LOG(INFO) << "~RadixTree";
raxShow(this->tree);

raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(),
rootToken.size(), NULL, false);
if (dataNode != nullptr) {
delete (DataWrapper*) dataNode->custom_data;
delete (DataWrapper*) raxGetData(dataNode);
}

raxFree(this->tree);
}

std::shared_ptr<NodeData> RadixTree::Insert(
Expand Down Expand Up @@ -154,15 +155,22 @@ void RadixTree::DeleteInternal(std::vector<int> tokens,
DataWrapper* oldData;
raxNode* subTreeNode;
std::vector<int> pre;
// raxFindAndReturnDataNode(this->tree, deleteTokensArray,
// deleteTokensArrayLen,
// &subTreeNode, false);
raxNode* dataNode = raxFindAndReturnDataNode(
this->tree, deleteTokensArray, deleteTokensArrayLen, &subTreeNode, false);
bool nodeIsSubTree = false;
if (dataNode != nullptr && dataNode->issubtree) {
nodeIsSubTree = true;
}
int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen,
(void**) &oldData, &subTreeNode);
(void**) &oldData);
if (retval == 1) {
evictedNode = std::make_shared<NodeData>(
oldData, (DataWrapper*) subTreeNode->custom_data);
nodeCount--;
if (nodeIsSubTree) {
// subTreeDataSet.erase(subTreeNode->custom_data);
evictedNode->cleanTreeData = true;
}
} else {
LOG(INFO) << "remove failed";
}
Expand Down Expand Up @@ -224,6 +232,15 @@ std::string RadixTree::Serialize() {

serializedStr += timestampOSS.str() + "|";

raxNode* node =
raxFindAndReturnDataNode(this->tree, tokenList[index].data(),
tokenList[index].size(), NULL, false);
uint32_t numNodes = node->numnodes;
std::ostringstream subTreeSizeOSS;
subTreeSizeOSS << std::hex << numNodes;

serializedStr += subTreeSizeOSS.str() + "|";

// convert data to hex string
char* bytes = (char*) ((DataWrapper*) dataList[index])->data;
std::ostringstream dataOSS;
Expand Down Expand Up @@ -251,7 +268,7 @@ std::string RadixTree::Serialize() {
char* bytes = (char*) ((DataWrapper*) subTreeDataList[index])->data;
std::ostringstream dataOSS;

LOG(INFO) << "data lengtπh:"
LOG(INFO) << "data length:"
<< ((DataWrapper*) subTreeDataList[index])->dataLength;
for (int i = 0; i < ((DataWrapper*) subTreeDataList[index])->dataLength;
++i) {
Expand All @@ -267,17 +284,21 @@ std::string RadixTree::Serialize() {
// use ZSTD to compress the serialized string
size_t srcSize = serializedStr.size();
std::string compressedStr(srcSize, '\0');
int compressedSize = ZSTD_compress((void *)(compressedStr.c_str()), compressedStr.length(),
serializedStr.c_str(), srcSize, 3);
int compressedSize =
ZSTD_compress((void*) (compressedStr.c_str()), compressedStr.length(),
serializedStr.c_str(), srcSize, 3);
if (ZSTD_isError(compressedSize)) {
LOG(ERROR) << "ZSTD compression failed: " << ZSTD_getErrorName(compressedSize);
LOG(ERROR) << "ZSTD compression failed: "
<< ZSTD_getErrorName(compressedSize);
}
int cacheCapacity = this->cacheCapacity - 1;

std::string result = std::string((char*) &srcSize, sizeof(int)) +
std::string((char*) &cacheCapacity, sizeof(int)) +
compressedStr;

std::string result =
std::string((char*) &srcSize, sizeof(int)) +
std::string((char*) &cacheCapacity, sizeof(int)) +
std::string((char*) &(this->tree->head->numnodes), sizeof(uint32_t)) +
compressedStr;

return result;
}

Expand All @@ -288,11 +309,15 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
data.erase(0, sizeof(int));
int cacheCapacity = *(int*) data.c_str();
data.erase(0, sizeof(int));
int rootNumNodes = *(uint32_t*) data.c_str();
data.erase(0, sizeof(uint32_t));
std::string decompressedStr(srcSize, '\0');
int decompressedSize = ZSTD_decompress((void *)(decompressedStr.c_str()), decompressedStr.size(),
data.c_str(), srcSize);
int decompressedSize =
ZSTD_decompress((void*) (decompressedStr.c_str()), decompressedStr.size(),
data.c_str(), srcSize);
if (ZSTD_isError(decompressedSize)) {
LOG(ERROR) << "ZSTD decompression failed: " << ZSTD_getErrorName(decompressedSize);
LOG(ERROR) << "ZSTD decompression failed: "
<< ZSTD_getErrorName(decompressedSize);
}
data = decompressedStr.substr(0, decompressedSize);

Expand All @@ -303,6 +328,7 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
std::vector<std::vector<int>> subTreeTokenList;
std::vector<void*> subTreeDataList;
std::vector<int> subTreeDataSizeList;
std::vector<int> subTreeSizeList;
std::istringstream iss(data);
std::string line;
bool isMainTree = true;
Expand All @@ -315,7 +341,7 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
}
LOG(INFO) << "data line:" << line << std::endl;
std::istringstream lineStream(line);
std::string tokenListPart, timestampPart, dataPart;
std::string tokenListPart, timestampPart, dataPart, subTreeSizePart;

if (!std::getline(lineStream, tokenListPart, '|')) {
throw std::runtime_error(
Expand All @@ -326,6 +352,10 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
throw std::runtime_error(
"Invalid serialized string format in timestamp part.");
}
if (!std::getline(lineStream, subTreeSizePart, '|')) {
throw std::runtime_error(
"Invalid serialized string format in sub tree size part.");
}
}
if (!std::getline(lineStream, dataPart)) {
LOG(INFO) << "data length is 0";
Expand All @@ -345,6 +375,15 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
LOG(INFO) << "Invalid timestamp format.";
throw std::runtime_error("Invalid timestamp format.");
}

std::istringstream subTreeSizeStream(subTreeSizePart);
uint32_t subTreeSize;
if (!(subTreeSizeStream >> std::hex >> subTreeSize)) {
LOG(INFO) << "Invalid sub tree size format.";
throw std::runtime_error("Invalid sub tree size format.");
}
LOG(INFO) << "Deserialize sub tree size:" << subTreeSize;
subTreeSizeList.push_back(subTreeSize);
}

size_t dataSize = dataPart.length() /
Expand Down Expand Up @@ -425,6 +464,16 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
}
dataNode->timestamp = timestampList[i];
}

for (size_t i = 0; i < tokenList.size(); i++) {
raxNode* node = raxFindAndReturnDataNode(
radixTree->tree, tokenList[i].data(), tokenList[i].size(), NULL, false);
LOG(INFO) << "node:" << node << " sub tree node num:" << subTreeSizeList[i];
node->numnodes = subTreeSizeList[i];
}
radixTree->tree->head->numnodes = rootNumNodes;
raxShow(radixTree->tree);

LOG(INFO) << "start to insert sub tree token list" << std::endl;
for (size_t i = 0; i < subTreeTokenList.size(); i++) {
for (size_t j = 0; j < subTreeTokenList[i].size(); j++) {
Expand All @@ -449,20 +498,18 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
node->issubtree = true;
raxSetCustomData(node, data);

// TBD
// refactor this code.
radixTree->subTreeDataSet.insert(data);
radixTree->subTreeDataSet.insert(subTreeDataList[i]);
}
LOG(INFO) << "Deserialize success";
raxShow(radixTree->tree);
return radixTree;
}

std::vector<std::shared_ptr<NodeData>> RadixTree::SplitInternal(
std::vector<int> tokens, std::shared_ptr<NodeData>& header) {
std::vector<int> rootToken;
DataWrapper* dummyData = new DataWrapper();
raxNode* subTreeRootNode =
raxSplit(this->tree, tokens.data(), tokens.size(), dummyData, rootToken);
raxSplit(this->tree, tokens.data(), tokens.size(), rootToken);

raxShow(this->tree);
subTreeRootNode->issubtree = true;
Expand Down Expand Up @@ -496,12 +543,9 @@ std::vector<std::shared_ptr<NodeData>> RadixTree::TraverseTreeWithoutSubTree(
return nodes;
}

void RadixTree::SetSubtreeData(void* data, int dataLength) {
LOG(INFO) << "set subtree data";
DataWrapper* dataWrapper = new DataWrapper();
dataWrapper->data = data;
dataWrapper->dataLength = dataLength;
subTreeDataSet.insert(dataWrapper);
void RadixTree::SetSubtreeData(void* data) {
LOG(INFO) << "set subtree data:" << data;
subTreeDataSet.insert(data);
}

std::shared_ptr<NodeData> RadixTree::GetRootNode() {
Expand Down Expand Up @@ -531,4 +575,19 @@ void RadixTree::MergeTree(std::shared_ptr<RadixTree> tree_1,
std::vector<int> tmp(vec.begin() + 1, vec.end());
insert_tokens.insert(tmp);
}
}

std::set<void*> RadixTree::GetAllNodeData() {
raxIterator iter;
raxStart(&iter, this->tree);
raxSeek(&iter, "^", NULL, 0);
std::set<void*> nodeDataSet;
while (raxNext(&iter)) {
raxNode* node = iter.node;
if (node->isnull) {
continue;
}
nodeDataSet.insert(((DataWrapper*) raxGetData(node))->data);
}
return nodeDataSet;
}
Loading
Loading