Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote function planning #14718

Merged
merged 4 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,24 @@ public final void abort(FunctionNamespaceTransactionHandle transactionHandle)
public final Collection<SqlInvokedFunction> getFunctions(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, QualifiedFunctionName functionName)
{
checkCatalog(functionName);
checkArgument(transactionHandle.isPresent(), "missing transactionHandle");
return transactions.get(transactionHandle.get()).loadAndGetFunctionsTransactional(functionName);
if (transactionHandle.isPresent()) {
return transactions.get(transactionHandle.get()).loadAndGetFunctionsTransactional(functionName);
}
return fetchFunctionsDirect(functionName);
}

@Override
public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
{
checkCatalog(signature.getName());
checkArgument(transactionHandle.isPresent(), "missing transactionHandle");
// This is the only assumption in this class that we're dealing with sql-invoked regular function.
SqlFunctionId functionId = new SqlFunctionId(signature.getName(), signature.getArgumentTypes());
return transactions.get(transactionHandle.get()).getFunctionHandle(functionId);
if (transactionHandle.isPresent()) {
return transactions.get(transactionHandle.get()).getFunctionHandle(functionId);
}
FunctionCollection collection = new FunctionCollection();
collection.loadAndGetFunctionsTransactional(signature.getName());
return collection.getFunctionHandle(functionId);
caithagoras marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand Down
7 changes: 7 additions & 0 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@
<artifactId>jaxrs-testing</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-function-namespace-managers</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-plugin-toolkit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ public void loadFunctionNamespaceManager(
}
}

@VisibleForTesting
public void addFunctionNamespace(String catalogName, FunctionNamespaceManager functionNamespaceManager)
{
transactionManager.registerFunctionNamespaceManager(catalogName, functionNamespaceManager);
if (functionNamespaceManagers.putIfAbsent(catalogName, functionNamespaceManager) != null) {
throw new IllegalArgumentException(format("Function namespace manager is already registered for catalog [%s]", catalogName));
}
}

public FunctionInvokerProvider getFunctionInvokerProvider()
{
return functionInvokerProvider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@

import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractExpressions;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractExternalFunctions;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractWindowFunctions;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED;
import static java.util.Objects.requireNonNull;

public class Analyzer
Expand Down Expand Up @@ -107,4 +109,12 @@ static void verifyNoAggregateWindowOrGroupingFunctions(Map<NodeRef<FunctionCall>
throw new SemanticException(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, predicate, "%s cannot contain aggregations, window functions or grouping operations: %s", clause, found);
}
}

static void verifyNoExternalFunctions(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, FunctionManager functionManager, Expression predicate, String clause)
{
List<FunctionCall> externalFunctions = extractExternalFunctions(functionHandles, ImmutableList.of(predicate), functionManager);
if (!externalFunctions.isEmpty()) {
throw new SemanticException(NOT_SUPPORTED, predicate, "External functions in %s is not supported: %s", clause, externalFunctions);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, for better readability of the constructed error message:

External functions in [%s] is not supported: %s

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no unit test coverage for this commit. Can we add some.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, for better readability of the constructed error message:

External functions in [%s] is not supported: %s

I don't think the [] is useful. It would make the error message like this:

External functions in [Lambda expression] is not supported:....

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
import static com.facebook.presto.metadata.FunctionManager.qualifyFunctionName;
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE;
Expand Down Expand Up @@ -896,6 +897,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext<C
Type type = innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types));
if (expression instanceof LambdaExpression) {
verifyNoAggregateWindowOrGroupingFunctions(innerExpressionAnalyzer.getResolvedFunctions(), functionManager, ((LambdaExpression) expression).getBody(), "Lambda expression");
verifyNoExternalFunctions(innerExpressionAnalyzer.getResolvedFunctions(), functionManager, ((LambdaExpression) expression).getBody(), "Lambda expression");
}
return type.getTypeSignature();
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -48,6 +46,11 @@ static List<FunctionCall> extractWindowFunctions(Iterable<? extends Node> nodes)
return extractExpressions(nodes, FunctionCall.class, ExpressionTreeUtils::isWindowFunction);
}

