Skip to content

Commit

Permalink
Recycle resources of kv state cache object. (#1748)
Browse files Browse the repository at this point in the history
Fixes #1732

Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Feb 28, 2024
1 parent 024dc1e commit 03d6325
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 84 deletions.
57 changes: 42 additions & 15 deletions modules/kv-state-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension,
std::shared_ptr<NodeData> rootTreeHeader = this->rootTree->GetRootNode();
rootTreeHeader->treeData->data = treeData;
rootTreeHeader->treeData->dataLength = sizeof(TreeData);
this->rootTree->SetSubtreeData(treeData, sizeof(TreeData));
this->rootTree->SetSubtreeData(treeData);
LOG(INFO) << "set builder:" << builder
<< " to tree:" << this->rootTree->GetRootTree()->head;
LOG(INFO) << "data:" << treeData
Expand All @@ -102,7 +102,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client,
std::set<void*> subTreeData = cache->rootTree->GetSubTreeDataSet();

for (auto iter = subTreeData.begin(); iter != subTreeData.end(); ++iter) {
TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data;
TreeData* treeData = (TreeData*) (*iter);
LOG(INFO) << "tree data:" << treeData;
VINEYARD_ASSERT(treeData->isPtr == false);
LOG(INFO) << "id:" << treeData->builderObjectID;
Expand Down Expand Up @@ -135,13 +135,12 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split(
kvStateCacheBlockBuilder->GetKeyStateBuilder();
const std::shared_ptr<TensorBuilder<double>> valueStateTensorBuilder =
kvStateCacheBlockBuilder->GetValueStateBuilder();
OffsetData* new_offset_data = new OffsetData();
OffsetData new_offset_data;
childKVStateCacheBlockBuilder->Update(
keyStateTensorBuilder->data() + index * this->dimension,
valueStateTensorBuilder->data() + index * this->dimension,
this->dimension, new_offset_data);
nodeDataList[i]->nodeData->data = new_offset_data;
nodeDataList[i]->nodeData->dataLength = sizeof(OffsetData);
this->dimension, &new_offset_data);
data->offset = new_offset_data.offset;
// Clear the bitmap.
kvStateCacheBlockBuilder->DeleteKVCache(index);
}
Expand Down Expand Up @@ -177,6 +176,13 @@ void KVStateCacheBuilder::Update(Client& client,
Delete(evictedNodeData);
}

// if (evictedNodeData->treeData != nullptr && evictedNodeData->nodeData !=
// nullptr) {
// if (evictedNodeData->nodeData->data != nullptr) {
// delete (TreeData*) evictedNodeData->nodeData->data;
// }
// }

// TBD
// Use lock to protect the kv_state_cache_builder
LOG(INFO) << "data:" << nodeData->treeData->data
Expand Down Expand Up @@ -204,7 +210,7 @@ void KVStateCacheBuilder::Update(Client& client,

subTreeHeader->treeData->data = newTreeData;
subTreeHeader->treeData->dataLength = sizeof(TreeData);
rootTree->SetSubtreeData(newTreeData, sizeof(TreeData));
rootTree->SetSubtreeData(newTreeData);
LOG(INFO) << "block split success";

// kv_state_cache_builder->UnLock();
Expand Down Expand Up @@ -257,6 +263,15 @@ void KVStateCacheBuilder::Delete(std::shared_ptr<NodeData> evictedNodeData) {
kvStateCacheBlockBuilder->DeleteKVCache(data->offset);
LOG(INFO) << "stage4";
delete data;
// TBD
// Refactor this code. The data should be deleted by the RadixTree
// delete (DataWrapper*) evictedNodeData->nodeData;
LOG(INFO) << "tree data:" << evictedNodeData->treeData->data;
if (evictedNodeData->cleanTreeData) {
LOG(INFO) << "erase";
this->rootTree->GetSubTreeDataSet().erase(evictedNodeData->treeData->data);
}
evictedNodeData->RecycleSource();
}

void KVStateCacheBuilder::Merge(Client& client,
Expand Down Expand Up @@ -320,11 +335,10 @@ std::shared_ptr<Object> KVStateCacheBuilder::_Seal(Client& client) {
// change the tree data from pointer to object id

int count = 0;
LOG(INFO) << "count:" << count;
std::set<void*> subTreeDataSet = rootTree->GetSubTreeDataSet();
for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end();
++iter) {
TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data;
TreeData* treeData = (TreeData*) (*iter);
VINEYARD_ASSERT(treeData != nullptr);
VINEYARD_ASSERT(treeData->isPtr == true);

Expand Down Expand Up @@ -357,12 +371,25 @@ std::shared_ptr<Object> KVStateCacheBuilder::_Seal(Client& client) {
}

KVStateCacheBuilder::~KVStateCacheBuilder() {
// TBD
// std::vector<std::shared_ptr<NodeData>> nodeDataList =
// RadixTree::TraverseTreeWithoutSubTree(this->rootTree);
// for (size_t i = 0; i < nodeDataList.size(); i++) {
// delete (OffsetData*) nodeDataList[i]->get_node()->get_data();
// }
LOG(INFO) << "KVStateCacheBuilder::~KVStateCacheBuilder";
// get all subtree data and node data
std::set<void*> subTreeDataSet = rootTree->GetSubTreeDataSet();
std::set<void*> nodeDataSet = rootTree->GetAllNodeData();
// 2. delete all subtree data and node data
for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end();
++iter) {
TreeData* treeData = (TreeData*) (*iter);
if (treeData->isPtr == true) {
delete (KVStateCacheBlockBuilder*) treeData->kvStateCacheBlockBuilder;
delete treeData;
}
}
for (auto iter = nodeDataSet.begin(); iter != nodeDataSet.end(); ++iter) {
OffsetData* data = (OffsetData*) (*iter);
if (data != nullptr) {
delete data;
}
}
}

} // namespace vineyard
135 changes: 97 additions & 38 deletions modules/kv-state-cache/radix-tree/radix-tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,23 @@ RadixTree::RadixTree(int cacheCapacity) {
data->data = nullptr;
data->dataLength = 0;
dataNode->custom_data = data;
LOG(INFO) << "root data wrapper:" << data;
dataNode->issubtree = true;
this->rootToken = rootToken;
}

RadixTree::~RadixTree() {
// TBD
// raxFreeWithCallback(this->tree, [](raxNode *n) {
// if (n->iskey && !n->isnull) {
// nodeData* nodedata = (nodeData*) raxGetData(n);
// delete nodedata;
// }
// if (n->issubtree && n->iscustomallocated && !n->iscustomnull) {
// customData* customdata = (customData*) raxGetCustomData(n);
// delete customdata;
// }
// });
LOG(INFO) << "~RadixTree";
raxShow(this->tree);

raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(),
rootToken.size(), NULL, false);
if (dataNode != nullptr) {
delete (DataWrapper*) dataNode->custom_data;
delete (DataWrapper*) raxGetData(dataNode);
}

raxFree(this->tree);
}

std::shared_ptr<NodeData> RadixTree::Insert(
Expand Down Expand Up @@ -154,15 +155,22 @@ void RadixTree::DeleteInternal(std::vector<int> tokens,
DataWrapper* oldData;
raxNode* subTreeNode;
std::vector<int> pre;
// raxFindAndReturnDataNode(this->tree, deleteTokensArray,
// deleteTokensArrayLen,
// &subTreeNode, false);
raxNode* dataNode = raxFindAndReturnDataNode(
this->tree, deleteTokensArray, deleteTokensArrayLen, &subTreeNode, false);
bool nodeIsSubTree = false;
if (dataNode != nullptr && dataNode->issubtree) {
nodeIsSubTree = true;
}
int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen,
(void**) &oldData, &subTreeNode);
(void**) &oldData);
if (retval == 1) {
evictedNode = std::make_shared<NodeData>(
oldData, (DataWrapper*) subTreeNode->custom_data);
nodeCount--;
if (nodeIsSubTree) {
// subTreeDataSet.erase(subTreeNode->custom_data);
evictedNode->cleanTreeData = true;
}
} else {
LOG(INFO) << "remove failed";
}
Expand Down Expand Up @@ -224,6 +232,15 @@ std::string RadixTree::Serialize() {

serializedStr += timestampOSS.str() + "|";

raxNode* node =
raxFindAndReturnDataNode(this->tree, tokenList[index].data(),
tokenList[index].size(), NULL, false);
uint32_t numNodes = node->numnodes;
std::ostringstream subTreeSizeOSS;
subTreeSizeOSS << std::hex << numNodes;

serializedStr += subTreeSizeOSS.str() + "|";

// convert data to hex string
char* bytes = (char*) ((DataWrapper*) dataList[index])->data;
std::ostringstream dataOSS;
Expand Down Expand Up @@ -251,7 +268,7 @@ std::string RadixTree::Serialize() {
char* bytes = (char*) ((DataWrapper*) subTreeDataList[index])->data;
std::ostringstream dataOSS;

LOG(INFO) << "data lengtπh:"
LOG(INFO) << "data length:"
<< ((DataWrapper*) subTreeDataList[index])->dataLength;
for (int i = 0; i < ((DataWrapper*) subTreeDataList[index])->dataLength;
++i) {
Expand All @@ -267,17 +284,21 @@ std::string RadixTree::Serialize() {
// use ZSTD to compress the serialized string
size_t srcSize = serializedStr.size();
std::string compressedStr(srcSize, '\0');
int compressedSize = ZSTD_compress((void *)(compressedStr.c_str()), compressedStr.length(),
serializedStr.c_str(), srcSize, 3);
int compressedSize =
ZSTD_compress((void*) (compressedStr.c_str()), compressedStr.length(),
serializedStr.c_str(), srcSize, 3);
if (ZSTD_isError(compressedSize)) {
LOG(ERROR) << "ZSTD compression failed: " << ZSTD_getErrorName(compressedSize);
LOG(ERROR) << "ZSTD compression failed: "
<< ZSTD_getErrorName(compressedSize);
}
int cacheCapacity = this->cacheCapacity - 1;

std::string result = std::string((char*) &srcSize, sizeof(int)) +
std::string((char*) &cacheCapacity, sizeof(int)) +
compressedStr;

std::string result =
std::string((char*) &srcSize, sizeof(int)) +
std::string((char*) &cacheCapacity, sizeof(int)) +
std::string((char*) &(this->tree->head->numnodes), sizeof(uint32_t)) +
compressedStr;

return result;
}

Expand All @@ -288,11 +309,15 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
data.erase(0, sizeof(int));
int cacheCapacity = *(int*) data.c_str();
data.erase(0, sizeof(int));
int rootNumNodes = *(uint32_t*) data.c_str();
data.erase(0, sizeof(uint32_t));
std::string decompressedStr(srcSize, '\0');
int decompressedSize = ZSTD_decompress((void *)(decompressedStr.c_str()), decompressedStr.size(),
data.c_str(), srcSize);
int decompressedSize =
ZSTD_decompress((void*) (decompressedStr.c_str()), decompressedStr.size(),
data.c_str(), srcSize);
if (ZSTD_isError(decompressedSize)) {
LOG(ERROR) << "ZSTD decompression failed: " << ZSTD_getErrorName(decompressedSize);
LOG(ERROR) << "ZSTD decompression failed: "
<< ZSTD_getErrorName(decompressedSize);
}
data = decompressedStr.substr(0, decompressedSize);

Expand All @@ -303,6 +328,7 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
std::vector<std::vector<int>> subTreeTokenList;
std::vector<void*> subTreeDataList;
std::vector<int> subTreeDataSizeList;
std::vector<int> subTreeSizeList;
std::istringstream iss(data);
std::string line;
bool isMainTree = true;
Expand All @@ -315,7 +341,7 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
}
LOG(INFO) << "data line:" << line << std::endl;
std::istringstream lineStream(line);
std::string tokenListPart, timestampPart, dataPart;
std::string tokenListPart, timestampPart, dataPart, subTreeSizePart;

if (!std::getline(lineStream, tokenListPart, '|')) {
throw std::runtime_error(
Expand All @@ -326,6 +352,10 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
throw std::runtime_error(
"Invalid serialized string format in timestamp part.");
}
if (!std::getline(lineStream, subTreeSizePart, '|')) {
throw std::runtime_error(
"Invalid serialized string format in sub tree size part.");
}
}
if (!std::getline(lineStream, dataPart)) {
LOG(INFO) << "data length is 0";
Expand All @@ -345,6 +375,15 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
LOG(INFO) << "Invalid timestamp format.";
throw std::runtime_error("Invalid timestamp format.");
}

std::istringstream subTreeSizeStream(subTreeSizePart);
uint32_t subTreeSize;
if (!(subTreeSizeStream >> std::hex >> subTreeSize)) {
LOG(INFO) << "Invalid sub tree size format.";
throw std::runtime_error("Invalid sub tree size format.");
}
LOG(INFO) << "Deserialize sub tree size:" << subTreeSize;
subTreeSizeList.push_back(subTreeSize);
}

size_t dataSize = dataPart.length() /
Expand Down Expand Up @@ -425,6 +464,16 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
}
dataNode->timestamp = timestampList[i];
}

