From 7b295fa30ac38681870d05ab344f92e64dbd0ed7 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 4 Nov 2024 11:43:15 -0500 Subject: [PATCH 01/28] Fix COPY rollback Add tests for node insertion rollback Add tests for rollback node insert/update Refactor PK column scan Use normal scan (instead of one used for checkpoints) for hash index rollback Disable failing test for now Fix test Pass node group idx to undo record so that order of node group append doesn't matter Tables are appended to undo buffer (instead of NodeGroupCollection). Also update rollback so node group deletion doesn't depend on undo buffer log order Fix + move around tests Some code cleanup More fixes More code cleanup Make undo buffer delete info also operate at table level Add back rollbackInsert for chunkedNodeGroup Add test for rel insert rollback Add CSR node group source to undo buffer Remove chunked node group from undo buffer Some code cleanup Make transaction param in rollback const Fix rollback of empty inserts Fix getNumEmptyTrailingGroups Make rel BM exception recovery test trigger rollback at least once Refactor copy test Clean up node group Clean up undo buffer --- scripts/headers.txt | 1 + src/include/main/database.h | 4 +- .../enums/csr_node_group_scan_source.h | 12 + src/include/storage/index/hash_index.h | 2 +- src/include/storage/index/in_mem_hash_index.h | 29 +- .../storage/store/chunked_node_group.h | 8 +- src/include/storage/store/column.h | 17 +- src/include/storage/store/column_chunk.h | 4 +- .../storage/store/column_reader_writer.h | 10 +- src/include/storage/store/csr_node_group.h | 23 +- src/include/storage/store/dictionary_column.h | 8 +- src/include/storage/store/group_collection.h | 5 + src/include/storage/store/list_column.h | 16 +- src/include/storage/store/node_group.h | 34 ++- .../storage/store/node_group_collection.h | 32 ++- src/include/storage/store/node_table.h | 11 + src/include/storage/store/null_column.h | 4 +- src/include/storage/store/rel_table.h | 9 +- src/include/storage/store/rel_table_data.h | 14 +- src/include/storage/store/string_column.h | 8 +- src/include/storage/store/struct_column.h | 6 +- src/include/storage/store/version_info.h | 7 +- src/include/storage/undo_buffer.h | 37 ++- src/include/transaction/transaction.h | 18 +- src/main/database.cpp | 6 + .../operator/persistent/rel_batch_insert.cpp | 43 ++- src/storage/index/hash_index.cpp | 2 +- src/storage/store/chunked_node_group.cpp | 26 +- src/storage/store/column.cpp | 21 +- src/storage/store/column_chunk.cpp | 10 +- src/storage/store/column_reader_writer.cpp | 40 +-- src/storage/store/csr_chunked_node_group.cpp | 2 +- src/storage/store/csr_node_group.cpp | 35 ++- src/storage/store/dictionary_column.cpp | 12 +- src/storage/store/list_column.cpp | 12 +- src/storage/store/node_group.cpp | 136 ++++++--- src/storage/store/node_group_collection.cpp | 79 +++++- src/storage/store/node_table.cpp | 258 +++++++++++++++--- src/storage/store/null_column.cpp | 4 +- src/storage/store/rel_table.cpp | 50 ++-- src/storage/store/rel_table_data.cpp | 16 +- src/storage/store/string_column.cpp | 8 +- src/storage/store/struct_column.cpp | 6 +- src/storage/store/version_info.cpp | 16 +- src/storage/undo_buffer.cpp | 79 ++++-- src/transaction/transaction.cpp | 28 +- test/copy/copy_test.cpp | 176 +++++++++++- test/include/graph_test/base_graph_test.h | 3 +- test/storage/local_hash_index_test.cpp | 8 +- .../transaction/copy/copy_node.test | 34 ++- .../transaction/create_node/create_node.test | 18 ++ .../transaction/set_node/set_empty.test | 18 ++ 52 files changed, 1106 insertions(+), 359 deletions(-) create mode 100644 src/include/storage/enums/csr_node_group_scan_source.h diff --git a/scripts/headers.txt b/scripts/headers.txt index 36873186cc4..45600901d52 100644 --- a/scripts/headers.txt +++ b/scripts/headers.txt @@ -71,5 +71,6 @@ src/include/processor/result/flat_tuple.h src/include/processor/warning_context.h src/include/processor/operator/persistent/reader/copy_from_error.h src/include/storage/storage_version_info.h +src/include/storage/enums/csr_node_group_scan_source.h src/include/transaction/transaction.h src/include/transaction/transaction_context.h diff --git a/src/include/main/database.h b/src/include/main/database.h index 1861b2d1554..45a8f61caa2 100644 --- a/src/include/main/database.h +++ b/src/include/main/database.h @@ -142,8 +142,8 @@ class Database { void initMembers(std::string_view dbPath, construct_bm_func_t initBmFunc = initBufferManager); // factory method only to be used for tests - static std::unique_ptr construct(std::string_view databasePath, - SystemConfig systemConfig, construct_bm_func_t constructFunc); + Database(std::string_view databasePath, SystemConfig systemConfig, + construct_bm_func_t constructBMFunc); void openLockFile(); void initAndLockDBDir(); diff --git a/src/include/storage/enums/csr_node_group_scan_source.h b/src/include/storage/enums/csr_node_group_scan_source.h new file mode 100644 index 00000000000..5c11ec8773e --- /dev/null +++ b/src/include/storage/enums/csr_node_group_scan_source.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace kuzu::storage { +enum class CSRNodeGroupScanSource : uint8_t { + COMMITTED_PERSISTENT = 0, + COMMITTED_IN_MEMORY = 1, + UNCOMMITTED = 2, + NONE = 10 +}; +} // namespace kuzu::storage diff --git a/src/include/storage/index/hash_index.h b/src/include/storage/index/hash_index.h index d1f5d91f2aa..b79aa9df660 100644 --- a/src/include/storage/index/hash_index.h +++ b/src/include/storage/index/hash_index.h @@ -343,7 +343,7 @@ class PrimaryKeyIndex { KU_ASSERT(keyDataTypeID == common::TypeUtils::getPhysicalTypeIDForType()); return getTypedHashIndex(key)->insertInternal(transaction, key, value, isVisible); } - bool insert(const transaction::Transaction* transaction, common::ValueVector* keyVector, + bool insert(const transaction::Transaction* transaction, const common::ValueVector* keyVector, uint64_t vectorPos, common::offset_t value, visible_func isVisible); // Appends the buffer to the index. Returns the number of values successfully inserted. diff --git a/src/include/storage/index/in_mem_hash_index.h b/src/include/storage/index/in_mem_hash_index.h index 946fc2bb2fa..5c6817fe3cb 100644 --- a/src/include/storage/index/in_mem_hash_index.h +++ b/src/include/storage/index/in_mem_hash_index.h @@ -159,14 +159,13 @@ class InMemHashIndex final { auto fingerprint = HashIndexUtils::getFingerprintForHash(hashValue); auto slotId = HashIndexUtils::getPrimarySlotIdForHash(this->indexHeader, hashValue); SlotIterator iter(slotId, this); - std::optional deletedPos = 0; + std::optional deletedPos; do { for (auto entryPos = 0u; entryPos < getSlotCapacity(); entryPos++) { if (iter.slot->header.isEntryValid(entryPos) && iter.slot->header.fingerprints[entryPos] == fingerprint && equals(key, iter.slot->entries[entryPos].key)) { deletedPos = entryPos; - iter.slot->header.setEntryInvalid(entryPos); break; } } @@ -177,22 +176,40 @@ class InMemHashIndex final { if (deletedPos.has_value()) { // Find the last valid entry and move it into the deleted position - auto newIter = iter; - while (nextChainedSlot(newIter)) - ; + auto newIter = getLastValidEntry(iter); if (newIter.slotInfo != iter.slotInfo || *deletedPos != newIter.slot->header.numEntries() - 1) { - auto lastEntryPos = newIter.slot->header.numEntries(); + KU_ASSERT(newIter.slot->header.numEntries() > 0); + auto lastEntryPos = newIter.slot->header.numEntries() - 1; iter.slot->entries[*deletedPos] = newIter.slot->entries[lastEntryPos]; iter.slot->header.setEntryValid(*deletedPos, newIter.slot->header.fingerprints[lastEntryPos]); newIter.slot->header.setEntryInvalid(lastEntryPos); + } else { + iter.slot->header.setEntryInvalid(*deletedPos); } + return true; } return false; } private: + SlotIterator getLastValidEntry(const SlotIterator& startIter) { + auto curIter = startIter; + auto newIter = startIter; + while (nextChainedSlot(curIter)) { + if (curIter.slotInfo.slotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + break; + } + if (curIter.slot->header.numEntries() == 0) { + // if the current overflow slot is empty if last valid entry is in the previous slot + break; + } + newIter = curIter; + } + return newIter; + } + // Assumes that space has already been allocated for the entry bool appendInternal(Key key, common::offset_t value, common::hash_t hash, visible_func isVisible) { diff --git a/src/include/storage/store/chunked_node_group.h b/src/include/storage/store/chunked_node_group.h index d0405fba645..65009cc062f 100644 --- a/src/include/storage/store/chunked_node_group.h +++ b/src/include/storage/store/chunked_node_group.h @@ -103,7 +103,7 @@ class ChunkedNodeGroup { std::pair, std::unique_ptr> scanUpdates( transaction::Transaction* transaction, common::column_id_t columnID); - bool lookup(transaction::Transaction* transaction, const TableScanState& state, + bool lookup(const transaction::Transaction* transaction, const TableScanState& state, NodeGroupScanState& nodeGroupScanState, common::offset_t rowIdxInChunk, common::sel_t posInOutput) const; @@ -139,10 +139,12 @@ class ChunkedNodeGroup { void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS); - void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_); + void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS); void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS); - void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_); + void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS); uint64_t getEstimatedMemoryUsage() const; diff --git a/src/include/storage/store/column.h b/src/include/storage/store/column.h index a1ad5714e41..07b43b191a7 100644 --- a/src/include/storage/store/column.h +++ b/src/include/storage/store/column.h @@ -52,15 +52,15 @@ class Column { virtual void scan(transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan, common::ValueVector* resultVector) const; - virtual void lookupValue(transaction::Transaction* transaction, const ChunkState& state, + virtual void lookupValue(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const; // Scan from [startOffsetInGroup, endOffsetInGroup). - virtual void scan(transaction::Transaction* transaction, const ChunkState& state, + virtual void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, common::ValueVector* resultVector, uint64_t offsetInVector) const; // Scan from [startOffsetInGroup, endOffsetInGroup). - virtual void scan(transaction::Transaction* transaction, const ChunkState& state, + virtual void scan(const transaction::Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, common::offset_t startOffset = 0, common::offset_t endOffset = common::INVALID_OFFSET) const; @@ -71,7 +71,7 @@ class Column { std::string getName() const { return name; } - virtual void scan(transaction::Transaction* transaction, const ChunkState& state, + virtual void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, uint8_t* result); // Batch write to a set of sequential pages. @@ -99,8 +99,9 @@ class Column { common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan, common::ValueVector* resultVector) const; - virtual void lookupInternal(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const; + virtual void lookupInternal(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector, + uint32_t posInVector) const; void writeValues(ChunkState& state, common::offset_t dstOffset, const uint8_t* data, const common::NullMask* nullChunkData, common::offset_t srcOffset = 0, @@ -162,7 +163,7 @@ class InternalIDColumn final : public Column { populateCommonTableID(resultVector); } - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, common::ValueVector* resultVector, uint64_t offsetInVector) const override { Column::scan(transaction, state, startOffsetInGroup, endOffsetInGroup, resultVector, @@ -170,7 +171,7 @@ class InternalIDColumn final : public Column { populateCommonTableID(resultVector); } - void lookupInternal(transaction::Transaction* transaction, const ChunkState& state, + void lookupInternal(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const override { Column::lookupInternal(transaction, state, nodeOffset, resultVector, posInVector); diff --git a/src/include/storage/store/column_chunk.h b/src/include/storage/store/column_chunk.h index 3c6a11752c5..75912ec9399 100644 --- a/src/include/storage/store/column_chunk.h +++ b/src/include/storage/store/column_chunk.h @@ -51,10 +51,10 @@ class ColumnChunk { void scan(const transaction::Transaction* transaction, const ChunkState& state, common::ValueVector& output, common::offset_t offsetInChunk, common::length_t length) const; template - void scanCommitted(transaction::Transaction* transaction, ChunkState& chunkState, + void scanCommitted(const transaction::Transaction* transaction, ChunkState& chunkState, ColumnChunk& output, common::row_idx_t startRow = 0, common::row_idx_t numRows = common::INVALID_ROW_IDX) const; - void lookup(transaction::Transaction* transaction, const ChunkState& state, + void lookup(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t rowInChunk, common::ValueVector& output, common::sel_t posInOutputVector) const; void update(const transaction::Transaction* transaction, common::offset_t offsetInChunk, diff --git a/src/include/storage/store/column_reader_writer.h b/src/include/storage/store/column_reader_writer.h index 9307651cca0..635d850bdc4 100644 --- a/src/include/storage/store/column_reader_writer.h +++ b/src/include/storage/store/column_reader_writer.h @@ -49,22 +49,22 @@ class ColumnReadWriter { virtual ~ColumnReadWriter() = default; - virtual void readCompressedValueToPage(transaction::Transaction* transaction, + virtual void readCompressedValueToPage(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, uint8_t* result, uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) = 0; - virtual void readCompressedValueToVector(transaction::Transaction* transaction, + virtual void readCompressedValueToVector(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* result, uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) = 0; - virtual uint64_t readCompressedValuesToPage(transaction::Transaction* transaction, + virtual uint64_t readCompressedValuesToPage(const transaction::Transaction* transaction, const ChunkState& state, uint8_t* result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, const std::optional& filterFunc = {}) = 0; - virtual uint64_t readCompressedValuesToVector(transaction::Transaction* transaction, + virtual uint64_t readCompressedValuesToVector(const transaction::Transaction* transaction, const ChunkState& state, common::ValueVector* result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, @@ -78,7 +78,7 @@ class ColumnReadWriter { const uint8_t* data, const common::NullMask* nullChunkData, common::offset_t srcOffset, common::offset_t numValues, const write_values_func_t& writeFunc) = 0; - void readFromPage(transaction::Transaction* transaction, common::page_idx_t pageIdx, + void readFromPage(const transaction::Transaction* transaction, common::page_idx_t pageIdx, const std::function& readFunc); void updatePageWithCursor(PageCursor cursor, diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index c73089df595..5ed1a60d7d5 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -4,6 +4,7 @@ #include #include "common/data_chunk/data_chunk.h" +#include "storage/enums/csr_node_group_scan_source.h" #include "storage/store/csr_chunked_node_group.h" #include "storage/store/node_group.h" @@ -21,13 +22,6 @@ struct csr_list_t { common::length_t length = 0; }; -enum class CSRNodeGroupScanSource : uint8_t { - COMMITTED_PERSISTENT = 0, - COMMITTED_IN_MEMORY = 1, - UNCOMMITTED = 2, - NONE = 10 -}; - // Store rows of a CSR list. // If rows of the CSR list are stored in a sequential order, then `isSequential` is set to true. // rowIndices consists of startRowIdx and length. @@ -187,9 +181,9 @@ class CSRNodeGroup final : public NodeGroup { } } - void initializeScanState(transaction::Transaction* transaction, + void initializeScanState(const transaction::Transaction* transaction, TableScanState& state) const override; - NodeGroupScanResult scan(transaction::Transaction* transaction, + NodeGroupScanResult scan(const transaction::Transaction* transaction, TableScanState& state) const override; void appendChunkedCSRGroup(const transaction::Transaction* transaction, @@ -211,6 +205,7 @@ class CSRNodeGroup final : public NodeGroup { bool isEmpty() const override { return !persistentChunkGroup && NodeGroup::isEmpty(); } + common::row_idx_t getNumPersistentRows() const; ChunkedNodeGroup* getPersistentChunkedGroup() const { return persistentChunkGroup.get(); } void setPersistentChunkedGroup(std::unique_ptr chunkedNodeGroup) { KU_ASSERT(chunkedNodeGroup->getFormat() == NodeGroupDataFormat::CSR); @@ -220,7 +215,11 @@ class CSRNodeGroup final : public NodeGroup { void serialize(common::Serializer& serializer) override; private: - void initScanForCommittedPersistent(transaction::Transaction* transaction, + std::pair actionOnChunkedGroups(const common::UniqLock& lock, + common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS, + CSRNodeGroupScanSource source, chunked_group_transaction_operation_t operation) override; + + void initScanForCommittedPersistent(const transaction::Transaction* transaction, RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const; void initScanForCommittedInMem(RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const; @@ -237,11 +236,11 @@ class CSRNodeGroup final : public NodeGroup { const transaction::Transaction* transaction, RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; - NodeGroupScanResult scanCommittedInMem(transaction::Transaction* transaction, + NodeGroupScanResult scanCommittedInMem(const transaction::Transaction* transaction, RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; NodeGroupScanResult scanCommittedInMemSequential(const transaction::Transaction* transaction, const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; - NodeGroupScanResult scanCommittedInMemRandom(transaction::Transaction* transaction, + NodeGroupScanResult scanCommittedInMemRandom(const transaction::Transaction* transaction, const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; void checkpointInMemOnly(const common::UniqLock& lock, NodeGroupCheckpointState& state); diff --git a/src/include/storage/store/dictionary_column.h b/src/include/storage/store/dictionary_column.h index 3e9cb6abe75..b7f6ed39c15 100644 --- a/src/include/storage/store/dictionary_column.h +++ b/src/include/storage/store/dictionary_column.h @@ -12,12 +12,12 @@ class DictionaryColumn { DictionaryColumn(const std::string& name, FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression); - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, DictionaryChunk& dictChunk) const; // Offsets to scan should be a sorted list of pairs mapping the index of the entry in the string // dictionary (as read from the index column) to the output index in the result vector to store // the string. - void scan(transaction::Transaction* transaction, const ChunkState& offsetState, + void scan(const transaction::Transaction* transaction, const ChunkState& offsetState, const ChunkState& dataState, std::vector>& offsetsToScan, common::ValueVector* resultVector, const ColumnChunkMetadata& indexMeta) const; @@ -32,10 +32,10 @@ class DictionaryColumn { Column* getOffsetColumn() const { return offsetColumn.get(); } private: - void scanOffsets(transaction::Transaction* transaction, const ChunkState& state, + void scanOffsets(const transaction::Transaction* transaction, const ChunkState& state, DictionaryChunk::string_offset_t* offsets, uint64_t index, uint64_t numValues, uint64_t dataSize) const; - void scanValueToVector(transaction::Transaction* transaction, const ChunkState& dataState, + void scanValueToVector(const transaction::Transaction* transaction, const ChunkState& dataState, uint64_t startOffset, uint64_t endOffset, common::ValueVector* resultVector, uint64_t offsetInVector) const; diff --git a/src/include/storage/store/group_collection.h b/src/include/storage/store/group_collection.h index da107614993..de10c2a056e 100644 --- a/src/include/storage/store/group_collection.h +++ b/src/include/storage/store/group_collection.h @@ -24,6 +24,11 @@ class GroupCollection { [&](common::Deserializer& deser) { return T::deserialize(memoryManager, deser); }); } + void removeTrailingGroups(const common::UniqLock&, common::idx_t numGroupsToRemove) { + KU_ASSERT(numGroupsToRemove <= groups.size()); + groups.erase(groups.end() - numGroupsToRemove, groups.end()); + } + void serializeGroups(common::Serializer& ser) { auto lockGuard = lock(); ser.serializeVectorOfPtrs(groups); diff --git a/src/include/storage/store/list_column.h b/src/include/storage/store/list_column.h index b08a89ca218..4a932e32e35 100644 --- a/src/include/storage/store/list_column.h +++ b/src/include/storage/store/list_column.h @@ -56,10 +56,10 @@ class ListColumn final : public Column { static std::unique_ptr flushChunkData(const ColumnChunkData& chunk, FileHandle& dataFH); - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, common::ValueVector* resultVector, uint64_t offsetInVector = 0) const override; - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, common::offset_t startOffset = 0, common::offset_t endOffset = common::INVALID_OFFSET) const override; @@ -74,7 +74,7 @@ class ListColumn final : public Column { common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan, common::ValueVector* resultVector) const override; - void lookupInternal(transaction::Transaction* transaction, const ChunkState& state, + void lookupInternal(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const override; @@ -85,12 +85,12 @@ class ListColumn final : public Column { void scanFiltered(transaction::Transaction* transaction, const ChunkState& state, common::ValueVector* offsetVector, const ListOffsetSizeInfo& listOffsetInfoInStorage) const; - common::offset_t readOffset(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t offsetInNodeGroup) const; - common::list_size_t readSize(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t offsetInNodeGroup) const; + common::offset_t readOffset(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t offsetInNodeGroup) const; + common::list_size_t readSize(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t offsetInNodeGroup) const; - ListOffsetSizeInfo getListOffsetSizeInfo(transaction::Transaction* transaction, + ListOffsetSizeInfo getListOffsetSizeInfo(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInNodeGroup, common::offset_t endOffsetInNodeGroup) const; diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 268486acfac..e249ef6dd39 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -3,6 +3,7 @@ #include #include "common/uniq_lock.h" +#include "storage/enums/csr_node_group_scan_source.h" #include "storage/enums/residency_state.h" #include "storage/store/chunked_node_group.h" #include "storage/store/group_collection.h" @@ -125,11 +126,11 @@ class NodeGroup { void merge(transaction::Transaction* transaction, std::unique_ptr chunkedGroup); - virtual void initializeScanState(transaction::Transaction* transaction, + virtual void initializeScanState(const transaction::Transaction* transaction, TableScanState& state) const; - void initializeScanState(transaction::Transaction* transaction, const common::UniqLock& lock, - TableScanState& state) const; - virtual NodeGroupScanResult scan(transaction::Transaction* transaction, + void initializeScanState(const transaction::Transaction* transaction, + const common::UniqLock& lock, TableScanState& state) const; + virtual NodeGroupScanResult scan(const transaction::Transaction* transaction, TableScanState& state) const; virtual NodeGroupScanResult scan(transaction::Transaction* transaction, TableScanState& state, @@ -149,6 +150,16 @@ class NodeGroup { void flush(transaction::Transaction* transaction, FileHandle& dataFH); + void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS, CSRNodeGroupScanSource source); + void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS, CSRNodeGroupScanSource source); + + void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + CSRNodeGroupScanSource source); + void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + CSRNodeGroupScanSource source); + virtual void checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state); uint64_t getEstimatedMemoryUsage(); @@ -181,8 +192,21 @@ class NodeGroup { bool isDeleted(const transaction::Transaction* transaction, common::offset_t offsetInGroup); bool isInserted(const transaction::Transaction* transaction, common::offset_t offsetInGroup); + common::node_group_idx_t getNodeGroupIdx() const { return nodeGroupIdx; } + +protected: + static constexpr auto INVALID_CHUNKED_GROUP_IDX = UINT32_MAX; + static constexpr auto INVALID_START_ROW_IDX = UINT64_MAX; + + using chunked_group_transaction_operation_t = void ( + ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); + virtual std::pair actionOnChunkedGroups( + const common::UniqLock& lock, common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS, CSRNodeGroupScanSource source, + chunked_group_transaction_operation_t operation); + private: - common::idx_t findChunkedGroupIdxFromRowIdx(const common::UniqLock& lock, + std::pair findChunkedGroupIdxFromRowIdxNoLock( common::row_idx_t rowIdx) const; ChunkedNodeGroup* findChunkedGroupFromRowIdx(const common::UniqLock& lock, common::row_idx_t rowIdx) const; diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index e32b8163d6c..73e5fe43495 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -1,5 +1,6 @@ #pragma once +#include "storage/enums/csr_node_group_scan_source.h" #include "storage/stats/table_stats.h" #include "storage/store/group_collection.h" #include "storage/store/node_group.h" @@ -11,16 +12,19 @@ class Transaction; namespace storage { class MemoryManager; +using append_to_undo_buffer_func_t = + std::function; + class NodeGroupCollection { public: - explicit NodeGroupCollection(MemoryManager& memoryManager, - const std::vector& types, bool enableCompression, - FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr); + NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, + bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, + append_to_undo_buffer_func_t appendToUndoBufferFunc = defaultAppendToUndoBuffer); void append(const transaction::Transaction* transaction, const std::vector& vectors); void append(const transaction::Transaction* transaction, NodeGroupCollection& other); - void appned(const transaction::Transaction* transaction, NodeGroup& nodeGroup); + void append(const transaction::Transaction* transaction, NodeGroup& nodeGroup); // This function only tries to append data into the last node group, and if the last node group // is not enough to hold all the data, it will append partially and return the number of rows @@ -36,6 +40,7 @@ class NodeGroupCollection { return nodeGroups.getNumGroups(lock); } NodeGroup* getNodeGroupNoLock(const common::node_group_idx_t groupIdx) { + KU_ASSERT(nodeGroups.getGroupNoLock(groupIdx)->getNodeGroupIdx() == groupIdx); return nodeGroups.getGroupNoLock(groupIdx); } NodeGroup* getNodeGroup(const common::node_group_idx_t groupIdx, @@ -44,6 +49,7 @@ class NodeGroupCollection { if (mayOutOfBound && groupIdx >= nodeGroups.getNumGroups(lock)) { return nullptr; } + KU_ASSERT(nodeGroups.getGroupNoLock(groupIdx)->getNodeGroupIdx() == groupIdx); return nodeGroups.getGroup(lock, groupIdx); } NodeGroup* getOrCreateNodeGroup(common::node_group_idx_t groupIdx, NodeGroupDataFormat format); @@ -54,6 +60,20 @@ class NodeGroupCollection { nodeGroups.replaceGroup(lock, nodeGroupIdx, std::move(group)); } + void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, + CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, + CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + + void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx, + CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx, + CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + void clear() { const auto lock = nodeGroups.lock(); nodeGroups.clear(lock); @@ -73,6 +93,9 @@ class NodeGroupCollection { void deserialize(common::Deserializer& deSer, MemoryManager& memoryManager); private: + static void defaultAppendToUndoBuffer(const transaction::Transaction*, NodeGroup*, + common::row_idx_t); + bool enableCompression; // Num rows in the collection regardless of deletions. common::row_idx_t numTotalRows; @@ -80,6 +103,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; + append_to_undo_buffer_func_t appendToUndoBufferFunc; }; } // namespace storage diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index 065ccc7547d..9112b88fd8c 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -156,6 +156,9 @@ class NodeTable final : public Table { void commit(transaction::Transaction* transaction, LocalTable* localTable) override; void checkpoint(common::Serializer& ser, catalog::TableCatalogEntry* tableEntry) override; + void rollbackInsert(const transaction::Transaction* transaction, common::row_idx_t startRow, + common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx); + common::node_group_idx_t getNumCommittedNodeGroups() const { return nodeGroups->getNumNodeGroups(); } @@ -171,6 +174,8 @@ class NodeTable final : public Table { return nodeGroups->getNodeGroupNoLock(nodeGroupIdx); } + NodeGroupCollection* getNodeGroups() { return nodeGroups.get(); } + TableStats getStats(const transaction::Transaction* transaction) const; private: @@ -181,6 +186,12 @@ class NodeTable final : public Table { void serialize(common::Serializer& serializer) const override; + std::unique_ptr initPKScanState(common::DataChunk& dataChunk, + TableScanSource source) const; + + visible_func getVisibleFunc(const transaction::Transaction* transaction) const; + common::DataChunk constructDataChunkForPKColumn() const; + private: std::vector> columns; std::unique_ptr nodeGroups; diff --git a/src/include/storage/store/null_column.h b/src/include/storage/store/null_column.h index eae35896baa..7fc3d0d6639 100644 --- a/src/include/storage/store/null_column.h +++ b/src/include/storage/store/null_column.h @@ -21,10 +21,10 @@ class NullColumn final : public Column { void scan(transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan, common::ValueVector* resultVector) const override; - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, common::ValueVector* resultVector, uint64_t offsetInVector) const override; - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, common::offset_t startOffset = 0, common::offset_t endOffset = common::INVALID_OFFSET) const override; diff --git a/src/include/storage/store/rel_table.h b/src/include/storage/store/rel_table.h index eea1f512b5a..7e8ee17bf9c 100644 --- a/src/include/storage/store/rel_table.h +++ b/src/include/storage/store/rel_table.h @@ -186,6 +186,9 @@ class RelTable final : public Table { return currentRelOffset; } + void pushInsertInfo(transaction::Transaction* transaction, common::RelDataDirection direction, + const CSRNodeGroup& nodeGroup, common::row_idx_t numRows_, CSRNodeGroupScanSource source); + private: static void prepareCommitForNodeGroup(const transaction::Transaction* transaction, NodeGroup& localNodeGroup, CSRNodeGroup& csrNodeGroup, common::offset_t boundOffsetInGroup, @@ -198,9 +201,9 @@ class RelTable final : public Table { static common::offset_t getCommittedOffset(common::offset_t uncommittedOffset, common::offset_t maxCommittedOffset); - void detachDeleteForCSRRels(transaction::Transaction* transaction, - const RelTableData* tableData, const RelTableData* reverseTableData, - RelTableScanState* relDataReadState, RelTableDeleteState* deleteState); + void detachDeleteForCSRRels(transaction::Transaction* transaction, RelTableData* tableData, + RelTableData* reverseTableData, RelTableScanState* relDataReadState, + RelTableDeleteState* deleteState); void checkRelMultiplicityConstraint(transaction::Transaction* transaction, const TableInsertState& state) const; diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index 0a3a2ca733b..bd453a04312 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -30,7 +30,7 @@ class RelTableData { const common::ValueVector& relIDVector, common::column_id_t columnID, const common::ValueVector& dataVector) const; bool delete_(transaction::Transaction* transaction, common::ValueVector& boundNodeIDVector, - const common::ValueVector& relIDVector) const; + const common::ValueVector& relIDVector); void addColumn(transaction::Transaction* transaction, TableAddColumnState& addColumnState); bool checkIfNodeHasRels(transaction::Transaction* transaction, @@ -57,14 +57,7 @@ class RelTableData { return nodeGroups->getOrCreateNodeGroup(nodeGroupIdx, NodeGroupDataFormat::CSR); } - common::row_idx_t getNumRows() const { - common::row_idx_t numRows = 0; - const auto numGroups = nodeGroups->getNumNodeGroups(); - for (auto nodeGroupIdx = 0u; nodeGroupIdx < numGroups; nodeGroupIdx++) { - numRows += nodeGroups->getNodeGroup(nodeGroupIdx)->getNumRows(); - } - return numRows; - } + NodeGroupCollection* getNodeGroups() { return nodeGroups.get(); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } @@ -72,6 +65,9 @@ class RelTableData { void checkpoint(const std::vector& columnIDs); + void pushInsertInfo(transaction::Transaction* transaction, const CSRNodeGroup& nodeGroup, + common::row_idx_t numRows_, CSRNodeGroupScanSource source); + void serialize(common::Serializer& serializer) const; private: diff --git a/src/include/storage/store/string_column.h b/src/include/storage/store/string_column.h index a20ed597857..2cd64955fa1 100644 --- a/src/include/storage/store/string_column.h +++ b/src/include/storage/store/string_column.h @@ -17,10 +17,10 @@ class StringColumn final : public Column { static std::unique_ptr flushChunkData(const ColumnChunkData& chunkData, FileHandle& dataFH); - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, common::ValueVector* resultVector, uint64_t offsetInVector = 0) const override; - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, common::offset_t startOffset = 0, common::offset_t endOffset = common::INVALID_OFFSET) const override; @@ -39,13 +39,13 @@ class StringColumn final : public Column { void scanInternal(transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan, common::ValueVector* resultVector) const override; - void scanUnfiltered(transaction::Transaction* transaction, const ChunkState& state, + void scanUnfiltered(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInChunk, common::offset_t numValuesToRead, common::ValueVector* resultVector, common::sel_t startPosInVector = 0) const; void scanFiltered(transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInChunk, common::ValueVector* resultVector) const; - void lookupInternal(transaction::Transaction* transaction, const ChunkState& state, + void lookupInternal(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const override; diff --git a/src/include/storage/store/struct_column.h b/src/include/storage/store/struct_column.h index 3e71bc238a0..7828d758f9e 100644 --- a/src/include/storage/store/struct_column.h +++ b/src/include/storage/store/struct_column.h @@ -14,10 +14,10 @@ class StructColumn final : public Column { static std::unique_ptr flushChunkData(const ColumnChunkData& chunk, FileHandle& dataFH); - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, common::offset_t startOffset = 0, common::offset_t endOffset = common::INVALID_OFFSET) const override; - void scan(transaction::Transaction* transaction, const ChunkState& state, + void scan(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, common::ValueVector* resultVector, uint64_t offsetInVector) const override; @@ -35,7 +35,7 @@ class StructColumn final : public Column { common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan, common::ValueVector* resultVector) const override; - void lookupInternal(transaction::Transaction* transaction, const ChunkState& state, + void lookupInternal(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const override; diff --git a/src/include/storage/store/version_info.h b/src/include/storage/store/version_info.h index ae47a2dc74d..a8a2f82203f 100644 --- a/src/include/storage/store/version_info.h +++ b/src/include/storage/store/version_info.h @@ -86,10 +86,9 @@ class VersionInfo { public: VersionInfo() {} - void append(const transaction::Transaction* transaction, ChunkedNodeGroup* chunkedNodeGroup, - common::row_idx_t startRow, common::row_idx_t numRows); - bool delete_(const transaction::Transaction* transaction, ChunkedNodeGroup* chunkedNodeGroup, - common::row_idx_t rowIdx); + void append(const transaction::Transaction* transaction, common::row_idx_t startRow, + common::row_idx_t numRows); + bool delete_(const transaction::Transaction* transaction, common::row_idx_t rowIdx); void getSelVectorToScan(common::transaction_t startTS, common::transaction_t transactionID, common::SelectionVector& selVector, common::row_idx_t startRow, diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 4d2f0f19e35..33ee87722ba 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include "common/constants.h" #include "common/types/types.h" +#include "storage/enums/csr_node_group_scan_source.h" namespace kuzu { namespace catalog { @@ -21,6 +23,9 @@ class ClientContext; } namespace storage { +using pre_rollback_callback_t = std::function; + // TODO(Guodong): This should be reworked to use MemoryManager for memory allocaiton. // For now, we use malloc to get around the limitation of 256KB from MM. class UndoMemoryBuffer { @@ -65,7 +70,9 @@ class UndoBufferIterator { class UpdateInfo; class VersionInfo; struct VectorUpdateInfo; -class ChunkedNodeGroup; +class RelTableData; +class NodeTable; +class NodeGroupCollection; class WAL; // This class is not thread safe, as it is supposed to be accessed by a single thread. class UndoBuffer { @@ -85,27 +92,36 @@ class UndoBuffer { void createCatalogEntry(catalog::CatalogSet& catalogSet, catalog::CatalogEntry& catalogEntry); void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, const catalog::SequenceRollbackData& data); - void createInsertInfo(ChunkedNodeGroup* chunkedNodeGroup, common::row_idx_t startRow, - common::row_idx_t numRows); - void createDeleteInfo(ChunkedNodeGroup* chunkedNodeGroup, common::row_idx_t startRow, - common::row_idx_t numRows); + void createInsertInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source); + void createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows); + void createDeleteInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows); + void createDeleteInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, VectorUpdateInfo* vectorUpdateInfo); void commit(common::transaction_t commitTS) const; - void rollback(); + void rollback(const transaction::Transaction* transaction); uint64_t getMemUsage() const; private: uint8_t* createUndoRecord(uint64_t size); - void createVersionInfo(UndoRecordType recordType, ChunkedNodeGroup* chunkedNodeGroup, - common::row_idx_t startRow, common::row_idx_t numRows); + void createVersionInfo(UndoRecordType recordType, NodeGroupCollection* nodeGroupCollection, + pre_rollback_callback_t preRollbackCallback, common::row_idx_t startRow, + common::row_idx_t numRows, common::node_group_idx_t nodeGroupIdx = 0, + storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); void commitRecord(UndoRecordType recordType, const uint8_t* record, common::transaction_t commitTS) const; - void rollbackRecord(UndoRecordType recordType, const uint8_t* record); + void rollbackRecord(const transaction::Transaction* transaction, UndoRecordType recordType, + const uint8_t* record); void commitCatalogEntryRecord(const uint8_t* record, common::transaction_t commitTS) const; void rollbackCatalogEntryRecord(const uint8_t* record); @@ -115,7 +131,8 @@ class UndoBuffer { void commitVersionInfo(UndoRecordType recordType, const uint8_t* record, common::transaction_t commitTS) const; - void rollbackVersionInfo(UndoRecordType recordType, const uint8_t* record); + void rollbackVersionInfo(const transaction::Transaction* transaction, UndoRecordType recordType, + const uint8_t* record); void commitVectorUpdateInfo(const uint8_t* record, common::transaction_t commitTS) const; void rollbackVectorUpdateInfo(const uint8_t* record) const; diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index 0c5acbaa204..e55c3cb9023 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -2,6 +2,7 @@ #include "common/enums/statement_type.h" #include "common/types/types.h" +#include "storage/enums/csr_node_group_scan_source.h" namespace kuzu { namespace catalog { @@ -20,10 +21,13 @@ class WAL; class VersionInfo; class UpdateInfo; struct VectorUpdateInfo; +class RelTableData; +class NodeTable; class ChunkedNodeGroup; } // namespace storage namespace transaction { class TransactionManager; +class Transaction; enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; @@ -113,10 +117,16 @@ class Transaction { bool skipLoggingToWAL = false) const; void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount, const catalog::SequenceRollbackData& data) const; - void pushInsertInfo(storage::ChunkedNodeGroup* chunkedNodeGroup, common::row_idx_t startRow, - common::row_idx_t numRows) const; - void pushDeleteInfo(storage::ChunkedNodeGroup* chunkedNodeGroup, common::row_idx_t startRow, - common::row_idx_t numRows) const; + void pushInsertInfo(storage::RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source) const; + void pushInsertInfo(storage::NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) const; + void pushDeleteInfo(storage::RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source) const; + void pushDeleteInfo(storage::NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const; diff --git a/src/main/database.cpp b/src/main/database.cpp index f9e695a0bee..11f224c0992 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -79,6 +79,12 @@ Database::Database(std::string_view databasePath, SystemConfig systemConfig) initMembers(databasePath); } +Database::Database(std::string_view databasePath, SystemConfig systemConfig, + construct_bm_func_t constructBMFunc) + : dbConfig(systemConfig) { + initMembers(databasePath, constructBMFunc); +} + std::unique_ptr Database::initBufferManager(const Database& db) { return std::make_unique(db.databasePath, db.vfs->joinPath(db.databasePath, StorageConstants::TEMP_SPILLING_FILE_NAME), diff --git a/src/processor/operator/persistent/rel_batch_insert.cpp b/src/processor/operator/persistent/rel_batch_insert.cpp index 51c44bdcf05..765ba03a84b 100644 --- a/src/processor/operator/persistent/rel_batch_insert.cpp +++ b/src/processor/operator/persistent/rel_batch_insert.cpp @@ -66,17 +66,40 @@ void RelBatchInsert::executeInternal(ExecutionContext* context) { // No more partitions left in the partitioning buffer. break; } - // TODO(Guodong): We need to handle the concurrency between COPY and other insertions into - // the same node group. + // TODO(Guodong): We need to handle the concurrency between COPY and other insertions + // into the same node group. auto& nodeGroup = relTable->getOrCreateNodeGroup(relLocalState->nodeGroupIdx, relInfo->direction) ->cast(); + if (nodeGroup.isEmpty()) { + // push an insert of size 0 so that we can rollback the creation of this node group if + // needed + relTable->pushInsertInfo(context->clientContext->getTx(), relInfo->direction, nodeGroup, + 0, CSRNodeGroupScanSource::COMMITTED_PERSISTENT); + } appendNodeGroup(context->clientContext->getTx(), nodeGroup, *relInfo, *relLocalState, *sharedState, *partitionerSharedState); updateProgress(context); } } +static void appendNewChunkedGroup(transaction::Transaction* transaction, + ChunkedCSRNodeGroup& chunkedGroup, RelTable& relTable, CSRNodeGroup& nodeGroup, + RelDataDirection direction) { + const bool isNewNodeGroup = nodeGroup.isEmpty(); + const CSRNodeGroupScanSource source = isNewNodeGroup ? + CSRNodeGroupScanSource::COMMITTED_PERSISTENT : + CSRNodeGroupScanSource::COMMITTED_IN_MEMORY; + relTable.pushInsertInfo(transaction, direction, nodeGroup, chunkedGroup.getNumRows(), source); + if (isNewNodeGroup) { + auto flushedChunkedGroup = + chunkedGroup.flushAsNewChunkedNodeGroup(transaction, *relTable.getDataFH()); + nodeGroup.setPersistentChunkedGroup(std::move(flushedChunkedGroup)); + } else { + nodeGroup.appendChunkedCSRGroup(transaction, chunkedGroup); + } +} + void RelBatchInsert::appendNodeGroup(transaction::Transaction* transaction, CSRNodeGroup& nodeGroup, const RelBatchInsertInfo& relInfo, const RelBatchInsertLocalState& localState, BatchInsertSharedState& sharedState, const PartitionerSharedState& partitionerSharedState) { @@ -93,10 +116,9 @@ void RelBatchInsert::appendNodeGroup(transaction::Transaction* transaction, CSRN // This will be used to set the num of values of the node group. const auto numNodes = std::min(StorageConstants::NODE_GROUP_SIZE, partitionerSharedState.maxNodeOffsets[relInfo.partitioningIdx] - startNodeOffset + 1); - const auto isNewNodeGroup = nodeGroup.isEmpty(); // We optimistically flush new node group directly to disk in gapped CSR format. // There is no benefit of leaving gaps for existing node groups, which is kept in memory. - const auto leaveGaps = isNewNodeGroup; + const auto leaveGaps = nodeGroup.isEmpty(); populateCSRHeaderAndRowIdx(*partitioningBuffer, startNodeOffset, relInfo, localState, numNodes, leaveGaps); const auto& csrHeader = localState.chunkedGroup->cast().getCSRHeader(); @@ -122,14 +144,11 @@ void RelBatchInsert::appendNodeGroup(transaction::Transaction* transaction, CSRN } KU_ASSERT(localState.chunkedGroup->getNumRows() == maxSize); localState.chunkedGroup->finalize(); - if (isNewNodeGroup) { - auto flushedChunkedGroup = localState.chunkedGroup->flushAsNewChunkedNodeGroup(transaction, - *sharedState.table->getDataFH()); - nodeGroup.setPersistentChunkedGroup(std::move(flushedChunkedGroup)); - } else { - nodeGroup.appendChunkedCSRGroup(transaction, - localState.chunkedGroup->cast()); - } + + auto* relTable = sharedState.table->ptrCast(); + appendNewChunkedGroup(transaction, localState.chunkedGroup->cast(), + *relTable, nodeGroup, relInfo.direction); + localState.chunkedGroup->resetToEmpty(); } diff --git a/src/storage/index/hash_index.cpp b/src/storage/index/hash_index.cpp index cac17a1a9ad..05387288f0b 100644 --- a/src/storage/index/hash_index.cpp +++ b/src/storage/index/hash_index.cpp @@ -500,7 +500,7 @@ bool PrimaryKeyIndex::lookup(const Transaction* trx, ValueVector* keyVector, uin return retVal; } -bool PrimaryKeyIndex::insert(const Transaction* transaction, ValueVector* keyVector, +bool PrimaryKeyIndex::insert(const Transaction* transaction, const ValueVector* keyVector, uint64_t vectorPos, offset_t value, visible_func isVisible) { bool result = false; TypeUtils::visit( diff --git a/src/storage/store/chunked_node_group.cpp b/src/storage/store/chunked_node_group.cpp index b6ccb9b7663..d1acdeff66a 100644 --- a/src/storage/store/chunked_node_group.cpp +++ b/src/storage/store/chunked_node_group.cpp @@ -137,7 +137,7 @@ uint64_t ChunkedNodeGroup::append(const Transaction* transaction, if (!versionInfo) { versionInfo = std::make_unique(); } - versionInfo->append(transaction, this, numRows, numRowsToAppendInChunk); + versionInfo->append(transaction, numRows, numRowsToAppendInChunk); } numRows += numRowsToAppendInChunk; return numRowsToAppendInChunk; @@ -168,7 +168,7 @@ offset_t ChunkedNodeGroup::append(const Transaction* transaction, if (!versionInfo) { versionInfo = std::make_unique(); } - versionInfo->append(transaction, this, numRows, numToAppendInChunkedGroup); + versionInfo->append(transaction, numRows, numToAppendInChunkedGroup); } numRows += numToAppendInChunkedGroup; return numToAppendInChunkedGroup; @@ -292,7 +292,7 @@ std::pair, std::unique_ptr> ChunkedNod return getColumnChunk(columnID).scanUpdates(transaction); } -bool ChunkedNodeGroup::lookup(Transaction* transaction, const TableScanState& state, +bool ChunkedNodeGroup::lookup(const Transaction* transaction, const TableScanState& state, NodeGroupScanState& nodeGroupScanState, offset_t rowIdxInChunk, sel_t posInOutput) const { KU_ASSERT(rowIdxInChunk + 1 <= numRows); std::unique_ptr selVector = nullptr; @@ -334,7 +334,7 @@ bool ChunkedNodeGroup::delete_(const Transaction* transaction, row_idx_t rowIdxI if (!versionInfo) { versionInfo = std::make_unique(); } - return versionInfo->delete_(transaction, this, rowIdxInChunk); + return versionInfo->delete_(transaction, rowIdxInChunk); } void ChunkedNodeGroup::addColumn(Transaction* transaction, @@ -396,7 +396,7 @@ std::unique_ptr ChunkedNodeGroup::flushAsNewChunkedNodeGroup( std::make_unique(std::move(flushedChunks), 0 /*startRowIdx*/); flushedChunkedGroup->versionInfo = std::make_unique(); KU_ASSERT(flushedChunkedGroup->getNumRows() == numRows); - flushedChunkedGroup->versionInfo->append(transaction, flushedChunkedGroup.get(), 0, numRows); + flushedChunkedGroup->versionInfo->append(transaction, 0, numRows); return flushedChunkedGroup; } @@ -429,12 +429,8 @@ bool ChunkedNodeGroup::hasUpdates() const { return false; } -void ChunkedNodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRowsToCommit, - transaction_t commitTS) { - versionInfo->commitInsert(startRow, numRowsToCommit, commitTS); -} - -void ChunkedNodeGroup::rollbackInsert(row_idx_t startRow, row_idx_t numRows_) { +void ChunkedNodeGroup::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t) { if (startRow == 0) { setNumRows(0); versionInfo.reset(); @@ -448,12 +444,18 @@ void ChunkedNodeGroup::rollbackInsert(row_idx_t startRow, row_idx_t numRows_) { numRows = startRow; } +void ChunkedNodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRowsToCommit, + transaction_t commitTS) { + versionInfo->commitInsert(startRow, numRowsToCommit, commitTS); +} + void ChunkedNodeGroup::commitDelete(row_idx_t startRow, row_idx_t numRows_, transaction_t commitTS) { versionInfo->commitDelete(startRow, numRows_, commitTS); } -void ChunkedNodeGroup::rollbackDelete(row_idx_t startRow, row_idx_t numRows_) { +void ChunkedNodeGroup::rollbackDelete(row_idx_t startRow, row_idx_t numRows_, + common::transaction_t) { versionInfo->rollbackDelete(startRow, numRows_); } diff --git a/src/storage/store/column.cpp b/src/storage/store/column.cpp index c351696414c..22179b16e21 100644 --- a/src/storage/store/column.cpp +++ b/src/storage/store/column.cpp @@ -196,8 +196,9 @@ void Column::scan(Transaction* transaction, const ChunkState& state, offset_t st scanInternal(transaction, state, startOffsetInChunk, numValuesToScan, resultVector); } -void Column::scan(Transaction* transaction, const ChunkState& state, offset_t startOffsetInGroup, - offset_t endOffsetInGroup, ValueVector* resultVector, uint64_t offsetInVector) const { +void Column::scan(const Transaction* transaction, const ChunkState& state, + offset_t startOffsetInGroup, offset_t endOffsetInGroup, ValueVector* resultVector, + uint64_t offsetInVector) const { if (nullColumn) { KU_ASSERT(state.nullState); nullColumn->scan(transaction, *state.nullState, startOffsetInGroup, endOffsetInGroup, @@ -207,8 +208,8 @@ void Column::scan(Transaction* transaction, const ChunkState& state, offset_t st startOffsetInGroup, endOffsetInGroup, readToVectorFunc); } -void Column::scan(Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, - offset_t startOffset, offset_t endOffset) const { +void Column::scan(const Transaction* transaction, const ChunkState& state, + ColumnChunkData* columnChunk, offset_t startOffset, offset_t endOffset) const { if (nullColumn) { nullColumn->scan(transaction, *state.nullState, columnChunk->getNullData(), startOffset, endOffset); @@ -233,8 +234,8 @@ void Column::scan(Transaction* transaction, const ChunkState& state, ColumnChunk columnChunk->setNumValues(numValuesScanned); } -void Column::scan(Transaction* transaction, const ChunkState& state, offset_t startOffsetInGroup, - offset_t endOffsetInGroup, uint8_t* result) { +void Column::scan(const Transaction* transaction, const ChunkState& state, + offset_t startOffsetInGroup, offset_t endOffsetInGroup, uint8_t* result) { columnReadWriter->readCompressedValuesToPage(transaction, state, result, 0, startOffsetInGroup, endOffsetInGroup, readToPageFunc); } @@ -267,8 +268,8 @@ void Column::scanInternal(Transaction* transaction, const ChunkState& state, } } -void Column::lookupValue(Transaction* transaction, const ChunkState& state, offset_t nodeOffset, - ValueVector* resultVector, uint32_t posInVector) const { +void Column::lookupValue(const Transaction* transaction, const ChunkState& state, + offset_t nodeOffset, ValueVector* resultVector, uint32_t posInVector) const { if (nullColumn) { nullColumn->lookupValue(transaction, *state.nullState, nodeOffset, resultVector, posInVector); @@ -279,8 +280,8 @@ void Column::lookupValue(Transaction* transaction, const ChunkState& state, offs lookupInternal(transaction, state, nodeOffset, resultVector, posInVector); } -void Column::lookupInternal(Transaction* transaction, const ChunkState& state, offset_t nodeOffset, - ValueVector* resultVector, uint32_t posInVector) const { +void Column::lookupInternal(const Transaction* transaction, const ChunkState& state, + offset_t nodeOffset, ValueVector* resultVector, uint32_t posInVector) const { columnReadWriter->readCompressedValueToVector(transaction, state, nodeOffset, resultVector, posInVector, readToVectorFunc); } diff --git a/src/storage/store/column_chunk.cpp b/src/storage/store/column_chunk.cpp index 2c5dae7aa48..13d93f75260 100644 --- a/src/storage/store/column_chunk.cpp +++ b/src/storage/store/column_chunk.cpp @@ -83,7 +83,7 @@ void ColumnChunk::scan(const Transaction* transaction, const ChunkState& state, } template -void ColumnChunk::scanCommitted(Transaction* transaction, ChunkState& chunkState, +void ColumnChunk::scanCommitted(const Transaction* transaction, ChunkState& chunkState, ColumnChunk& output, row_idx_t startRow, row_idx_t numRows) const { if (numRows == INVALID_ROW_IDX) { numRows = getNumValues(); @@ -111,9 +111,9 @@ void ColumnChunk::scanCommitted(Transaction* transaction, ChunkState& chunkState } } -template void ColumnChunk::scanCommitted(Transaction* transaction, +template void ColumnChunk::scanCommitted(const Transaction* transaction, ChunkState& chunkState, ColumnChunk& output, row_idx_t startRow, row_idx_t numRows) const; -template void ColumnChunk::scanCommitted(Transaction* transaction, +template void ColumnChunk::scanCommitted(const Transaction* transaction, ChunkState& chunkState, ColumnChunk& output, row_idx_t startRow, row_idx_t numRows) const; bool ColumnChunk::hasUpdates(const Transaction* transaction, row_idx_t startRow, @@ -160,8 +160,8 @@ void ColumnChunk::scanCommittedUpdates(const Transaction* transaction, ColumnChu } } -void ColumnChunk::lookup(Transaction* transaction, const ChunkState& state, offset_t rowInChunk, - ValueVector& output, sel_t posInOutputVector) const { +void ColumnChunk::lookup(const Transaction* transaction, const ChunkState& state, + offset_t rowInChunk, ValueVector& output, sel_t posInOutputVector) const { switch (getResidencyState()) { case ResidencyState::IN_MEMORY: { data->lookup(rowInChunk, output, posInOutputVector); diff --git a/src/storage/store/column_reader_writer.cpp b/src/storage/store/column_reader_writer.cpp index 4c6f60f8d20..728bbad2645 100644 --- a/src/storage/store/column_reader_writer.cpp +++ b/src/storage/store/column_reader_writer.cpp @@ -79,23 +79,24 @@ class DefaultColumnReadWriter : public ColumnReadWriter { ShadowFile* shadowFile) : ColumnReadWriter(dbFileID, dataFH, bufferManager, shadowFile) {} - void readCompressedValueToPage(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t nodeOffset, uint8_t* result, uint32_t offsetInResult, - const read_value_from_page_func_t& readFunc) override { + void readCompressedValueToPage(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t nodeOffset, uint8_t* result, + uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) override { auto [offsetInChunk, cursor] = getOffsetAndCursor(nodeOffset, state); readCompressedValue(transaction, state.metadata, cursor, offsetInChunk, result, offsetInResult, readFunc); } - void readCompressedValueToVector(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t nodeOffset, common::ValueVector* result, uint32_t offsetInResult, + void readCompressedValueToVector(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* result, + uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) override { auto [offsetInChunk, cursor] = getOffsetAndCursor(nodeOffset, state); readCompressedValue(transaction, state.metadata, cursor, offsetInChunk, result, offsetInResult, readFunc); } - uint64_t readCompressedValuesToPage(transaction::Transaction* transaction, + uint64_t readCompressedValuesToPage(const transaction::Transaction* transaction, const ChunkState& state, uint8_t* result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, @@ -104,7 +105,7 @@ class DefaultColumnReadWriter : public ColumnReadWriter { startNodeOffset, endNodeOffset, readFunc, filterFunc); } - uint64_t readCompressedValuesToVector(transaction::Transaction* transaction, + uint64_t readCompressedValuesToVector(const transaction::Transaction* transaction, const ChunkState& state, common::ValueVector* result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, @@ -152,7 +153,7 @@ class DefaultColumnReadWriter : public ColumnReadWriter { } template - void readCompressedValue(transaction::Transaction* transaction, + void readCompressedValue(const transaction::Transaction* transaction, const ColumnChunkMetadata& metadata, PageCursor cursor, common::offset_t /*offsetInChunk*/, OutputType result, uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) { @@ -164,7 +165,7 @@ class DefaultColumnReadWriter : public ColumnReadWriter { } template - uint64_t readCompressedValues(Transaction* transaction, const ChunkState& state, + uint64_t readCompressedValues(const Transaction* transaction, const ChunkState& state, OutputType result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, const std::optional& filterFunc) { @@ -210,23 +211,24 @@ class FloatColumnReadWriter : public ColumnReadWriter { defaultReader(std::make_unique(dbFileID, dataFH, bufferManager, shadowFile)) {} - void readCompressedValueToPage(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t nodeOffset, uint8_t* result, uint32_t offsetInResult, - const read_value_from_page_func_t& readFunc) override { + void readCompressedValueToPage(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t nodeOffset, uint8_t* result, + uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) override { auto [offsetInChunk, cursor] = getOffsetAndCursor(nodeOffset, state); readCompressedValue(transaction, state, offsetInChunk, result, offsetInResult, readFunc); } - void readCompressedValueToVector(transaction::Transaction* transaction, const ChunkState& state, - common::offset_t nodeOffset, common::ValueVector* result, uint32_t offsetInResult, + void readCompressedValueToVector(const transaction::Transaction* transaction, + const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* result, + uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) override { auto [offsetInChunk, cursor] = getOffsetAndCursor(nodeOffset, state); readCompressedValue(transaction, state, offsetInChunk, result, offsetInResult, readFunc); } - uint64_t readCompressedValuesToPage(transaction::Transaction* transaction, + uint64_t readCompressedValuesToPage(const transaction::Transaction* transaction, const ChunkState& state, uint8_t* result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, @@ -235,7 +237,7 @@ class FloatColumnReadWriter : public ColumnReadWriter { startNodeOffset, endNodeOffset, readFunc, filterFunc); } - uint64_t readCompressedValuesToVector(transaction::Transaction* transaction, + uint64_t readCompressedValuesToVector(const transaction::Transaction* transaction, const ChunkState& state, common::ValueVector* result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, @@ -298,7 +300,7 @@ class FloatColumnReadWriter : public ColumnReadWriter { } template - void readCompressedValue(transaction::Transaction* transaction, const ChunkState& state, + void readCompressedValue(const transaction::Transaction* transaction, const ChunkState& state, common::offset_t offsetInChunk, OutputType result, uint32_t offsetInResult, const read_value_from_page_func_t& readFunc) { RUNTIME_CHECK(const ColumnChunkMetadata& metadata = state.metadata); @@ -311,7 +313,7 @@ class FloatColumnReadWriter : public ColumnReadWriter { } template - uint64_t readCompressedValues(Transaction* transaction, const ChunkState& state, + uint64_t readCompressedValues(const Transaction* transaction, const ChunkState& state, OutputType result, uint32_t startOffsetInResult, uint64_t startNodeOffset, uint64_t endNodeOffset, const read_values_from_page_func_t& readFunc, const std::optional& filterFunc) { @@ -429,7 +431,7 @@ ColumnReadWriter::ColumnReadWriter(DBFileID dbFileID, FileHandle* dataFH, BufferManager* bufferManager, ShadowFile* shadowFile) : dbFileID(dbFileID), dataFH(dataFH), bufferManager(bufferManager), shadowFile(shadowFile) {} -void ColumnReadWriter::readFromPage(Transaction* transaction, page_idx_t pageIdx, +void ColumnReadWriter::readFromPage(const Transaction* transaction, page_idx_t pageIdx, const std::function& readFunc) { // For constant compression, call read on a nullptr since there is no data on disk and // decompression only requires metadata diff --git a/src/storage/store/csr_chunked_node_group.cpp b/src/storage/store/csr_chunked_node_group.cpp index c2324863f2f..40978905585 100644 --- a/src/storage/store/csr_chunked_node_group.cpp +++ b/src/storage/store/csr_chunked_node_group.cpp @@ -245,7 +245,7 @@ std::unique_ptr ChunkedCSRNodeGroup::flushAsNewChunkedNodeGrou std::move(flushedChunks), 0 /*startRowIdx*/); flushedChunkedGroup->versionInfo = std::make_unique(); KU_ASSERT(numRows == flushedChunkedGroup->getNumRows()); - flushedChunkedGroup->versionInfo->append(transaction, flushedChunkedGroup.get(), 0, numRows); + flushedChunkedGroup->versionInfo->append(transaction, 0, numRows); return flushedChunkedGroup; } diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index a3f66dcabd0..eb3e7709826 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -55,7 +55,8 @@ bool CSRNodeGroupScanState::tryScanCachedTuples(RelTableScanState& tableScanStat return true; } -void CSRNodeGroup::initializeScanState(Transaction* transaction, TableScanState& state) const { +void CSRNodeGroup::initializeScanState(const Transaction* transaction, + TableScanState& state) const { auto& relScanState = state.cast(); KU_ASSERT(relScanState.nodeGroupScanState); auto& nodeGroupScanState = relScanState.nodeGroupScanState->cast(); @@ -78,7 +79,7 @@ void CSRNodeGroup::initializeScanState(Transaction* transaction, TableScanState& } } -void CSRNodeGroup::initScanForCommittedPersistent(Transaction* transaction, +void CSRNodeGroup::initScanForCommittedPersistent(const Transaction* transaction, RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const { // Scan the csr header chunks from disk. ChunkState offsetState, lengthState; @@ -115,7 +116,8 @@ void CSRNodeGroup::initScanForCommittedInMem(RelTableScanState& relScanState, nodeGroupScanState.inMemCSRList.clear(); } -NodeGroupScanResult CSRNodeGroup::scan(Transaction* transaction, TableScanState& state) const { +NodeGroupScanResult CSRNodeGroup::scan(const Transaction* transaction, + TableScanState& state) const { auto& relScanState = state.cast(); auto& nodeGroupScanState = relScanState.nodeGroupScanState->cast(); while (true) { @@ -220,7 +222,7 @@ NodeGroupScanResult CSRNodeGroup::scanCommittedPersistentWtihoutCache( return NodeGroupScanResult{startRow, numToScan}; } -NodeGroupScanResult CSRNodeGroup::scanCommittedInMem(Transaction* transaction, +NodeGroupScanResult CSRNodeGroup::scanCommittedInMem(const Transaction* transaction, RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { while (true) { if (tableState.currBoundNodeIdx >= tableState.cachedBoundNodeSelVector.getSelSize()) { @@ -276,7 +278,7 @@ NodeGroupScanResult CSRNodeGroup::scanCommittedInMemSequential(const Transaction return NodeGroupScanResult{startRow, numRows}; } -NodeGroupScanResult CSRNodeGroup::scanCommittedInMemRandom(Transaction* transaction, +NodeGroupScanResult CSRNodeGroup::scanCommittedInMemRandom(const Transaction* transaction, const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { const auto numRows = std::min(nodeGroupScanState.inMemCSRList.rowIndices.size() - nodeGroupScanState.nextRowToScan, @@ -953,5 +955,28 @@ void CSRNodeGroup::finalizeCheckpoint(const UniqLock& lock) { csrIndex.reset(); } +std::pair CSRNodeGroup::actionOnChunkedGroups(const common::UniqLock& lock, + common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS, + CSRNodeGroupScanSource source, chunked_group_transaction_operation_t operation) { + if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + KU_ASSERT(persistentChunkGroup || (numRows_ == 0)); + if (persistentChunkGroup) { + std::invoke(operation, *persistentChunkGroup, startRow, numRows_, commitTS); + } + return {UINT32_MAX, UINT32_MAX}; + } else { + KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); + return NodeGroup::actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, + operation); + } +} + +common::row_idx_t CSRNodeGroup::getNumPersistentRows() const { + if (!persistentChunkGroup) { + return 0; + } + return persistentChunkGroup->getNumRows(); +} + } // namespace storage } // namespace kuzu diff --git a/src/storage/store/dictionary_column.cpp b/src/storage/store/dictionary_column.cpp index 91d3350dacb..28a5a8f7939 100644 --- a/src/storage/store/dictionary_column.cpp +++ b/src/storage/store/dictionary_column.cpp @@ -28,7 +28,7 @@ DictionaryColumn::DictionaryColumn(const std::string& name, FileHandle* dataFH, shadowFile, enableCompression, false /*requireNullColumn*/); } -void DictionaryColumn::scan(Transaction* transaction, const ChunkState& state, +void DictionaryColumn::scan(const Transaction* transaction, const ChunkState& state, DictionaryChunk& dictChunk) const { auto& dataMetadata = StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA).metadata; @@ -51,7 +51,7 @@ void DictionaryColumn::scan(Transaction* transaction, const ChunkState& state, StringColumn::getChildState(state, StringColumn::ChildStateIndex::OFFSET), offsetChunk); } -void DictionaryColumn::scan(Transaction* transaction, const ChunkState& offsetState, +void DictionaryColumn::scan(const Transaction* transaction, const ChunkState& offsetState, const ChunkState& dataState, std::vector>& offsetsToScan, ValueVector* resultVector, const ColumnChunkMetadata& indexMeta) const { string_index_t firstOffsetToScan = 0, lastOffsetToScan = 0; @@ -108,7 +108,7 @@ string_index_t DictionaryColumn::append(const DictionaryChunk& dictChunk, ChunkS reinterpret_cast(&startOffset), nullptr /*nullChunkData*/, 1 /*numValues*/); } -void DictionaryColumn::scanOffsets(Transaction* transaction, const ChunkState& state, +void DictionaryColumn::scanOffsets(const Transaction* transaction, const ChunkState& state, DictionaryChunk::string_offset_t* offsets, uint64_t index, uint64_t numValues, uint64_t dataSize) const { // We either need to read the next value, or store the maximum string offset at the end. @@ -121,9 +121,9 @@ void DictionaryColumn::scanOffsets(Transaction* transaction, const ChunkState& s } } -void DictionaryColumn::scanValueToVector(Transaction* transaction, const ChunkState& dataState, - uint64_t startOffset, uint64_t endOffset, ValueVector* resultVector, - uint64_t offsetInVector) const { +void DictionaryColumn::scanValueToVector(const Transaction* transaction, + const ChunkState& dataState, uint64_t startOffset, uint64_t endOffset, + ValueVector* resultVector, uint64_t offsetInVector) const { KU_ASSERT(endOffset >= startOffset); // Add string to vector first and read directly into the vector auto& kuString = diff --git a/src/storage/store/list_column.cpp b/src/storage/store/list_column.cpp index 1cd6d83b35a..b693137a044 100644 --- a/src/storage/store/list_column.cpp +++ b/src/storage/store/list_column.cpp @@ -85,7 +85,7 @@ std::unique_ptr ListColumn::flushChunkData(const ColumnChunkDat return flushedChunk; } -void ListColumn::scan(Transaction* transaction, const ChunkState& state, +void ListColumn::scan(const Transaction* transaction, const ChunkState& state, offset_t startOffsetInGroup, offset_t endOffsetInGroup, ValueVector* resultVector, uint64_t offsetInVector) const { nullColumn->scan(transaction, *state.nullState, startOffsetInGroup, endOffsetInGroup, @@ -126,7 +126,7 @@ void ListColumn::scan(Transaction* transaction, const ChunkState& state, } } -void ListColumn::scan(Transaction* transaction, const ChunkState& state, +void ListColumn::scan(const Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, offset_t startOffset, offset_t endOffset) const { Column::scan(transaction, state, columnChunk, startOffset, endOffset); if (columnChunk->getNumValues() == 0) { @@ -190,7 +190,7 @@ void ListColumn::scanInternal(Transaction* transaction, const ChunkState& state, } } -void ListColumn::lookupInternal(Transaction* transaction, const ChunkState& state, +void ListColumn::lookupInternal(const Transaction* transaction, const ChunkState& state, offset_t nodeOffset, ValueVector* resultVector, uint32_t posInVector) const { auto [nodeGroupIdx, offsetInChunk] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(nodeOffset); const auto listEndOffset = readOffset(transaction, state, offsetInChunk); @@ -271,7 +271,7 @@ void ListColumn::scanFiltered(Transaction* transaction, const ChunkState& state, } } -offset_t ListColumn::readOffset(Transaction* transaction, const ChunkState& readState, +offset_t ListColumn::readOffset(const Transaction* transaction, const ChunkState& readState, offset_t offsetInNodeGroup) const { offset_t ret = INVALID_OFFSET; const auto& offsetState = readState.childrenStates[OFFSET_COLUMN_CHILD_READ_STATE_IDX]; @@ -280,7 +280,7 @@ offset_t ListColumn::readOffset(Transaction* transaction, const ChunkState& read return ret; } -list_size_t ListColumn::readSize(Transaction* transaction, const ChunkState& readState, +list_size_t ListColumn::readSize(const Transaction* transaction, const ChunkState& readState, offset_t offsetInNodeGroup) const { const auto& sizeState = readState.childrenStates[SIZE_COLUMN_CHILD_READ_STATE_IDX]; offset_t value = INVALID_OFFSET; @@ -289,7 +289,7 @@ list_size_t ListColumn::readSize(Transaction* transaction, const ChunkState& rea return value; } -ListOffsetSizeInfo ListColumn::getListOffsetSizeInfo(Transaction* transaction, +ListOffsetSizeInfo ListColumn::getListOffsetSizeInfo(const Transaction* transaction, const ChunkState& state, offset_t startOffsetInNodeGroup, offset_t endOffsetInNodeGroup) const { const auto numOffsetsToRead = endOffsetInNodeGroup - startOffsetInNodeGroup; auto offsetColumnChunk = ColumnChunkFactory::createColumnChunkData(*mm, LogicalType::INT64(), diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index 3ecb788f695..b5a96a61cf1 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -99,12 +99,12 @@ void NodeGroup::merge(Transaction*, std::unique_ptr chunkedGro KU_ASSERT(chunkedGroup->getColumnChunk(i).getDataType().getPhysicalType() == dataTypes[i].getPhysicalType()); } - const auto lock = chunkedGroups.lock(); numRows += chunkedGroup->getNumRows(); + const auto lock = chunkedGroups.lock(); chunkedGroups.appendGroup(lock, std::move(chunkedGroup)); } -void NodeGroup::initializeScanState(Transaction* transaction, TableScanState& state) const { +void NodeGroup::initializeScanState(const Transaction* transaction, TableScanState& state) const { const auto lock = chunkedGroups.lock(); initializeScanState(transaction, lock, state); } @@ -129,7 +129,7 @@ static void initializeScanStateForChunkedGroup(const TableScanState& state, } } -void NodeGroup::initializeScanState(Transaction*, const UniqLock& lock, +void NodeGroup::initializeScanState(const Transaction*, const UniqLock& lock, TableScanState& state) const { auto& nodeGroupScanState = *state.nodeGroupScanState; nodeGroupScanState.chunkedGroupIdx = 0; @@ -167,7 +167,7 @@ void applySemiMaskFilter(const TableScanState& state, row_idx_t numRowsToScan, } } -NodeGroupScanResult NodeGroup::scan(Transaction* transaction, TableScanState& state) const { +NodeGroupScanResult NodeGroup::scan(const Transaction* transaction, TableScanState& state) const { // TODO(Guodong): Move the locked part of figuring out the chunked group to initScan. const auto lock = chunkedGroups.lock(); auto& nodeGroupScanState = *state.nodeGroupScanState; @@ -239,7 +239,7 @@ NodeGroupScanResult NodeGroup::scanInternal(const common::UniqLock& lock, Transa auto& nodeGroupScanState = *state.nodeGroupScanState; nodeGroupScanState.nextRowToScan = startOffsetInGroup; - auto newChunkedGroupIdx = findChunkedGroupIdxFromRowIdx(lock, startOffsetInGroup); + auto [newChunkedGroupIdx, _] = findChunkedGroupIdxFromRowIdxNoLock(startOffsetInGroup); const auto* chunkedGroupToScan = chunkedGroups.getGroup(lock, newChunkedGroupIdx); if (newChunkedGroupIdx != nodeGroupScanState.chunkedGroupIdx) { @@ -353,25 +353,91 @@ void NodeGroup::flush(Transaction* transaction, FileHandle& dataFH) { chunkedGroups.resize(lock, 1); } +std::pair NodeGroup::actionOnChunkedGroups(const common::UniqLock& lock, + common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS, + CSRNodeGroupScanSource, chunked_group_transaction_operation_t operation) { + const auto [startChunkedGroupIdx, startRowIdxInChunk] = + findChunkedGroupIdxFromRowIdxNoLock(startRow); + if (startChunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { + auto curChunkedGroupIdx = startChunkedGroupIdx; + auto curStartRowIdxInChunk = startRowIdxInChunk; + + auto numRowsLeft = numRows_; + while (numRowsLeft > 0 && curChunkedGroupIdx < chunkedGroups.getNumGroups(lock)) { + auto* chunkedGroup = chunkedGroups.getGroup(lock, curChunkedGroupIdx); + const auto numRowsForGroup = + std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); + std::invoke(operation, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); + + ++curChunkedGroupIdx; + numRowsLeft -= numRowsForGroup; + curStartRowIdxInChunk = 0; + } + } + + return {startChunkedGroupIdx, startRowIdxInChunk}; +} + +static constexpr common::transaction_t UNUSED_COMMIT_TS = INVALID_TRANSACTION; + +void NodeGroup::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + CSRNodeGroupScanSource source) { + const auto lock = chunkedGroups.lock(); + const auto [startChunkedGroupIdx, startRowIdxInChunk] = actionOnChunkedGroups(lock, startRow, + numRows_, UNUSED_COMMIT_TS, source, &ChunkedNodeGroup::rollbackInsert); + if (startChunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { + const auto numChunkedGroups = chunkedGroups.getNumGroups(lock); + KU_ASSERT(startChunkedGroupIdx < numChunkedGroups); + const bool shouldRemoveStartChunk = (startRowIdxInChunk == 0); + const auto numChunksToRemove = + numChunkedGroups - startChunkedGroupIdx - (shouldRemoveStartChunk ? 0 : 1); + chunkedGroups.removeTrailingGroups(lock, numChunksToRemove); + + numRows = startRow; + } +} + +void NodeGroup::rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + CSRNodeGroupScanSource source) { + const auto lock = chunkedGroups.lock(); + actionOnChunkedGroups(lock, startRow, numRows_, UNUSED_COMMIT_TS, source, + &ChunkedNodeGroup::rollbackDelete); +} + +void NodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRows_, common::transaction_t commitTS, + CSRNodeGroupScanSource source) { + const auto lock = chunkedGroups.lock(); + actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, + &ChunkedNodeGroup::commitInsert); +} + +void NodeGroup::commitDelete(row_idx_t startRow, row_idx_t numRows_, common::transaction_t commitTS, + CSRNodeGroupScanSource source) { + const auto lock = chunkedGroups.lock(); + actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, + &ChunkedNodeGroup::commitDelete); +} + void NodeGroup::checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state) { // We don't need to consider deletions here, as they are flushed separately as metadata. // TODO(Guodong): A special case can be all rows are deleted or rollbacked, then we can skip // flushing the data. const auto lock = chunkedGroups.lock(); - KU_ASSERT(chunkedGroups.getNumGroups(lock) >= 1); - const auto firstGroup = chunkedGroups.getFirstGroup(lock); - const auto hasPersistentData = firstGroup->getResidencyState() == ResidencyState::ON_DISK; - // Re-populate version info here first. - auto checkpointedVersionInfo = checkpointVersionInfo(lock, &DUMMY_CHECKPOINT_TRANSACTION); - std::unique_ptr checkpointedChunkedGroup; - if (hasPersistentData) { - checkpointedChunkedGroup = checkpointInMemAndOnDisk(memoryManager, lock, state); - } else { - checkpointedChunkedGroup = checkpointInMemOnly(memoryManager, lock, state); + if (!chunkedGroups.isEmpty(lock)) { + const auto firstGroup = chunkedGroups.getFirstGroup(lock); + const auto hasPersistentData = firstGroup->getResidencyState() == ResidencyState::ON_DISK; + // Re-populate version info here first. + auto checkpointedVersionInfo = checkpointVersionInfo(lock, &DUMMY_CHECKPOINT_TRANSACTION); + std::unique_ptr checkpointedChunkedGroup; + if (hasPersistentData) { + checkpointedChunkedGroup = checkpointInMemAndOnDisk(memoryManager, lock, state); + } else { + checkpointedChunkedGroup = checkpointInMemOnly(memoryManager, lock, state); + } + checkpointedChunkedGroup->setVersionInfo(std::move(checkpointedVersionInfo)); + chunkedGroups.clear(lock); + chunkedGroups.appendGroup(lock, std::move(checkpointedChunkedGroup)); } - checkpointedChunkedGroup->setVersionInfo(std::move(checkpointedVersionInfo)); - chunkedGroups.clear(lock); - chunkedGroups.appendGroup(lock, std::move(checkpointedChunkedGroup)); } std::unique_ptr NodeGroup::checkpointInMemAndOnDisk(MemoryManager& memoryManager, @@ -444,7 +510,7 @@ std::unique_ptr NodeGroup::checkpointVersionInfo(const UniqLock& lo // TODO(Guodong): Optimize the for loop here to directly acess the version info. for (auto i = 0u; i < chunkedGroup->getNumRows(); i++) { if (chunkedGroup->isDeleted(transaction, i)) { - checkpointVersionInfo->delete_(transaction, nullptr, currRow + i); + checkpointVersionInfo->delete_(transaction, currRow + i); } } } @@ -526,33 +592,37 @@ std::unique_ptr NodeGroup::deserialize(MemoryManager& memoryManager, } } -idx_t NodeGroup::findChunkedGroupIdxFromRowIdx(const UniqLock& lock, row_idx_t rowIdx) const { - KU_ASSERT(!chunkedGroups.isEmpty(lock)); - const auto numRowsInFirstGroup = chunkedGroups.getFirstGroup(lock)->getNumRows(); +std::pair NodeGroup::findChunkedGroupIdxFromRowIdxNoLock(row_idx_t rowIdx) const { + if (chunkedGroups.getNumGroupsNoLock() == 0) { + return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; + } + const auto numRowsInFirstGroup = chunkedGroups.getFirstGroupNoLock()->getNumRows(); if (rowIdx < numRowsInFirstGroup) { - return 0; + return {0, rowIdx}; } rowIdx -= numRowsInFirstGroup; - return rowIdx / ChunkedNodeGroup::CHUNK_CAPACITY + 1; + const auto chunkedGroupIdx = rowIdx / ChunkedNodeGroup::CHUNK_CAPACITY + 1; + const auto rowIdxInChunk = rowIdx % ChunkedNodeGroup::CHUNK_CAPACITY; + if (chunkedGroupIdx >= chunkedGroups.getNumGroupsNoLock()) { + return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; + } + return {chunkedGroupIdx, rowIdxInChunk}; } ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdx(const UniqLock& lock, row_idx_t rowIdx) const { - auto chunkedGroupIdx = findChunkedGroupIdxFromRowIdx(lock, rowIdx); - if (chunkedGroupIdx >= chunkedGroups.getNumGroups(lock)) { + const auto [chunkedGroupIdx, rowIdxInChunkedGroup] = + findChunkedGroupIdxFromRowIdxNoLock(rowIdx); + if (chunkedGroupIdx == INVALID_CHUNKED_GROUP_IDX) { return nullptr; } return chunkedGroups.getGroup(lock, chunkedGroupIdx); } ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdxNoLock(row_idx_t rowIdx) const { - const auto numRowsInFirstGroup = chunkedGroups.getFirstGroupNoLock()->getNumRows(); - if (rowIdx < numRowsInFirstGroup) { - return chunkedGroups.getFirstGroupNoLock(); - } - rowIdx -= numRowsInFirstGroup; - const auto chunkedGroupIdx = rowIdx / ChunkedNodeGroup::CHUNK_CAPACITY + 1; - if (chunkedGroupIdx >= chunkedGroups.getNumGroupsNoLock()) { + const auto [chunkedGroupIdx, rowIdxInChunkedGroup] = + findChunkedGroupIdxFromRowIdxNoLock(rowIdx); + if (chunkedGroupIdx == INVALID_CHUNKED_GROUP_IDX) { return nullptr; } return chunkedGroups.getGroupNoLock(chunkedGroupIdx); diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index dbe7a53deef..91c7225508b 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -1,5 +1,6 @@ #include "storage/store/node_group_collection.h" +#include "common/utils.h" #include "common/vector/value_vector.h" #include "storage/buffer_manager/memory_manager.h" #include "storage/store/csr_node_group.h" @@ -14,9 +15,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer) + Deserializer* deSer, append_to_undo_buffer_func_t appendToUndoBufferFunc) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH} { + dataFH{dataFH}, appendToUndoBufferFunc(std::move(appendToUndoBufferFunc)) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -52,21 +53,22 @@ void NodeGroupCollection::append(const Transaction* transaction, const auto numToAppendInNodeGroup = std::min(numRowsToAppend - numRowsAppended, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); + appendToUndoBufferFunc(transaction, lastNodeGroup, numToAppendInNodeGroup); lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); numRowsAppended += numToAppendInNodeGroup; + numTotalRows += numToAppendInNodeGroup; } - numTotalRows += numRowsAppended; stats.incrementCardinality(numRowsAppended); } void NodeGroupCollection::append(const Transaction* transaction, NodeGroupCollection& other) { const auto otherLock = other.nodeGroups.lock(); for (auto& nodeGroup : other.nodeGroups.getAllGroups(otherLock)) { - appned(transaction, *nodeGroup); + append(transaction, *nodeGroup); } } -void NodeGroupCollection::appned(const Transaction* transaction, NodeGroup& nodeGroup) { +void NodeGroupCollection::append(const Transaction* transaction, NodeGroup& nodeGroup) { const auto numRowsToAppend = nodeGroup.getNumRows(); KU_ASSERT(nodeGroup.getDataTypes().size() == types.size()); const auto lock = nodeGroups.lock(); @@ -77,8 +79,8 @@ void NodeGroupCollection::appned(const Transaction* transaction, NodeGroup& node const auto numChunkedGroupsToAppend = nodeGroup.getNumChunkedGroups(); node_group_idx_t numChunkedGroupsAppended = 0; while (numChunkedGroupsAppended < numChunkedGroupsToAppend) { - const auto chunkedGrouoToAppend = nodeGroup.getChunkedNodeGroup(numChunkedGroupsAppended); - const auto numRowsToAppendInChunkedGroup = chunkedGrouoToAppend->getNumRows(); + const auto chunkedGroupToAppend = nodeGroup.getChunkedNodeGroup(numChunkedGroupsAppended); + const auto numRowsToAppendInChunkedGroup = chunkedGroupToAppend->getNumRows(); row_idx_t numRowsAppendedInChunkedGroup = 0; while (numRowsAppendedInChunkedGroup < numRowsToAppendInChunkedGroup) { auto lastNodeGroup = nodeGroups.getLastGroup(lock); @@ -92,13 +94,14 @@ void NodeGroupCollection::appned(const Transaction* transaction, NodeGroup& node std::min(numRowsToAppendInChunkedGroup - numRowsAppendedInChunkedGroup, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInBatch); - lastNodeGroup->append(transaction, *chunkedGrouoToAppend, numRowsAppendedInChunkedGroup, + appendToUndoBufferFunc(transaction, lastNodeGroup, numToAppendInBatch); + lastNodeGroup->append(transaction, *chunkedGroupToAppend, numRowsAppendedInChunkedGroup, numToAppendInBatch); numRowsAppendedInChunkedGroup += numToAppendInBatch; + numTotalRows += numToAppendInBatch; } numChunkedGroupsAppended++; } - numTotalRows += numRowsToAppend; stats.incrementCardinality(numRowsToAppend); } @@ -128,12 +131,12 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush // If the node group is empty now and the chunked group is full, we can directly flush it. directFlushWhenAppend = numToAppend == numRowsLeftInLastNodeGroup && lastNodeGroup->getNumRows() == 0; + appendToUndoBufferFunc(transaction, lastNodeGroup, chunkedGroup.getNumRows()); if (!directFlushWhenAppend) { // TODO(Guodong): Furthur optimize on this. Should directly figure out startRowIdx to // start appending into the node group and pass in as param. lastNodeGroup->append(transaction, chunkedGroup, 0, numToAppend); } - numTotalRows += numToAppend; } if (directFlushWhenAppend) { chunkedGroup.finalize(); @@ -141,6 +144,7 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush KU_ASSERT(lastNodeGroup->getNumChunkedGroups() == 0); lastNodeGroup->merge(transaction, std::move(flushedGroup)); } + numTotalRows += numToAppend; stats.incrementCardinality(numToAppend); return {startOffset, numToAppend}; } @@ -191,6 +195,58 @@ void NodeGroupCollection::checkpoint(MemoryManager& memoryManager, } } +static idx_t getNumEmptyTrailingGroups(const GroupCollection& nodeGroups, + const common::UniqLock& lock) { + const auto& nodeGroupVector = nodeGroups.getAllGroups(lock); + return safeIntegerConversion( + std::find_if(nodeGroupVector.rbegin(), nodeGroupVector.rend(), + [](const auto& nodeGroup) { return (nodeGroup->getNumRows() != 0); }) - + nodeGroupVector.rbegin()); +} + +void NodeGroupCollection::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx, CSRNodeGroupScanSource source) { + const auto lock = nodeGroups.lock(); + auto numRowsToSubtract = numRows_; + // skip the rollback if all newly created node groups have already been deleted + if (!nodeGroups.isEmpty(lock) || nodeGroupIdx > 0) { + KU_ASSERT(nodeGroupIdx < nodeGroups.getNumGroups(lock)); + auto* nodeGroup = nodeGroups.getGroup(lock, nodeGroupIdx); + + KU_ASSERT(startRow <= nodeGroup->getNumRows()); + numRowsToSubtract = std::min(numRowsToSubtract, nodeGroup->getNumRows() - startRow); + nodeGroup->rollbackInsert(startRow, numRows_, source); + + // remove any empty trailing node groups after the rollback + const auto numGroupsToRemove = getNumEmptyTrailingGroups(nodeGroups, lock); + nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); + } + KU_ASSERT(numRowsToSubtract <= numTotalRows); + numTotalRows -= numRowsToSubtract; +} + +void NodeGroupCollection::rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx, CSRNodeGroupScanSource source) { + const auto lock = nodeGroups.lock(); + KU_ASSERT(nodeGroupIdx < nodeGroups.getNumGroups(lock)); + nodeGroups.getGroup(lock, nodeGroupIdx)->rollbackDelete(startRow, numRows_, source); +} + +void NodeGroupCollection::commitInsert(row_idx_t startRow, row_idx_t numRows_, + node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, CSRNodeGroupScanSource source) { + if (numRows_ == 0) { + return; + } + const auto lock = nodeGroups.lock(); + nodeGroups.getGroup(lock, nodeGroupIdx)->commitInsert(startRow, numRows_, commitTS, source); +} + +void NodeGroupCollection::commitDelete(row_idx_t startRow, row_idx_t numRows_, + node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, CSRNodeGroupScanSource source) { + const auto lock = nodeGroups.lock(); + nodeGroups.getGroup(lock, nodeGroupIdx)->commitDelete(startRow, numRows_, commitTS, source); +} + void NodeGroupCollection::serialize(Serializer& ser) { ser.writeDebuggingInfo("node_groups"); nodeGroups.serializeGroups(ser); @@ -207,5 +263,8 @@ void NodeGroupCollection::deserialize(Deserializer& deSer, MemoryManager& memory stats.deserialize(deSer); } +void NodeGroupCollection::defaultAppendToUndoBuffer(const transaction::Transaction*, NodeGroup*, + common::row_idx_t) {} + } // namespace storage } // namespace kuzu diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 62394029b1a..36b41b216af 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -43,6 +43,147 @@ bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset return true; } +template +concept notIndexHashable = !IndexHashable; + +namespace { +struct PKColumnScanHelper { + explicit PKColumnScanHelper(common::node_group_idx_t numNodeGroups, PrimaryKeyIndex* pkIndex, + common::DataChunk dataChunk, table_id_t tableID) + : numNodeGroups(numNodeGroups), dataChunk(std::move(dataChunk)), tableID(tableID), + pkIndex(pkIndex) {} + virtual ~PKColumnScanHelper() = default; + + virtual bool processScanOutput(const transaction::Transaction* transaction, + NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) = 0; + virtual NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const = 0; + + common::node_group_idx_t numNodeGroups; + common::DataChunk dataChunk; + table_id_t tableID; + PrimaryKeyIndex* pkIndex; +}; + +struct CommittedPKColumnScanHelper : public PKColumnScanHelper { +public: + CommittedPKColumnScanHelper(LocalNodeTable& localTable, row_idx_t startNodeOffset, + DataChunk dataChunk, table_id_t tableID, PrimaryKeyIndex* pkIndex, visible_func isVisible) + : PKColumnScanHelper(localTable.getNumNodeGroups(), pkIndex, std::move(dataChunk), tableID), + localTable(localTable), startNodeOffset(startNodeOffset), + nodeIDVector(LogicalType::INTERNAL_ID()), isVisible(std::move(isVisible)) { + nodeIDVector.setState(this->dataChunk.state); + } + + bool processScanOutput(const transaction::Transaction* transaction, + NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) override; + + NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const override { + return localTable.getNodeGroup(nodeGroupIdx); + } + + LocalNodeTable& localTable; + row_idx_t startNodeOffset; + ValueVector nodeIDVector; + visible_func isVisible; +}; + +struct RollbackPKColumnScanHelper : public PKColumnScanHelper { +public: + RollbackPKColumnScanHelper(row_idx_t startNodeOffset, row_idx_t numRows, + NodeGroupCollection& nodeGroups, DataChunk dataChunk, table_id_t tableID, + PrimaryKeyIndex* pkIndex) + : PKColumnScanHelper(nodeGroups.getNumNodeGroups(), pkIndex, std::move(dataChunk), tableID), + semiMask(RoaringBitmapSemiMaskUtil::createRoaringBitmapSemiMask(tableID, + startNodeOffset + numRows)), + nodeGroups(nodeGroups) { + for (row_idx_t i = 0; i < numRows; ++i) { + semiMask->mask(startNodeOffset + i); + } + } + + bool processScanOutput(const transaction::Transaction* transaction, + NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) override; + + NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const override { + return nodeGroups.getNodeGroup(nodeGroupIdx); + } + + std::unique_ptr semiMask; + NodeGroupCollection& nodeGroups; +}; + +static void insertPKInternal(const Transaction* transaction, const ValueVector& nodeIDVector, + const ValueVector& pkVector, PrimaryKeyIndex* pkIndex, const visible_func& isVisible) { + for (auto i = 0u; i < nodeIDVector.state->getSelVector().getSelSize(); i++) { + const auto nodeIDPos = nodeIDVector.state->getSelVector()[i]; + const auto offset = nodeIDVector.readNodeOffset(nodeIDPos); + auto pkPos = pkVector.state->getSelVector()[i]; + if (pkVector.isNull(pkPos)) { + throw RuntimeException(ExceptionMessage::nullPKException()); + } + if (!pkIndex->insert(transaction, &pkVector, pkPos, offset, isVisible)) { + throw RuntimeException( + ExceptionMessage::duplicatePKException(pkVector.getAsValue(pkPos)->toString())); + } + } +} + +void scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper, + std::unique_ptr scanState) { + + node_group_idx_t nodeGroupToScan = 0u; + while (nodeGroupToScan < scanHelper.numNodeGroups) { + // We need to scan from local storage here because some tuples in local node groups might + // have been deleted. + scanState->nodeGroup = scanHelper.getNodeGroup(nodeGroupToScan); + KU_ASSERT(scanState->nodeGroup); + scanState->nodeGroup->initializeScanState(transaction, *scanState); + while (true) { + auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); + if (!scanHelper.processScanOutput(transaction, scanResult, + *scanState->outputVectors[0])) { + break; + } + } + nodeGroupToScan++; + } +} + +bool CommittedPKColumnScanHelper::processScanOutput(const transaction::Transaction* transaction, + NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) { + if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { + return false; + } + for (auto i = 0u; i < scanResult.numRows; i++) { + nodeIDVector.setValue(i, nodeID_t{startNodeOffset + i, tableID}); + } + insertPKInternal(transaction, nodeIDVector, scannedVector, pkIndex, isVisible); + startNodeOffset += scanResult.numRows; + return true; +} + +bool RollbackPKColumnScanHelper::processScanOutput(const transaction::Transaction* transaction, + NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) { + if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { + return false; + } + const auto rollbackFunc = [&](T) { + for (idx_t i = 0; i < scannedVector.state->getSelSize(); ++i) { + const auto pos = scannedVector.state->getSelVector()[i]; + T key = scannedVector.getValue(pos); + static constexpr auto isVisible = [](offset_t) { return true; }; + offset_t lookupOffset = 0; + if (pkIndex->lookup(transaction, key, lookupOffset, isVisible)) { + pkIndex->delete_(key); + } + } + }; + TypeUtils::visit(scannedVector.dataType.getPhysicalType(), std::cref(rollbackFunc), + [](T) { KU_UNREACHABLE; }); + return true; +} +} // namespace + bool NodeTableScanState::scanNext(Transaction* transaction) { KU_ASSERT(columns.size() == outputVectors.size()); if (source == TableScanSource::NONE) { @@ -63,6 +204,16 @@ bool NodeTableScanState::scanNext(Transaction* transaction) { return true; } +static decltype(auto) createAppendToUndoBufferFunc(NodeTable* nodeTable) { + return [nodeTable](const transaction::Transaction* transaction, NodeGroup* nodeGroup, + common::row_idx_t numRows) { + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushInsertInfo(nodeTable, nodeGroup->getNodeGroupIdx(), + nodeGroup->getNumRows(), numRows); + } + }; +} + NodeTable::NodeTable(const StorageManager* storageManager, const NodeTableCatalogEntry* nodeTableEntry, MemoryManager* memoryManager, VirtualFileSystem* vfs, main::ClientContext* context, Deserializer* deSer) @@ -78,8 +229,10 @@ NodeTable::NodeTable(const StorageManager* storageManager, columns[columnID] = ColumnFactory::createColumn(columnName, property.getType().copy(), dataFH, memoryManager, shadowFile, enableCompression); } + nodeGroups = std::make_unique(*memoryManager, - getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); + getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer, + createAppendToUndoBufferFunc(this)); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); } @@ -294,6 +447,9 @@ bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) const auto rowIdxInGroup = nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushDeleteInfo(this, nodeGroupIdx, rowIdxInGroup, 1); + } } if (isDeleted) { hasChanges = true; @@ -329,6 +485,27 @@ std::pair NodeTable::appendToLastNodeGroup(Transaction* tran return nodeGroups->appendToLastNodeGroupAndFlushWhenFull(transaction, chunkedGroup); } +std::unique_ptr NodeTable::initPKScanState(DataChunk& dataChunk, + TableScanSource source) const { + std::vector columnIDs{getPKColumnID()}; + auto scanState = std::make_unique(tableID, columnIDs); + for (auto& vector : dataChunk.valueVectors) { + scanState->outputVectors.push_back(vector.get()); + } + scanState->outState = dataChunk.state.get(); + scanState->source = source; + for (const auto& column : columns) { + scanState->columns.push_back(column.get()); + } + return scanState; +} + +common::DataChunk NodeTable::constructDataChunkForPKColumn() const { + std::vector types; + types.push_back(columns[pkColumnID]->getDataType().copy()); + return constructDataChunk(std::move(types)); +} + void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { auto startNodeOffset = nodeGroups->getNumTotalRows(); transaction->setMaxCommittedNodeOffset(tableID, startNodeOffset); @@ -353,62 +530,37 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { const auto rowIdxInGroup = startNodeOffset + nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); - nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); + [[maybe_unused]] const bool isDeleted = + nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); + KU_ASSERT(isDeleted); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushDeleteInfo(this, nodeGroupIdx, rowIdxInGroup, 1); + } } } } numLocalRows += localNodeGroup->getNumRows(); } + // 3. Scan pk column for newly inserted tuples that are not deleted and insert into pk index. - std::vector columnIDs{getPKColumnID()}; - auto types = std::vector(); - types.push_back(columns[pkColumnID]->getDataType().copy()); - auto dataChunk = constructDataChunk(std::move(types)); - ValueVector nodeIDVector(LogicalType::INTERNAL_ID()); - nodeIDVector.setState(dataChunk.state); - const auto numNodeGroupsToScan = localNodeTable.getNumNodeGroups(); - const auto scanState = std::make_unique(tableID, columnIDs, - std::vector{}, dataChunk, &nodeIDVector); - scanState->source = TableScanSource::UNCOMMITTED; - node_group_idx_t nodeGroupToScan = 0u; - while (nodeGroupToScan < numNodeGroupsToScan) { - // We need to scan from local storage here because some tuples in local node groups might - // have been deleted. - scanState->nodeGroup = localNodeTable.getNodeGroup(nodeGroupToScan); - KU_ASSERT(scanState->nodeGroup); - scanState->nodeGroup->initializeScanState(transaction, *scanState); - while (true) { - auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); - if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { - break; - } - for (auto i = 0u; i < scanResult.numRows; i++) { - nodeIDVector.setValue(i, nodeID_t{startNodeOffset + i, tableID}); - } - insertPK(transaction, nodeIDVector, *scanState->outputVectors[0]); - startNodeOffset += scanResult.numRows; - } - nodeGroupToScan++; - } + CommittedPKColumnScanHelper scanHelper{localNodeTable, startNodeOffset, + constructDataChunkForPKColumn(), tableID, pkIndex.get(), getVisibleFunc(transaction)}; + scanPKColumn(transaction, scanHelper, + initPKScanState(scanHelper.dataChunk, TableScanSource::UNCOMMITTED)); + // 4. Clear local table. localTable->clear(); } +visible_func NodeTable::getVisibleFunc(const Transaction* transaction) const { + return + [this, transaction](offset_t offset_) -> bool { return isVisible(transaction, offset_); }; +} + void NodeTable::insertPK(const Transaction* transaction, const ValueVector& nodeIDVector, const ValueVector& pkVector) const { - for (auto i = 0u; i < nodeIDVector.state->getSelVector().getSelSize(); i++) { - const auto nodeIDPos = nodeIDVector.state->getSelVector()[i]; - const auto offset = nodeIDVector.readNodeOffset(nodeIDPos); - auto pkPos = pkVector.state->getSelVector()[i]; - if (pkVector.isNull(pkPos)) { - throw RuntimeException(ExceptionMessage::nullPKException()); - } - if (!pkIndex->insert(transaction, const_cast(&pkVector), pkPos, offset, - [&](offset_t offset_) { return isVisible(transaction, offset_); })) { - throw RuntimeException( - ExceptionMessage::duplicatePKException(pkVector.getAsValue(pkPos)->toString())); - } - } + return insertPKInternal(transaction, nodeIDVector, pkVector, pkIndex.get(), + getVisibleFunc(transaction)); } void NodeTable::checkpoint(Serializer& ser, TableCatalogEntry* tableEntry) { @@ -438,6 +590,19 @@ void NodeTable::checkpoint(Serializer& ser, TableCatalogEntry* tableEntry) { serialize(ser); } +void NodeTable::rollbackInsert(const transaction::Transaction* transaction, + common::row_idx_t startRow, common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx) { + row_idx_t startNodeOffset = startRow; + for (node_group_idx_t i = 0; i < nodeGroupIdx; ++i) { + startNodeOffset += nodeGroups->getNodeGroupNoLock(i)->getNumRows(); + } + + RollbackPKColumnScanHelper scanHelper{startNodeOffset, numRows_, *nodeGroups, + constructDataChunkForPKColumn(), tableID, pkIndex.get()}; + scanPKColumn(transaction, scanHelper, + initPKScanState(scanHelper.dataChunk, TableScanSource::COMMITTED)); +} + TableStats NodeTable::getStats(const Transaction* transaction) const { auto stats = nodeGroups->getStats(); const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID, @@ -462,6 +627,9 @@ bool NodeTable::isVisible(const Transaction* transaction, offset_t offset) const bool NodeTable::isVisibleNoLock(const Transaction* transaction, offset_t offset) const { auto [nodeGroupIdx, offsetInGroup] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(offset); + if (nodeGroupIdx >= nodeGroups->getNumNodeGroups()) { + return false; + } auto* nodeGroup = getNodeGroupNoLock(nodeGroupIdx); return nodeGroup->isVisibleNoLock(transaction, offsetInGroup); } diff --git a/src/storage/store/null_column.cpp b/src/storage/store/null_column.cpp index 6a13e4cba10..c634b1b9c79 100644 --- a/src/storage/store/null_column.cpp +++ b/src/storage/store/null_column.cpp @@ -42,14 +42,14 @@ void NullColumn::scan(Transaction* transaction, const ChunkState& state, scanInternal(transaction, state, startOffsetInChunk, numValuesToScan, resultVector); } -void NullColumn::scan(Transaction* transaction, const ChunkState& state, +void NullColumn::scan(const Transaction* transaction, const ChunkState& state, offset_t startOffsetInGroup, offset_t endOffsetInGroup, ValueVector* resultVector, uint64_t offsetInVector) const { Column::scan(transaction, state, startOffsetInGroup, endOffsetInGroup, resultVector, offsetInVector); } -void NullColumn::scan(Transaction* transaction, const ChunkState& state, +void NullColumn::scan(const Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, offset_t startOffset, offset_t endOffset) const { Column::scan(transaction, state, columnChunk, startOffset, endOffset); } diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index 70357beee14..c84dfa9af5d 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -334,8 +334,8 @@ void RelTable::throwIfNodeHasRels(Transaction* transaction, RelDataDirection dir } } -void RelTable::detachDeleteForCSRRels(Transaction* transaction, const RelTableData* tableData, - const RelTableData* reverseTableData, RelTableScanState* relDataReadState, +void RelTable::detachDeleteForCSRRels(Transaction* transaction, RelTableData* tableData, + RelTableData* reverseTableData, RelTableScanState* relDataReadState, RelTableDeleteState* deleteState) { const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID, LocalStorage::NotExistAction::RETURN_NULL); @@ -385,6 +385,15 @@ NodeGroup* RelTable::getOrCreateNodeGroup(node_group_idx_t nodeGroupIdx, bwdRelTableData->getOrCreateNodeGroup(nodeGroupIdx); } +void RelTable::pushInsertInfo(Transaction* transaction, RelDataDirection direction, + const CSRNodeGroup& nodeGroup, row_idx_t numRows_, CSRNodeGroupScanSource source) { + if (transaction->shouldAppendToUndoBuffer()) { + auto& relTableData = + (direction == common::RelDataDirection::FWD) ? fwdRelTableData : bwdRelTableData; + relTableData->pushInsertInfo(transaction, nodeGroup, numRows_, source); + } +} + void RelTable::commit(Transaction* transaction, LocalTable* localTable) { auto& localRelTable = localTable->cast(); if (localRelTable.isEmpty()) { @@ -401,22 +410,27 @@ void RelTable::commit(Transaction* transaction, LocalTable* localTable) { for (auto i = 0u; i < localRelTable.getNumColumns(); i++) { columnIDsToScan.push_back(i); } - auto& fwdIndex = localRelTable.getFWDIndex(); - for (auto& [boundNodeOffset, rowIndices] : fwdIndex) { - auto [nodeGroupIdx, boundOffsetInGroup] = - StorageUtils::getQuotientRemainder(boundNodeOffset, StorageConstants::NODE_GROUP_SIZE); - auto& nodeGroup = fwdRelTableData->getOrCreateNodeGroup(nodeGroupIdx)->cast(); - prepareCommitForNodeGroup(transaction, localNodeGroup, nodeGroup, boundOffsetInGroup, - rowIndices, LOCAL_BOUND_NODE_ID_COLUMN_ID); - } - auto& bwdIndex = localRelTable.getBWDIndex(); - for (auto& [boundNodeOffset, rowIndices] : bwdIndex) { - auto [nodeGroupIdx, boundOffsetInGroup] = - StorageUtils::getQuotientRemainder(boundNodeOffset, StorageConstants::NODE_GROUP_SIZE); - auto& nodeGroup = bwdRelTableData->getOrCreateNodeGroup(nodeGroupIdx)->cast(); - prepareCommitForNodeGroup(transaction, localNodeGroup, nodeGroup, boundOffsetInGroup, - rowIndices, LOCAL_NBR_NODE_ID_COLUMN_ID); - } + + const auto commitRelTableData = [&](RelDataDirection direction) { + auto [index, relTableData, columnToSkip] = + (direction == common::RelDataDirection::FWD) ? + std::tie(localRelTable.getFWDIndex(), fwdRelTableData, + LOCAL_BOUND_NODE_ID_COLUMN_ID) : + std::tie(localRelTable.getBWDIndex(), bwdRelTableData, LOCAL_NBR_NODE_ID_COLUMN_ID); + for (auto& [boundNodeOffset, rowIndices] : index) { + auto [nodeGroupIdx, boundOffsetInGroup] = StorageUtils::getQuotientRemainder( + boundNodeOffset, StorageConstants::NODE_GROUP_SIZE); + auto& nodeGroup = + relTableData->getOrCreateNodeGroup(nodeGroupIdx)->cast(); + pushInsertInfo(transaction, direction, nodeGroup, rowIndices.size(), + CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); + prepareCommitForNodeGroup(transaction, localNodeGroup, nodeGroup, boundOffsetInGroup, + rowIndices, columnToSkip); + } + }; + commitRelTableData(common::RelDataDirection::FWD); + commitRelTableData(common::RelDataDirection::BWD); + localRelTable.clear(); } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index b476c3dd2d3..e5a53eed0df 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -83,7 +83,7 @@ bool RelTableData::update(Transaction* transaction, ValueVector& boundNodeIDVect } bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVector, - const ValueVector& relIDVector) const { + const ValueVector& relIDVector) { const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0]; const auto relIDPos = relIDVector.state->getSelVector()[0]; if (boundNodeIDVector.isNull(boundNodePos) || relIDVector.isNull(relIDPos)) { @@ -96,7 +96,11 @@ bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVec const auto boundNodeOffset = boundNodeIDVector.getValue(boundNodePos).offset; const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset); auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); - return csrNodeGroup.delete_(transaction, source, rowIdx); + bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); + if (isDeleted && transaction->shouldAppendToUndoBuffer()) { + transaction->pushDeleteInfo(this, nodeGroupIdx, rowIdx, 1, source); + } + return isDeleted; } void RelTableData::addColumn(Transaction* transaction, TableAddColumnState& addColumnState) { @@ -186,6 +190,14 @@ bool RelTableData::checkIfNodeHasRels(Transaction* transaction, return false; } +void RelTableData::pushInsertInfo(transaction::Transaction* transaction, + const CSRNodeGroup& nodeGroup, common::row_idx_t numRows_, CSRNodeGroupScanSource source) { + const auto startRow = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? + nodeGroup.getNumPersistentRows() : + nodeGroup.getNumRows(); + transaction->pushInsertInfo(this, nodeGroup.getNodeGroupIdx(), startRow, numRows_, source); +} + void RelTableData::checkpoint(const std::vector& columnIDs) { std::vector> checkpointColumns; for (auto i = 0u; i < columnIDs.size(); i++) { diff --git a/src/storage/store/string_column.cpp b/src/storage/store/string_column.cpp index 8fb281d1296..1b05e46dfc1 100644 --- a/src/storage/store/string_column.cpp +++ b/src/storage/store/string_column.cpp @@ -59,7 +59,7 @@ std::unique_ptr StringColumn::flushChunkData(const ColumnChunkD return flushedChunkData; } -void StringColumn::scan(Transaction* transaction, const ChunkState& state, +void StringColumn::scan(const Transaction* transaction, const ChunkState& state, offset_t startOffsetInGroup, offset_t endOffsetInGroup, ValueVector* resultVector, uint64_t offsetInVector) const { nullColumn->scan(transaction, *state.nullState, startOffsetInGroup, endOffsetInGroup, @@ -68,7 +68,7 @@ void StringColumn::scan(Transaction* transaction, const ChunkState& state, resultVector, offsetInVector); } -void StringColumn::scan(Transaction* transaction, const ChunkState& state, +void StringColumn::scan(const Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, offset_t startOffset, offset_t endOffset) const { KU_ASSERT(state.nullState); Column::scan(transaction, state, columnChunk, startOffset, endOffset); @@ -82,7 +82,7 @@ void StringColumn::scan(Transaction* transaction, const ChunkState& state, dictionary.scan(transaction, state, stringColumnChunk.getDictionaryChunk()); } -void StringColumn::lookupInternal(Transaction* transaction, const ChunkState& state, +void StringColumn::lookupInternal(const Transaction* transaction, const ChunkState& state, offset_t nodeOffset, ValueVector* resultVector, uint32_t posInVector) const { auto [nodeGroupIdx, offsetInChunk] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(nodeOffset); string_index_t index = 0; @@ -141,7 +141,7 @@ void StringColumn::scanInternal(Transaction* transaction, const ChunkState& stat } } -void StringColumn::scanUnfiltered(Transaction* transaction, const ChunkState& state, +void StringColumn::scanUnfiltered(const Transaction* transaction, const ChunkState& state, offset_t startOffsetInChunk, offset_t numValuesToRead, ValueVector* resultVector, sel_t startPosInVector) const { // TODO: Replace indices with ValueVector to avoid maintaining `scan` interface from diff --git a/src/storage/store/struct_column.cpp b/src/storage/store/struct_column.cpp index 8ac5eff5067..81ae2bb6e23 100644 --- a/src/storage/store/struct_column.cpp +++ b/src/storage/store/struct_column.cpp @@ -39,7 +39,7 @@ std::unique_ptr StructColumn::flushChunkData(const ColumnChunkD return flushedChunk; } -void StructColumn::scan(Transaction* transaction, const ChunkState& state, +void StructColumn::scan(const Transaction* transaction, const ChunkState& state, ColumnChunkData* columnChunk, offset_t startOffset, offset_t endOffset) const { KU_ASSERT(columnChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT); Column::scan(transaction, state, columnChunk, startOffset, endOffset); @@ -50,7 +50,7 @@ void StructColumn::scan(Transaction* transaction, const ChunkState& state, } } -void StructColumn::scan(Transaction* transaction, const ChunkState& state, +void StructColumn::scan(const Transaction* transaction, const ChunkState& state, offset_t startOffsetInGroup, offset_t endOffsetInGroup, ValueVector* resultVector, uint64_t offsetInVector) const { nullColumn->scan(transaction, *state.nullState, startOffsetInGroup, endOffsetInGroup, @@ -71,7 +71,7 @@ void StructColumn::scanInternal(Transaction* transaction, const ChunkState& stat } } -void StructColumn::lookupInternal(Transaction* transaction, const ChunkState& state, +void StructColumn::lookupInternal(const Transaction* transaction, const ChunkState& state, offset_t nodeOffset, ValueVector* resultVector, uint32_t posInVector) const { for (auto i = 0u; i < childColumns.size(); i++) { const auto fieldVector = StructVector::getFieldVector(resultVector, i).get(); diff --git a/src/storage/store/version_info.cpp b/src/storage/store/version_info.cpp index 42e12ba26cd..aa68d64f2a8 100644 --- a/src/storage/store/version_info.cpp +++ b/src/storage/store/version_info.cpp @@ -352,8 +352,8 @@ VectorVersionInfo* VersionInfo::getVectorVersionInfo(idx_t vectorIdx) const { return vectorsInfo[vectorIdx].get(); } -void VersionInfo::append(const transaction::Transaction* transaction, - ChunkedNodeGroup* chunkedNodeGroup, const row_idx_t startRow, const row_idx_t numRows) { +void VersionInfo::append(const transaction::Transaction* transaction, const row_idx_t startRow, + const row_idx_t numRows) { if (numRows == 0) { return; } @@ -369,13 +369,9 @@ void VersionInfo::append(const transaction::Transaction* transaction, const auto numRowsInVector = endRowIdx - startRowIdx + 1; vectorVersionInfo.append(transaction->getID(), startRowIdx, numRowsInVector); } - if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(chunkedNodeGroup, startRow, numRows); - } } -bool VersionInfo::delete_(const transaction::Transaction* transaction, - ChunkedNodeGroup* chunkedNodeGroup, const row_idx_t rowIdx) { +bool VersionInfo::delete_(const transaction::Transaction* transaction, const row_idx_t rowIdx) { auto [vectorIdx, rowIdxInVector] = StorageUtils::getQuotientRemainder(rowIdx, DEFAULT_VECTOR_CAPACITY); auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); @@ -385,11 +381,7 @@ bool VersionInfo::delete_(const transaction::Transaction* transaction, // ALWAYS_INSERTED to avoid checking the version in the future. vectorVersionInfo.insertionStatus = VectorVersionInfo::InsertionStatus::ALWAYS_INSERTED; } - const auto deleted = vectorVersionInfo.delete_(transaction->getID(), rowIdxInVector); - if (deleted && transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(chunkedNodeGroup, rowIdx, 1); - } - return deleted; + return vectorVersionInfo.delete_(transaction->getID(), rowIdxInVector); } void VersionInfo::getSelVectorToScan(const transaction_t startTS, const transaction_t transactionID, diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 5cb1497a58b..d61ef3959f0 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -4,7 +4,8 @@ #include "catalog/catalog_entry/sequence_catalog_entry.h" #include "catalog/catalog_entry/table_catalog_entry.h" #include "catalog/catalog_set.h" -#include "storage/store/chunked_node_group.h" +#include "storage/store/node_table.h" +#include "storage/store/rel_table_data.h" #include "storage/store/update_info.h" using namespace kuzu::catalog; @@ -37,9 +38,12 @@ struct NodeBatchInsertRecord { }; struct VersionRecord { - ChunkedNodeGroup* chunkedNodeGroup; + NodeGroupCollection* nodeGroupCollection; row_idx_t startRow; row_idx_t numRows; + node_group_idx_t nodeGroupIdx; + pre_rollback_callback_t preRollbackCallback; + CSRNodeGroupScanSource source; }; struct VectorUpdateRecord { @@ -109,23 +113,47 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, *reinterpret_cast(buffer) = sequenceEntryRecord; } -void UndoBuffer::createInsertInfo(ChunkedNodeGroup* chunkedNodeGroup, row_idx_t startRow, - row_idx_t numRows) { - createVersionInfo(UndoRecordType::INSERT_INFO, chunkedNodeGroup, startRow, numRows); +static void noPreRollbackFunc(const transaction::Transaction*, common::row_idx_t, common::row_idx_t, + common::node_group_idx_t) {} + +void UndoBuffer::createInsertInfo(RelTableData* relTableData, node_group_idx_t nodeGroupIdx, + row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source) { + createVersionInfo(UndoRecordType::INSERT_INFO, relTableData->getNodeGroups(), noPreRollbackFunc, + startRow, numRows, nodeGroupIdx, source); +} + +void UndoBuffer::createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + row_idx_t startRow, row_idx_t numRows) { + createVersionInfo( + UndoRecordType::INSERT_INFO, nodeTable->getNodeGroups(), + [nodeTable](const transaction::Transaction* transaction, common::row_idx_t startRow, + common::row_idx_t numRows, common::node_group_idx_t nodeGroupIdx) { + nodeTable->rollbackInsert(transaction, startRow, numRows, nodeGroupIdx); + }, + startRow, numRows, nodeGroupIdx); } -void UndoBuffer::createDeleteInfo(ChunkedNodeGroup* chunkedNodeGroup, row_idx_t startRow, - row_idx_t numRows) { - createVersionInfo(UndoRecordType::DELETE_INFO, chunkedNodeGroup, startRow, numRows); +void UndoBuffer::createDeleteInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) { + createVersionInfo(UndoRecordType::DELETE_INFO, nodeTable->getNodeGroups(), noPreRollbackFunc, + startRow, numRows, nodeGroupIdx); +} + +void UndoBuffer::createDeleteInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) { + createVersionInfo(UndoRecordType::DELETE_INFO, relTableData->getNodeGroups(), noPreRollbackFunc, + startRow, numRows, nodeGroupIdx, source); } void UndoBuffer::createVersionInfo(const UndoRecordType recordType, - ChunkedNodeGroup* chunkedNodeGroup, row_idx_t startRow, row_idx_t numRows) { + NodeGroupCollection* nodeGroupCollection, pre_rollback_callback_t callback, row_idx_t startRow, + row_idx_t numRows, node_group_idx_t nodeGroupIdx, storage::CSRNodeGroupScanSource source) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); - const VersionRecord vectorVersionRecord{chunkedNodeGroup, startRow, numRows}; + const VersionRecord vectorVersionRecord{nodeGroupCollection, startRow, numRows, nodeGroupIdx, + callback, source}; *reinterpret_cast(buffer) = vectorVersionRecord; } @@ -162,10 +190,11 @@ void UndoBuffer::commit(transaction_t commitTS) const { }); } -void UndoBuffer::rollback() { +void UndoBuffer::rollback(const transaction::Transaction* transaction) { UndoBufferIterator iterator{*this}; - iterator.reverseIterate( - [&](UndoRecordType entryType, uint8_t const* entry) { rollbackRecord(entryType, entry); }); + iterator.reverseIterate([&](UndoRecordType entryType, uint8_t const* entry) { + rollbackRecord(transaction, entryType, entry); + }); } uint64_t UndoBuffer::getMemUsage() const { @@ -210,12 +239,12 @@ void UndoBuffer::commitVersionInfo(UndoRecordType recordType, const uint8_t* rec const auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - undoRecord.chunkedNodeGroup->commitInsert(undoRecord.startRow, undoRecord.numRows, - commitTS); + undoRecord.nodeGroupCollection->commitInsert(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, commitTS, undoRecord.source); } break; case UndoRecordType::DELETE_INFO: { - undoRecord.chunkedNodeGroup->commitDelete(undoRecord.startRow, undoRecord.numRows, - commitTS); + undoRecord.nodeGroupCollection->commitDelete(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, commitTS, undoRecord.source); } break; default: { KU_UNREACHABLE; @@ -228,7 +257,8 @@ void UndoBuffer::commitVectorUpdateInfo(const uint8_t* record, transaction_t com undoRecord.vectorUpdateInfo->version = commitTS; } -void UndoBuffer::rollbackRecord(const UndoRecordType recordType, const uint8_t* record) { +void UndoBuffer::rollbackRecord(const transaction::Transaction* transaction, + const UndoRecordType recordType, const uint8_t* record) { switch (recordType) { case UndoRecordType::CATALOG_ENTRY: { rollbackCatalogEntryRecord(record); @@ -238,7 +268,7 @@ void UndoBuffer::rollbackRecord(const UndoRecordType recordType, const uint8_t* } break; case UndoRecordType::INSERT_INFO: case UndoRecordType::DELETE_INFO: { - rollbackVersionInfo(recordType, record); + rollbackVersionInfo(transaction, recordType, record); } break; case UndoRecordType::UPDATE_INFO: { rollbackVectorUpdateInfo(record); @@ -285,14 +315,19 @@ void UndoBuffer::rollbackSequenceEntry(const uint8_t* entry) { sequenceEntry->rollbackVal(data.usageCount, data.currVal); } -void UndoBuffer::rollbackVersionInfo(UndoRecordType recordType, const uint8_t* record) { +void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction, + UndoRecordType recordType, const uint8_t* record) { auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - undoRecord.chunkedNodeGroup->rollbackInsert(undoRecord.startRow, undoRecord.numRows); + undoRecord.preRollbackCallback(transaction, undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx); + undoRecord.nodeGroupCollection->rollbackInsert(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, undoRecord.source); } break; case UndoRecordType::DELETE_INFO: { - undoRecord.chunkedNodeGroup->rollbackDelete(undoRecord.startRow, undoRecord.numRows); + undoRecord.nodeGroupCollection->rollbackDelete(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, undoRecord.source); } break; default: { KU_UNREACHABLE; diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 0ff6251d46d..94c58387a3b 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -63,7 +63,7 @@ void Transaction::commit(storage::WAL* wal) const { void Transaction::rollback(storage::WAL* wal) const { localStorage->rollback(); - undoBuffer->rollback(); + undoBuffer->rollback(this); if (isWriteTransaction() && shouldLogToWAL()) { KU_ASSERT(wal); wal->logRollback(); @@ -173,14 +173,28 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ } } -void Transaction::pushInsertInfo(storage::ChunkedNodeGroup* chunkedNodeGroup, - common::row_idx_t startRow, common::row_idx_t numRows) const { - undoBuffer->createInsertInfo(chunkedNodeGroup, startRow, numRows); +void Transaction::pushInsertInfo(storage::RelTableData* relTableData, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source) const { + undoBuffer->createInsertInfo(relTableData, nodeGroupIdx, startRow, numRows, source); } -void Transaction::pushDeleteInfo(storage::ChunkedNodeGroup* chunkedNodeGroup, - common::row_idx_t startRow, common::row_idx_t numRows) const { - undoBuffer->createDeleteInfo(chunkedNodeGroup, startRow, numRows); +void Transaction::pushInsertInfo(storage::NodeTable* nodeTable, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + undoBuffer->createInsertInfo(nodeTable, nodeGroupIdx, startRow, numRows); +} + +void Transaction::pushDeleteInfo(storage::NodeTable* nodeTable, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + undoBuffer->createDeleteInfo(nodeTable, nodeGroupIdx, startRow, numRows); +} + +void Transaction::pushDeleteInfo(storage::RelTableData* relTableData, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source) const { + undoBuffer->createDeleteInfo(relTableData, nodeGroupIdx, startRow, numRows, source); } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, diff --git a/test/copy/copy_test.cpp b/test/copy/copy_test.cpp index 3cdda5efdab..3773c5eec89 100644 --- a/test/copy/copy_test.cpp +++ b/test/copy/copy_test.cpp @@ -3,34 +3,58 @@ #include "graph_test/base_graph_test.h" #include "graph_test/graph_test.h" #include "main/database.h" + +#define private public #include "storage/buffer_manager/buffer_manager.h" +#include "transaction/transaction_manager.h" namespace kuzu { namespace testing { -// TODO(Royi) add tests that use this once enough issues are fixed so that the tests can pass class FlakyBufferManager : public storage::BufferManager { public: FlakyBufferManager(const std::string& databasePath, const std::string& spillToDiskPath, uint64_t bufferPoolSize, uint64_t maxDBSize, common::VirtualFileSystem* vfs, bool readOnly, - uint64_t& failureFrequency) + uint64_t& failureFrequency, bool canFailDuringExecute, bool canFailDuringCheckpoint) : storage::BufferManager(databasePath, spillToDiskPath, bufferPoolSize, maxDBSize, vfs, readOnly), - failureFrequency(failureFrequency) {} + failureFrequency(failureFrequency), canFailDuringCheckpoint(canFailDuringCheckpoint), + canFailDuringExecute(canFailDuringExecute) {} bool reserve(uint64_t sizeToReserve) override { + const bool inCheckpoint = + ctx && !ctx->getTransactionManagerUnsafe()->hasActiveWriteTransactionNoLock(); + const bool inCommit = + !inCheckpoint && ctx && ctx->getTx()->getCommitTS() != common::INVALID_TRANSACTION; + const bool inExecute = (!inCommit && !inCheckpoint); reserveCount = (reserveCount + 1) % failureFrequency; - if (reserveCount == 0) { + if ((canFailDuringCheckpoint || !inCheckpoint) && (canFailDuringExecute || !inExecute) && + reserveCount == 0) { failureFrequency *= 2; return false; } return storage::BufferManager::reserve(sizeToReserve); } + void setClientContext(main::ClientContext* newCtx) { ctx = newCtx; } + uint64_t& failureFrequency; + main::ClientContext* ctx{nullptr}; + bool canFailDuringCheckpoint; + bool canFailDuringExecute; uint64_t reserveCount = 0; }; +struct BMExceptionRecoveryTestConfig { + bool canFailDuringExecute; + bool canFailDuringCheckpoint; + std::function initFunc; + std::function(main::Connection*, int)> executeFunc; + std::function earlyExitOnFailureFunc; + std::function(main::Connection*)> checkFunc; + uint64_t checkResult; +}; + class CopyTest : public BaseGraphTest { public: void SetUp() override { @@ -40,29 +64,161 @@ class CopyTest : public BaseGraphTest { void resetDB(uint64_t bufferPoolSize) { systemConfig->bufferPoolSize = bufferPoolSize; - conn.reset(); database.reset(); + conn.reset(); createDBAndConn(); } - - void resetDBFlaky() { + void resetDBFlaky(bool canFailDuringExecute = true, bool canFailDuringCheckpoint = true) { database.reset(); conn.reset(); systemConfig->bufferPoolSize = main::SystemConfig{}.bufferPoolSize; auto constructBMFunc = [&](const main::Database& db) { - return std::unique_ptr(new FlakyBufferManager(databasePath, + auto bm = std::unique_ptr(new FlakyBufferManager(databasePath, getFileSystem(db)->joinPath(databasePath, "copy.tmp"), systemConfig->bufferPoolSize, systemConfig->maxDBSize, getFileSystem(db), systemConfig->readOnly, - failureFrequency)); + failureFrequency, canFailDuringExecute, canFailDuringCheckpoint)); + currentBM = bm.get(); + return bm; }; database = BaseGraphTest::constructDB(databasePath, *systemConfig, constructBMFunc); conn = std::make_unique(database.get()); + currentBM->setClientContext(conn->getClientContext()); } std::string getInputDir() override { KU_UNREACHABLE; } + void BMExceptionRecoveryTest(BMExceptionRecoveryTestConfig cfg); uint64_t failureFrequency; + FlakyBufferManager* currentBM; }; -TEST_F(CopyTest, DISABLED_OutOfMemoryRecovery) { +void CopyTest::BMExceptionRecoveryTest(BMExceptionRecoveryTestConfig cfg) { + if (inMemMode) { + GTEST_SKIP(); + } + static constexpr uint64_t dbSize = 64 * 1024 * 1024; + resetDB(dbSize); + cfg.initFunc(conn.get()); + + // this test only checks robustness during the transaction + // we don't want to trigger BM exceptions during checkpoint + // TODO(Royi) fix checkpointing so this test passes even if BM fails during checkpoint + resetDBFlaky(cfg.canFailDuringExecute, cfg.canFailDuringCheckpoint); + + for (int i = 0;; i++) { + ASSERT_LT(i, 20); + + const auto queryString = common::stringFormat( + "COPY account FROM \"{}/dataset/snap/twitter/csv/twitter-nodes.csv\"", + KUZU_ROOT_DIRECTORY); + + auto result = cfg.executeFunc(conn.get(), i); + if (!result->isSuccess()) { + if (cfg.earlyExitOnFailureFunc(result.get())) { + break; + } + ASSERT_EQ(result->getErrorMessage(), "Buffer manager exception: Unable to allocate " + "memory! The buffer pool is full and no " + "memory could be freed!"); + } else { + // the copy shouldn't succeed first try + ASSERT_GT(i, 0); + break; + } + } + + // Reopen the DB so no spurious errors occur during the query + resetDB(dbSize); + { + // Test that the table copied as expected after the query + auto result = cfg.checkFunc(conn.get()); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + ASSERT_TRUE(result->hasNext()); + ASSERT_EQ(result->getNext()->getValue(0)->getValue(), cfg.checkResult); + } +} + +TEST_F(CopyTest, NodeCopyBMExceptionRecoverySameConnection) { + BMExceptionRecoveryTestConfig cfg{.canFailDuringExecute = true, + .canFailDuringCheckpoint = false, + .initFunc = + [](main::Connection* conn) { + conn->query("CREATE NODE TABLE account(ID INT64, PRIMARY KEY(ID))"); + }, + .executeFunc = + [](main::Connection* conn, int) { + const auto queryString = common::stringFormat( + "COPY account FROM \"{}/dataset/snap/twitter/csv/twitter-nodes.csv\"", + KUZU_ROOT_DIRECTORY); + + return conn->query(queryString); + }, + .earlyExitOnFailureFunc = [](main::QueryResult*) { return false; }, + .checkFunc = + [](main::Connection* conn) { return conn->query("MATCH (a:account) RETURN COUNT(*)"); }, + .checkResult = 81306}; + BMExceptionRecoveryTest(cfg); +} + +TEST_F(CopyTest, RelCopyBMExceptionRecoverySameConnection) { + BMExceptionRecoveryTestConfig cfg{.canFailDuringExecute = true, + .canFailDuringCheckpoint = false, + .initFunc = + [](main::Connection* conn) { + conn->query("CREATE NODE TABLE account(ID INT64, PRIMARY KEY(ID))"); + conn->query("CREATE REL TABLE follows(FROM account TO account);"); + ASSERT_TRUE(conn->query(common::stringFormat( + "COPY account FROM \"{}/dataset/snap/twitter/csv/twitter-nodes.csv\"", + KUZU_ROOT_DIRECTORY))); + }, + .executeFunc = + [this](main::Connection* conn, int i) { + // there are many allocations in the partitioning phase + // we scale the failure frequency linearly so that we trigger at least one + // allocation failure in the batch insert phase + failureFrequency = 512 * (i + 15); + + return conn->query(common::stringFormat( + "COPY follows FROM '{}/dataset/snap/twitter/csv/twitter-edges.csv' (DELIM=' ')", + KUZU_ROOT_DIRECTORY)); + }, + .earlyExitOnFailureFunc = + [this](main::QueryResult*) { + // clear the BM so that the failure frequency isn't messed with by cached pages + while (0 != currentBM->evictPages()) + ; + return false; + }, + .checkFunc = + [](main::Connection* conn) { + return conn->query("MATCH (a:account)-[:follows]->(b:account) RETURN COUNT(*)"); + }, + .checkResult = 2420766}; + BMExceptionRecoveryTest(cfg); +} + +TEST_F(CopyTest, NodeInsertBMExceptionDuringCommitRecovery) { + static constexpr uint64_t numValues = 200000; + BMExceptionRecoveryTestConfig cfg{.canFailDuringExecute = false, + .canFailDuringCheckpoint = false, + .initFunc = + [this](main::Connection* conn) { + failureFrequency = 128; + conn->query("CREATE NODE TABLE account(ID INT64, PRIMARY KEY(ID))"); + }, + .executeFunc = + [](main::Connection* conn, int) { + const auto queryString = common::stringFormat( + "UNWIND RANGE(1,{}) AS i CREATE (a:account {ID:i})", numValues); + + return conn->query(queryString); + }, + .earlyExitOnFailureFunc = [](main::QueryResult*) { return false; }, + .checkFunc = + [](main::Connection* conn) { return conn->query("MATCH (a:account) RETURN COUNT(*)"); }, + .checkResult = numValues}; + BMExceptionRecoveryTest(cfg); +} + +TEST_F(CopyTest, OutOfMemoryRecovery) { if (inMemMode) { GTEST_SKIP(); } diff --git a/test/include/graph_test/base_graph_test.h b/test/include/graph_test/base_graph_test.h index 8dbd7bfd56c..f37bce101e0 100644 --- a/test/include/graph_test/base_graph_test.h +++ b/test/include/graph_test/base_graph_test.h @@ -65,7 +65,8 @@ class BaseGraphTest : public Test { // Static functions to access Database's non-public properties/interfaces. static std::unique_ptr constructDB(std::string_view databasePath, main::SystemConfig systemConfig, main::Database::construct_bm_func_t constructFunc) { - return main::Database::construct(databasePath, systemConfig, constructFunc); + return std::unique_ptr( + new main::Database(databasePath, systemConfig, constructFunc)); } static storage::BufferManager* getBufferManager(const main::Database& database) { diff --git a/test/storage/local_hash_index_test.cpp b/test/storage/local_hash_index_test.cpp index 54bc2d9953a..7ab323157b3 100644 --- a/test/storage/local_hash_index_test.cpp +++ b/test/storage/local_hash_index_test.cpp @@ -17,17 +17,17 @@ TEST(LocalHashIndexTests, LocalInserts) { auto hashIndex = std::make_unique(PhysicalTypeID::INT64, overflowFileHandle.get()); - for (int64_t i = 0u; i < 100; i++) { + for (int64_t i = 0u; i < 100000; i++) { ASSERT_TRUE(hashIndex->insert(i, i * 2, isVisible)); } - for (int64_t i = 0u; i < 100; i++) { + for (int64_t i = 0u; i < 100000; i++) { ASSERT_FALSE(hashIndex->insert(i, i, isVisible)); } - for (int64_t i = 0u; i < 100; i++) { + for (int64_t i = 0u; i < 100000; i++) { hashIndex->delete_(i); } - for (int64_t i = 0u; i < 100; i++) { + for (int64_t i = 0u; i < 100000; i++) { ASSERT_TRUE(hashIndex->insert(i, i, isVisible)); } } diff --git a/test/test_files/transaction/copy/copy_node.test b/test/test_files/transaction/copy/copy_node.test index b1184a01a20..3a7a7e0ad6c 100644 --- a/test/test_files/transaction/copy/copy_node.test +++ b/test/test_files/transaction/copy/copy_node.test @@ -1,12 +1,26 @@ -DATASET CSV empty -- +-CASE CopyNodeAfterPKErrorRollbackFlushedGroups +-STATEMENT create node table Comment (id int64, creationDate INT64, locationIP STRING, browserUsed STRING, content STRING, length INT32, PRIMARY KEY (id)); +---- ok +# COPY will trigger duplicate PK once the 2nd file is hit +-STATEMENT COPY Comment FROM ['${KUZU_ROOT_DIRECTORY}/dataset/ldbc-sf01/Comment.csv', '${KUZU_ROOT_DIRECTORY}/dataset/ldbc-sf01/Comment.csv'] (delim="|", header=true, parallel=false); +---- error(regex) +Copy exception: Found duplicated primary key value \w+, which violates the uniqueness constraint of the primary key column. +# The failed COPY should revert all of its insertions and the 2nd COPY should succeed +-STATEMENT COPY Comment FROM '${KUZU_ROOT_DIRECTORY}/dataset/ldbc-sf01/Comment.csv' (DELIM="|", header=true); +---- ok +-STATEMENT MATCH (c:Comment) WHERE c.ID = 962073046352 RETURN c.locationIP +---- 1 +36.95.74.186 + -CASE CopyNodeManualTransactionCheck -STATEMENT BEGIN TRANSACTION; ---- ok -STATEMENT CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, u UUID, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- error Connection exception: COPY FROM is only supported in auto transaction mode. @@ -16,7 +30,7 @@ Connection exception: COPY FROM is only supported in auto transaction mode. ---- ok -STATEMENT CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, u UUID, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- ok -STATEMENT COMMIT; ---- ok @@ -37,7 +51,7 @@ Connection exception: COPY FROM is only supported in auto transaction mode. ---- ok -STATEMENT CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, u UUID, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- ok -STATEMENT COMMIT; ---- ok @@ -56,7 +70,7 @@ Connection exception: COPY FROM is only supported in auto transaction mode. -CASE CopyNodeRollbackDueToError -STATEMENT CREATE NODE TABLE person (ID INT64, name STRING, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/copy-error/person.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/copy-error/person.csv" ---- error Copy exception: Found duplicated primary key value 10, which violates the uniqueness constraint of the primary key column. -STATEMENT MATCH (p:person) return count(*); @@ -66,7 +80,7 @@ Copy exception: Found duplicated primary key value 10, which violates the unique -CASE CopyNodeRollbackAndManualCheckpoint -STATEMENT CREATE NODE TABLE person (ID INT64, name STRING, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/copy-error/person.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/copy-error/person.csv" ---- error Copy exception: Found duplicated primary key value 10, which violates the uniqueness constraint of the primary key column. -STATEMENT MATCH (p:person) return count(*); @@ -81,7 +95,7 @@ Copy exception: Found duplicated primary key value 10, which violates the unique -CASE CopyNodeRollbackAndManualCheckpointAndReloadDB -STATEMENT CREATE NODE TABLE person (ID INT64, name STRING, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/copy-error/person.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/copy-error/person.csv" ---- error Copy exception: Found duplicated primary key value 10, which violates the uniqueness constraint of the primary key column. -STATEMENT MATCH (p:person) return count(*); @@ -100,9 +114,9 @@ Copy exception: Found duplicated primary key value 10, which violates the unique -CASE CopyNodeRollbackAndManualCheckpoint2 -STATEMENT CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, u UUID, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- error Copy exception: Found duplicated primary key value 0, which violates the uniqueness constraint of the primary key column. -STATEMENT MATCH (p:person) return count(*); @@ -181,7 +195,7 @@ Copy exception: Found duplicated primary key value 0, which violates the uniquen -CASE CopyNodeAndDeleteAndManualCheckpoint -STATEMENT CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, u UUID, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- ok -STATEMENT MATCH (p:person) return count(*); ---- 1 @@ -200,7 +214,7 @@ Copy exception: Found duplicated primary key value 0, which violates the uniquen -CASE CopyNodeAndDeleteAndManualCheckpointAndReloadDB -STATEMENT CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, u UUID, PRIMARY KEY (ID)); ---- ok --STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" +-STATEMENT COPY person FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" ---- ok -STATEMENT MATCH (p:person) return count(*); ---- 1 diff --git a/test/test_files/transaction/create_node/create_node.test b/test/test_files/transaction/create_node/create_node.test index 4da2c74d794..3394309a63d 100644 --- a/test/test_files/transaction/create_node/create_node.test +++ b/test/test_files/transaction/create_node/create_node.test @@ -36,6 +36,24 @@ ---- 1 0|A|True|2019-01-01 +-CASE CreateRollbackThenRetry +-STATEMENT CREATE NODE TABLE test(id INT64, name STRING, isTrue BOOLEAN, birthday DATE, PRIMARY KEY(id)); +---- ok +-STATEMENT BEGIN TRANSACTION +---- ok +-STATEMENT CREATE (a:test {id:0, name:'A', isTrue:True, birthday:Date('2019-01-01')}) +---- ok +-STATEMENT MATCH (a:test) RETURN a.id, a.name, a.isTrue, a.birthday +---- 1 +0|A|True|2019-01-01 +-STATEMENT ROLLBACK +---- ok +-STATEMENT CREATE (a:test {id:0, name:'A', isTrue:True, birthday:Date('2019-01-01')}) +---- ok +-STATEMENT MATCH (a:test) RETURN a.id, a.name, a.isTrue, a.birthday +---- 1 +0|A|True|2019-01-01 + -CASE Create3 -STATEMENT CALL auto_checkpoint=false; ---- ok diff --git a/test/test_files/transaction/set_node/set_empty.test b/test/test_files/transaction/set_node/set_empty.test index 72dbeefd09c..41a9643ba3c 100644 --- a/test/test_files/transaction/set_node/set_empty.test +++ b/test/test_files/transaction/set_node/set_empty.test @@ -2,6 +2,24 @@ -- +-CASE ExceptionDuringHashIndexCommitRecovery +-STATEMENT CREATE NODE TABLE account(ID INT64, PRIMARY KEY(ID)) +---- ok +-STATEMENT BEGIN TRANSACTION +---- ok +-STATEMENT UNWIND RANGE(1,200000) AS i CREATE (a:account {ID:i}) +---- ok +-STATEMENT MATCH (a:account) WHERE a.ID = 199000 SET a.ID = 1 +---- ok +-STATEMENT COMMIT +---- error(regex) +Runtime exception: Found duplicated primary key value \d+, which violates the uniqueness constraint of the primary key column. +-STATEMENT UNWIND RANGE(1,200000) AS i CREATE (a:account {ID:i}) +---- ok +-STATEMENT MATCH (a:account) RETURN COUNT(*) +---- 1 +200000 + -CASE DoubleColumnInsertionsAndUpdatesLarge -STATEMENT CALL auto_checkpoint=false; ---- ok From 756fef4d5a65cdfcf13ac6e112ae8d10b18a8774 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 13:09:53 -0500 Subject: [PATCH 02/28] Make VersionRecord all stack-allocated --- src/include/storage/store/node_table.h | 5 +++ src/include/storage/store/table.h | 3 ++ src/include/storage/undo_buffer.h | 12 +++---- src/storage/store/node_table.cpp | 11 ++++-- src/storage/undo_buffer.cpp | 48 +++++++++++--------------- 5 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index 9112b88fd8c..fd84d4c1248 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -178,6 +178,10 @@ class NodeTable final : public Table { TableStats getStats(const transaction::Transaction* transaction) const; + const pre_rollback_insert_func_t& getPreRollbackInsertFunc() const { + return preRollbackInsertFunc; + } + private: void insertPK(const transaction::Transaction* transaction, const common::ValueVector& nodeIDVector, const common::ValueVector& pkVector) const; @@ -197,6 +201,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; + pre_rollback_insert_func_t preRollbackInsertFunc; }; } // namespace storage diff --git a/src/include/storage/store/table.h b/src/include/storage/store/table.h index 18971ac672d..ab0eccaedf4 100644 --- a/src/include/storage/store/table.h +++ b/src/include/storage/store/table.h @@ -13,6 +13,9 @@ class ExpressionEvaluator; namespace storage { class MemoryManager; +using pre_rollback_insert_func_t = std::function; + enum class TableScanSource : uint8_t { COMMITTED = 0, UNCOMMITTED = 1, NONE = UINT8_MAX }; struct TableScanState { diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 33ee87722ba..9ec3ff01059 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -1,11 +1,11 @@ #pragma once -#include #include #include "common/constants.h" #include "common/types/types.h" #include "storage/enums/csr_node_group_scan_source.h" +#include "storage/store/table.h" namespace kuzu { namespace catalog { @@ -23,9 +23,6 @@ class ClientContext; } namespace storage { -using pre_rollback_callback_t = std::function; - // TODO(Guodong): This should be reworked to use MemoryManager for memory allocaiton. // For now, we use malloc to get around the limitation of 256KB from MM. class UndoMemoryBuffer { @@ -114,9 +111,10 @@ class UndoBuffer { uint8_t* createUndoRecord(uint64_t size); void createVersionInfo(UndoRecordType recordType, NodeGroupCollection* nodeGroupCollection, - pre_rollback_callback_t preRollbackCallback, common::row_idx_t startRow, - common::row_idx_t numRows, common::node_group_idx_t nodeGroupIdx = 0, - storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + common::row_idx_t startRow, common::row_idx_t numRows, + common::node_group_idx_t nodeGroupIdx = 0, + storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE, + const pre_rollback_insert_func_t* preRollbackCallback = nullptr); void commitRecord(UndoRecordType recordType, const uint8_t* record, common::transaction_t commitTS) const; diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 36b41b216af..4fe7d767887 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -235,6 +235,12 @@ NodeTable::NodeTable(const StorageManager* storageManager, createAppendToUndoBufferFunc(this)); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); + + preRollbackInsertFunc = [this](const transaction::Transaction* transaction, + common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx_) { + return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_); + }; } std::unique_ptr NodeTable::loadTable(Deserializer& deSer, const Catalog& catalog, @@ -591,9 +597,10 @@ void NodeTable::checkpoint(Serializer& ser, TableCatalogEntry* tableEntry) { } void NodeTable::rollbackInsert(const transaction::Transaction* transaction, - common::row_idx_t startRow, common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx) { + common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx_) { row_idx_t startNodeOffset = startRow; - for (node_group_idx_t i = 0; i < nodeGroupIdx; ++i) { + for (node_group_idx_t i = 0; i < nodeGroupIdx_; ++i) { startNodeOffset += nodeGroups->getNodeGroupNoLock(i)->getNumRows(); } diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index d61ef3959f0..7498e753268 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -42,7 +42,7 @@ struct VersionRecord { row_idx_t startRow; row_idx_t numRows; node_group_idx_t nodeGroupIdx; - pre_rollback_callback_t preRollbackCallback; + const pre_rollback_insert_func_t* preRollbackCallback; CSRNodeGroupScanSource source; }; @@ -113,48 +113,40 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, *reinterpret_cast(buffer) = sequenceEntryRecord; } -static void noPreRollbackFunc(const transaction::Transaction*, common::row_idx_t, common::row_idx_t, - common::node_group_idx_t) {} +void UndoBuffer::createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, + row_idx_t startRow, row_idx_t numRows) { + createVersionInfo(UndoRecordType::INSERT_INFO, nodeTable->getNodeGroups(), startRow, numRows, + nodeGroupIdx, CSRNodeGroupScanSource::NONE, &nodeTable->getPreRollbackInsertFunc()); +} void UndoBuffer::createInsertInfo(RelTableData* relTableData, node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source) { - createVersionInfo(UndoRecordType::INSERT_INFO, relTableData->getNodeGroups(), noPreRollbackFunc, - startRow, numRows, nodeGroupIdx, source); -} - -void UndoBuffer::createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - row_idx_t startRow, row_idx_t numRows) { - createVersionInfo( - UndoRecordType::INSERT_INFO, nodeTable->getNodeGroups(), - [nodeTable](const transaction::Transaction* transaction, common::row_idx_t startRow, - common::row_idx_t numRows, common::node_group_idx_t nodeGroupIdx) { - nodeTable->rollbackInsert(transaction, startRow, numRows, nodeGroupIdx); - }, - startRow, numRows, nodeGroupIdx); + createVersionInfo(UndoRecordType::INSERT_INFO, relTableData->getNodeGroups(), startRow, numRows, + nodeGroupIdx, source); } void UndoBuffer::createDeleteInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows) { - createVersionInfo(UndoRecordType::DELETE_INFO, nodeTable->getNodeGroups(), noPreRollbackFunc, - startRow, numRows, nodeGroupIdx); + createVersionInfo(UndoRecordType::DELETE_INFO, nodeTable->getNodeGroups(), startRow, numRows, + nodeGroupIdx); } void UndoBuffer::createDeleteInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) { - createVersionInfo(UndoRecordType::DELETE_INFO, relTableData->getNodeGroups(), noPreRollbackFunc, - startRow, numRows, nodeGroupIdx, source); + createVersionInfo(UndoRecordType::DELETE_INFO, relTableData->getNodeGroups(), startRow, numRows, + nodeGroupIdx, source); } void UndoBuffer::createVersionInfo(const UndoRecordType recordType, - NodeGroupCollection* nodeGroupCollection, pre_rollback_callback_t callback, row_idx_t startRow, - row_idx_t numRows, node_group_idx_t nodeGroupIdx, storage::CSRNodeGroupScanSource source) { + NodeGroupCollection* nodeGroupCollection, row_idx_t startRow, row_idx_t numRows, + node_group_idx_t nodeGroupIdx, storage::CSRNodeGroupScanSource source, + const pre_rollback_insert_func_t* callback) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); - const VersionRecord vectorVersionRecord{nodeGroupCollection, startRow, numRows, nodeGroupIdx, - callback, source}; - *reinterpret_cast(buffer) = vectorVersionRecord; + *reinterpret_cast(buffer) = + VersionRecord{nodeGroupCollection, startRow, numRows, nodeGroupIdx, callback, source}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -320,8 +312,10 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - undoRecord.preRollbackCallback(transaction, undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx); + if (undoRecord.preRollbackCallback) { + (*undoRecord.preRollbackCallback)(transaction, undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx); + } undoRecord.nodeGroupCollection->rollbackInsert(undoRecord.startRow, undoRecord.numRows, undoRecord.nodeGroupIdx, undoRecord.source); } break; From b9af266a925af48c468c459e62511c40c3ddea09 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 13:32:21 -0500 Subject: [PATCH 03/28] Try fix tests --- src/storage/store/node_group_collection.cpp | 25 ++++++++++----------- test/copy/copy_test.cpp | 7 +++--- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index 91c7225508b..d6acb10ebb2 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -54,9 +54,9 @@ void NodeGroupCollection::append(const Transaction* transaction, std::min(numRowsToAppend - numRowsAppended, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); appendToUndoBufferFunc(transaction, lastNodeGroup, numToAppendInNodeGroup); + numTotalRows += numToAppendInNodeGroup; lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); numRowsAppended += numToAppendInNodeGroup; - numTotalRows += numToAppendInNodeGroup; } stats.incrementCardinality(numRowsAppended); } @@ -95,10 +95,10 @@ void NodeGroupCollection::append(const Transaction* transaction, NodeGroup& node lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInBatch); appendToUndoBufferFunc(transaction, lastNodeGroup, numToAppendInBatch); + numTotalRows += numToAppendInBatch; lastNodeGroup->append(transaction, *chunkedGroupToAppend, numRowsAppendedInChunkedGroup, numToAppendInBatch); numRowsAppendedInChunkedGroup += numToAppendInBatch; - numTotalRows += numToAppendInBatch; } numChunkedGroupsAppended++; } @@ -132,6 +132,7 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush directFlushWhenAppend = numToAppend == numRowsLeftInLastNodeGroup && lastNodeGroup->getNumRows() == 0; appendToUndoBufferFunc(transaction, lastNodeGroup, chunkedGroup.getNumRows()); + numTotalRows += numToAppend; if (!directFlushWhenAppend) { // TODO(Guodong): Furthur optimize on this. Should directly figure out startRowIdx to // start appending into the node group and pass in as param. @@ -144,7 +145,6 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush KU_ASSERT(lastNodeGroup->getNumChunkedGroups() == 0); lastNodeGroup->merge(transaction, std::move(flushedGroup)); } - numTotalRows += numToAppend; stats.incrementCardinality(numToAppend); return {startOffset, numToAppend}; } @@ -207,22 +207,21 @@ static idx_t getNumEmptyTrailingGroups(const GroupCollection& nodeGro void NodeGroupCollection::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx, CSRNodeGroupScanSource source) { const auto lock = nodeGroups.lock(); - auto numRowsToSubtract = numRows_; // skip the rollback if all newly created node groups have already been deleted if (!nodeGroups.isEmpty(lock) || nodeGroupIdx > 0) { KU_ASSERT(nodeGroupIdx < nodeGroups.getNumGroups(lock)); auto* nodeGroup = nodeGroups.getGroup(lock, nodeGroupIdx); + if (nodeGroup) { + KU_ASSERT(startRow <= nodeGroup->getNumRows()); + nodeGroup->rollbackInsert(startRow, numRows_, source); - KU_ASSERT(startRow <= nodeGroup->getNumRows()); - numRowsToSubtract = std::min(numRowsToSubtract, nodeGroup->getNumRows() - startRow); - nodeGroup->rollbackInsert(startRow, numRows_, source); - - // remove any empty trailing node groups after the rollback - const auto numGroupsToRemove = getNumEmptyTrailingGroups(nodeGroups, lock); - nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); + // remove any empty trailing node groups after the rollback + const auto numGroupsToRemove = getNumEmptyTrailingGroups(nodeGroups, lock); + nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); + } } - KU_ASSERT(numRowsToSubtract <= numTotalRows); - numTotalRows -= numRowsToSubtract; + KU_ASSERT(numRows_ <= numTotalRows); + numTotalRows -= numRows_; } void NodeGroupCollection::rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, diff --git a/test/copy/copy_test.cpp b/test/copy/copy_test.cpp index 3773c5eec89..7f13502fff7 100644 --- a/test/copy/copy_test.cpp +++ b/test/copy/copy_test.cpp @@ -119,8 +119,6 @@ void CopyTest::BMExceptionRecoveryTest(BMExceptionRecoveryTestConfig cfg) { "memory! The buffer pool is full and no " "memory could be freed!"); } else { - // the copy shouldn't succeed first try - ASSERT_GT(i, 0); break; } } @@ -183,8 +181,9 @@ TEST_F(CopyTest, RelCopyBMExceptionRecoverySameConnection) { .earlyExitOnFailureFunc = [this](main::QueryResult*) { // clear the BM so that the failure frequency isn't messed with by cached pages - while (0 != currentBM->evictPages()) - ; + for (auto& fh : currentBM->fileHandles) { + currentBM->removeFilePagesFromFrames(*fh); + } return false; }, .checkFunc = From 789c54eb1b643e5a6018ebcbc6b05902c5d3131c Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 14:47:59 -0500 Subject: [PATCH 04/28] Combine interfaces for creating undo buffer info --- src/include/storage/store/node_table.h | 2 -- src/include/storage/store/rel_table_data.h | 2 -- src/include/storage/undo_buffer.h | 8 +++---- src/include/transaction/transaction.h | 21 ++++++++--------- src/storage/store/node_table.cpp | 13 ++++++----- src/storage/store/rel_table_data.cpp | 5 ++-- src/storage/undo_buffer.cpp | 27 +++++++--------------- src/transaction/transaction.cpp | 20 ++++------------ 8 files changed, 34 insertions(+), 64 deletions(-) diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index fd84d4c1248..a94ed1bb860 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -174,8 +174,6 @@ class NodeTable final : public Table { return nodeGroups->getNodeGroupNoLock(nodeGroupIdx); } - NodeGroupCollection* getNodeGroups() { return nodeGroups.get(); } - TableStats getStats(const transaction::Transaction* transaction) const; const pre_rollback_insert_func_t& getPreRollbackInsertFunc() const { diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index bd453a04312..b9121d3e533 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -57,8 +57,6 @@ class RelTableData { return nodeGroups->getOrCreateNodeGroup(nodeGroupIdx, NodeGroupDataFormat::CSR); } - NodeGroupCollection* getNodeGroups() { return nodeGroups.get(); } - common::RelMultiplicity getMultiplicity() const { return multiplicity; } TableStats getStats() const { return nodeGroups->getStats(); } diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 9ec3ff01059..0be92a42b3e 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -89,14 +89,12 @@ class UndoBuffer { void createCatalogEntry(catalog::CatalogSet& catalogSet, catalog::CatalogEntry& catalogEntry); void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, const catalog::SequenceRollbackData& data); - void createInsertInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + void createInsertInfo(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source); + storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); void createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows); - void createDeleteInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows); - void createDeleteInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, + void createDeleteInfo(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index e55c3cb9023..cf6ba1b475f 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -21,8 +21,7 @@ class WAL; class VersionInfo; class UpdateInfo; struct VectorUpdateInfo; -class RelTableData; -class NodeTable; +class NodeGroupCollection; class ChunkedNodeGroup; } // namespace storage namespace transaction { @@ -117,16 +116,14 @@ class Transaction { bool skipLoggingToWAL = false) const; void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount, const catalog::SequenceRollbackData& data) const; - void pushInsertInfo(storage::RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source) const; - void pushInsertInfo(storage::NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows) const; - void pushDeleteInfo(storage::RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source) const; - void pushDeleteInfo(storage::NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows) const; + void pushInsertInfo(storage::NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE) const; + void pushDeleteInfo(storage::NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const; diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 4fe7d767887..e3e6438c969 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -204,11 +204,11 @@ bool NodeTableScanState::scanNext(Transaction* transaction) { return true; } -static decltype(auto) createAppendToUndoBufferFunc(NodeTable* nodeTable) { - return [nodeTable](const transaction::Transaction* transaction, NodeGroup* nodeGroup, +static decltype(auto) createAppendToUndoBufferFunc(NodeGroupCollection* nodeGroups) { + return [nodeGroups](const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows) { if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(nodeTable, nodeGroup->getNodeGroupIdx(), + transaction->pushInsertInfo(nodeGroups, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows); } }; @@ -232,7 +232,7 @@ NodeTable::NodeTable(const StorageManager* storageManager, nodeGroups = std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer, - createAppendToUndoBufferFunc(this)); + createAppendToUndoBufferFunc(nodeGroups.get())); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); @@ -454,7 +454,7 @@ bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(this, nodeGroupIdx, rowIdxInGroup, 1); + transaction->pushDeleteInfo(nodeGroups.get(), nodeGroupIdx, rowIdxInGroup, 1); } } if (isDeleted) { @@ -540,7 +540,8 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); KU_ASSERT(isDeleted); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(this, nodeGroupIdx, rowIdxInGroup, 1); + transaction->pushDeleteInfo(nodeGroups.get(), nodeGroupIdx, rowIdxInGroup, + 1); } } } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index e5a53eed0df..9f28185c97f 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -98,7 +98,7 @@ bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVec auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); if (isDeleted && transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(this, nodeGroupIdx, rowIdx, 1, source); + transaction->pushDeleteInfo(nodeGroups.get(), nodeGroupIdx, rowIdx, 1, source); } return isDeleted; } @@ -195,7 +195,8 @@ void RelTableData::pushInsertInfo(transaction::Transaction* transaction, const auto startRow = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? nodeGroup.getNumPersistentRows() : nodeGroup.getNumRows(); - transaction->pushInsertInfo(this, nodeGroup.getNodeGroupIdx(), startRow, numRows_, source); + transaction->pushInsertInfo(nodeGroups.get(), nodeGroup.getNodeGroupIdx(), startRow, numRows_, + source); } void RelTableData::checkpoint(const std::vector& columnIDs) { diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 7498e753268..4fe60193917 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -113,28 +113,17 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, *reinterpret_cast(buffer) = sequenceEntryRecord; } -void UndoBuffer::createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - row_idx_t startRow, row_idx_t numRows) { - createVersionInfo(UndoRecordType::INSERT_INFO, nodeTable->getNodeGroups(), startRow, numRows, - nodeGroupIdx, CSRNodeGroupScanSource::NONE, &nodeTable->getPreRollbackInsertFunc()); -} - -void UndoBuffer::createInsertInfo(RelTableData* relTableData, node_group_idx_t nodeGroupIdx, +void UndoBuffer::createInsertInfo(NodeGroupCollection* nodeGroups, node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source) { - createVersionInfo(UndoRecordType::INSERT_INFO, relTableData->getNodeGroups(), startRow, numRows, - nodeGroupIdx, source); -} - -void UndoBuffer::createDeleteInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows) { - createVersionInfo(UndoRecordType::DELETE_INFO, nodeTable->getNodeGroups(), startRow, numRows, - nodeGroupIdx); + createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroups, startRow, numRows, nodeGroupIdx, + source); } -void UndoBuffer::createDeleteInfo(RelTableData* relTableData, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) { - createVersionInfo(UndoRecordType::DELETE_INFO, relTableData->getNodeGroups(), startRow, numRows, - nodeGroupIdx, source); +void UndoBuffer::createDeleteInfo(NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source) { + createVersionInfo(UndoRecordType::DELETE_INFO, nodeGroups, startRow, numRows, nodeGroupIdx, + source); } void UndoBuffer::createVersionInfo(const UndoRecordType recordType, diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 94c58387a3b..f73492d6156 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -173,28 +173,16 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ } } -void Transaction::pushInsertInfo(storage::RelTableData* relTableData, +void Transaction::pushInsertInfo(storage::NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) const { - undoBuffer->createInsertInfo(relTableData, nodeGroupIdx, startRow, numRows, source); + undoBuffer->createInsertInfo(nodeGroups, nodeGroupIdx, startRow, numRows, source); } -void Transaction::pushInsertInfo(storage::NodeTable* nodeTable, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows) const { - undoBuffer->createInsertInfo(nodeTable, nodeGroupIdx, startRow, numRows); -} - -void Transaction::pushDeleteInfo(storage::NodeTable* nodeTable, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows) const { - undoBuffer->createDeleteInfo(nodeTable, nodeGroupIdx, startRow, numRows); -} - -void Transaction::pushDeleteInfo(storage::RelTableData* relTableData, +void Transaction::pushDeleteInfo(storage::NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) const { - undoBuffer->createDeleteInfo(relTableData, nodeGroupIdx, startRow, numRows, source); + undoBuffer->createDeleteInfo(nodeGroups, nodeGroupIdx, startRow, numRows, source); } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, From 6a617e0f2fefb2df9285ae3a4fb4ebffe7e53c0b Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 15:24:18 -0500 Subject: [PATCH 05/28] Properly maintain numTotalRows for rel table data --- .../storage/store/node_group_collection.h | 19 ++++---- src/include/storage/store/rel_table.h | 4 +- src/include/storage/store/rel_table_data.h | 6 ++- .../operator/persistent/rel_batch_insert.cpp | 13 ++---- src/storage/store/node_group_collection.cpp | 45 ++++++++++++++----- src/storage/store/node_table.cpp | 13 +----- src/storage/store/rel_table.cpp | 18 ++++---- src/storage/store/rel_table_data.cpp | 2 +- 8 files changed, 63 insertions(+), 57 deletions(-) diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 73e5fe43495..920f14f2cd2 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -12,14 +12,11 @@ class Transaction; namespace storage { class MemoryManager; -using append_to_undo_buffer_func_t = - std::function; - class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, - bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, - append_to_undo_buffer_func_t appendToUndoBufferFunc = defaultAppendToUndoBuffer); + bool enableCompression, FileHandle* dataFH = nullptr, + common::Deserializer* deSer = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -52,7 +49,8 @@ class NodeGroupCollection { KU_ASSERT(nodeGroups.getGroupNoLock(groupIdx)->getNodeGroupIdx() == groupIdx); return nodeGroups.getGroup(lock, groupIdx); } - NodeGroup* getOrCreateNodeGroup(common::node_group_idx_t groupIdx, NodeGroupDataFormat format); + NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, + common::node_group_idx_t groupIdx, NodeGroupDataFormat format); void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, std::unique_ptr group) { @@ -92,9 +90,13 @@ class NodeGroupCollection { void serialize(common::Serializer& ser); void deserialize(common::Deserializer& deSer, MemoryManager& memoryManager); + void pushInsertInfo(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, storage::CSRNodeGroupScanSource source); + private: - static void defaultAppendToUndoBuffer(const transaction::Transaction*, NodeGroup*, - common::row_idx_t); + void appendToUndoBuffer(const transaction::Transaction* transaction, NodeGroup* nodeGroup, + common::row_idx_t numRows, CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); bool enableCompression; // Num rows in the collection regardless of deletions. @@ -103,7 +105,6 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; - append_to_undo_buffer_func_t appendToUndoBufferFunc; }; } // namespace storage diff --git a/src/include/storage/store/rel_table.h b/src/include/storage/store/rel_table.h index 7e8ee17bf9c..dc15221a5e3 100644 --- a/src/include/storage/store/rel_table.h +++ b/src/include/storage/store/rel_table.h @@ -166,8 +166,8 @@ class RelTable final : public Table { bwdRelTableData->getColumn(columnID); } - NodeGroup* getOrCreateNodeGroup(common::node_group_idx_t nodeGroupIdx, - common::RelDataDirection direction) const; + NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::RelDataDirection direction) const; void commit(transaction::Transaction* transaction, LocalTable* localTable) override; void checkpoint(common::Serializer& ser, catalog::TableCatalogEntry* tableEntry) override; diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index b9121d3e533..98b488ce2bf 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -53,8 +53,10 @@ class RelTableData { NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const { return nodeGroups->getNodeGroup(nodeGroupIdx, true /*mayOutOfBound*/); } - NodeGroup* getOrCreateNodeGroup(common::node_group_idx_t nodeGroupIdx) const { - return nodeGroups->getOrCreateNodeGroup(nodeGroupIdx, NodeGroupDataFormat::CSR); + NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, + NodeGroupDataFormat::CSR); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } diff --git a/src/processor/operator/persistent/rel_batch_insert.cpp b/src/processor/operator/persistent/rel_batch_insert.cpp index 765ba03a84b..4b8b5d181c6 100644 --- a/src/processor/operator/persistent/rel_batch_insert.cpp +++ b/src/processor/operator/persistent/rel_batch_insert.cpp @@ -68,15 +68,10 @@ void RelBatchInsert::executeInternal(ExecutionContext* context) { } // TODO(Guodong): We need to handle the concurrency between COPY and other insertions // into the same node group. - auto& nodeGroup = - relTable->getOrCreateNodeGroup(relLocalState->nodeGroupIdx, relInfo->direction) - ->cast(); - if (nodeGroup.isEmpty()) { - // push an insert of size 0 so that we can rollback the creation of this node group if - // needed - relTable->pushInsertInfo(context->clientContext->getTx(), relInfo->direction, nodeGroup, - 0, CSRNodeGroupScanSource::COMMITTED_PERSISTENT); - } + auto& nodeGroup = relTable + ->getOrCreateNodeGroup(context->clientContext->getTx(), + relLocalState->nodeGroupIdx, relInfo->direction) + ->cast(); appendNodeGroup(context->clientContext->getTx(), nodeGroup, *relInfo, *relLocalState, *sharedState, *partitionerSharedState); updateProgress(context); diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index d6acb10ebb2..6efded6296e 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -15,9 +15,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer, append_to_undo_buffer_func_t appendToUndoBufferFunc) + Deserializer* deSer) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH}, appendToUndoBufferFunc(std::move(appendToUndoBufferFunc)) { + dataFH{dataFH} { if (deSer) { deserialize(*deSer, memoryManager); } @@ -53,7 +53,7 @@ void NodeGroupCollection::append(const Transaction* transaction, const auto numToAppendInNodeGroup = std::min(numRowsToAppend - numRowsAppended, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); - appendToUndoBufferFunc(transaction, lastNodeGroup, numToAppendInNodeGroup); + appendToUndoBuffer(transaction, lastNodeGroup, numToAppendInNodeGroup); numTotalRows += numToAppendInNodeGroup; lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); numRowsAppended += numToAppendInNodeGroup; @@ -94,7 +94,7 @@ void NodeGroupCollection::append(const Transaction* transaction, NodeGroup& node std::min(numRowsToAppendInChunkedGroup - numRowsAppendedInChunkedGroup, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInBatch); - appendToUndoBufferFunc(transaction, lastNodeGroup, numToAppendInBatch); + appendToUndoBuffer(transaction, lastNodeGroup, numToAppendInBatch); numTotalRows += numToAppendInBatch; lastNodeGroup->append(transaction, *chunkedGroupToAppend, numRowsAppendedInChunkedGroup, numToAppendInBatch); @@ -131,7 +131,7 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush // If the node group is empty now and the chunked group is full, we can directly flush it. directFlushWhenAppend = numToAppend == numRowsLeftInLastNodeGroup && lastNodeGroup->getNumRows() == 0; - appendToUndoBufferFunc(transaction, lastNodeGroup, chunkedGroup.getNumRows()); + appendToUndoBuffer(transaction, lastNodeGroup, numToAppend); numTotalRows += numToAppend; if (!directFlushWhenAppend) { // TODO(Guodong): Furthur optimize on this. Should directly figure out startRowIdx to @@ -154,8 +154,8 @@ row_idx_t NodeGroupCollection::getNumTotalRows() { return numTotalRows; } -NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(node_group_idx_t groupIdx, - NodeGroupDataFormat format) { +NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* transaction, + node_group_idx_t groupIdx, NodeGroupDataFormat format) { const auto lock = nodeGroups.lock(); while (groupIdx >= nodeGroups.getNumGroups(lock)) { const auto currentGroupIdx = nodeGroups.getNumGroups(lock); @@ -164,6 +164,10 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(node_group_idx_t groupIdx, enableCompression, LogicalType::copy(types)) : std::make_unique(currentGroupIdx, enableCompression, LogicalType::copy(types))); + // push an insert of size 0 so that we can rollback the creation of this node group if + // needed + appendToUndoBuffer(transaction, nodeGroups.getLastGroup(lock), 0, + CSRNodeGroupScanSource::COMMITTED_PERSISTENT); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); return nodeGroups.getGroup(lock, groupIdx); @@ -220,8 +224,11 @@ void NodeGroupCollection::rollbackInsert(common::row_idx_t startRow, common::row nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); } } - KU_ASSERT(numRows_ <= numTotalRows); - numTotalRows -= numRows_; + + if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + KU_ASSERT(numRows_ <= numTotalRows); + numTotalRows -= numRows_; + } } void NodeGroupCollection::rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, @@ -246,6 +253,23 @@ void NodeGroupCollection::commitDelete(row_idx_t startRow, row_idx_t numRows_, nodeGroups.getGroup(lock, nodeGroupIdx)->commitDelete(startRow, numRows_, commitTS, source); } +void NodeGroupCollection::appendToUndoBuffer(const transaction::Transaction* transaction, + NodeGroup* nodeGroup, common::row_idx_t numRows, CSRNodeGroupScanSource source) { + pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, + source); +}; + +void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + storage::CSRNodeGroupScanSource source) { + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushInsertInfo(this, nodeGroupIdx, startRow, numRows, source); + if (source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY) { + numTotalRows += numRows; + } + } +} + void NodeGroupCollection::serialize(Serializer& ser) { ser.writeDebuggingInfo("node_groups"); nodeGroups.serializeGroups(ser); @@ -262,8 +286,5 @@ void NodeGroupCollection::deserialize(Deserializer& deSer, MemoryManager& memory stats.deserialize(deSer); } -void NodeGroupCollection::defaultAppendToUndoBuffer(const transaction::Transaction*, NodeGroup*, - common::row_idx_t) {} - } // namespace storage } // namespace kuzu diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index e3e6438c969..05175456cc5 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -204,16 +204,6 @@ bool NodeTableScanState::scanNext(Transaction* transaction) { return true; } -static decltype(auto) createAppendToUndoBufferFunc(NodeGroupCollection* nodeGroups) { - return [nodeGroups](const transaction::Transaction* transaction, NodeGroup* nodeGroup, - common::row_idx_t numRows) { - if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(nodeGroups, nodeGroup->getNodeGroupIdx(), - nodeGroup->getNumRows(), numRows); - } - }; -} - NodeTable::NodeTable(const StorageManager* storageManager, const NodeTableCatalogEntry* nodeTableEntry, MemoryManager* memoryManager, VirtualFileSystem* vfs, main::ClientContext* context, Deserializer* deSer) @@ -231,8 +221,7 @@ NodeTable::NodeTable(const StorageManager* storageManager, } nodeGroups = std::make_unique(*memoryManager, - getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer, - createAppendToUndoBufferFunc(nodeGroups.get())); + getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index c84dfa9af5d..4b711f48cb1 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -378,20 +378,18 @@ void RelTable::addColumn(Transaction* transaction, TableAddColumnState& addColum hasChanges = true; } -NodeGroup* RelTable::getOrCreateNodeGroup(node_group_idx_t nodeGroupIdx, - RelDataDirection direction) const { +NodeGroup* RelTable::getOrCreateNodeGroup(transaction::Transaction* transaction, + node_group_idx_t nodeGroupIdx, RelDataDirection direction) const { return direction == RelDataDirection::FWD ? - fwdRelTableData->getOrCreateNodeGroup(nodeGroupIdx) : - bwdRelTableData->getOrCreateNodeGroup(nodeGroupIdx); + fwdRelTableData->getOrCreateNodeGroup(transaction, nodeGroupIdx) : + bwdRelTableData->getOrCreateNodeGroup(transaction, nodeGroupIdx); } void RelTable::pushInsertInfo(Transaction* transaction, RelDataDirection direction, const CSRNodeGroup& nodeGroup, row_idx_t numRows_, CSRNodeGroupScanSource source) { - if (transaction->shouldAppendToUndoBuffer()) { - auto& relTableData = - (direction == common::RelDataDirection::FWD) ? fwdRelTableData : bwdRelTableData; - relTableData->pushInsertInfo(transaction, nodeGroup, numRows_, source); - } + auto& relTableData = + (direction == common::RelDataDirection::FWD) ? fwdRelTableData : bwdRelTableData; + relTableData->pushInsertInfo(transaction, nodeGroup, numRows_, source); } void RelTable::commit(Transaction* transaction, LocalTable* localTable) { @@ -421,7 +419,7 @@ void RelTable::commit(Transaction* transaction, LocalTable* localTable) { auto [nodeGroupIdx, boundOffsetInGroup] = StorageUtils::getQuotientRemainder( boundNodeOffset, StorageConstants::NODE_GROUP_SIZE); auto& nodeGroup = - relTableData->getOrCreateNodeGroup(nodeGroupIdx)->cast(); + relTableData->getOrCreateNodeGroup(transaction, nodeGroupIdx)->cast(); pushInsertInfo(transaction, direction, nodeGroup, rowIndices.size(), CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); prepareCommitForNodeGroup(transaction, localNodeGroup, nodeGroup, boundOffsetInGroup, diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index 9f28185c97f..8c36efb9c62 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -195,7 +195,7 @@ void RelTableData::pushInsertInfo(transaction::Transaction* transaction, const auto startRow = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? nodeGroup.getNumPersistentRows() : nodeGroup.getNumRows(); - transaction->pushInsertInfo(nodeGroups.get(), nodeGroup.getNodeGroupIdx(), startRow, numRows_, + nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, source); } From ac3e9e1f6c50d4020570329ed500364d08cb9d2b Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 15:41:45 -0500 Subject: [PATCH 06/28] Code cleanup --- src/include/storage/store/rel_table.h | 2 ++ src/storage/store/rel_table.cpp | 21 +++++++++------------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/include/storage/store/rel_table.h b/src/include/storage/store/rel_table.h index dc15221a5e3..5b8ed574653 100644 --- a/src/include/storage/store/rel_table.h +++ b/src/include/storage/store/rel_table.h @@ -208,6 +208,8 @@ class RelTable final : public Table { void checkRelMultiplicityConstraint(transaction::Transaction* transaction, const TableInsertState& state) const; + RelTableData* getRelTableData(common::RelDataDirection direction) const; + private: common::table_id_t fromNodeTableID; common::table_id_t toNodeTableID; diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index 4b711f48cb1..0c6733117bd 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -153,9 +153,7 @@ void RelTable::initScanState(Transaction* transaction, TableScanState& scanState const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeID.offset); if (relScanState.nodeGroupIdx != nodeGroupIdx) { // We need to re-initialize the node group scan state. - nodeGroup = relScanState.direction == RelDataDirection::FWD ? - fwdRelTableData->getNodeGroup(nodeGroupIdx) : - bwdRelTableData->getNodeGroup(nodeGroupIdx); + nodeGroup = getRelTableData(relScanState.direction)->getNodeGroup(nodeGroupIdx); } else { nodeGroup = relScanState.nodeGroup; } @@ -319,9 +317,8 @@ bool RelTable::checkIfNodeHasRels(Transaction* transaction, RelDataDirection dir hasRels = hasRels || localTable->cast().checkIfNodeHasRels(srcNodeIDVector, direction); } - hasRels = hasRels || ((direction == RelDataDirection::FWD) ? - fwdRelTableData->checkIfNodeHasRels(transaction, srcNodeIDVector) : - bwdRelTableData->checkIfNodeHasRels(transaction, srcNodeIDVector)); + hasRels = + hasRels || (getRelTableData(direction)->checkIfNodeHasRels(transaction, srcNodeIDVector)); return hasRels; } @@ -378,18 +375,18 @@ void RelTable::addColumn(Transaction* transaction, TableAddColumnState& addColum hasChanges = true; } +RelTableData* RelTable::getRelTableData(common::RelDataDirection direction) const { + return direction == RelDataDirection::FWD ? fwdRelTableData.get() : bwdRelTableData.get(); +} + NodeGroup* RelTable::getOrCreateNodeGroup(transaction::Transaction* transaction, node_group_idx_t nodeGroupIdx, RelDataDirection direction) const { - return direction == RelDataDirection::FWD ? - fwdRelTableData->getOrCreateNodeGroup(transaction, nodeGroupIdx) : - bwdRelTableData->getOrCreateNodeGroup(transaction, nodeGroupIdx); + return getRelTableData(direction)->getOrCreateNodeGroup(transaction, nodeGroupIdx); } void RelTable::pushInsertInfo(Transaction* transaction, RelDataDirection direction, const CSRNodeGroup& nodeGroup, row_idx_t numRows_, CSRNodeGroupScanSource source) { - auto& relTableData = - (direction == common::RelDataDirection::FWD) ? fwdRelTableData : bwdRelTableData; - relTableData->pushInsertInfo(transaction, nodeGroup, numRows_, source); + getRelTableData(direction)->pushInsertInfo(transaction, nodeGroup, numRows_, source); } void RelTable::commit(Transaction* transaction, LocalTable* localTable) { From 99e2ef17aae65549ac2c627b03e8ceb97fa0d496 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 16:02:38 -0500 Subject: [PATCH 07/28] Avoid appending to undo buffer for local tables --- .../storage/store/node_group_collection.h | 2 +- src/storage/store/node_group_collection.cpp | 22 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 920f14f2cd2..186028b0b1f 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -95,7 +95,7 @@ class NodeGroupCollection { common::row_idx_t numRows, storage::CSRNodeGroupScanSource source); private: - void appendToUndoBuffer(const transaction::Transaction* transaction, NodeGroup* nodeGroup, + void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); bool enableCompression; diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index 6efded6296e..aeb5f213a9f 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -53,8 +53,7 @@ void NodeGroupCollection::append(const Transaction* transaction, const auto numToAppendInNodeGroup = std::min(numRowsToAppend - numRowsAppended, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); - appendToUndoBuffer(transaction, lastNodeGroup, numToAppendInNodeGroup); - numTotalRows += numToAppendInNodeGroup; + pushInsertInfo(transaction, lastNodeGroup, numToAppendInNodeGroup); lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); numRowsAppended += numToAppendInNodeGroup; } @@ -94,8 +93,7 @@ void NodeGroupCollection::append(const Transaction* transaction, NodeGroup& node std::min(numRowsToAppendInChunkedGroup - numRowsAppendedInChunkedGroup, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInBatch); - appendToUndoBuffer(transaction, lastNodeGroup, numToAppendInBatch); - numTotalRows += numToAppendInBatch; + pushInsertInfo(transaction, lastNodeGroup, numToAppendInBatch); lastNodeGroup->append(transaction, *chunkedGroupToAppend, numRowsAppendedInChunkedGroup, numToAppendInBatch); numRowsAppendedInChunkedGroup += numToAppendInBatch; @@ -131,8 +129,7 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush // If the node group is empty now and the chunked group is full, we can directly flush it. directFlushWhenAppend = numToAppend == numRowsLeftInLastNodeGroup && lastNodeGroup->getNumRows() == 0; - appendToUndoBuffer(transaction, lastNodeGroup, numToAppend); - numTotalRows += numToAppend; + pushInsertInfo(transaction, lastNodeGroup, numToAppend); if (!directFlushWhenAppend) { // TODO(Guodong): Furthur optimize on this. Should directly figure out startRowIdx to // start appending into the node group and pass in as param. @@ -166,7 +163,7 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* t enableCompression, LogicalType::copy(types))); // push an insert of size 0 so that we can rollback the creation of this node group if // needed - appendToUndoBuffer(transaction, nodeGroups.getLastGroup(lock), 0, + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, CSRNodeGroupScanSource::COMMITTED_PERSISTENT); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); @@ -253,7 +250,7 @@ void NodeGroupCollection::commitDelete(row_idx_t startRow, row_idx_t numRows_, nodeGroups.getGroup(lock, nodeGroupIdx)->commitDelete(startRow, numRows_, commitTS, source); } -void NodeGroupCollection::appendToUndoBuffer(const transaction::Transaction* transaction, +void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, CSRNodeGroupScanSource source) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, source); @@ -262,11 +259,12 @@ void NodeGroupCollection::appendToUndoBuffer(const transaction::Transaction* tra void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source) { - if (transaction->shouldAppendToUndoBuffer()) { + // we only append to the undo buffer if the node group collection is persistent + if (dataFH && transaction->shouldAppendToUndoBuffer()) { transaction->pushInsertInfo(this, nodeGroupIdx, startRow, numRows, source); - if (source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY) { - numTotalRows += numRows; - } + } + if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + numTotalRows += numRows; } } From 1308842801f2a1f21fdf5233bf5d821dfa5aa810 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 18 Nov 2024 18:18:46 -0500 Subject: [PATCH 08/28] Bug fixes --- .../storage/store/node_group_collection.h | 1 + src/include/storage/store/node_table.h | 6 +++--- src/include/storage/store/table.h | 3 --- src/include/storage/undo_buffer.h | 10 ++++------ src/include/transaction/transaction.h | 9 ++++++++- src/storage/store/csr_node_group.cpp | 2 +- src/storage/store/node_table.cpp | 12 ++++++------ src/storage/undo_buffer.cpp | 19 ++++++++++--------- src/transaction/transaction.cpp | 6 ++++-- 9 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 186028b0b1f..ac0e24cbbed 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -4,6 +4,7 @@ #include "storage/stats/table_stats.h" #include "storage/store/group_collection.h" #include "storage/store/node_group.h" +#include "transaction/transaction.h" namespace kuzu { namespace transaction { diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index a94ed1bb860..00397bed590 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -176,8 +176,8 @@ class NodeTable final : public Table { TableStats getStats(const transaction::Transaction* transaction) const; - const pre_rollback_insert_func_t& getPreRollbackInsertFunc() const { - return preRollbackInsertFunc; + const transaction::rollback_insert_func_t& getRollbackInsertFunc() const { + return rollbackInsertFunc; } private: @@ -199,7 +199,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; - pre_rollback_insert_func_t preRollbackInsertFunc; + transaction::rollback_insert_func_t rollbackInsertFunc; }; } // namespace storage diff --git a/src/include/storage/store/table.h b/src/include/storage/store/table.h index ab0eccaedf4..18971ac672d 100644 --- a/src/include/storage/store/table.h +++ b/src/include/storage/store/table.h @@ -13,9 +13,6 @@ class ExpressionEvaluator; namespace storage { class MemoryManager; -using pre_rollback_insert_func_t = std::function; - enum class TableScanSource : uint8_t { COMMITTED = 0, UNCOMMITTED = 1, NONE = UINT8_MAX }; struct TableScanState { diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 0be92a42b3e..95dccdce206 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -5,7 +5,7 @@ #include "common/constants.h" #include "common/types/types.h" #include "storage/enums/csr_node_group_scan_source.h" -#include "storage/store/table.h" +#include "transaction/transaction.h" namespace kuzu { namespace catalog { @@ -69,7 +69,6 @@ class VersionInfo; struct VectorUpdateInfo; class RelTableData; class NodeTable; -class NodeGroupCollection; class WAL; // This class is not thread safe, as it is supposed to be accessed by a single thread. class UndoBuffer { @@ -91,9 +90,8 @@ class UndoBuffer { const catalog::SequenceRollbackData& data); void createInsertInfo(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); - void createInsertInfo(NodeTable* nodeTable, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows); + storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE, + const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); void createDeleteInfo(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, storage::CSRNodeGroupScanSource source); @@ -112,7 +110,7 @@ class UndoBuffer { common::row_idx_t startRow, common::row_idx_t numRows, common::node_group_idx_t nodeGroupIdx = 0, storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE, - const pre_rollback_insert_func_t* preRollbackCallback = nullptr); + const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); void commitRecord(UndoRecordType recordType, const uint8_t* record, common::transaction_t commitTS) const; diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index cf6ba1b475f..11e3f0617f4 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "common/enums/statement_type.h" #include "common/types/types.h" #include "storage/enums/csr_node_group_scan_source.h" @@ -28,6 +30,10 @@ namespace transaction { class TransactionManager; class Transaction; +using rollback_insert_func_t = + std::function; + enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; class Transaction { @@ -119,7 +125,8 @@ class Transaction { void pushInsertInfo(storage::NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE) const; + storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE, + const transaction::rollback_insert_func_t* rollbackInsertCallback = nullptr) const; void pushDeleteInfo(storage::NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index eb3e7709826..ac3420de4af 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -963,7 +963,7 @@ std::pair CSRNodeGroup::actionOnChunkedGroups(const common::Un if (persistentChunkGroup) { std::invoke(operation, *persistentChunkGroup, startRow, numRows_, commitTS); } - return {UINT32_MAX, UINT32_MAX}; + return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; } else { KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); return NodeGroup::actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 05175456cc5..6b279706397 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -220,16 +220,16 @@ NodeTable::NodeTable(const StorageManager* storageManager, dataFH, memoryManager, shadowFile, enableCompression); } + rollbackInsertFunc = [this](const transaction::Transaction* transaction, + common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx_, CSRNodeGroupScanSource) { + return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_); + }; + nodeGroups = std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); - - preRollbackInsertFunc = [this](const transaction::Transaction* transaction, - common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx_) { - return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_); - }; } std::unique_ptr NodeTable::loadTable(Deserializer& deSer, const Catalog& catalog, diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 4fe60193917..e097d2588c2 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -42,7 +42,7 @@ struct VersionRecord { row_idx_t startRow; row_idx_t numRows; node_group_idx_t nodeGroupIdx; - const pre_rollback_insert_func_t* preRollbackCallback; + const transaction::rollback_insert_func_t* rollbackInsertCallback; CSRNodeGroupScanSource source; }; @@ -114,9 +114,10 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, } void UndoBuffer::createInsertInfo(NodeGroupCollection* nodeGroups, node_group_idx_t nodeGroupIdx, - row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source) { + row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source, + const transaction::rollback_insert_func_t* rollbackInsertFunc) { createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroups, startRow, numRows, nodeGroupIdx, - source); + source, rollbackInsertFunc); } void UndoBuffer::createDeleteInfo(NodeGroupCollection* nodeGroups, @@ -129,13 +130,13 @@ void UndoBuffer::createDeleteInfo(NodeGroupCollection* nodeGroups, void UndoBuffer::createVersionInfo(const UndoRecordType recordType, NodeGroupCollection* nodeGroupCollection, row_idx_t startRow, row_idx_t numRows, node_group_idx_t nodeGroupIdx, storage::CSRNodeGroupScanSource source, - const pre_rollback_insert_func_t* callback) { + const transaction::rollback_insert_func_t* rollbackInsertFunc) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); - *reinterpret_cast(buffer) = - VersionRecord{nodeGroupCollection, startRow, numRows, nodeGroupIdx, callback, source}; + *reinterpret_cast(buffer) = VersionRecord{nodeGroupCollection, startRow, + numRows, nodeGroupIdx, rollbackInsertFunc, source}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -301,9 +302,9 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - if (undoRecord.preRollbackCallback) { - (*undoRecord.preRollbackCallback)(transaction, undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx); + if (undoRecord.rollbackInsertCallback) { + (*undoRecord.rollbackInsertCallback)(transaction, undoRecord.startRow, + undoRecord.numRows, undoRecord.nodeGroupIdx, undoRecord.source); } undoRecord.nodeGroupCollection->rollbackInsert(undoRecord.startRow, undoRecord.numRows, undoRecord.nodeGroupIdx, undoRecord.source); diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index f73492d6156..b8fd6cc0743 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -175,8 +175,10 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ void Transaction::pushInsertInfo(storage::NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source) const { - undoBuffer->createInsertInfo(nodeGroups, nodeGroupIdx, startRow, numRows, source); + storage::CSRNodeGroupScanSource source, + const transaction::rollback_insert_func_t* rollbackInsertCallback) const { + undoBuffer->createInsertInfo(nodeGroups, nodeGroupIdx, startRow, numRows, source, + rollbackInsertCallback); } void Transaction::pushDeleteInfo(storage::NodeGroupCollection* nodeGroups, From 8c72b078f24c4fa561bb7f27c8f68f3c4fd239b1 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Tue, 19 Nov 2024 09:20:41 -0500 Subject: [PATCH 09/28] Actually pass rollback insert callback --- src/include/storage/store/node_group_collection.h | 5 +++-- src/storage/store/node_group_collection.cpp | 7 ++++--- src/storage/store/node_table.cpp | 5 +++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index ac0e24cbbed..7373642aa6f 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -16,8 +16,8 @@ class MemoryManager; class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, - bool enableCompression, FileHandle* dataFH = nullptr, - common::Deserializer* deSer = nullptr); + bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, + const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -106,6 +106,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; + const transaction::rollback_insert_func_t* rollbackInsertFunc; }; } // namespace storage diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index aeb5f213a9f..9a6ff10cf3e 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -15,9 +15,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer) + Deserializer* deSer, const transaction::rollback_insert_func_t* rollbackInsertFunc) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH} { + dataFH{dataFH}, rollbackInsertFunc(rollbackInsertFunc) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -261,7 +261,8 @@ void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transac storage::CSRNodeGroupScanSource source) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(this, nodeGroupIdx, startRow, numRows, source); + transaction->pushInsertInfo(this, nodeGroupIdx, startRow, numRows, source, + rollbackInsertFunc); } if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { numTotalRows += numRows; diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 6b279706397..ec86b403a74 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -226,8 +226,9 @@ NodeTable::NodeTable(const StorageManager* storageManager, return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_); }; - nodeGroups = std::make_unique(*memoryManager, - getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); + nodeGroups = + std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), + enableCompression, storageManager->getDataFH(), deSer, &rollbackInsertFunc); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); } From ffd0cd5b30a03173b88fb27f4a33cb1addbcad4c Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Tue, 19 Nov 2024 13:55:01 -0500 Subject: [PATCH 10/28] Add second layer of iterators to undo buffer --- scripts/headers.txt | 1 - .../store/chunked_group_undo_iterator.h | 47 ++++++++ src/include/storage/store/csr_node_group.h | 16 ++- src/include/storage/store/node_group.h | 32 +++--- .../storage/store/node_group_collection.h | 27 ++--- src/include/storage/store/node_table.h | 17 ++- src/include/storage/store/rel_table_data.h | 7 +- src/include/storage/undo_buffer.h | 22 ++-- src/include/transaction/transaction.h | 19 ++-- .../operator/persistent/rel_batch_insert.cpp | 1 + src/storage/store/csr_node_group.cpp | 35 +++--- src/storage/store/node_group.cpp | 106 ++++++++---------- src/storage/store/node_group_collection.cpp | 68 +++-------- src/storage/store/node_table.cpp | 27 +++-- src/storage/store/rel_table_data.cpp | 31 ++++- src/storage/undo_buffer.cpp | 61 +++++----- src/transaction/transaction.cpp | 18 ++- 17 files changed, 286 insertions(+), 249 deletions(-) create mode 100644 src/include/storage/store/chunked_group_undo_iterator.h diff --git a/scripts/headers.txt b/scripts/headers.txt index 45600901d52..36873186cc4 100644 --- a/scripts/headers.txt +++ b/scripts/headers.txt @@ -71,6 +71,5 @@ src/include/processor/result/flat_tuple.h src/include/processor/warning_context.h src/include/processor/operator/persistent/reader/copy_from_error.h src/include/storage/storage_version_info.h -src/include/storage/enums/csr_node_group_scan_source.h src/include/transaction/transaction.h src/include/transaction/transaction_context.h diff --git a/src/include/storage/store/chunked_group_undo_iterator.h b/src/include/storage/store/chunked_group_undo_iterator.h new file mode 100644 index 00000000000..de15baabddf --- /dev/null +++ b/src/include/storage/store/chunked_group_undo_iterator.h @@ -0,0 +1,47 @@ +#pragma once + +#include + +#include "common/types/types.h" + +namespace kuzu { + +namespace transaction { +class Transaction; +} + +namespace storage { +class ChunkedNodeGroup; +class NodeGroupCollection; +class ChunkedGroupUndoIterator; + +using chunked_group_undo_op_t = void ( + ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); + +using chunked_group_iterator_construct_t = + std::function(common::row_idx_t, common::row_idx_t, + common::node_group_idx_t, common::transaction_t commitTS)>; + +// Note: these iterators are not necessarily thread-safe when used on their own +class ChunkedGroupUndoIterator { +public: + ChunkedGroupUndoIterator(NodeGroupCollection* nodeGroups, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) + : startRow(startRow), numRows(numRows), commitTS(commitTS), nodeGroups(nodeGroups) {} + + virtual ~ChunkedGroupUndoIterator() = default; + + virtual void initRollbackInsert(const transaction::Transaction* /*transaction*/) {} + virtual void finalizeRollbackInsert() {}; + virtual void iterate(chunked_group_undo_op_t undoFunc) = 0; + +protected: + common::row_idx_t startRow; + common::row_idx_t numRows; + common::transaction_t commitTS; + + NodeGroupCollection* nodeGroups; +}; + +} // namespace storage +} // namespace kuzu diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index 5ed1a60d7d5..338ef9b607f 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -165,6 +165,18 @@ static constexpr common::column_id_t REL_ID_COLUMN_ID = 1; struct RelTableScanState; class CSRNodeGroup final : public NodeGroup { public: + class PersistentIterator : public ChunkedGroupUndoIterator { + public: + PersistentIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); + + void iterate(chunked_group_undo_op_t undoFunc) override; + void finalizeRollbackInsert() override; + + private: + CSRNodeGroup* nodeGroup; + }; + static constexpr PackedCSRInfo DEFAULT_PACKED_CSR_INFO{}; CSRNodeGroup(const common::node_group_idx_t nodeGroupIdx, const bool enableCompression, @@ -215,10 +227,6 @@ class CSRNodeGroup final : public NodeGroup { void serialize(common::Serializer& serializer) override; private: - std::pair actionOnChunkedGroups(const common::UniqLock& lock, - common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS, - CSRNodeGroupScanSource source, chunked_group_transaction_operation_t operation) override; - void initScanForCommittedPersistent(const transaction::Transaction* transaction, RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const; void initScanForCommittedInMem(RelTableScanState& relScanState, diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index e249ef6dd39..005e5e314d8 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -5,6 +5,7 @@ #include "common/uniq_lock.h" #include "storage/enums/csr_node_group_scan_source.h" #include "storage/enums/residency_state.h" +#include "storage/store/chunked_group_undo_iterator.h" #include "storage/store/chunked_node_group.h" #include "storage/store/group_collection.h" @@ -18,6 +19,8 @@ class MemoryManager; struct TableAddColumnState; class NodeGroup; +class NodeGroupCollection; + struct NodeGroupScanState { // Index of committed but not yet checkpointed chunked group to scan. common::idx_t chunkedGroupIdx = 0; @@ -80,6 +83,18 @@ static auto NODE_GROUP_SCAN_EMMPTY_RESULT = NodeGroupScanResult{}; struct TableScanState; class NodeGroup { public: + class NodeGroupBaseIterator : public ChunkedGroupUndoIterator { + public: + NodeGroupBaseIterator(NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS); + void iterate(chunked_group_undo_op_t undoFunc) override; + void finalizeRollbackInsert() override; + + protected: + NodeGroup* nodeGroup; + }; + NodeGroup(const common::node_group_idx_t nodeGroupIdx, const bool enableCompression, std::vector dataTypes, common::row_idx_t capacity = common::StorageConstants::NODE_GROUP_SIZE, @@ -150,15 +165,7 @@ class NodeGroup { void flush(transaction::Transaction* transaction, FileHandle& dataFH); - void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows_, - common::transaction_t commitTS, CSRNodeGroupScanSource source); - void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_, - common::transaction_t commitTS, CSRNodeGroupScanSource source); - - void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, - CSRNodeGroupScanSource source); - void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, - CSRNodeGroupScanSource source); + void rollbackInsert(common::row_idx_t startRow); virtual void checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state); @@ -198,13 +205,6 @@ class NodeGroup { static constexpr auto INVALID_CHUNKED_GROUP_IDX = UINT32_MAX; static constexpr auto INVALID_START_ROW_IDX = UINT64_MAX; - using chunked_group_transaction_operation_t = void ( - ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); - virtual std::pair actionOnChunkedGroups( - const common::UniqLock& lock, common::row_idx_t startRow, common::row_idx_t numRows_, - common::transaction_t commitTS, CSRNodeGroupScanSource source, - chunked_group_transaction_operation_t operation); - private: std::pair findChunkedGroupIdxFromRowIdxNoLock( common::row_idx_t rowIdx) const; diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 7373642aa6f..e2c2bd070df 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -17,7 +17,7 @@ class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, - const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); + const chunked_group_iterator_construct_t* iteratorConstructFunc = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -51,7 +51,8 @@ class NodeGroupCollection { return nodeGroups.getGroup(lock, groupIdx); } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, - common::node_group_idx_t groupIdx, NodeGroupDataFormat format); + common::node_group_idx_t groupIdx, NodeGroupDataFormat format, + const chunked_group_iterator_construct_t* constructIteratorFunc_); void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, std::unique_ptr group) { @@ -59,19 +60,7 @@ class NodeGroupCollection { nodeGroups.replaceGroup(lock, nodeGroupIdx, std::move(group)); } - void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, - CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); - void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, - CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); - - void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx, - CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); - void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx, - CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + void rollbackInsert(common::row_idx_t numRows_, bool updateNumRows = true); void clear() { const auto lock = nodeGroups.lock(); @@ -93,11 +82,13 @@ class NodeGroupCollection { void pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, storage::CSRNodeGroupScanSource source); + common::row_idx_t numRows, + const chunked_group_iterator_construct_t* constructIteratorFunc_); private: void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, - common::row_idx_t numRows, CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE); + common::row_idx_t numRows, + const chunked_group_iterator_construct_t* constructIteratorFunc_ = nullptr); bool enableCompression; // Num rows in the collection regardless of deletions. @@ -106,7 +97,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; - const transaction::rollback_insert_func_t* rollbackInsertFunc; + const chunked_group_iterator_construct_t* iteratorConstructFunc; }; } // namespace storage diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index 00397bed590..ae73526ef01 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -84,6 +84,17 @@ struct NodeTableDeleteState final : TableDeleteState { class StorageManager; class NodeTable final : public Table { public: + class NodeGroupIterator : public NodeGroup::NodeGroupBaseIterator { + public: + NodeGroupIterator(NodeTable* table, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); + + void initRollbackInsert(const transaction::Transaction* transaction) override; + + private: + NodeTable* table; + }; + static std::vector getNodeTableColumnTypes(const NodeTable& table) { std::vector types; for (auto i = 0u; i < table.getNumColumns(); i++) { @@ -176,8 +187,8 @@ class NodeTable final : public Table { TableStats getStats(const transaction::Transaction* transaction) const; - const transaction::rollback_insert_func_t& getRollbackInsertFunc() const { - return rollbackInsertFunc; + const chunked_group_iterator_construct_t& getIteratorConstructFunc() const { + return iteratorConstructFunc; } private: @@ -199,7 +210,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; - transaction::rollback_insert_func_t rollbackInsertFunc; + chunked_group_iterator_construct_t iteratorConstructFunc; }; } // namespace storage diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index 98b488ce2bf..bf80c1da125 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -55,8 +55,8 @@ class RelTableData { } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx) const { - return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, - NodeGroupDataFormat::CSR); + return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, NodeGroupDataFormat::CSR, + &persistentIteratorConstructFunc); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } @@ -113,6 +113,9 @@ class RelTableData { CSRHeaderColumns csrHeaderColumns; std::vector> columns; + + chunked_group_iterator_construct_t inMemIteratorConstructFunc; + chunked_group_iterator_construct_t persistentIteratorConstructFunc; }; } // namespace storage diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 95dccdce206..6fa6e318828 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -4,8 +4,7 @@ #include "common/constants.h" #include "common/types/types.h" -#include "storage/enums/csr_node_group_scan_source.h" -#include "transaction/transaction.h" +#include "storage/store/node_group.h" namespace kuzu { namespace catalog { @@ -88,13 +87,10 @@ class UndoBuffer { void createCatalogEntry(catalog::CatalogSet& catalogSet, catalog::CatalogEntry& catalogEntry); void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, const catalog::SequenceRollbackData& data); - void createInsertInfo(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE, - const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); - void createDeleteInfo(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source); + void createInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc); + void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, VectorUpdateInfo* vectorUpdateInfo); @@ -106,11 +102,9 @@ class UndoBuffer { private: uint8_t* createUndoRecord(uint64_t size); - void createVersionInfo(UndoRecordType recordType, NodeGroupCollection* nodeGroupCollection, - common::row_idx_t startRow, common::row_idx_t numRows, - common::node_group_idx_t nodeGroupIdx = 0, - storage::CSRNodeGroupScanSource source = CSRNodeGroupScanSource::NONE, - const transaction::rollback_insert_func_t* rollbackInsertFunc = nullptr); + void createVersionInfo(UndoRecordType recordType, common::row_idx_t startRow, + common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc, + common::node_group_idx_t nodeGroupIdx = 0); void commitRecord(UndoRecordType recordType, const uint8_t* record, common::transaction_t commitTS) const; diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index 11e3f0617f4..d2cdc7109b3 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -4,7 +4,6 @@ #include "common/enums/statement_type.h" #include "common/types/types.h" -#include "storage/enums/csr_node_group_scan_source.h" namespace kuzu { namespace catalog { @@ -25,14 +24,15 @@ class UpdateInfo; struct VectorUpdateInfo; class NodeGroupCollection; class ChunkedNodeGroup; +class ChunkedGroupUndoIterator; } // namespace storage namespace transaction { class TransactionManager; class Transaction; -using rollback_insert_func_t = - std::function; +using chunked_group_iterator_construct_t = + std::function(common::row_idx_t, + common::row_idx_t, common::node_group_idx_t, common::transaction_t commitTS)>; enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; @@ -122,15 +122,12 @@ class Transaction { bool skipLoggingToWAL = false) const; void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount, const catalog::SequenceRollbackData& data) const; - void pushInsertInfo(storage::NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + void pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE, - const transaction::rollback_insert_func_t* rollbackInsertCallback = nullptr) const; - void pushDeleteInfo(storage::NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + const chunked_group_iterator_construct_t* constructIteratorFunc = nullptr) const; + void pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source = storage::CSRNodeGroupScanSource::NONE) const; + const chunked_group_iterator_construct_t* constructIteratorFunc) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const; diff --git a/src/processor/operator/persistent/rel_batch_insert.cpp b/src/processor/operator/persistent/rel_batch_insert.cpp index 4b8b5d181c6..c2e4545fadc 100644 --- a/src/processor/operator/persistent/rel_batch_insert.cpp +++ b/src/processor/operator/persistent/rel_batch_insert.cpp @@ -85,6 +85,7 @@ static void appendNewChunkedGroup(transaction::Transaction* transaction, const CSRNodeGroupScanSource source = isNewNodeGroup ? CSRNodeGroupScanSource::COMMITTED_PERSISTENT : CSRNodeGroupScanSource::COMMITTED_IN_MEMORY; + // TODO this may need to be atomic relTable.pushInsertInfo(transaction, direction, nodeGroup, chunkedGroup.getNumRows(), source); if (isNewNodeGroup) { auto flushedChunkedGroup = diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index ac3420de4af..88a7a8a19f4 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -12,6 +12,25 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { +CSRNodeGroup::PersistentIterator::PersistentIterator(NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS) + : ChunkedGroupUndoIterator(nodeGroups, startRow, numRows, commitTS), nodeGroup(nullptr) { + if (nodeGroupIdx < nodeGroups->getNumNodeGroups()) { + nodeGroup = ku_dynamic_cast(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)); + } +} + +void CSRNodeGroup::PersistentIterator::iterate(chunked_group_undo_op_t undoFunc) { + if (nodeGroup && nodeGroup->persistentChunkGroup) { + std::invoke(undoFunc, *nodeGroup->persistentChunkGroup, startRow, numRows, commitTS); + } +} + +void CSRNodeGroup::PersistentIterator::finalizeRollbackInsert() { + nodeGroups->rollbackInsert(numRows, false); +} + bool CSRNodeGroupScanState::tryScanCachedTuples(RelTableScanState& tableScanState) { if (numCachedRows == 0 || tableScanState.currBoundNodeIdx >= tableScanState.cachedBoundNodeSelVector.getSelSize()) { @@ -955,22 +974,6 @@ void CSRNodeGroup::finalizeCheckpoint(const UniqLock& lock) { csrIndex.reset(); } -std::pair CSRNodeGroup::actionOnChunkedGroups(const common::UniqLock& lock, - common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS, - CSRNodeGroupScanSource source, chunked_group_transaction_operation_t operation) { - if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { - KU_ASSERT(persistentChunkGroup || (numRows_ == 0)); - if (persistentChunkGroup) { - std::invoke(operation, *persistentChunkGroup, startRow, numRows_, commitTS); - } - return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; - } else { - KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); - return NodeGroup::actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, - operation); - } -} - common::row_idx_t CSRNodeGroup::getNumPersistentRows() const { if (!persistentChunkGroup) { return 0; diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index b5a96a61cf1..f35b8758292 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -4,6 +4,7 @@ #include "common/constants.h" #include "common/types/types.h" #include "common/uniq_lock.h" +#include "common/utils.h" #include "main/client_context.h" #include "storage/buffer_manager/memory_manager.h" #include "storage/enums/residency_state.h" @@ -21,6 +22,40 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { +NodeGroup::NodeGroupBaseIterator::NodeGroupBaseIterator(NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + transaction_t commitTS) + : ChunkedGroupUndoIterator(nodeGroups, startRow, numRows, commitTS), + nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)) {} + +void NodeGroup::NodeGroupBaseIterator::iterate(chunked_group_undo_op_t undoFunc) { + auto lock = nodeGroup->chunkedGroups.lock(); + const auto [chunkedGroupIdx, startRowInChunkedGroup] = + nodeGroup->findChunkedGroupIdxFromRowIdxNoLock(startRow); + if (chunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { + auto curChunkedGroupIdx = chunkedGroupIdx; + auto curStartRowIdxInChunk = startRowInChunkedGroup; + + auto numRowsLeft = numRows; + while ( + numRowsLeft > 0 && curChunkedGroupIdx < nodeGroup->chunkedGroups.getNumGroups(lock)) { + auto* chunkedGroup = nodeGroup->chunkedGroups.getGroup(lock, curChunkedGroupIdx); + const auto numRowsForGroup = + std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); + std::invoke(undoFunc, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); + + ++curChunkedGroupIdx; + numRowsLeft -= numRowsForGroup; + curStartRowIdxInChunk = 0; + } + } +} + +void NodeGroup::NodeGroupBaseIterator::finalizeRollbackInsert() { + nodeGroup->rollbackInsert(startRow); + nodeGroups->rollbackInsert(numRows); +} + row_idx_t NodeGroup::append(const Transaction* transaction, ChunkedNodeGroup& chunkedGroup, row_idx_t startRowIdx, row_idx_t numRowsToAppend) { KU_ASSERT(numRowsToAppend <= chunkedGroup.getNumRows()); @@ -353,69 +388,20 @@ void NodeGroup::flush(Transaction* transaction, FileHandle& dataFH) { chunkedGroups.resize(lock, 1); } -std::pair NodeGroup::actionOnChunkedGroups(const common::UniqLock& lock, - common::row_idx_t startRow, common::row_idx_t numRows_, common::transaction_t commitTS, - CSRNodeGroupScanSource, chunked_group_transaction_operation_t operation) { - const auto [startChunkedGroupIdx, startRowIdxInChunk] = - findChunkedGroupIdxFromRowIdxNoLock(startRow); - if (startChunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { - auto curChunkedGroupIdx = startChunkedGroupIdx; - auto curStartRowIdxInChunk = startRowIdxInChunk; - - auto numRowsLeft = numRows_; - while (numRowsLeft > 0 && curChunkedGroupIdx < chunkedGroups.getNumGroups(lock)) { - auto* chunkedGroup = chunkedGroups.getGroup(lock, curChunkedGroupIdx); - const auto numRowsForGroup = - std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); - std::invoke(operation, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); - - ++curChunkedGroupIdx; - numRowsLeft -= numRowsForGroup; - curStartRowIdxInChunk = 0; - } - } - - return {startChunkedGroupIdx, startRowIdxInChunk}; -} - -static constexpr common::transaction_t UNUSED_COMMIT_TS = INVALID_TRANSACTION; - -void NodeGroup::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, - CSRNodeGroupScanSource source) { - const auto lock = chunkedGroups.lock(); - const auto [startChunkedGroupIdx, startRowIdxInChunk] = actionOnChunkedGroups(lock, startRow, - numRows_, UNUSED_COMMIT_TS, source, &ChunkedNodeGroup::rollbackInsert); - if (startChunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { - const auto numChunkedGroups = chunkedGroups.getNumGroups(lock); - KU_ASSERT(startChunkedGroupIdx < numChunkedGroups); - const bool shouldRemoveStartChunk = (startRowIdxInChunk == 0); - const auto numChunksToRemove = - numChunkedGroups - startChunkedGroupIdx - (shouldRemoveStartChunk ? 0 : 1); - chunkedGroups.removeTrailingGroups(lock, numChunksToRemove); - - numRows = startRow; - } -} - -void NodeGroup::rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, - CSRNodeGroupScanSource source) { - const auto lock = chunkedGroups.lock(); - actionOnChunkedGroups(lock, startRow, numRows_, UNUSED_COMMIT_TS, source, - &ChunkedNodeGroup::rollbackDelete); -} - -void NodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRows_, common::transaction_t commitTS, - CSRNodeGroupScanSource source) { - const auto lock = chunkedGroups.lock(); - actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, - &ChunkedNodeGroup::commitInsert); +static idx_t getNumEmptyTrailingGroups(const GroupCollection& nodeGroups, + const common::UniqLock& lock) { + const auto& chunkedGroupVector = nodeGroups.getAllGroups(lock); + return safeIntegerConversion( + std::find_if(chunkedGroupVector.rbegin(), chunkedGroupVector.rend(), + [](const auto& chunkedGroup) { return (chunkedGroup->getNumRows() != 0); }) - + chunkedGroupVector.rbegin()); } -void NodeGroup::commitDelete(row_idx_t startRow, row_idx_t numRows_, common::transaction_t commitTS, - CSRNodeGroupScanSource source) { +void NodeGroup::rollbackInsert(common::row_idx_t startRow) { const auto lock = chunkedGroups.lock(); - actionOnChunkedGroups(lock, startRow, numRows_, commitTS, source, - &ChunkedNodeGroup::commitDelete); + const auto numEmptyTrailingGroups = getNumEmptyTrailingGroups(chunkedGroups, lock); + chunkedGroups.removeTrailingGroups(lock, numEmptyTrailingGroups); + numRows = startRow; } void NodeGroup::checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state) { diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index 9a6ff10cf3e..b6e032764f6 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -15,9 +15,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer, const transaction::rollback_insert_func_t* rollbackInsertFunc) + Deserializer* deSer, const chunked_group_iterator_construct_t* iteratorConstructFunc) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH}, rollbackInsertFunc(rollbackInsertFunc) { + dataFH{dataFH}, iteratorConstructFunc(iteratorConstructFunc) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -55,6 +55,7 @@ void NodeGroupCollection::append(const Transaction* transaction, lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); pushInsertInfo(transaction, lastNodeGroup, numToAppendInNodeGroup); lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); + numTotalRows += numToAppendInNodeGroup; numRowsAppended += numToAppendInNodeGroup; } stats.incrementCardinality(numRowsAppended); @@ -96,6 +97,7 @@ void NodeGroupCollection::append(const Transaction* transaction, NodeGroup& node pushInsertInfo(transaction, lastNodeGroup, numToAppendInBatch); lastNodeGroup->append(transaction, *chunkedGroupToAppend, numRowsAppendedInChunkedGroup, numToAppendInBatch); + numTotalRows += numToAppendInBatch; numRowsAppendedInChunkedGroup += numToAppendInBatch; } numChunkedGroupsAppended++; @@ -130,6 +132,7 @@ std::pair NodeGroupCollection::appendToLastNodeGroupAndFlush directFlushWhenAppend = numToAppend == numRowsLeftInLastNodeGroup && lastNodeGroup->getNumRows() == 0; pushInsertInfo(transaction, lastNodeGroup, numToAppend); + numTotalRows += numToAppend; if (!directFlushWhenAppend) { // TODO(Guodong): Furthur optimize on this. Should directly figure out startRowIdx to // start appending into the node group and pass in as param. @@ -152,7 +155,8 @@ row_idx_t NodeGroupCollection::getNumTotalRows() { } NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* transaction, - node_group_idx_t groupIdx, NodeGroupDataFormat format) { + node_group_idx_t groupIdx, NodeGroupDataFormat format, + const chunked_group_iterator_construct_t* constructIteratorFunc_) { const auto lock = nodeGroups.lock(); while (groupIdx >= nodeGroups.getNumGroups(lock)) { const auto currentGroupIdx = nodeGroups.getNumGroups(lock); @@ -163,8 +167,7 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* t enableCompression, LogicalType::copy(types))); // push an insert of size 0 so that we can rollback the creation of this node group if // needed - pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, - CSRNodeGroupScanSource::COMMITTED_PERSISTENT); + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, constructIteratorFunc_); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); return nodeGroups.getGroup(lock, groupIdx); @@ -205,67 +208,32 @@ static idx_t getNumEmptyTrailingGroups(const GroupCollection& nodeGro nodeGroupVector.rbegin()); } -void NodeGroupCollection::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx, CSRNodeGroupScanSource source) { +void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool updateNumRows) { const auto lock = nodeGroups.lock(); - // skip the rollback if all newly created node groups have already been deleted - if (!nodeGroups.isEmpty(lock) || nodeGroupIdx > 0) { - KU_ASSERT(nodeGroupIdx < nodeGroups.getNumGroups(lock)); - auto* nodeGroup = nodeGroups.getGroup(lock, nodeGroupIdx); - if (nodeGroup) { - KU_ASSERT(startRow <= nodeGroup->getNumRows()); - nodeGroup->rollbackInsert(startRow, numRows_, source); - // remove any empty trailing node groups after the rollback - const auto numGroupsToRemove = getNumEmptyTrailingGroups(nodeGroups, lock); - nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); - } - } + // remove any empty trailing node groups after the rollback + const auto numGroupsToRemove = getNumEmptyTrailingGroups(nodeGroups, lock); + nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); - if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + if (updateNumRows) { KU_ASSERT(numRows_ <= numTotalRows); numTotalRows -= numRows_; } } -void NodeGroupCollection::rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx, CSRNodeGroupScanSource source) { - const auto lock = nodeGroups.lock(); - KU_ASSERT(nodeGroupIdx < nodeGroups.getNumGroups(lock)); - nodeGroups.getGroup(lock, nodeGroupIdx)->rollbackDelete(startRow, numRows_, source); -} - -void NodeGroupCollection::commitInsert(row_idx_t startRow, row_idx_t numRows_, - node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, CSRNodeGroupScanSource source) { - if (numRows_ == 0) { - return; - } - const auto lock = nodeGroups.lock(); - nodeGroups.getGroup(lock, nodeGroupIdx)->commitInsert(startRow, numRows_, commitTS, source); -} - -void NodeGroupCollection::commitDelete(row_idx_t startRow, row_idx_t numRows_, - node_group_idx_t nodeGroupIdx, common::transaction_t commitTS, CSRNodeGroupScanSource source) { - const auto lock = nodeGroups.lock(); - nodeGroups.getGroup(lock, nodeGroupIdx)->commitDelete(startRow, numRows_, commitTS, source); -} - void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, - NodeGroup* nodeGroup, common::row_idx_t numRows, CSRNodeGroupScanSource source) { + NodeGroup* nodeGroup, common::row_idx_t numRows, + const chunked_group_iterator_construct_t* constructIteratorFunc_) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, - source); + constructIteratorFunc_ ? constructIteratorFunc_ : iteratorConstructFunc); }; void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source) { + const chunked_group_iterator_construct_t* constructIteratorFunc_) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(this, nodeGroupIdx, startRow, numRows, source, - rollbackInsertFunc); - } - if (source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { - numTotalRows += numRows; + transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, constructIteratorFunc_); } } diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index ec86b403a74..5789da1fe02 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -21,6 +21,16 @@ using namespace kuzu::evaluator; namespace kuzu { namespace storage { +NodeTable::NodeGroupIterator::NodeGroupIterator(NodeTable* table, node_group_idx_t nodeGroupidx, + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) + : NodeGroup::NodeGroupBaseIterator(table->nodeGroups.get(), nodeGroupidx, startRow, numRows, + commitTS), + table(table) {} + +void NodeTable::NodeGroupIterator::initRollbackInsert(const transaction::Transaction* transaction) { + table->rollbackInsert(transaction, startRow, numRows, nodeGroup->getNodeGroupIdx()); +} + bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset, offset_t numNodes) { KU_ASSERT(columns.size() == outputVectors.size()); @@ -220,15 +230,16 @@ NodeTable::NodeTable(const StorageManager* storageManager, dataFH, memoryManager, shadowFile, enableCompression); } - rollbackInsertFunc = [this](const transaction::Transaction* transaction, - common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx_, CSRNodeGroupScanSource) { - return rollbackInsert(transaction, startRow, numRows_, nodeGroupIdx_); + iteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx_, + common::transaction_t commitTS) { + return std::make_unique(this, nodeGroupIdx_, startRow, numRows_, + commitTS); }; nodeGroups = std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), - enableCompression, storageManager->getDataFH(), deSer, &rollbackInsertFunc); + enableCompression, storageManager->getDataFH(), deSer, &iteratorConstructFunc); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); } @@ -444,7 +455,7 @@ bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroups.get(), nodeGroupIdx, rowIdxInGroup, 1); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, &iteratorConstructFunc); } } if (isDeleted) { @@ -530,8 +541,8 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); KU_ASSERT(isDeleted); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroups.get(), nodeGroupIdx, rowIdxInGroup, - 1); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, + &iteratorConstructFunc); } } } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index 8c36efb9c62..f952873b276 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -25,6 +25,21 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh multiplicity = tableEntry->constCast().getMultiplicity(direction); initCSRHeaderColumns(); initPropertyColumns(tableEntry); + + inMemIteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx_, + common::transaction_t commitTS) { + return std::make_unique(nodeGroups.get(), nodeGroupIdx_, + startRow, numRows_, commitTS); + }; + + persistentIteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx_, + common::transaction_t commitTS) { + return std::make_unique(nodeGroups.get(), nodeGroupIdx_, + startRow, numRows_, commitTS); + }; + nodeGroups = std::make_unique(*mm, getColumnTypes(), enableCompression, dataFH, deSer); } @@ -98,7 +113,11 @@ bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVec auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); if (isDeleted && transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroups.get(), nodeGroupIdx, rowIdx, 1, source); + const auto* constructIteratorFunc = + (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? + &persistentIteratorConstructFunc : + &inMemIteratorConstructFunc; + transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, constructIteratorFunc); } return isDeleted; } @@ -192,11 +211,13 @@ bool RelTableData::checkIfNodeHasRels(Transaction* transaction, void RelTableData::pushInsertInfo(transaction::Transaction* transaction, const CSRNodeGroup& nodeGroup, common::row_idx_t numRows_, CSRNodeGroupScanSource source) { - const auto startRow = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? - nodeGroup.getNumPersistentRows() : - nodeGroup.getNumRows(); + const auto [startRow, constructIteratorFunc] = + (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? + std::make_pair(nodeGroup.getNumPersistentRows(), &persistentIteratorConstructFunc) : + std::make_pair(nodeGroup.getNumRows(), &inMemIteratorConstructFunc); + nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, - source); + constructIteratorFunc); } void RelTableData::checkpoint(const std::vector& columnIDs) { diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index e097d2588c2..41c38ef2749 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -38,12 +38,10 @@ struct NodeBatchInsertRecord { }; struct VersionRecord { - NodeGroupCollection* nodeGroupCollection; row_idx_t startRow; row_idx_t numRows; node_group_idx_t nodeGroupIdx; - const transaction::rollback_insert_func_t* rollbackInsertCallback; - CSRNodeGroupScanSource source; + const chunked_group_iterator_construct_t* iteratorConstructFunc; }; struct VectorUpdateRecord { @@ -113,30 +111,27 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, *reinterpret_cast(buffer) = sequenceEntryRecord; } -void UndoBuffer::createInsertInfo(NodeGroupCollection* nodeGroups, node_group_idx_t nodeGroupIdx, - row_idx_t startRow, row_idx_t numRows, storage::CSRNodeGroupScanSource source, - const transaction::rollback_insert_func_t* rollbackInsertFunc) { - createVersionInfo(UndoRecordType::INSERT_INFO, nodeGroups, startRow, numRows, nodeGroupIdx, - source, rollbackInsertFunc); +void UndoBuffer::createInsertInfo(node_group_idx_t nodeGroupIdx, row_idx_t startRow, + row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc) { + createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, iteratorConstructFunc, + nodeGroupIdx); } -void UndoBuffer::createDeleteInfo(NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source) { - createVersionInfo(UndoRecordType::DELETE_INFO, nodeGroups, startRow, numRows, nodeGroupIdx, - source); +void UndoBuffer::createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc) { + createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, iteratorConstructFunc, + nodeGroupIdx); } -void UndoBuffer::createVersionInfo(const UndoRecordType recordType, - NodeGroupCollection* nodeGroupCollection, row_idx_t startRow, row_idx_t numRows, - node_group_idx_t nodeGroupIdx, storage::CSRNodeGroupScanSource source, - const transaction::rollback_insert_func_t* rollbackInsertFunc) { +void UndoBuffer::createVersionInfo(const UndoRecordType recordType, row_idx_t startRow, + row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc, + node_group_idx_t nodeGroupIdx) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); - *reinterpret_cast(buffer) = VersionRecord{nodeGroupCollection, startRow, - numRows, nodeGroupIdx, rollbackInsertFunc, source}; + *reinterpret_cast(buffer) = + VersionRecord{startRow, numRows, nodeGroupIdx, iteratorConstructFunc}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -221,12 +216,14 @@ void UndoBuffer::commitVersionInfo(UndoRecordType recordType, const uint8_t* rec const auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - undoRecord.nodeGroupCollection->commitInsert(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, commitTS, undoRecord.source); + auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, commitTS); + it->iterate(&ChunkedNodeGroup::commitInsert); } break; case UndoRecordType::DELETE_INFO: { - undoRecord.nodeGroupCollection->commitDelete(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, commitTS, undoRecord.source); + auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, commitTS); + it->iterate(&ChunkedNodeGroup::commitDelete); } break; default: { KU_UNREACHABLE; @@ -299,19 +296,21 @@ void UndoBuffer::rollbackSequenceEntry(const uint8_t* entry) { void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction, UndoRecordType recordType, const uint8_t* record) { + static constexpr transaction_t UNUSED_COMMIT_TS = INVALID_TRANSACTION; + auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - if (undoRecord.rollbackInsertCallback) { - (*undoRecord.rollbackInsertCallback)(transaction, undoRecord.startRow, - undoRecord.numRows, undoRecord.nodeGroupIdx, undoRecord.source); - } - undoRecord.nodeGroupCollection->rollbackInsert(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, undoRecord.source); + auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, UNUSED_COMMIT_TS); + it->initRollbackInsert(transaction); + it->iterate(&ChunkedNodeGroup::rollbackInsert); + it->finalizeRollbackInsert(); } break; case UndoRecordType::DELETE_INFO: { - undoRecord.nodeGroupCollection->rollbackDelete(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, undoRecord.source); + auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, + undoRecord.nodeGroupIdx, UNUSED_COMMIT_TS); + it->iterate(&ChunkedNodeGroup::rollbackDelete); } break; default: { KU_UNREACHABLE; diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index b8fd6cc0743..4d5636f3413 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -173,18 +173,16 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ } } -void Transaction::pushInsertInfo(storage::NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source, - const transaction::rollback_insert_func_t* rollbackInsertCallback) const { - undoBuffer->createInsertInfo(nodeGroups, nodeGroupIdx, startRow, numRows, source, - rollbackInsertCallback); +void Transaction::pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, + const chunked_group_iterator_construct_t* constructIteratorFunc) const { + undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, constructIteratorFunc); } -void Transaction::pushDeleteInfo(storage::NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - storage::CSRNodeGroupScanSource source) const { - undoBuffer->createDeleteInfo(nodeGroups, nodeGroupIdx, startRow, numRows, source); +void Transaction::pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, + const chunked_group_iterator_construct_t* constructIteratorFunc) const { + undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, constructIteratorFunc); } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, From 739b5d87645d1fa78b8fcc0599c443e67d9173da Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Tue, 19 Nov 2024 16:25:25 -0500 Subject: [PATCH 11/28] Bug fixes + code cleanup --- src/include/storage/store/group_collection.h | 9 ++++ src/include/storage/store/node_group.h | 2 +- .../storage/store/node_group_collection.h | 1 - src/storage/store/node_group.cpp | 46 ++++++++----------- src/storage/store/node_group_collection.cpp | 12 +---- src/storage/undo_buffer.cpp | 2 - 6 files changed, 30 insertions(+), 42 deletions(-) diff --git a/src/include/storage/store/group_collection.h b/src/include/storage/store/group_collection.h index de10c2a056e..8aa7185397e 100644 --- a/src/include/storage/store/group_collection.h +++ b/src/include/storage/store/group_collection.h @@ -6,6 +6,7 @@ #include "common/serializer/serializer.h" #include "common/types/types.h" #include "common/uniq_lock.h" +#include "common/utils.h" namespace kuzu { namespace storage { @@ -116,6 +117,14 @@ class GroupCollection { groups.clear(); } + common::idx_t getNumEmptyTrailingGroups(const common::UniqLock& lock) { + const auto& groupsVector = getAllGroups(lock); + return common::safeIntegerConversion( + std::find_if(groupsVector.rbegin(), groupsVector.rend(), + [](const auto& group) { return (group->getNumRows() != 0); }) - + groupsVector.rbegin()); + } + private: mutable std::mutex mtx; std::vector> groups; diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 005e5e314d8..70c76f13854 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -3,7 +3,6 @@ #include #include "common/uniq_lock.h" -#include "storage/enums/csr_node_group_scan_source.h" #include "storage/enums/residency_state.h" #include "storage/store/chunked_group_undo_iterator.h" #include "storage/store/chunked_node_group.h" @@ -93,6 +92,7 @@ class NodeGroup { protected: NodeGroup* nodeGroup; + common::row_idx_t numRowsToRollback; }; NodeGroup(const common::node_group_idx_t nodeGroupIdx, const bool enableCompression, diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index e2c2bd070df..ccb3b32c157 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -1,6 +1,5 @@ #pragma once -#include "storage/enums/csr_node_group_scan_source.h" #include "storage/stats/table_stats.h" #include "storage/store/group_collection.h" #include "storage/store/node_group.h" diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index f35b8758292..f68e59ba86b 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -4,7 +4,6 @@ #include "common/constants.h" #include "common/types/types.h" #include "common/uniq_lock.h" -#include "common/utils.h" #include "main/client_context.h" #include "storage/buffer_manager/memory_manager.h" #include "storage/enums/residency_state.h" @@ -26,7 +25,10 @@ NodeGroup::NodeGroupBaseIterator::NodeGroupBaseIterator(NodeGroupCollection* nod common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, transaction_t commitTS) : ChunkedGroupUndoIterator(nodeGroups, startRow, numRows, commitTS), - nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)) {} + nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)), + numRowsToRollback(std::min(numRows, nodeGroup->getNumRows() - startRow)) { + KU_ASSERT(startRow <= nodeGroup->getNumRows()); +} void NodeGroup::NodeGroupBaseIterator::iterate(chunked_group_undo_op_t undoFunc) { auto lock = nodeGroup->chunkedGroups.lock(); @@ -53,7 +55,7 @@ void NodeGroup::NodeGroupBaseIterator::iterate(chunked_group_undo_op_t undoFunc) void NodeGroup::NodeGroupBaseIterator::finalizeRollbackInsert() { nodeGroup->rollbackInsert(startRow); - nodeGroups->rollbackInsert(numRows); + nodeGroups->rollbackInsert(numRowsToRollback); } row_idx_t NodeGroup::append(const Transaction* transaction, ChunkedNodeGroup& chunkedGroup, @@ -388,18 +390,9 @@ void NodeGroup::flush(Transaction* transaction, FileHandle& dataFH) { chunkedGroups.resize(lock, 1); } -static idx_t getNumEmptyTrailingGroups(const GroupCollection& nodeGroups, - const common::UniqLock& lock) { - const auto& chunkedGroupVector = nodeGroups.getAllGroups(lock); - return safeIntegerConversion( - std::find_if(chunkedGroupVector.rbegin(), chunkedGroupVector.rend(), - [](const auto& chunkedGroup) { return (chunkedGroup->getNumRows() != 0); }) - - chunkedGroupVector.rbegin()); -} - void NodeGroup::rollbackInsert(common::row_idx_t startRow) { const auto lock = chunkedGroups.lock(); - const auto numEmptyTrailingGroups = getNumEmptyTrailingGroups(chunkedGroups, lock); + const auto numEmptyTrailingGroups = chunkedGroups.getNumEmptyTrailingGroups(lock); chunkedGroups.removeTrailingGroups(lock, numEmptyTrailingGroups); numRows = startRow; } @@ -409,21 +402,20 @@ void NodeGroup::checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointStat // TODO(Guodong): A special case can be all rows are deleted or rollbacked, then we can skip // flushing the data. const auto lock = chunkedGroups.lock(); - if (!chunkedGroups.isEmpty(lock)) { - const auto firstGroup = chunkedGroups.getFirstGroup(lock); - const auto hasPersistentData = firstGroup->getResidencyState() == ResidencyState::ON_DISK; - // Re-populate version info here first. - auto checkpointedVersionInfo = checkpointVersionInfo(lock, &DUMMY_CHECKPOINT_TRANSACTION); - std::unique_ptr checkpointedChunkedGroup; - if (hasPersistentData) { - checkpointedChunkedGroup = checkpointInMemAndOnDisk(memoryManager, lock, state); - } else { - checkpointedChunkedGroup = checkpointInMemOnly(memoryManager, lock, state); - } - checkpointedChunkedGroup->setVersionInfo(std::move(checkpointedVersionInfo)); - chunkedGroups.clear(lock); - chunkedGroups.appendGroup(lock, std::move(checkpointedChunkedGroup)); + KU_ASSERT(chunkedGroups.getNumGroups(lock) >= 1); + const auto firstGroup = chunkedGroups.getFirstGroup(lock); + const auto hasPersistentData = firstGroup->getResidencyState() == ResidencyState::ON_DISK; + // Re-populate version info here first. + auto checkpointedVersionInfo = checkpointVersionInfo(lock, &DUMMY_CHECKPOINT_TRANSACTION); + std::unique_ptr checkpointedChunkedGroup; + if (hasPersistentData) { + checkpointedChunkedGroup = checkpointInMemAndOnDisk(memoryManager, lock, state); + } else { + checkpointedChunkedGroup = checkpointInMemOnly(memoryManager, lock, state); } + checkpointedChunkedGroup->setVersionInfo(std::move(checkpointedVersionInfo)); + chunkedGroups.clear(lock); + chunkedGroups.appendGroup(lock, std::move(checkpointedChunkedGroup)); } std::unique_ptr NodeGroup::checkpointInMemAndOnDisk(MemoryManager& memoryManager, diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index b6e032764f6..23a2937332c 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -1,6 +1,5 @@ #include "storage/store/node_group_collection.h" -#include "common/utils.h" #include "common/vector/value_vector.h" #include "storage/buffer_manager/memory_manager.h" #include "storage/store/csr_node_group.h" @@ -199,20 +198,11 @@ void NodeGroupCollection::checkpoint(MemoryManager& memoryManager, } } -static idx_t getNumEmptyTrailingGroups(const GroupCollection& nodeGroups, - const common::UniqLock& lock) { - const auto& nodeGroupVector = nodeGroups.getAllGroups(lock); - return safeIntegerConversion( - std::find_if(nodeGroupVector.rbegin(), nodeGroupVector.rend(), - [](const auto& nodeGroup) { return (nodeGroup->getNumRows() != 0); }) - - nodeGroupVector.rbegin()); -} - void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool updateNumRows) { const auto lock = nodeGroups.lock(); // remove any empty trailing node groups after the rollback - const auto numGroupsToRemove = getNumEmptyTrailingGroups(nodeGroups, lock); + const auto numGroupsToRemove = nodeGroups.getNumEmptyTrailingGroups(lock); nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); if (updateNumRows) { diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 41c38ef2749..1c059a18e3f 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -4,8 +4,6 @@ #include "catalog/catalog_entry/sequence_catalog_entry.h" #include "catalog/catalog_entry/table_catalog_entry.h" #include "catalog/catalog_set.h" -#include "storage/store/node_table.h" -#include "storage/store/rel_table_data.h" #include "storage/store/update_info.h" using namespace kuzu::catalog; From 14fb9a1f0feeb88327ce67f0a4f0d578c972acf8 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Tue, 19 Nov 2024 17:23:47 -0500 Subject: [PATCH 12/28] Cleanup node table + actually use semi mask --- src/include/common/mask.h | 7 + src/include/storage/store/node_table.h | 20 ++- src/storage/store/node_table.cpp | 179 +++++++++++-------------- 3 files changed, 102 insertions(+), 104 deletions(-) diff --git a/src/include/common/mask.h b/src/include/common/mask.h index ad1fbfb1f2f..0f2d2ce667d 100644 --- a/src/include/common/mask.h +++ b/src/include/common/mask.h @@ -17,6 +17,7 @@ class RoaringBitmapSemiMask { virtual ~RoaringBitmapSemiMask() = default; virtual void mask(common::offset_t nodeOffset) = 0; + virtual void maskRange(common::offset_t startNodeOffset, common::offset_t endNodeOffset) = 0; virtual bool isMasked(common::offset_t startNodeOffset) = 0; @@ -44,6 +45,9 @@ class Roaring32BitmapSemiMask : public RoaringBitmapSemiMask { } void mask(common::offset_t nodeOffset) override { roaring->add(nodeOffset); } + void maskRange(common::offset_t startNodeOffset, common::offset_t endNodeOffset) override { + roaring->addRange(startNodeOffset, endNodeOffset); + } bool isMasked(common::offset_t startNodeOffset) override { return roaring->contains(startNodeOffset); @@ -76,6 +80,9 @@ class Roaring64BitmapSemiMask : public RoaringBitmapSemiMask { roaring(std::make_shared()) {} void mask(common::offset_t nodeOffset) override { roaring->add(nodeOffset); } + void maskRange(common::offset_t startNodeOffset, common::offset_t endNodeOffset) override { + roaring->addRange(startNodeOffset, endNodeOffset); + } bool isMasked(common::offset_t startNodeOffset) override { return roaring->contains(startNodeOffset); diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index ae73526ef01..671be55b8b2 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -81,6 +81,20 @@ struct NodeTableDeleteState final : TableDeleteState { : nodeIDVector{nodeIDVector}, pkVector{pkVector} {} }; +struct PKColumnScanHelper { + explicit PKColumnScanHelper(PrimaryKeyIndex* pkIndex, common::table_id_t tableID) + : tableID(tableID), pkIndex(pkIndex) {} + virtual ~PKColumnScanHelper() = default; + + virtual std::unique_ptr initPKScanState(common::DataChunk& dataChunk, + common::column_id_t pkColumnID, const std::vector>& columns); + virtual bool processScanOutput(const transaction::Transaction* transaction, + NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) = 0; + + common::table_id_t tableID; + PrimaryKeyIndex* pkIndex; +}; + class StorageManager; class NodeTable final : public Table { public: @@ -192,18 +206,14 @@ class NodeTable final : public Table { } private: - void insertPK(const transaction::Transaction* transaction, - const common::ValueVector& nodeIDVector, const common::ValueVector& pkVector) const; void validatePkNotExists(const transaction::Transaction* transaction, common::ValueVector* pkVector); void serialize(common::Serializer& serializer) const override; - std::unique_ptr initPKScanState(common::DataChunk& dataChunk, - TableScanSource source) const; - visible_func getVisibleFunc(const transaction::Transaction* transaction) const; common::DataChunk constructDataChunkForPKColumn() const; + void scanPKColumn(const transaction::Transaction* transaction, PKColumnScanHelper& scanHelper); private: std::vector> columns; diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 5789da1fe02..89883f45dce 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -53,76 +53,46 @@ bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset return true; } -template -concept notIndexHashable = !IndexHashable; - namespace { -struct PKColumnScanHelper { - explicit PKColumnScanHelper(common::node_group_idx_t numNodeGroups, PrimaryKeyIndex* pkIndex, - common::DataChunk dataChunk, table_id_t tableID) - : numNodeGroups(numNodeGroups), dataChunk(std::move(dataChunk)), tableID(tableID), - pkIndex(pkIndex) {} - virtual ~PKColumnScanHelper() = default; - - virtual bool processScanOutput(const transaction::Transaction* transaction, - NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) = 0; - virtual NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const = 0; - - common::node_group_idx_t numNodeGroups; - common::DataChunk dataChunk; - table_id_t tableID; - PrimaryKeyIndex* pkIndex; -}; -struct CommittedPKColumnScanHelper : public PKColumnScanHelper { +struct CommittedPKInserter : public PKColumnScanHelper { public: - CommittedPKColumnScanHelper(LocalNodeTable& localTable, row_idx_t startNodeOffset, - DataChunk dataChunk, table_id_t tableID, PrimaryKeyIndex* pkIndex, visible_func isVisible) - : PKColumnScanHelper(localTable.getNumNodeGroups(), pkIndex, std::move(dataChunk), tableID), - localTable(localTable), startNodeOffset(startNodeOffset), - nodeIDVector(LogicalType::INTERNAL_ID()), isVisible(std::move(isVisible)) { - nodeIDVector.setState(this->dataChunk.state); - } + CommittedPKInserter(row_idx_t startNodeOffset, table_id_t tableID, PrimaryKeyIndex* pkIndex, + visible_func isVisible) + : PKColumnScanHelper(pkIndex, tableID), startNodeOffset(startNodeOffset), + nodeIDVector(LogicalType::INTERNAL_ID()), isVisible(std::move(isVisible)) {} + + std::unique_ptr initPKScanState(DataChunk& dataChunk, + column_id_t pkColumnID, const std::vector>& columns) override; bool processScanOutput(const transaction::Transaction* transaction, NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) override; - NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const override { - return localTable.getNodeGroup(nodeGroupIdx); - } - - LocalNodeTable& localTable; row_idx_t startNodeOffset; ValueVector nodeIDVector; visible_func isVisible; }; -struct RollbackPKColumnScanHelper : public PKColumnScanHelper { +struct RollbackPKDeleter : public PKColumnScanHelper { public: - RollbackPKColumnScanHelper(row_idx_t startNodeOffset, row_idx_t numRows, - NodeGroupCollection& nodeGroups, DataChunk dataChunk, table_id_t tableID, + RollbackPKDeleter(row_idx_t startNodeOffset, row_idx_t numRows, table_id_t tableID, PrimaryKeyIndex* pkIndex) - : PKColumnScanHelper(nodeGroups.getNumNodeGroups(), pkIndex, std::move(dataChunk), tableID), + : PKColumnScanHelper(pkIndex, tableID), semiMask(RoaringBitmapSemiMaskUtil::createRoaringBitmapSemiMask(tableID, - startNodeOffset + numRows)), - nodeGroups(nodeGroups) { - for (row_idx_t i = 0; i < numRows; ++i) { - semiMask->mask(startNodeOffset + i); - } + startNodeOffset + numRows)) { + semiMask->maskRange(startNodeOffset, startNodeOffset + numRows); } + std::unique_ptr initPKScanState(DataChunk& dataChunk, + column_id_t pkColumnID, const std::vector>& columns) override; + bool processScanOutput(const transaction::Transaction* transaction, NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) override; - NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const override { - return nodeGroups.getNodeGroup(nodeGroupIdx); - } - std::unique_ptr semiMask; - NodeGroupCollection& nodeGroups; }; -static void insertPKInternal(const Transaction* transaction, const ValueVector& nodeIDVector, +static void insertPK(const Transaction* transaction, const ValueVector& nodeIDVector, const ValueVector& pkVector, PrimaryKeyIndex* pkIndex, const visible_func& isVisible) { for (auto i = 0u; i < nodeIDVector.state->getSelVector().getSelSize(); i++) { const auto nodeIDPos = nodeIDVector.state->getSelVector()[i]; @@ -138,28 +108,15 @@ static void insertPKInternal(const Transaction* transaction, const ValueVector& } } -void scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper, - std::unique_ptr scanState) { - - node_group_idx_t nodeGroupToScan = 0u; - while (nodeGroupToScan < scanHelper.numNodeGroups) { - // We need to scan from local storage here because some tuples in local node groups might - // have been deleted. - scanState->nodeGroup = scanHelper.getNodeGroup(nodeGroupToScan); - KU_ASSERT(scanState->nodeGroup); - scanState->nodeGroup->initializeScanState(transaction, *scanState); - while (true) { - auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); - if (!scanHelper.processScanOutput(transaction, scanResult, - *scanState->outputVectors[0])) { - break; - } - } - nodeGroupToScan++; - } +std::unique_ptr CommittedPKInserter::initPKScanState(DataChunk& dataChunk, + column_id_t pkColumnID, const std::vector>& columns) { + auto scanState = PKColumnScanHelper::initPKScanState(dataChunk, pkColumnID, columns); + nodeIDVector.setState(dataChunk.state); + scanState->source = TableScanSource::UNCOMMITTED; + return scanState; } -bool CommittedPKColumnScanHelper::processScanOutput(const transaction::Transaction* transaction, +bool CommittedPKInserter::processScanOutput(const transaction::Transaction* transaction, NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) { if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { return false; @@ -167,12 +124,23 @@ bool CommittedPKColumnScanHelper::processScanOutput(const transaction::Transacti for (auto i = 0u; i < scanResult.numRows; i++) { nodeIDVector.setValue(i, nodeID_t{startNodeOffset + i, tableID}); } - insertPKInternal(transaction, nodeIDVector, scannedVector, pkIndex, isVisible); + insertPK(transaction, nodeIDVector, scannedVector, pkIndex, isVisible); startNodeOffset += scanResult.numRows; return true; } -bool RollbackPKColumnScanHelper::processScanOutput(const transaction::Transaction* transaction, +std::unique_ptr RollbackPKDeleter::initPKScanState(DataChunk& dataChunk, + column_id_t pkColumnID, const std::vector>& columns) { + auto scanState = PKColumnScanHelper::initPKScanState(dataChunk, pkColumnID, columns); + scanState->source = TableScanSource::COMMITTED; + scanState->semiMask = semiMask.get(); + return scanState; +} + +template +concept notIndexHashable = !IndexHashable; + +bool RollbackPKDeleter::processScanOutput(const transaction::Transaction* transaction, NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) { if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { return false; @@ -194,6 +162,42 @@ bool RollbackPKColumnScanHelper::processScanOutput(const transaction::Transactio } } // namespace +std::unique_ptr PKColumnScanHelper::initPKScanState(DataChunk& dataChunk, + column_id_t pkColumnID, const std::vector>& columns) { + std::vector columnIDs{pkColumnID}; + auto scanState = std::make_unique(tableID, columnIDs); + for (auto& vector : dataChunk.valueVectors) { + scanState->outputVectors.push_back(vector.get()); + } + scanState->outState = dataChunk.state.get(); + for (const auto& column : columns) { + scanState->columns.push_back(column.get()); + } + return scanState; +} + +void NodeTable::scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper) { + auto dataChunk = constructDataChunkForPKColumn(); + auto scanState = scanHelper.initPKScanState(dataChunk, pkColumnID, columns); + + node_group_idx_t nodeGroupToScan = 0u; + while (nodeGroupToScan < nodeGroups->getNumNodeGroups()) { + // We need to scan from local storage here because some tuples in local node groups might + // have been deleted. + scanState->nodeGroup = nodeGroups->getNodeGroup(nodeGroupToScan); + KU_ASSERT(scanState->nodeGroup); + scanState->nodeGroup->initializeScanState(transaction, *scanState); + while (true) { + auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); + if (!scanHelper.processScanOutput(transaction, scanResult, + *scanState->outputVectors[0])) { + break; + } + } + nodeGroupToScan++; + } +} + bool NodeTableScanState::scanNext(Transaction* transaction) { KU_ASSERT(columns.size() == outputVectors.size()); if (source == TableScanSource::NONE) { @@ -416,7 +420,8 @@ void NodeTable::update(Transaction* transaction, TableUpdateState& updateState) localTable->update(&dummyTrx, updateState); } else { if (nodeUpdateState.columnID == pkColumnID && pkIndex) { - insertPK(transaction, nodeUpdateState.nodeIDVector, nodeUpdateState.propertyVector); + insertPK(transaction, nodeUpdateState.nodeIDVector, nodeUpdateState.propertyVector, + pkIndex.get(), getVisibleFunc(transaction)); } const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); const auto rowIdxInGroup = @@ -492,21 +497,6 @@ std::pair NodeTable::appendToLastNodeGroup(Transaction* tran return nodeGroups->appendToLastNodeGroupAndFlushWhenFull(transaction, chunkedGroup); } -std::unique_ptr NodeTable::initPKScanState(DataChunk& dataChunk, - TableScanSource source) const { - std::vector columnIDs{getPKColumnID()}; - auto scanState = std::make_unique(tableID, columnIDs); - for (auto& vector : dataChunk.valueVectors) { - scanState->outputVectors.push_back(vector.get()); - } - scanState->outState = dataChunk.state.get(); - scanState->source = source; - for (const auto& column : columns) { - scanState->columns.push_back(column.get()); - } - return scanState; -} - common::DataChunk NodeTable::constructDataChunkForPKColumn() const { std::vector types; types.push_back(columns[pkColumnID]->getDataType().copy()); @@ -551,10 +541,9 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { } // 3. Scan pk column for newly inserted tuples that are not deleted and insert into pk index. - CommittedPKColumnScanHelper scanHelper{localNodeTable, startNodeOffset, - constructDataChunkForPKColumn(), tableID, pkIndex.get(), getVisibleFunc(transaction)}; - scanPKColumn(transaction, scanHelper, - initPKScanState(scanHelper.dataChunk, TableScanSource::UNCOMMITTED)); + CommittedPKInserter scanHelper{startNodeOffset, tableID, pkIndex.get(), + getVisibleFunc(transaction)}; + scanPKColumn(transaction, scanHelper); // 4. Clear local table. localTable->clear(); @@ -565,12 +554,6 @@ visible_func NodeTable::getVisibleFunc(const Transaction* transaction) const { [this, transaction](offset_t offset_) -> bool { return isVisible(transaction, offset_); }; } -void NodeTable::insertPK(const Transaction* transaction, const ValueVector& nodeIDVector, - const ValueVector& pkVector) const { - return insertPKInternal(transaction, nodeIDVector, pkVector, pkIndex.get(), - getVisibleFunc(transaction)); -} - void NodeTable::checkpoint(Serializer& ser, TableCatalogEntry* tableEntry) { if (hasChanges) { // Deleted columns are vaccumed and not checkpointed or serialized. @@ -606,10 +589,8 @@ void NodeTable::rollbackInsert(const transaction::Transaction* transaction, startNodeOffset += nodeGroups->getNodeGroupNoLock(i)->getNumRows(); } - RollbackPKColumnScanHelper scanHelper{startNodeOffset, numRows_, *nodeGroups, - constructDataChunkForPKColumn(), tableID, pkIndex.get()}; - scanPKColumn(transaction, scanHelper, - initPKScanState(scanHelper.dataChunk, TableScanSource::COMMITTED)); + RollbackPKDeleter scanHelper{startNodeOffset, numRows_, tableID, pkIndex.get()}; + scanPKColumn(transaction, scanHelper); } TableStats NodeTable::getStats(const Transaction* transaction) const { From 76ed7d00f2267fa9628bb6c7b95bd7adfe5c1af9 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Wed, 20 Nov 2024 10:49:13 -0500 Subject: [PATCH 13/28] Actually enable semi mask --- src/include/storage/store/node_table.h | 3 ++- src/storage/store/node_table.cpp | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index 671be55b8b2..c8ead97d3ae 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -213,7 +213,8 @@ class NodeTable final : public Table { visible_func getVisibleFunc(const transaction::Transaction* transaction) const; common::DataChunk constructDataChunkForPKColumn() const; - void scanPKColumn(const transaction::Transaction* transaction, PKColumnScanHelper& scanHelper); + void scanPKColumn(const transaction::Transaction* transaction, PKColumnScanHelper& scanHelper, + NodeGroupCollection& nodeGroups_); private: std::vector> columns; diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 89883f45dce..82510a6fac5 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -81,6 +81,7 @@ struct RollbackPKDeleter : public PKColumnScanHelper { semiMask(RoaringBitmapSemiMaskUtil::createRoaringBitmapSemiMask(tableID, startNodeOffset + numRows)) { semiMask->maskRange(startNodeOffset, startNodeOffset + numRows); + semiMask->enable(); } std::unique_ptr initPKScanState(DataChunk& dataChunk, @@ -92,7 +93,7 @@ struct RollbackPKDeleter : public PKColumnScanHelper { std::unique_ptr semiMask; }; -static void insertPK(const Transaction* transaction, const ValueVector& nodeIDVector, +void insertPK(const Transaction* transaction, const ValueVector& nodeIDVector, const ValueVector& pkVector, PrimaryKeyIndex* pkIndex, const visible_func& isVisible) { for (auto i = 0u; i < nodeIDVector.state->getSelVector().getSelSize(); i++) { const auto nodeIDPos = nodeIDVector.state->getSelVector()[i]; @@ -125,7 +126,7 @@ bool CommittedPKInserter::processScanOutput(const transaction::Transaction* tran nodeIDVector.setValue(i, nodeID_t{startNodeOffset + i, tableID}); } insertPK(transaction, nodeIDVector, scannedVector, pkIndex, isVisible); - startNodeOffset += scanResult.numRows; + startNodeOffset = scanResult.startRow + scanResult.numRows; return true; } @@ -176,15 +177,15 @@ std::unique_ptr PKColumnScanHelper::initPKScanState(DataChun return scanState; } -void NodeTable::scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper) { +void NodeTable::scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper, + NodeGroupCollection& nodeGroups_) { auto dataChunk = constructDataChunkForPKColumn(); auto scanState = scanHelper.initPKScanState(dataChunk, pkColumnID, columns); node_group_idx_t nodeGroupToScan = 0u; - while (nodeGroupToScan < nodeGroups->getNumNodeGroups()) { - // We need to scan from local storage here because some tuples in local node groups might - // have been deleted. - scanState->nodeGroup = nodeGroups->getNodeGroup(nodeGroupToScan); + while (nodeGroupToScan < nodeGroups_.getNumNodeGroups()) { + scanState->nodeGroup = nodeGroups_.getNodeGroup(nodeGroupToScan); + scanState->nodeGroupIdx = nodeGroupToScan; KU_ASSERT(scanState->nodeGroup); scanState->nodeGroup->initializeScanState(transaction, *scanState); while (true) { @@ -543,7 +544,9 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { // 3. Scan pk column for newly inserted tuples that are not deleted and insert into pk index. CommittedPKInserter scanHelper{startNodeOffset, tableID, pkIndex.get(), getVisibleFunc(transaction)}; - scanPKColumn(transaction, scanHelper); + // We need to scan from local storage here because some tuples in local node groups might + // have been deleted. + scanPKColumn(transaction, scanHelper, localNodeTable.getNodeGroups()); // 4. Clear local table. localTable->clear(); @@ -590,7 +593,7 @@ void NodeTable::rollbackInsert(const transaction::Transaction* transaction, } RollbackPKDeleter scanHelper{startNodeOffset, numRows_, tableID, pkIndex.get()}; - scanPKColumn(transaction, scanHelper); + scanPKColumn(transaction, scanHelper, *nodeGroups); } TableStats NodeTable::getStats(const Transaction* transaction) const { From 5a1e3537cd2484dcc2ddf38e2777abfe7693b093 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Wed, 20 Nov 2024 11:00:32 -0500 Subject: [PATCH 14/28] Self-review --- .../store/chunked_group_undo_iterator.h | 5 +- src/include/storage/store/csr_node_group.h | 1 - src/include/storage/store/node_group.h | 7 +- .../storage/store/node_group_collection.h | 2 +- src/include/storage/store/node_table.h | 4 +- src/include/transaction/transaction.h | 1 - src/storage/store/chunked_node_group.cpp | 12 +-- src/storage/store/csr_node_group.cpp | 7 -- src/storage/store/node_group.cpp | 8 +- src/storage/store/node_group_collection.cpp | 8 +- src/storage/store/node_table.cpp | 84 ++++++++++--------- src/storage/store/rel_table_data.cpp | 10 ++- src/storage/undo_buffer.cpp | 7 +- 13 files changed, 76 insertions(+), 80 deletions(-) diff --git a/src/include/storage/store/chunked_group_undo_iterator.h b/src/include/storage/store/chunked_group_undo_iterator.h index de15baabddf..bae490c5a51 100644 --- a/src/include/storage/store/chunked_group_undo_iterator.h +++ b/src/include/storage/store/chunked_group_undo_iterator.h @@ -18,9 +18,8 @@ class ChunkedGroupUndoIterator; using chunked_group_undo_op_t = void ( ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); -using chunked_group_iterator_construct_t = - std::function(common::row_idx_t, common::row_idx_t, - common::node_group_idx_t, common::transaction_t commitTS)>; +using chunked_group_iterator_construct_t = std::function( + common::row_idx_t, common::row_idx_t, common::node_group_idx_t, common::transaction_t)>; // Note: these iterators are not necessarily thread-safe when used on their own class ChunkedGroupUndoIterator { diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index 338ef9b607f..a064b42f44e 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -217,7 +217,6 @@ class CSRNodeGroup final : public NodeGroup { bool isEmpty() const override { return !persistentChunkGroup && NodeGroup::isEmpty(); } - common::row_idx_t getNumPersistentRows() const; ChunkedNodeGroup* getPersistentChunkedGroup() const { return persistentChunkGroup.get(); } void setPersistentChunkedGroup(std::unique_ptr chunkedNodeGroup) { KU_ASSERT(chunkedNodeGroup->getFormat() == NodeGroupDataFormat::CSR); diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 70c76f13854..40e3d13a11f 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -82,11 +82,10 @@ static auto NODE_GROUP_SCAN_EMMPTY_RESULT = NodeGroupScanResult{}; struct TableScanState; class NodeGroup { public: - class NodeGroupBaseIterator : public ChunkedGroupUndoIterator { + class ChunkedGroupIterator : public ChunkedGroupUndoIterator { public: - NodeGroupBaseIterator(NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS); + ChunkedGroupIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); void iterate(chunked_group_undo_op_t undoFunc) override; void finalizeRollbackInsert() override; diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index ccb3b32c157..f0964dcfa33 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -87,7 +87,7 @@ class NodeGroupCollection { private: void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc_ = nullptr); + const chunked_group_iterator_construct_t* constructIteratorOverrideFunc = nullptr); bool enableCompression; // Num rows in the collection regardless of deletions. diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index c8ead97d3ae..c3b1e5be8ea 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -98,9 +98,9 @@ struct PKColumnScanHelper { class StorageManager; class NodeTable final : public Table { public: - class NodeGroupIterator : public NodeGroup::NodeGroupBaseIterator { + class ChunkedGroupIterator : public NodeGroup::ChunkedGroupIterator { public: - NodeGroupIterator(NodeTable* table, common::node_group_idx_t nodeGroupIdx, + ChunkedGroupIterator(NodeTable* table, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); void initRollbackInsert(const transaction::Transaction* transaction) override; diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index d2cdc7109b3..363fdc09066 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -22,7 +22,6 @@ class WAL; class VersionInfo; class UpdateInfo; struct VectorUpdateInfo; -class NodeGroupCollection; class ChunkedNodeGroup; class ChunkedGroupUndoIterator; } // namespace storage diff --git a/src/storage/store/chunked_node_group.cpp b/src/storage/store/chunked_node_group.cpp index d1acdeff66a..4e2c3971162 100644 --- a/src/storage/store/chunked_node_group.cpp +++ b/src/storage/store/chunked_node_group.cpp @@ -429,7 +429,12 @@ bool ChunkedNodeGroup::hasUpdates() const { return false; } -void ChunkedNodeGroup::rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, +void ChunkedNodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRowsToCommit, + transaction_t commitTS) { + versionInfo->commitInsert(startRow, numRowsToCommit, commitTS); +} + +void ChunkedNodeGroup::rollbackInsert(row_idx_t startRow, row_idx_t numRows_, common::transaction_t) { if (startRow == 0) { setNumRows(0); @@ -444,11 +449,6 @@ void ChunkedNodeGroup::rollbackInsert(common::row_idx_t startRow, common::row_id numRows = startRow; } -void ChunkedNodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRowsToCommit, - transaction_t commitTS) { - versionInfo->commitInsert(startRow, numRowsToCommit, commitTS); -} - void ChunkedNodeGroup::commitDelete(row_idx_t startRow, row_idx_t numRows_, transaction_t commitTS) { versionInfo->commitDelete(startRow, numRows_, commitTS); diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index 88a7a8a19f4..d68e40edf5b 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -974,12 +974,5 @@ void CSRNodeGroup::finalizeCheckpoint(const UniqLock& lock) { csrIndex.reset(); } -common::row_idx_t CSRNodeGroup::getNumPersistentRows() const { - if (!persistentChunkGroup) { - return 0; - } - return persistentChunkGroup->getNumRows(); -} - } // namespace storage } // namespace kuzu diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index f68e59ba86b..da6ae51e1d6 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -21,7 +21,7 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -NodeGroup::NodeGroupBaseIterator::NodeGroupBaseIterator(NodeGroupCollection* nodeGroups, +NodeGroup::ChunkedGroupIterator::ChunkedGroupIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, transaction_t commitTS) : ChunkedGroupUndoIterator(nodeGroups, startRow, numRows, commitTS), @@ -30,7 +30,7 @@ NodeGroup::NodeGroupBaseIterator::NodeGroupBaseIterator(NodeGroupCollection* nod KU_ASSERT(startRow <= nodeGroup->getNumRows()); } -void NodeGroup::NodeGroupBaseIterator::iterate(chunked_group_undo_op_t undoFunc) { +void NodeGroup::ChunkedGroupIterator::iterate(chunked_group_undo_op_t undoFunc) { auto lock = nodeGroup->chunkedGroups.lock(); const auto [chunkedGroupIdx, startRowInChunkedGroup] = nodeGroup->findChunkedGroupIdxFromRowIdxNoLock(startRow); @@ -53,7 +53,7 @@ void NodeGroup::NodeGroupBaseIterator::iterate(chunked_group_undo_op_t undoFunc) } } -void NodeGroup::NodeGroupBaseIterator::finalizeRollbackInsert() { +void NodeGroup::ChunkedGroupIterator::finalizeRollbackInsert() { nodeGroup->rollbackInsert(startRow); nodeGroups->rollbackInsert(numRowsToRollback); } @@ -136,8 +136,8 @@ void NodeGroup::merge(Transaction*, std::unique_ptr chunkedGro KU_ASSERT(chunkedGroup->getColumnChunk(i).getDataType().getPhysicalType() == dataTypes[i].getPhysicalType()); } - numRows += chunkedGroup->getNumRows(); const auto lock = chunkedGroups.lock(); + numRows += chunkedGroup->getNumRows(); chunkedGroups.appendGroup(lock, std::move(chunkedGroup)); } diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index 23a2937332c..f115c475454 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -53,8 +53,8 @@ void NodeGroupCollection::append(const Transaction* transaction, std::min(numRowsToAppend - numRowsAppended, lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); pushInsertInfo(transaction, lastNodeGroup, numToAppendInNodeGroup); - lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); numTotalRows += numToAppendInNodeGroup; + lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); numRowsAppended += numToAppendInNodeGroup; } stats.incrementCardinality(numRowsAppended); @@ -94,9 +94,9 @@ void NodeGroupCollection::append(const Transaction* transaction, NodeGroup& node lastNodeGroup->getNumRowsLeftToAppend()); lastNodeGroup->moveNextRowToAppend(numToAppendInBatch); pushInsertInfo(transaction, lastNodeGroup, numToAppendInBatch); + numTotalRows += numToAppendInBatch; lastNodeGroup->append(transaction, *chunkedGroupToAppend, numRowsAppendedInChunkedGroup, numToAppendInBatch); - numTotalRows += numToAppendInBatch; numRowsAppendedInChunkedGroup += numToAppendInBatch; } numChunkedGroupsAppended++; @@ -213,9 +213,9 @@ void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool update void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc_) { + const chunked_group_iterator_construct_t* constructIteratorOverrideFunc) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, - constructIteratorFunc_ ? constructIteratorFunc_ : iteratorConstructFunc); + constructIteratorOverrideFunc ? constructIteratorOverrideFunc : iteratorConstructFunc); }; void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 82510a6fac5..511603f2885 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -21,13 +21,15 @@ using namespace kuzu::evaluator; namespace kuzu { namespace storage { -NodeTable::NodeGroupIterator::NodeGroupIterator(NodeTable* table, node_group_idx_t nodeGroupidx, - common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) - : NodeGroup::NodeGroupBaseIterator(table->nodeGroups.get(), nodeGroupidx, startRow, numRows, +NodeTable::ChunkedGroupIterator::ChunkedGroupIterator(NodeTable* table, + node_group_idx_t nodeGroupidx, common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS) + : NodeGroup::ChunkedGroupIterator(table->nodeGroups.get(), nodeGroupidx, startRow, numRows, commitTS), table(table) {} -void NodeTable::NodeGroupIterator::initRollbackInsert(const transaction::Transaction* transaction) { +void NodeTable::ChunkedGroupIterator::initRollbackInsert( + const transaction::Transaction* transaction) { table->rollbackInsert(transaction, startRow, numRows, nodeGroup->getNodeGroupIdx()); } @@ -163,42 +165,6 @@ bool RollbackPKDeleter::processScanOutput(const transaction::Transaction* transa } } // namespace -std::unique_ptr PKColumnScanHelper::initPKScanState(DataChunk& dataChunk, - column_id_t pkColumnID, const std::vector>& columns) { - std::vector columnIDs{pkColumnID}; - auto scanState = std::make_unique(tableID, columnIDs); - for (auto& vector : dataChunk.valueVectors) { - scanState->outputVectors.push_back(vector.get()); - } - scanState->outState = dataChunk.state.get(); - for (const auto& column : columns) { - scanState->columns.push_back(column.get()); - } - return scanState; -} - -void NodeTable::scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper, - NodeGroupCollection& nodeGroups_) { - auto dataChunk = constructDataChunkForPKColumn(); - auto scanState = scanHelper.initPKScanState(dataChunk, pkColumnID, columns); - - node_group_idx_t nodeGroupToScan = 0u; - while (nodeGroupToScan < nodeGroups_.getNumNodeGroups()) { - scanState->nodeGroup = nodeGroups_.getNodeGroup(nodeGroupToScan); - scanState->nodeGroupIdx = nodeGroupToScan; - KU_ASSERT(scanState->nodeGroup); - scanState->nodeGroup->initializeScanState(transaction, *scanState); - while (true) { - auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); - if (!scanHelper.processScanOutput(transaction, scanResult, - *scanState->outputVectors[0])) { - break; - } - } - nodeGroupToScan++; - } -} - bool NodeTableScanState::scanNext(Transaction* transaction) { KU_ASSERT(columns.size() == outputVectors.size()); if (source == TableScanSource::NONE) { @@ -238,7 +204,7 @@ NodeTable::NodeTable(const StorageManager* storageManager, iteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx_, common::transaction_t commitTS) { - return std::make_unique(this, nodeGroupIdx_, startRow, numRows_, + return std::make_unique(this, nodeGroupIdx_, startRow, numRows_, commitTS); }; @@ -641,5 +607,41 @@ bool NodeTable::lookupPK(const Transaction* transaction, ValueVector* keyVector, [&](offset_t offset) { return isVisibleNoLock(transaction, offset); }); } +void NodeTable::scanPKColumn(const Transaction* transaction, PKColumnScanHelper& scanHelper, + NodeGroupCollection& nodeGroups_) { + auto dataChunk = constructDataChunkForPKColumn(); + auto scanState = scanHelper.initPKScanState(dataChunk, pkColumnID, columns); + + node_group_idx_t nodeGroupToScan = 0u; + while (nodeGroupToScan < nodeGroups_.getNumNodeGroups()) { + scanState->nodeGroup = nodeGroups_.getNodeGroup(nodeGroupToScan); + scanState->nodeGroupIdx = nodeGroupToScan; + KU_ASSERT(scanState->nodeGroup); + scanState->nodeGroup->initializeScanState(transaction, *scanState); + while (true) { + auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); + if (!scanHelper.processScanOutput(transaction, scanResult, + *scanState->outputVectors[0])) { + break; + } + } + nodeGroupToScan++; + } +} + +std::unique_ptr PKColumnScanHelper::initPKScanState(DataChunk& dataChunk, + column_id_t pkColumnID, const std::vector>& columns) { + std::vector columnIDs{pkColumnID}; + auto scanState = std::make_unique(tableID, columnIDs); + for (auto& vector : dataChunk.valueVectors) { + scanState->outputVectors.push_back(vector.get()); + } + scanState->outState = dataChunk.state.get(); + for (const auto& column : columns) { + scanState->columns.push_back(column.get()); + } + return scanState; +} + } // namespace storage } // namespace kuzu diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index f952873b276..9a81bb76d1f 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -29,7 +29,7 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh inMemIteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx_, common::transaction_t commitTS) { - return std::make_unique(nodeGroups.get(), nodeGroupIdx_, + return std::make_unique(nodeGroups.get(), nodeGroupIdx_, startRow, numRows_, commitTS); }; @@ -211,9 +211,15 @@ bool RelTableData::checkIfNodeHasRels(Transaction* transaction, void RelTableData::pushInsertInfo(transaction::Transaction* transaction, const CSRNodeGroup& nodeGroup, common::row_idx_t numRows_, CSRNodeGroupScanSource source) { + // we shouldn't be appending directly to the to the persistent data + // unless we are performing batch insert and the persistent chunked group is empty + KU_ASSERT(source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT || + !nodeGroup.getPersistentChunkedGroup() || + nodeGroup.getPersistentChunkedGroup()->getNumRows() == 0); + const auto [startRow, constructIteratorFunc] = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? - std::make_pair(nodeGroup.getNumPersistentRows(), &persistentIteratorConstructFunc) : + std::make_pair(static_cast(0), &persistentIteratorConstructFunc) : std::make_pair(nodeGroup.getNumRows(), &inMemIteratorConstructFunc); nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 1c059a18e3f..27fa98314f0 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -5,6 +5,7 @@ #include "catalog/catalog_entry/table_catalog_entry.h" #include "catalog/catalog_set.h" #include "storage/store/update_info.h" +#include "transaction/transaction.h" using namespace kuzu::catalog; using namespace kuzu::common; @@ -294,20 +295,18 @@ void UndoBuffer::rollbackSequenceEntry(const uint8_t* entry) { void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction, UndoRecordType recordType, const uint8_t* record) { - static constexpr transaction_t UNUSED_COMMIT_TS = INVALID_TRANSACTION; - auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, UNUSED_COMMIT_TS); + undoRecord.nodeGroupIdx, transaction->getCommitTS()); it->initRollbackInsert(transaction); it->iterate(&ChunkedNodeGroup::rollbackInsert); it->finalizeRollbackInsert(); } break; case UndoRecordType::DELETE_INFO: { auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, UNUSED_COMMIT_TS); + undoRecord.nodeGroupIdx, transaction->getCommitTS()); it->iterate(&ChunkedNodeGroup::rollbackDelete); } break; default: { From a87f669cb4ae1d9577bb69b845c0c457e199abf7 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Wed, 20 Nov 2024 13:48:10 -0500 Subject: [PATCH 15/28] Add tests --- .../operator/persistent/rel_batch_insert.cpp | 4 ++- test/copy/copy_test.cpp | 36 +++++++++++++++---- .../create_rel/create_tinysnb.test | 15 ++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/processor/operator/persistent/rel_batch_insert.cpp b/src/processor/operator/persistent/rel_batch_insert.cpp index c2e4545fadc..d01358f1555 100644 --- a/src/processor/operator/persistent/rel_batch_insert.cpp +++ b/src/processor/operator/persistent/rel_batch_insert.cpp @@ -85,7 +85,9 @@ static void appendNewChunkedGroup(transaction::Transaction* transaction, const CSRNodeGroupScanSource source = isNewNodeGroup ? CSRNodeGroupScanSource::COMMITTED_PERSISTENT : CSRNodeGroupScanSource::COMMITTED_IN_MEMORY; - // TODO this may need to be atomic + // since each thread operates on distinct node groups + // We don't need a lock here (to ensure the insert info and append agree on the number of rows + // in the node group) relTable.pushInsertInfo(transaction, direction, nodeGroup, chunkedGroup.getNumRows(), source); if (isNewNodeGroup) { auto flushedChunkedGroup = diff --git a/test/copy/copy_test.cpp b/test/copy/copy_test.cpp index 7f13502fff7..dfecfdcff8e 100644 --- a/test/copy/copy_test.cpp +++ b/test/copy/copy_test.cpp @@ -105,11 +105,6 @@ void CopyTest::BMExceptionRecoveryTest(BMExceptionRecoveryTestConfig cfg) { for (int i = 0;; i++) { ASSERT_LT(i, 20); - - const auto queryString = common::stringFormat( - "COPY account FROM \"{}/dataset/snap/twitter/csv/twitter-nodes.csv\"", - KUZU_ROOT_DIRECTORY); - auto result = cfg.executeFunc(conn.get(), i); if (!result->isSuccess()) { if (cfg.earlyExitOnFailureFunc(result.get())) { @@ -130,7 +125,7 @@ void CopyTest::BMExceptionRecoveryTest(BMExceptionRecoveryTestConfig cfg) { auto result = cfg.checkFunc(conn.get()); ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); ASSERT_TRUE(result->hasNext()); - ASSERT_EQ(result->getNext()->getValue(0)->getValue(), cfg.checkResult); + ASSERT_EQ(cfg.checkResult, result->getNext()->getValue(0)->getValue()); } } @@ -217,6 +212,35 @@ TEST_F(CopyTest, NodeInsertBMExceptionDuringCommitRecovery) { BMExceptionRecoveryTest(cfg); } +TEST_F(CopyTest, RelInsertBMExceptionDuringCommitRecovery) { + static constexpr auto numNodes = 10000; + BMExceptionRecoveryTestConfig cfg{.canFailDuringExecute = false, + .canFailDuringCheckpoint = false, + .initFunc = + [this](main::Connection* conn) { + failureFrequency = 128; + conn->query("CREATE NODE TABLE account(ID INT64, PRIMARY KEY(ID))"); + conn->query("CREATE REL TABLE follows(FROM account TO account);"); + const auto queryString = common::stringFormat( + "UNWIND RANGE(1,{}) AS i CREATE (a:account {ID:i})", numNodes); + ASSERT_TRUE(conn->query(queryString)->isSuccess()); + }, + .executeFunc = + [](main::Connection* conn, int) { + return conn->query(common::stringFormat( + "UNWIND RANGE(1,{}) AS i MATCH (a:account), (b:account) WHERE a.ID = i AND " + "b.ID = i + 1 CREATE (a)-[f:follows]->(b)", + numNodes)); + }, + .earlyExitOnFailureFunc = [](main::QueryResult*) { return false; }, + .checkFunc = + [](main::Connection* conn) { + return conn->query("MATCH (a)-[f:follows]->(b) RETURN COUNT(*)"); + }, + .checkResult = numNodes - 1}; + BMExceptionRecoveryTest(cfg); +} + TEST_F(CopyTest, OutOfMemoryRecovery) { if (inMemMode) { GTEST_SKIP(); diff --git a/test/test_files/transaction/create_rel/create_tinysnb.test b/test/test_files/transaction/create_rel/create_tinysnb.test index 3fcb809be4d..50ab5f68ce4 100644 --- a/test/test_files/transaction/create_rel/create_tinysnb.test +++ b/test/test_files/transaction/create_rel/create_tinysnb.test @@ -34,6 +34,21 @@ ---- error Runtime exception: Node(nodeOffset: 7) has more than one neighbour in table marries in the bwd direction, which violates the rel multiplicity constraint. +# Retry +-STATEMENT MATCH (a:person)-[m:marries]->(b:person) RETURN a.ID, b.ID +---- 3 +0|2 +3|5 +7|8 +-STATEMENT MATCH (a:person), (b:person) WHERE a.ID = 9 AND b.ID = 10 CREATE (a)-[:marries]->(b) +---- ok +-STATEMENT MATCH (a:person)-[m:marries]->(b:person) RETURN a.ID, b.ID +---- 4 +0|2 +3|5 +7|8 +9|10 + -LOG Bwd -STATEMENT BEGIN TRANSACTION ---- ok From 07cef58ae6b57bd475ea38cf0d0d367395097845 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Thu, 21 Nov 2024 12:02:09 -0500 Subject: [PATCH 16/28] Reclaim overflow slots in in-mem hash index after delete --- src/include/storage/index/in_mem_hash_index.h | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/include/storage/index/in_mem_hash_index.h b/src/include/storage/index/in_mem_hash_index.h index 5c6817fe3cb..5e739849658 100644 --- a/src/include/storage/index/in_mem_hash_index.h +++ b/src/include/storage/index/in_mem_hash_index.h @@ -188,6 +188,11 @@ class InMemHashIndex final { } else { iter.slot->header.setEntryInvalid(*deletedPos); } + + if (newIter.slot->header.numEntries() == 0) { + reclaimOverflowSlots(SlotIterator(slotId, this)); + } + return true; } return false; @@ -196,18 +201,9 @@ class InMemHashIndex final { private: SlotIterator getLastValidEntry(const SlotIterator& startIter) { auto curIter = startIter; - auto newIter = startIter; - while (nextChainedSlot(curIter)) { - if (curIter.slotInfo.slotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID) { - break; - } - if (curIter.slot->header.numEntries() == 0) { - // if the current overflow slot is empty if last valid entry is in the previous slot - break; - } - newIter = curIter; - } - return newIter; + while (curIter.slot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID && + nextChainedSlot(curIter)) {} + return curIter; } // Assumes that space has already been allocated for the entry From a91e4482dd332ec223131dca21eb74fcae9312d5 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 10:29:25 -0500 Subject: [PATCH 17/28] Address review comments --- src/include/storage/store/group_collection.h | 4 +++- src/include/storage/store/version_info.h | 4 ++-- src/include/storage/undo_buffer.h | 2 +- src/storage/store/chunked_node_group.cpp | 8 ++++---- src/storage/store/csr_chunked_node_group.cpp | 2 +- src/storage/store/node_group.cpp | 2 +- src/storage/store/node_table.cpp | 16 ++++++++-------- src/storage/store/version_info.cpp | 8 ++++---- src/storage/undo_buffer.cpp | 2 +- src/transaction/transaction.cpp | 2 +- 10 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/include/storage/store/group_collection.h b/src/include/storage/store/group_collection.h index 8aa7185397e..c6c5dc34bb8 100644 --- a/src/include/storage/store/group_collection.h +++ b/src/include/storage/store/group_collection.h @@ -25,7 +25,9 @@ class GroupCollection { [&](common::Deserializer& deser) { return T::deserialize(memoryManager, deser); }); } - void removeTrailingGroups(const common::UniqLock&, common::idx_t numGroupsToRemove) { + void removeTrailingGroups([[maybe_unused]] const common::UniqLock& lock, + common::idx_t numGroupsToRemove) { + KU_ASSERT(lock.isLocked()); KU_ASSERT(numGroupsToRemove <= groups.size()); groups.erase(groups.end() - numGroupsToRemove, groups.end()); } diff --git a/src/include/storage/store/version_info.h b/src/include/storage/store/version_info.h index a8a2f82203f..a4753c0e1ed 100644 --- a/src/include/storage/store/version_info.h +++ b/src/include/storage/store/version_info.h @@ -86,9 +86,9 @@ class VersionInfo { public: VersionInfo() {} - void append(const transaction::Transaction* transaction, common::row_idx_t startRow, + void append(common::transaction_t transactionID, common::row_idx_t startRow, common::row_idx_t numRows); - bool delete_(const transaction::Transaction* transaction, common::row_idx_t rowIdx); + bool delete_(common::transaction_t transactionID, common::row_idx_t rowIdx); void getSelVectorToScan(common::transaction_t startTS, common::transaction_t transactionID, common::SelectionVector& selVector, common::row_idx_t startRow, diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 6fa6e318828..d9d0624d841 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -95,7 +95,7 @@ class UndoBuffer { VectorUpdateInfo* vectorUpdateInfo); void commit(common::transaction_t commitTS) const; - void rollback(const transaction::Transaction* transaction); + void rollback(); uint64_t getMemUsage() const; diff --git a/src/storage/store/chunked_node_group.cpp b/src/storage/store/chunked_node_group.cpp index 4e2c3971162..34664877e83 100644 --- a/src/storage/store/chunked_node_group.cpp +++ b/src/storage/store/chunked_node_group.cpp @@ -137,7 +137,7 @@ uint64_t ChunkedNodeGroup::append(const Transaction* transaction, if (!versionInfo) { versionInfo = std::make_unique(); } - versionInfo->append(transaction, numRows, numRowsToAppendInChunk); + versionInfo->append(transaction->getID(), numRows, numRowsToAppendInChunk); } numRows += numRowsToAppendInChunk; return numRowsToAppendInChunk; @@ -168,7 +168,7 @@ offset_t ChunkedNodeGroup::append(const Transaction* transaction, if (!versionInfo) { versionInfo = std::make_unique(); } - versionInfo->append(transaction, numRows, numToAppendInChunkedGroup); + versionInfo->append(transaction->getID(), numRows, numToAppendInChunkedGroup); } numRows += numToAppendInChunkedGroup; return numToAppendInChunkedGroup; @@ -334,7 +334,7 @@ bool ChunkedNodeGroup::delete_(const Transaction* transaction, row_idx_t rowIdxI if (!versionInfo) { versionInfo = std::make_unique(); } - return versionInfo->delete_(transaction, rowIdxInChunk); + return versionInfo->delete_(transaction->getID(), rowIdxInChunk); } void ChunkedNodeGroup::addColumn(Transaction* transaction, @@ -396,7 +396,7 @@ std::unique_ptr ChunkedNodeGroup::flushAsNewChunkedNodeGroup( std::make_unique(std::move(flushedChunks), 0 /*startRowIdx*/); flushedChunkedGroup->versionInfo = std::make_unique(); KU_ASSERT(flushedChunkedGroup->getNumRows() == numRows); - flushedChunkedGroup->versionInfo->append(transaction, 0, numRows); + flushedChunkedGroup->versionInfo->append(transaction->getID(), 0, numRows); return flushedChunkedGroup; } diff --git a/src/storage/store/csr_chunked_node_group.cpp b/src/storage/store/csr_chunked_node_group.cpp index 40978905585..b6e405fd0ae 100644 --- a/src/storage/store/csr_chunked_node_group.cpp +++ b/src/storage/store/csr_chunked_node_group.cpp @@ -245,7 +245,7 @@ std::unique_ptr ChunkedCSRNodeGroup::flushAsNewChunkedNodeGrou std::move(flushedChunks), 0 /*startRowIdx*/); flushedChunkedGroup->versionInfo = std::make_unique(); KU_ASSERT(numRows == flushedChunkedGroup->getNumRows()); - flushedChunkedGroup->versionInfo->append(transaction, 0, numRows); + flushedChunkedGroup->versionInfo->append(transaction->getID(), 0, numRows); return flushedChunkedGroup; } diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index da6ae51e1d6..3cf0b1f70ca 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -488,7 +488,7 @@ std::unique_ptr NodeGroup::checkpointVersionInfo(const UniqLock& lo // TODO(Guodong): Optimize the for loop here to directly acess the version info. for (auto i = 0u; i < chunkedGroup->getNumRows(); i++) { if (chunkedGroup->isDeleted(transaction, i)) { - checkpointVersionInfo->delete_(transaction, currRow + i); + checkpointVersionInfo->delete_(transaction->getID(), currRow + i); } } } diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 511603f2885..da78d7dd09e 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -57,9 +57,9 @@ bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset namespace { -struct CommittedPKInserter : public PKColumnScanHelper { +struct UncommittedPKInserter : public PKColumnScanHelper { public: - CommittedPKInserter(row_idx_t startNodeOffset, table_id_t tableID, PrimaryKeyIndex* pkIndex, + UncommittedPKInserter(row_idx_t startNodeOffset, table_id_t tableID, PrimaryKeyIndex* pkIndex, visible_func isVisible) : PKColumnScanHelper(pkIndex, tableID), startNodeOffset(startNodeOffset), nodeIDVector(LogicalType::INTERNAL_ID()), isVisible(std::move(isVisible)) {} @@ -111,7 +111,7 @@ void insertPK(const Transaction* transaction, const ValueVector& nodeIDVector, } } -std::unique_ptr CommittedPKInserter::initPKScanState(DataChunk& dataChunk, +std::unique_ptr UncommittedPKInserter::initPKScanState(DataChunk& dataChunk, column_id_t pkColumnID, const std::vector>& columns) { auto scanState = PKColumnScanHelper::initPKScanState(dataChunk, pkColumnID, columns); nodeIDVector.setState(dataChunk.state); @@ -119,7 +119,7 @@ std::unique_ptr CommittedPKInserter::initPKScanState(DataChu return scanState; } -bool CommittedPKInserter::processScanOutput(const transaction::Transaction* transaction, +bool UncommittedPKInserter::processScanOutput(const transaction::Transaction* transaction, NodeGroupScanResult scanResult, const common::ValueVector& scannedVector) { if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { return false; @@ -508,11 +508,11 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { } // 3. Scan pk column for newly inserted tuples that are not deleted and insert into pk index. - CommittedPKInserter scanHelper{startNodeOffset, tableID, pkIndex.get(), + UncommittedPKInserter pkInserter{startNodeOffset, tableID, pkIndex.get(), getVisibleFunc(transaction)}; // We need to scan from local storage here because some tuples in local node groups might // have been deleted. - scanPKColumn(transaction, scanHelper, localNodeTable.getNodeGroups()); + scanPKColumn(transaction, pkInserter, localNodeTable.getNodeGroups()); // 4. Clear local table. localTable->clear(); @@ -558,8 +558,8 @@ void NodeTable::rollbackInsert(const transaction::Transaction* transaction, startNodeOffset += nodeGroups->getNodeGroupNoLock(i)->getNumRows(); } - RollbackPKDeleter scanHelper{startNodeOffset, numRows_, tableID, pkIndex.get()}; - scanPKColumn(transaction, scanHelper, *nodeGroups); + RollbackPKDeleter pkDeleter{startNodeOffset, numRows_, tableID, pkIndex.get()}; + scanPKColumn(transaction, pkDeleter, *nodeGroups); } TableStats NodeTable::getStats(const Transaction* transaction) const { diff --git a/src/storage/store/version_info.cpp b/src/storage/store/version_info.cpp index aa68d64f2a8..e2a9f3f023c 100644 --- a/src/storage/store/version_info.cpp +++ b/src/storage/store/version_info.cpp @@ -352,7 +352,7 @@ VectorVersionInfo* VersionInfo::getVectorVersionInfo(idx_t vectorIdx) const { return vectorsInfo[vectorIdx].get(); } -void VersionInfo::append(const transaction::Transaction* transaction, const row_idx_t startRow, +void VersionInfo::append(transaction_t transactionID, const row_idx_t startRow, const row_idx_t numRows) { if (numRows == 0) { return; @@ -367,11 +367,11 @@ void VersionInfo::append(const transaction::Transaction* transaction, const row_ const auto endRowIdx = vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; const auto numRowsInVector = endRowIdx - startRowIdx + 1; - vectorVersionInfo.append(transaction->getID(), startRowIdx, numRowsInVector); + vectorVersionInfo.append(transactionID, startRowIdx, numRowsInVector); } } -bool VersionInfo::delete_(const transaction::Transaction* transaction, const row_idx_t rowIdx) { +bool VersionInfo::delete_(transaction_t transactionID, const row_idx_t rowIdx) { auto [vectorIdx, rowIdxInVector] = StorageUtils::getQuotientRemainder(rowIdx, DEFAULT_VECTOR_CAPACITY); auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); @@ -381,7 +381,7 @@ bool VersionInfo::delete_(const transaction::Transaction* transaction, const row // ALWAYS_INSERTED to avoid checking the version in the future. vectorVersionInfo.insertionStatus = VectorVersionInfo::InsertionStatus::ALWAYS_INSERTED; } - return vectorVersionInfo.delete_(transaction->getID(), rowIdxInVector); + return vectorVersionInfo.delete_(transactionID, rowIdxInVector); } void VersionInfo::getSelVectorToScan(const transaction_t startTS, const transaction_t transactionID, diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 27fa98314f0..0bf33b792e1 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -166,7 +166,7 @@ void UndoBuffer::commit(transaction_t commitTS) const { }); } -void UndoBuffer::rollback(const transaction::Transaction* transaction) { +void UndoBuffer::rollback() { UndoBufferIterator iterator{*this}; iterator.reverseIterate([&](UndoRecordType entryType, uint8_t const* entry) { rollbackRecord(transaction, entryType, entry); diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 4d5636f3413..7ec55dfcf43 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -63,7 +63,7 @@ void Transaction::commit(storage::WAL* wal) const { void Transaction::rollback(storage::WAL* wal) const { localStorage->rollback(); - undoBuffer->rollback(this); + undoBuffer->rollback(); if (isWriteTransaction() && shouldLogToWAL()) { KU_ASSERT(wal); wal->logRollback(); From 9134abb0d12d6ea3590b8094e983859b9e6ac05d Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 12:05:29 -0500 Subject: [PATCH 18/28] Replace construct iterator callback with virtual class --- .../store/chunked_group_undo_iterator.h | 25 +++++--- src/include/storage/store/csr_node_group.h | 4 +- src/include/storage/store/node_group.h | 4 +- .../storage/store/node_group_collection.h | 10 +-- src/include/storage/store/node_table.h | 19 ++++-- src/include/storage/store/rel_table_data.h | 29 ++++++++- src/include/storage/undo_buffer.h | 9 ++- src/include/transaction/transaction.h | 11 ++-- src/storage/store/csr_node_group.cpp | 6 +- src/storage/store/node_group.cpp | 6 +- src/storage/store/node_group_collection.cpp | 18 +++--- src/storage/store/node_table.cpp | 27 ++++---- src/storage/store/rel_table_data.cpp | 64 +++++++++++-------- src/storage/undo_buffer.cpp | 35 +++++----- src/transaction/transaction.cpp | 8 +-- 15 files changed, 167 insertions(+), 108 deletions(-) diff --git a/src/include/storage/store/chunked_group_undo_iterator.h b/src/include/storage/store/chunked_group_undo_iterator.h index bae490c5a51..6f52284e8ed 100644 --- a/src/include/storage/store/chunked_group_undo_iterator.h +++ b/src/include/storage/store/chunked_group_undo_iterator.h @@ -13,26 +13,26 @@ class Transaction; namespace storage { class ChunkedNodeGroup; class NodeGroupCollection; -class ChunkedGroupUndoIterator; +class VersionRecordHandler; -using chunked_group_undo_op_t = void ( +using version_record_handler_op_t = void ( ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); -using chunked_group_iterator_construct_t = std::function( +using version_record_handler_construct_t = std::function( common::row_idx_t, common::row_idx_t, common::node_group_idx_t, common::transaction_t)>; // Note: these iterators are not necessarily thread-safe when used on their own -class ChunkedGroupUndoIterator { +class VersionRecordHandler { public: - ChunkedGroupUndoIterator(NodeGroupCollection* nodeGroups, common::row_idx_t startRow, + VersionRecordHandler(NodeGroupCollection* nodeGroups, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) : startRow(startRow), numRows(numRows), commitTS(commitTS), nodeGroups(nodeGroups) {} - virtual ~ChunkedGroupUndoIterator() = default; + virtual ~VersionRecordHandler() = default; virtual void initRollbackInsert(const transaction::Transaction* /*transaction*/) {} - virtual void finalizeRollbackInsert() {}; - virtual void iterate(chunked_group_undo_op_t undoFunc) = 0; + virtual void finalizeRollbackInsert(){}; + virtual void applyFuncToChunkedGroups(version_record_handler_op_t func) = 0; protected: common::row_idx_t startRow; @@ -42,5 +42,14 @@ class ChunkedGroupUndoIterator { NodeGroupCollection* nodeGroups; }; +class VersionRecordHandlerData { +public: + virtual ~VersionRecordHandlerData() = default; + + virtual std::unique_ptr constructVersionRecordHandler( + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS, + common::node_group_idx_t nodeGroupIdx) const = 0; +}; + } // namespace storage } // namespace kuzu diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index a064b42f44e..0a88477924e 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -165,12 +165,12 @@ static constexpr common::column_id_t REL_ID_COLUMN_ID = 1; struct RelTableScanState; class CSRNodeGroup final : public NodeGroup { public: - class PersistentIterator : public ChunkedGroupUndoIterator { + class PersistentIterator : public VersionRecordHandler { public: PersistentIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); - void iterate(chunked_group_undo_op_t undoFunc) override; + void applyFuncToChunkedGroups(version_record_handler_op_t func) override; void finalizeRollbackInsert() override; private: diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 40e3d13a11f..54b2b2d22bb 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -82,11 +82,11 @@ static auto NODE_GROUP_SCAN_EMMPTY_RESULT = NodeGroupScanResult{}; struct TableScanState; class NodeGroup { public: - class ChunkedGroupIterator : public ChunkedGroupUndoIterator { + class ChunkedGroupIterator : public VersionRecordHandler { public: ChunkedGroupIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); - void iterate(chunked_group_undo_op_t undoFunc) override; + void applyFuncToChunkedGroups(version_record_handler_op_t func) override; void finalizeRollbackInsert() override; protected: diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index f0964dcfa33..b2153a2ce68 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -16,7 +16,7 @@ class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, - const chunked_group_iterator_construct_t* iteratorConstructFunc = nullptr); + const VersionRecordHandlerData* versionRecordHandlerData = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -51,7 +51,7 @@ class NodeGroupCollection { } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t groupIdx, NodeGroupDataFormat format, - const chunked_group_iterator_construct_t* constructIteratorFunc_); + const VersionRecordHandlerData* versionRecordHandlerData); void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, std::unique_ptr group) { @@ -82,12 +82,12 @@ class NodeGroupCollection { void pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc_); + const VersionRecordHandlerData* overridedVersionRecordHandlerData); private: void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorOverrideFunc = nullptr); + const VersionRecordHandlerData* overridedVersionRecordHandlerData = nullptr); bool enableCompression; // Num rows in the collection regardless of deletions. @@ -96,7 +96,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; - const chunked_group_iterator_construct_t* iteratorConstructFunc; + const VersionRecordHandlerData* versionRecordHandlerData; }; } // namespace storage diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index c3b1e5be8ea..3696dba10fb 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -21,6 +21,7 @@ class Transaction; } // namespace transaction namespace storage { +class NodeTable; struct NodeTableScanState final : TableScanState { // Scan state for un-committed data. @@ -95,6 +96,18 @@ struct PKColumnScanHelper { PrimaryKeyIndex* pkIndex; }; +class NodeTableVersionRecordHandlerData : public VersionRecordHandlerData { +public: + explicit NodeTableVersionRecordHandlerData(NodeTable* nodeTable) : nodeTable(nodeTable) {} + + std::unique_ptr constructVersionRecordHandler(common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS, + common::node_group_idx_t nodeGroupIdx) const override; + +private: + NodeTable* nodeTable; +}; + class StorageManager; class NodeTable final : public Table { public: @@ -201,10 +214,6 @@ class NodeTable final : public Table { TableStats getStats(const transaction::Transaction* transaction) const; - const chunked_group_iterator_construct_t& getIteratorConstructFunc() const { - return iteratorConstructFunc; - } - private: void validatePkNotExists(const transaction::Transaction* transaction, common::ValueVector* pkVector); @@ -221,7 +230,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; - chunked_group_iterator_construct_t iteratorConstructFunc; + NodeTableVersionRecordHandlerData versionRecordHandlerData; }; } // namespace storage diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index bf80c1da125..abe2d2550c3 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -14,12 +14,27 @@ class Transaction; } namespace storage { class MemoryManager; +class RelTableData; struct CSRHeaderColumns { std::unique_ptr offset; std::unique_ptr length; }; +class RelTableVersionRecordHandlerData : public VersionRecordHandlerData { +public: + RelTableVersionRecordHandlerData(RelTableData* relTableData, CSRNodeGroupScanSource source) + : relTableData(relTableData), source(source) {} + + std::unique_ptr constructVersionRecordHandler(common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS, + common::node_group_idx_t nodeGroupIdx) const override; + +private: + RelTableData* relTableData; + CSRNodeGroupScanSource source; +}; + class RelTableData { public: RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, @@ -56,7 +71,7 @@ class RelTableData { NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx) const { return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, NodeGroupDataFormat::CSR, - &persistentIteratorConstructFunc); + &persistentVersionRecordHandlerData); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } @@ -70,6 +85,11 @@ class RelTableData { void serialize(common::Serializer& serializer) const; + std::unique_ptr constructVersionRecordHandler( + CSRNodeGroupScanSource source, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS) const; + private: void initCSRHeaderColumns(); void initPropertyColumns(const catalog::TableCatalogEntry* tableEntry); @@ -98,6 +118,9 @@ class RelTableData { return types; } + const RelTableVersionRecordHandlerData* getVersionRecordHandlerData( + CSRNodeGroupScanSource source); + private: FileHandle* dataFH; common::table_id_t tableID; @@ -114,8 +137,8 @@ class RelTableData { CSRHeaderColumns csrHeaderColumns; std::vector> columns; - chunked_group_iterator_construct_t inMemIteratorConstructFunc; - chunked_group_iterator_construct_t persistentIteratorConstructFunc; + RelTableVersionRecordHandlerData persistentVersionRecordHandlerData; + RelTableVersionRecordHandlerData inMemoryVersionRecordHandlerData; }; } // namespace storage diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index d9d0624d841..98c836101d8 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -88,9 +88,11 @@ class UndoBuffer { void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, const catalog::SequenceRollbackData& data); void createInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc); + common::row_idx_t numRows, + const storage::VersionRecordHandlerData* versionRecordHandlerData); void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc); + common::row_idx_t numRows, + const storage::VersionRecordHandlerData* versionRecordHandlerData); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, VectorUpdateInfo* vectorUpdateInfo); @@ -103,7 +105,8 @@ class UndoBuffer { uint8_t* createUndoRecord(uint64_t size); void createVersionInfo(UndoRecordType recordType, common::row_idx_t startRow, - common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc, + common::row_idx_t numRows, + const storage::VersionRecordHandlerData* versionRecordHandlerData, common::node_group_idx_t nodeGroupIdx = 0); void commitRecord(UndoRecordType recordType, const uint8_t* record, diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index 363fdc09066..2ac04b68dd2 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -23,16 +23,13 @@ class VersionInfo; class UpdateInfo; struct VectorUpdateInfo; class ChunkedNodeGroup; -class ChunkedGroupUndoIterator; +class VersionRecordHandler; +class VersionRecordHandlerData; } // namespace storage namespace transaction { class TransactionManager; class Transaction; -using chunked_group_iterator_construct_t = - std::function(common::row_idx_t, - common::row_idx_t, common::node_group_idx_t, common::transaction_t commitTS)>; - enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; class Transaction { @@ -123,10 +120,10 @@ class Transaction { const catalog::SequenceRollbackData& data) const; void pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc = nullptr) const; + const storage::VersionRecordHandlerData* versionRecordHandlerData) const; void pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc) const; + const storage::VersionRecordHandlerData* versionRecordHandlerData) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const; diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index d68e40edf5b..300c7adea73 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -15,15 +15,15 @@ namespace storage { CSRNodeGroup::PersistentIterator::PersistentIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) - : ChunkedGroupUndoIterator(nodeGroups, startRow, numRows, commitTS), nodeGroup(nullptr) { + : VersionRecordHandler(nodeGroups, startRow, numRows, commitTS), nodeGroup(nullptr) { if (nodeGroupIdx < nodeGroups->getNumNodeGroups()) { nodeGroup = ku_dynamic_cast(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)); } } -void CSRNodeGroup::PersistentIterator::iterate(chunked_group_undo_op_t undoFunc) { +void CSRNodeGroup::PersistentIterator::applyFuncToChunkedGroups(version_record_handler_op_t func) { if (nodeGroup && nodeGroup->persistentChunkGroup) { - std::invoke(undoFunc, *nodeGroup->persistentChunkGroup, startRow, numRows, commitTS); + std::invoke(func, *nodeGroup->persistentChunkGroup, startRow, numRows, commitTS); } } diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index 3cf0b1f70ca..9995b635f83 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -24,13 +24,13 @@ namespace storage { NodeGroup::ChunkedGroupIterator::ChunkedGroupIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, transaction_t commitTS) - : ChunkedGroupUndoIterator(nodeGroups, startRow, numRows, commitTS), + : VersionRecordHandler(nodeGroups, startRow, numRows, commitTS), nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)), numRowsToRollback(std::min(numRows, nodeGroup->getNumRows() - startRow)) { KU_ASSERT(startRow <= nodeGroup->getNumRows()); } -void NodeGroup::ChunkedGroupIterator::iterate(chunked_group_undo_op_t undoFunc) { +void NodeGroup::ChunkedGroupIterator::applyFuncToChunkedGroups(version_record_handler_op_t func) { auto lock = nodeGroup->chunkedGroups.lock(); const auto [chunkedGroupIdx, startRowInChunkedGroup] = nodeGroup->findChunkedGroupIdxFromRowIdxNoLock(startRow); @@ -44,7 +44,7 @@ void NodeGroup::ChunkedGroupIterator::iterate(chunked_group_undo_op_t undoFunc) auto* chunkedGroup = nodeGroup->chunkedGroups.getGroup(lock, curChunkedGroupIdx); const auto numRowsForGroup = std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); - std::invoke(undoFunc, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); + std::invoke(func, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); ++curChunkedGroupIdx; numRowsLeft -= numRowsForGroup; diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index f115c475454..257554b05f7 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -14,9 +14,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer, const chunked_group_iterator_construct_t* iteratorConstructFunc) + Deserializer* deSer, const VersionRecordHandlerData* versionRecordHandlerData) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH}, iteratorConstructFunc(iteratorConstructFunc) { + dataFH{dataFH}, versionRecordHandlerData(versionRecordHandlerData) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -155,7 +155,7 @@ row_idx_t NodeGroupCollection::getNumTotalRows() { NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* transaction, node_group_idx_t groupIdx, NodeGroupDataFormat format, - const chunked_group_iterator_construct_t* constructIteratorFunc_) { + const VersionRecordHandlerData* versionRecordHandlerData) { const auto lock = nodeGroups.lock(); while (groupIdx >= nodeGroups.getNumGroups(lock)) { const auto currentGroupIdx = nodeGroups.getNumGroups(lock); @@ -166,7 +166,7 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* t enableCompression, LogicalType::copy(types))); // push an insert of size 0 so that we can rollback the creation of this node group if // needed - pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, constructIteratorFunc_); + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, versionRecordHandlerData); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); return nodeGroups.getGroup(lock, groupIdx); @@ -213,17 +213,19 @@ void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool update void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorOverrideFunc) { + const VersionRecordHandlerData* overridedVersionRecordHandlerData) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, - constructIteratorOverrideFunc ? constructIteratorOverrideFunc : iteratorConstructFunc); + overridedVersionRecordHandlerData ? overridedVersionRecordHandlerData : + versionRecordHandlerData); }; void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc_) { + const VersionRecordHandlerData* overridedVersionRecordHandlerData) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, constructIteratorFunc_); + transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, + overridedVersionRecordHandlerData); } } diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index da78d7dd09e..cd22b44f2af 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -21,6 +21,14 @@ using namespace kuzu::evaluator; namespace kuzu { namespace storage { +std::unique_ptr +NodeTableVersionRecordHandlerData::constructVersionRecordHandler(common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS, + common::node_group_idx_t nodeGroupIdx) const { + return std::make_unique(nodeTable, nodeGroupIdx, startRow, + numRows, commitTS); +} + NodeTable::ChunkedGroupIterator::ChunkedGroupIterator(NodeTable* table, node_group_idx_t nodeGroupidx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) @@ -189,7 +197,8 @@ NodeTable::NodeTable(const StorageManager* storageManager, const NodeTableCatalogEntry* nodeTableEntry, MemoryManager* memoryManager, VirtualFileSystem* vfs, main::ClientContext* context, Deserializer* deSer) : Table{nodeTableEntry, storageManager, memoryManager}, - pkColumnID{nodeTableEntry->getColumnID(nodeTableEntry->getPrimaryKeyName())} { + pkColumnID{nodeTableEntry->getColumnID(nodeTableEntry->getPrimaryKeyName())}, + versionRecordHandlerData(this) { const auto maxColumnID = nodeTableEntry->getMaxColumnID(); columns.resize(maxColumnID + 1); for (auto i = 0u; i < nodeTableEntry->getNumProperties(); i++) { @@ -201,16 +210,8 @@ NodeTable::NodeTable(const StorageManager* storageManager, dataFH, memoryManager, shadowFile, enableCompression); } - iteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx_, - common::transaction_t commitTS) { - return std::make_unique(this, nodeGroupIdx_, startRow, numRows_, - commitTS); - }; - - nodeGroups = - std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), - enableCompression, storageManager->getDataFH(), deSer, &iteratorConstructFunc); + nodeGroups = std::make_unique(*memoryManager, + getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); } @@ -427,7 +428,7 @@ bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, &iteratorConstructFunc); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, &versionRecordHandlerData); } } if (isDeleted) { @@ -499,7 +500,7 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { KU_ASSERT(isDeleted); if (transaction->shouldAppendToUndoBuffer()) { transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, - &iteratorConstructFunc); + &versionRecordHandlerData); } } } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index 9a81bb76d1f..f2387ce6ce3 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -16,30 +16,26 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { +std::unique_ptr +RelTableVersionRecordHandlerData::constructVersionRecordHandler(common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS, + common::node_group_idx_t nodeGroupIdx) const { + return relTableData->constructVersionRecordHandler(source, nodeGroupIdx, startRow, numRows, + commitTS); +} + RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, const TableCatalogEntry* tableEntry, RelDataDirection direction, bool enableCompression, Deserializer* deSer) : dataFH{dataFH}, tableID{tableEntry->getTableID()}, tableName{tableEntry->getName()}, memoryManager{mm}, shadowFile{shadowFile}, enableCompression{enableCompression}, - direction{direction} { + direction{direction}, + persistentVersionRecordHandlerData(this, CSRNodeGroupScanSource::COMMITTED_PERSISTENT), + inMemoryVersionRecordHandlerData(this, CSRNodeGroupScanSource::COMMITTED_IN_MEMORY) { multiplicity = tableEntry->constCast().getMultiplicity(direction); initCSRHeaderColumns(); initPropertyColumns(tableEntry); - inMemIteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx_, - common::transaction_t commitTS) { - return std::make_unique(nodeGroups.get(), nodeGroupIdx_, - startRow, numRows_, commitTS); - }; - - persistentIteratorConstructFunc = [this](common::row_idx_t startRow, common::row_idx_t numRows_, - common::node_group_idx_t nodeGroupIdx_, - common::transaction_t commitTS) { - return std::make_unique(nodeGroups.get(), nodeGroupIdx_, - startRow, numRows_, commitTS); - }; - nodeGroups = std::make_unique(*mm, getColumnTypes(), enableCompression, dataFH, deSer); } @@ -113,11 +109,7 @@ bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVec auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); if (isDeleted && transaction->shouldAppendToUndoBuffer()) { - const auto* constructIteratorFunc = - (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? - &persistentIteratorConstructFunc : - &inMemIteratorConstructFunc; - transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, constructIteratorFunc); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, getVersionRecordHandlerData(source)); } return isDeleted; } @@ -217,13 +209,12 @@ void RelTableData::pushInsertInfo(transaction::Transaction* transaction, !nodeGroup.getPersistentChunkedGroup() || nodeGroup.getPersistentChunkedGroup()->getNumRows() == 0); - const auto [startRow, constructIteratorFunc] = - (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? - std::make_pair(static_cast(0), &persistentIteratorConstructFunc) : - std::make_pair(nodeGroup.getNumRows(), &inMemIteratorConstructFunc); + const auto startRow = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? + static_cast(0) : + nodeGroup.getNumRows(); nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, - constructIteratorFunc); + getVersionRecordHandlerData(source)); } void RelTableData::checkpoint(const std::vector& columnIDs) { @@ -248,5 +239,28 @@ void RelTableData::serialize(Serializer& serializer) const { nodeGroups->serialize(serializer); } +std::unique_ptr RelTableData::constructVersionRecordHandler( + CSRNodeGroupScanSource source, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) const { + if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + return std::make_unique(nodeGroups.get(), nodeGroupIdx, + startRow, numRows, commitTS); + } else { + KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); + return std::make_unique(nodeGroups.get(), nodeGroupIdx, + startRow, numRows, commitTS); + } +} + +const RelTableVersionRecordHandlerData* RelTableData::getVersionRecordHandlerData( + CSRNodeGroupScanSource source) { + if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + return &persistentVersionRecordHandlerData; + } else { + KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); + return &inMemoryVersionRecordHandlerData; + } +} + } // namespace storage } // namespace kuzu diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 0bf33b792e1..d30eb34343f 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -40,7 +40,7 @@ struct VersionRecord { row_idx_t startRow; row_idx_t numRows; node_group_idx_t nodeGroupIdx; - const chunked_group_iterator_construct_t* iteratorConstructFunc; + const storage::VersionRecordHandlerData* versionRecordHandlerData; }; struct VectorUpdateRecord { @@ -111,26 +111,26 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, } void UndoBuffer::createInsertInfo(node_group_idx_t nodeGroupIdx, row_idx_t startRow, - row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc) { - createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, iteratorConstructFunc, + row_idx_t numRows, const storage::VersionRecordHandlerData* versionRecordHandlerData) { + createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, versionRecordHandlerData, nodeGroupIdx); } void UndoBuffer::createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc) { - createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, iteratorConstructFunc, + common::row_idx_t numRows, const storage::VersionRecordHandlerData* versionRecordHandlerData) { + createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, versionRecordHandlerData, nodeGroupIdx); } void UndoBuffer::createVersionInfo(const UndoRecordType recordType, row_idx_t startRow, - row_idx_t numRows, const chunked_group_iterator_construct_t* iteratorConstructFunc, + row_idx_t numRows, const storage::VersionRecordHandlerData* versionRecordHandlerData, node_group_idx_t nodeGroupIdx) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); *reinterpret_cast(buffer) = - VersionRecord{startRow, numRows, nodeGroupIdx, iteratorConstructFunc}; + VersionRecord{startRow, numRows, nodeGroupIdx, versionRecordHandlerData}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -215,14 +215,14 @@ void UndoBuffer::commitVersionInfo(UndoRecordType recordType, const uint8_t* rec const auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, commitTS); - it->iterate(&ChunkedNodeGroup::commitInsert); + auto handler = undoRecord.versionRecordHandlerData->constructVersionRecordHandler( + undoRecord.startRow, undoRecord.numRows, commitTS, undoRecord.nodeGroupIdx); + handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitInsert); } break; case UndoRecordType::DELETE_INFO: { - auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, commitTS); - it->iterate(&ChunkedNodeGroup::commitDelete); + auto handler = undoRecord.versionRecordHandlerData->constructVersionRecordHandler( + undoRecord.startRow, undoRecord.numRows, commitTS, undoRecord.nodeGroupIdx); + handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitDelete); } break; default: { KU_UNREACHABLE; @@ -301,13 +301,14 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, undoRecord.nodeGroupIdx, transaction->getCommitTS()); it->initRollbackInsert(transaction); - it->iterate(&ChunkedNodeGroup::rollbackInsert); + it->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackInsert); it->finalizeRollbackInsert(); } break; case UndoRecordType::DELETE_INFO: { - auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, transaction->getCommitTS()); - it->iterate(&ChunkedNodeGroup::rollbackDelete); + auto handler = + undoRecord.versionRecordHandlerData->constructVersionRecordHandler(undoRecord.startRow, + undoRecord.numRows, transaction->getCommitTS(), undoRecord.nodeGroupIdx); + handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackDelete); } break; default: { KU_UNREACHABLE; diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 7ec55dfcf43..50a3ed536be 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -175,14 +175,14 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ void Transaction::pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc) const { - undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, constructIteratorFunc); + const storage::VersionRecordHandlerData* versionRecordHandlerData) const { + undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerData); } void Transaction::pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const chunked_group_iterator_construct_t* constructIteratorFunc) const { - undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, constructIteratorFunc); + const storage::VersionRecordHandlerData* versionRecordHandlerData) const { + undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerData); } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, From c3c638e4c90398a52478e56c667066e15f34494b Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 13:10:05 -0500 Subject: [PATCH 19/28] Refactor version record handler --- src/include/storage/store/csr_node_group.h | 2 +- src/include/storage/store/node_group.h | 11 +++--- .../storage/store/node_group_collection.h | 10 +++--- src/include/storage/store/node_table.h | 10 +++--- src/include/storage/store/rel_table_data.h | 12 +++---- ...do_iterator.h => version_record_handler.h} | 19 +++++----- src/include/storage/undo_buffer.h | 6 ++-- src/include/transaction/transaction.h | 6 ++-- src/storage/store/csr_node_group.cpp | 3 +- src/storage/store/node_group.cpp | 17 ++++----- src/storage/store/node_group_collection.cpp | 18 +++++----- src/storage/store/node_table.cpp | 25 +++++++------ src/storage/store/rel_table_data.cpp | 21 +++++------ src/storage/undo_buffer.cpp | 35 ++++++++++--------- src/transaction/transaction.cpp | 8 ++--- 15 files changed, 102 insertions(+), 101 deletions(-) rename src/include/storage/store/{chunked_group_undo_iterator.h => version_record_handler.h} (73%) diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index 0a88477924e..9f4090a0f31 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -171,7 +171,7 @@ class CSRNodeGroup final : public NodeGroup { common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); void applyFuncToChunkedGroups(version_record_handler_op_t func) override; - void finalizeRollbackInsert() override; + void rollbackInsert(const transaction::Transaction* transaction) override; private: CSRNodeGroup* nodeGroup; diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 54b2b2d22bb..754298f3cb8 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -4,9 +4,9 @@ #include "common/uniq_lock.h" #include "storage/enums/residency_state.h" -#include "storage/store/chunked_group_undo_iterator.h" #include "storage/store/chunked_node_group.h" #include "storage/store/group_collection.h" +#include "storage/store/version_record_handler.h" namespace kuzu { namespace transaction { @@ -82,16 +82,15 @@ static auto NODE_GROUP_SCAN_EMMPTY_RESULT = NodeGroupScanResult{}; struct TableScanState; class NodeGroup { public: - class ChunkedGroupIterator : public VersionRecordHandler { + class NodeGroupVersionRecordHandler : public VersionRecordHandler { public: - ChunkedGroupIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); + NodeGroupVersionRecordHandler(NodeGroupCollection* nodeGroups, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS); void applyFuncToChunkedGroups(version_record_handler_op_t func) override; - void finalizeRollbackInsert() override; protected: NodeGroup* nodeGroup; - common::row_idx_t numRowsToRollback; }; NodeGroup(const common::node_group_idx_t nodeGroupIdx, const bool enableCompression, diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index b2153a2ce68..80a44bcba8e 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -16,7 +16,7 @@ class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, - const VersionRecordHandlerData* versionRecordHandlerData = nullptr); + const VersionRecordHandlerSelector* versionRecordHandlerSelector = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -51,7 +51,7 @@ class NodeGroupCollection { } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t groupIdx, NodeGroupDataFormat format, - const VersionRecordHandlerData* versionRecordHandlerData); + const VersionRecordHandlerSelector* versionRecordHandlerSelector); void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, std::unique_ptr group) { @@ -82,12 +82,12 @@ class NodeGroupCollection { void pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const VersionRecordHandlerData* overridedVersionRecordHandlerData); + const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector); private: void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const VersionRecordHandlerData* overridedVersionRecordHandlerData = nullptr); + const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector = nullptr); bool enableCompression; // Num rows in the collection regardless of deletions. @@ -96,7 +96,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; - const VersionRecordHandlerData* versionRecordHandlerData; + const VersionRecordHandlerSelector* versionRecordHandlerSelector; }; } // namespace storage diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index 3696dba10fb..fc717da66b2 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -96,9 +96,9 @@ struct PKColumnScanHelper { PrimaryKeyIndex* pkIndex; }; -class NodeTableVersionRecordHandlerData : public VersionRecordHandlerData { +class NodeTableVersionRecordHandlerSelector : public VersionRecordHandlerSelector { public: - explicit NodeTableVersionRecordHandlerData(NodeTable* nodeTable) : nodeTable(nodeTable) {} + explicit NodeTableVersionRecordHandlerSelector(NodeTable* nodeTable) : nodeTable(nodeTable) {} std::unique_ptr constructVersionRecordHandler(common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS, @@ -111,12 +111,12 @@ class NodeTableVersionRecordHandlerData : public VersionRecordHandlerData { class StorageManager; class NodeTable final : public Table { public: - class ChunkedGroupIterator : public NodeGroup::ChunkedGroupIterator { + class ChunkedGroupIterator : public NodeGroup::NodeGroupVersionRecordHandler { public: ChunkedGroupIterator(NodeTable* table, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); - void initRollbackInsert(const transaction::Transaction* transaction) override; + void rollbackInsert(const transaction::Transaction* transaction) override; private: NodeTable* table; @@ -230,7 +230,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; - NodeTableVersionRecordHandlerData versionRecordHandlerData; + NodeTableVersionRecordHandlerSelector versionRecordHandlerSelector; }; } // namespace storage diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index abe2d2550c3..637c5d9c8c3 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -21,9 +21,9 @@ struct CSRHeaderColumns { std::unique_ptr length; }; -class RelTableVersionRecordHandlerData : public VersionRecordHandlerData { +class RelTableVersionRecordHandlerSelector : public VersionRecordHandlerSelector { public: - RelTableVersionRecordHandlerData(RelTableData* relTableData, CSRNodeGroupScanSource source) + RelTableVersionRecordHandlerSelector(RelTableData* relTableData, CSRNodeGroupScanSource source) : relTableData(relTableData), source(source) {} std::unique_ptr constructVersionRecordHandler(common::row_idx_t startRow, @@ -71,7 +71,7 @@ class RelTableData { NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx) const { return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, NodeGroupDataFormat::CSR, - &persistentVersionRecordHandlerData); + &persistentVersionRecordHandlerSelector); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } @@ -118,7 +118,7 @@ class RelTableData { return types; } - const RelTableVersionRecordHandlerData* getVersionRecordHandlerData( + const RelTableVersionRecordHandlerSelector* getVersionRecordHandlerSelector( CSRNodeGroupScanSource source); private: @@ -137,8 +137,8 @@ class RelTableData { CSRHeaderColumns csrHeaderColumns; std::vector> columns; - RelTableVersionRecordHandlerData persistentVersionRecordHandlerData; - RelTableVersionRecordHandlerData inMemoryVersionRecordHandlerData; + RelTableVersionRecordHandlerSelector persistentVersionRecordHandlerSelector; + RelTableVersionRecordHandlerSelector inMemoryVersionRecordHandlerSelector; }; } // namespace storage diff --git a/src/include/storage/store/chunked_group_undo_iterator.h b/src/include/storage/store/version_record_handler.h similarity index 73% rename from src/include/storage/store/chunked_group_undo_iterator.h rename to src/include/storage/store/version_record_handler.h index 6f52284e8ed..fe6e137b70a 100644 --- a/src/include/storage/store/chunked_group_undo_iterator.h +++ b/src/include/storage/store/version_record_handler.h @@ -1,8 +1,7 @@ #pragma once -#include - #include "common/types/types.h" +#include "storage/store/chunked_node_group.h" namespace kuzu { @@ -11,16 +10,12 @@ class Transaction; } namespace storage { -class ChunkedNodeGroup; class NodeGroupCollection; class VersionRecordHandler; using version_record_handler_op_t = void ( ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); -using version_record_handler_construct_t = std::function( - common::row_idx_t, common::row_idx_t, common::node_group_idx_t, common::transaction_t)>; - // Note: these iterators are not necessarily thread-safe when used on their own class VersionRecordHandler { public: @@ -30,10 +25,12 @@ class VersionRecordHandler { virtual ~VersionRecordHandler() = default; - virtual void initRollbackInsert(const transaction::Transaction* /*transaction*/) {} - virtual void finalizeRollbackInsert(){}; virtual void applyFuncToChunkedGroups(version_record_handler_op_t func) = 0; + virtual void rollbackInsert(const transaction::Transaction* /*transaction*/) { + applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackInsert); + } + protected: common::row_idx_t startRow; common::row_idx_t numRows; @@ -42,9 +39,11 @@ class VersionRecordHandler { NodeGroupCollection* nodeGroups; }; -class VersionRecordHandlerData { +// Contains pointer to a table + any information needed by the table to construct a +// VersionRecordHandler +class VersionRecordHandlerSelector { public: - virtual ~VersionRecordHandlerData() = default; + virtual ~VersionRecordHandlerSelector() = default; virtual std::unique_ptr constructVersionRecordHandler( common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS, diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 98c836101d8..f5fc93819d2 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -89,10 +89,10 @@ class UndoBuffer { const catalog::SequenceRollbackData& data); void createInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData); + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector); void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData); + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, VectorUpdateInfo* vectorUpdateInfo); @@ -106,7 +106,7 @@ class UndoBuffer { void createVersionInfo(UndoRecordType recordType, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData, + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector, common::node_group_idx_t nodeGroupIdx = 0); void commitRecord(UndoRecordType recordType, const uint8_t* record, diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index 2ac04b68dd2..207f824472c 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -24,7 +24,7 @@ class UpdateInfo; struct VectorUpdateInfo; class ChunkedNodeGroup; class VersionRecordHandler; -class VersionRecordHandlerData; +class VersionRecordHandlerSelector; } // namespace storage namespace transaction { class TransactionManager; @@ -120,10 +120,10 @@ class Transaction { const catalog::SequenceRollbackData& data) const; void pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData) const; + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const; void pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData) const; + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const; diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index 300c7adea73..f84b2f5005b 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -27,7 +27,8 @@ void CSRNodeGroup::PersistentIterator::applyFuncToChunkedGroups(version_record_h } } -void CSRNodeGroup::PersistentIterator::finalizeRollbackInsert() { +void CSRNodeGroup::PersistentIterator::rollbackInsert(const transaction::Transaction* transaction) { + VersionRecordHandler::rollbackInsert(transaction); nodeGroups->rollbackInsert(numRows, false); } diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index 9995b635f83..9cb5e2f60bd 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -21,16 +21,16 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -NodeGroup::ChunkedGroupIterator::ChunkedGroupIterator(NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - transaction_t commitTS) +NodeGroup::NodeGroupVersionRecordHandler::NodeGroupVersionRecordHandler( + NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows, transaction_t commitTS) : VersionRecordHandler(nodeGroups, startRow, numRows, commitTS), - nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)), - numRowsToRollback(std::min(numRows, nodeGroup->getNumRows() - startRow)) { + nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)) { KU_ASSERT(startRow <= nodeGroup->getNumRows()); } -void NodeGroup::ChunkedGroupIterator::applyFuncToChunkedGroups(version_record_handler_op_t func) { +void NodeGroup::NodeGroupVersionRecordHandler::applyFuncToChunkedGroups( + version_record_handler_op_t func) { auto lock = nodeGroup->chunkedGroups.lock(); const auto [chunkedGroupIdx, startRowInChunkedGroup] = nodeGroup->findChunkedGroupIdxFromRowIdxNoLock(startRow); @@ -53,11 +53,6 @@ void NodeGroup::ChunkedGroupIterator::applyFuncToChunkedGroups(version_record_ha } } -void NodeGroup::ChunkedGroupIterator::finalizeRollbackInsert() { - nodeGroup->rollbackInsert(startRow); - nodeGroups->rollbackInsert(numRowsToRollback); -} - row_idx_t NodeGroup::append(const Transaction* transaction, ChunkedNodeGroup& chunkedGroup, row_idx_t startRowIdx, row_idx_t numRowsToAppend) { KU_ASSERT(numRowsToAppend <= chunkedGroup.getNumRows()); diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index 257554b05f7..e456e17afd9 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -14,9 +14,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer, const VersionRecordHandlerData* versionRecordHandlerData) + Deserializer* deSer, const VersionRecordHandlerSelector* versionRecordHandlerSelector) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH}, versionRecordHandlerData(versionRecordHandlerData) { + dataFH{dataFH}, versionRecordHandlerSelector(versionRecordHandlerSelector) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -155,7 +155,7 @@ row_idx_t NodeGroupCollection::getNumTotalRows() { NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* transaction, node_group_idx_t groupIdx, NodeGroupDataFormat format, - const VersionRecordHandlerData* versionRecordHandlerData) { + const VersionRecordHandlerSelector* versionRecordHandlerSelector) { const auto lock = nodeGroups.lock(); while (groupIdx >= nodeGroups.getNumGroups(lock)) { const auto currentGroupIdx = nodeGroups.getNumGroups(lock); @@ -166,7 +166,7 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* t enableCompression, LogicalType::copy(types))); // push an insert of size 0 so that we can rollback the creation of this node group if // needed - pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, versionRecordHandlerData); + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, versionRecordHandlerSelector); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); return nodeGroups.getGroup(lock, groupIdx); @@ -213,19 +213,19 @@ void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool update void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const VersionRecordHandlerData* overridedVersionRecordHandlerData) { + const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, - overridedVersionRecordHandlerData ? overridedVersionRecordHandlerData : - versionRecordHandlerData); + overridedVersionRecordHandlerSelector ? overridedVersionRecordHandlerSelector : + versionRecordHandlerSelector); }; void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const VersionRecordHandlerData* overridedVersionRecordHandlerData) { + const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, - overridedVersionRecordHandlerData); + overridedVersionRecordHandlerSelector); } } diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index cd22b44f2af..e95315ae510 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -22,7 +22,7 @@ namespace kuzu { namespace storage { std::unique_ptr -NodeTableVersionRecordHandlerData::constructVersionRecordHandler(common::row_idx_t startRow, +NodeTableVersionRecordHandlerSelector::constructVersionRecordHandler(common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS, common::node_group_idx_t nodeGroupIdx) const { return std::make_unique(nodeTable, nodeGroupIdx, startRow, @@ -32,13 +32,16 @@ NodeTableVersionRecordHandlerData::constructVersionRecordHandler(common::row_idx NodeTable::ChunkedGroupIterator::ChunkedGroupIterator(NodeTable* table, node_group_idx_t nodeGroupidx, common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) - : NodeGroup::ChunkedGroupIterator(table->nodeGroups.get(), nodeGroupidx, startRow, numRows, - commitTS), + : NodeGroup::NodeGroupVersionRecordHandler(table->nodeGroups.get(), nodeGroupidx, startRow, + numRows, commitTS), table(table) {} -void NodeTable::ChunkedGroupIterator::initRollbackInsert( - const transaction::Transaction* transaction) { +void NodeTable::ChunkedGroupIterator::rollbackInsert(const transaction::Transaction* transaction) { table->rollbackInsert(transaction, startRow, numRows, nodeGroup->getNodeGroupIdx()); + NodeGroup::NodeGroupVersionRecordHandler::rollbackInsert(transaction); + nodeGroup->rollbackInsert(startRow); + const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); + nodeGroups->rollbackInsert(numRowsToRollback); } bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset, @@ -198,7 +201,7 @@ NodeTable::NodeTable(const StorageManager* storageManager, VirtualFileSystem* vfs, main::ClientContext* context, Deserializer* deSer) : Table{nodeTableEntry, storageManager, memoryManager}, pkColumnID{nodeTableEntry->getColumnID(nodeTableEntry->getPrimaryKeyName())}, - versionRecordHandlerData(this) { + versionRecordHandlerSelector(this) { const auto maxColumnID = nodeTableEntry->getMaxColumnID(); columns.resize(maxColumnID + 1); for (auto i = 0u; i < nodeTableEntry->getNumProperties(); i++) { @@ -210,8 +213,9 @@ NodeTable::NodeTable(const StorageManager* storageManager, dataFH, memoryManager, shadowFile, enableCompression); } - nodeGroups = std::make_unique(*memoryManager, - getNodeTableColumnTypes(*this), enableCompression, storageManager->getDataFH(), deSer); + nodeGroups = + std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), + enableCompression, storageManager->getDataFH(), deSer, &versionRecordHandlerSelector); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); } @@ -428,7 +432,8 @@ bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, &versionRecordHandlerData); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, + &versionRecordHandlerSelector); } } if (isDeleted) { @@ -500,7 +505,7 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { KU_ASSERT(isDeleted); if (transaction->shouldAppendToUndoBuffer()) { transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, - &versionRecordHandlerData); + &versionRecordHandlerSelector); } } } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index f2387ce6ce3..1485c9e030b 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -17,7 +17,7 @@ namespace kuzu { namespace storage { std::unique_ptr -RelTableVersionRecordHandlerData::constructVersionRecordHandler(common::row_idx_t startRow, +RelTableVersionRecordHandlerSelector::constructVersionRecordHandler(common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS, common::node_group_idx_t nodeGroupIdx) const { return relTableData->constructVersionRecordHandler(source, nodeGroupIdx, startRow, numRows, @@ -30,8 +30,8 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh : dataFH{dataFH}, tableID{tableEntry->getTableID()}, tableName{tableEntry->getName()}, memoryManager{mm}, shadowFile{shadowFile}, enableCompression{enableCompression}, direction{direction}, - persistentVersionRecordHandlerData(this, CSRNodeGroupScanSource::COMMITTED_PERSISTENT), - inMemoryVersionRecordHandlerData(this, CSRNodeGroupScanSource::COMMITTED_IN_MEMORY) { + persistentVersionRecordHandlerSelector(this, CSRNodeGroupScanSource::COMMITTED_PERSISTENT), + inMemoryVersionRecordHandlerSelector(this, CSRNodeGroupScanSource::COMMITTED_IN_MEMORY) { multiplicity = tableEntry->constCast().getMultiplicity(direction); initCSRHeaderColumns(); initPropertyColumns(tableEntry); @@ -109,7 +109,8 @@ bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVec auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); if (isDeleted && transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, getVersionRecordHandlerData(source)); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, + getVersionRecordHandlerSelector(source)); } return isDeleted; } @@ -214,7 +215,7 @@ void RelTableData::pushInsertInfo(transaction::Transaction* transaction, nodeGroup.getNumRows(); nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, - getVersionRecordHandlerData(source)); + getVersionRecordHandlerSelector(source)); } void RelTableData::checkpoint(const std::vector& columnIDs) { @@ -247,18 +248,18 @@ std::unique_ptr RelTableData::constructVersionRecordHandle startRow, numRows, commitTS); } else { KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); - return std::make_unique(nodeGroups.get(), nodeGroupIdx, - startRow, numRows, commitTS); + return std::make_unique(nodeGroups.get(), + nodeGroupIdx, startRow, numRows, commitTS); } } -const RelTableVersionRecordHandlerData* RelTableData::getVersionRecordHandlerData( +const RelTableVersionRecordHandlerSelector* RelTableData::getVersionRecordHandlerSelector( CSRNodeGroupScanSource source) { if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { - return &persistentVersionRecordHandlerData; + return &persistentVersionRecordHandlerSelector; } else { KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); - return &inMemoryVersionRecordHandlerData; + return &inMemoryVersionRecordHandlerSelector; } } diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index d30eb34343f..faf25294f9f 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -40,7 +40,7 @@ struct VersionRecord { row_idx_t startRow; row_idx_t numRows; node_group_idx_t nodeGroupIdx; - const storage::VersionRecordHandlerData* versionRecordHandlerData; + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector; }; struct VectorUpdateRecord { @@ -111,26 +111,28 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, } void UndoBuffer::createInsertInfo(node_group_idx_t nodeGroupIdx, row_idx_t startRow, - row_idx_t numRows, const storage::VersionRecordHandlerData* versionRecordHandlerData) { - createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, versionRecordHandlerData, + row_idx_t numRows, const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) { + createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, versionRecordHandlerSelector, nodeGroupIdx); } void UndoBuffer::createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const storage::VersionRecordHandlerData* versionRecordHandlerData) { - createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, versionRecordHandlerData, + common::row_idx_t numRows, + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) { + createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, versionRecordHandlerSelector, nodeGroupIdx); } void UndoBuffer::createVersionInfo(const UndoRecordType recordType, row_idx_t startRow, - row_idx_t numRows, const storage::VersionRecordHandlerData* versionRecordHandlerData, + row_idx_t numRows, const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector, node_group_idx_t nodeGroupIdx) { + KU_ASSERT(versionRecordHandlerSelector); auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); *reinterpret_cast(buffer) = - VersionRecord{startRow, numRows, nodeGroupIdx, versionRecordHandlerData}; + VersionRecord{startRow, numRows, nodeGroupIdx, versionRecordHandlerSelector}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -215,12 +217,12 @@ void UndoBuffer::commitVersionInfo(UndoRecordType recordType, const uint8_t* rec const auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - auto handler = undoRecord.versionRecordHandlerData->constructVersionRecordHandler( + auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( undoRecord.startRow, undoRecord.numRows, commitTS, undoRecord.nodeGroupIdx); handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitInsert); } break; case UndoRecordType::DELETE_INFO: { - auto handler = undoRecord.versionRecordHandlerData->constructVersionRecordHandler( + auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( undoRecord.startRow, undoRecord.numRows, commitTS, undoRecord.nodeGroupIdx); handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitDelete); } break; @@ -298,16 +300,15 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - auto it = (*undoRecord.iteratorConstructFunc)(undoRecord.startRow, undoRecord.numRows, - undoRecord.nodeGroupIdx, transaction->getCommitTS()); - it->initRollbackInsert(transaction); - it->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackInsert); - it->finalizeRollbackInsert(); + auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( + undoRecord.startRow, undoRecord.numRows, transaction->getCommitTS(), + undoRecord.nodeGroupIdx); + handler->rollbackInsert(transaction); } break; case UndoRecordType::DELETE_INFO: { - auto handler = - undoRecord.versionRecordHandlerData->constructVersionRecordHandler(undoRecord.startRow, - undoRecord.numRows, transaction->getCommitTS(), undoRecord.nodeGroupIdx); + auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( + undoRecord.startRow, undoRecord.numRows, transaction->getCommitTS(), + undoRecord.nodeGroupIdx); handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackDelete); } break; default: { diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 50a3ed536be..79d4f43c53a 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -175,14 +175,14 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ void Transaction::pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData) const { - undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerData); + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const { + undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerSelector); } void Transaction::pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const storage::VersionRecordHandlerData* versionRecordHandlerData) const { - undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerData); + const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const { + undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerSelector); } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, From 89f713fa4ad1a3410c806bd9df1e50b52062d4b4 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 14:12:38 -0500 Subject: [PATCH 20/28] Refactor version record handler again --- src/include/storage/store/csr_node_group.h | 12 --- src/include/storage/store/node_group.h | 13 +--- .../storage/store/node_group_collection.h | 11 ++- src/include/storage/store/node_table.h | 35 ++++----- src/include/storage/store/rel_table_data.h | 48 ++++++++---- .../storage/store/version_record_handler.h | 37 ++------- src/include/storage/undo_buffer.h | 9 +-- src/include/transaction/transaction.h | 8 +- src/storage/store/CMakeLists.txt | 4 +- src/storage/store/csr_node_group.cpp | 20 ----- src/storage/store/node_group.cpp | 57 +++++++------- src/storage/store/node_group_collection.cpp | 18 ++--- src/storage/store/node_table.cpp | 52 ++++++------- src/storage/store/rel_table_data.cpp | 75 ++++++++++++------- src/storage/store/version_record_handler.cpp | 10 +++ src/storage/undo_buffer.cpp | 40 +++++----- src/transaction/transaction.cpp | 10 +-- 17 files changed, 207 insertions(+), 252 deletions(-) create mode 100644 src/storage/store/version_record_handler.cpp diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index 9f4090a0f31..75125f5e8a3 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -165,18 +165,6 @@ static constexpr common::column_id_t REL_ID_COLUMN_ID = 1; struct RelTableScanState; class CSRNodeGroup final : public NodeGroup { public: - class PersistentIterator : public VersionRecordHandler { - public: - PersistentIterator(NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); - - void applyFuncToChunkedGroups(version_record_handler_op_t func) override; - void rollbackInsert(const transaction::Transaction* transaction) override; - - private: - CSRNodeGroup* nodeGroup; - }; - static constexpr PackedCSRInfo DEFAULT_PACKED_CSR_INFO{}; CSRNodeGroup(const common::node_group_idx_t nodeGroupIdx, const bool enableCompression, diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 754298f3cb8..0ab4dccf8a2 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -82,17 +82,6 @@ static auto NODE_GROUP_SCAN_EMMPTY_RESULT = NodeGroupScanResult{}; struct TableScanState; class NodeGroup { public: - class NodeGroupVersionRecordHandler : public VersionRecordHandler { - public: - NodeGroupVersionRecordHandler(NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS); - void applyFuncToChunkedGroups(version_record_handler_op_t func) override; - - protected: - NodeGroup* nodeGroup; - }; - NodeGroup(const common::node_group_idx_t nodeGroupIdx, const bool enableCompression, std::vector dataTypes, common::row_idx_t capacity = common::StorageConstants::NODE_GROUP_SIZE, @@ -163,6 +152,8 @@ class NodeGroup { void flush(transaction::Transaction* transaction, FileHandle& dataFH); + void applyFuncToChunkedGroups(version_record_handler_op_t func, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const; void rollbackInsert(common::row_idx_t startRow); virtual void checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state); diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 80a44bcba8e..4b9353054c8 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -16,7 +16,7 @@ class NodeGroupCollection { public: NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, bool enableCompression, FileHandle* dataFH = nullptr, common::Deserializer* deSer = nullptr, - const VersionRecordHandlerSelector* versionRecordHandlerSelector = nullptr); + const VersionRecordHandler* versionRecordHandler = nullptr); void append(const transaction::Transaction* transaction, const std::vector& vectors); @@ -51,7 +51,7 @@ class NodeGroupCollection { } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t groupIdx, NodeGroupDataFormat format, - const VersionRecordHandlerSelector* versionRecordHandlerSelector); + const VersionRecordHandler* versionRecordHandler); void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, std::unique_ptr group) { @@ -81,13 +81,12 @@ class NodeGroupCollection { void pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector); + common::row_idx_t numRows, const VersionRecordHandler* overridedVersionRecordHandler); private: void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector = nullptr); + const VersionRecordHandler* overridedVersionRecordHandler = nullptr); bool enableCompression; // Num rows in the collection regardless of deletions. @@ -96,7 +95,7 @@ class NodeGroupCollection { GroupCollection nodeGroups; FileHandle* dataFH; TableStats stats; - const VersionRecordHandlerSelector* versionRecordHandlerSelector; + const VersionRecordHandler* versionRecordHandler; }; } // namespace storage diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index fc717da66b2..13b2fba8076 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -96,32 +96,25 @@ struct PKColumnScanHelper { PrimaryKeyIndex* pkIndex; }; -class NodeTableVersionRecordHandlerSelector : public VersionRecordHandlerSelector { +class NodeTableVersionRecordHandler : public VersionRecordHandler { public: - explicit NodeTableVersionRecordHandlerSelector(NodeTable* nodeTable) : nodeTable(nodeTable) {} + explicit NodeTableVersionRecordHandler(NodeTable* table); - std::unique_ptr constructVersionRecordHandler(common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS, - common::node_group_idx_t nodeGroupIdx) const override; + void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const override; + void rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const override; private: - NodeTable* nodeTable; + NodeTable* table; }; class StorageManager; + class NodeTable final : public Table { public: - class ChunkedGroupIterator : public NodeGroup::NodeGroupVersionRecordHandler { - public: - ChunkedGroupIterator(NodeTable* table, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS); - - void rollbackInsert(const transaction::Transaction* transaction) override; - - private: - NodeTable* table; - }; - static std::vector getNodeTableColumnTypes(const NodeTable& table) { std::vector types; for (auto i = 0u; i < table.getNumColumns(); i++) { @@ -194,8 +187,10 @@ class NodeTable final : public Table { void commit(transaction::Transaction* transaction, LocalTable* localTable) override; void checkpoint(common::Serializer& ser, catalog::TableCatalogEntry* tableEntry) override; - void rollbackInsert(const transaction::Transaction* transaction, common::row_idx_t startRow, - common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx); + void rollbackPKIndexInsert(const transaction::Transaction* transaction, + common::row_idx_t startRow, common::row_idx_t numRows_, + common::node_group_idx_t nodeGroupIdx); + void rollbackGroupCollectionInsert(common::row_idx_t numRows_); common::node_group_idx_t getNumCommittedNodeGroups() const { return nodeGroups->getNumNodeGroups(); @@ -230,7 +225,7 @@ class NodeTable final : public Table { std::unique_ptr nodeGroups; common::column_id_t pkColumnID; std::unique_ptr pkIndex; - NodeTableVersionRecordHandlerSelector versionRecordHandlerSelector; + NodeTableVersionRecordHandler versionRecordHandler; }; } // namespace storage diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index 637c5d9c8c3..f8d4d66b22b 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -21,18 +21,34 @@ struct CSRHeaderColumns { std::unique_ptr length; }; -class RelTableVersionRecordHandlerSelector : public VersionRecordHandlerSelector { +class PersistentVersionRecordHandler : public VersionRecordHandler { public: - RelTableVersionRecordHandlerSelector(RelTableData* relTableData, CSRNodeGroupScanSource source) - : relTableData(relTableData), source(source) {} + explicit PersistentVersionRecordHandler(RelTableData* relTableData); - std::unique_ptr constructVersionRecordHandler(common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS, - common::node_group_idx_t nodeGroupIdx) const override; + void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const override; + void rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const override; + +private: + RelTableData* relTableData; +}; + +class InMemoryVersionRecordHandler : public VersionRecordHandler { +public: + explicit InMemoryVersionRecordHandler(RelTableData* relTableData); + + void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const override; + void rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const override; private: RelTableData* relTableData; - CSRNodeGroupScanSource source; }; class RelTableData { @@ -71,7 +87,7 @@ class RelTableData { NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx) const { return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, NodeGroupDataFormat::CSR, - &persistentVersionRecordHandlerSelector); + &persistentVersionRecordHandler); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } @@ -85,10 +101,11 @@ class RelTableData { void serialize(common::Serializer& serializer) const; - std::unique_ptr constructVersionRecordHandler( - CSRNodeGroupScanSource source, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, - common::transaction_t commitTS) const; + NodeGroup* getNodeGroupNoLock(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getNodeGroupNoLock(nodeGroupIdx); + } + + void rollbackGroupCollectionInsert(common::row_idx_t numRows_, bool isPersistent); private: void initCSRHeaderColumns(); @@ -118,8 +135,7 @@ class RelTableData { return types; } - const RelTableVersionRecordHandlerSelector* getVersionRecordHandlerSelector( - CSRNodeGroupScanSource source); + const VersionRecordHandler* getVersionRecordHandler(CSRNodeGroupScanSource source); private: FileHandle* dataFH; @@ -137,8 +153,8 @@ class RelTableData { CSRHeaderColumns csrHeaderColumns; std::vector> columns; - RelTableVersionRecordHandlerSelector persistentVersionRecordHandlerSelector; - RelTableVersionRecordHandlerSelector inMemoryVersionRecordHandlerSelector; + PersistentVersionRecordHandler persistentVersionRecordHandler; + InMemoryVersionRecordHandler inMemoryVersionRecordHandler; }; } // namespace storage diff --git a/src/include/storage/store/version_record_handler.h b/src/include/storage/store/version_record_handler.h index fe6e137b70a..4fded7703f4 100644 --- a/src/include/storage/store/version_record_handler.h +++ b/src/include/storage/store/version_record_handler.h @@ -2,13 +2,10 @@ #include "common/types/types.h" #include "storage/store/chunked_node_group.h" +#include "transaction/transaction.h" namespace kuzu { -namespace transaction { -class Transaction; -} - namespace storage { class NodeGroupCollection; class VersionRecordHandler; @@ -19,35 +16,15 @@ using version_record_handler_op_t = void ( // Note: these iterators are not necessarily thread-safe when used on their own class VersionRecordHandler { public: - VersionRecordHandler(NodeGroupCollection* nodeGroups, common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS) - : startRow(startRow), numRows(numRows), commitTS(commitTS), nodeGroups(nodeGroups) {} - virtual ~VersionRecordHandler() = default; - virtual void applyFuncToChunkedGroups(version_record_handler_op_t func) = 0; - - virtual void rollbackInsert(const transaction::Transaction* /*transaction*/) { - applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackInsert); - } - -protected: - common::row_idx_t startRow; - common::row_idx_t numRows; - common::transaction_t commitTS; - - NodeGroupCollection* nodeGroups; -}; - -// Contains pointer to a table + any information needed by the table to construct a -// VersionRecordHandler -class VersionRecordHandlerSelector { -public: - virtual ~VersionRecordHandlerSelector() = default; + virtual void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const = 0; - virtual std::unique_ptr constructVersionRecordHandler( - common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS, - common::node_group_idx_t nodeGroupIdx) const = 0; + virtual void rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const; }; } // namespace storage diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index f5fc93819d2..38eb0b4158e 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -88,11 +88,9 @@ class UndoBuffer { void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, const catalog::SequenceRollbackData& data); void createInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector); + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler); void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector); + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, VectorUpdateInfo* vectorUpdateInfo); @@ -105,8 +103,7 @@ class UndoBuffer { uint8_t* createUndoRecord(uint64_t size); void createVersionInfo(UndoRecordType recordType, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector, + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler, common::node_group_idx_t nodeGroupIdx = 0); void commitRecord(UndoRecordType recordType, const uint8_t* record, diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index 207f824472c..d50a2c31225 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -1,7 +1,5 @@ #pragma once -#include - #include "common/enums/statement_type.h" #include "common/types/types.h" @@ -119,11 +117,9 @@ class Transaction { void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount, const catalog::SequenceRollbackData& data) const; void pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const; + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const; void pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const; + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const; diff --git a/src/storage/store/CMakeLists.txt b/src/storage/store/CMakeLists.txt index 4c368d87b9b..3d02a1629a6 100644 --- a/src/storage/store/CMakeLists.txt +++ b/src/storage/store/CMakeLists.txt @@ -29,7 +29,9 @@ add_library(kuzu_storage_store struct_column.cpp table.cpp update_info.cpp - version_info.cpp) + version_info.cpp + version_record_handler.cpp +) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index f84b2f5005b..0d3d0a99580 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -12,26 +12,6 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -CSRNodeGroup::PersistentIterator::PersistentIterator(NodeGroupCollection* nodeGroups, - common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - common::transaction_t commitTS) - : VersionRecordHandler(nodeGroups, startRow, numRows, commitTS), nodeGroup(nullptr) { - if (nodeGroupIdx < nodeGroups->getNumNodeGroups()) { - nodeGroup = ku_dynamic_cast(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)); - } -} - -void CSRNodeGroup::PersistentIterator::applyFuncToChunkedGroups(version_record_handler_op_t func) { - if (nodeGroup && nodeGroup->persistentChunkGroup) { - std::invoke(func, *nodeGroup->persistentChunkGroup, startRow, numRows, commitTS); - } -} - -void CSRNodeGroup::PersistentIterator::rollbackInsert(const transaction::Transaction* transaction) { - VersionRecordHandler::rollbackInsert(transaction); - nodeGroups->rollbackInsert(numRows, false); -} - bool CSRNodeGroupScanState::tryScanCachedTuples(RelTableScanState& tableScanState) { if (numCachedRows == 0 || tableScanState.currBoundNodeIdx >= tableScanState.cachedBoundNodeSelVector.getSelSize()) { diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index 9cb5e2f60bd..7dd24807641 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -21,38 +21,6 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -NodeGroup::NodeGroupVersionRecordHandler::NodeGroupVersionRecordHandler( - NodeGroupCollection* nodeGroups, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, transaction_t commitTS) - : VersionRecordHandler(nodeGroups, startRow, numRows, commitTS), - nodeGroup(nodeGroups->getNodeGroupNoLock(nodeGroupIdx)) { - KU_ASSERT(startRow <= nodeGroup->getNumRows()); -} - -void NodeGroup::NodeGroupVersionRecordHandler::applyFuncToChunkedGroups( - version_record_handler_op_t func) { - auto lock = nodeGroup->chunkedGroups.lock(); - const auto [chunkedGroupIdx, startRowInChunkedGroup] = - nodeGroup->findChunkedGroupIdxFromRowIdxNoLock(startRow); - if (chunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { - auto curChunkedGroupIdx = chunkedGroupIdx; - auto curStartRowIdxInChunk = startRowInChunkedGroup; - - auto numRowsLeft = numRows; - while ( - numRowsLeft > 0 && curChunkedGroupIdx < nodeGroup->chunkedGroups.getNumGroups(lock)) { - auto* chunkedGroup = nodeGroup->chunkedGroups.getGroup(lock, curChunkedGroupIdx); - const auto numRowsForGroup = - std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); - std::invoke(func, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); - - ++curChunkedGroupIdx; - numRowsLeft -= numRowsForGroup; - curStartRowIdxInChunk = 0; - } - } -} - row_idx_t NodeGroup::append(const Transaction* transaction, ChunkedNodeGroup& chunkedGroup, row_idx_t startRowIdx, row_idx_t numRowsToAppend) { KU_ASSERT(numRowsToAppend <= chunkedGroup.getNumRows()); @@ -684,5 +652,30 @@ bool NodeGroup::isInserted(const Transaction* transaction, offset_t offsetInGrou return chunkedGroup->isInserted(transaction, offsetInGroup - chunkedGroup->getStartRowIdx()); } +void NodeGroup::applyFuncToChunkedGroups(version_record_handler_op_t func, + common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) const { + KU_ASSERT(startRow <= getNumRows()); + + auto lock = chunkedGroups.lock(); + const auto [chunkedGroupIdx, startRowInChunkedGroup] = + findChunkedGroupIdxFromRowIdxNoLock(startRow); + if (chunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { + auto curChunkedGroupIdx = chunkedGroupIdx; + auto curStartRowIdxInChunk = startRowInChunkedGroup; + + auto numRowsLeft = numRows; + while (numRowsLeft > 0 && curChunkedGroupIdx < chunkedGroups.getNumGroups(lock)) { + auto* chunkedGroup = chunkedGroups.getGroup(lock, curChunkedGroupIdx); + const auto numRowsForGroup = + std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); + std::invoke(func, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); + + ++curChunkedGroupIdx; + numRowsLeft -= numRowsForGroup; + curStartRowIdxInChunk = 0; + } + } +} + } // namespace storage } // namespace kuzu diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index e456e17afd9..2ece802cb81 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -14,9 +14,9 @@ namespace storage { NodeGroupCollection::NodeGroupCollection(MemoryManager& memoryManager, const std::vector& types, const bool enableCompression, FileHandle* dataFH, - Deserializer* deSer, const VersionRecordHandlerSelector* versionRecordHandlerSelector) + Deserializer* deSer, const VersionRecordHandler* versionRecordHandler) : enableCompression{enableCompression}, numTotalRows{0}, types{LogicalType::copy(types)}, - dataFH{dataFH}, versionRecordHandlerSelector(versionRecordHandlerSelector) { + dataFH{dataFH}, versionRecordHandler(versionRecordHandler) { if (deSer) { deserialize(*deSer, memoryManager); } @@ -155,7 +155,7 @@ row_idx_t NodeGroupCollection::getNumTotalRows() { NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* transaction, node_group_idx_t groupIdx, NodeGroupDataFormat format, - const VersionRecordHandlerSelector* versionRecordHandlerSelector) { + const VersionRecordHandler* versionRecordHandler) { const auto lock = nodeGroups.lock(); while (groupIdx >= nodeGroups.getNumGroups(lock)) { const auto currentGroupIdx = nodeGroups.getNumGroups(lock); @@ -166,7 +166,7 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* t enableCompression, LogicalType::copy(types))); // push an insert of size 0 so that we can rollback the creation of this node group if // needed - pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, versionRecordHandlerSelector); + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, versionRecordHandler); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); return nodeGroups.getGroup(lock, groupIdx); @@ -213,19 +213,17 @@ void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool update void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, common::row_idx_t numRows, - const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector) { + const VersionRecordHandler* overridedVersionRecordHandler) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, - overridedVersionRecordHandlerSelector ? overridedVersionRecordHandlerSelector : - versionRecordHandlerSelector); + overridedVersionRecordHandler ? overridedVersionRecordHandler : versionRecordHandler); }; void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const VersionRecordHandlerSelector* overridedVersionRecordHandlerSelector) { + const VersionRecordHandler* overridedVersionRecordHandler) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, - overridedVersionRecordHandlerSelector); + transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, overridedVersionRecordHandler); } } diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index e95315ae510..d78b0436fb5 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -21,27 +21,24 @@ using namespace kuzu::evaluator; namespace kuzu { namespace storage { -std::unique_ptr -NodeTableVersionRecordHandlerSelector::constructVersionRecordHandler(common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS, - common::node_group_idx_t nodeGroupIdx) const { - return std::make_unique(nodeTable, nodeGroupIdx, startRow, - numRows, commitTS); -} - -NodeTable::ChunkedGroupIterator::ChunkedGroupIterator(NodeTable* table, - node_group_idx_t nodeGroupidx, common::row_idx_t startRow, common::row_idx_t numRows, - common::transaction_t commitTS) - : NodeGroup::NodeGroupVersionRecordHandler(table->nodeGroups.get(), nodeGroupidx, startRow, - numRows, commitTS), - table(table) {} - -void NodeTable::ChunkedGroupIterator::rollbackInsert(const transaction::Transaction* transaction) { - table->rollbackInsert(transaction, startRow, numRows, nodeGroup->getNodeGroupIdx()); - NodeGroup::NodeGroupVersionRecordHandler::rollbackInsert(transaction); - nodeGroup->rollbackInsert(startRow); +NodeTableVersionRecordHandler::NodeTableVersionRecordHandler(NodeTable* table) : table(table) {} + +void NodeTableVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS) const { + auto* nodeGroup = table->getNodeGroupNoLock(nodeGroupIdx); + nodeGroup->applyFuncToChunkedGroups(func, startRow, numRows, commitTS); +} + +void NodeTableVersionRecordHandler::rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + table->rollbackPKIndexInsert(transaction, startRow, numRows, nodeGroupIdx); + + VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); + auto* nodeGroup = table->getNodeGroupNoLock(nodeGroupIdx); const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); - nodeGroups->rollbackInsert(numRowsToRollback); + table->rollbackGroupCollectionInsert(numRowsToRollback); } bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset, @@ -201,7 +198,7 @@ NodeTable::NodeTable(const StorageManager* storageManager, VirtualFileSystem* vfs, main::ClientContext* context, Deserializer* deSer) : Table{nodeTableEntry, storageManager, memoryManager}, pkColumnID{nodeTableEntry->getColumnID(nodeTableEntry->getPrimaryKeyName())}, - versionRecordHandlerSelector(this) { + versionRecordHandler(this) { const auto maxColumnID = nodeTableEntry->getMaxColumnID(); columns.resize(maxColumnID + 1); for (auto i = 0u; i < nodeTableEntry->getNumProperties(); i++) { @@ -215,7 +212,7 @@ NodeTable::NodeTable(const StorageManager* storageManager, nodeGroups = std::make_unique(*memoryManager, getNodeTableColumnTypes(*this), - enableCompression, storageManager->getDataFH(), deSer, &versionRecordHandlerSelector); + enableCompression, storageManager->getDataFH(), deSer, &versionRecordHandler); initializePKIndex(storageManager->getDatabasePath(), nodeTableEntry, storageManager->isReadOnly(), vfs, context); } @@ -432,8 +429,7 @@ bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); if (transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, - &versionRecordHandlerSelector); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, &versionRecordHandler); } } if (isDeleted) { @@ -505,7 +501,7 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { KU_ASSERT(isDeleted); if (transaction->shouldAppendToUndoBuffer()) { transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, - &versionRecordHandlerSelector); + &versionRecordHandler); } } } @@ -556,7 +552,7 @@ void NodeTable::checkpoint(Serializer& ser, TableCatalogEntry* tableEntry) { serialize(ser); } -void NodeTable::rollbackInsert(const transaction::Transaction* transaction, +void NodeTable::rollbackPKIndexInsert(const transaction::Transaction* transaction, common::row_idx_t startRow, common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx_) { row_idx_t startNodeOffset = startRow; @@ -568,6 +564,10 @@ void NodeTable::rollbackInsert(const transaction::Transaction* transaction, scanPKColumn(transaction, pkDeleter, *nodeGroups); } +void NodeTable::rollbackGroupCollectionInsert(common::row_idx_t numRows_) { + nodeGroups->rollbackInsert(numRows_); +} + TableStats NodeTable::getStats(const Transaction* transaction) const { auto stats = nodeGroups->getStats(); const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID, diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index 1485c9e030b..ffe1659c2a9 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -16,12 +16,45 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -std::unique_ptr -RelTableVersionRecordHandlerSelector::constructVersionRecordHandler(common::row_idx_t startRow, - common::row_idx_t numRows, common::transaction_t commitTS, - common::node_group_idx_t nodeGroupIdx) const { - return relTableData->constructVersionRecordHandler(source, nodeGroupIdx, startRow, numRows, - commitTS); +PersistentVersionRecordHandler::PersistentVersionRecordHandler(RelTableData* relTableData) + : relTableData(relTableData) {} + +void PersistentVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS) const { + if (nodeGroupIdx < relTableData->getNumNodeGroups()) { + auto& nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx)->cast(); + auto* persistentChunkedGroup = nodeGroup.getPersistentChunkedGroup(); + if (persistentChunkedGroup) { + std::invoke(func, *persistentChunkedGroup, startRow, numRows, commitTS); + } + } +} + +void PersistentVersionRecordHandler::rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); + relTableData->rollbackGroupCollectionInsert(numRows, true); +} + +InMemoryVersionRecordHandler::InMemoryVersionRecordHandler(RelTableData* relTableData) + : relTableData(relTableData) {} + +void InMemoryVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS) const { + auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx); + nodeGroup->applyFuncToChunkedGroups(func, startRow, numRows, commitTS); +} + +void InMemoryVersionRecordHandler::rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); + auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx); + const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); + relTableData->rollbackGroupCollectionInsert(numRowsToRollback, false); } RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, @@ -29,9 +62,8 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh Deserializer* deSer) : dataFH{dataFH}, tableID{tableEntry->getTableID()}, tableName{tableEntry->getName()}, memoryManager{mm}, shadowFile{shadowFile}, enableCompression{enableCompression}, - direction{direction}, - persistentVersionRecordHandlerSelector(this, CSRNodeGroupScanSource::COMMITTED_PERSISTENT), - inMemoryVersionRecordHandlerSelector(this, CSRNodeGroupScanSource::COMMITTED_IN_MEMORY) { + direction{direction}, persistentVersionRecordHandler(this), + inMemoryVersionRecordHandler(this) { multiplicity = tableEntry->constCast().getMultiplicity(direction); initCSRHeaderColumns(); initPropertyColumns(tableEntry); @@ -109,8 +141,7 @@ bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVec auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); if (isDeleted && transaction->shouldAppendToUndoBuffer()) { - transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, - getVersionRecordHandlerSelector(source)); + transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, getVersionRecordHandler(source)); } return isDeleted; } @@ -215,7 +246,7 @@ void RelTableData::pushInsertInfo(transaction::Transaction* transaction, nodeGroup.getNumRows(); nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, - getVersionRecordHandlerSelector(source)); + getVersionRecordHandler(source)); } void RelTableData::checkpoint(const std::vector& columnIDs) { @@ -240,27 +271,17 @@ void RelTableData::serialize(Serializer& serializer) const { nodeGroups->serialize(serializer); } -std::unique_ptr RelTableData::constructVersionRecordHandler( - CSRNodeGroupScanSource source, common::node_group_idx_t nodeGroupIdx, - common::row_idx_t startRow, common::row_idx_t numRows, common::transaction_t commitTS) const { +const VersionRecordHandler* RelTableData::getVersionRecordHandler(CSRNodeGroupScanSource source) { if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { - return std::make_unique(nodeGroups.get(), nodeGroupIdx, - startRow, numRows, commitTS); + return &persistentVersionRecordHandler; } else { KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); - return std::make_unique(nodeGroups.get(), - nodeGroupIdx, startRow, numRows, commitTS); + return &inMemoryVersionRecordHandler; } } -const RelTableVersionRecordHandlerSelector* RelTableData::getVersionRecordHandlerSelector( - CSRNodeGroupScanSource source) { - if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { - return &persistentVersionRecordHandlerSelector; - } else { - KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); - return &inMemoryVersionRecordHandlerSelector; - } +void RelTableData::rollbackGroupCollectionInsert(common::row_idx_t numRows_, bool isPersistent) { + nodeGroups->rollbackInsert(numRows_, !isPersistent); } } // namespace storage diff --git a/src/storage/store/version_record_handler.cpp b/src/storage/store/version_record_handler.cpp new file mode 100644 index 00000000000..62953c13741 --- /dev/null +++ b/src/storage/store/version_record_handler.cpp @@ -0,0 +1,10 @@ +#include "storage/store/version_record_handler.h" + +namespace kuzu::storage { +void VersionRecordHandler::rollbackInsert(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackInsert, nodeGroupIdx, startRow, numRows, + transaction->getCommitTS()); +} +} // namespace kuzu::storage diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index faf25294f9f..a33aa14ef70 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -40,7 +40,7 @@ struct VersionRecord { row_idx_t startRow; row_idx_t numRows; node_group_idx_t nodeGroupIdx; - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector; + const storage::VersionRecordHandler* versionRecordHandler; }; struct VectorUpdateRecord { @@ -111,28 +111,27 @@ void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, } void UndoBuffer::createInsertInfo(node_group_idx_t nodeGroupIdx, row_idx_t startRow, - row_idx_t numRows, const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) { - createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, versionRecordHandlerSelector, + row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) { + createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, versionRecordHandler, nodeGroupIdx); } void UndoBuffer::createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) { - createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, versionRecordHandlerSelector, + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) { + createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, versionRecordHandler, nodeGroupIdx); } void UndoBuffer::createVersionInfo(const UndoRecordType recordType, row_idx_t startRow, - row_idx_t numRows, const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector, + row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler, node_group_idx_t nodeGroupIdx) { - KU_ASSERT(versionRecordHandlerSelector); + KU_ASSERT(versionRecordHandler); auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); *reinterpret_cast(buffer) = - VersionRecord{startRow, numRows, nodeGroupIdx, versionRecordHandlerSelector}; + VersionRecord{startRow, numRows, nodeGroupIdx, versionRecordHandler}; } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, @@ -217,14 +216,12 @@ void UndoBuffer::commitVersionInfo(UndoRecordType recordType, const uint8_t* rec const auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( - undoRecord.startRow, undoRecord.numRows, commitTS, undoRecord.nodeGroupIdx); - handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitInsert); + undoRecord.versionRecordHandler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitInsert, + undoRecord.nodeGroupIdx, undoRecord.startRow, undoRecord.numRows, commitTS); } break; case UndoRecordType::DELETE_INFO: { - auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( - undoRecord.startRow, undoRecord.numRows, commitTS, undoRecord.nodeGroupIdx); - handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitDelete); + undoRecord.versionRecordHandler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitDelete, + undoRecord.nodeGroupIdx, undoRecord.startRow, undoRecord.numRows, commitTS); } break; default: { KU_UNREACHABLE; @@ -300,16 +297,13 @@ void UndoBuffer::rollbackVersionInfo(const transaction::Transaction* transaction auto& undoRecord = *reinterpret_cast(record); switch (recordType) { case UndoRecordType::INSERT_INFO: { - auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( - undoRecord.startRow, undoRecord.numRows, transaction->getCommitTS(), - undoRecord.nodeGroupIdx); - handler->rollbackInsert(transaction); + undoRecord.versionRecordHandler->rollbackInsert(transaction, undoRecord.nodeGroupIdx, + undoRecord.startRow, undoRecord.numRows); } break; case UndoRecordType::DELETE_INFO: { - auto handler = undoRecord.versionRecordHandlerSelector->constructVersionRecordHandler( - undoRecord.startRow, undoRecord.numRows, transaction->getCommitTS(), - undoRecord.nodeGroupIdx); - handler->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackDelete); + undoRecord.versionRecordHandler->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackDelete, + undoRecord.nodeGroupIdx, undoRecord.startRow, undoRecord.numRows, + transaction->getCommitTS()); } break; default: { KU_UNREACHABLE; diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 79d4f43c53a..3c539ce5b25 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -174,15 +174,13 @@ void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_ } void Transaction::pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const { - undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerSelector); + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const { + undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandler); } void Transaction::pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, - const storage::VersionRecordHandlerSelector* versionRecordHandlerSelector) const { - undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, versionRecordHandlerSelector); + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const { + undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, versionRecordHandler); } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, From 08757c393f8185cb79369938f766020c2bb685c4 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 14:46:29 -0500 Subject: [PATCH 21/28] Rollback insert for node groups --- src/storage/store/node_group.cpp | 1 + src/storage/store/node_table.cpp | 1 + src/storage/store/rel_table_data.cpp | 1 + 3 files changed, 3 insertions(+) diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index 7dd24807641..b1461d7de4b 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -185,6 +185,7 @@ NodeGroupScanResult NodeGroup::scan(const Transaction* transaction, TableScanSta } const auto& chunkedGroupToScan = *chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); + KU_ASSERT(nodeGroupScanState.nextRowToScan >= chunkedGroupToScan.getStartRowIdx()); const auto rowIdxInChunkToScan = nodeGroupScanState.nextRowToScan - chunkedGroupToScan.getStartRowIdx(); const auto numRowsToScan = diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index d78b0436fb5..3355a770c4e 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -37,6 +37,7 @@ void NodeTableVersionRecordHandler::rollbackInsert(const transaction::Transactio VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); auto* nodeGroup = table->getNodeGroupNoLock(nodeGroupIdx); + nodeGroup->rollbackInsert(startRow); const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); table->rollbackGroupCollectionInsert(numRowsToRollback); } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index ffe1659c2a9..be481a3b1cc 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -53,6 +53,7 @@ void InMemoryVersionRecordHandler::rollbackInsert(const transaction::Transaction common::row_idx_t numRows) const { VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx); + nodeGroup->rollbackInsert(startRow); const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); relTableData->rollbackGroupCollectionInsert(numRowsToRollback, false); } From 6b30e3513485ec2a7bde0ccef0d3f959b612d7bd Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 14:57:08 -0500 Subject: [PATCH 22/28] Remove unused forward declares --- src/include/storage/store/node_group.h | 1 - src/include/storage/store/version_record_handler.h | 2 -- src/include/storage/undo_buffer.h | 2 -- src/include/transaction/transaction.h | 2 -- 4 files changed, 7 deletions(-) diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 0ab4dccf8a2..ce97a8a5213 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -18,7 +18,6 @@ class MemoryManager; struct TableAddColumnState; class NodeGroup; -class NodeGroupCollection; struct NodeGroupScanState { // Index of committed but not yet checkpointed chunked group to scan. diff --git a/src/include/storage/store/version_record_handler.h b/src/include/storage/store/version_record_handler.h index 4fded7703f4..d3cf5516397 100644 --- a/src/include/storage/store/version_record_handler.h +++ b/src/include/storage/store/version_record_handler.h @@ -7,8 +7,6 @@ namespace kuzu { namespace storage { -class NodeGroupCollection; -class VersionRecordHandler; using version_record_handler_op_t = void ( ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 38eb0b4158e..5b02cfd452a 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -66,8 +66,6 @@ class UndoBufferIterator { class UpdateInfo; class VersionInfo; struct VectorUpdateInfo; -class RelTableData; -class NodeTable; class WAL; // This class is not thread safe, as it is supposed to be accessed by a single thread. class UndoBuffer { diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index d50a2c31225..0a237fc5dac 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -22,11 +22,9 @@ class UpdateInfo; struct VectorUpdateInfo; class ChunkedNodeGroup; class VersionRecordHandler; -class VersionRecordHandlerSelector; } // namespace storage namespace transaction { class TransactionManager; -class Transaction; enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; From 1a554437a0a479830bac34bf6ceed061ea30cfa9 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 15:05:00 -0500 Subject: [PATCH 23/28] Get correct num of total rows to rollback in node group collection --- src/storage/store/node_table.cpp | 2 +- src/storage/store/rel_table_data.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 3355a770c4e..432878ad02d 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -37,8 +37,8 @@ void NodeTableVersionRecordHandler::rollbackInsert(const transaction::Transactio VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); auto* nodeGroup = table->getNodeGroupNoLock(nodeGroupIdx); - nodeGroup->rollbackInsert(startRow); const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); + nodeGroup->rollbackInsert(startRow); table->rollbackGroupCollectionInsert(numRowsToRollback); } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index be481a3b1cc..ec2ef50fdd1 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -53,8 +53,8 @@ void InMemoryVersionRecordHandler::rollbackInsert(const transaction::Transaction common::row_idx_t numRows) const { VersionRecordHandler::rollbackInsert(transaction, nodeGroupIdx, startRow, numRows); auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx); - nodeGroup->rollbackInsert(startRow); const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); + nodeGroup->rollbackInsert(startRow); relTableData->rollbackGroupCollectionInsert(numRowsToRollback, false); } From 84b6e1735d5700a6a6bf1ba887d80b2966150f4a Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 16:14:03 -0500 Subject: [PATCH 24/28] Make BM exception during rel commit trigger earlier --- test/copy/copy_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/copy/copy_test.cpp b/test/copy/copy_test.cpp index dfecfdcff8e..9fe9de5b9f5 100644 --- a/test/copy/copy_test.cpp +++ b/test/copy/copy_test.cpp @@ -218,7 +218,7 @@ TEST_F(CopyTest, RelInsertBMExceptionDuringCommitRecovery) { .canFailDuringCheckpoint = false, .initFunc = [this](main::Connection* conn) { - failureFrequency = 128; + failureFrequency = 32; conn->query("CREATE NODE TABLE account(ID INT64, PRIMARY KEY(ID))"); conn->query("CREATE REL TABLE follows(FROM account TO account);"); const auto queryString = common::stringFormat( From 40691ca9a421fc7f60d49c6c048359fdf9d96396 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Fri, 29 Nov 2024 19:04:24 -0500 Subject: [PATCH 25/28] Update num total rows for rel table data node group collection --- .../storage/store/node_group_collection.h | 9 ++++----- src/include/storage/store/rel_table_data.h | 4 ++-- src/storage/store/node_group_collection.cpp | 17 +++++++++-------- src/storage/store/rel_table_data.cpp | 14 +++++++++----- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/include/storage/store/node_group_collection.h b/src/include/storage/store/node_group_collection.h index 4b9353054c8..d3a18324ad8 100644 --- a/src/include/storage/store/node_group_collection.h +++ b/src/include/storage/store/node_group_collection.h @@ -50,8 +50,7 @@ class NodeGroupCollection { return nodeGroups.getGroup(lock, groupIdx); } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, - common::node_group_idx_t groupIdx, NodeGroupDataFormat format, - const VersionRecordHandler* versionRecordHandler); + common::node_group_idx_t groupIdx, NodeGroupDataFormat format); void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, std::unique_ptr group) { @@ -81,12 +80,12 @@ class NodeGroupCollection { void pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const VersionRecordHandler* overridedVersionRecordHandler); + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler, + bool incrementNumTotalRows); private: void pushInsertInfo(const transaction::Transaction* transaction, NodeGroup* nodeGroup, - common::row_idx_t numRows, - const VersionRecordHandler* overridedVersionRecordHandler = nullptr); + common::row_idx_t numRows); bool enableCompression; // Num rows in the collection regardless of deletions. diff --git a/src/include/storage/store/rel_table_data.h b/src/include/storage/store/rel_table_data.h index f8d4d66b22b..62beab1d439 100644 --- a/src/include/storage/store/rel_table_data.h +++ b/src/include/storage/store/rel_table_data.h @@ -86,8 +86,8 @@ class RelTableData { } NodeGroup* getOrCreateNodeGroup(transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx) const { - return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, NodeGroupDataFormat::CSR, - &persistentVersionRecordHandler); + return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, + NodeGroupDataFormat::CSR); } common::RelMultiplicity getMultiplicity() const { return multiplicity; } diff --git a/src/storage/store/node_group_collection.cpp b/src/storage/store/node_group_collection.cpp index 2ece802cb81..b577ad1c971 100644 --- a/src/storage/store/node_group_collection.cpp +++ b/src/storage/store/node_group_collection.cpp @@ -154,8 +154,7 @@ row_idx_t NodeGroupCollection::getNumTotalRows() { } NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* transaction, - node_group_idx_t groupIdx, NodeGroupDataFormat format, - const VersionRecordHandler* versionRecordHandler) { + node_group_idx_t groupIdx, NodeGroupDataFormat format) { const auto lock = nodeGroups.lock(); while (groupIdx >= nodeGroups.getNumGroups(lock)) { const auto currentGroupIdx = nodeGroups.getNumGroups(lock); @@ -166,7 +165,7 @@ NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(transaction::Transaction* t enableCompression, LogicalType::copy(types))); // push an insert of size 0 so that we can rollback the creation of this node group if // needed - pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0, versionRecordHandler); + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0); } KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); return nodeGroups.getGroup(lock, groupIdx); @@ -212,18 +211,20 @@ void NodeGroupCollection::rollbackInsert(common::row_idx_t numRows_, bool update } void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, - NodeGroup* nodeGroup, common::row_idx_t numRows, - const VersionRecordHandler* overridedVersionRecordHandler) { + NodeGroup* nodeGroup, common::row_idx_t numRows) { pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, - overridedVersionRecordHandler ? overridedVersionRecordHandler : versionRecordHandler); + versionRecordHandler, false); }; void NodeGroupCollection::pushInsertInfo(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, - const VersionRecordHandler* overridedVersionRecordHandler) { + const VersionRecordHandler* versionRecordHandler, bool incrementNumRows) { // we only append to the undo buffer if the node group collection is persistent if (dataFH && transaction->shouldAppendToUndoBuffer()) { - transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, overridedVersionRecordHandler); + transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandler); + } + if (incrementNumRows) { + numTotalRows += numRows; } } diff --git a/src/storage/store/rel_table_data.cpp b/src/storage/store/rel_table_data.cpp index ec2ef50fdd1..2d9f8a33343 100644 --- a/src/storage/store/rel_table_data.cpp +++ b/src/storage/store/rel_table_data.cpp @@ -69,8 +69,11 @@ RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* sh initCSRHeaderColumns(); initPropertyColumns(tableEntry); + // default to using the persistent version record handler + // if we want to use the in-memory handler we will explicitly pass it into + // nodeGroups.pushInsertInfo() nodeGroups = std::make_unique(*mm, getColumnTypes(), enableCompression, - dataFH, deSer); + dataFH, deSer, &persistentVersionRecordHandler); } void RelTableData::initCSRHeaderColumns() { @@ -242,12 +245,13 @@ void RelTableData::pushInsertInfo(transaction::Transaction* transaction, !nodeGroup.getPersistentChunkedGroup() || nodeGroup.getPersistentChunkedGroup()->getNumRows() == 0); - const auto startRow = (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? - static_cast(0) : - nodeGroup.getNumRows(); + const auto [startRow, shouldIncrementNumRows] = + (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? + std::make_pair(static_cast(0), false) : + std::make_pair(nodeGroup.getNumRows(), true); nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, - getVersionRecordHandler(source)); + getVersionRecordHandler(source), shouldIncrementNumRows); } void RelTableData::checkpoint(const std::vector& columnIDs) { From 4ad2632eae3c888938b4c36af62282e2d5737270 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 2 Dec 2024 09:06:30 -0500 Subject: [PATCH 26/28] Address review comments --- src/include/storage/buffer_manager/buffer_manager.h | 8 ++++++-- src/include/storage/store/version_record_handler.h | 3 ++- src/include/storage/undo_buffer.h | 8 ++++---- src/include/transaction/transaction_manager.h | 2 ++ src/storage/store/version_record_handler.cpp | 2 ++ src/storage/undo_buffer.cpp | 2 ++ test/copy/copy_test.cpp | 12 +++++++++--- 7 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/include/storage/buffer_manager/buffer_manager.h b/src/include/storage/buffer_manager/buffer_manager.h index a1ea8a785d4..9e2784f0cea 100644 --- a/src/include/storage/buffer_manager/buffer_manager.h +++ b/src/include/storage/buffer_manager/buffer_manager.h @@ -20,8 +20,9 @@ namespace common { class VirtualFileSystem; }; namespace testing { -class EmptyBufferManagerTest; -}; +class FlakyBufferManager; +class CopyTestHelper; +}; // namespace testing namespace storage { class ChunkedNodeGroup; class Spiller; @@ -179,6 +180,9 @@ class EvictionQueue { * https://github.com/fabubaker/kuzu/blob/umbra-bm/final_project_report.pdf. */ class BufferManager { + friend class testing::FlakyBufferManager; + friend class testing::CopyTestHelper; + friend class FileHandle; friend class MemoryManager; diff --git a/src/include/storage/store/version_record_handler.h b/src/include/storage/store/version_record_handler.h index d3cf5516397..92f16188f91 100644 --- a/src/include/storage/store/version_record_handler.h +++ b/src/include/storage/store/version_record_handler.h @@ -1,13 +1,14 @@ #pragma once #include "common/types/types.h" -#include "storage/store/chunked_node_group.h" #include "transaction/transaction.h" namespace kuzu { namespace storage { +class ChunkedNodeGroup; + using version_record_handler_op_t = void ( ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 5b02cfd452a..ed656733c87 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -4,7 +4,6 @@ #include "common/constants.h" #include "common/types/types.h" -#include "storage/store/node_group.h" namespace kuzu { namespace catalog { @@ -21,6 +20,7 @@ namespace main { class ClientContext; } namespace storage { +class VersionRecordHandler; // TODO(Guodong): This should be reworked to use MemoryManager for memory allocaiton. // For now, we use malloc to get around the limitation of 256KB from MM. @@ -86,9 +86,9 @@ class UndoBuffer { void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, const catalog::SequenceRollbackData& data); void createInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler); + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler); void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, - common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler); + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, VectorUpdateInfo* vectorUpdateInfo); @@ -101,7 +101,7 @@ class UndoBuffer { uint8_t* createUndoRecord(uint64_t size); void createVersionInfo(UndoRecordType recordType, common::row_idx_t startRow, - common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler, + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler, common::node_group_idx_t nodeGroupIdx = 0); void commitRecord(UndoRecordType recordType, const uint8_t* record, diff --git a/src/include/transaction/transaction_manager.h b/src/include/transaction/transaction_manager.h index ae530b437b1..97471cc5095 100644 --- a/src/include/transaction/transaction_manager.h +++ b/src/include/transaction/transaction_manager.h @@ -15,12 +15,14 @@ class ClientContext; namespace testing { class DBTest; +class FlakyBufferManager; } // namespace testing namespace transaction { class TransactionManager { friend class testing::DBTest; + friend class testing::FlakyBufferManager; public: // Timestamp starts from 1. 0 is reserved for the dummy system transaction. diff --git a/src/storage/store/version_record_handler.cpp b/src/storage/store/version_record_handler.cpp index 62953c13741..5abf77c4e2a 100644 --- a/src/storage/store/version_record_handler.cpp +++ b/src/storage/store/version_record_handler.cpp @@ -1,5 +1,7 @@ #include "storage/store/version_record_handler.h" +#include "storage/store/chunked_node_group.h" + namespace kuzu::storage { void VersionRecordHandler::rollbackInsert(const transaction::Transaction* transaction, common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index a33aa14ef70..c98de54b233 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -4,7 +4,9 @@ #include "catalog/catalog_entry/sequence_catalog_entry.h" #include "catalog/catalog_entry/table_catalog_entry.h" #include "catalog/catalog_set.h" +#include "storage/store/chunked_node_group.h" #include "storage/store/update_info.h" +#include "storage/store/version_record_handler.h" #include "transaction/transaction.h" using namespace kuzu::catalog; diff --git a/test/copy/copy_test.cpp b/test/copy/copy_test.cpp index 9fe9de5b9f5..3fe2a22a95d 100644 --- a/test/copy/copy_test.cpp +++ b/test/copy/copy_test.cpp @@ -3,14 +3,20 @@ #include "graph_test/base_graph_test.h" #include "graph_test/graph_test.h" #include "main/database.h" - -#define private public #include "storage/buffer_manager/buffer_manager.h" #include "transaction/transaction_manager.h" namespace kuzu { namespace testing { +class CopyTestHelper { +public: + static std::vector>& getBMFileHandles( + storage::BufferManager* bm) { + return bm->fileHandles; + } +}; + class FlakyBufferManager : public storage::BufferManager { public: FlakyBufferManager(const std::string& databasePath, const std::string& spillToDiskPath, @@ -176,7 +182,7 @@ TEST_F(CopyTest, RelCopyBMExceptionRecoverySameConnection) { .earlyExitOnFailureFunc = [this](main::QueryResult*) { // clear the BM so that the failure frequency isn't messed with by cached pages - for (auto& fh : currentBM->fileHandles) { + for (auto& fh : CopyTestHelper::getBMFileHandles(currentBM)) { currentBM->removeFilePagesFromFrames(*fh); } return false; From c052e90cd5fbccc89951dd6ae9b2f79d3b36b78f Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 2 Dec 2024 09:42:00 -0500 Subject: [PATCH 27/28] Rework nextChainedSlots() for in mem hash index --- src/include/storage/index/in_mem_hash_index.h | 17 +++++++---------- src/storage/index/in_mem_hash_index.cpp | 1 + 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/include/storage/index/in_mem_hash_index.h b/src/include/storage/index/in_mem_hash_index.h index 5e739849658..8caa30edbd4 100644 --- a/src/include/storage/index/in_mem_hash_index.h +++ b/src/include/storage/index/in_mem_hash_index.h @@ -134,9 +134,11 @@ class InMemHashIndex final { // Leaves the slot pointer pointing at the last slot to make it easier to add a new one bool nextChainedSlot(SlotIterator& iter) const { - iter.slotInfo.slotId = iter.slot->header.nextOvfSlotId; - iter.slotInfo.slotType = SlotType::OVF; + KU_ASSERT(iter.slotInfo.slotType == SlotType::PRIMARY || + iter.slotInfo.slotId != iter.slot->header.nextOvfSlotId); if (iter.slot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + iter.slotInfo.slotId = iter.slot->header.nextOvfSlotId; + iter.slotInfo.slotType = SlotType::OVF; iter.slot = getSlot(iter.slotInfo); return true; } @@ -176,7 +178,9 @@ class InMemHashIndex final { if (deletedPos.has_value()) { // Find the last valid entry and move it into the deleted position - auto newIter = getLastValidEntry(iter); + auto newIter = iter; + while (nextChainedSlot(newIter)) + ; if (newIter.slotInfo != iter.slotInfo || *deletedPos != newIter.slot->header.numEntries() - 1) { KU_ASSERT(newIter.slot->header.numEntries() > 0); @@ -199,13 +203,6 @@ class InMemHashIndex final { } private: - SlotIterator getLastValidEntry(const SlotIterator& startIter) { - auto curIter = startIter; - while (curIter.slot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID && - nextChainedSlot(curIter)) {} - return curIter; - } - // Assumes that space has already been allocated for the entry bool appendInternal(Key key, common::offset_t value, common::hash_t hash, visible_func isVisible) { diff --git a/src/storage/index/in_mem_hash_index.cpp b/src/storage/index/in_mem_hash_index.cpp index c7bf71f742b..f8fcff25096 100644 --- a/src/storage/index/in_mem_hash_index.cpp +++ b/src/storage/index/in_mem_hash_index.cpp @@ -141,6 +141,7 @@ void InMemHashIndex::reclaimOverflowSlots(SlotIterator iter) { while (iter.slot->header.numEntries() > 0 || iter.slotInfo.slotType == SlotType::PRIMARY) { lastNonEmptySlot = iter.slot; if (!nextChainedSlot(iter)) { + iter.slotInfo = HashIndexUtils::INVALID_OVF_INFO; break; } } From cec96e58571b7e9925be710584ccc94b139e2b24 Mon Sep 17 00:00:00 2001 From: Royi Luo Date: Mon, 2 Dec 2024 10:58:32 -0500 Subject: [PATCH 28/28] Update splitSlots so behaviour is same as before nextChainedSlot() refactor --- src/storage/index/in_mem_hash_index.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/storage/index/in_mem_hash_index.cpp b/src/storage/index/in_mem_hash_index.cpp index f8fcff25096..fed3062287b 100644 --- a/src/storage/index/in_mem_hash_index.cpp +++ b/src/storage/index/in_mem_hash_index.cpp @@ -99,7 +99,10 @@ void InMemHashIndex::splitSlot(HashIndexHeader& header) { if (newSlotPos >= getSlotCapacity()) { auto newOvfSlotId = allocateAOSlot(); newSlot.slot->header.nextOvfSlotId = newOvfSlotId; - nextChainedSlot(newSlot); + if (newSlot.slot->header.nextOvfSlotId != + SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + nextChainedSlot(newSlot); + } newSlotPos = 0; } newSlot.slot->entries[newSlotPos] = entry; @@ -114,7 +117,10 @@ void InMemHashIndex::splitSlot(HashIndexHeader& header) { entryPosToInsert++; if (entryPosToInsert >= getSlotCapacity()) { entryPosToInsert = 0; - nextChainedSlot(originalSlotForInsert); + if (originalSlotForInsert.slot->header.nextOvfSlotId != + SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + nextChainedSlot(originalSlotForInsert); + } } } originalSlotForInsert.slot->entries[entryPosToInsert] = entry;