Skip to content

Commit

Permalink
Merge pull request facebookincubator#23 from karthikeyann/fea-cudfdri…
Browse files Browse the repository at this point in the history
…ver_inspect2

Add plan inspect to cudfDriverAdapter to cache planNodes
add unregisterCudf, and to TearDown
replace public API changes with teardown unregister pattern
  • Loading branch information
karthikeyann authored Aug 30, 2024
2 parents ed2a092 + 1bf7390 commit 2397144
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 20 deletions.
4 changes: 0 additions & 4 deletions velox/exec/HashBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ class HashBuild final : public Operator {

void close() override;

std::shared_ptr<const core::HashJoinNode> getPlanNode() const {
return joinNode_;
}

private:
void setState(State state);
void checkStateTransition(State state);
Expand Down
4 changes: 0 additions & 4 deletions velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ class HashProbe : public Operator {
return inputSpiller_ != nullptr;
}

std::shared_ptr<const core::HashJoinNode> getPlanNode() const {
return joinNode_;
}

private:
// Indicates if the join type includes misses from the left side in the
// output.
Expand Down
75 changes: 65 additions & 10 deletions velox/experimental/cudf/exec/ToCudf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace facebook::velox::cudf_velox {
bool CompileState::compile() {
std::cout << "Calling cudfDriverAdapter" << std::endl;
auto operators = driver_.operators();
auto& nodes = driverFactory_.planNodes;
auto& nodes = planNodes_;
std::cout << "Number of operators: " << operators.size() << std::endl;
for (auto& op : operators) {
std::cout << " Operator: ID " << op->operatorId() << ": " << op->toString()
Expand All @@ -49,6 +49,15 @@ bool CompileState::compile() {
bool replacements_made = false;
auto ctx = driver_.driverCtx();

// Get plan node by id lookup.
auto get_plan_node = [&](const core::PlanNodeId& id) {
auto it =
std::find_if(nodes.cbegin(), nodes.cend(), [&id](const auto& node) {
return node->id() == id;
});
VELOX_CHECK(it != nodes.end());
return *it;
};
// Replace HashBuild and HashProbe operators with CudfHashJoinBuild and
// CudfHashJoinProbe operators.
for (int32_t operatorIndex = 0; operatorIndex < operators.size();
Expand All @@ -60,7 +69,7 @@ bool CompileState::compile() {
if (auto joinBuildOp = dynamic_cast<exec::HashBuild*>(oper)) {
auto id = joinBuildOp->operatorId();
auto plan_node = std::dynamic_pointer_cast<const core::HashJoinNode>(
joinBuildOp->getPlanNode());
get_plan_node(joinBuildOp->planNodeId()));
VELOX_CHECK(plan_node != nullptr);
replace_op.push_back(
std::make_unique<CudfHashJoinBuild>(id, ctx, plan_node));
Expand All @@ -71,7 +80,7 @@ bool CompileState::compile() {
} else if (auto joinProbeOp = dynamic_cast<exec::HashProbe*>(oper)) {
auto id = joinProbeOp->operatorId();
auto plan_node = std::dynamic_pointer_cast<const core::HashJoinNode>(
joinProbeOp->getPlanNode());
get_plan_node(joinProbeOp->planNodeId()));
VELOX_CHECK(plan_node != nullptr);
replace_op.push_back(
std::make_unique<CudfHashJoinProbe>(id, ctx, plan_node));
Expand All @@ -84,12 +93,52 @@ bool CompileState::compile() {
return replacements_made;
}

bool cudfDriverAdapter(
const exec::DriverFactory& factory,
exec::Driver& driver) {
auto state = CompileState(factory, driver);
return state.compile();
}
struct cudfDriverAdapter {
std::shared_ptr<std::vector<std::shared_ptr<core::PlanNode const>>> planNodes;
cudfDriverAdapter() {
std::cout << "cudfDriverAdapter constructor" << std::endl;
planNodes =
std::make_shared<std::vector<std::shared_ptr<core::PlanNode const>>>();
}
~cudfDriverAdapter() {
std::cout << "cudfDriverAdapter destructor" << std::endl;
printf(
"cached planNodes %p, %ld\n", planNodes.get(), planNodes.use_count());
}
// driveradapter
bool operator()(const exec::DriverFactory& factory, exec::Driver& driver) {
auto state = CompileState(factory, driver, *planNodes);
// Stored planNodes from inspect.
printf("driver.planNodes=%p\n", planNodes.get());
for (auto planNode : *planNodes) {
std::cout << "PlanNode: " << (*planNode).toString() << std::endl;
}
auto res = state.compile();
return res;
}
// Iterate recursively and store them in the planNodes_ptr.
void storePlanNodes(const std::shared_ptr<const core::PlanNode>& planNode) {
const auto& sources = planNode->sources();
for (int32_t i = 0; i < sources.size(); ++i) {
storePlanNodes(sources[i]);
}
planNodes->push_back(planNode);
}

// inspect
void operator()(const core::PlanFragment& planFragment) {
// signature: std::function<void(const core::PlanFragment&)> inspect;
// call: adapter.inspect(planFragment);
planNodes->clear();
std::cout << "Inspecting PlanFragment: " << std::endl;
if (planNodes) {
printf("inspect.planNodes=%p\n", planNodes.get());
storePlanNodes(planFragment.planNode);
} else {
std::cout << "planNodes_ptr is nullptr" << std::endl;
}
}
};

void registerCudf() {
CUDF_FUNC_RANGE();
Expand All @@ -98,7 +147,13 @@ void registerCudf() {
exec::Operator::registerOperator(
std::make_unique<CudfHashJoinBridgeTranslator>());
std::cout << "Registering cudfDriverAdapter" << std::endl;
exec::DriverAdapter cudfAdapter{"cuDF", {}, cudfDriverAdapter};
cudfDriverAdapter cda{};
exec::DriverAdapter cudfAdapter{"cuDF", cda, cda};
exec::DriverFactory::registerAdapter(cudfAdapter);
}

void unregisterCudf() {
std::cout << "unRegistering cudfDriverAdapter" << std::endl;
exec::DriverFactory::adapters.clear();
}
} // namespace facebook::velox::cudf_velox
9 changes: 7 additions & 2 deletions velox/experimental/cudf/exec/ToCudf.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ namespace facebook::velox::cudf_velox {

class CompileState {
public:
CompileState(const exec::DriverFactory& driverFactory, exec::Driver& driver)
: driverFactory_(driverFactory), driver_(driver) {}
CompileState(
const exec::DriverFactory& driverFactory,
exec::Driver& driver,
std::vector<std::shared_ptr<core::PlanNode const>>& planNodes)
: driverFactory_(driverFactory), driver_(driver), planNodes_(planNodes) {}

exec::Driver& driver() {
return driver_;
Expand All @@ -36,9 +39,11 @@ class CompileState {

const exec::DriverFactory& driverFactory_;
exec::Driver& driver_;
const std::vector<std::shared_ptr<core::PlanNode const>>& planNodes_;
};

/// Registers adapter to add cuDF operators to Drivers.
void registerCudf();
void unregisterCudf();

} // namespace facebook::velox::cudf_velox
5 changes: 5 additions & 0 deletions velox/experimental/cudf/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,11 @@ class HashJoinTest : public HiveConnectorTestBase {
.allowLazyVector = false};
}

void TearDown() override {
cudf_velox::unregisterCudf();
HiveConnectorTestBase::TearDown();
}

// Make splits with each plan node having a number of source files.
SplitInput makeSpiltInput(
const std::vector<core::PlanNodeId>& nodeIds,
Expand Down

0 comments on commit 2397144

Please sign in to comment.