static List<FunctionCall> extractExternalFunctions(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, Iterable<? extends Node> nodes, FunctionManager functionManager)
{
return extractExpressions(nodes, FunctionCall.class, isExternalFunctionPredicate(functionHandles, functionManager));
}

public static <T extends Expression> List<T> extractExpressions(
Iterable<? extends Node> nodes,
Class<T> clazz)
Expand All @@ -57,16 +60,21 @@ public static <T extends Expression> List<T> extractExpressions(

private static Predicate<FunctionCall> isAggregationPredicate(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, FunctionManager functionManager)
{
return ((functionCall) -> (functionManager.getFunctionMetadata(functionHandles.get(NodeRef.of(functionCall))).getFunctionKind() == AGGREGATE || functionCall.getFilter().isPresent())
return functionCall -> (functionManager.getFunctionMetadata(functionHandles.get(NodeRef.of(functionCall))).getFunctionKind() == AGGREGATE || functionCall.getFilter().isPresent())
&& !functionCall.getWindow().isPresent()
|| functionCall.getOrderBy().isPresent());
|| functionCall.getOrderBy().isPresent();
}

private static boolean isWindowFunction(FunctionCall functionCall)
{
return functionCall.getWindow().isPresent();
}

private static Predicate<FunctionCall> isExternalFunctionPredicate(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, FunctionManager functionManager)
caithagoras marked this conversation as resolved.
Show resolved Hide resolved
{
return functionCall -> functionManager.getFunctionMetadata(functionHandles.get(NodeRef.of(functionCall))).getImplementationType().isExternal();
}

private static <T extends Expression> List<T> extractExpressions(
Iterable<? extends Node> nodes,
Class<T> clazz,
Expand Down Expand Up @@ -104,9 +112,4 @@ public static boolean isEqualComparisonExpression(Expression expression)
{
return expression instanceof ComparisonExpression && ((ComparisonExpression) expression).getOperator() == ComparisonExpression.Operator.EQUAL;
}

public static boolean isInValuesComparisonExpression(Expression expression)
{
return expression instanceof InPredicate && ((InPredicate) expression).getValueList() instanceof InListExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@
import static com.facebook.presto.sql.ParsingUtil.createParsingOptions;
import static com.facebook.presto.sql.analyzer.AggregationAnalyzer.verifyOrderByAggregations;
import static com.facebook.presto.sql.analyzer.AggregationAnalyzer.verifySourceAggregations;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions;
Expand Down Expand Up @@ -596,7 +598,8 @@ protected Scope visitCreateFunction(CreateFunction node, Optional<Scope> scope)
throw new SemanticException(TYPE_MISMATCH, node, "Function implementation type '%s' does not match declared return type '%s'", bodyType, returnType);
}

Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), returnExpression, "CREATE FUNCTION body");
verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), returnExpression, "CREATE FUNCTION body");
verifyNoExternalFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), returnExpression, "CREATE FUNCTION body");

// TODO: Check body contains no SQL invoked functions
}
Expand Down Expand Up @@ -1305,7 +1308,7 @@ else if (criteria instanceof JoinOn) {
analysis.addCoercion(expression, BOOLEAN, false);
}

Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), expression, "JOIN clause");
verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), expression, "JOIN clause");

analysis.recordSubqueries(node, expressionAnalysis);
analysis.setJoinCriteria(node, expression);
Expand Down Expand Up @@ -1702,7 +1705,7 @@ private List<Expression> analyzeGroupBy(QuerySpecification node, Scope scope, Li
sets.add(ImmutableList.of(ImmutableSet.of(field)));
}
else {
Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), column, "GROUP BY clause");
verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), column, "GROUP BY clause");
analysis.recordSubqueries(node, analyzeExpression(column, scope));
complexExpressions.add(column);
}
Expand Down Expand Up @@ -1935,7 +1938,7 @@ public void analyzeWhere(Node node, Scope scope, Expression predicate)
{
ExpressionAnalysis expressionAnalysis = analyzeExpression(predicate, scope);

Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), predicate, "WHERE clause");
verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), predicate, "WHERE clause");

