Skip to content

Commit

Permalink
[POAE7-2280] [M2] [Integration] VeloxToSubstrait- support variadic fu…
Browse files Browse the repository at this point in the history
…nction lookup (oap-project#45)

* logical scalar function support & code refactor

* method 'loadExtension' should be load once at runtime

* revert setup-ubuntu.sh

* remove unused methods

* code style fix

* code style fix

* license header fix

* merge logical op test into same method
  • Loading branch information
chaojun-zhang authored and ZJie1 committed Dec 14, 2022
1 parent dc5d08f commit ea8fc59
Show file tree
Hide file tree
Showing 23 changed files with 1,141 additions and 1,016 deletions.
34 changes: 34 additions & 0 deletions velox/substrait/ExprUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/substrait/ExprUtils.h"
#include "velox/substrait/SubstraitType.h"

namespace facebook::velox::substrait {

SubstraitSignaturePtr toSubstraitSignature(
const core::CallTypedExprPtr& callTypedExpr) {
std::vector<SubstraitTypePtr> types;
types.reserve(callTypedExpr->inputs().size());
for (const auto& input : callTypedExpr->inputs()) {
types.emplace_back(fromVelox(input->type()));
}

return SubstraitFunctionSignature::of(
callTypedExpr->name(), types, fromVelox(callTypedExpr->type()));
}

} // namespace facebook::velox::substrait
29 changes: 29 additions & 0 deletions velox/substrait/ExprUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "velox/expression/Expr.h"
#include "velox/substrait/SubstraitSignature.h"
#include "velox/substrait/TypeUtils.h"

namespace facebook::velox::substrait {

/// convert velox callTyped expression to substrait function signature.
SubstraitSignaturePtr toSubstraitSignature(
const core::CallTypedExprPtr& callTypedExpr);

} // namespace facebook::velox::substrait
60 changes: 40 additions & 20 deletions velox/substrait/SubstraitExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,42 @@ static bool decodeFunctionVariant(
while (std::getline(ss, lastReturnType, '\n')) {
}
function.returnType = SubstraitType::decode(lastReturnType);
auto& args = node["args"];
if (args && args.IsSequence()) {
for (auto& arg : args) {
if (arg["options"]) { // enum argument
auto enumArgument = std::make_shared<SubstraitEnumArgument>(
arg.as<SubstraitEnumArgument>());
function.arguments.emplace_back(enumArgument);
} else if (arg["value"]) { // value argument
auto valueArgument = std::make_shared<SubstraitValueArgument>(
arg.as<SubstraitValueArgument>());
function.arguments.emplace_back(valueArgument);
} else { // type argument
auto typeArgument = std::make_shared<SubstraitTypeArgument>(
arg.as<SubstraitTypeArgument>());
function.arguments.emplace_back(typeArgument);
}
}
auto& args = node["args"];
if (args && args.IsSequence()) {
for (auto& arg : args) {
if (arg["options"]) { // enum argument
auto enumArgument = std::make_shared<SubstraitEnumArgument>(
arg.as<SubstraitEnumArgument>());
function.arguments.emplace_back(enumArgument);
} else if (arg["value"]) { // value argument
auto valueArgument = std::make_shared<SubstraitValueArgument>(
arg.as<SubstraitValueArgument>());
function.arguments.emplace_back(valueArgument);
} else { // type argument
auto typeArgument = std::make_shared<SubstraitTypeArgument>(
arg.as<SubstraitTypeArgument>());
function.arguments.emplace_back(typeArgument);
}
}
return true;
}
return false;

auto& variadic = node["variadic"];
if (variadic) {
auto& min = variadic["min"];
auto& max = variadic["max"];
if (min) {
function.variadic = std::make_optional<SubstraitFunctionVariadic>(
{min.as<int>(),
max ? std::make_optional<int>(max.as<int>()) : std::nullopt});
} else {
function.variadic = std::nullopt;
}
} else {
function.variadic = std::nullopt;
}

return true;
}

template <>
Expand Down Expand Up @@ -110,7 +125,7 @@ struct convert<SubstraitAggregateFunctionVariant> {
static bool decode(
const Node& node,
SubstraitAggregateFunctionVariant& function) {
auto res = decodeFunctionVariant(node, function);
const auto& res = decodeFunctionVariant(node, function);
if (res) {
const auto& intermediate = node["intermediate"];
if (intermediate) {
Expand Down Expand Up @@ -246,7 +261,12 @@ std::string getSubstraitExtensionAbsolutePath() {
} // namespace

std::shared_ptr<SubstraitExtension> SubstraitExtension::loadExtension() {
std::vector<std::string> extensionFiles = {
static const auto& extension = loadDefault();
return extension;
}

std::shared_ptr<SubstraitExtension> SubstraitExtension::loadDefault() {
static const std::vector<std::string> extensionFiles = {
"functions_aggregate_approx.yaml",
"functions_aggregate_generic.yaml",
"functions_arithmetic.yaml",
Expand Down
7 changes: 6 additions & 1 deletion velox/substrait/SubstraitExtension.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
namespace facebook::velox::substrait {

/// class used to deserialize substrait YAML extension files.
struct SubstraitExtension {
class SubstraitExtension {
public:
/// deserialize default substrait extension.
static std::shared_ptr<SubstraitExtension> loadExtension();

Expand Down Expand Up @@ -63,6 +64,10 @@ struct SubstraitExtension {

/// substrait user defined types loaded from Substrait extension yaml.
std::vector<SubstraitTypeAnchorPtr> types;

private:
/// deserialize default substrait extension.
static std::shared_ptr<SubstraitExtension> loadDefault();
};

using SubstraitExtensionPtr = std::shared_ptr<const SubstraitExtension>;
Expand Down
2 changes: 2 additions & 0 deletions velox/substrait/SubstraitFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
*/

#include "velox/substrait/SubstraitFunction.h"
#include <sstream>
#include "velox/substrait/SubstraitType.h"

namespace facebook::velox::substrait {

std::string SubstraitFunctionVariant::signature(
const std::string& name,
const std::vector<SubstraitFunctionArgumentPtr>& arguments) {
Expand Down
25 changes: 14 additions & 11 deletions velox/substrait/SubstraitFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,22 @@ struct SubstraitFunctionAnchor {
}
};

struct SubstraitFunctionVariadic {
int min;
std::optional<int> max;
};

struct SubstraitFunctionVariant {
/// scalar function name.
std::string name;
/// scalar function uri.
std::string uri;
/// function arguments.
std::vector<SubstraitFunctionArgumentPtr> arguments;
/// return type of scalar function.
SubstraitTypePtr returnType;

SubstraitFunctionVariant() {}

SubstraitFunctionVariant(const SubstraitFunctionVariant& that) {
this->name = that.name;
this->returnType = that.returnType;
this->uri = that.uri;
this->arguments = that.arguments;
}
/// function variadic
std::optional<SubstraitFunctionVariadic> variadic;

/// create function signature by given function name and arguments.
static std::string signature(
Expand All @@ -118,14 +117,14 @@ struct SubstraitFunctionVariant {

/// create function signature by function name and arguments.
const std::string signature() const {
return SubstraitFunctionVariant::signature(name, arguments);
return signature(name, arguments);
}

const SubstraitFunctionAnchor anchor() const {
return {uri, signature()};
}

const bool hasWildcardArgument() const {
const bool isWildcard() const {
for (auto& arg : arguments) {
if (arg->isWildcardType()) {
return true;
Expand All @@ -134,6 +133,10 @@ struct SubstraitFunctionVariant {
return false;
}

const bool isVariadic() const {
return variadic.has_value();
}

/// create an new function variant with given arguments.
SubstraitFunctionVariant& operator=(const SubstraitFunctionVariant& that) {
this->name = that.name;
Expand Down
24 changes: 13 additions & 11 deletions velox/substrait/SubstraitFunctionCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ void SubstraitFunctionCollector::addFunctionToPlan(
std::unordered_map<std::string, SimpleExtensionURI*> uris;
for (auto& [referenceNum, function] : functions_->forwardMap_) {
SimpleExtensionURI* extensionUri;
if (uris.find(function.uri) == uris.end()) {
const auto uri = uris.find(function.uri);
if (uri == uris.end()) {
extensionUri = substraitPlan->add_extension_uris();
extensionUri->set_extension_uri_anchor(++uriPos);
extensionUri->set_uri(function.uri);
uris[function.uri] = extensionUri;
} else {
extensionUri = uris.at(function.uri);
extensionUri = uri->second;
}

auto extensionFunction =
Expand All @@ -51,9 +52,10 @@ void SubstraitFunctionCollector::addFunctionToPlan(

int SubstraitFunctionCollector::getFunctionReference(
const SubstraitFunctionVariantPtr& function) {
if (functions_->reverseMap_.find(function->anchor()) !=
functions_->reverseMap_.end()) {
return functions_->reverseMap_.at(function->anchor());
const auto& anchorReference =
functions_->reverseMap_.find(function->anchor());
if (anchorReference != functions_->reverseMap_.end()) {
return anchorReference->second;
}
++functionReference_;
functions_->put(functionReference_, function->anchor());
Expand Down Expand Up @@ -101,8 +103,9 @@ void SubstraitFunctionCollector::addTypeToPlan(

int SubstraitFunctionCollector::getTypeReference(
const SubstraitTypeAnchorPtr& typeAnchor) {
if (types_->reverseMap_.find(*typeAnchor) != types_->reverseMap_.end()) {
return types_->reverseMap_.at(*typeAnchor);
const auto& anchorReference = types_->reverseMap_.find(*typeAnchor);
if (anchorReference != types_->reverseMap_.end()) {
return anchorReference->second;
}
++typeReference_;
types_->put(functionReference_, *typeAnchor);
Expand All @@ -116,8 +119,7 @@ SubstraitFunctionCollector::getScalarFunctionVariant(
const auto& functionAnchor = functions_->forwardMap_.find(referernce);
if (functionAnchor != functions_->forwardMap_.end()) {
for (const auto& scalarFunctionVariant : extension.scalarFunctionVariants) {
if (scalarFunctionVariant->anchor() ==
functions_->forwardMap_.at(referernce)) {
if (scalarFunctionVariant->anchor() == functionAnchor->second) {
return scalarFunctionVariant;
}
}
Expand All @@ -134,12 +136,12 @@ SubstraitFunctionCollector::getAggregateFunctionVariant(
if (functionAnchor != functions_->forwardMap_.end()) {
for (const auto& aggregateFunctionVaraint :
extension.aggregateFunctionVariants) {
if (aggregateFunctionVaraint->anchor() ==
functions_->forwardMap_.at(referernce)) {
if (aggregateFunctionVaraint->anchor() == functionAnchor->second) {
return aggregateFunctionVaraint;
}
}
}

VELOX_NYI(
"Unknown aggregate function id. Make sure that the function id provided was shared in the extensions section of the plan.");
}
Expand Down
Loading

0 comments on commit ea8fc59

Please sign in to comment.