Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
royi-luo committed Nov 18, 2024
1 parent f6f618d commit 98076f8
Showing 12 changed files with 54 additions and 32 deletions.
6 changes: 4 additions & 2 deletions src/include/storage/store/node_group_collection.h
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
#include "storage/stats/table_stats.h"
#include "storage/store/group_collection.h"
#include "storage/store/node_group.h"
#include "transaction/transaction.h"

namespace kuzu {
namespace transaction {
@@ -15,8 +16,8 @@ class MemoryManager;
class NodeGroupCollection {
public:
NodeGroupCollection(MemoryManager& memoryManager, const std::vector<common::LogicalType>& types,
bool enableCompression, FileHandle* dataFH = nullptr,
common::Deserializer* deSer = nullptr);
bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr,
const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr);

void append(const transaction::Transaction* transaction,
const std::vector<common::ValueVector*>& vectors);
@@ -95,6 +96,7 @@ class NodeGroupCollection {
GroupCollection<NodeGroup> nodeGroups;
FileHandle* dataFH;
TableStats stats;
const transaction::rollback_insert_func_t* rollbackInsertFunc;
};

} // namespace storage
6 changes: 4 additions & 2 deletions src/include/storage/store/node_table.h
Original file line number Diff line number Diff line change
@@ -177,7 +177,9 @@ class NodeTable final : public Table {

TableStats getStats(const transaction::Transaction* transaction) const;

const rollback_insert_func_t& getRollbackInsertFunc() const { return rollbackInsertFunc; }
const transaction::rollback_insert_func_t& getRollbackInsertFunc() const {
return rollbackInsertFunc;
}

private:
void insertPK(const transaction::Transaction* transaction,
@@ -198,7 +200,7 @@ class NodeTable final : public Table {
std::unique_ptr<NodeGroupCollection> nodeGroups;
common::column_id_t pkColumnID;
std::unique_ptr<PrimaryKeyIndex> pkIndex;
rollback_insert_func_t rollbackInsertFunc;
transaction::rollback_insert_func_t rollbackInsertFunc;
};

} // namespace storage
2 changes: 2 additions & 0 deletions src/include/storage/store/rel_table_data.h
Original file line number Diff line number Diff line change
@@ -113,6 +113,8 @@ class RelTableData {

CSRHeaderColumns csrHeaderColumns;
std::vector<std::unique_ptr<Column>> columns;

transaction::rollback_insert_func_t rollbackInsertFunc;
};

} // namespace storage
3 changes: 0 additions & 3 deletions src/include/storage/store/table.h
Original file line number Diff line number Diff line change
@@ -13,9 +13,6 @@ class ExpressionEvaluator;
namespace storage {
class MemoryManager;

using rollback_insert_func_t = std::function<void(const transaction::Transaction*,
common::row_idx_t, common::row_idx_t, common::node_group_idx_t)>;

enum class TableScanSource : uint8_t { COMMITTED = 0, UNCOMMITTED = 1, NONE = UINT8_MAX };

struct TableScanState {
5 changes: 3 additions & 2 deletions src/include/storage/undo_buffer.h
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
#include "common/constants.h"
#include "common/types/types.h"
#include "storage/enums/csr_node_group_scan_source.h"
#include "storage/store/table.h"
#include "transaction/transaction.h"

namespace kuzu {
namespace catalog {
@@ -90,6 +90,7 @@ class UndoBuffer {
const catalog::SequenceRollbackData& data);
void createInsertInfo(NodeGroup* nodeGroup, common::row_idx_t startRow,
common::row_idx_t numRows,
const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr,
storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE);
void createDeleteInfo(NodeGroup* nodeGroup, common::row_idx_t startRow,
common::row_idx_t numRows, storage::CSRNodeGroupScanSource source);
@@ -107,7 +108,7 @@ class UndoBuffer {
void createVersionInfo(UndoRecordType recordType, NodeGroup* nodeGroup,
common::row_idx_t startRow, common::row_idx_t numRows,
storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE,
const rollback_insert_func_t* preRollbackCallback = nullptr);
const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr);

void commitRecord(UndoRecordType recordType, const uint8_t* record,
common::transaction_t commitTS) const;
8 changes: 7 additions & 1 deletion src/include/transaction/transaction.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <functional>

#include "common/enums/statement_type.h"
#include "common/types/types.h"
#include "storage/enums/csr_node_group_scan_source.h"
@@ -28,6 +30,10 @@ namespace transaction {
class TransactionManager;
class Transaction;

using rollback_insert_func_t =
std::function<void(const transaction::Transaction*, common::row_idx_t, common::row_idx_t,
common::node_group_idx_t, storage::CSRNodeGroupScanSource)>;

enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY };

class Transaction {
@@ -117,7 +123,7 @@ class Transaction {
void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount,
const catalog::SequenceRollbackData& data) const;
void pushInsertInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow,
common::row_idx_t numRows,
common::row_idx_t numRows, const rollback_insert_func_t* rollbackInsertFunc = nullptr,
storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE) const;
void pushDeleteInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow,
common::row_idx_t numRows,
2 changes: 1 addition & 1 deletion src/storage/store/csr_node_group.cpp
Original file line number Diff line number Diff line change
@@ -956,7 +956,7 @@ std::pair<idx_t, row_idx_t> CSRNodeGroup::actionOnChunkedGroups(const common::Un
if (persistentChunkGroup) {
std::invoke(operation, *persistentChunkGroup, startRow, numRows_, commitTS);
}
return {UINT32_MAX, UINT32_MAX};
return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX};
} else {
KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY);
return NodeGroup::actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source,
6 changes: 3 additions & 3 deletions src/storage/store/node_group_collection.cpp
Original file line number Diff line number Diff line change
@@ -15,9 +15,9 @@ namespace storage {

NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager,
const std::vector<LogicalType>& types, const bool enableCompression, FileHandle* dataFH,
Deserializer* deSer)
Deserializer* deSer, const transaction::rollback_insert_func_t* rollbackInsertFunc)
: enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)},
dataFH{dataFH} {
dataFH{dataFH}, rollbackInsertFunc(rollbackInsertFunc) {
if (deSer) {
deserialize(*deSer, memoryManager);
}
@@ -238,7 +238,7 @@ void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transac
storage::CSRNodeGroupScanSource source) {
// we only append to the undo buffer if the node group collection is persistent
if (dataFH && transaction->shouldAppendToUndoBuffer()) {
transaction->pushInsertInfo(nodeGroup, startRow, numRows, source);
transaction->pushInsertInfo(nodeGroup, startRow, numRows, rollbackInsertFunc, source);
}
if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) {
numTotalRows += numRows;
13 changes: 7 additions & 6 deletions src/storage/store/node_table.cpp
Original file line number Diff line number Diff line change
@@ -220,16 +220,17 @@ NodeTable::NodeTable(const StorageManager* storageManager,
dataFH, memoryManager, shadowFile, enableCompression);
}

nodeGroups = std::make_unique<NodeGroupCollection>(*memoryManager,
getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer);
initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry,
storageManager->isReadOnly(), vfs, context);

rollbackInsertFunc = [this](const transaction::Transaction* transaction,
common::row_idx_t startRow, common::row_idx_t numRows_,
common::node_group_idx_t nodeGroupIdx_) {
common::node_group_idx_t nodeGroupIdx_, CSRNodeGroupScanSource) {
return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_);
};

