Skip to content

Commit

Permalink
Fix rollback during Node Table COPY (#4467)
Browse files Browse the repository at this point in the history
  • Loading branch information
royi-luo authored Dec 2, 2024
1 parent f5224d7 commit a82a76a
Show file tree
Hide file tree
Showing 59 changed files with 1,206 additions and 374 deletions.
7 changes: 7 additions & 0 deletions src/include/common/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class RoaringBitmapSemiMask {
virtual ~RoaringBitmapSemiMask() = default;

virtual void mask(common::offset_t nodeOffset) = 0;
virtual void maskRange(common::offset_t startNodeOffset, common::offset_t endNodeOffset) = 0;

virtual bool isMasked(common::offset_t startNodeOffset) = 0;

Expand Down Expand Up @@ -44,6 +45,9 @@ class Roaring32BitmapSemiMask : public RoaringBitmapSemiMask {
}

void mask(common::offset_t nodeOffset) override { roaring->add(nodeOffset); }
void maskRange(common::offset_t startNodeOffset, common::offset_t endNodeOffset) override {
roaring->addRange(startNodeOffset, endNodeOffset);
}

bool isMasked(common::offset_t startNodeOffset) override {
return roaring->contains(startNodeOffset);
Expand Down Expand Up @@ -76,6 +80,9 @@ class Roaring64BitmapSemiMask : public RoaringBitmapSemiMask {
roaring(std::make_shared<roaring::Roaring64Map>()) {}

void mask(common::offset_t nodeOffset) override { roaring->add(nodeOffset); }
void maskRange(common::offset_t startNodeOffset, common::offset_t endNodeOffset) override {
roaring->addRange(startNodeOffset, endNodeOffset);
}

bool isMasked(common::offset_t startNodeOffset) override {
return roaring->contains(startNodeOffset);
Expand Down
4 changes: 2 additions & 2 deletions src/include/main/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ class Database {
void initMembers(std::string_view dbPath, construct_bm_func_t initBmFunc = initBufferManager);

// factory method only to be used for tests
static std::unique_ptr<Database> construct(std::string_view databasePath,
SystemConfig systemConfig, construct_bm_func_t constructFunc);
Database(std::string_view databasePath, SystemConfig systemConfig,
construct_bm_func_t constructBMFunc);

void openLockFile();
void initAndLockDBDir();
Expand Down
8 changes: 6 additions & 2 deletions src/include/storage/buffer_manager/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ namespace common {
class VirtualFileSystem;
};
namespace testing {
class EmptyBufferManagerTest;
};
class FlakyBufferManager;
class CopyTestHelper;
}; // namespace testing
namespace storage {
class ChunkedNodeGroup;
class Spiller;
Expand Down Expand Up @@ -179,6 +180,9 @@ class EvictionQueue {
* https://github.com/fabubaker/kuzu/blob/umbra-bm/final_project_report.pdf.
*/
class BufferManager {
friend class testing::FlakyBufferManager;
friend class testing::CopyTestHelper;

friend class FileHandle;
friend class MemoryManager;

Expand Down
12 changes: 12 additions & 0 deletions src/include/storage/enums/csr_node_group_scan_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <cstdint>

namespace kuzu::storage {
enum class CSRNodeGroupScanSource : uint8_t {
COMMITTED_PERSISTENT = 0,
COMMITTED_IN_MEMORY = 1,
UNCOMMITTED = 2,
NONE = 10
};
} // namespace kuzu::storage
2 changes: 1 addition & 1 deletion src/include/storage/index/hash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class PrimaryKeyIndex {
KU_ASSERT(keyDataTypeID == common::TypeUtils::getPhysicalTypeIDForType<T>());
return getTypedHashIndex(key)->insertInternal(transaction, key, value, isVisible);
}
bool insert(const transaction::Transaction* transaction, common::ValueVector* keyVector,
bool insert(const transaction::Transaction* transaction, const common::ValueVector* keyVector,
uint64_t vectorPos, common::offset_t value, visible_func isVisible);

// Appends the buffer to the index. Returns the number of values successfully inserted.
Expand Down
20 changes: 15 additions & 5 deletions src/include/storage/index/in_mem_hash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ class InMemHashIndex final {

// Leaves the slot pointer pointing at the last slot to make it easier to add a new one
bool nextChainedSlot(SlotIterator& iter) const {
iter.slotInfo.slotId = iter.slot->header.nextOvfSlotId;
iter.slotInfo.slotType = SlotType::OVF;
KU_ASSERT(iter.slotInfo.slotType == SlotType::PRIMARY ||
iter.slotInfo.slotId != iter.slot->header.nextOvfSlotId);
if (iter.slot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID) {
iter.slotInfo.slotId = iter.slot->header.nextOvfSlotId;
iter.slotInfo.slotType = SlotType::OVF;
iter.slot = getSlot(iter.slotInfo);
return true;
}
Expand All @@ -159,14 +161,13 @@ class InMemHashIndex final {
auto fingerprint = HashIndexUtils::getFingerprintForHash(hashValue);
auto slotId = HashIndexUtils::getPrimarySlotIdForHash(this->indexHeader, hashValue);
SlotIterator iter(slotId, this);
std::optional<entry_pos_t> deletedPos = 0;
std::optional<entry_pos_t> deletedPos;
do {
for (auto entryPos = 0u; entryPos < getSlotCapacity<T>(); entryPos++) {
if (iter.slot->header.isEntryValid(entryPos) &&
iter.slot->header.fingerprints[entryPos] == fingerprint &&
equals(key, iter.slot->entries[entryPos].key)) {
deletedPos = entryPos;
iter.slot->header.setEntryInvalid(entryPos);
break;
}
}
Expand All @@ -182,12 +183,21 @@ class InMemHashIndex final {
;
if (newIter.slotInfo != iter.slotInfo ||
*deletedPos != newIter.slot->header.numEntries() - 1) {
auto lastEntryPos = newIter.slot->header.numEntries();
KU_ASSERT(newIter.slot->header.numEntries() > 0);
auto lastEntryPos = newIter.slot->header.numEntries() - 1;
iter.slot->entries[*deletedPos] = newIter.slot->entries[lastEntryPos];
iter.slot->header.setEntryValid(*deletedPos,
newIter.slot->header.fingerprints[lastEntryPos]);
newIter.slot->header.setEntryInvalid(lastEntryPos);
} else {
iter.slot->header.setEntryInvalid(*deletedPos);
}

if (newIter.slot->header.numEntries() == 0) {
reclaimOverflowSlots(SlotIterator(slotId, this));
}

return true;
}
return false;
}
Expand Down
8 changes: 5 additions & 3 deletions src/include/storage/store/chunked_node_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ChunkedNodeGroup {
std::pair<std::unique_ptr<ColumnChunk>, std::unique_ptr<ColumnChunk>> scanUpdates(
transaction::Transaction* transaction, common::column_id_t columnID);

bool lookup(transaction::Transaction* transaction, const TableScanState& state,
bool lookup(const transaction::Transaction* transaction, const TableScanState& state,
NodeGroupScanState& nodeGroupScanState, common::offset_t rowIdxInChunk,
common::sel_t posInOutput) const;

Expand Down Expand Up @@ -139,10 +139,12 @@ class ChunkedNodeGroup {

void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows_,
common::transaction_t commitTS);
void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_);
void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_,
common::transaction_t commitTS);
void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_,
common::transaction_t commitTS);
void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_);
void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_,
common::transaction_t commitTS);

uint64_t getEstimatedMemoryUsage() const;

Expand Down
17 changes: 9 additions & 8 deletions src/include/storage/store/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class Column {
virtual void scan(transaction::Transaction* transaction, const ChunkState& state,
common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan,
common::ValueVector* resultVector) const;
virtual void lookupValue(transaction::Transaction* transaction, const ChunkState& state,
virtual void lookupValue(const transaction::Transaction* transaction, const ChunkState& state,
common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const;

// Scan from [startOffsetInGroup, endOffsetInGroup).
virtual void scan(transaction::Transaction* transaction, const ChunkState& state,
virtual void scan(const transaction::Transaction* transaction, const ChunkState& state,
common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup,
common::ValueVector* resultVector, uint64_t offsetInVector) const;
// Scan from [startOffsetInGroup, endOffsetInGroup).
virtual void scan(transaction::Transaction* transaction, const ChunkState& state,
virtual void scan(const transaction::Transaction* transaction, const ChunkState& state,
ColumnChunkData* columnChunk, common::offset_t startOffset = 0,
common::offset_t endOffset = common::INVALID_OFFSET) const;

Expand All @@ -71,7 +71,7 @@ class Column {

std::string getName() const { return name; }

virtual void scan(transaction::Transaction* transaction, const ChunkState& state,
virtual void scan(const transaction::Transaction* transaction, const ChunkState& state,
common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup, uint8_t* result);

// Batch write to a set of sequential pages.
Expand Down Expand Up @@ -99,8 +99,9 @@ class Column {
common::offset_t startOffsetInChunk, common::row_idx_t numValuesToScan,
common::ValueVector* resultVector) const;

virtual void lookupInternal(transaction::Transaction* transaction, const ChunkState& state,
common::offset_t nodeOffset, common::ValueVector* resultVector, uint32_t posInVector) const;
virtual void lookupInternal(const transaction::Transaction* transaction,
const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* resultVector,
uint32_t posInVector) const;

void writeValues(ChunkState& state, common::offset_t dstOffset, const uint8_t* data,
const common::NullMask* nullChunkData, common::offset_t srcOffset = 0,
Expand Down Expand Up @@ -162,15 +163,15 @@ class InternalIDColumn final : public Column {
populateCommonTableID(resultVector);
}

void scan(transaction::Transaction* transaction, const ChunkState& state,
void scan(const transaction::Transaction* transaction, const ChunkState& state,
common::offset_t startOffsetInGroup, common::offset_t endOffsetInGroup,
common::ValueVector* resultVector, uint64_t offsetInVector) const override {
Column::scan(transaction, state, startOffsetInGroup, endOffsetInGroup, resultVector,
offsetInVector);
populateCommonTableID(resultVector);
}

void lookupInternal(transaction::Transaction* transaction, const ChunkState& state,
void lookupInternal(const transaction::Transaction* transaction, const ChunkState& state,
common::offset_t nodeOffset, common::ValueVector* resultVector,
uint32_t posInVector) const override {
Column::lookupInternal(transaction, state, nodeOffset, resultVector, posInVector);
Expand Down
4 changes: 2 additions & 2 deletions src/include/storage/store/column_chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ class ColumnChunk {
void scan(const transaction::Transaction* transaction, const ChunkState& state,
common::ValueVector& output, common::offset_t offsetInChunk, common::length_t length) const;
template<ResidencyState SCAN_RESIDENCY_STATE>
void scanCommitted(transaction::Transaction* transaction, ChunkState& chunkState,
void scanCommitted(const transaction::Transaction* transaction, ChunkState& chunkState,
ColumnChunk& output, common::row_idx_t startRow = 0,
common::row_idx_t numRows = common::INVALID_ROW_IDX) const;
void lookup(transaction::Transaction* transaction, const ChunkState& state,
void lookup(const transaction::Transaction* transaction, const ChunkState& state,
common::offset_t rowInChunk, common::ValueVector& output,
common::sel_t posInOutputVector) const;
void update(const transaction::Transaction* transaction, common::offset_t offsetInChunk,
Expand Down
10 changes: 5 additions & 5 deletions src/include/storage/store/column_reader_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,22 @@ class ColumnReadWriter {

virtual ~ColumnReadWriter() = default;

virtual void readCompressedValueToPage(transaction::Transaction* transaction,
virtual void readCompressedValueToPage(const transaction::Transaction* transaction,
const ChunkState& state, common::offset_t nodeOffset, uint8_t* result,
uint32_t offsetInResult, const read_value_from_page_func_t<uint8_t*>& readFunc) = 0;

virtual void readCompressedValueToVector(transaction::Transaction* transaction,
virtual void readCompressedValueToVector(const transaction::Transaction* transaction,
const ChunkState& state, common::offset_t nodeOffset, common::ValueVector* result,
uint32_t offsetInResult,
const read_value_from_page_func_t<common::ValueVector*>& readFunc) = 0;

virtual uint64_t readCompressedValuesToPage(transaction::Transaction* transaction,
virtual uint64_t readCompressedValuesToPage(const transaction::Transaction* transaction,
const ChunkState& state, uint8_t* result, uint32_t startOffsetInResult,
uint64_t startNodeOffset, uint64_t endNodeOffset,
const read_values_from_page_func_t<uint8_t*>& readFunc,
const std::optional<filter_func_t>& filterFunc = {}) = 0;

virtual uint64_t readCompressedValuesToVector(transaction::Transaction* transaction,
virtual uint64_t readCompressedValuesToVector(const transaction::Transaction* transaction,
const ChunkState& state, common::ValueVector* result, uint32_t startOffsetInResult,
uint64_t startNodeOffset, uint64_t endNodeOffset,
const read_values_from_page_func_t<common::ValueVector*>& readFunc,
Expand All @@ -78,7 +78,7 @@ class ColumnReadWriter {
const uint8_t* data, const common::NullMask* nullChunkData, common::offset_t srcOffset,
common::offset_t numValues, const write_values_func_t& writeFunc) = 0;

void readFromPage(transaction::Transaction* transaction, common::page_idx_t pageIdx,
void readFromPage(const transaction::Transaction* transaction, common::page_idx_t pageIdx,
const std::function<void(uint8_t*)>& readFunc);

void updatePageWithCursor(PageCursor cursor,
Expand Down
18 changes: 6 additions & 12 deletions src/include/storage/store/csr_node_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <bitset>

#include "common/data_chunk/data_chunk.h"
#include "storage/enums/csr_node_group_scan_source.h"
#include "storage/store/csr_chunked_node_group.h"
#include "storage/store/node_group.h"

Expand All @@ -21,13 +22,6 @@ struct csr_list_t {
common::length_t length = 0;
};

enum class CSRNodeGroupScanSource : uint8_t {
COMMITTED_PERSISTENT = 0,
COMMITTED_IN_MEMORY = 1,
UNCOMMITTED = 2,
NONE = 10
};

// Store rows of a CSR list.
// If rows of the CSR list are stored in a sequential order, then `isSequential` is set to true.
// rowIndices consists of startRowIdx and length.
Expand Down Expand Up @@ -187,9 +181,9 @@ class CSRNodeGroup final : public NodeGroup {
}
}

void initializeScanState(transaction::Transaction* transaction,
void initializeScanState(const transaction::Transaction* transaction,
TableScanState& state) const override;
NodeGroupScanResult scan(transaction::Transaction* transaction,
NodeGroupScanResult scan(const transaction::Transaction* transaction,
TableScanState& state) const override;

void appendChunkedCSRGroup(const transaction::Transaction* transaction,
Expand Down Expand Up @@ -220,7 +214,7 @@ class CSRNodeGroup final : public NodeGroup {
void serialize(common::Serializer& serializer) override;

private:
void initScanForCommittedPersistent(transaction::Transaction* transaction,
void initScanForCommittedPersistent(const transaction::Transaction* transaction,
RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const;
void initScanForCommittedInMem(RelTableScanState& relScanState,
CSRNodeGroupScanState& nodeGroupScanState) const;
Expand All @@ -237,11 +231,11 @@ class CSRNodeGroup final : public NodeGroup {
const transaction::Transaction* transaction, RelTableScanState& tableState,
CSRNodeGroupScanState& nodeGroupScanState) const;

NodeGroupScanResult scanCommittedInMem(transaction::Transaction* transaction,
NodeGroupScanResult scanCommittedInMem(const transaction::Transaction* transaction,
RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const;
NodeGroupScanResult scanCommittedInMemSequential(const transaction::Transaction* transaction,
const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const;
NodeGroupScanResult scanCommittedInMemRandom(transaction::Transaction* transaction,
NodeGroupScanResult scanCommittedInMemRandom(const transaction::Transaction* transaction,
const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const;

void checkpointInMemOnly(const common::UniqLock& lock, NodeGroupCheckpointState& state);
Expand Down
8 changes: 4 additions & 4 deletions src/include/storage/store/dictionary_column.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class DictionaryColumn {
DictionaryColumn(const std::string& name, FileHandle* dataFH, MemoryManager* mm,
ShadowFile* shadowFile, bool enableCompression);

void scan(transaction::Transaction* transaction, const ChunkState& state,
void scan(const transaction::Transaction* transaction, const ChunkState& state,
DictionaryChunk& dictChunk) const;
// Offsets to scan should be a sorted list of pairs mapping the index of the entry in the string
// dictionary (as read from the index column) to the output index in the result vector to store
// the string.
void scan(transaction::Transaction* transaction, const ChunkState& offsetState,
void scan(const transaction::Transaction* transaction, const ChunkState& offsetState,
const ChunkState& dataState,
std::vector<std::pair<DictionaryChunk::string_index_t, uint64_t>>& offsetsToScan,
common::ValueVector* resultVector, const ColumnChunkMetadata& indexMeta) const;
Expand All @@ -32,10 +32,10 @@ class DictionaryColumn {
Column* getOffsetColumn() const { return offsetColumn.get(); }

private:
void scanOffsets(transaction::Transaction* transaction, const ChunkState& state,
void scanOffsets(const transaction::Transaction* transaction, const ChunkState& state,
DictionaryChunk::string_offset_t* offsets, uint64_t index, uint64_t numValues,
uint64_t dataSize) const;
void scanValueToVector(transaction::Transaction* transaction, const ChunkState& dataState,
void scanValueToVector(const transaction::Transaction* transaction, const ChunkState& dataState,
uint64_t startOffset, uint64_t endOffset, common::ValueVector* resultVector,
uint64_t offsetInVector) const;

Expand Down
16 changes: 16 additions & 0 deletions src/include/storage/store/group_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "common/serializer/serializer.h"
#include "common/types/types.h"
#include "common/uniq_lock.h"
#include "common/utils.h"

namespace kuzu {
namespace storage {
Expand All @@ -24,6 +25,13 @@ class GroupCollection {
[&](common::Deserializer& deser) { return T::deserialize(memoryManager, deser); });
}

void removeTrailingGroups([[maybe_unused]] const common::UniqLock& lock,
common::idx_t numGroupsToRemove) {
KU_ASSERT(lock.isLocked());
KU_ASSERT(numGroupsToRemove <= groups.size());
groups.erase(groups.end() - numGroupsToRemove, groups.end());
}

void serializeGroups(common::Serializer& ser) {
auto lockGuard = lock();
ser.serializeVectorOfPtrs<T>(groups);
Expand Down Expand Up @@ -111,6 +119,14 @@ class GroupCollection {
groups.clear();
}

common::idx_t getNumEmptyTrailingGroups(const common::UniqLock& lock) {
const auto& groupsVector = getAllGroups(lock);
return common::safeIntegerConversion<common::idx_t>(
std::find_if(groupsVector.rbegin(), groupsVector.rend(),
[](const auto& group) { return (group->getNumRows() != 0); }) -
groupsVector.rbegin());
}

private:
mutable std::mutex mtx;
std::vector<std::unique_ptr<T>> groups;
Expand Down
Loading

0 comments on commit a82a76a

Please sign in to comment.