From 2591604a296914fbf6be4e7fc8cb57649b9fb712 Mon Sep 17 00:00:00 2001 From: Benjamin Winger Date: Mon, 18 Nov 2024 15:31:18 +0000 Subject: [PATCH] Add GDS support for vertex property scanning (#4453) --- src/common/vector/value_vector.cpp | 6 +- src/function/gds/all_shortest_paths.cpp | 4 +- src/function/gds/gds_task.cpp | 16 +- src/function/gds/gds_utils.cpp | 2 +- src/function/gds/rec_joins.cpp | 14 +- src/function/gds/single_shortest_paths.cpp | 4 +- src/function/gds/variable_length_path.cpp | 2 +- .../gds/weakly_connected_components.cpp | 3 +- src/graph/on_disk_graph.cpp | 111 +++++++++--- src/include/common/vector/value_vector.h | 3 +- src/include/function/gds/gds_frontier.h | 9 +- src/include/function/gds/gds_task.h | 5 +- src/include/graph/graph.h | 114 ++++++++++--- src/include/graph/on_disk_graph.h | 60 +++++-- src/include/storage/store/column_chunk_data.h | 11 -- src/include/storage/store/csr_node_group.h | 16 +- src/include/storage/store/node_group.h | 25 +-- src/include/storage/store/node_table.h | 26 ++- src/include/storage/store/rel_table.h | 5 - src/include/storage/store/table.h | 20 +-- src/main/storage_driver.cpp | 6 +- .../operator/scan/offset_scan_node_table.cpp | 11 +- .../scan/primary_key_scan_node_table.cpp | 11 +- src/storage/store/csr_node_group.cpp | 35 ++-- src/storage/store/node_group.cpp | 160 +++++++++++++----- src/storage/store/node_table.cpp | 51 +++++- src/storage/store/table.cpp | 8 +- test/storage/rel_scan_test.cpp | 23 +++ 28 files changed, 519 insertions(+), 242 deletions(-) diff --git a/src/common/vector/value_vector.cpp b/src/common/vector/value_vector.cpp index f73f5e0c4f9..f08312e945d 100644 --- a/src/common/vector/value_vector.cpp +++ b/src/common/vector/value_vector.cpp @@ -14,7 +14,8 @@ namespace kuzu { namespace common { -ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager) +ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager, + std::shared_ptr dataChunkState) : dataType{std::move(dataType)}, nullMask{DEFAULT_VECTOR_CAPACITY} { if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { // LCOV_EXCL_START @@ -26,6 +27,9 @@ ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryMan numBytesPerValue = getDataTypeSize(this->dataType); initializeValueBuffer(); auxiliaryBuffer = AuxiliaryBufferFactory::getAuxiliaryBuffer(this->dataType, memoryManager); + if (dataChunkState) { + setState(std::move(dataChunkState)); + } } void ValueVector::setState(const std::shared_ptr& state_) { diff --git a/src/function/gds/all_shortest_paths.cpp b/src/function/gds/all_shortest_paths.cpp index b2f9d1afa15..2601f0fefb4 100644 --- a/src/function/gds/all_shortest_paths.cpp +++ b/src/function/gds/all_shortest_paths.cpp @@ -150,7 +150,7 @@ class AllSPDestinationsEdgeCompute : public SPEdgeCompute { PathMultiplicities* multiplicities) : SPEdgeCompute{frontierPair}, multiplicities{multiplicities} {}; - std::vector edgeCompute(nodeID_t boundNodeID, GraphScanState::Chunk& resultChunk, + std::vector edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk, bool) override { std::vector activeNodes; resultChunk.forEach([&](auto nbrNodeID, auto /*edgeID*/) { @@ -190,7 +190,7 @@ class AllSPPathsEdgeCompute : public SPEdgeCompute { parentListBlock = bfsGraph->addNewBlock(); } - std::vector edgeCompute(nodeID_t boundNodeID, GraphScanState::Chunk& resultChunk, + std::vector edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk, bool fwdEdge) override { std::vector activeNodes; resultChunk.forEach([&](auto nbrNodeID, auto edgeID) { diff --git a/src/function/gds/gds_task.cpp b/src/function/gds/gds_task.cpp index e5ecded932e..9eb0c36949b 100644 --- a/src/function/gds/gds_task.cpp +++ b/src/function/gds/gds_task.cpp @@ -7,11 +7,11 @@ using namespace kuzu::common; namespace kuzu { namespace function { -static uint64_t computeScanResult(nodeID_t sourceNodeID, graph::GraphScanState::Chunk& chunk, +static uint64_t computeScanResult(nodeID_t sourceNodeID, graph::NbrScanState::Chunk& nbrChunk, EdgeCompute& ec, FrontierPair& frontierPair, bool isFwd) { - auto activeNodes = ec.edgeCompute(sourceNodeID, chunk, isFwd); + auto activeNodes = ec.edgeCompute(sourceNodeID, nbrChunk, isFwd); frontierPair.getNextFrontierUnsafe().setActive(activeNodes); - return chunk.size(); + return nbrChunk.size(); } void FrontierTask::run() { @@ -49,11 +49,15 @@ void FrontierTask::run() { void VertexComputeTask::run() { FrontierMorsel frontierMorsel; + auto graph = sharedState->graph; + std::vector propertiesToScan; + auto scanState = + graph->prepareVertexScan(sharedState->morselDispatcher.getTableID(), propertiesToScan); auto localVc = info.vc.copy(); while (sharedState->morselDispatcher.getNextRangeMorsel(frontierMorsel)) { - while (frontierMorsel.hasNextOffset()) { - common::nodeID_t nodeID = frontierMorsel.getNextNodeID(); - localVc->vertexCompute(nodeID); + for (auto chunk : graph->scanVertices(frontierMorsel.getBeginOffset(), + frontierMorsel.getEndOffsetExclusive(), *scanState)) { + localVc->vertexCompute(chunk); } } } diff --git a/src/function/gds/gds_utils.cpp b/src/function/gds/gds_utils.cpp index fc785ff8713..a384f7a8db9 100644 --- a/src/function/gds/gds_utils.cpp +++ b/src/function/gds/gds_utils.cpp @@ -82,7 +82,7 @@ void GDSUtils::runVertexComputeIteration(processor::ExecutionContext* executionC auto maxThreads = clientContext->getCurrentSetting(main::ThreadsSetting::name).getValue(); auto info = VertexComputeTaskInfo(vc); - auto sharedState = std::make_shared(maxThreads); + auto sharedState = std::make_shared(maxThreads, graph); for (auto& tableID : graph->getNodeTableIDs()) { if (!vc.beginOnTable(tableID)) { continue; diff --git a/src/function/gds/rec_joins.cpp b/src/function/gds/rec_joins.cpp index 38e5c377609..03f3137220c 100644 --- a/src/function/gds/rec_joins.cpp +++ b/src/function/gds/rec_joins.cpp @@ -8,6 +8,7 @@ #include "function/gds/gds.h" #include "function/gds/gds_frontier.h" #include "function/gds/gds_utils.h" +#include "graph/graph.h" #include "processor/execution_context.h" #include "processor/result/factorized_table.h" #include "storage/buffer_manager/memory_manager.h" @@ -149,11 +150,16 @@ class RJVertexCompute : public VertexCompute { return true; } - void vertexCompute(nodeID_t nodeID) override { - if (sharedState->exceedLimit() || writer->skip(nodeID)) { - return; + void vertexCompute(const graph::VertexScanState::Chunk& chunk) override { + for (auto nodeID : chunk.getNodeIDs()) { + if (sharedState->exceedLimit()) { + return; + } + if (writer->skip(nodeID)) { + continue; + } + writer->write(*localFT, nodeID, sharedState->counter.get()); } - writer->write(*localFT, nodeID, sharedState->counter.get()); } std::unique_ptr copy() override { diff --git a/src/function/gds/single_shortest_paths.cpp b/src/function/gds/single_shortest_paths.cpp index 5972efa192a..c935fa524ce 100644 --- a/src/function/gds/single_shortest_paths.cpp +++ b/src/function/gds/single_shortest_paths.cpp @@ -35,7 +35,7 @@ class SingleSPDestinationsEdgeCompute : public SPEdgeCompute { explicit SingleSPDestinationsEdgeCompute(SinglePathLengthsFrontierPair* frontierPair) : SPEdgeCompute{frontierPair} {}; - std::vector edgeCompute(common::nodeID_t, GraphScanState::Chunk& resultChunk, + std::vector edgeCompute(common::nodeID_t, NbrScanState::Chunk& resultChunk, bool) override { std::vector activeNodes; resultChunk.forEach([&](auto nbrNode, auto) { @@ -59,7 +59,7 @@ class SingleSPPathsEdgeCompute : public SPEdgeCompute { parentListBlock = bfsGraph->addNewBlock(); } - std::vector edgeCompute(nodeID_t boundNodeID, GraphScanState::Chunk& resultChunk, + std::vector edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk, bool isFwd) override { std::vector activeNodes; resultChunk.forEach([&](auto nbrNodeID, auto edgeID) { diff --git a/src/function/gds/variable_length_path.cpp b/src/function/gds/variable_length_path.cpp index 90084393ed0..c3560b8a548 100644 --- a/src/function/gds/variable_length_path.cpp +++ b/src/function/gds/variable_length_path.cpp @@ -51,7 +51,7 @@ struct VarLenJoinsEdgeCompute : public EdgeCompute { parentPtrsBlock = bfsGraph->addNewBlock(); }; - std::vector edgeCompute(nodeID_t boundNodeID, graph::GraphScanState::Chunk& chunk, + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk, bool isFwd) override { std::vector activeNodes; chunk.forEach([&](auto nbrNode, auto edgeID) { diff --git a/src/function/gds/weakly_connected_components.cpp b/src/function/gds/weakly_connected_components.cpp index 594de5ba8a2..f2afc0a6488 100644 --- a/src/function/gds/weakly_connected_components.cpp +++ b/src/function/gds/weakly_connected_components.cpp @@ -112,8 +112,7 @@ class WeaklyConnectedComponent final : public GDSAlgorithm { } private: - void findConnectedComponent(common::nodeID_t nodeID, int64_t groupID, - GraphScanState& scanState) { + void findConnectedComponent(common::nodeID_t nodeID, int64_t groupID, NbrScanState& scanState) { KU_ASSERT(!visitedMap.contains(nodeID)); visitedMap.insert({nodeID, groupID}); // Collect the nodes so that the recursive scan doesn't begin until this scan is done diff --git a/src/graph/on_disk_graph.cpp b/src/graph/on_disk_graph.cpp index 09a960d1bb0..4e71c3cab0f 100644 --- a/src/graph/on_disk_graph.cpp +++ b/src/graph/on_disk_graph.cpp @@ -4,6 +4,8 @@ #include "binder/expression/property_expression.h" #include "common/assert.h" +#include "common/cast.h" +#include "common/constants.h" #include "common/data_chunk/data_chunk_state.h" #include "common/enums/rel_direction.h" #include "common/types/types.h" @@ -13,9 +15,13 @@ #include "main/client_context.h" #include "planner/operator/schema.h" #include "processor/expression_mapper.h" +#include "storage/buffer_manager/memory_manager.h" #include "storage/local_storage/local_rel_table.h" #include "storage/local_storage/local_storage.h" #include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "storage/store/column.h" +#include "storage/store/node_table.h" #include "storage/store/rel_table.h" #include "storage/store/table.h" @@ -72,8 +78,9 @@ static std::unique_ptr getRelScanState( return scanState; } -OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::span tables, - const GraphEntry& graphEntry, std::optional edgePropertyIndex) +OnDiskGraphNbrScanStates::OnDiskGraphNbrScanStates(ClientContext* context, + std::span tables, const GraphEntry& graphEntry, + std::optional edgePropertyIndex) : iteratorIndex{0}, direction{RelDataDirection::INVALID} { auto schema = graphEntry.getRelPropertiesSchema(); auto descriptor = ResultSetDescriptor(&schema); @@ -121,7 +128,7 @@ OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::spangetTableID(), - OnDiskGraphScanState{context, *table, std::move(fwdState), std::move(bwdState)}); + OnDiskGraphNbrScanState{context, *table, std::move(fwdState), std::move(bwdState)}); } } @@ -201,14 +208,14 @@ std::vector OnDiskGraph::getRelTableIDInfos() { return result; } -std::unique_ptr OnDiskGraph::prepareScan(table_id_t relTableID, +std::unique_ptr OnDiskGraph::prepareScan(table_id_t relTableID, std::optional edgePropertyIndex) { auto relTable = context->getStorageManager()->getTable(relTableID)->ptrCast(); - return std::unique_ptr( - new OnDiskGraphScanStates(context, std::span(&relTable, 1), graphEntry, edgePropertyIndex)); + return std::unique_ptr(new OnDiskGraphNbrScanStates(context, + std::span(&relTable, 1), graphEntry, edgePropertyIndex)); } -std::unique_ptr OnDiskGraph::prepareMultiTableScanFwd( +std::unique_ptr OnDiskGraph::prepareMultiTableScanFwd( std::span nodeTableIDs) { std::unordered_set relTableIDSet; std::vector tables; @@ -220,11 +227,11 @@ std::unique_ptr OnDiskGraph::prepareMultiTableScanFwd( } } } - return std::unique_ptr( - new OnDiskGraphScanStates(context, std::span(tables), graphEntry)); + return std::unique_ptr( + new OnDiskGraphNbrScanStates(context, std::span(tables), graphEntry)); } -std::unique_ptr OnDiskGraph::prepareMultiTableScanBwd( +std::unique_ptr OnDiskGraph::prepareMultiTableScanBwd( std::span nodeTableIDs) { std::unordered_set relTableIDSet; std::vector tables; @@ -236,12 +243,12 @@ std::unique_ptr OnDiskGraph::prepareMultiTableScanBwd( } } } - return std::unique_ptr( - new OnDiskGraphScanStates(context, std::span(tables), graphEntry)); + return std::unique_ptr( + new OnDiskGraphNbrScanStates(context, std::span(tables), graphEntry)); } -Graph::Iterator OnDiskGraph::scanFwd(nodeID_t nodeID, GraphScanState& state) { - auto& onDiskScanState = ku_dynamic_cast(state); +Graph::EdgeIterator OnDiskGraph::scanFwd(nodeID_t nodeID, NbrScanState& state) { + auto& onDiskScanState = ku_dynamic_cast(state); onDiskScanState.srcNodeIDVector->setValue(0, nodeID); onDiskScanState.dstNodeIDVector->state->getSelVectorUnsafe().setSelSize(0); KU_ASSERT(nodeTableIDToFwdRelTables.contains(nodeID.tableID)); @@ -253,11 +260,11 @@ Graph::Iterator OnDiskGraph::scanFwd(nodeID_t nodeID, GraphScanState& state) { } } onDiskScanState.startScan(common::RelDataDirection::FWD); - return Graph::Iterator(&onDiskScanState); + return Graph::EdgeIterator(&onDiskScanState); } -Graph::Iterator OnDiskGraph::scanBwd(nodeID_t nodeID, GraphScanState& state) { - auto& onDiskScanState = ku_dynamic_cast(state); +Graph::EdgeIterator OnDiskGraph::scanBwd(nodeID_t nodeID, NbrScanState& state) { + auto& onDiskScanState = ku_dynamic_cast(state); onDiskScanState.srcNodeIDVector->setValue(0, nodeID); onDiskScanState.dstNodeIDVector->state->getSelVectorUnsafe().setSelSize(0); KU_ASSERT(nodeTableIDToBwdRelTables.contains(nodeID.tableID)); @@ -269,10 +276,10 @@ Graph::Iterator OnDiskGraph::scanBwd(nodeID_t nodeID, GraphScanState& state) { } } onDiskScanState.startScan(common::RelDataDirection::BWD); - return Graph::Iterator(&onDiskScanState); + return Graph::EdgeIterator(&onDiskScanState); } -bool OnDiskGraphScanState::InnerIterator::next(evaluator::ExpressionEvaluator* predicate) { +bool OnDiskGraphNbrScanState::InnerIterator::next(evaluator::ExpressionEvaluator* predicate) { while (true) { if (!relTable->scan(context->getTx(), *tableScanState)) { return false; @@ -287,15 +294,15 @@ bool OnDiskGraphScanState::InnerIterator::next(evaluator::ExpressionEvaluator* p } } -OnDiskGraphScanState::InnerIterator::InnerIterator(const main::ClientContext* context, +OnDiskGraphNbrScanState::InnerIterator::InnerIterator(const main::ClientContext* context, storage::RelTable* relTable, std::unique_ptr tableScanState) : context{context}, relTable{relTable}, tableScanState{std::move(tableScanState)} {} -void OnDiskGraphScanState::InnerIterator::initScan() { +void OnDiskGraphNbrScanState::InnerIterator::initScan() { relTable->initScanState(context->getTx(), *tableScanState); } -bool OnDiskGraphScanStates::next() { +bool OnDiskGraphNbrScanStates::next() { while (iteratorIndex < scanStates.size()) { if (getInnerIterator().next(relPredicateEvaluator.get())) { return true; @@ -305,5 +312,65 @@ bool OnDiskGraphScanStates::next() { return false; } +OnDiskGraphVertexScanState::OnDiskGraphVertexScanState(ClientContext& context, + common::table_id_t tableID, const std::vector& propertyNames) + : context{context}, + nodeTable{ku_dynamic_cast(*context.getStorageManager()->getTable(tableID))}, + numNodesScanned{0}, tableID{tableID}, currentOffset{0}, endOffsetExclusive{0} { + std::vector propertyColumnIDs; + propertyColumnIDs.reserve(propertyNames.size()); + auto tableCatalogEntry = context.getCatalog()->getTableCatalogEntry(context.getTx(), tableID); + std::vector columns; + std::vector types; + for (const auto& property : propertyNames) { + auto columnID = tableCatalogEntry->getColumnID(property); + propertyColumnIDs.push_back(columnID); + columns.push_back(&nodeTable.getColumn(columnID)); + types.push_back(columns.back()->getDataType().copy()); + } + propertyVectors = nodeTable.constructDataChunk(std::move(types)); + nodeIDVector = std::make_unique(LogicalType::INTERNAL_ID(), + context.getMemoryManager(), propertyVectors.state); + tableScanState = std::make_unique(tableID, std::move(propertyColumnIDs), + std::move(columns), propertyVectors, nodeIDVector.get()); +} + +bool OnDiskGraphVertexScanState::next() { + if (currentOffset >= endOffsetExclusive) { + return false; + } + if (currentOffset < endOffsetExclusive && + StorageUtils::getNodeGroupIdx(currentOffset) != tableScanState->nodeGroupIdx) { + startScan(currentOffset, endOffsetExclusive); + } + + auto endOffset = std::min(endOffsetExclusive, + StorageUtils::getStartOffsetOfNodeGroup(tableScanState->nodeGroupIdx + 1)); + numNodesScanned = std::min(endOffset - currentOffset, DEFAULT_VECTOR_CAPACITY); + auto result = tableScanState->scanNext(context.getTx(), currentOffset, numNodesScanned); + currentOffset += numNodesScanned; + return result; +} + +Graph::VertexIterator OnDiskGraph::scanVertices(common::offset_t beginOffset, + common::offset_t endOffsetExclusive, VertexScanState& state) { + auto& onDiskVertexScanState = ku_dynamic_cast(state); + onDiskVertexScanState.startScan(beginOffset, endOffsetExclusive); + return Graph::VertexIterator(&state); +} + +std::unique_ptr OnDiskGraph::prepareVertexScan(common::table_id_t tableID, + const std::vector& propertiesToScan) { + return std::make_unique(*context, tableID, propertiesToScan); +} + +void OnDiskGraphVertexScanState::startScan(common::offset_t beginOffset, + common::offset_t endOffsetExclusive) { + numNodesScanned = 0; + this->currentOffset = beginOffset; + this->endOffsetExclusive = endOffsetExclusive; + nodeTable.initScanState(context.getTx(), *tableScanState, tableID, beginOffset); +} + } // namespace graph } // namespace kuzu diff --git a/src/include/common/vector/value_vector.h b/src/include/common/vector/value_vector.h index 95f1d94615e..b40c60477cf 100644 --- a/src/include/common/vector/value_vector.h +++ b/src/include/common/vector/value_vector.h @@ -25,7 +25,8 @@ class KUZU_API ValueVector { friend class ArrowColumnVector; public: - explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr); + explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr, + std::shared_ptr dataChunkState = nullptr); explicit ValueVector(LogicalTypeID dataTypeID, storage::MemoryManager* memoryManager = nullptr) : ValueVector(LogicalType(dataTypeID), memoryManager) { KU_ASSERT(dataTypeID != LogicalTypeID::LIST); diff --git a/src/include/function/gds/gds_frontier.h b/src/include/function/gds/gds_frontier.h index 3d93607e3c6..474ecc4e8e3 100644 --- a/src/include/function/gds/gds_frontier.h +++ b/src/include/function/gds/gds_frontier.h @@ -26,7 +26,7 @@ class EdgeCompute { // So if the implementing class has access to the next frontier as a field, // **do not** call setActive. Helper functions in GDSUtils will do that work. virtual std::vector edgeCompute(common::nodeID_t boundNodeID, - graph::GraphScanState::Chunk& results, bool fwdEdge) = 0; + graph::NbrScanState::Chunk& results, bool fwdEdge) = 0; virtual void resetSingleThreadState() {} @@ -52,7 +52,7 @@ class VertexCompute { // GDSUtils helper functions call isActive on nodes to check if any work should be done for // the edges of a node. Instead, here GDSUtils helper functions for VertexCompute blindly run // the function on each node in a graph. - virtual void vertexCompute(common::nodeID_t curNodeID) = 0; + virtual void vertexCompute(const graph::VertexScanState::Chunk& chunk) = 0; virtual std::unique_ptr copy() = 0; }; @@ -67,6 +67,9 @@ class FrontierMorsel { common::nodeID_t getNextNodeID() { return {nextOffset++, tableID}; } + common::offset_t getBeginOffset() const { return beginOffset; } + common::offset_t getEndOffsetExclusive() const { return endOffsetExclusive; } + protected: void initMorsel(common::table_id_t _tableID, common::offset_t _beginOffset, common::offset_t _endOffsetExclusive) { @@ -97,6 +100,8 @@ class FrontierMorselDispatcher { bool getNextRangeMorsel(FrontierMorsel& frontierMorsel); + common::table_id_t getTableID() const { return tableID; } + private: std::atomic maxThreadsForExec; std::atomic tableID; diff --git a/src/include/function/gds/gds_task.h b/src/include/function/gds/gds_task.h index 5658ffa568e..ad6a715a835 100644 --- a/src/include/function/gds/gds_task.h +++ b/src/include/function/gds/gds_task.h @@ -44,9 +44,10 @@ class FrontierTask : public common::Task { struct VertexComputeTaskSharedState { FrontierMorselDispatcher morselDispatcher; + graph::Graph* graph; - explicit VertexComputeTaskSharedState(uint64_t maxThreadsForExecution) - : morselDispatcher{maxThreadsForExecution} {} + explicit VertexComputeTaskSharedState(uint64_t maxThreadsForExecution, graph::Graph* graph) + : morselDispatcher{maxThreadsForExecution}, graph{graph} {} }; struct VertexComputeTaskInfo { diff --git a/src/include/graph/graph.h b/src/include/graph/graph.h index 98ad76f39c3..fd71253f819 100644 --- a/src/include/graph/graph.h +++ b/src/include/graph/graph.h @@ -26,10 +26,10 @@ struct RelTableIDInfo { common::table_id_t toNodeTableID; }; -class GraphScanState { +class NbrScanState { public: struct Chunk { - friend class GraphScanState; + friend class NbrScanState; // Any neighbour for which the given function returns false // will be omitted from future iterations @@ -69,7 +69,7 @@ class GraphScanState { const common::ValueVector* propertyVector; }; - virtual ~GraphScanState() = default; + virtual ~NbrScanState() = default; virtual Chunk getChunk() = 0; // Returns true if there are more values after the current batch @@ -83,6 +83,44 @@ class GraphScanState { } }; +class VertexScanState { +public: + struct Chunk { + friend class VertexScanState; + + size_t size() const { return nodeIDs.size(); } + std::span getNodeIDs() const { return nodeIDs; } + template + std::span getProperties(size_t propertyIndex) const { + return std::span(reinterpret_cast(propertyVectors[propertyIndex]->getData()), + nodeIDs.size()); + } + + private: + Chunk(std::span nodeIDs, + std::span> propertyVectors) + : nodeIDs{nodeIDs}, propertyVectors{propertyVectors} { + KU_ASSERT(nodeIDs.size() <= common::DEFAULT_VECTOR_CAPACITY); + } + + private: + std::span nodeIDs; + std::span> propertyVectors; + }; + virtual Chunk getChunk() = 0; + + // Returns true if there are more values after the current batch + virtual bool next() = 0; + + virtual ~VertexScanState() = default; + +protected: + Chunk createChunk(std::span nodeIDs, + std::span> propertyVectors) const { + return Chunk{nodeIDs, propertyVectors}; + } +}; + /** * Graph interface to be use by GDS algorithms to get neighbors of nodes. * @@ -93,25 +131,24 @@ class GraphScanState { */ class Graph { public: - class Iterator { + class EdgeIterator { public: - explicit constexpr Iterator(GraphScanState* scanState) : scanState{scanState} {} - DEFAULT_BOTH_MOVE(Iterator); - Iterator(const Iterator& other) = default; - Iterator() : scanState{nullptr} {} - using iterator_category = std::input_iterator_tag; + explicit constexpr EdgeIterator(NbrScanState* scanState) : scanState{scanState} {} + DEFAULT_BOTH_MOVE(EdgeIterator); + EdgeIterator(const EdgeIterator& other) = default; + EdgeIterator() : scanState{nullptr} {} using difference_type = std::ptrdiff_t; - using value_type = GraphScanState::Chunk; + using value_type = NbrScanState::Chunk; value_type operator*() const { return scanState->getChunk(); } - Iterator& operator++() { + EdgeIterator& operator++() { if (!scanState->next()) { scanState = nullptr; } return *this; } void operator++(int) { ++*this; } - bool operator==(const Iterator& other) const { + bool operator==(const EdgeIterator& other) const { // Only needed for comparing to the end, so they are equal if and only if both are null return scanState == nullptr && other.scanState == nullptr; } @@ -135,13 +172,13 @@ class Graph { return nbrNodes; } - Iterator& begin() noexcept { return *this; } - static constexpr Iterator end() noexcept { return Iterator(nullptr); } + EdgeIterator& begin() noexcept { return *this; } + static constexpr EdgeIterator end() noexcept { return EdgeIterator(nullptr); } private: - GraphScanState* scanState; + NbrScanState* scanState; }; - static_assert(std::input_iterator); + static_assert(std::input_iterator); Graph() = default; virtual ~Graph() = default; @@ -164,28 +201,63 @@ class Graph { virtual std::vector getRelTableIDInfos() = 0; // Prepares scan on the specified relationship table (works for backwards and forwards scans) - virtual std::unique_ptr prepareScan(common::table_id_t relTableID, + virtual std::unique_ptr prepareScan(common::table_id_t relTableID, std::optional edgePropertyIndex = std::nullopt) = 0; // Prepares scan on all connected relationship tables using forward adjList. - virtual std::unique_ptr prepareMultiTableScanFwd( + virtual std::unique_ptr prepareMultiTableScanFwd( std::span nodeTableIDs) = 0; // scanFwd an scanBwd scan a single source node under the assumption that many nodes in the same // group will be scanned at once. // Get dst nodeIDs for given src nodeID using forward adjList. - virtual Iterator scanFwd(common::nodeID_t nodeID, GraphScanState& state) = 0; + virtual EdgeIterator scanFwd(common::nodeID_t nodeID, NbrScanState& state) = 0; // We don't use scanBwd currently. I'm adding them because they are the mirroring to scanFwd. // Also, algorithm may only need adjList index in single direction so we should make double // indexing optional. // Prepares scan on all connected relationship tables using backward adjList. - virtual std::unique_ptr prepareMultiTableScanBwd( + virtual std::unique_ptr prepareMultiTableScanBwd( std::span nodeTableIDs) = 0; // Get dst nodeIDs for given src nodeID tables using backward adjList. - virtual Iterator scanBwd(common::nodeID_t nodeID, GraphScanState& state) = 0; + virtual EdgeIterator scanBwd(common::nodeID_t nodeID, NbrScanState& state) = 0; + + class VertexIterator { + public: + explicit constexpr VertexIterator(VertexScanState* scanState) : scanState{scanState} {} + DEFAULT_BOTH_MOVE(VertexIterator); + VertexIterator(const VertexIterator& other) = default; + VertexIterator() : scanState{nullptr} {} + using difference_type = std::ptrdiff_t; + using value_type = VertexScanState::Chunk; + + value_type operator*() const { return scanState->getChunk(); } + VertexIterator& operator++() { + if (!scanState->next()) { + scanState = nullptr; + } + return *this; + } + void operator++(int) { ++*this; } + bool operator==(const VertexIterator& other) const { + // Only needed for comparing to the end, so they are equal if and only if both are null + return scanState == nullptr && other.scanState == nullptr; + } + + VertexIterator& begin() noexcept { return *this; } + static constexpr VertexIterator end() noexcept { return VertexIterator(nullptr); } + + private: + VertexScanState* scanState; + }; + static_assert(std::input_iterator); + + virtual std::unique_ptr prepareVertexScan(common::table_id_t tableID, + const std::vector& propertiesToScan) = 0; + virtual VertexIterator scanVertices(common::offset_t startNodeOffset, + common::offset_t endNodeOffsetExclusive, VertexScanState& scanState) = 0; }; } // namespace graph diff --git a/src/include/graph/on_disk_graph.h b/src/include/graph/on_disk_graph.h index 211e4b5ec6e..f596cb7e850 100644 --- a/src/include/graph/on_disk_graph.h +++ b/src/include/graph/on_disk_graph.h @@ -20,7 +20,7 @@ class MemoryManager; } namespace graph { -struct OnDiskGraphScanState { +struct OnDiskGraphNbrScanState { class InnerIterator { public: InnerIterator(const main::ClientContext* context, storage::RelTable* relTable, @@ -69,18 +69,18 @@ struct OnDiskGraphScanState { InnerIterator fwdIterator; InnerIterator bwdIterator; - OnDiskGraphScanState(main::ClientContext* context, storage::RelTable& table, + OnDiskGraphNbrScanState(main::ClientContext* context, storage::RelTable& table, std::unique_ptr fwdState, std::unique_ptr bwdState) : fwdIterator{context, &table, std::move(fwdState)}, bwdIterator{context, &table, std::move(bwdState)} {} }; -class OnDiskGraphScanStates : public GraphScanState { +class OnDiskGraphNbrScanStates : public NbrScanState { friend class OnDiskGraph; public: - GraphScanState::Chunk getChunk() override { + NbrScanState::Chunk getChunk() override { auto& iter = getInnerIterator(); return createChunk(iter.getNbrNodes(), iter.getEdges(), iter.getSelVectorUnsafe(), propertyVector.get()); @@ -93,7 +93,7 @@ class OnDiskGraphScanStates : public GraphScanState { } private: - const OnDiskGraphScanState::InnerIterator& getInnerIterator() const { + const OnDiskGraphNbrScanState::InnerIterator& getInnerIterator() const { KU_ASSERT(iteratorIndex < scanStates.size()); if (direction == common::RelDataDirection::FWD) { return scanStates[iteratorIndex].second.fwdIterator; @@ -102,9 +102,9 @@ class OnDiskGraphScanStates : public GraphScanState { } } - OnDiskGraphScanState::InnerIterator& getInnerIterator() { - return const_cast( - const_cast(this)->getInnerIterator()); + OnDiskGraphNbrScanState::InnerIterator& getInnerIterator() { + return const_cast( + const_cast(this)->getInnerIterator()); } private: @@ -117,10 +117,35 @@ class OnDiskGraphScanStates : public GraphScanState { std::unique_ptr relPredicateEvaluator; - explicit OnDiskGraphScanStates(main::ClientContext* context, + explicit OnDiskGraphNbrScanStates(main::ClientContext* context, std::span tableIDs, const GraphEntry& graphEntry, std::optional edgePropertyIndex = std::nullopt); - std::vector> scanStates; + std::vector> scanStates; +}; + +class OnDiskGraphVertexScanState : public VertexScanState { +public: + OnDiskGraphVertexScanState(main::ClientContext& context, common::table_id_t tableID, + const std::vector& propertyNames); + + void startScan(common::offset_t beginOffset, common::offset_t endOffsetExclusive); + + bool next() override; + Chunk getChunk() override { + return createChunk(std::span(&nodeIDVector->getValue(0), numNodesScanned), + std::span(propertyVectors.valueVectors)); + } + +private: + common::DataChunk propertyVectors; + std::unique_ptr nodeIDVector; + std::unique_ptr tableScanState; + const main::ClientContext& context; + const storage::NodeTable& nodeTable; + common::offset_t numNodesScanned; + common::table_id_t tableID; + common::offset_t currentOffset; + common::offset_t endOffsetExclusive; }; class OnDiskGraph final : public Graph { @@ -139,15 +164,20 @@ class OnDiskGraph final : public Graph { std::vector getRelTableIDInfos() override; - std::unique_ptr prepareScan(common::table_id_t relTableID, + std::unique_ptr prepareScan(common::table_id_t relTableID, std::optional edgePropertyIndex = std::nullopt) override; - std::unique_ptr prepareMultiTableScanFwd( + std::unique_ptr prepareMultiTableScanFwd( std::span nodeTableIDs) override; - std::unique_ptr prepareMultiTableScanBwd( + std::unique_ptr prepareMultiTableScanBwd( std::span nodeTableIDs) override; + std::unique_ptr prepareVertexScan(common::table_id_t tableID, + const std::vector& propertiesToScan) override; + + Graph::EdgeIterator scanFwd(common::nodeID_t nodeID, NbrScanState& state) override; + Graph::EdgeIterator scanBwd(common::nodeID_t nodeID, NbrScanState& state) override; - Graph::Iterator scanFwd(common::nodeID_t nodeID, GraphScanState& state) override; - Graph::Iterator scanBwd(common::nodeID_t nodeID, GraphScanState& state) override; + VertexIterator scanVertices(common::offset_t beginOffset, common::offset_t endOffsetExclusive, + VertexScanState& state) override; private: main::ClientContext* context; diff --git a/src/include/storage/store/column_chunk_data.h b/src/include/storage/store/column_chunk_data.h index 4c0ce1f25fa..20ececf42ca 100644 --- a/src/include/storage/store/column_chunk_data.h +++ b/src/include/storage/store/column_chunk_data.h @@ -64,17 +64,6 @@ struct ChunkState { return childrenStates[childIdx]; } - void resetState() { - numValuesPerPage = UINT64_MAX; - metadata = ColumnChunkMetadata{}; - if (nullState) { - nullState->resetState(); - } - for (auto& childState : childrenStates) { - childState.resetState(); - } - } - template InMemoryExceptionChunk* getExceptionChunk() { using GetType = std::unique_ptr>; diff --git a/src/include/storage/store/csr_node_group.h b/src/include/storage/store/csr_node_group.h index 1ba5939d156..c0003da4a93 100644 --- a/src/include/storage/store/csr_node_group.h +++ b/src/include/storage/store/csr_node_group.h @@ -127,34 +127,22 @@ struct CSRNodeGroupScanState final : NodeGroupScanState { common::row_idx_t nextCachedRowToScan; // States at the csr list level. Cached during scan over a single csr list. - common::row_idx_t nextRowToScan; NodeCSRIndex inMemCSRList; CSRNodeGroupScanSource source; explicit CSRNodeGroupScanState(common::idx_t numChunks) : NodeGroupScanState{numChunks}, header{nullptr}, numTotalRows{0}, numCachedRows{0}, - nextCachedRowToScan{0}, nextRowToScan{0}, - source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} {} + nextCachedRowToScan{0}, source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} {} CSRNodeGroupScanState(MemoryManager& mm, common::idx_t numChunks) : NodeGroupScanState{numChunks}, numTotalRows{0}, numCachedRows{0}, nextCachedRowToScan{0}, - nextRowToScan{0}, source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} { + source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} { header = std::make_unique(mm, false, common::StorageConstants::NODE_GROUP_SIZE, ResidencyState::IN_MEMORY); cachedScannedVectorsSelBitset.set(); } bool tryScanCachedTuples(RelTableScanState& tableScanState); - - void resetState() override { - NodeGroupScanState::resetState(); - numTotalRows = 0; - numCachedRows = 0; - nextCachedRowToScan = 0; - nextRowToScan = 0; - inMemCSRList = NodeCSRIndex{}; - source = CSRNodeGroupScanSource::COMMITTED_PERSISTENT; - } }; struct CSRNodeGroupCheckpointState final : NodeGroupCheckpointState { diff --git a/src/include/storage/store/node_group.h b/src/include/storage/store/node_group.h index 432b42512f3..268486acfac 100644 --- a/src/include/storage/store/node_group.h +++ b/src/include/storage/store/node_group.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "common/uniq_lock.h" #include "storage/enums/residency_state.h" #include "storage/store/chunked_node_group.h" @@ -18,7 +20,7 @@ class NodeGroup; struct NodeGroupScanState { // Index of committed but not yet checkpointed chunked group to scan. common::idx_t chunkedGroupIdx = 0; - common::row_idx_t numScannedRows = 0; + common::row_idx_t nextRowToScan = 0; // State of each chunk in the checkpointed chunked group. std::vector chunkStates; @@ -26,14 +28,6 @@ struct NodeGroupScanState { virtual ~NodeGroupScanState() = default; DELETE_COPY_DEFAULT_MOVE(NodeGroupScanState); - virtual void resetState() { - chunkedGroupIdx = 0; - numScannedRows = 0; - for (auto& chunkState : chunkStates) { - chunkState.resetState(); - } - } - template TARGET& cast() { return common::ku_dynamic_cast(*this); @@ -138,6 +132,9 @@ class NodeGroup { virtual NodeGroupScanResult scan(transaction::Transaction* transaction, TableScanState& state) const; + virtual NodeGroupScanResult scan(transaction::Transaction* transaction, TableScanState& state, + common::offset_t startOffset, common::offset_t numNodes) const; + bool lookup(const common::UniqLock& lock, transaction::Transaction* transaction, const TableScanState& state); bool lookup(transaction::Transaction* transaction, const TableScanState& state); @@ -185,9 +182,11 @@ class NodeGroup { bool isInserted(const transaction::Transaction* transaction, common::offset_t offsetInGroup); private: + common::idx_t findChunkedGroupIdxFromRowIdx(const common::UniqLock& lock, + common::row_idx_t rowIdx) const; ChunkedNodeGroup* findChunkedGroupFromRowIdx(const common::UniqLock& lock, - common::row_idx_t rowIdx); - ChunkedNodeGroup* findChunkedGroupFromRowIdxNoLock(common::row_idx_t rowIdx); + common::row_idx_t rowIdx) const; + ChunkedNodeGroup* findChunkedGroupFromRowIdxNoLock(common::row_idx_t rowIdx) const; std::unique_ptr checkpointInMemOnly(MemoryManager& memoryManager, const common::UniqLock& lock, NodeGroupCheckpointState& state); @@ -203,6 +202,10 @@ class NodeGroup { const common::UniqLock& lock, const std::vector& columnIDs, const std::vector& columns) const; + virtual NodeGroupScanResult scanInternal(const common::UniqLock& lock, + transaction::Transaction* transaction, TableScanState& state, common::offset_t startOffset, + common::offset_t numNodes) const; + protected: common::node_group_idx_t nodeGroupIdx; NodeGroupDataFormat format; diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index a1a17df9d8c..5d7c8234191 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -26,20 +26,30 @@ namespace storage { struct NodeTableScanState final : TableScanState { // Scan state for un-committed data. // Ideally we shouldn't need columns to scan un-checkpointed but committed data. - NodeTableScanState(common::table_id_t tableID, std::vector columnIDs) - : NodeTableScanState{tableID, std::move(columnIDs), {}} {} NodeTableScanState(common::table_id_t tableID, std::vector columnIDs, - std::vector columns) - : NodeTableScanState{tableID, std::move(columnIDs), std::move(columns), - std::vector{}} {} - NodeTableScanState(common::table_id_t tableID, std::vector columnIDs, - std::vector columns, std::vector columnPredicateSets) + std::vector columns = {}, + std::vector columnPredicateSets = {}) : TableScanState{tableID, std::move(columnIDs), std::move(columns), std::move(columnPredicateSets)} { nodeGroupScanState = std::make_unique(this->columnIDs.size()); } + NodeTableScanState(common::table_id_t tableID, std::vector columnIDs, + std::vector columns, const common::DataChunk& dataChunk, + common::ValueVector* nodeIDVector) + : NodeTableScanState{tableID, std::move(columnIDs), std::move(columns)} { + for (auto& vector : dataChunk.valueVectors) { + outputVectors.push_back(vector.get()); + } + outState = dataChunk.state.get(); + this->nodeIDVector = nodeIDVector; + rowIdxVector->state = this->nodeIDVector->state; + } + bool scanNext(transaction::Transaction* transaction) override; + + bool scanNext(transaction::Transaction* transaction, common::offset_t startOffset, + common::offset_t numNodes); }; struct NodeTableInsertState final : TableInsertState { @@ -100,6 +110,8 @@ class NodeTable final : public Table { void initScanState(transaction::Transaction* transaction, TableScanState& scanState) const override; + void initScanState(transaction::Transaction* transaction, TableScanState& scanState, + common::table_id_t tableID, common::offset_t startOffset) const; bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) override; bool lookup(transaction::Transaction* transaction, const TableScanState& scanState) const; diff --git a/src/include/storage/store/rel_table.h b/src/include/storage/store/rel_table.h index 8cc8fda8131..eea1f512b5a 100644 --- a/src/include/storage/store/rel_table.h +++ b/src/include/storage/store/rel_table.h @@ -50,11 +50,6 @@ struct RelTableScanState : TableScanState { bool scanNext(transaction::Transaction* transaction) override; - void resetState() override { - TableScanState::resetState(); - currBoundNodeIdx = 0; - } - void setNodeIDVectorToFlat(common::sel_t selPos) const; private: diff --git a/src/include/storage/store/table.h b/src/include/storage/store/table.h index ee098e8dd9c..18971ac672d 100644 --- a/src/include/storage/store/table.h +++ b/src/include/storage/store/table.h @@ -35,13 +35,9 @@ struct TableScanState { std::vector columnPredicateSets; - TableScanState(common::table_id_t tableID, std::vector columnIDs) - : TableScanState{tableID, std::move(columnIDs), {}} {} TableScanState(common::table_id_t tableID, std::vector columnIDs, - std::vector columns) - : TableScanState{tableID, std::move(columnIDs), std::move(columns), {}} {} - TableScanState(common::table_id_t tableID, std::vector columnIDs, - std::vector columns, std::vector columnPredicateSets) + std::vector columns = {}, + std::vector columnPredicateSets = {}) : tableID{tableID}, nodeIDVector(nullptr), outState{nullptr}, columnIDs{std::move(columnIDs)}, semiMask{nullptr}, columns{std::move(columns)}, columnPredicateSets{std::move(columnPredicateSets)} { @@ -61,13 +57,6 @@ struct TableScanState { void resetOutVectors(); - virtual void resetState() { - source = TableScanSource::NONE; - nodeGroupIdx = common::INVALID_NODE_GROUP_IDX; - nodeGroup = nullptr; - nodeGroupScanState->resetState(); - } - template TARGET& cast() { return common::ku_dynamic_cast(*this); @@ -187,14 +176,13 @@ class Table { MemoryManager& getMemoryManager() const { return *memoryManager; } + common::DataChunk constructDataChunk(std::vector types) const; + protected: virtual bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) = 0; virtual void serialize(common::Serializer& serializer) const; - std::unique_ptr constructDataChunk( - const std::vector& types); - protected: common::TableType tableType; common::table_id_t tableID; diff --git a/src/main/storage_driver.cpp b/src/main/storage_driver.cpp index d604cb5079f..efb52a3a1f9 100644 --- a/src/main/storage_driver.cpp +++ b/src/main/storage_driver.cpp @@ -128,12 +128,10 @@ uint64_t StorageDriver::getNumRels(const std::string& relName) { void StorageDriver::scanColumn(storage::Table* table, column_id_t columnID, offset_t* offsets, size_t size, uint8_t* result) { // Create scan state. - auto columnIDs = std::vector{columnID}; auto nodeTable = table->ptrCast(); auto column = &nodeTable->getColumn(columnID); - std::vector columns; - columns.push_back(column); - auto scanState = std::make_unique(table->getTableID(), columnIDs, columns); + auto scanState = std::make_unique(table->getTableID(), + std::vector{columnID}, std::vector{column}); // Create value vectors auto idVector = std::make_unique(LogicalType::INTERNAL_ID()); auto columnVector = std::make_unique(column->getDataType().copy(), diff --git a/src/processor/operator/scan/offset_scan_node_table.cpp b/src/processor/operator/scan/offset_scan_node_table.cpp index 2b3d5f52209..18922e6eb6c 100644 --- a/src/processor/operator/scan/offset_scan_node_table.cpp +++ b/src/processor/operator/scan/offset_scan_node_table.cpp @@ -34,15 +34,8 @@ bool OffsetScanNodeTable::getNextTuplesInternal(ExecutionContext* context) { auto nodeID = nodeIDVector->getValue(0); KU_ASSERT(tableIDToNodeInfo.contains(nodeID.tableID)); auto& nodeInfo = tableIDToNodeInfo.at(nodeID.tableID); - if (transaction->isUnCommitted(nodeID.tableID, nodeID.offset)) { - nodeInfo.localScanState->source = TableScanSource::UNCOMMITTED; - nodeInfo.localScanState->nodeGroupIdx = StorageUtils::getNodeGroupIdx( - transaction->getLocalRowIdx(nodeID.tableID, nodeID.offset)); - } else { - nodeInfo.localScanState->source = TableScanSource::COMMITTED; - nodeInfo.localScanState->nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeID.offset); - } - nodeInfo.table->initScanState(transaction, *nodeInfo.localScanState); + nodeInfo.table->initScanState(transaction, *nodeInfo.localScanState, nodeID.tableID, + nodeID.offset); if (!nodeInfo.table->lookup(transaction, *nodeInfo.localScanState)) { // LCOV_EXCL_START throw RuntimeException(stringFormat("Cannot perform lookup on {}. This should not happen.", diff --git a/src/processor/operator/scan/primary_key_scan_node_table.cpp b/src/processor/operator/scan/primary_key_scan_node_table.cpp index b0c298b18d3..a1cdedfdd12 100644 --- a/src/processor/operator/scan/primary_key_scan_node_table.cpp +++ b/src/processor/operator/scan/primary_key_scan_node_table.cpp @@ -79,15 +79,8 @@ bool PrimaryKeyScanNodeTable::getNextTuplesInternal(ExecutionContext* context) { } auto nodeID = nodeID_t{nodeOffset, nodeInfo.table->getTableID()}; nodeInfo.localScanState->nodeIDVector->setValue(pos, nodeID); - if (transaction->isUnCommitted(nodeID.tableID, nodeOffset)) { - nodeInfo.localScanState->source = TableScanSource::UNCOMMITTED; - nodeInfo.localScanState->nodeGroupIdx = - StorageUtils::getNodeGroupIdx(transaction->getLocalRowIdx(nodeID.tableID, nodeOffset)); - } else { - nodeInfo.localScanState->source = TableScanSource::COMMITTED; - nodeInfo.localScanState->nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); - } - nodeInfo.table->initScanState(transaction, *nodeInfo.localScanState); + nodeInfo.table->initScanState(transaction, *nodeInfo.localScanState, nodeID.tableID, + nodeOffset); metrics->numOutputTuple.incrementByOne(); return nodeInfo.table->lookup(transaction, *nodeInfo.localScanState); } diff --git a/src/storage/store/csr_node_group.cpp b/src/storage/store/csr_node_group.cpp index f2152e75b15..734255b6923 100644 --- a/src/storage/store/csr_node_group.cpp +++ b/src/storage/store/csr_node_group.cpp @@ -22,15 +22,15 @@ bool CSRNodeGroupScanState::tryScanCachedTuples(RelTableScanState& tableScanStat const auto startCSROffset = header->getStartCSROffset(boundNodeOffsetInGroup); const auto csrLength = header->getCSRLength(boundNodeOffsetInGroup); nextCachedRowToScan = std::max(nextCachedRowToScan, startCSROffset); - if (nextCachedRowToScan >= numScannedRows || - nextCachedRowToScan < numScannedRows - numCachedRows) { + if (nextCachedRowToScan >= nextRowToScan || + nextCachedRowToScan < nextRowToScan - numCachedRows) { // Out of the bound of cached rows. return false; } - KU_ASSERT(nextCachedRowToScan >= numScannedRows - numCachedRows); + KU_ASSERT(nextCachedRowToScan >= nextRowToScan - numCachedRows); const auto numRowsToScan = - std::min(numScannedRows, startCSROffset + csrLength) - nextCachedRowToScan; - const auto startCachedRow = nextCachedRowToScan - (numScannedRows - numCachedRows); + std::min(nextRowToScan, startCSROffset + csrLength) - nextCachedRowToScan; + const auto startCachedRow = nextCachedRowToScan - (nextRowToScan - numCachedRows); auto numSelected = 0u; tableScanState.outState->getSelVectorUnsafe().setToFiltered(); for (auto i = 0u; i < numRowsToScan; i++) { @@ -54,7 +54,6 @@ void CSRNodeGroup::initializeScanState(Transaction* transaction, TableScanState& KU_ASSERT(relScanState.nodeGroupScanState); auto& nodeGroupScanState = relScanState.nodeGroupScanState->cast(); if (relScanState.nodeGroupIdx != nodeGroupIdx) { - relScanState.nodeGroupScanState->resetState(); relScanState.nodeGroupIdx = nodeGroupIdx; if (persistentChunkGroup) { initScanForCommittedPersistent(transaction, relScanState, nodeGroupScanState); @@ -62,15 +61,14 @@ void CSRNodeGroup::initializeScanState(Transaction* transaction, TableScanState& } // Switch to a new Vector of bound nodes (i.e., new csr lists) in the node group. if (persistentChunkGroup) { - nodeGroupScanState.numScannedRows = 0; - nodeGroupScanState.numCachedRows = 0; nodeGroupScanState.nextRowToScan = 0; + nodeGroupScanState.numCachedRows = 0; nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_PERSISTENT; } else if (csrIndex) { initScanForCommittedInMem(relScanState, nodeGroupScanState); } else { nodeGroupScanState.source = CSRNodeGroupScanSource::NONE; - nodeGroupScanState.numScannedRows = 0; + nodeGroupScanState.nextRowToScan = 0; } } @@ -106,9 +104,8 @@ void CSRNodeGroup::initScanForCommittedInMem(RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const { relScanState.currBoundNodeIdx = 0; nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_IN_MEMORY; - nodeGroupScanState.numScannedRows = 0; - nodeGroupScanState.numCachedRows = 0; nodeGroupScanState.nextRowToScan = 0; + nodeGroupScanState.numCachedRows = 0; nodeGroupScanState.inMemCSRList.clear(); } @@ -160,11 +157,11 @@ NodeGroupScanResult CSRNodeGroup::scanCommittedPersistentWithCache(const Transac while (nodeGroupScanState.tryScanCachedTuples(tableState)) { if (tableState.outState->getSelVector().getSelSize() > 0) { // Note: This is a dummy return value. - return NodeGroupScanResult{nodeGroupScanState.numScannedRows, + return NodeGroupScanResult{nodeGroupScanState.nextRowToScan, tableState.outState->getSelVector().getSelSize()}; } } - if (nodeGroupScanState.numScannedRows == nodeGroupScanState.numTotalRows || + if (nodeGroupScanState.nextRowToScan == nodeGroupScanState.numTotalRows || tableState.currBoundNodeIdx >= tableState.cachedBoundNodeSelVector.getSelSize()) { return NODE_GROUP_SCAN_EMMPTY_RESULT; } @@ -172,17 +169,17 @@ NodeGroupScanResult CSRNodeGroup::scanCommittedPersistentWithCache(const Transac tableState.cachedBoundNodeSelVector[tableState.currBoundNodeIdx]); const auto offsetInGroup = currNodeOffset % StorageConstants::NODE_GROUP_SIZE; const auto startCSROffset = nodeGroupScanState.header->getStartCSROffset(offsetInGroup); - if (startCSROffset > nodeGroupScanState.numScannedRows) { - nodeGroupScanState.numScannedRows = startCSROffset; + if (startCSROffset > nodeGroupScanState.nextRowToScan) { + nodeGroupScanState.nextRowToScan = startCSROffset; } - KU_ASSERT(nodeGroupScanState.numScannedRows <= nodeGroupScanState.numTotalRows); + KU_ASSERT(nodeGroupScanState.nextRowToScan <= nodeGroupScanState.numTotalRows); const auto numToScan = - std::min(nodeGroupScanState.numTotalRows - nodeGroupScanState.numScannedRows, + std::min(nodeGroupScanState.numTotalRows - nodeGroupScanState.nextRowToScan, DEFAULT_VECTOR_CAPACITY); persistentChunkGroup->scan(transaction, tableState, nodeGroupScanState, - nodeGroupScanState.numScannedRows, numToScan); + nodeGroupScanState.nextRowToScan, numToScan); nodeGroupScanState.numCachedRows = numToScan; - nodeGroupScanState.numScannedRows += numToScan; + nodeGroupScanState.nextRowToScan += numToScan; if (tableState.outState->getSelVector().isUnfiltered()) { nodeGroupScanState.cachedScannedVectorsSelBitset.set(); } else { diff --git a/src/storage/store/node_group.cpp b/src/storage/store/node_group.cpp index 21e8767c114..3ecb788f695 100644 --- a/src/storage/store/node_group.cpp +++ b/src/storage/store/node_group.cpp @@ -1,6 +1,9 @@ #include "storage/store/node_group.h" +#include "common/assert.h" +#include "common/constants.h" #include "common/types/types.h" +#include "common/uniq_lock.h" #include "main/client_context.h" #include "storage/buffer_manager/memory_manager.h" #include "storage/enums/residency_state.h" @@ -107,7 +110,7 @@ void NodeGroup::initializeScanState(Transaction* transaction, TableScanState& st } static void initializeScanStateForChunkedGroup(const TableScanState& state, - ChunkedNodeGroup* chunkedGroup) { + const ChunkedNodeGroup* chunkedGroup) { KU_ASSERT(chunkedGroup); if (chunkedGroup->getResidencyState() != ResidencyState::ON_DISK) { return; @@ -130,19 +133,48 @@ void NodeGroup::initializeScanState(Transaction*, const UniqLock& lock, TableScanState& state) const { auto& nodeGroupScanState = *state.nodeGroupScanState; nodeGroupScanState.chunkedGroupIdx = 0; - nodeGroupScanState.numScannedRows = 0; + nodeGroupScanState.nextRowToScan = 0; ChunkedNodeGroup* firstChunkedGroup = chunkedGroups.getFirstGroup(lock); initializeScanStateForChunkedGroup(state, firstChunkedGroup); } +void applySemiMaskFilter(const TableScanState& state, row_idx_t numRowsToScan, + SelectionVector& selVector) { + auto& nodeGroupScanState = *state.nodeGroupScanState; + const auto startNodeOffset = nodeGroupScanState.nextRowToScan + + StorageUtils::getStartOffsetOfNodeGroup(state.nodeGroupIdx); + const auto endNodeOffset = startNodeOffset + numRowsToScan; + const auto& arr = state.semiMask->range(startNodeOffset, endNodeOffset); + if (arr.empty()) { + selVector.setSelSize(0); + } else { + auto stat = selVector.getMutableBuffer(); + uint64_t numSelectedValues = 0; + size_t i = 0, j = 0; + while (i < numRowsToScan && j < arr.size()) { + auto temp = arr[j] - startNodeOffset; + if (selVector[i] < temp) { + ++i; + } else if (selVector[i] > temp) { + ++j; + } else { + stat[numSelectedValues++] = temp; + ++i; + ++j; + } + } + selVector.setToFiltered(numSelectedValues); + } +} + NodeGroupScanResult NodeGroup::scan(Transaction* transaction, TableScanState& state) const { // TODO(Guodong): Move the locked part of figuring out the chunked group to initScan. const auto lock = chunkedGroups.lock(); auto& nodeGroupScanState = *state.nodeGroupScanState; KU_ASSERT(nodeGroupScanState.chunkedGroupIdx < chunkedGroups.getNumGroups(lock)); const auto chunkedGroup = chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); - if (chunkedGroup && nodeGroupScanState.numScannedRows >= - chunkedGroup->getNumRows() + chunkedGroup->getStartRowIdx()) { + if (nodeGroupScanState.nextRowToScan >= + chunkedGroup->getNumRows() + chunkedGroup->getStartRowIdx()) { nodeGroupScanState.chunkedGroupIdx++; if (nodeGroupScanState.chunkedGroupIdx >= chunkedGroups.getNumGroups(lock)) { return NODE_GROUP_SCAN_EMMPTY_RESULT; @@ -154,50 +186,89 @@ NodeGroupScanResult NodeGroup::scan(Transaction* transaction, TableScanState& st const auto& chunkedGroupToScan = *chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); const auto rowIdxInChunkToScan = - nodeGroupScanState.numScannedRows - chunkedGroupToScan.getStartRowIdx(); + nodeGroupScanState.nextRowToScan - chunkedGroupToScan.getStartRowIdx(); const auto numRowsToScan = std::min(chunkedGroupToScan.getNumRows() - rowIdxInChunkToScan, DEFAULT_VECTOR_CAPACITY); bool enableSemiMask = state.source == TableScanSource::COMMITTED && state.semiMask && state.semiMask->isEnabled(); if (enableSemiMask) { - const auto startNodeOffset = nodeGroupScanState.numScannedRows + - StorageUtils::getStartOffsetOfNodeGroup(state.nodeGroupIdx); - const auto endNodeOffset = startNodeOffset + numRowsToScan; - const auto& arr = state.semiMask->range(startNodeOffset, endNodeOffset); - if (arr.empty()) { - state.outState->getSelVectorUnsafe().setSelSize(0); - nodeGroupScanState.numScannedRows += numRowsToScan; - return NodeGroupScanResult{nodeGroupScanState.numScannedRows, 0}; - } else { - chunkedGroupToScan.scan(transaction, state, nodeGroupScanState, rowIdxInChunkToScan, - numRowsToScan); - auto& selVector = state.outState->getSelVectorUnsafe(); - auto stat = selVector.getMutableBuffer(); - uint64_t numSelectedValues = 0; - size_t i = 0, j = 0; - while (i < selVector.getSelSize() && j < arr.size()) { - auto temp = arr[j] - startNodeOffset; - if (selVector[i] < temp) { - ++i; - } else if (selVector[i] > temp) { - ++j; - } else { - stat[numSelectedValues++] = temp; - ++i; - ++j; - } - } - selVector.setToFiltered(numSelectedValues); + applySemiMaskFilter(state, numRowsToScan, state.outState->getSelVectorUnsafe()); + if (state.outState->getSelVector().getSelSize() == 0) { + state.nodeGroupScanState->nextRowToScan += numRowsToScan; + return NodeGroupScanResult{nodeGroupScanState.nextRowToScan, 0}; } - } else { - chunkedGroupToScan.scan(transaction, state, nodeGroupScanState, rowIdxInChunkToScan, - numRowsToScan); } - const auto startRow = nodeGroupScanState.numScannedRows; - nodeGroupScanState.numScannedRows += numRowsToScan; + chunkedGroupToScan.scan(transaction, state, nodeGroupScanState, rowIdxInChunkToScan, + numRowsToScan); + const auto startRow = nodeGroupScanState.nextRowToScan; + nodeGroupScanState.nextRowToScan += numRowsToScan; return NodeGroupScanResult{startRow, numRowsToScan}; } +NodeGroupScanResult NodeGroup::scan(Transaction* transaction, TableScanState& state, + offset_t startOffset, offset_t numRowsToScan) const { + bool enableSemiMask = + state.source == TableScanSource::COMMITTED && state.semiMask && state.semiMask->isEnabled(); + if (enableSemiMask) { + applySemiMaskFilter(state, numRowsToScan, state.outState->getSelVectorUnsafe()); + if (state.outState->getSelVector().getSelSize() == 0) { + state.nodeGroupScanState->nextRowToScan += numRowsToScan; + return NodeGroupScanResult{state.nodeGroupScanState->nextRowToScan, 0}; + } + } + if (state.outputVectors.size() == 0) { + auto startOffsetInGroup = + startOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + KU_ASSERT(scanInternal(chunkedGroups.lock(), transaction, state, startOffset, + numRowsToScan) == NodeGroupScanResult(startOffsetInGroup, numRowsToScan)); + return NodeGroupScanResult{startOffsetInGroup, numRowsToScan}; + } + return scanInternal(chunkedGroups.lock(), transaction, state, startOffset, numRowsToScan); +} + +NodeGroupScanResult NodeGroup::scanInternal(const common::UniqLock& lock, Transaction* transaction, + TableScanState& state, offset_t startOffset, offset_t numRowsToScan) const { + // Only meant for scanning once + KU_ASSERT(numRowsToScan <= DEFAULT_VECTOR_CAPACITY); + + auto nodeGroupStartOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + KU_ASSERT(startOffset >= nodeGroupStartOffset); + auto startOffsetInGroup = startOffset - nodeGroupStartOffset; + KU_ASSERT(startOffsetInGroup + numRowsToScan <= numRows); + + auto& nodeGroupScanState = *state.nodeGroupScanState; + nodeGroupScanState.nextRowToScan = startOffsetInGroup; + + auto newChunkedGroupIdx = findChunkedGroupIdxFromRowIdx(lock, startOffsetInGroup); + + const auto* chunkedGroupToScan = chunkedGroups.getGroup(lock, newChunkedGroupIdx); + if (newChunkedGroupIdx != nodeGroupScanState.chunkedGroupIdx) { + // If the chunked group matches the scan state, don't re-initialize it. + // E.g. we may scan a group multiple times in parts + initializeScanStateForChunkedGroup(state, chunkedGroupToScan); + nodeGroupScanState.chunkedGroupIdx = newChunkedGroupIdx; + } + + uint64_t numRowsScanned = 0; + do { + const auto rowIdxInChunkToScan = + (startOffsetInGroup + numRowsScanned) - chunkedGroupToScan->getStartRowIdx(); + + uint64_t numRowsToScanInChunk = std::min(numRowsToScan - numRowsScanned, + chunkedGroupToScan->getNumRows() - rowIdxInChunkToScan); + chunkedGroupToScan->scan(transaction, state, nodeGroupScanState, rowIdxInChunkToScan, + numRowsToScanInChunk); + numRowsScanned += numRowsToScanInChunk; + nodeGroupScanState.nextRowToScan += numRowsToScanInChunk; + if (numRowsScanned < numRowsToScan) { + nodeGroupScanState.chunkedGroupIdx++; + chunkedGroupToScan = chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); + } + } while (numRowsScanned < numRowsToScan); + + return NodeGroupScanResult{startOffsetInGroup, numRowsScanned}; +} + bool NodeGroup::lookup(const UniqLock& lock, Transaction* transaction, const TableScanState& state) { idx_t numTuplesFound = 0; @@ -455,21 +526,26 @@ std::unique_ptr NodeGroup::deserialize(MemoryManager& memoryManager, } } -ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdx(const UniqLock& lock, row_idx_t rowIdx) { +idx_t NodeGroup::findChunkedGroupIdxFromRowIdx(const UniqLock& lock, row_idx_t rowIdx) const { KU_ASSERT(!chunkedGroups.isEmpty(lock)); const auto numRowsInFirstGroup = chunkedGroups.getFirstGroup(lock)->getNumRows(); if (rowIdx < numRowsInFirstGroup) { - return chunkedGroups.getFirstGroup(lock); + return 0; } rowIdx -= numRowsInFirstGroup; - const auto chunkedGroupIdx = rowIdx / ChunkedNodeGroup::CHUNK_CAPACITY + 1; + return rowIdx / ChunkedNodeGroup::CHUNK_CAPACITY + 1; +} + +ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdx(const UniqLock& lock, + row_idx_t rowIdx) const { + auto chunkedGroupIdx = findChunkedGroupIdxFromRowIdx(lock, rowIdx); if (chunkedGroupIdx >= chunkedGroups.getNumGroups(lock)) { return nullptr; } return chunkedGroups.getGroup(lock, chunkedGroupIdx); } -ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdxNoLock(row_idx_t rowIdx) { +ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdxNoLock(row_idx_t rowIdx) const { const auto numRowsInFirstGroup = chunkedGroups.getFirstGroupNoLock()->getNumRows(); if (rowIdx < numRowsInFirstGroup) { return chunkedGroups.getFirstGroupNoLock(); diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 62fb6a1c674..62394029b1a 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -21,12 +21,34 @@ using namespace kuzu::evaluator; namespace kuzu { namespace storage { +bool NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset, + offset_t numNodes) { + KU_ASSERT(columns.size() == outputVectors.size()); + if (source == TableScanSource::NONE) { + return false; + } + const NodeGroupScanResult scanResult = + nodeGroup->scan(transaction, *this, startOffset, numNodes); + if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { + return false; + } + auto nodeGroupStartOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + if (source == TableScanSource::UNCOMMITTED) { + nodeGroupStartOffset = transaction->getUncommittedOffset(tableID, nodeGroupStartOffset); + } + for (auto i = 0u; i < scanResult.numRows; i++) { + nodeIDVector->setValue(i, + nodeID_t{nodeGroupStartOffset + scanResult.startRow + i, tableID}); + } + return true; +} + bool NodeTableScanState::scanNext(Transaction* transaction) { KU_ASSERT(columns.size() == outputVectors.size()); if (source == TableScanSource::NONE) { return false; } - const auto scanResult = nodeGroup->scan(transaction, *this); + const NodeGroupScanResult scanResult = nodeGroup->scan(transaction, *this); if (scanResult == NODE_GROUP_SCAN_EMMPTY_RESULT) { return false; } @@ -122,6 +144,20 @@ void NodeTable::initScanState(Transaction* transaction, TableScanState& scanStat nodeScanState.initState(transaction, nodeGroup); } +void NodeTable::initScanState(Transaction* transaction, TableScanState& scanState, + table_id_t tableID, offset_t startOffset) const { + if (transaction->isUnCommitted(tableID, startOffset)) { + scanState.source = TableScanSource::UNCOMMITTED; + scanState.nodeGroupIdx = + StorageUtils::getNodeGroupIdx(transaction->getLocalRowIdx(tableID, startOffset)); + + } else { + scanState.source = TableScanSource::COMMITTED; + scanState.nodeGroupIdx = StorageUtils::getNodeGroupIdx(startOffset); + } + initScanState(transaction, scanState); +} + bool NodeTable::scanInternal(Transaction* transaction, TableScanState& scanState) { scanState.resetOutVectors(); return scanState.scanNext(transaction); @@ -325,17 +361,14 @@ void NodeTable::commit(Transaction* transaction, LocalTable* localTable) { } // 3. Scan pk column for newly inserted tuples that are not deleted and insert into pk index. std::vector columnIDs{getPKColumnID()}; - std::vector types; + auto types = std::vector(); types.push_back(columns[pkColumnID]->getDataType().copy()); - const auto dataChunk = constructDataChunk({types}); + auto dataChunk = constructDataChunk(std::move(types)); ValueVector nodeIDVector(LogicalType::INTERNAL_ID()); - nodeIDVector.setState(dataChunk->state); + nodeIDVector.setState(dataChunk.state); const auto numNodeGroupsToScan = localNodeTable.getNumNodeGroups(); - const auto scanState = std::make_unique(tableID, columnIDs); - for (auto& vector : dataChunk->valueVectors) { - scanState->outputVectors.push_back(vector.get()); - } - scanState->outState = dataChunk->state.get(); + const auto scanState = std::make_unique(tableID, columnIDs, + std::vector{}, dataChunk, &nodeIDVector); scanState->source = TableScanSource::UNCOMMITTED; node_group_idx_t nodeGroupToScan = 0u; while (nodeGroupToScan < numNodeGroupsToScan) { diff --git a/src/storage/store/table.cpp b/src/storage/store/table.cpp index 59e188b9ff0..0d34cdb06ef 100644 --- a/src/storage/store/table.cpp +++ b/src/storage/store/table.cpp @@ -60,11 +60,11 @@ void Table::serialize(Serializer& serializer) const { serializer.write(tableID); } -std::unique_ptr Table::constructDataChunk(const std::vector& types) { - auto dataChunk = std::make_unique(types.size()); +DataChunk Table::constructDataChunk(std::vector types) const { + DataChunk dataChunk(types.size()); for (auto i = 0u; i < types.size(); i++) { - auto valueVector = std::make_unique(types[i].copy(), memoryManager); - dataChunk->insert(i, std::move(valueVector)); + auto valueVector = std::make_unique(std::move(types[i]), memoryManager); + dataChunk.insert(i, std::move(valueVector)); } return dataChunk; } diff --git a/test/storage/rel_scan_test.cpp b/test/storage/rel_scan_test.cpp index 2d69cd31a5e..d9d84a80096 100644 --- a/test/storage/rel_scan_test.cpp +++ b/test/storage/rel_scan_test.cpp @@ -4,6 +4,7 @@ #include "catalog/catalog.h" #include "common/types/date_t.h" +#include "common/types/ku_string.h" #include "common/types/types.h" #include "graph/graph_entry.h" #include "graph/on_disk_graph.h" @@ -131,5 +132,27 @@ TEST_F(RelScanTest, ScanFwd) { compare(2, {0, 1, 3}, {6, 7, 8}, {1, 4, 11}); } +TEST_F(RelScanTest, ScanVertexProperties) { + auto tableID = catalog->getTableID(context->getTx(), "person"); + std::vector properties = {"fname", "height"}; + auto scanState = graph->prepareVertexScan(tableID, properties); + + const auto compare = [&](offset_t startNodeOffset, offset_t endNodeOffset, + std::vector> expectedNames) { + std::vector> results; + for (auto chunk : graph->scanVertices(startNodeOffset, endNodeOffset, *scanState)) { + for (size_t i = 0; i < chunk.size(); i++) { + results.push_back(std::make_tuple(chunk.getNodeIDs()[i].offset, + chunk.getProperties(0)[i].getAsString(), + chunk.getProperties(1)[i])); + }; + } + ASSERT_EQ(results, expectedNames); + }; + compare(0, 3, {{0, "Alice", 1.731}, {1, "Bob", 0.99}, {2, "Carol", 1.0}}); + compare(1, 3, {{1, "Bob", 0.99}, {2, "Carol", 1.0}}); + compare(2, 4, {{2, "Carol", 1.0}, {3, "Dan", 1.3}}); +} + } // namespace testing } // namespace kuzu