nodeGroups =
std::make_unique<NodeGroupCollection>(*memoryManager, getNodeTableColumnTypes(*this),
enableCompression, storageManager->getDataFH(), deSer, &rollbackInsertFunc);
initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry,
storageManager->isReadOnly(), vfs, context);
}

std::unique_ptr<NodeTable> NodeTable::loadTable(Deserializer& deSer, const Catalog& catalog,
9 changes: 8 additions & 1 deletion src/storage/store/rel_table_data.cpp
Original file line number Diff line number Diff line change
@@ -25,8 +25,15 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh
multiplicity = tableEntry->constCast<RelTableCatalogEntry>().getMultiplicity(direction);
initCSRHeaderColumns();
initPropertyColumns(tableEntry);

rollbackInsertFunc = [this](const transaction::Transaction*, common::row_idx_t startRow,
common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx_,
CSRNodeGroupScanSource source) {
return nodeGroups->rollbackInsert(startRow, numRows_, nodeGroupIdx_, source);
};

nodeGroups = std::make_unique<NodeGroupCollection>(*mm, getColumnTypes(), enableCompression,
dataFH, deSer);
dataFH, deSer, &rollbackInsertFunc);
}

void RelTableData::initCSRHeaderColumns() {
21 changes: 12 additions & 9 deletions src/storage/undo_buffer.cpp
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ struct VersionRecord {
NodeGroup* nodeGroup;
row_idx_t startRow;
row_idx_t numRows;
const rollback_insert_func_t* preRollbackCallback;
const transaction::rollback_insert_func_t* rollbackInsertFunc;
CSRNodeGroupScanSource source;
};

@@ -113,8 +113,10 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry,
}

