diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 605b6b5c8da..bb1bf73db4c 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -4,6 +4,7 @@ #include "storage/stats/table_stats.h" #include "storage/store/group_collection.h" #include "storage/store/node_group.h" +#include "transaction/transaction.h" namespace kuzu { namespace transaction { @@ -15,8 +16,8 @@ class MemoryManager; class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, - bool enableCompression, FileHandle* dataFH = nullptr, - common::Deserializer* deSer = nullptr); + bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, + const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -95,6 +96,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; + const transaction::rollback_insert_func_t* rollbackInsertFunc; }; } // namespace storage diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index adc07826dd8..35cbd951ca6 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -177,7 +177,9 @@ class NodeTable final : public Table { TableStats getStats(const transaction::Transaction* transaction) const; - const rollback_insert_func_t& getRollbackInsertFunc() const { return rollbackInsertFunc; } + const transaction::rollback_insert_func_t& getRollbackInsertFunc() const { + return rollbackInsertFunc; + } private: void insertPK(const transaction::Transaction* transaction, @@ -198,7 +200,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; - rollback_insert_func_t rollbackInsertFunc; + transaction::rollback_insert_func_t rollbackInsertFunc; }; } // namespace storage diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index d9f1f747820..36adda5d10b 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -113,6 +113,8 @@ class RelTableData { CSRHeaderColumns csrHeaderColumns; std::vector> columns; + + transaction::rollback_insert_func_t rollbackInsertFunc; }; } // namespace storage diff --git a/src/include/storage/store/table.h b/src/include/storage/store/table.h index 62e5b70fee9..18971ac672d 100644 --- a/src/include/storage/store/table.h +++ b/src/include/storage/store/table.h @@ -13,9 +13,6 @@ class ExpressionEvaluator; namespace storage { class MemoryManager; -using rollback_insert_func_t = std::function; - enum class TableScanSource : uint8_t { COMMITTED = 0, UNCOMMITTED = 1, NONE = UINT8_MAX }; struct TableScanState { diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index d59fed7af35..c1a6baade68 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -5,7 +5,7 @@ #include "common/constants.h" #include "common/types/types.h" #include "storage/enums/csr_node_group_scan_source.h" -#include "storage/store/table.h" +#include "transaction/transaction.h" namespace kuzu { namespace catalog { @@ -90,6 +90,7 @@ class UndoBuffer { const catalog::SequenceRollbackData& data); void createInsertInfo(NodeGroup* nodeGroup, common::row_idx_t startRow, common::row_idx_t numRows, + const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr, storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); void createDeleteInfo(NodeGroup* nodeGroup, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source); @@ -107,7 +108,7 @@ class UndoBuffer { void createVersionInfo(UndoRecordType recordType, NodeGroup* nodeGroup, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE, - const rollback_insert_func_t* preRollbackCallback = nullptr); + const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); void commitRecord(UndoRecordType recordType, const uint8_t* record, common::transaction_t commitTS) const; diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index b6117339979..56b65f9b4a6 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "common/enums/statement_type.h" #include "common/types/types.h" #include "storage/enums/csr_node_group_scan_source.h" @@ -28,6 +30,10 @@ namespace transaction { class TransactionManager; class Transaction; +using rollback_insert_func_t = + std::function; + enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; class Transaction { @@ -117,7 +123,7 @@ class Transaction { void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount, const catalog::SequenceRollbackData& data) const; void pushInsertInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow, - common::row_idx_t numRows, + common::row_idx_t numRows, const rollback_insert_func_t* rollbackInsertFunc = nullptr, storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE) const; void pushDeleteInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow, common::row_idx_t numRows, diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index bbd3499db08..e28e619428f 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -956,7 +956,7 @@ std::pair CSRNodeGroup::actionOnChunkedGroups(const common::Un if (persistentChunkGroup) { std::invoke(operation, *persistentChunkGroup, startRow, numRows_, commitTS); } - return {UINT32_MAX, UINT32_MAX}; + return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; } else { KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); return NodeGroup::actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index fd53f40724e..ba5761cc4e8 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -15,9 +15,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer) + Deserializer* deSer, const transaction::rollback_insert_func_t* rollbackInsertFunc) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH} { + dataFH{dataFH}, rollbackInsertFunc(rollbackInsertFunc) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -238,7 +238,7 @@ void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transac storage::CSRNodeGroupScanSource source) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(nodeGroup, startRow, numRows, source); + transaction->pushInsertInfo(nodeGroup, startRow, numRows, rollbackInsertFunc, source); } if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { numTotalRows += numRows; diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index c2fea23704d..cebd77ba956 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -220,16 +220,17 @@ NodeTable::NodeTable(const StorageManager* storageManager, dataFH, memoryManager, shadowFile, enableCompression); } - nodeGroups = std::make_unique(*memoryManager, - getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); - initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, - storageManager->isReadOnly(), vfs, context); - rollbackInsertFunc = [this](const transaction::Transaction* transaction, common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx_) { + common::node_group_idx_t nodeGroupIdx_, CSRNodeGroupScanSource) { return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_); }; + + nodeGroups = + std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), + enableCompression, storageManager->getDataFH(), deSer, &rollbackInsertFunc); + initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, + storageManager->isReadOnly(), vfs, context); } std::unique_ptr NodeTable::loadTable(Deserializer& deSer, const Catalog& catalog, diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index 3a4526b6439..f367aefb7d9 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -25,8 +25,15 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh multiplicity = tableEntry->constCast().getMultiplicity(direction); initCSRHeaderColumns(); initPropertyColumns(tableEntry); + + rollbackInsertFunc = [this](const transaction::Transaction*, common::row_idx_t startRow, + common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx_, + CSRNodeGroupScanSource source) { + return nodeGroups->rollbackInsert(startRow, numRows_, nodeGroupIdx_, source); + }; + nodeGroups = std::make_unique(*mm, getColumnTypes(), enableCompression, - dataFH, deSer); + dataFH, deSer, &rollbackInsertFunc); } void RelTableData::initCSRHeaderColumns() { diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 5da00570e2b..0fbc2dbe214 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -41,7 +41,7 @@ struct VersionRecord { NodeGroup* nodeGroup; row_idx_t startRow; row_idx_t numRows; - const rollback_insert_func_t* preRollbackCallback; + const transaction::rollback_insert_func_t* rollbackInsertFunc; CSRNodeGroupScanSource source; }; @@ -113,8 +113,10 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, } void UndoBuffer::createInsertInfo(NodeGroup* nodeGroup, row_idx_t startRow, row_idx_t numRows, + const transaction::rollback_insert_func_t* rollbackInsertFunc, storage::CSRNodeGroupScanSource source) { - createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroup, startRow, numRows, source); + createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroup, startRow, numRows, source, + rollbackInsertFunc); } void UndoBuffer::createDeleteInfo(NodeGroup* nodeGroup, common::row_idx_t startRow, @@ -124,13 +126,13 @@ void UndoBuffer::createDeleteInfo(NodeGroup* nodeGroup, common::row_idx_t startR void UndoBuffer::createVersionInfo(const UndoRecordType recordType, NodeGroup* nodeGroup, row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source, - const rollback_insert_func_t* callback) { + const transaction::rollback_insert_func_t* rollbackInsertFunc) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); *reinterpret_cast(buffer) = - VersionRecord{nodeGroup, startRow, numRows, callback, source}; + VersionRecord{nodeGroup, startRow, numRows, rollbackInsertFunc, source}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -296,12 +298,13 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - if (undoRecord.preRollbackCallback) { - (*undoRecord.preRollbackCallback)(transaction, undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroup->getNodeGroupIdx()); + if (undoRecord.rollbackInsertFunc) { + (*undoRecord.rollbackInsertFunc)(transaction, undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroup->getNodeGroupIdx(), undoRecord.source); + } else { + undoRecord.nodeGroup->rollbackInsert(undoRecord.startRow, undoRecord.numRows, + undoRecord.source); } - undoRecord.nodeGroup->rollbackInsert(undoRecord.startRow, undoRecord.numRows, - undoRecord.source); } break; case UndoRecordType::DELETE_INFO: { undoRecord.nodeGroup->rollbackDelete(undoRecord.startRow, undoRecord.numRows, diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 6359d412bf1..c34608de818 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -172,8 +172,9 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ } void Transaction::pushInsertInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow, - common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) const { - undoBuffer->createInsertInfo(nodeGroup, startRow, numRows, source); + common::row_idx_t numRows, const rollback_insert_func_t* rollbackInsertFunc, + storage::CSRNodeGroupScanSource source) const { + undoBuffer->createInsertInfo(nodeGroup, startRow, numRows, rollbackInsertFunc, source); } void Transaction::pushDeleteInfo(storage::NodeGroup* nodeGroup, common::row_idx_t startRow,