Skip to content

Commit

Permalink
Add GDS support for vertex property scanning (#4453)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger authored Nov 18, 2024
1 parent 1beaa37 commit 2591604
Show file tree
Hide file tree
Showing 28 changed files with 519 additions and 242 deletions.
6 changes: 5 additions & 1 deletion src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
namespace kuzu {
namespace common {

ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager)
ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager,
std::shared_ptr<DataChunkState> dataChunkState)
: dataType{std::move(dataType)}, nullMask{DEFAULT_VECTOR_CAPACITY} {
if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
// LCOV_EXCL_START
Expand All @@ -26,6 +27,9 @@ ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryMan
numBytesPerValue = getDataTypeSize(this->dataType);
initializeValueBuffer();
auxiliaryBuffer = AuxiliaryBufferFactory::getAuxiliaryBuffer(this->dataType, memoryManager);
if (dataChunkState) {
setState(std::move(dataChunkState));
}
}

void ValueVector::setState(const std::shared_ptr<DataChunkState>& state_) {
Expand Down
4 changes: 2 additions & 2 deletions src/function/gds/all_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class AllSPDestinationsEdgeCompute : public SPEdgeCompute {
PathMultiplicities* multiplicities)
: SPEdgeCompute{frontierPair}, multiplicities{multiplicities} {};

std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, GraphScanState::Chunk& resultChunk,
std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk,
bool) override {
std::vector<nodeID_t> activeNodes;
resultChunk.forEach([&](auto nbrNodeID, auto /*edgeID*/) {
Expand Down Expand Up @@ -190,7 +190,7 @@ class AllSPPathsEdgeCompute : public SPEdgeCompute {
parentListBlock = bfsGraph->addNewBlock();
}

std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, GraphScanState::Chunk& resultChunk,
std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk,
bool fwdEdge) override {
std::vector<nodeID_t> activeNodes;
resultChunk.forEach([&](auto nbrNodeID, auto edgeID) {
Expand Down
16 changes: 10 additions & 6 deletions src/function/gds/gds_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

static uint64_t computeScanResult(nodeID_t sourceNodeID, graph::GraphScanState::Chunk& chunk,
static uint64_t computeScanResult(nodeID_t sourceNodeID, graph::NbrScanState::Chunk& nbrChunk,
EdgeCompute& ec, FrontierPair& frontierPair, bool isFwd) {
auto activeNodes = ec.edgeCompute(sourceNodeID, chunk, isFwd);
auto activeNodes = ec.edgeCompute(sourceNodeID, nbrChunk, isFwd);
frontierPair.getNextFrontierUnsafe().setActive(activeNodes);
return chunk.size();
return nbrChunk.size();
}

void FrontierTask::run() {
Expand Down Expand Up @@ -49,11 +49,15 @@ void FrontierTask::run() {

void VertexComputeTask::run() {
FrontierMorsel frontierMorsel;
auto graph = sharedState->graph;
std::vector<std::string> propertiesToScan;
auto scanState =
graph->prepareVertexScan(sharedState->morselDispatcher.getTableID(), propertiesToScan);
auto localVc = info.vc.copy();
while (sharedState->morselDispatcher.getNextRangeMorsel(frontierMorsel)) {
while (frontierMorsel.hasNextOffset()) {
common::nodeID_t nodeID = frontierMorsel.getNextNodeID();
localVc->vertexCompute(nodeID);
for (auto chunk : graph->scanVertices(frontierMorsel.getBeginOffset(),
frontierMorsel.getEndOffsetExclusive(), *scanState)) {
localVc->vertexCompute(chunk);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/function/gds/gds_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void GDSUtils::runVertexComputeIteration(processor::ExecutionContext* executionC
auto maxThreads =
clientContext->getCurrentSetting(main::ThreadsSetting::name).getValue<uint64_t>();
auto info = VertexComputeTaskInfo(vc);
auto sharedState = std::make_shared<VertexComputeTaskSharedState>(maxThreads);
auto sharedState = std::make_shared<VertexComputeTaskSharedState>(maxThreads, graph);
for (auto& tableID : graph->getNodeTableIDs()) {
if (!vc.beginOnTable(tableID)) {
continue;
Expand Down
14 changes: 10 additions & 4 deletions src/function/gds/rec_joins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "function/gds/gds.h"
#include "function/gds/gds_frontier.h"
#include "function/gds/gds_utils.h"
#include "graph/graph.h"
#include "processor/execution_context.h"
#include "processor/result/factorized_table.h"
#include "storage/buffer_manager/memory_manager.h"
Expand Down Expand Up @@ -149,11 +150,16 @@ class RJVertexCompute : public VertexCompute {
return true;
}

void vertexCompute(nodeID_t nodeID) override {
if (sharedState->exceedLimit() || writer->skip(nodeID)) {
return;
void vertexCompute(const graph::VertexScanState::Chunk& chunk) override {
for (auto nodeID : chunk.getNodeIDs()) {
if (sharedState->exceedLimit()) {
return;
}
if (writer->skip(nodeID)) {
continue;
}
writer->write(*localFT, nodeID, sharedState->counter.get());
}
writer->write(*localFT, nodeID, sharedState->counter.get());
}

std::unique_ptr<VertexCompute> copy() override {
Expand Down
4 changes: 2 additions & 2 deletions src/function/gds/single_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SingleSPDestinationsEdgeCompute : public SPEdgeCompute {
explicit SingleSPDestinationsEdgeCompute(SinglePathLengthsFrontierPair* frontierPair)
: SPEdgeCompute{frontierPair} {};

std::vector<nodeID_t> edgeCompute(common::nodeID_t, GraphScanState::Chunk& resultChunk,
std::vector<nodeID_t> edgeCompute(common::nodeID_t, NbrScanState::Chunk& resultChunk,
bool) override {
std::vector<nodeID_t> activeNodes;
resultChunk.forEach([&](auto nbrNode, auto) {
Expand All @@ -59,7 +59,7 @@ class SingleSPPathsEdgeCompute : public SPEdgeCompute {
parentListBlock = bfsGraph->addNewBlock();
}

std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, GraphScanState::Chunk& resultChunk,
std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk,
bool isFwd) override {
std::vector<nodeID_t> activeNodes;
resultChunk.forEach([&](auto nbrNodeID, auto edgeID) {
Expand Down
2 changes: 1 addition & 1 deletion src/function/gds/variable_length_path.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct VarLenJoinsEdgeCompute : public EdgeCompute {
parentPtrsBlock = bfsGraph->addNewBlock();
};

std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, graph::GraphScanState::Chunk& chunk,
std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk,
bool isFwd) override {
std::vector<nodeID_t> activeNodes;
chunk.forEach([&](auto nbrNode, auto edgeID) {
Expand Down
3 changes: 1 addition & 2 deletions src/function/gds/weakly_connected_components.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ class WeaklyConnectedComponent final : public GDSAlgorithm {
}

private:
void findConnectedComponent(common::nodeID_t nodeID, int64_t groupID,
GraphScanState& scanState) {
void findConnectedComponent(common::nodeID_t nodeID, int64_t groupID, NbrScanState& scanState) {
KU_ASSERT(!visitedMap.contains(nodeID));
visitedMap.insert({nodeID, groupID});
// Collect the nodes so that the recursive scan doesn't begin until this scan is done
Expand Down
111 changes: 89 additions & 22 deletions src/graph/on_disk_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "binder/expression/property_expression.h"
#include "common/assert.h"
#include "common/cast.h"
#include "common/constants.h"
#include "common/data_chunk/data_chunk_state.h"
#include "common/enums/rel_direction.h"
#include "common/types/types.h"
Expand All @@ -13,9 +15,13 @@
#include "main/client_context.h"
#include "planner/operator/schema.h"
#include "processor/expression_mapper.h"
#include "storage/buffer_manager/memory_manager.h"
#include "storage/local_storage/local_rel_table.h"
#include "storage/local_storage/local_storage.h"
#include "storage/storage_manager.h"
#include "storage/storage_utils.h"
#include "storage/store/column.h"
#include "storage/store/node_table.h"
#include "storage/store/rel_table.h"
#include "storage/store/table.h"

Expand Down Expand Up @@ -72,8 +78,9 @@ static std::unique_ptr<RelTableScanState> getRelScanState(
return scanState;
}

OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::span<RelTable*> tables,
const GraphEntry& graphEntry, std::optional<idx_t> edgePropertyIndex)
OnDiskGraphNbrScanStates::OnDiskGraphNbrScanStates(ClientContext* context,
std::span<RelTable*> tables, const GraphEntry& graphEntry,
std::optional<idx_t> edgePropertyIndex)
: iteratorIndex{0}, direction{RelDataDirection::INVALID} {
auto schema = graphEntry.getRelPropertiesSchema();
auto descriptor = ResultSetDescriptor(&schema);
Expand Down Expand Up @@ -121,7 +128,7 @@ OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::span<R
relIDVector.get(), graphEntry.getRelProperties(), edgePropertyID, propertyVector.get(),
schema, resultSet);
scanStates.emplace_back(table->getTableID(),
OnDiskGraphScanState{context, *table, std::move(fwdState), std::move(bwdState)});
OnDiskGraphNbrScanState{context, *table, std::move(fwdState), std::move(bwdState)});
}
}

Expand Down Expand Up @@ -201,14 +208,14 @@ std::vector<RelTableIDInfo> OnDiskGraph::getRelTableIDInfos() {
return result;
}

std::unique_ptr<GraphScanState> OnDiskGraph::prepareScan(table_id_t relTableID,
std::unique_ptr<NbrScanState> OnDiskGraph::prepareScan(table_id_t relTableID,
std::optional<idx_t> edgePropertyIndex) {
auto relTable = context->getStorageManager()->getTable(relTableID)->ptrCast<RelTable>();
return std::unique_ptr<OnDiskGraphScanStates>(
new OnDiskGraphScanStates(context, std::span(&relTable, 1), graphEntry, edgePropertyIndex));
return std::unique_ptr<OnDiskGraphNbrScanStates>(new OnDiskGraphNbrScanStates(context,
std::span(&relTable, 1), graphEntry, edgePropertyIndex));
}

std::unique_ptr<GraphScanState> OnDiskGraph::prepareMultiTableScanFwd(
std::unique_ptr<NbrScanState> OnDiskGraph::prepareMultiTableScanFwd(
std::span<table_id_t> nodeTableIDs) {
std::unordered_set<table_id_t> relTableIDSet;
std::vector<RelTable*> tables;
Expand All @@ -220,11 +227,11 @@ std::unique_ptr<GraphScanState> OnDiskGraph::prepareMultiTableScanFwd(
}
}
}
return std::unique_ptr<OnDiskGraphScanStates>(
new OnDiskGraphScanStates(context, std::span(tables), graphEntry));
return std::unique_ptr<OnDiskGraphNbrScanStates>(
new OnDiskGraphNbrScanStates(context, std::span(tables), graphEntry));
}

std::unique_ptr<GraphScanState> OnDiskGraph::prepareMultiTableScanBwd(
std::unique_ptr<NbrScanState> OnDiskGraph::prepareMultiTableScanBwd(
std::span<table_id_t> nodeTableIDs) {
std::unordered_set<table_id_t> relTableIDSet;
std::vector<RelTable*> tables;
Expand All @@ -236,12 +243,12 @@ std::unique_ptr<GraphScanState> OnDiskGraph::prepareMultiTableScanBwd(
}
}
}
return std::unique_ptr<OnDiskGraphScanStates>(
new OnDiskGraphScanStates(context, std::span(tables), graphEntry));
return std::unique_ptr<OnDiskGraphNbrScanStates>(
new OnDiskGraphNbrScanStates(context, std::span(tables), graphEntry));
}

Graph::Iterator OnDiskGraph::scanFwd(nodeID_t nodeID, GraphScanState& state) {
auto& onDiskScanState = ku_dynamic_cast<OnDiskGraphScanStates&>(state);
Graph::EdgeIterator OnDiskGraph::scanFwd(nodeID_t nodeID, NbrScanState& state) {
auto& onDiskScanState = ku_dynamic_cast<OnDiskGraphNbrScanStates&>(state);
onDiskScanState.srcNodeIDVector->setValue<nodeID_t>(0, nodeID);
onDiskScanState.dstNodeIDVector->state->getSelVectorUnsafe().setSelSize(0);
KU_ASSERT(nodeTableIDToFwdRelTables.contains(nodeID.tableID));
Expand All @@ -253,11 +260,11 @@ Graph::Iterator OnDiskGraph::scanFwd(nodeID_t nodeID, GraphScanState& state) {
}
}
onDiskScanState.startScan(common::RelDataDirection::FWD);
return Graph::Iterator(&onDiskScanState);
return Graph::EdgeIterator(&onDiskScanState);
}

Graph::Iterator OnDiskGraph::scanBwd(nodeID_t nodeID, GraphScanState& state) {
auto& onDiskScanState = ku_dynamic_cast<OnDiskGraphScanStates&>(state);
Graph::EdgeIterator OnDiskGraph::scanBwd(nodeID_t nodeID, NbrScanState& state) {
auto& onDiskScanState = ku_dynamic_cast<OnDiskGraphNbrScanStates&>(state);
onDiskScanState.srcNodeIDVector->setValue<nodeID_t>(0, nodeID);
onDiskScanState.dstNodeIDVector->state->getSelVectorUnsafe().setSelSize(0);
KU_ASSERT(nodeTableIDToBwdRelTables.contains(nodeID.tableID));
Expand All @@ -269,10 +276,10 @@ Graph::Iterator OnDiskGraph::scanBwd(nodeID_t nodeID, GraphScanState& state) {
}
}
onDiskScanState.startScan(common::RelDataDirection::BWD);
return Graph::Iterator(&onDiskScanState);
return Graph::EdgeIterator(&onDiskScanState);
}

bool OnDiskGraphScanState::InnerIterator::next(evaluator::ExpressionEvaluator* predicate) {
bool OnDiskGraphNbrScanState::InnerIterator::next(evaluator::ExpressionEvaluator* predicate) {
while (true) {
if (!relTable->scan(context->getTx(), *tableScanState)) {
return false;
Expand All @@ -287,15 +294,15 @@ bool OnDiskGraphScanState::InnerIterator::next(evaluator::ExpressionEvaluator* p
}
}

OnDiskGraphScanState::InnerIterator::InnerIterator(const main::ClientContext* context,
OnDiskGraphNbrScanState::InnerIterator::InnerIterator(const main::ClientContext* context,
storage::RelTable* relTable, std::unique_ptr<storage::RelTableScanState> tableScanState)
: context{context}, relTable{relTable}, tableScanState{std::move(tableScanState)} {}

void OnDiskGraphScanState::InnerIterator::initScan() {
void OnDiskGraphNbrScanState::InnerIterator::initScan() {
relTable->initScanState(context->getTx(), *tableScanState);
}

bool OnDiskGraphScanStates::next() {
bool OnDiskGraphNbrScanStates::next() {
while (iteratorIndex < scanStates.size()) {
if (getInnerIterator().next(relPredicateEvaluator.get())) {
return true;
Expand All @@ -305,5 +312,65 @@ bool OnDiskGraphScanStates::next() {
return false;
}

OnDiskGraphVertexScanState::OnDiskGraphVertexScanState(ClientContext& context,
common::table_id_t tableID, const std::vector<std::string>& propertyNames)
: context{context},
nodeTable{ku_dynamic_cast<const NodeTable&>(*context.getStorageManager()->getTable(tableID))},
numNodesScanned{0}, tableID{tableID}, currentOffset{0}, endOffsetExclusive{0} {
std::vector<column_id_t> propertyColumnIDs;
propertyColumnIDs.reserve(propertyNames.size());
auto tableCatalogEntry = context.getCatalog()->getTableCatalogEntry(context.getTx(), tableID);
std::vector<const Column*> columns;
std::vector<LogicalType> types;
for (const auto& property : propertyNames) {
auto columnID = tableCatalogEntry->getColumnID(property);
propertyColumnIDs.push_back(columnID);
columns.push_back(&nodeTable.getColumn(columnID));
types.push_back(columns.back()->getDataType().copy());
}
propertyVectors = nodeTable.constructDataChunk(std::move(types));
nodeIDVector = std::make_unique<ValueVector>(LogicalType::INTERNAL_ID(),
context.getMemoryManager(), propertyVectors.state);
tableScanState = std::make_unique<NodeTableScanState>(tableID, std::move(propertyColumnIDs),
std::move(columns), propertyVectors, nodeIDVector.get());
}

bool OnDiskGraphVertexScanState::next() {
if (currentOffset >= endOffsetExclusive) {
return false;
}
if (currentOffset < endOffsetExclusive &&
StorageUtils::getNodeGroupIdx(currentOffset) != tableScanState->nodeGroupIdx) {
startScan(currentOffset, endOffsetExclusive);
}

auto endOffset = std::min(endOffsetExclusive,
StorageUtils::getStartOffsetOfNodeGroup(tableScanState->nodeGroupIdx + 1));
numNodesScanned = std::min(endOffset - currentOffset, DEFAULT_VECTOR_CAPACITY);
auto result = tableScanState->scanNext(context.getTx(), currentOffset, numNodesScanned);
currentOffset += numNodesScanned;
return result;
}

Graph::VertexIterator OnDiskGraph::scanVertices(common::offset_t beginOffset,
common::offset_t endOffsetExclusive, VertexScanState& state) {
auto& onDiskVertexScanState = ku_dynamic_cast<OnDiskGraphVertexScanState&>(state);
onDiskVertexScanState.startScan(beginOffset, endOffsetExclusive);
return Graph::VertexIterator(&state);
}

std::unique_ptr<VertexScanState> OnDiskGraph::prepareVertexScan(common::table_id_t tableID,
const std::vector<std::string>& propertiesToScan) {
return std::make_unique<OnDiskGraphVertexScanState>(*context, tableID, propertiesToScan);
}

void OnDiskGraphVertexScanState::startScan(common::offset_t beginOffset,
common::offset_t endOffsetExclusive) {
numNodesScanned = 0;
this->currentOffset = beginOffset;
this->endOffsetExclusive = endOffsetExclusive;
nodeTable.initScanState(context.getTx(), *tableScanState, tableID, beginOffset);
}

} // namespace graph
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class KUZU_API ValueVector {
friend class ArrowColumnVector;

public:
explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr);
explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr,
std::shared_ptr<DataChunkState> dataChunkState = nullptr);
explicit ValueVector(LogicalTypeID dataTypeID, storage::MemoryManager* memoryManager = nullptr)
: ValueVector(LogicalType(dataTypeID), memoryManager) {
KU_ASSERT(dataTypeID != LogicalTypeID::LIST);
Expand Down
Loading

0 comments on commit 2591604

Please sign in to comment.