for (size_t i = 0; i < tokenList.size(); i++) {
raxNode* node = raxFindAndReturnDataNode(
radixTree->tree, tokenList[i].data(), tokenList[i].size(), NULL, false);
LOG(INFO) << "node:" << node << " sub tree node num:" << subTreeSizeList[i];
node->numnodes = subTreeSizeList[i];
}
radixTree->tree->head->numnodes = rootNumNodes;
raxShow(radixTree->tree);

LOG(INFO) << "start to insert sub tree token list" << std::endl;
for (size_t i = 0; i < subTreeTokenList.size(); i++) {
for (size_t j = 0; j < subTreeTokenList[i].size(); j++) {
Expand All @@ -449,20 +498,18 @@ std::shared_ptr<RadixTree> RadixTree::Deserialize(std::string data) {
node->issubtree = true;
raxSetCustomData(node, data);

// TBD
// refactor this code.
radixTree->subTreeDataSet.insert(data);
radixTree->subTreeDataSet.insert(subTreeDataList[i]);
}
LOG(INFO) << "Deserialize success";
raxShow(radixTree->tree);
return radixTree;
}

std::vector<std::shared_ptr<NodeData>> RadixTree::SplitInternal(
std::vector<int> tokens, std::shared_ptr<NodeData>& header) {
std::vector<int> rootToken;
DataWrapper* dummyData = new DataWrapper();
raxNode* subTreeRootNode =
raxSplit(this->tree, tokens.data(), tokens.size(), dummyData, rootToken);
raxSplit(this->tree, tokens.data(), tokens.size(), rootToken);

raxShow(this->tree);
subTreeRootNode->issubtree = true;
Expand Down Expand Up @@ -496,12 +543,9 @@ std::vector<std::shared_ptr<NodeData>> RadixTree::TraverseTreeWithoutSubTree(
return nodes;
}

void RadixTree::SetSubtreeData(void* data, int dataLength) {
LOG(INFO) << "set subtree data";
DataWrapper* dataWrapper = new DataWrapper();
dataWrapper->data = data;
dataWrapper->dataLength = dataLength;
subTreeDataSet.insert(dataWrapper);
void RadixTree::SetSubtreeData(void* data) {
LOG(INFO) << "set subtree data:" << data;
subTreeDataSet.insert(data);
}

std::shared_ptr<NodeData> RadixTree::GetRootNode() {
Expand Down Expand Up @@ -531,4 +575,19 @@ void RadixTree::MergeTree(std::shared_ptr<RadixTree> tree_1,
std::vector<int> tmp(vec.begin() + 1, vec.end());
insert_tokens.insert(tmp);
}
}

std::set<void*> RadixTree::GetAllNodeData() {
raxIterator iter;
raxStart(&iter, this->tree);
raxSeek(&iter, "^", NULL, 0);
std::set<void*> nodeDataSet;
while (raxNext(&iter)) {
raxNode* node = iter.node;
if (node->isnull) {
continue;
}
nodeDataSet.insert(((DataWrapper*) raxGetData(node))->data);
}
return nodeDataSet;
}
Loading

0 comments on commit 03d6325

Please sign in to comment.