analysis.recordSubqueries(node, expressionAnalysis);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import com.facebook.presto.sql.planner.iterative.rule.MergeLimits;
import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout;
import com.facebook.presto.sql.planner.iterative.rule.PlanRemotePojections;
import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationSourceColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneCountAggregationOverScalar;
Expand Down Expand Up @@ -88,6 +89,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarApplyNodes;
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject;
import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation;
import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant;
Expand Down Expand Up @@ -432,6 +434,16 @@ public PlanOptimizers(
new TranslateExpressions(metadata, sqlParser).rules()));
// After this point, all planNodes should not contain OriginalExpression

// PlanRemoteProjections only handles RowExpression so this need to run after TranslateExpressions
// Rules applied after this need to handle locality of ProjectNode properly.
builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(
new RewriteFilterWithExternalFunctionToProject(metadata.getFunctionManager()),
new PlanRemotePojections(metadata.getFunctionManager()))));

// Pass a supplier so that we pickup connector optimizers that are installed later
builder.add(
new ApplyConnectorOptimization(() -> planOptimizerManager.getOptimizers(LOGICAL)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import static com.facebook.presto.spi.plan.AggregationNode.groupingSets;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme;
import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder;
Expand Down Expand Up @@ -406,7 +407,8 @@ private PlanBuilder explicitCoercionFields(PlanBuilder subPlan, Iterable<Express
return new PlanBuilder(translations, new ProjectNode(
idAllocator.getNextId(),
subPlan.getRoot(),
projections.build()),
projections.build(),
LOCAL),
analysis.getParameters());
}

Expand All @@ -422,7 +424,8 @@ private PlanBuilder explicitCoercionVariables(PlanBuilder subPlan, List<Variable
return new PlanBuilder(translations, new ProjectNode(
idAllocator.getNextId(),
subPlan.getRoot(),
assignments),
assignments,
LOCAL),
analysis.getParameters());
}

Expand Down Expand Up @@ -541,7 +544,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
aggregationArguments.stream().map(AssignmentUtils::identityAsSymbolReference).forEach(assignments::put);
groupingSetMappings.forEach((key, value) -> assignments.put(key, castToRowExpression(asSymbolReference(value))));

ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build());
ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build(), LOCAL);
subPlan = new PlanBuilder(groupingTranslations, project, analysis.getParameters());
}

Expand Down Expand Up @@ -701,7 +704,7 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica
newTranslations.put(groupingOperation, variable);
}

return new PlanBuilder(newTranslations, new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), projections.build()), analysis.getParameters());
return new PlanBuilder(newTranslations, new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), projections.build(), LOCAL), analysis.getParameters());
}

private PlanBuilder window(PlanBuilder subPlan, OrderBy node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import java.util.Set;

import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression;
import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences;
Expand Down Expand Up @@ -195,7 +196,7 @@ protected RelationPlan visitAliasedRelation(AliasedRelation node, Void context)
}
}

root = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build());
root = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build(), LOCAL);
mappings = newMappings.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context)
if (projectNode.getAssignments().equals(assignments)) {
return Result.empty();
}
return Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), assignments));
return Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), assignments, projectNode.getLocality()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import static com.facebook.presto.spi.StandardErrorCode.INVALID_SPATIAL_PARTITIONING;
import static com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING;
import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
Expand Down Expand Up @@ -642,7 +643,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe
}

projections.put(variable, expression);
return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build());
return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build(), LOCAL);
}

private static PlanNode addPartitioningNodes(Context context, FunctionManager functionManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional<RowExpression> radius)
Expand Down Expand Up @@ -672,7 +673,7 @@ private static PlanNode addPartitioningNodes(Context context, FunctionManager fu

return new UnnestNode(
context.getIdAllocator().getNextId(),
new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build()),
new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build(), LOCAL),
node.getOutputVariables(),
ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)),
Optional.empty());
Expand Down
Loading