From 21af13229d4491945a8a7061dd980b3252e5915e Mon Sep 17 00:00:00 2001 From: Yi He Date: Fri, 10 May 2019 12:16:58 -0700 Subject: [PATCH] Add RowExpressionSymbolInliner --- .../sql/planner/ExpressionSymbolInliner.java | 1 + .../planner/RowExpressionSymbolInliner.java | 65 ++++++++++++++++ .../TestRowExpressionSymbolInliner.java | 74 +++++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionSymbolInliner.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionSymbolInliner.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionSymbolInliner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionSymbolInliner.java index 0a5bc150397a2..6a4989cecd20c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionSymbolInliner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionSymbolInliner.java @@ -28,6 +28,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +@Deprecated public final class ExpressionSymbolInliner { public static Expression inlineSymbols(Map mapping, Expression expression) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionSymbolInliner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionSymbolInliner.java new file mode 100644 index 0000000000000..f025c40d6c8ae --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionSymbolInliner.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.relational.RowExpressionRewriter; +import com.facebook.presto.sql.relational.RowExpressionTreeRewriter; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +public final class RowExpressionSymbolInliner + extends RowExpressionRewriter +{ + private final Set excludedNames = new HashSet<>(); + private final Map mapping; + + private RowExpressionSymbolInliner(Map mapping) + { + this.mapping = mapping; + } + + public static RowExpression inlineSymbols(Map mapping, RowExpression expression) + { + return RowExpressionTreeRewriter.rewriteWith(new RowExpressionSymbolInliner(mapping), expression); + } + + @Override + public RowExpression rewriteVariableReference(VariableReferenceExpression node, Void context, RowExpressionTreeRewriter treeRewriter) + { + if (!excludedNames.contains(node.getName())) { + RowExpression result = new VariableReferenceExpression(mapping.get(new Symbol(node.getName())).getName(), node.getType()); + checkState(result != null, "Cannot resolve symbol %s", node.getName()); + return result; + } + return null; + } + + @Override + public RowExpression rewriteLambda(LambdaDefinitionExpression node, Void context, RowExpressionTreeRewriter treeRewriter) + { + checkArgument(!node.getArguments().stream().anyMatch(excludedNames::contains), "Lambda argument already contained in excluded names."); + excludedNames.addAll(node.getArguments()); + RowExpression result = treeRewriter.defaultRewrite(node, context); + excludedNames.removeAll(node.getArguments()); + return result; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionSymbolInliner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionSymbolInliner.java new file mode 100644 index 0000000000000..2d113998bd781 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionSymbolInliner.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static org.testng.Assert.assertEquals; + +public class TestRowExpressionSymbolInliner +{ + private static final FunctionHandle TEST_FUNCTION = () -> null; + + @Test + public void testInlineVariable() + { + assertEquals(RowExpressionSymbolInliner.inlineSymbols( + ImmutableMap.of( + symbol("a"), + symbol("b")), + variable("a")), + variable("b")); + } + + @Test + public void testInlineLambda() + { + assertEquals(RowExpressionSymbolInliner.inlineSymbols( + ImmutableMap.of( + symbol("a"), + symbol("b"), + symbol("lambda_argument"), + symbol("c")), + new CallExpression("apply", TEST_FUNCTION, BIGINT, ImmutableList.of( + variable("a"), + new LambdaDefinitionExpression( + ImmutableList.of(BIGINT), + ImmutableList.of("lambda_argument"), + new CallExpression("add", TEST_FUNCTION, BIGINT, ImmutableList.of(variable("lambda_argument"), variable("a"))))))), + new CallExpression("apply", TEST_FUNCTION, BIGINT, ImmutableList.of( + variable("b"), + new LambdaDefinitionExpression( + ImmutableList.of(BIGINT), + ImmutableList.of("lambda_argument"), + new CallExpression("add", TEST_FUNCTION, BIGINT, ImmutableList.of(variable("lambda_argument"), variable("b"))))))); + } + + private Symbol symbol(String name) + { + return new Symbol(name); + } + + private VariableReferenceExpression variable(String name) + { + return new VariableReferenceExpression(name, BIGINT); + } +}