void UndoBuffer::createInsertInfo(NodeGroup* nodeGroup, row_idx_t startRow, row_idx_t numRows,
const transaction::rollback_insert_func_t* rollbackInsertFunc,
storage::CSRNodeGroupScanSource source) {
createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroup, startRow, numRows, source);
createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroup, startRow, numRows, source,
rollbackInsertFunc);
}

void UndoBuffer::createDeleteInfo(NodeGroup* nodeGroup, common::row_idx_t startRow,
@@ -124,13 +126,13 @@ void UndoBuffer::createDeleteInfo(NodeGroup* nodeGroup, common::row_idx_t startR

void UndoBuffer::createVersionInfo(const UndoRecordType recordType, NodeGroup* nodeGroup,
row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source,
const rollback_insert_func_t* callback) {
const transaction::rollback_insert_func_t* rollbackInsertFunc) {
auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord));
const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)};
*reinterpret_cast<UndoRecordHeader*>(buffer) = recordHeader;
buffer += sizeof(UndoRecordHeader);
*reinterpret_cast<VersionRecord*>(buffer) =
VersionRecord{nodeGroup, startRow, numRows, callback, source};
VersionRecord{nodeGroup, startRow, numRows, rollbackInsertFunc, source};
}

void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx,
@@ -296,12 +298,13 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction
auto& undoRecord = *reinterpret_cast<VersionRecord const*>(record);
switch (recordType) {
case UndoRecordType::INSERT_INFO: {
if (undoRecord.preRollbackCallback) {
(*undoRecord.preRollbackCallback)(transaction, undoRecord.startRow, undoRecord.numRows,
undoRecord.nodeGroup->getNodeGroupIdx());
if (undoRecord.rollbackInsertFunc) {
(*undoRecord.rollbackInsertFunc)(transaction, undoRecord.startRow, undoRecord.numRows,
undoRecord.nodeGroup->getNodeGroupIdx(), undoRecord.source);
} else {
undoRecord.nodeGroup->rollbackInsert(undoRecord.startRow, undoRecord.numRows,
undoRecord.source);
}
undoRecord.nodeGroup->rollbackInsert(undoRecord.startRow, undoRecord.numRows,
undoRecord.source);
} break;
case UndoRecordType::DELETE_INFO: {
undoRecord.nodeGroup->rollbackDelete(undoRecord.startRow, undoRecord.numRows,
5 changes: 3 additions & 2 deletions src/transaction/transaction.cpp
Original file line number Diff line number Diff line change
@@ -172,8 +172,9 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_
}

void Transaction::pushInsertInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow,
common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) const {
undoBuffer->createInsertInfo(nodeGroup, startRow, numRows, source);
common::row_idx_t numRows, const rollback_insert_func_t* rollbackInsertFunc,
storage::CSRNodeGroupScanSource source) const {
undoBuffer->createInsertInfo(nodeGroup, startRow, numRows, rollbackInsertFunc, source);
}

void Transaction::pushDeleteInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow,

0 comments on commit 98076f8

Please sign in to comment.