From e323438166466f646ba899f1f9b1afb90dca0b6b Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Tue, 30 Jul 2024 10:15:40 -0400 Subject: [PATCH 01/10] Add sidecars to NodeManager and add PluginNodeManager --- .../connector/ConnectorAwareNodeManager.java | 18 ++++ .../presto/nodeManager/PluginNodeManager.java | 97 +++++++++++++++++++ .../presto/testing/TestingNodeManager.java | 6 ++ .../nodeManager/TestPluginNodeManager.java | 94 ++++++++++++++++++ .../com/facebook/presto/spi/NodeManager.java | 2 + 5 files changed, 217 insertions(+) create mode 100644 presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java create mode 100644 presto-main/src/test/java/com/facebook/presto/nodeManager/TestPluginNodeManager.java diff --git a/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java b/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java index 6764405dca877..3506d55393f21 100644 --- a/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java @@ -13,14 +13,19 @@ */ package com.facebook.presto.connector; +import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableSet; import java.util.Set; +import static com.facebook.presto.spi.StandardErrorCode.NO_CPP_SIDECARS; +import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_SIDECARS; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; public class ConnectorAwareNodeManager @@ -58,6 +63,19 @@ public Node getCurrentNode() return nodeManager.getCurrentNode(); } + @Override + public Node getSidecarNode() + { + Set coordinatorSidecars = nodeManager.getCoordinatorSidecars(); + if (coordinatorSidecars.isEmpty()) { + throw new PrestoException(NO_CPP_SIDECARS, "Expected exactly one coordinator sidecar, but found none"); + } + if (coordinatorSidecars.size() > 1) { + throw new PrestoException(TOO_MANY_SIDECARS, "Expected exactly one coordinator sidecar, but found " + coordinatorSidecars.size()); + } + return getOnlyElement(coordinatorSidecars); + } + @Override public String getEnvironment() { diff --git a/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java b/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java new file mode 100644 index 0000000000000..34c2a1a3ccaab --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nodeManager; + +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Inject; + +import java.util.Set; + +import static com.facebook.presto.spi.StandardErrorCode.NO_CPP_SIDECARS; +import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_SIDECARS; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +/** + * This class simplifies managing Presto's cluster nodes, + * focusing on active workers and coordinators without tying to specific connectors. + */ +public class PluginNodeManager + implements NodeManager +{ + private final InternalNodeManager nodeManager; + private final String environment; + + @Inject + public PluginNodeManager(InternalNodeManager nodeManager) + { + this.nodeManager = nodeManager; + this.environment = "test"; + } + + public PluginNodeManager(InternalNodeManager nodeManager, String environment) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.environment = requireNonNull(environment, "environment is null"); + } + + @Override + public Set getAllNodes() + { + return ImmutableSet.builder() + .addAll(getWorkerNodes()) + .addAll(nodeManager.getCoordinators()) + .build(); + } + + @Override + public Set getWorkerNodes() + { + //Retrieves all active worker nodes, excluding coordinators, resource managers, and catalog servers. + return nodeManager.getAllNodes().getActiveNodes().stream() + .filter(node -> !node.isResourceManager() && !node.isCoordinator() && !node.isCatalogServer()) + .collect(toImmutableSet()); + } + + @Override + public Node getCurrentNode() + { + return nodeManager.getCurrentNode(); + } + + @Override + public Node getSidecarNode() + { + Set coordinatorSidecars = nodeManager.getCoordinatorSidecars(); + if (coordinatorSidecars.isEmpty()) { + throw new PrestoException(NO_CPP_SIDECARS, "Expected exactly one coordinator sidecar, but found none"); + } + if (coordinatorSidecars.size() > 1) { + throw new PrestoException(TOO_MANY_SIDECARS, "Expected exactly one coordinator sidecar, but found " + coordinatorSidecars.size()); + } + return getOnlyElement(coordinatorSidecars); + } + + @Override + public String getEnvironment() + { + return environment; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingNodeManager.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingNodeManager.java index d76239c726078..5b2eab906325b 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingNodeManager.java @@ -92,6 +92,12 @@ public Node getCurrentNode() return localNode; } + @Override + public Node getSidecarNode() + { + return localNode; + } + @Override public String getEnvironment() { diff --git a/presto-main/src/test/java/com/facebook/presto/nodeManager/TestPluginNodeManager.java b/presto-main/src/test/java/com/facebook/presto/nodeManager/TestPluginNodeManager.java new file mode 100644 index 0000000000000..aa5ef70e7f3f1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/nodeManager/TestPluginNodeManager.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nodeManager; + +import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.Node; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Arrays; +import java.util.Set; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestPluginNodeManager +{ + private InMemoryNodeManager inMemoryNodeManager; + private PluginNodeManager pluginNodeManager; + + @BeforeClass + public void setUp() + { + // Initialize the InMemoryNodeManager and PluginNodeManager before each test. + inMemoryNodeManager = new InMemoryNodeManager(); + pluginNodeManager = new PluginNodeManager(inMemoryNodeManager, "test-env"); + } + + @Test + public void testGetAllNodes() + { + ConnectorId connectorId = new ConnectorId("test-connector"); + InternalNode activeNode1 = new InternalNode("activeNode1", URI.create("http://example1.com"), new NodeVersion("1"), false); + InternalNode activeNode2 = new InternalNode("activeNode2", URI.create("http://example2.com"), new NodeVersion("1"), false); + InternalNode coordinatorNode = new InternalNode("coordinatorNode", URI.create("http://example3.com"), new NodeVersion("1"), true); + + inMemoryNodeManager.addNode(connectorId, activeNode1); + inMemoryNodeManager.addNode(connectorId, activeNode2); + inMemoryNodeManager.addNode(connectorId, coordinatorNode); + + Set allNodes = pluginNodeManager.getAllNodes(); + // The expected count is 4, considering two active nodes, one coordinator, and one local node added by InMemoryNodeManager by default. + assertEquals(4, allNodes.size()); + assertTrue(allNodes.containsAll(Arrays.asList(activeNode1, activeNode2, coordinatorNode))); + } + + @Test + public void testGetWorkerNodes() + { + ConnectorId connectorId = new ConnectorId("test-connector"); + InternalNode activeNode1 = new InternalNode("activeNode1", URI.create("http://example1.com"), new NodeVersion("1"), false); + InternalNode activeNode2 = new InternalNode("activeNode2", URI.create("http://example2.com"), new NodeVersion("1"), false); + + inMemoryNodeManager.addNode(connectorId, activeNode1); + inMemoryNodeManager.addNode(connectorId, activeNode2); + + Set workerNodes = pluginNodeManager.getWorkerNodes(); + // Expected count is 3, accounting for two explicitly added active nodes and one local node. + assertEquals(3, workerNodes.size()); + assertTrue(workerNodes.containsAll(Arrays.asList(activeNode1, activeNode2))); + } + + @Test + public void testGetEnvironment() + { + // Validate that the PluginNodeManager correctly returns the environment string set during initialization. + assertEquals("test-env", pluginNodeManager.getEnvironment()); + } + + @Test + public void testGetCurrentNode() + { + Node currentNode = pluginNodeManager.getCurrentNode(); + assertNotNull(currentNode); + // Validate that the current node is not null and its identifier matches the expected local node identifier. + assertEquals("local", currentNode.getNodeIdentifier()); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/NodeManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/NodeManager.java index 92b84d0187fb3..602205786503d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/NodeManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/NodeManager.java @@ -25,6 +25,8 @@ public interface NodeManager Node getCurrentNode(); + Node getSidecarNode(); + String getEnvironment(); default Set getRequiredWorkerNodes() From 38f511e742413e4ccfd561a1da40758c259c5801 Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Fri, 6 Sep 2024 13:34:40 -0400 Subject: [PATCH 02/10] Allow workers to act as coordinator sidecars --- .../presto/connector/ConnectorAwareNodeManager.java | 7 +------ .../facebook/presto/execution/ClusterSizeMonitor.java | 10 ++-------- .../facebook/presto/nodeManager/PluginNodeManager.java | 7 +------ .../presto/execution/TestClusterSizeMonitor.java | 6 +++--- 4 files changed, 7 insertions(+), 23 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java b/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java index 3506d55393f21..07cc71dd00a0a 100644 --- a/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/connector/ConnectorAwareNodeManager.java @@ -24,8 +24,6 @@ import java.util.Set; import static com.facebook.presto.spi.StandardErrorCode.NO_CPP_SIDECARS; -import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_SIDECARS; -import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; public class ConnectorAwareNodeManager @@ -70,10 +68,7 @@ public Node getSidecarNode() if (coordinatorSidecars.isEmpty()) { throw new PrestoException(NO_CPP_SIDECARS, "Expected exactly one coordinator sidecar, but found none"); } - if (coordinatorSidecars.size() > 1) { - throw new PrestoException(TOO_MANY_SIDECARS, "Expected exactly one coordinator sidecar, but found " + coordinatorSidecars.size()); - } - return getOnlyElement(coordinatorSidecars); + return coordinatorSidecars.iterator().next(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java b/presto-main/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java index 396d6d575f2d3..6833912e31c18 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java @@ -40,7 +40,6 @@ import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static com.facebook.presto.spi.StandardErrorCode.NO_CPP_SIDECARS; -import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_SIDECARS; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.lang.String.format; @@ -183,11 +182,7 @@ public boolean hasRequiredCoordinators() */ public boolean hasRequiredCoordinatorSidecars() { - if (currentCoordinatorSidecarCount > 1) { - throw new PrestoException(TOO_MANY_SIDECARS, - format("Expected a single active coordinator sidecar. Found %s active coordinator sidecars", currentCoordinatorSidecarCount)); - } - return currentCoordinatorSidecarCount == 1; + return currentCoordinatorSidecarCount > 0; } /** @@ -257,7 +252,7 @@ public synchronized ListenableFuture waitForMinimumCoordinators() public synchronized ListenableFuture waitForMinimumCoordinatorSidecars() { - if (currentCoordinatorSidecarCount == 1 || !isCoordinatorSidecarEnabled) { + if (currentCoordinatorSidecarCount > 0 || !isCoordinatorSidecarEnabled) { return immediateFuture(null); } @@ -309,7 +304,6 @@ private synchronized void updateAllNodes(AllNodes allNodes) Set activeNodes = new HashSet<>(allNodes.getActiveNodes()); activeNodes.removeAll(allNodes.getActiveCoordinators()); activeNodes.removeAll(allNodes.getActiveResourceManagers()); - activeNodes.removeAll(allNodes.getActiveCoordinatorSidecars()); currentWorkerCount = activeNodes.size(); } currentCoordinatorCount = allNodes.getActiveCoordinators().size(); diff --git a/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java b/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java index 34c2a1a3ccaab..a77debc16bb30 100644 --- a/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/nodeManager/PluginNodeManager.java @@ -24,9 +24,7 @@ import java.util.Set; import static com.facebook.presto.spi.StandardErrorCode.NO_CPP_SIDECARS; -import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_SIDECARS; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; /** @@ -83,10 +81,7 @@ public Node getSidecarNode() if (coordinatorSidecars.isEmpty()) { throw new PrestoException(NO_CPP_SIDECARS, "Expected exactly one coordinator sidecar, but found none"); } - if (coordinatorSidecars.size() > 1) { - throw new PrestoException(TOO_MANY_SIDECARS, "Expected exactly one coordinator sidecar, but found " + coordinatorSidecars.size()); - } - return getOnlyElement(coordinatorSidecars); + return coordinatorSidecars.iterator().next(); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java b/presto-main/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java index de0ce3e05fd14..9ef10e188e0d7 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java @@ -17,7 +17,6 @@ import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.PrestoException; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.Duration; import org.testng.annotations.AfterMethod; @@ -170,7 +169,7 @@ public void testHasRequiredCoordinatorSidecars() assertTrue(monitor.hasRequiredCoordinatorSidecars()); } - @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Expected a single active coordinator sidecar. Found 2 active coordinator sidecars") + @Test public void testHasRequiredCoordinatorSidecarsMoreThanOne() throws InterruptedException { @@ -178,7 +177,7 @@ public void testHasRequiredCoordinatorSidecarsMoreThanOne() for (int i = numCoordinatorSidecars.get(); i < DESIRED_COORDINATOR_SIDECAR_COUNT + 1; i++) { addCoordinatorSidecar(nodeManager); } - assertFalse(monitor.hasRequiredCoordinatorSidecars()); + assertTrue(monitor.hasRequiredCoordinatorSidecars()); } @Test @@ -223,6 +222,7 @@ private ListenableFuture waitForMinimumCoordinatorSidecars() addSuccessCallback(coordinatorSidecarsFuture, () -> { assertFalse(coordinatorSidecarsTimeout.get()); minCoordinatorSidecarsLatch.countDown(); + minCoordinatorSidecarsLatch.countDown(); }); addExceptionCallback(coordinatorSidecarsFuture, () -> { assertTrue(coordinatorSidecarsTimeout.compareAndSet(false, true)); From 548eae185a1ff407ed89ac1d24add01f749ca700 Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Tue, 30 Jul 2024 10:11:46 -0400 Subject: [PATCH 03/10] Add SPI to customize expression optimization --- .../presto/spi/CoordinatorPlugin.java | 6 +++ .../planner/ExpressionOptimizerContext.java | 49 +++++++++++++++++++ .../planner/ExpressionOptimizerFactory.java | 25 ++++++++++ 3 files changed, 80 insertions(+) create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerFactory.java diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java index f855e6a63e2d3..7619d61abf287 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi; import com.facebook.presto.spi.function.FunctionNamespaceManagerFactory; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; import static java.util.Collections.emptyList; @@ -30,4 +31,9 @@ default Iterable getFunctionNamespaceManagerFac { return emptyList(); } + + default Iterable getRowExpressionInterpreterServiceFactories() + { + return emptyList(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java new file mode 100644 index 0000000000000..c9a6f84aaaa1d --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.sql.planner; + +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; + +import static java.util.Objects.requireNonNull; + +public class ExpressionOptimizerContext +{ + private final NodeManager nodeManager; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution functionResolution; + + public ExpressionOptimizerContext(NodeManager nodeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); + } + + public NodeManager getNodeManager() + { + return nodeManager; + } + + public FunctionMetadataManager getFunctionMetadataManager() + { + return functionMetadataManager; + } + + public StandardFunctionResolution getFunctionResolution() + { + return functionResolution; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerFactory.java b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerFactory.java new file mode 100644 index 0000000000000..ad645e0c04a8d --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerFactory.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.sql.planner; + +import com.facebook.presto.spi.relation.ExpressionOptimizer; + +import java.util.Map; + +public interface ExpressionOptimizerFactory +{ + ExpressionOptimizer createOptimizer(Map config, ExpressionOptimizerContext context); + + String getName(); +} From 85e2aa5fba3cba6b6de020881a7eda95a51fd490 Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Tue, 30 Jul 2024 10:37:54 -0400 Subject: [PATCH 04/10] Add DelegatingRowExpressionOptimizer --- .../presto/SystemSessionProperties.java | 11 + .../facebook/presto/server/PluginManager.java | 12 +- .../facebook/presto/server/PrestoServer.java | 2 + .../presto/server/ServerMainModule.java | 4 + .../server/testing/TestingPrestoServer.java | 7 + .../presto/sql/analyzer/FeaturesConfig.java | 14 + .../ExpressionOptimizerManager.java | 104 +++++++ .../ExpressionOptimizerProvider.java | 21 ++ .../presto/sql/planner/PlanOptimizers.java | 16 +- .../rule/SimplifyRowExpressions.java | 35 ++- .../CteProjectionAndPredicatePushDown.java | 9 +- .../DelegatingRowExpressionOptimizer.java | 85 +++++ .../relational/RowExpressionOptimizer.java | 17 +- .../presto/testing/LocalQueryRunner.java | 16 +- .../facebook/presto/testing/QueryRunner.java | 3 + .../sql/analyzer/TestFeaturesConfig.java | 7 +- .../sql/planner/TestLogicalPlanner.java | 34 +- .../planner/assertions/OptimizerAssert.java | 10 +- .../iterative/rule/TestRemoveMapCastRule.java | 24 +- ...teConstantArrayContainsToInExpression.java | 49 +-- .../rule/TestSimplifyRowExpressions.java | 21 +- .../iterative/rule/test/BaseRuleTest.java | 6 + .../iterative/rule/test/RuleTester.java | 9 + ...TestCteProjectionAndPredicatePushdown.java | 2 +- .../TestDelegatingRowExpressionOptimizer.java | 112 +++++++ .../nativeworker/ContainerQueryRunner.java | 7 + .../presto/spark/PrestoSparkModule.java | 4 + .../presto/spark/PrestoSparkQueryRunner.java | 7 + .../spi/relation/ExpressionOptimizer.java | 4 + .../tests/AbstractTestQueryFramework.java | 15 +- .../presto/tests/DistributedQueryRunner.java | 8 + .../presto/tests/StandaloneQueryRunner.java | 7 + .../presto/memory/TestMemoryManager.java | 3 +- .../TestDelegatingExpressionOptimizer.java | 182 +++++++++++ .../TestExpressionInterpreter.java | 225 ++++++++++++++ .../expressions/TestExpressionOptimizers.java | 133 ++++++++ .../tests/expressions/TestExpressions.java | 292 ++++++------------ .../thrift/integration/ThriftQueryRunner.java | 7 + 38 files changed, 1257 insertions(+), 267 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerProvider.java create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java rename presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java => presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java (88%) diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 1afbc272f89b3..41b42fe7444f4 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -364,6 +364,7 @@ public final class SystemSessionProperties public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms"; public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns"; public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values"; + public static final String DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED = "delegating_row_expression_optimizer_enabled"; private final List> sessionProperties; @@ -2038,6 +2039,11 @@ public SystemSessionProperties( booleanProperty(INLINE_PROJECTIONS_ON_VALUES, "Whether to evaluate project node on values node", featuresConfig.getInlineProjectionsOnValues(), + false), + booleanProperty( + DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, + "Enable delegating row optimizer", + featuresConfig.isDelegatingRowExpressionOptimizerEnabled(), false)); } @@ -3370,4 +3376,9 @@ public static boolean isInlineProjectionsOnValues(Session session) { return session.getSystemProperty(INLINE_PROJECTIONS_ON_VALUES, Boolean.class); } + + public static boolean isDelegatingRowExpressionOptimizerEnabled(Session session) + { + return session.getSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java index 22ecb4b1c4731..63d1bec784c2e 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.security.PasswordAuthenticatorFactory; import com.facebook.presto.spi.security.SystemAccessControlFactory; import com.facebook.presto.spi.session.SessionPropertyConfigurationManagerFactory; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.storage.TempStorageFactory; import com.facebook.presto.spi.tracing.TracerProvider; @@ -48,6 +49,7 @@ import com.facebook.presto.spi.ttl.NodeTtlFetcherFactory; import com.facebook.presto.sql.analyzer.AnalyzerProviderManager; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.storage.TempStorageManager; import com.facebook.presto.tracing.TracerProviderManager; import com.facebook.presto.ttl.clusterttlprovidermanagers.ClusterTtlProviderManager; @@ -131,6 +133,7 @@ public class PluginManager private final AnalyzerProviderManager analyzerProviderManager; private final QueryPreparerProviderManager queryPreparerProviderManager; private final NodeStatusNotificationManager nodeStatusNotificationManager; + private final ExpressionOptimizerManager expressionOptimizerManager; @Inject public PluginManager( @@ -152,7 +155,8 @@ public PluginManager( ClusterTtlProviderManager clusterTtlProviderManager, HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, TracerProviderManager tracerProviderManager, - NodeStatusNotificationManager nodeStatusNotificationManager) + NodeStatusNotificationManager nodeStatusNotificationManager, + ExpressionOptimizerManager expressionOptimizerManager) { requireNonNull(nodeInfo, "nodeInfo is null"); requireNonNull(config, "config is null"); @@ -184,6 +188,7 @@ public PluginManager( this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null"); this.queryPreparerProviderManager = requireNonNull(queryPreparerProviderManager, "queryPreparerProviderManager is null"); this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionManager is null"); } public void loadPlugins() @@ -356,6 +361,11 @@ public void installCoordinatorPlugin(CoordinatorPlugin plugin) log.info("Registering function namespace manager %s", functionNamespaceManagerFactory.getName()); metadata.getFunctionAndTypeManager().addFunctionNamespaceFactory(functionNamespaceManagerFactory); } + + for (ExpressionOptimizerFactory batchRowExpressionInterpreterProvider : plugin.getRowExpressionInterpreterServiceFactories()) { + log.info("Registering batch row expression interpreter provider %s", batchRowExpressionInterpreterProvider.getName()); + expressionOptimizerManager.addExpressionOptimizerFactory(batchRowExpressionInterpreterProvider); + } } private URLClassLoader buildClassLoader(String plugin) diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index 8b3aaa3009cca..6ab338b6c6d7c 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -46,6 +46,7 @@ import com.facebook.presto.server.security.PasswordAuthenticatorManager; import com.facebook.presto.server.security.ServerSecurityModule; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.storage.TempStorageManager; import com.facebook.presto.storage.TempStorageModule; @@ -177,6 +178,7 @@ public void run() injector.getInstance(TracerProviderManager.class).loadTracerProvider(); injector.getInstance(NodeStatusNotificationManager.class).loadNodeStatusNotificationProvider(); injector.getInstance(GracefulShutdownHandler.class).loadNodeStatusNotification(); + injector.getInstance(ExpressionOptimizerManager.class).loadExpressions(); startAssociatedProcesses(injector); injector.getInstance(Announcer.class).start(); diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index 85658fe457f7b..dac2dd3440a92 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -186,6 +186,7 @@ import com.facebook.presto.sql.analyzer.MetadataExtractorMBean; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -346,6 +347,9 @@ else if (serverConfig.isCoordinator()) { binder.bind(SystemSessionProperties.class).in(Scopes.SINGLETON); binder.bind(SessionPropertyDefaults.class).in(Scopes.SINGLETON); + // expression manager + binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + // schema properties binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 6157d89c19e0a..939abfc8b9cd4 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -69,6 +69,7 @@ import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; @@ -379,6 +380,7 @@ public TestingPrestoServer( statsCalculator = injector.getInstance(StatsCalculator.class); eventListenerManager = ((TestingEventListenerManager) injector.getInstance(EventListenerManager.class)); clusterStateProvider = null; + injector.getInstance(ExpressionOptimizerManager.class).loadExpressions(); } else if (resourceManager) { dispatchManager = null; @@ -682,6 +684,11 @@ public ShutdownAction getShutdownAction() return shutdownAction; } + public ExpressionOptimizerManager getExpressionManager() + { + return injector.getInstance(ExpressionOptimizerManager.class); + } + public boolean isCoordinator() { return coordinator; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index c88271e97dd55..1a99642605331 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -291,6 +291,7 @@ public class FeaturesConfig private boolean generateDomainFilters; private boolean printEstimatedStatsFromCache; private boolean removeCrossJoinWithSingleConstantRow = true; + private boolean delegatingRowOptimizerEnabled; private CreateView.Security defaultViewSecurityMode = DEFINER; private boolean useHistograms; @@ -2969,4 +2970,17 @@ public FeaturesConfig setInlineProjectionsOnValues(boolean isInlineProjectionsOn this.isInlineProjectionsOnValuesEnabled = isInlineProjectionsOnValuesEnabled; return this; } + + public boolean isDelegatingRowExpressionOptimizerEnabled() + { + return delegatingRowOptimizerEnabled; + } + + @Config("optimizer.delegating-row-expression-optimizer-enabled") + @ConfigDescription("Enable delegating row optimizer") + public FeaturesConfig setDelegatingRowExpressionOptimizerEnabled(boolean delegatingRowOptimizerEnabled) + { + this.delegatingRowOptimizerEnabled = delegatingRowOptimizerEnabled; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java new file mode 100644 index 0000000000000..ca9f39404a48e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; + +import javax.inject.Inject; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.presto.util.PropertiesUtil.loadProperties; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class ExpressionOptimizerManager + implements ExpressionOptimizerProvider +{ + private static final File EXPRESSION_MANAGER_CONFIGURATION = new File("etc/expression-manager.properties"); + public static final String EXPRESSION_MANAGER_FACTORY_NAME = "expression-manager-factory.name"; + + private final Map expressionOptimizerFactories = new ConcurrentHashMap<>(); + private final AtomicReference rowExpressionInterpreter = new AtomicReference<>(); + private final NodeManager nodeManager; + private final FunctionAndTypeManager functionAndTypeManager; + private final FunctionResolution functionResolution; + private final ExpressionOptimizer defaultExpressionOptimizer; + + @Inject + public ExpressionOptimizerManager(InternalNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, NodeInfo nodeInfo) + { + requireNonNull(nodeManager, "nodeManager is null"); + this.nodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + this.defaultExpressionOptimizer = new RowExpressionOptimizer(functionAndTypeManager); + rowExpressionInterpreter.set(defaultExpressionOptimizer); + } + + public void loadExpressions() + { + try { + if (EXPRESSION_MANAGER_CONFIGURATION.exists()) { + Map properties = new HashMap<>(loadProperties(EXPRESSION_MANAGER_CONFIGURATION)); + loadExpressions(properties); + } + } + catch (IOException e) { + throw new UncheckedIOException("Failed to load expression manager configuration", e); + } + } + + public void loadExpressions(Map properties) + { + properties = new HashMap<>(properties); + String factoryName = properties.remove(EXPRESSION_MANAGER_FACTORY_NAME); + checkArgument(!isNullOrEmpty(factoryName), "%s does not contain %s", EXPRESSION_MANAGER_CONFIGURATION, EXPRESSION_MANAGER_FACTORY_NAME); + checkArgument( + rowExpressionInterpreter.compareAndSet( + defaultExpressionOptimizer, + expressionOptimizerFactories.get(factoryName).createOptimizer(properties, new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution))), + "ExpressionManager is already loaded"); + } + + public void addExpressionOptimizerFactory(ExpressionOptimizerFactory expressionOptimizerFactory) + { + String name = expressionOptimizerFactory.getName(); + checkArgument( + this.expressionOptimizerFactories.putIfAbsent(name, expressionOptimizerFactory) == null, + "ExpressionOptimizerFactory %s is already registered", name); + } + + @Override + public ExpressionOptimizer getExpressionOptimizer() + { + return rowExpressionInterpreter.get(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerProvider.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerProvider.java new file mode 100644 index 0000000000000..77179582ae4fe --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerProvider.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.presto.spi.relation.ExpressionOptimizer; + +public interface ExpressionOptimizerProvider +{ + ExpressionOptimizer getExpressionOptimizer(); +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 32184097720da..435e966c625c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -22,6 +22,7 @@ import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; @@ -218,7 +219,8 @@ public PlanOptimizers( CostComparator costComparator, TaskCountEstimator taskCountEstimator, PartitioningProviderManager partitioningProviderManager, - FeaturesConfig featuresConfig) + FeaturesConfig featuresConfig, + ExpressionOptimizerManager expressionOptimizerManager) { this(metadata, sqlParser, @@ -233,7 +235,8 @@ public PlanOptimizers( costComparator, taskCountEstimator, partitioningProviderManager, - featuresConfig); + featuresConfig, + expressionOptimizerManager); } @PostConstruct @@ -264,7 +267,8 @@ public PlanOptimizers( CostComparator costComparator, TaskCountEstimator taskCountEstimator, PartitioningProviderManager partitioningProviderManager, - FeaturesConfig featuresConfig) + FeaturesConfig featuresConfig, + ExpressionOptimizerManager expressionOptimizerManager) { this.exporter = exporter; ImmutableList.Builder builder = ImmutableList.builder(); @@ -319,7 +323,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.>builder() - .addAll(new SimplifyRowExpressions(metadata).rules()) + .addAll(new SimplifyRowExpressions(metadata, expressionOptimizerManager).rules()) .add(new PruneRedundantProjectionAssignments()) .build()); @@ -485,7 +489,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.>builder() .add(new InlineProjectionsOnValues(metadata.getFunctionAndTypeManager())) - .addAll(new SimplifyRowExpressions(metadata).rules()) + .addAll(new SimplifyRowExpressions(metadata, expressionOptimizerManager).rules()) .build()), new IterativeOptimizer( metadata, @@ -844,7 +848,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges - builder.add(new CteProjectionAndPredicatePushDown(metadata)); // must run before PhysicalCteOptimizer + builder.add(new CteProjectionAndPredicatePushDown(metadata, expressionOptimizerManager)); // must run before PhysicalCteOptimizer builder.add(new PhysicalCteOptimizer(metadata)); // Must run before AddExchanges builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, partitioningProviderManager, featuresConfig.isNativeExecutionEnabled()))); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java index 9705f7e836fd5..3633dfc267e20 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java @@ -13,22 +13,26 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.Session; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.DelegatingRowExpressionOptimizer; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.facebook.presto.sql.relational.RowExpressionOptimizer; import com.google.common.annotations.VisibleForTesting; +import static com.facebook.presto.SystemSessionProperties.isDelegatingRowExpressionOptimizerEnabled; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND; @@ -39,44 +43,51 @@ public class SimplifyRowExpressions extends RowExpressionRewriteRuleSet { - public SimplifyRowExpressions(Metadata metadata) + public SimplifyRowExpressions(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) { - super(new Rewriter(metadata)); + super(new Rewriter(metadata, expressionOptimizerManager)); } private static class Rewriter implements PlanRowExpressionRewriter { - private final RowExpressionOptimizer optimizer; + private final ExpressionOptimizer inMemoryExpressionOptimizer; + private final ExpressionOptimizer delegatingExpressionOptimizer; private final LogicalExpressionRewriter logicalExpressionRewriter; - public Rewriter(Metadata metadata) + public Rewriter(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) { requireNonNull(metadata, "metadata is null"); - this.optimizer = new RowExpressionOptimizer(metadata); + requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + this.inMemoryExpressionOptimizer = new RowExpressionOptimizer(metadata); + this.delegatingExpressionOptimizer = new DelegatingRowExpressionOptimizer(metadata, expressionOptimizerManager); this.logicalExpressionRewriter = new LogicalExpressionRewriter(metadata.getFunctionAndTypeManager()); } @Override public RowExpression rewrite(RowExpression expression, Rule.Context context) { - return rewrite(expression, context.getSession().toConnectorSession()); + return rewrite(expression, context.getSession()); } - private RowExpression rewrite(RowExpression expression, ConnectorSession session) + private RowExpression rewrite(RowExpression expression, Session session) { // Rewrite RowExpression first to reduce depth of RowExpression tree by balancing AND/OR predicates. // It doesn't matter whether we rewrite/optimize first because this will be called by IterativeOptimizer. RowExpression rewritten = RowExpressionTreeRewriter.rewriteWith(logicalExpressionRewriter, expression, true); - RowExpression optimizedRowExpression = optimizer.optimize(rewritten, SERIALIZABLE, session); - return optimizedRowExpression; + if (isDelegatingRowExpressionOptimizerEnabled(session)) { + return delegatingExpressionOptimizer.optimize(rewritten, SERIALIZABLE, session.toConnectorSession()); + } + else { + return inMemoryExpressionOptimizer.optimize(rewritten, SERIALIZABLE, session.toConnectorSession()); + } } } @VisibleForTesting - public static RowExpression rewrite(RowExpression expression, Metadata metadata, ConnectorSession session) + public static RowExpression rewrite(RowExpression expression, Metadata metadata, Session session, ExpressionOptimizerManager expressionOptimizerManager) { - return new Rewriter(metadata).rewrite(expression, session); + return new Rewriter(metadata, expressionOptimizerManager).rewrite(expression, session); } private static class LogicalExpressionRewriter diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java index d64068aeb3c27..15434ee95b2a8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.RowExpressionVariableInliner; import com.facebook.presto.sql.planner.SimplePlanVisitor; @@ -87,10 +88,12 @@ public class CteProjectionAndPredicatePushDown implements PlanOptimizer { private final Metadata metadata; + private final ExpressionOptimizerManager expressionOptimizerManager; - public CteProjectionAndPredicatePushDown(Metadata metadata) + public CteProjectionAndPredicatePushDown(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) { - this.metadata = metadata; + this.metadata = requireNonNull(metadata, "metadata is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionManager is null"); } @Override @@ -383,7 +386,7 @@ private PlanNode addFilter(PlanNode node, List predicates) resultPredicate, predicates.get(i)); } } - resultPredicate = SimplifyRowExpressions.rewrite(resultPredicate, metadata, session.toConnectorSession()); + resultPredicate = SimplifyRowExpressions.rewrite(resultPredicate, metadata, session, expressionOptimizerManager); return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), node, resultPredicate); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java new file mode 100644 index 0000000000000..dafd7ef55f927 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.relational; + +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerProvider; + +import java.util.function.Function; + +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.DO_NOT_EVALUATE; +import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static java.util.Objects.requireNonNull; + +public final class DelegatingRowExpressionOptimizer + implements ExpressionOptimizer +{ + private static final int MAX_OPTIMIZATION_ATTEMPTS = 10; + private final ExpressionOptimizerProvider expressionOptimizerManager; + private final ExpressionOptimizer inMemoryOptimizer; + + public DelegatingRowExpressionOptimizer(Metadata metadata, ExpressionOptimizerProvider expressionOptimizerManager) + { + requireNonNull(metadata, "metadata is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + this.inMemoryOptimizer = new RowExpressionOptimizer(metadata); + } + + @Override + public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session) + { + ExpressionOptimizer delegate = expressionOptimizerManager.getExpressionOptimizer(); + RowExpression originalExpression; + for (int i = 0; i < MAX_OPTIMIZATION_ATTEMPTS; i++) { + // Do not optimize VariableReferenceExpression, ConstantExpression, and InputReferenceExpression because they cannot be optimized further + if (rowExpression instanceof VariableReferenceExpression || rowExpression instanceof ConstantExpression || rowExpression instanceof InputReferenceExpression) { + return rowExpression; + } + originalExpression = rowExpression; + rowExpression = delegate.optimize(rowExpression, level, session); + rowExpression = inMemoryOptimizer.optimize(rowExpression, DO_NOT_EVALUATE, session); + if (originalExpression.equals(rowExpression)) { + break; + } + } + return rowExpression; + } + + @Override + public Object optimize(RowExpression rowExpression, Level level, ConnectorSession session, Function variableResolver) + { + ExpressionOptimizer delegate = expressionOptimizerManager.getExpressionOptimizer(); + Object currentExpression = rowExpression; + Object originalExpression; + for (int i = 0; i < MAX_OPTIMIZATION_ATTEMPTS; i++) { + // Do not optimize VariableReferenceExpression, ConstantExpression, and InputReferenceExpression because they cannot be optimized further + if (currentExpression instanceof VariableReferenceExpression || currentExpression instanceof ConstantExpression || currentExpression instanceof InputReferenceExpression) { + return currentExpression; + } + originalExpression = currentExpression; + currentExpression = delegate.optimize(toRowExpression(currentExpression, rowExpression.getType()), level, session, variableResolver); + currentExpression = inMemoryOptimizer.optimize(toRowExpression(currentExpression, rowExpression.getType()), DO_NOT_EVALUATE, session); + if (originalExpression.equals(currentExpression)) { + break; + } + } + return currentExpression; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java index 1ab9d2ead2e6c..f4a7a067f7802 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -24,23 +26,30 @@ import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public final class RowExpressionOptimizer implements ExpressionOptimizer { - private final Metadata metadata; + private final FunctionAndTypeManager functionAndTypeManager; public RowExpressionOptimizer(Metadata metadata) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this(requireNonNull(metadata, "metadata is null").getFunctionAndTypeManager()); + } + + public RowExpressionOptimizer(FunctionMetadataManager functionMetadataManager) + { + checkArgument(functionMetadataManager instanceof FunctionAndTypeManager, "Expected functionMetadataManager to be instance of FunctionAndTypeManager"); + this.functionAndTypeManager = (FunctionAndTypeManager) requireNonNull(functionMetadataManager, "functionMetadataManager is null"); } @Override public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session) { if (level.ordinal() <= OPTIMIZED.ordinal()) { - return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, metadata.getFunctionAndTypeManager(), session, level).optimize(), rowExpression.getType()); + return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, functionAndTypeManager, session, level).optimize(), rowExpression.getType()); } throw new IllegalArgumentException("Not supported optimization level: " + level); } @@ -48,7 +57,7 @@ public RowExpression optimize(RowExpression rowExpression, Level level, Connecto @Override public Object optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) { - RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, metadata.getFunctionAndTypeManager(), session, level); + RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, functionAndTypeManager, session, level); return interpreter.optimize(variableResolver::apply); } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 93a30a8664d1b..d84eaeeec6bee 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -166,6 +166,7 @@ import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -345,6 +346,7 @@ public class LocalQueryRunner private static ExecutorService metadataExtractorExecutor = newCachedThreadPool(threadsNamed("query-execution-%s")); private final ReadWriteLock lock = new ReentrantReadWriteLock(); + private ExpressionOptimizerManager expressionOptimizerManager; public LocalQueryRunner(Session defaultSession) { @@ -481,6 +483,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, blockEncodingManager, featuresConfig); + expressionOptimizerManager = new ExpressionOptimizerManager(nodeManager, getFunctionAndTypeManager(), nodeInfo); + GlobalSystemConnectorFactory globalSystemConnectorFactory = new GlobalSystemConnectorFactory(ImmutableSet.of( new NodeSystemTable(nodeManager), new CatalogSystemTable(metadata, accessControl), @@ -515,7 +519,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ThrowingClusterTtlProviderManager(), historyBasedPlanStatisticsManager, new TracerProviderManager(new TracingConfig()), - new NodeStatusNotificationManager()); + new NodeStatusNotificationManager(), + expressionOptimizerManager); connectorManager.addConnectorFactory(globalSystemConnectorFactory); connectorManager.createConnection(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of()); @@ -685,6 +690,12 @@ public TestingAccessControlManager getAccessControl() return accessControl; } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return expressionOptimizerManager; + } + public ExecutorService getExecutor() { return notificationExecutor; @@ -1096,7 +1107,8 @@ public List getPlanOptimizers(boolean forceSingleNode) new CostComparator(featuresConfig), taskCountEstimator, partitioningProviderManager, - featuresConfig).getPlanningTimeOptimizers(); + featuresConfig, + expressionOptimizerManager).getPlanningTimeOptimizers(); } public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers, WarningCollector warningCollector) diff --git a/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java index 2cae8bfa14d56..b7e1301f6a1c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.Plan; @@ -60,6 +61,8 @@ public interface QueryRunner TestingAccessControlManager getAccessControl(); + ExpressionOptimizerManager getExpressionManager(); + MaterializedResult execute(@Language("SQL") String sql); MaterializedResult execute(Session session, @Language("SQL") String sql); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 01ddec6dee152..2e077cf330329 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -255,7 +255,8 @@ public void testDefaults() .setPrintEstimatedStatsFromCache(false) .setRemoveCrossJoinWithSingleConstantRow(true) .setUseHistograms(false) - .setInlineProjectionsOnValues(false)); + .setInlineProjectionsOnValues(false) + .setDelegatingRowExpressionOptimizerEnabled(false)); } @Test @@ -460,6 +461,7 @@ public void testExplicitPropertyMappings() .put("optimizer.remove-cross-join-with-single-constant-row", "false") .put("optimizer.use-histograms", "true") .put("optimizer.inline-projections-on-values", "true") + .put("optimizer.delegating-row-expression-optimizer-enabled", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -661,7 +663,8 @@ public void testExplicitPropertyMappings() .setPrintEstimatedStatsFromCache(true) .setRemoveCrossJoinWithSingleConstantRow(false) .setUseHistograms(true) - .setInlineProjectionsOnValues(true); + .setInlineProjectionsOnValues(true) + .setDelegatingRowExpressionOptimizerEnabled(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 1a967e51a145c..399ad22f1871c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -14,10 +14,18 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.functionNamespace.FunctionNamespaceManagerPlugin; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; @@ -76,8 +84,12 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.common.predicate.Domain.singleValue; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; @@ -88,8 +100,6 @@ import static com.facebook.presto.spi.plan.JoinType.RIGHT; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED; -import static com.facebook.presto.sql.TestExpressionInterpreter.AVG_UDAF_CPP; -import static com.facebook.presto.sql.TestExpressionInterpreter.SQUARE_UDF_CPP; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; @@ -144,6 +154,26 @@ public class TestLogicalPlanner extends BasePlanTest { + public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "Integer square", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned()); + + public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), + parseTypeSignature(StandardTypes.DOUBLE), + "Returns mean of doubles", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned(), + FunctionKind.AGGREGATE, + Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); + // TODO: Use com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder#tableScan with required node/stream // partitioning to properly test aggregation, window function and join. diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index b4ad658409b82..36f271133e194 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -13,9 +13,11 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.airlift.node.NodeInfo; import com.facebook.presto.Session; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; @@ -23,6 +25,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.Optimizer; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.TypeProvider; @@ -170,7 +173,12 @@ private List getMinimalOptimizers() new RuleStatsRecorder(), queryRunner.getStatsCalculator(), queryRunner.getCostCalculator(), - new SimplifyRowExpressions(metadata).rules())); + new SimplifyRowExpressions( + metadata, + new ExpressionOptimizerManager( + new InMemoryNodeManager(), + queryRunner.getFunctionAndTypeManager(), + new NodeInfo("test"))).rules())); } private void inTransaction(Function transactionSessionConsumer) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java index 19638181d35e4..64f3a251acfd2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java @@ -18,8 +18,10 @@ import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; @@ -33,12 +35,21 @@ public class TestRemoveMapCastRule extends BaseRuleTest { - @Test - public void testSubscriptCast() + @DataProvider(name = "delegating-row-expression-optimizer-enabled") + public Object[][] delegatingDataProvider() + { + return new Object[][] { + {true}, + {false}, + }; + } + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testSubscriptCast(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) .setSystemProperty(REMOVE_MAP_CAST, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", DOUBLE); VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); @@ -53,12 +64,13 @@ public void testSubscriptCast() values("feature", "key"))); } - @Test - public void testElementAtCast() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testElementAtCast(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) .setSystemProperty(REMOVE_MAP_CAST, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", DOUBLE); VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java index 47c50c31a7308..b33e399853715 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java @@ -19,8 +19,10 @@ import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.SystemSessionProperties.REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -32,13 +34,23 @@ public class TestRewriteConstantArrayContainsToInExpression extends BaseRuleTest { - @Test - public void testNoNull() + @DataProvider(name = "delegating-row-expression-optimizer-enabled") + public Object[][] delegatingDataProvider() + { + return new Object[][] { + {true}, + {false}, + }; + } + + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testNoNull(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll( new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); VariableReferenceExpression b = p.variable("b"); @@ -52,11 +64,12 @@ public void testNoNull() values("b"))); } - @Test - public void testDoesNotFireForNestedArray() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testDoesNotFireForNestedArray(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).projectRowExpressionRewriteRule()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); VariableReferenceExpression b = p.variable("b", new ArrayType(BIGINT)); @@ -67,8 +80,8 @@ public void testDoesNotFireForNestedArray() .doesNotFire(); } - @Test - public void testDoesNotFireForNull() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testDoesNotFireForNull(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).projectRowExpressionRewriteRule()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") @@ -82,8 +95,8 @@ public void testDoesNotFireForNull() .doesNotFire(); } - @Test - public void testDoesNotFireForEmpty() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testDoesNotFireForEmpty(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).projectRowExpressionRewriteRule()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") @@ -97,11 +110,11 @@ public void testDoesNotFireForEmpty() .doesNotFire(); } - @Test - public void testNotFire() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testNotFire(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll( new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") .on(p -> { @@ -118,11 +131,11 @@ public void testNotFire() values("b", "c"))); } - @Test - public void testWithNull() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testWithNull(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll( new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") .on(p -> { @@ -138,11 +151,11 @@ public void testWithNull() values("b"))); } - @Test - public void testLambda() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testLambda(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll( new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") .on(p -> { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 96da619e3ccb1..19568cb110fe9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -13,15 +13,19 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.Expression; @@ -36,6 +40,7 @@ import java.util.stream.Stream; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; @@ -43,6 +48,8 @@ import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static com.facebook.presto.sql.relational.Expressions.specialForm; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static java.lang.String.format; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; @@ -181,12 +188,24 @@ private static void assertSimplifies(String expression, String rowExpressionExpe { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); + InMemoryNodeManager nodeManager = new InMemoryNodeManager(); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(nodeManager, METADATA.getFunctionAndTypeManager(), new NodeInfo("test")); + expressionOptimizerManager.loadExpressions(); + TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA); RowExpression actualRowExpression = translator.translate(actualExpression, TypeProvider.viewOf(TYPES)); - RowExpression simplifiedRowExpression = SimplifyRowExpressions.rewrite(actualRowExpression, METADATA, TEST_SESSION.toConnectorSession()); + RowExpression simplifiedRowExpression = SimplifyRowExpressions.rewrite(actualRowExpression, METADATA, TEST_SESSION, expressionOptimizerManager); Expression expectedByRowExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(rowExpressionExpected)); RowExpression simplifiedByExpression = translator.translate(expectedByRowExpression, TypeProvider.viewOf(TYPES)); assertEquals(normalize(simplifiedRowExpression), normalize(simplifiedByExpression)); + + Session session = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, "true") + .build(); + RowExpression sidecarSimplifiedExpressions = SimplifyRowExpressions.rewrite(simplifiedRowExpression, METADATA, session, expressionOptimizerManager); + assertEquals(normalize(sidecarSimplifiedExpressions), normalize(simplifiedByExpression)); } private static RowExpression normalize(RowExpression expression) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java index 04ce733e96aab..0dafb0c96b8b6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java @@ -16,6 +16,7 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.Plugin; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.Plan; import com.google.common.collect.ImmutableList; import org.testng.annotations.AfterClass; @@ -84,4 +85,9 @@ protected void assertNodePresentInPlan(Plan plan, Class nodeClass) .matches(), "Expected " + nodeClass.toString() + " in plan after optimization. "); } + + protected ExpressionOptimizerManager getExpressionManager() + { + return tester.getExpressionManager(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java index bad0f6b342d13..77b0de9cb8c18 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.OptimizerAssert; @@ -60,6 +61,7 @@ public class RuleTester private final PageSourceManager pageSourceManager; private final AccessControl accessControl; private final SqlParser sqlParser; + private ExpressionOptimizerManager expressionOptimizerManager; public RuleTester() { @@ -106,6 +108,8 @@ public RuleTester(List plugins, Map sessionProperties, S connectorFactory, ImmutableMap.of()); plugins.stream().forEach(queryRunner::installPlugin); + expressionOptimizerManager = queryRunner.getExpressionManager(); + expressionOptimizerManager.loadExpressions(); this.metadata = queryRunner.getMetadata(); this.transactionManager = queryRunner.getTransactionManager(); @@ -196,4 +200,9 @@ public List> getTableConstraints(TableHandle table return metadata.getTableMetadata(transactionSession, tableHandle).getMetadata().getTableConstraintsHolder().getTableConstraintsWithColumnHandles(); }); } + + public ExpressionOptimizerManager getExpressionManager() + { + return expressionOptimizerManager; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java index f36c3aaabcf15..b50f378b9e700 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java @@ -144,7 +144,7 @@ private void assertCtePlan(String sql, PlanMatchPattern pattern) new RemoveIdentityProjectionsBelowProjection(), new PruneRedundantProjectionAssignments())), new PruneUnreferencedOutputs(), - new CteProjectionAndPredicatePushDown(metadata)); + new CteProjectionAndPredicatePushDown(metadata, getQueryRunner().getExpressionManager())); assertPlan(sql, getSession(), Optimizer.PlanStage.OPTIMIZED, pattern, optimizers); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java new file mode 100644 index 0000000000000..fb5088f9b72e6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.relational; + +import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.common.block.IntArrayBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.net.URI; + +import static com.facebook.airlift.testing.Assertions.assertInstanceOf; +import static com.facebook.presto.block.BlockAssertions.toValues; +import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.JsonType.JSON; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static io.airlift.slice.Slices.utf8Slice; +import static org.testng.Assert.assertEquals; + +public class TestDelegatingRowExpressionOptimizer +{ + private DelegatingRowExpressionOptimizer optimizer; + private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); + + @BeforeClass + public void setUp() + { + InMemoryNodeManager inMemoryNodeManager = new InMemoryNodeManager(); + inMemoryNodeManager.addNode(new ConnectorId("test"), new InternalNode("id", URI.create("id"), new NodeVersion("test"), false)); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(inMemoryNodeManager, METADATA.getFunctionAndTypeManager(), new NodeInfo("env")); + optimizer = new DelegatingRowExpressionOptimizer(METADATA, expressionOptimizerManager); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + optimizer = null; + } + + @Test + public void testIfConstantOptimization() + { + assertEquals(optimize(ifExpression(constant(true, BOOLEAN), 1L, 2L)), constant(1L, BIGINT)); + assertEquals(optimize(ifExpression(constant(false, BOOLEAN), 1L, 2L)), constant(2L, BIGINT)); + assertEquals(optimize(ifExpression(constant(null, BOOLEAN), 1L, 2L)), constant(2L, BIGINT)); + + FunctionHandle bigintEquals = METADATA.getFunctionAndTypeManager().resolveOperator(EQUAL, fromTypes(BIGINT, BIGINT)); + RowExpression condition = new CallExpression(EQUAL.name(), bigintEquals, BOOLEAN, ImmutableList.of(constant(3L, BIGINT), constant(3L, BIGINT))); + assertEquals(optimize(ifExpression(condition, 1L, 2L)), constant(1L, BIGINT)); + } + + @Test + public void testCastWithJsonParseOptimization() + { + FunctionHandle jsonParseFunctionHandle = METADATA.getFunctionAndTypeManager().lookupFunction("json_parse", fromTypes(VARCHAR)); + + // constant + FunctionHandle jsonCastFunctionHandle = METADATA.getFunctionAndTypeManager().lookupCast(CAST, JSON, METADATA.getFunctionAndTypeManager().getType(parseTypeSignature("array(integer)"))); + RowExpression jsonCastExpression = new CallExpression(CAST.name(), jsonCastFunctionHandle, new ArrayType(INTEGER), ImmutableList.of(call("json_parse", jsonParseFunctionHandle, JSON, constant(utf8Slice("[1, 2]"), VARCHAR)))); + RowExpression resultExpression = optimize(jsonCastExpression); + assertInstanceOf(resultExpression, ConstantExpression.class); + Object resultValue = ((ConstantExpression) resultExpression).getValue(); + assertInstanceOf(resultValue, IntArrayBlock.class); + assertEquals(toValues(INTEGER, (IntArrayBlock) resultValue), ImmutableList.of(1, 2)); + } + + private static RowExpression ifExpression(RowExpression condition, long trueValue, long falseValue) + { + return new SpecialFormExpression(IF, BIGINT, ImmutableList.of(condition, constant(trueValue, BIGINT), constant(falseValue, BIGINT))); + } + + private RowExpression optimize(RowExpression expression) + { + return optimizer.optimize(expression, OPTIMIZED, SESSION); + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java index 09cd72184cac8..b8fdbe06b72da 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.testing.MaterializedResult; @@ -203,6 +204,12 @@ public TestingAccessControlManager getAccessControl() throw new UnsupportedOperationException(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return null; + } + @Override public MaterializedResult execute(String sql) { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index 8eb985d2a059f..aeeacdcf588f6 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -170,6 +170,7 @@ import com.facebook.presto.sql.analyzer.MetadataExtractorMBean; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -336,6 +337,9 @@ protected void setup(Binder binder) binder.bind(AnalyzePropertyManager.class).in(Scopes.SINGLETON); binder.bind(QuerySessionSupplier.class).in(Scopes.SINGLETON); + // expression manager + binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + // tracer provider managers binder.bind(TracerProviderManager.class).in(Scopes.SINGLETON); diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java index 015f1936c4bc0..723056a4bac17 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java @@ -64,6 +64,7 @@ import com.facebook.presto.spi.security.PrincipalType; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -495,6 +496,12 @@ public TestingAccessControlManager getAccessControl() return testingAccessControlManager; } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + throw new UnsupportedOperationException(); + } + public HistoryBasedPlanStatisticsManager getHistoryBasedPlanStatisticsManager() { return historyBasedPlanStatisticsManager; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizer.java b/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizer.java index 70f72a555058c..2ee314495fec1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizer.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizer.java @@ -28,6 +28,10 @@ public interface ExpressionOptimizer enum Level { + /** + * DO_NOT_EVALUATE does not evaluate functions, but will simplify expressions where logical equivalents can be made + */ + DO_NOT_EVALUATE, /** * SERIALIZABLE guarantees the optimized RowExpression can be serialized and deserialized. */ diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 1757ba65e8d5a..a3cfdef39361f 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.tests; +import com.facebook.airlift.node.NodeInfo; import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.CostCalculator; @@ -21,11 +22,15 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.TaskCountEstimator; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.Plan; @@ -57,6 +62,7 @@ import java.util.OptionalLong; import java.util.function.Consumer; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.sql.SqlFormatter.formatSql; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -74,6 +80,7 @@ public abstract class AbstractTestQueryFramework { + private static final NodeInfo NODE_INFO = new NodeInfo("test"); private QueryRunner queryRunner; private ExpectedQueryRunner expectedQueryRunner; private SqlParser sqlParser; @@ -568,7 +575,13 @@ private QueryExplainer getQueryExplainer() new CostComparator(featuresConfig), taskCountEstimator, new PartitioningProviderManager(), - featuresConfig) + featuresConfig, + new ExpressionOptimizerManager( + new InMemoryNodeManager(), + queryRunner.getMetadata().getFunctionAndTypeManager(), + NODE_INFO, + // TODO: @tdm simple codec won't work, need to wire it + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))) .getPlanningTimeOptimizers(); return new QueryExplainer( optimizers, diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index 8b5f88303ee0e..1f217ce0e7ae2 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -42,6 +42,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -581,6 +582,13 @@ public TestingAccessControlManager getAccessControl() return coordinators.get(0).getAccessControl(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + checkState(coordinators.size() == 1, "Expected a single coordinator"); + return coordinators.get(0).getExpressionManager(); + } + public TestingPrestoServer getCoordinator() { checkState(coordinators.size() == 1, "Expected a single coordinator"); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java index 05e87aa335de3..4a748fdfa01b3 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -181,6 +182,12 @@ public TestingAccessControlManager getAccessControl() return server.getAccessControl(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return server.getExpressionManager(); + } + public TestingPrestoServer getServer() { return server; diff --git a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java index 30951ad186cbf..19d22733f67a5 100644 --- a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java +++ b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java @@ -512,7 +512,7 @@ public void clusterPoolsMultiCoordinatorCleanup() queryRunner2.close(); } - @Test(timeOut = 60_000, groups = {"clusterPoolsMultiCoordinator"}) + @Test(timeOut = 600_000, groups = {"clusterPoolsMultiCoordinator"}) public void testClusterPoolsMultiCoordinator() throws Exception { @@ -544,6 +544,7 @@ public void testClusterPoolsMultiCoordinator() generalPool = memoryManager.getClusterInfo(GENERAL_POOL); reservedPool = memoryManager.getClusterInfo(RESERVED_POOL); MILLISECONDS.sleep(10); + System.out.println("waiting"); } // Make sure the queries are blocked diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java new file mode 100644 index 0000000000000..57f28fa483645 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java @@ -0,0 +1,182 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.relational.DelegatingRowExpressionOptimizer; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.type.LikeFunctions.castVarcharToLikePattern; +import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static org.testng.Assert.assertEquals; + +public class TestDelegatingExpressionOptimizer + extends TestExpressions +{ + public static final FunctionResolution RESOLUTION = new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + private ExpressionOptimizer expressionOptimizer; + + @BeforeClass + public void setup() + { + METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); + setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); + + expressionOptimizer = new DelegatingRowExpressionOptimizer(METADATA, () -> TestNativeExpressions.getExpressionOptimizer(METADATA, HANDLE_RESOLVER)); + } + + @Test + public void assertLikeOptimizations() + { + assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)"); + } + + // TODO: this test is invalid as it manually constructs an expression which can't be serialized + @Test(enabled = false) + @Override + public void testLikeInvalidUtf8() + { + } + + // TODO: lambdas are currently unsupported by this test + @Test(enabled = false) + @Override + public void testLambda() + { + } + + // TODO: current timestamp returns the session timestamp, which is not supported by this test + @Test(enabled = false) + @Override + public void testCurrentTimestamp() + { + } + + // TODO: this function is not supported by this test because its contents are not serializable + @Test(enabled = false) + @Override + public void testMassiveArray() + { + } + + // TODO: non-deterministic function calls are not supported by this test and need to be tested separately + @Test(enabled = false) + @Override + public void testNonDeterministicFunctionCall() + { } + + @Override + protected void assertLike(byte[] value, String pattern, boolean expected) + { + CallExpression predicate = call( + "LIKE", + RESOLUTION.likeVarcharFunction(), + BOOLEAN, + ImmutableList.of( + new ConstantExpression(wrappedBuffer(value), VARCHAR), + new ConstantExpression(castVarcharToLikePattern(utf8Slice(pattern)), LIKE_PATTERN))); + assertEquals(optimizeRowExpression(predicate, EVALUATED), expected); + } + @Override + protected Object evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + return optimizeRowExpression(rowExpression, EVALUATED); + } + + @Override + protected Object optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + RowExpression parsedExpression = sqlToRowExpression(expression); + return optimizeRowExpression(parsedExpression, OPTIMIZED); + } + + @Override + protected Object optimizeRowExpression(RowExpression expression, Level level) + { + Object optimized = expressionOptimizer.optimize( + expression, + level, + TEST_SESSION.toConnectorSession(), + variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + return unwrap(optimized); + } + + public Object unwrap(Object result) + { + if (result instanceof ConstantExpression) { + return ((ConstantExpression) result).getValue(); + } + else { + return result; + } + } + + @Override + protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object optimizedActual = optimize(actual); + Object optimizedExpected = optimize(expected); + assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected); + } + + @Override + protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object actualOptimized = optimize(actual); + Object expectedOptimized = optimize(expected); + assertRowExpressionEvaluationEquals( + actualOptimized, + expectedOptimized); + } + + @Override + protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + Object rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java new file mode 100644 index 0000000000000..0486e807bd0c1 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java @@ -0,0 +1,225 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.planner.ExpressionInterpreter; +import com.facebook.presto.sql.planner.RowExpressionInterpreter; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.LikePredicate; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.StringLiteral; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionInterpreter; +import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; +import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; +import static java.util.Collections.emptyMap; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestExpressionInterpreter + extends TestExpressions +{ + private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + + @Test + public void assertLikeOptimizations() + { + assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); + } + + @Override + protected void assertLike(byte[] value, String pattern, boolean expected) + { + Expression predicate = new LikePredicate( + rawStringLiteral(Slices.wrappedBuffer(value)), + new StringLiteral(pattern), + Optional.empty()); + assertEquals(evaluate(predicate, true), expected); + } + + private static StringLiteral rawStringLiteral(final Slice slice) + { + return new StringLiteral(slice.toStringUtf8()) + { + @Override + public Slice getSlice() + { + return slice; + } + }; + } + + @Override + protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertEquals(optimize(actual), optimize(expected)); + } + + @Override + protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + // replaces FunctionCalls to FailureFunction by fail() + Object actualOptimized = optimize(actual); + if (actualOptimized instanceof Expression) { + actualOptimized = ExpressionTreeRewriter.rewriteWith(new FailedFunctionRewriter(), (Expression) actualOptimized); + } + assertEquals( + actualOptimized, + rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected))); + } + + @Override + protected Object optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + + Expression parsedExpression = expression(expression); + Object expressionResult = optimize(parsedExpression); + + RowExpression rowExpression = toRowExpression(parsedExpression); + Object rowExpressionResult = optimizeRowExpression(rowExpression, OPTIMIZED); + assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); + return expressionResult; + } + + @Override + protected Object optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) + { + RowExpressionInterpreter rowExpressionInterpreter = new RowExpressionInterpreter(expression, METADATA, TEST_SESSION.toConnectorSession(), level); + return rowExpressionInterpreter.optimize(variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + } + + private static Expression expression(String expression) + { + return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + } + + private static RowExpression toRowExpression(Expression expression) + { + return TRANSLATOR.translate(expression, SYMBOL_TYPES); + } + + private Object optimize(Expression expression) + { + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); + ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); + return interpreter.optimize(variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return symbol.toSymbolReference(); + } + return value; + }); + } + + @Override + protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + { + assertRoundTrip(expression); + Expression translatedExpression = expression(expression); + RowExpression rowExpression = toRowExpression(translatedExpression); + + Object expressionResult = optimize(translatedExpression); + if (expressionResult instanceof Expression) { + expressionResult = toRowExpression((Expression) expressionResult); + } + Object rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } + + private void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) + { + if (rowExpressionResult instanceof RowExpression) { + // Cannot be completely evaluated into a constant; compare expressions + assertTrue(expressionResult instanceof Expression); + + // It is tricky to check the equivalence of an expression and a row expression. + // We rely on the optimized translator to fill the gap. + RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); + assertRowExpressionEvaluationEquals(translated, rowExpressionResult); + } + else { + // We have constants; directly compare + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + } + } + @Override + protected Object evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + + Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + + return evaluate(parsedExpression, deterministic); + } + + private Object evaluate(Expression expression, boolean deterministic) + { + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); + Object expressionResult = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes).evaluate(); + Object rowExpressionResult = rowExpressionInterpreter(TRANSLATOR.translateAndOptimize(expression), METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession()).evaluate(); + + if (deterministic) { + assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); + } + return expressionResult; + } + + private static class FailedFunctionRewriter + extends ExpressionRewriter + { + @Override + public Expression rewriteFunctionCall(FunctionCall node, Object context, ExpressionTreeRewriter treeRewriter) + { + if (node.getName().equals(QualifiedName.of("fail"))) { + return new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(node.getArguments().get(0), new StringLiteral("ignored failure message"))); + } + return node; + } + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java new file mode 100644 index 0000000000000..1db46cd8862ee --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; +import com.google.common.collect.ImmutableList; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.type.LikeFunctions.castVarcharToLikePattern; +import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static org.testng.Assert.assertEquals; + +public class TestExpressionOptimizers + extends TestExpressions +{ + public static final FunctionResolution RESOLUTION = new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + private ExpressionOptimizer expressionOptimizer; + + @BeforeClass + public void setup() + { + expressionOptimizer = new RowExpressionOptimizer(METADATA.getFunctionAndTypeManager()); + } + + @Test + public void assertLikeOptimizations() + { + assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)"); + } + + @Override + protected void assertLike(byte[] value, String pattern, boolean expected) + { + CallExpression predicate = call( + "LIKE", + RESOLUTION.likeVarcharFunction(), + BOOLEAN, + ImmutableList.of( + new ConstantExpression(wrappedBuffer(value), VARCHAR), + new ConstantExpression(castVarcharToLikePattern(utf8Slice(pattern)), LIKE_PATTERN))); + assertEquals(optimizeRowExpression(predicate, EVALUATED), expected); + } + @Override + protected Object evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + return optimizeRowExpression(rowExpression, EVALUATED); + } + + @Override + protected Object optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + RowExpression parsedExpression = sqlToRowExpression(expression); + return optimizeRowExpression(parsedExpression, OPTIMIZED); + } + + @Override + protected Object optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) + { + return expressionOptimizer.optimize( + expression, + level, + TEST_SESSION.toConnectorSession(), + variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + } + + @Override + protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object optimizedActual = optimize(actual); + Object optimizedExpected = optimize(expected); + assertEquals(optimizedActual, optimizedExpected); + } + + @Override + protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object actualOptimized = optimize(actual); + Object expectedOptimized = optimize(expected); + assertRowExpressionEvaluationEquals( + actualOptimized, + expectedOptimized); + } + + @Override + protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + Object rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java similarity index 88% rename from presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java rename to presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java index 37295146c4296..785be015e06b8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql; +package com.facebook.presto.tests.expressions; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; @@ -24,14 +24,19 @@ import com.facebook.presto.common.type.SqlTimestampWithTimeZone; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; +import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.CatalogManager; +import com.facebook.presto.metadata.ColumnPropertyManager; import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.SchemaPropertyManager; +import com.facebook.presto.metadata.SessionPropertyManager; +import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.operator.scalar.FunctionAssertions; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.AggregationFunctionMetadata; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Parameter; @@ -39,29 +44,26 @@ import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.InputReferenceExpression; import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.sql.parser.ParsingOptions; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.ExpressionInterpreter; -import com.facebook.presto.sql.planner.RowExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.ExpressionRewriter; -import com.facebook.presto.sql.tree.ExpressionTreeRewriter; -import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.LikePredicate; -import com.facebook.presto.sql.tree.NodeRef; -import com.facebook.presto.sql.tree.QualifiedName; -import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.transaction.TransactionManager; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; @@ -75,7 +77,6 @@ import org.testng.annotations.Test; import java.math.BigInteger; -import java.util.Map; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; @@ -91,6 +92,7 @@ import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; @@ -98,28 +100,23 @@ import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; -import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; -import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; -import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionInterpreter; -import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; -import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; +import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager; import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; import static java.lang.String.format; -import static java.util.Collections.emptyMap; import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; -public class TestExpressionInterpreter +public abstract class TestExpressions { public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), @@ -142,11 +139,11 @@ public class TestExpressionInterpreter Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); private static final int TEST_VARCHAR_TYPE_LENGTH = 17; - private static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() + protected static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() .put("bound_integer", INTEGER) .put("bound_long", BIGINT) .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) - .put("bound_varbinary", VarbinaryType.VARBINARY) + .put("bound_varbinary", VARBINARY) .put("bound_double", DOUBLE) .put("bound_boolean", BOOLEAN) .put("bound_date", DATE) @@ -173,15 +170,40 @@ public class TestExpressionInterpreter .put("unbound_null_string", VARCHAR) .build()); - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + protected static final HandleResolver HANDLE_RESOLVER = new HandleResolver(); + protected static final Metadata METADATA = createTestingMetadata(HANDLE_RESOLVER); + + // This complicated factory method is needed instead of MetadataManager.createTestMetadataManager because + // we need to retain access to the HandleResolver instance for serializing and deserializing custom registered functions + // in child tests + private static Metadata createTestingMetadata(HandleResolver handleResolver) + { + BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); + CatalogManager catalogManager = new CatalogManager(); + TransactionManager transactionManager = createTestTransactionManager(catalogManager); + return new MetadataManager( + new FunctionAndTypeManager(transactionManager, blockEncodingManager, new FeaturesConfig(), new FunctionsConfig(), handleResolver, ImmutableSet.of()), + blockEncodingManager, + new SessionPropertyManager(), + new SchemaPropertyManager(), + new TablePropertyManager(), + new ColumnPropertyManager(), + new AnalyzePropertyManager(), + transactionManager); + } + + protected static final SqlParser SQL_PARSER = new SqlParser(); private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); - private static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); + protected static final BlockEncodingSerde BLOCK_ENCODING_SERDE = new BlockEncodingManager(); + + public TestExpressions() + { + METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); + } @BeforeClass public void setup() { - METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); } @@ -427,7 +449,7 @@ public void testCppAggregateFunctionCall() } // Run this method exactly once. - private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + protected void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) { functionAndTypeManager.addFunctionNamespaceFactory(new JsonFileBasedFunctionNamespaceManagerFactory()); functionAndTypeManager.loadFunctionNamespaceManager( @@ -1162,7 +1184,7 @@ public void testSimpleCase() "else 3 " + "end"); - assertOptimizedEquals("case true " + + assertOptimizedMatches("case true " + "when unbound_long = 1 then 1 " + "when 0 / 0 = 0 then 2 " + "else 33 end", @@ -1199,18 +1221,6 @@ public void testSimpleCase() "when unbound_long then 4 " + "end"); - assertOptimizedMatches("case 1 " + - "when unbound_long then 1 " + - "when 0 / 0 then 2 " + - "else 1 " + - "end", - "" + - "case BIGINT '1' " + - "when unbound_long then 1 " + - "when cast(fail(8, 'ignored failure message') AS integer) then 2 " + - "else 1 " + - "end"); - assertOptimizedMatches("case 1 " + "when 0 / 0 then 1 " + "when 0 / 0 then 2 " + @@ -1394,16 +1404,15 @@ public void testLikeOptimization() assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); - assertOptimizedEquals("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); + assertOptimizedMatches("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); + assertOptimizedMatches("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); - assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); - assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); + assertOptimizedMatches("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); } @Test @@ -1586,123 +1595,27 @@ public void testLiterals() optimize("interval '3' day * unbound_long"); optimize("interval '3' year * unbound_long"); - assertEquals(optimize("X'1234'"), Slices.wrappedBuffer((byte) 0x12, (byte) 0x34)); - } - - private static void assertLike(byte[] value, String pattern, boolean expected) - { - Expression predicate = new LikePredicate( - rawStringLiteral(Slices.wrappedBuffer(value)), - new StringLiteral(pattern), - Optional.empty()); - assertEquals(evaluate(predicate, true), expected); + assertEquals(optimize("X'1234'"), wrappedBuffer((byte) 0x12, (byte) 0x34)); } + protected abstract Object evaluate(String expression, boolean deterministic); - private static StringLiteral rawStringLiteral(final Slice slice) - { - return new StringLiteral(slice.toStringUtf8()) - { - @Override - public Slice getSlice() - { - return slice; - } - }; - } + protected abstract Object optimize(@Language("SQL") String expression); - private static void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) - { - assertEquals(optimize(actual), optimize(expected)); - } - - private static void assertRowExpressionEquals(Level level, @Language("SQL") String actual, @Language("SQL") String expected) - { - Object actualResult = optimize(toRowExpression(expression(actual)), level); - Object expectedResult = optimize(toRowExpression(expression(expected)), level); - if (actualResult instanceof Block && expectedResult instanceof Block) { - assertEquals(blockToSlice((Block) actualResult), blockToSlice((Block) expectedResult)); - return; - } - assertEquals(actualResult, expectedResult); - } - - private static void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) - { - // replaces FunctionCalls to FailureFunction by fail() - Object actualOptimized = optimize(actual); - if (actualOptimized instanceof Expression) { - actualOptimized = ExpressionTreeRewriter.rewriteWith(new FailedFunctionRewriter(), (Expression) actualOptimized); - } - assertEquals( - actualOptimized, - rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected))); - } - - private static Object optimize(@Language("SQL") String expression) - { - assertRoundTrip(expression); - - Expression parsedExpression = expression(expression); - Object expressionResult = optimize(parsedExpression); - - RowExpression rowExpression = toRowExpression(parsedExpression); - Object rowExpressionResult = optimize(rowExpression, OPTIMIZED); - assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); - return expressionResult; - } - - private static Expression expression(String expression) - { - return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - } + protected abstract void assertOptimizedEquals(@Language("SQL") String expression, @Language("SQL") String expected); - private static RowExpression toRowExpression(Expression expression) - { - return TRANSLATOR.translate(expression, SYMBOL_TYPES); - } + protected abstract void assertLike(byte[] value, String pattern, boolean expected); - private static Object optimize(Expression expression) - { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); - ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); - return interpreter.optimize(variable -> { - Symbol symbol = new Symbol(variable.getName()); - Object value = symbolConstant(symbol); - if (value == null) { - return symbol.toSymbolReference(); - } - return value; - }); - } + protected abstract void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected); - private static Object optimize(RowExpression expression, Level level) - { - return new RowExpressionInterpreter(expression, METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession(), level).optimize(variable -> { - Symbol symbol = new Symbol(variable.getName()); - Object value = symbolConstant(symbol); - if (value == null) { - return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); - } - return value; - }); - } + protected abstract void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel); - private static void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + protected RowExpression sqlToRowExpression(String expression) { - assertRoundTrip(expression); - Expression translatedExpression = expression(expression); - RowExpression rowExpression = toRowExpression(translatedExpression); - - Object expressionResult = optimize(translatedExpression); - if (expressionResult instanceof Expression) { - expressionResult = toRowExpression((Expression) expressionResult); - } - Object rowExpressionResult = optimize(rowExpression, optimizationLevel); - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + return TRANSLATOR.translate(parsedExpression, SYMBOL_TYPES); } - private static Object symbolConstant(Symbol symbol) + protected Object symbolConstant(Symbol symbol) { switch (symbol.getName().toLowerCase(ENGLISH)) { case "bound_integer": @@ -1733,32 +1646,14 @@ private static Object symbolConstant(Symbol symbol) return null; } - private static void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) - { - if (rowExpressionResult instanceof RowExpression) { - // Cannot be completely evaluated into a constant; compare expressions - assertTrue(expressionResult instanceof Expression); - - // It is tricky to check the equivalence of an expression and a row expression. - // We rely on the optimized translator to fill the gap. - RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); - assertRowExpressionEvaluationEquals(translated, rowExpressionResult); - } - else { - // We have constants; directly compare - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - } - } - /** * Assert the evaluation result of two row expressions equivalent * no matter they are constants or remaining row expressions. */ - private static void assertRowExpressionEvaluationEquals(Object left, Object right) + protected void assertRowExpressionEvaluationEquals(Object left, Object right) { if (right instanceof RowExpression) { assertTrue(left instanceof RowExpression); - // assertEquals(((RowExpression) left).getType(), ((RowExpression) right).getType()); if (left instanceof ConstantExpression) { if (isRemovableCast(right)) { assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); @@ -1770,6 +1665,13 @@ private static void assertRowExpressionEvaluationEquals(Object left, Object righ else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { assertEquals(left, right); } + else if (left instanceof CallExpression && ((CallExpression) left).getFunctionHandle().getName().contains("fail")) { + assertTrue(right instanceof CallExpression && ((CallExpression) right).getFunctionHandle().getName().contains("fail")); + assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); + for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); + } + } else if (left instanceof CallExpression) { assertTrue(right instanceof CallExpression); assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle()); @@ -1806,7 +1708,7 @@ else if (left instanceof SpecialFormExpression) { } } - private static boolean isRemovableCast(Object value) + protected boolean isRemovableCast(Object value) { if (value instanceof CallExpression && new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { @@ -1817,57 +1719,35 @@ private static boolean isRemovableCast(Object value) return false; } - private static Slice blockToSlice(Block block) + protected Slice blockToSlice(Block block) { // This function is strictly for testing use only SliceOutput sliceOutput = new DynamicSliceOutput(1000); - BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); + BlockSerdeUtil.writeBlock(BLOCK_ENCODING_SERDE, sliceOutput, block); return sliceOutput.slice(); } - private static void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + protected void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) { assertEquals(evaluate(actual, true), evaluate(expected, true)); } - private static Object evaluate(String expression, boolean deterministic) - { - assertRoundTrip(expression); - - Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - - return evaluate(parsedExpression, deterministic); - } - - private static void assertRoundTrip(String expression) + protected void assertRoundTrip(String expression) { ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); } - - private static Object evaluate(Expression expression, boolean deterministic) + protected void assertRowExpressionEquals(ExpressionOptimizer.Level level, @Language("SQL") String actual, @Language("SQL") String expected) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); - Object expressionResult = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes).evaluate(); - Object rowExpressionResult = rowExpressionInterpreter(TRANSLATOR.translateAndOptimize(expression), METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession()).evaluate(); - - if (deterministic) { - assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); + Object actualResult = optimizeRowExpression(sqlToRowExpression(actual), level); + Object expectedResult = optimizeRowExpression(sqlToRowExpression(expected), level); + if (actualResult instanceof Block && expectedResult instanceof Block) { + assertEquals(blockToSlice((Block) actualResult), blockToSlice((Block) expectedResult)); + return; } - return expressionResult; + assertEquals(actualResult, expectedResult); } - private static class FailedFunctionRewriter - extends ExpressionRewriter - { - @Override - public Expression rewriteFunctionCall(FunctionCall node, Object context, ExpressionTreeRewriter treeRewriter) - { - if (node.getName().equals(QualifiedName.of("fail"))) { - return new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(node.getArguments().get(0), new StringLiteral("ignored failure message"))); - } - return node; - } - } + protected abstract Object optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level); } diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java index 7f47f65ff49e8..adcfc11e7ce00 100644 --- a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java @@ -34,6 +34,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.testing.MaterializedResult; @@ -247,6 +248,12 @@ public TestingAccessControlManager getAccessControl() return source.getAccessControl(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return source.getExpressionManager(); + } + @Override public MaterializedResult execute(String sql) { From 6288bcd39f5d5b605697a0f89368b01b4d455da0 Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Tue, 30 Jul 2024 16:21:46 -0400 Subject: [PATCH 05/10] Add native row expression optimizer --- pom.xml | 7 + .../presto/metadata/HandleJsonModule.java | 19 +- .../presto/server/ServerMainModule.java | 5 + .../ExpressionOptimizerManager.java | 7 +- .../JsonCodecRowExpressionSerde.java | 48 ++ .../presto/testing/LocalQueryRunner.java | 5 +- .../planner/assertions/OptimizerAssert.java | 6 +- .../rule/TestSimplifyRowExpressions.java | 4 +- .../TestDelegatingRowExpressionOptimizer.java | 4 +- presto-native-plugin/pom.xml | 123 ++++ .../sql/expressions/ForSidecarInfo.java | 26 + .../NativeExpressionOptimizer.java | 610 ++++++++++++++++++ .../NativeExpressionOptimizerFactory.java | 64 ++ .../NativeExpressionOptimizerProvider.java | 42 ++ .../NativeExpressionsCommunicationModule.java | 29 + .../expressions/NativeExpressionsModule.java | 66 ++ .../expressions/NativeExpressionsPlugin.java | 41 ++ .../NativeSidecarExpressionInterpreter.java | 95 +++ .../RowExpressionDeserializer.java | 52 ++ .../expressions/RowExpressionSerializer.java | 52 ++ .../com.facebook.presto.spi.CoordinatorPlugin | 2 + .../TestNativeExpressionsPlugin.java | 40 ++ .../UnimplementedFunctionMetadataManager.java | 28 + .../UnimplementedFunctionResolution.java | 223 +++++++ .../expressions/UnimplementedNodeManager.java | 117 ++++ .../UnimplementedRowExpressionSerde.java | 33 + .../presto/spark/PrestoSparkModule.java | 5 + .../presto/spi/RowExpressionSerde.java | 23 + .../planner/ExpressionOptimizerContext.java | 10 +- presto-tests/pom.xml | 17 + .../expressions/TestNativeExpressions.java | 227 +++++++ 31 files changed, 2022 insertions(+), 8 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java create mode 100644 presto-native-plugin/pom.xml create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsPlugin.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java create mode 100644 presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java create mode 100644 presto-native-plugin/src/main/resources/META-INF/services/com.facebook.presto.spi.CoordinatorPlugin create mode 100644 presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionsPlugin.java create mode 100644 presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionMetadataManager.java create mode 100644 presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionResolution.java create mode 100644 presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedNodeManager.java create mode 100644 presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedRowExpressionSerde.java create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java diff --git a/pom.xml b/pom.xml index e53fa8717e359..f3486341bc2f9 100644 --- a/pom.xml +++ b/pom.xml @@ -199,6 +199,7 @@ presto-singlestore presto-hana presto-openapi + presto-native-plugin @@ -892,6 +893,12 @@ ${project.version} + + com.facebook.presto + presto-native-plugin + ${project.version} + + com.facebook.hive hive-dwrf diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index 0c07b99aaab4e..cc2dee1bd61d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -23,6 +23,18 @@ public class HandleJsonModule implements Module { + private final HandleResolver handleResolver; + + public HandleJsonModule() + { + this(null); + } + + public HandleJsonModule(HandleResolver handleResolver) + { + this.handleResolver = handleResolver; + } + @Override public void configure(Binder binder) { @@ -38,6 +50,11 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(MetadataUpdateJacksonModule.class); - binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + if (handleResolver == null) { + binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + } + else { + binder.bind(HandleResolver.class).toInstance(handleResolver); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index dac2dd3440a92..59f70e5a79ad5 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -144,11 +144,13 @@ import com.facebook.presto.spi.ConnectorTypeSerde; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; @@ -187,6 +189,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -349,6 +352,7 @@ else if (serverConfig.isCoordinator()) { // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // schema properties binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); @@ -542,6 +546,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon jsonCodecBinder(binder).bindJsonCodec(SqlInvokedFunction.class); jsonCodecBinder(binder).bindJsonCodec(TaskSource.class); jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java index ca9f39404a48e..5631d8b0613da 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; @@ -49,15 +50,17 @@ public class ExpressionOptimizerManager private final AtomicReference rowExpressionInterpreter = new AtomicReference<>(); private final NodeManager nodeManager; private final FunctionAndTypeManager functionAndTypeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionResolution functionResolution; private final ExpressionOptimizer defaultExpressionOptimizer; @Inject - public ExpressionOptimizerManager(InternalNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, NodeInfo nodeInfo) + public ExpressionOptimizerManager(InternalNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, NodeInfo nodeInfo, RowExpressionSerde rowExpressionSerde) { requireNonNull(nodeManager, "nodeManager is null"); this.nodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); this.defaultExpressionOptimizer = new RowExpressionOptimizer(functionAndTypeManager); rowExpressionInterpreter.set(defaultExpressionOptimizer); @@ -84,7 +87,7 @@ public void loadExpressions(Map properties) checkArgument( rowExpressionInterpreter.compareAndSet( defaultExpressionOptimizer, - expressionOptimizerFactories.get(factoryName).createOptimizer(properties, new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution))), + expressionOptimizerFactories.get(factoryName).createOptimizer(properties, new ExpressionOptimizerContext(nodeManager, rowExpressionSerde, functionAndTypeManager, functionResolution))), "ExpressionManager is already loaded"); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java new file mode 100644 index 0000000000000..20cfcf151df6c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; + +import javax.inject.Inject; + +import java.nio.charset.StandardCharsets; + +import static java.util.Objects.requireNonNull; + +public class JsonCodecRowExpressionSerde + implements RowExpressionSerde +{ + private final JsonCodec codec; + + @Inject + public JsonCodecRowExpressionSerde(JsonCodec codec) + { + this.codec = requireNonNull(codec, "codec is null"); + } + + @Override + public String serialize(RowExpression expression) + { + return new String(codec.toBytes(expression), StandardCharsets.UTF_8); + } + + @Override + public RowExpression deserialize(String data) + { + return codec.fromBytes(data.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index d84eaeeec6bee..dad3f71e175a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -142,6 +142,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.StageExecutionDescriptor; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -167,6 +168,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -483,7 +485,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, blockEncodingManager, featuresConfig); - expressionOptimizerManager = new ExpressionOptimizerManager(nodeManager, getFunctionAndTypeManager(), nodeInfo); + // TODO: @tdm wire this in + expressionOptimizerManager = new ExpressionOptimizerManager(nodeManager, getFunctionAndTypeManager(), nodeInfo, new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); GlobalSystemConnectorFactory globalSystemConnectorFactory = new GlobalSystemConnectorFactory(ImmutableSet.of( new NodeSystemTable(nodeManager), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index 36f271133e194..4075560ed90ef 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -23,9 +23,11 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.Optimizer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.TypeProvider; @@ -49,6 +51,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlanDoesNotMatch; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -178,7 +181,8 @@ private List getMinimalOptimizers() new ExpressionOptimizerManager( new InMemoryNodeManager(), queryRunner.getFunctionAndTypeManager(), - new NodeInfo("test"))).rules())); + new NodeInfo("test"), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))).rules())); } private void inTransaction(Function transactionSessionConsumer) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 19568cb110fe9..ac12125a62646 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.Expression; @@ -39,6 +40,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -189,7 +191,7 @@ private static void assertSimplifies(String expression, String rowExpressionExpe Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); InMemoryNodeManager nodeManager = new InMemoryNodeManager(); - ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(nodeManager, METADATA.getFunctionAndTypeManager(), new NodeInfo("test")); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(nodeManager, METADATA.getFunctionAndTypeManager(), new NodeInfo("test"), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); expressionOptimizerManager.loadExpressions(); TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java index fb5088f9b72e6..045ebc8570039 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.google.common.collect.ImmutableList; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -34,6 +35,7 @@ import java.net.URI; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.testing.Assertions.assertInstanceOf; import static com.facebook.presto.block.BlockAssertions.toValues; import static com.facebook.presto.common.function.OperatorType.EQUAL; @@ -63,7 +65,7 @@ public void setUp() { InMemoryNodeManager inMemoryNodeManager = new InMemoryNodeManager(); inMemoryNodeManager.addNode(new ConnectorId("test"), new InternalNode("id", URI.create("id"), new NodeVersion("test"), false)); - ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(inMemoryNodeManager, METADATA.getFunctionAndTypeManager(), new NodeInfo("env")); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(inMemoryNodeManager, METADATA.getFunctionAndTypeManager(), new NodeInfo("env"), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); optimizer = new DelegatingRowExpressionOptimizer(METADATA, expressionOptimizerManager); } diff --git a/presto-native-plugin/pom.xml b/presto-native-plugin/pom.xml new file mode 100644 index 0000000000000..eeba938f82aa6 --- /dev/null +++ b/presto-native-plugin/pom.xml @@ -0,0 +1,123 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.290-SNAPSHOT + + + presto-native-plugin + Presto - Session Property Managers + + + ${project.parent.basedir} + 1.8 + 1.8 + + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + json + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + javax.inject + javax.inject + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-core + + + + com.facebook.airlift + http-client + + + + + com.facebook.presto + presto-spi + provided + + + + com.facebook.presto + presto-common + provided + + + + io.airlift + units + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + com.facebook.drift + drift-api + provided + + + + io.airlift + slice + provided + + + + org.openjdk.jol + jol-core + provided + + + + + com.facebook.presto + presto-testng-services + test + + + + org.testng + testng + test + + + + com.facebook.airlift + testing + test + + + diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java new file mode 100644 index 0000000000000..8bf10b20223d2 --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; + +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@BindingAnnotation +public @interface ForSidecarInfo +{ +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java new file mode 100644 index 0000000000000..05faf7a6cf101 --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java @@ -0,0 +1,610 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.IntermediateFormExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class NativeExpressionOptimizer + implements ExpressionOptimizer +{ + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution resolution; + private final NativeSidecarExpressionInterpreter rowExpressionInterpreterService; + + public NativeExpressionOptimizer( + NativeSidecarExpressionInterpreter rowExpressionInterpreterService, + FunctionMetadataManager functionMetadataManager, + StandardFunctionResolution resolution) + { + this.rowExpressionInterpreterService = requireNonNull(rowExpressionInterpreterService, "rowExpressionInterpreterService is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + @Override + public RowExpression optimize(RowExpression expression, Level level, ConnectorSession session) + { + CollectingVisitor collectingVisitor = new CollectingVisitor(functionMetadataManager, resolution, level); + ReplacingVisitor replacingVisitor = new ReplacingVisitor(); + + checkState(level.ordinal() <= EVALUATED.ordinal(), "optimize(SymbolResolver) not allowed for interpreter"); + ResolvedRowExpression resolvedExpression = expression.accept(collectingVisitor, null); + collectingVisitor.addRowExpressionToOptimize(resolvedExpression); + Map expressions = collectingVisitor.getExpressionsToOptimize(); + if (!expressions.isEmpty()) { + Map replacements = rowExpressionInterpreterService.optimizeBatch(session, expressions, level); + return toRowExpression( + expression.getSourceLocation(), + expression.accept(replacingVisitor, new ReplacingState(replacements)), + expression.getType()); + } + return expression; + } + + @Override + public Object optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) + { + CollectingVisitor collectingVisitor = new CollectingVisitor(functionMetadataManager, resolution, level); + ReplacingVisitor replacingVisitor = new ReplacingVisitor(); + + checkState(level.ordinal() <= EVALUATED.ordinal(), "optimize(SymbolResolver) not allowed for interpreter"); + ResolvedRowExpression resolvedExpression = expression.accept(collectingVisitor, (VariableResolver) variableResolver::apply); + collectingVisitor.addRowExpressionToOptimize(resolvedExpression); + Map expressions = collectingVisitor.getExpressionsToOptimize(); + if (!expressions.isEmpty()) { + Map replacements = rowExpressionInterpreterService.optimizeBatch(session, expressions, level); + return toRowExpression( + expression.getSourceLocation(), + expression.accept(replacingVisitor, new ReplacingState(replacements)), + expression.getType()); + } + return expression; + } + + public interface VariableResolver + { + Object getValue(VariableReferenceExpression variable); + } + + private static final boolean CAN_BE_OPTIMIZED = true; + private static final boolean CANNOT_BE_OPTIMIZED = false; + + private static class ReplacingState + { + private final Map replacements; + + public ReplacingState(Map replacements) + { + this.replacements = requireNonNull(replacements, "replacements is null"); + } + + public Map getReplacements() + { + return replacements; + } + } + + private static class ResolvedRowExpression + { + private final boolean canBeOptimized; + private final boolean anyChildrenCanBeOptimized; + private final RowExpression originalExpression; + private final RowExpression resolvedExpression; + private final Set children; + + private ResolvedRowExpression(boolean canBeOptimized, RowExpression originalExpression, RowExpression resolvedExpression, Set children) + { + this.canBeOptimized = canBeOptimized; + this.anyChildrenCanBeOptimized = canBeOptimized || children.stream().anyMatch(ResolvedRowExpression::anyChildrenCanBeOptimized); + this.originalExpression = requireNonNull(originalExpression, "originalExpression is null"); + this.resolvedExpression = requireNonNull(resolvedExpression, "resolvedExpression is null"); + this.children = children instanceof HashSet ? children : new HashSet<>(children); + } + + public ResolvedRowExpression(boolean canBeOptimized, RowExpression originalExpression, RowExpression resolvedExpression, ResolvedRowExpression... children) + { + this(canBeOptimized, originalExpression, resolvedExpression, new HashSet<>(Arrays.asList(children))); + } + + public ResolvedRowExpression(boolean canBeOptimized, RowExpression originalExpression, List children) + { + this( + canBeOptimized, + originalExpression, + originalExpression.accept(new ResolvingVisitor(children), null), + new HashSet<>(children)); + } + + public ResolvedRowExpression(boolean canBeOptimized, RowExpression originalExpression, ResolvedRowExpression... children) + { + this(canBeOptimized, originalExpression, Arrays.asList(children)); + } + + public boolean canBeOptimized() + { + return canBeOptimized; + } + + public boolean anyChildrenCanBeOptimized() + { + return anyChildrenCanBeOptimized; + } + + public RowExpression getOriginalExpression() + { + return originalExpression; + } + + public RowExpression getResolvedExpression() + { + return resolvedExpression; + } + + public Set getChildren() + { + return children; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResolvedRowExpression that = (ResolvedRowExpression) o; + return canBeOptimized == that.canBeOptimized && anyChildrenCanBeOptimized == that.anyChildrenCanBeOptimized && Objects.equals(resolvedExpression, that.resolvedExpression) && Objects.equals(children, that.children); + } + + @Override + public int hashCode() + { + return Objects.hash(canBeOptimized, anyChildrenCanBeOptimized, resolvedExpression, children); + } + + @Override + public String toString() + { + return "ResolvedRowExpression{" + + ", canBeOptimized=" + canBeOptimized + + ", anyChildrenCanBeOptimized=" + anyChildrenCanBeOptimized + + ", originalExpression=" + originalExpression + + ", resolvedExpression=" + resolvedExpression + + ", children=" + children + + '}'; + } + } + + private static class ResolvingVisitor + implements RowExpressionVisitor + { + private final List resolvedChildren; + + public ResolvingVisitor(List resolvedChildren) + { + this.resolvedChildren = requireNonNull(resolvedChildren, "resolvedChildren is null"); + } + + @Override + public RowExpression visitExpression(RowExpression node, Void context) + { + return node; + } + + @Override + public RowExpression visitCall(CallExpression node, Void context) + { + return new CallExpression( + node.getSourceLocation(), + node.getDisplayName(), + node.getFunctionHandle(), + node.getType(), + resolvedChildren.stream().map(ResolvedRowExpression::getResolvedExpression).collect(toImmutableList())); + } + + @Override + public RowExpression visitSpecialForm(SpecialFormExpression node, Void context) + { + return new SpecialFormExpression( + node.getSourceLocation(), + node.getForm(), + node.getType(), + resolvedChildren.stream().map(ResolvedRowExpression::getResolvedExpression).collect(toImmutableList())); + } + + @Override + public RowExpression visitLambda(LambdaDefinitionExpression node, Void context) + { + return new LambdaDefinitionExpression( + node.getSourceLocation(), + node.getArgumentTypes(), + node.getArguments(), + resolvedChildren.get(0).getResolvedExpression()); + } + } + + private class CollectingVisitor + implements RowExpressionVisitor + { + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution resolution; + private final Level optimizationLevel; + + public CollectingVisitor(FunctionMetadataManager functionMetadataManager, StandardFunctionResolution resolution, Level optimizationLevel) + { + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + this.optimizationLevel = requireNonNull(optimizationLevel, "optimizationLevel is null"); + } + private final Set expressionsToOptimize = new HashSet<>(); + + @Override + public ResolvedRowExpression visitInputReference(InputReferenceExpression node, Object context) + { + return new ResolvedRowExpression(CANNOT_BE_OPTIMIZED, node); + } + + @Override + public ResolvedRowExpression visitConstant(ConstantExpression node, Object context) + { + return new ResolvedRowExpression(CAN_BE_OPTIMIZED, node); + } + + @Override + public ResolvedRowExpression visitVariableReference(VariableReferenceExpression node, Object context) + { + if (context instanceof VariableResolver) { + Object value = ((VariableResolver) context).getValue(node); + if (value != null) { + if (value instanceof RowExpression) { + return new ResolvedRowExpression(CANNOT_BE_OPTIMIZED, (RowExpression) value); + } + return new ResolvedRowExpression(CAN_BE_OPTIMIZED, node, new ConstantExpression(node.getSourceLocation(), value, node.getType())); + } + } + return new ResolvedRowExpression(CANNOT_BE_OPTIMIZED, node); + } + + @Override + public ResolvedRowExpression visitCall(CallExpression node, Object context) + { + List arguments = node.getArguments(); + List resolvedArguments = new ArrayList<>(); + for (RowExpression argument : arguments) { + ResolvedRowExpression returned = argument.accept(this, context); + resolvedArguments.add(returned); + } + + FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()); + boolean canBeEvaluated = (optimizationLevel.ordinal() < EVALUATED.ordinal() && functionMetadata.isDeterministic()) || optimizationLevel.ordinal() == EVALUATED.ordinal(); + if (node.getArguments().isEmpty()) { + return new ResolvedRowExpression(canBeEvaluated ? CAN_BE_OPTIMIZED : CANNOT_BE_OPTIMIZED, node); + } + + boolean anyConstantFoldable = false; + boolean allConstantFoldable = true; + + for (ResolvedRowExpression returned : resolvedArguments) { + boolean constantFoldable = returned.canBeOptimized(); + anyConstantFoldable = anyConstantFoldable || constantFoldable; + allConstantFoldable = allConstantFoldable && constantFoldable; + } + + FunctionHandle functionHandle = node.getFunctionHandle(); + + if (resolution.isCastFunction(functionHandle)) { + // If the function is a cast function, we can optimize it if the argument is constant foldable + // e.g. CAST(1 AS BIGINT) is 1 + ResolvedRowExpression resolved = node.getArguments().get(0).accept(this, context); + boolean canBeOptimized = resolved.canBeOptimized() || node.getArguments().get(0).getType().equals(node.getType()); + ResolvedRowExpression rowExpression = new ResolvedRowExpression( + // If the destination type and the source type are the same, the cast is a no-op, mark it as constant foldable + canBeOptimized, + node, + resolved); + addRowExpressionToOptimize(rowExpression); + return rowExpression; + } + + boolean canBeOptimized = anyConstantFoldable && canBeEvaluated; + ResolvedRowExpression resolvedRowExpression = new ResolvedRowExpression( + canBeOptimized, + // If this node can be optimized, then the children shouldn't be + node, + resolvedArguments); + addRowExpressionToOptimize(resolvedRowExpression); + return resolvedRowExpression; + } + + @Override + public ResolvedRowExpression visitSpecialForm(SpecialFormExpression node, Object context) + { + if (node.getForm() == SWITCH) { + return handleSwitchExpression(node, context); + } + + ImmutableList.Builder resolvedArgumentsBuilder = ImmutableList.builder(); + boolean anyArgumentsConstantFoldable = false; + boolean allArgumentsConstantFoldable = true; + + for (RowExpression argument : node.getArguments()) { + ResolvedRowExpression returned = argument.accept(this, context); + resolvedArgumentsBuilder.add(returned); + + boolean canBeOptimized = returned.canBeOptimized(); + anyArgumentsConstantFoldable = anyArgumentsConstantFoldable || canBeOptimized; + allArgumentsConstantFoldable = allArgumentsConstantFoldable && canBeOptimized; + } + List resolvedArguments = resolvedArgumentsBuilder.build(); + + switch (node.getForm()) { + case IF: { + ResolvedRowExpression returned = node.getArguments().get(0).accept(this, context); + // If the first argument is constant foldable, the whole expression is constant foldable + boolean canBeOptimized = returned.canBeOptimized(); + ResolvedRowExpression rowExpression = new ResolvedRowExpression( + canBeOptimized, + node, + resolvedArguments); + addRowExpressionToOptimize(rowExpression); + return rowExpression; + } + case COALESCE: { + ImmutableSet.Builder builder = ImmutableSet.builder(); + // Check if there's any duplicate arguments, these can be de-duplicated + for (ResolvedRowExpression argument : resolvedArguments) { + ResolvedRowExpression returned = argument.getResolvedExpression().accept(this, context); + // The duplicate argument must either be a leaf (variable reference) or constant foldable + if (returned.canBeOptimized()) { + builder.add(argument); + } + } + // If there were any duplicates, or if there's no arguments (cancel out), or if there's only one argument (just return it), + // then it's also constant foldable + boolean canBeOptimized = builder.build().size() <= resolvedArguments.size() || resolvedArguments.size() <= 1; + ResolvedRowExpression rowExpression = new ResolvedRowExpression( + canBeOptimized, + node, + resolvedArguments); + addRowExpressionToOptimize(rowExpression); + return rowExpression; + } + default: + ResolvedRowExpression rowExpression = new ResolvedRowExpression( + anyArgumentsConstantFoldable, + node, + resolvedArguments); + addRowExpressionToOptimize(rowExpression); + return rowExpression; + } + } + + /** + * Switch statements require special handling, because when statements require special handling ({@code RowExpressionInterpreter} can't handle them). + * This method will resolve the expression and all when clauses, and if any part of the switch expression is constant foldable, the entire switch + * expression will be sent to the delegated expression optimizer. + */ + private ResolvedRowExpression handleSwitchExpression(SpecialFormExpression node, Object context) + { + // First argument is the expression, follow by N when clauses, and an optional else clause + RowExpression expression = node.getArguments().get(0); + Optional elseClause = Optional.empty(); + + // Collect all when clauses + List whenClauses = buildWhenClauses(node); + // Determine if the final clause is an else clause or a when clause + RowExpression finalClause = node.getArguments().get(node.getArguments().size() - 1); + if (finalClause instanceof SpecialFormExpression && ((SpecialFormExpression) finalClause).getForm() == WHEN) { + whenClauses = ImmutableList.builder().addAll(whenClauses).add(finalClause).build(); + } + else { + elseClause = Optional.of(finalClause); + } + + boolean canBeOptimized; + ResolvedRowExpression resolvedExpression; + List resolvedWhenClauses = new ArrayList<>(); + Optional resolvedElseClause = Optional.empty(); + + // First determine if the expression is constant foldable + resolvedExpression = expression.accept(this, context); + canBeOptimized = resolvedExpression.canBeOptimized(); + addRowExpressionToOptimize(resolvedExpression); + + // Next determine if all when clauses are constant foldable + for (RowExpression whenClause : whenClauses) { + SpecialFormExpression whenClauseSpecialForm = (SpecialFormExpression) whenClause; + List whenClauseArguments = whenClauseSpecialForm.getArguments(); + checkArgument(whenClauseArguments.size() == 2, "WHEN clause must have 2 arguments, got [%s]", whenClauseArguments); + ResolvedRowExpression resolvedArgument = whenClauseArguments.get(0).accept(this, context); + canBeOptimized = canBeOptimized || resolvedArgument.canBeOptimized(); + + ResolvedRowExpression thenClause = whenClauseArguments.get(1).accept(this, context); + + // Create a rewritten when clause that's resolved all variables + ResolvedRowExpression resolvedWhenClause = new ResolvedRowExpression( + resolvedArgument.canBeOptimized() || thenClause.canBeOptimized(), + whenClauseSpecialForm, + resolvedArgument, + thenClause); + addRowExpressionToOptimize(resolvedWhenClause); + resolvedWhenClauses.add(resolvedWhenClause); + } + + // Resolve the else clause if it exists + if (elseClause.isPresent()) { + ResolvedRowExpression elseExpression = elseClause.get().accept(this, context); + resolvedElseClause = Optional.of(elseExpression); + canBeOptimized = canBeOptimized || elseExpression.canBeOptimized(); + addRowExpressionToOptimize(elseExpression); + } + + ImmutableList.Builder resolvedArguments = ImmutableList.builder().add(resolvedExpression).addAll(resolvedWhenClauses); + resolvedElseClause.ifPresent(resolvedArguments::add); + + ResolvedRowExpression rowExpression = new ResolvedRowExpression( + canBeOptimized, + node, + resolvedArguments.build()); + // If any part of the entire switch expression is constant foldable, send the whole thing over + addRowExpressionToOptimize(rowExpression); + // Otherwise it's not constant foldable. + return rowExpression; + } + + private List buildWhenClauses(SpecialFormExpression node) + { + ImmutableList.Builder whenClausesBuilder = ImmutableList.builder(); + for (int i = 1; i < node.getArguments().size() - 1; i++) { + whenClausesBuilder.add(node.getArguments().get(i)); + } + return whenClausesBuilder.build(); + } + + @Override + public ResolvedRowExpression visitLambda(LambdaDefinitionExpression node, Object context) + { + ResolvedRowExpression resolvedBody = node.getBody().accept(this, null); + return new ResolvedRowExpression( + resolvedBody.canBeOptimized(), + node, + resolvedBody); + } + + @Override + public ResolvedRowExpression visitIntermediateFormExpression(IntermediateFormExpression intermediateFormExpression, Object context) + { + return new ResolvedRowExpression(CANNOT_BE_OPTIMIZED, intermediateFormExpression); + } + + private void removeChildren(ResolvedRowExpression resolvedRowExpression) + { + resolvedRowExpression.getChildren().forEach(this::removeChildren); + resolvedRowExpression.getChildren().clear(); + expressionsToOptimize.remove(resolvedRowExpression); + } + + private void addRowExpressionToOptimize(ResolvedRowExpression resolvedRowExpression) + { + if (resolvedRowExpression.canBeOptimized()) { + removeChildren(resolvedRowExpression); + expressionsToOptimize.add(resolvedRowExpression); + } + } + + public Map getExpressionsToOptimize() + { + return expressionsToOptimize.stream().collect(toImmutableMap( + ResolvedRowExpression::getOriginalExpression, + ResolvedRowExpression::getResolvedExpression)); + } + } + + private class ReplacingVisitor + implements RowExpressionVisitor + { + @Override + public Object visitExpression(RowExpression originalExpression, ReplacingState context) + { + return context.getReplacements().getOrDefault(originalExpression, originalExpression); + } + + @Override + public Object visitLambda(LambdaDefinitionExpression lambda, ReplacingState context) + { + if (context.getReplacements().containsKey(lambda.getBody())) { + return new LambdaDefinitionExpression( + lambda.getSourceLocation(), + lambda.getArgumentTypes(), + lambda.getArguments(), + toRowExpression(lambda.getSourceLocation(), context.getReplacements().get(lambda.getBody()), lambda.getBody().getType())); + } + return lambda; + } + + @Override + public Object visitCall(CallExpression call, ReplacingState context) + { + if (context.getReplacements().containsKey(call)) { + return context.getReplacements().get(call); + } + List updatedArguments = call.getArguments().stream() + .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType())) + .collect(toImmutableList()); + return new CallExpression(call.getSourceLocation(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), updatedArguments); + } + + @Override + public Object visitSpecialForm(SpecialFormExpression specialForm, ReplacingState context) + { + if (context.getReplacements().containsKey(specialForm)) { + return context.getReplacements().get(specialForm); + } + List updatedArguments = specialForm.getArguments().stream() + .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType())) + .collect(toImmutableList()); + return new SpecialFormExpression(specialForm.getSourceLocation(), specialForm.getForm(), specialForm.getType(), updatedArguments); + } + } + + private static RowExpression toRowExpression(Optional sourceLocation, Object object, Type type) + { + requireNonNull(type, "type is null"); + + if (object instanceof RowExpression) { + return (RowExpression) object; + } + + return new ConstantExpression(sourceLocation, object, type); + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java new file mode 100644 index 0000000000000..4157cffe2c839 --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.google.inject.Injector; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class NativeExpressionOptimizerFactory + implements ExpressionOptimizerFactory +{ + private final ClassLoader classLoader; + + public NativeExpressionOptimizerFactory(ClassLoader classLoader) + { + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public ExpressionOptimizer createOptimizer(Map config, ExpressionOptimizerContext context) + { + requireNonNull(context, "context is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap( + new JsonModule(), + new NativeExpressionsCommunicationModule(), + new NativeExpressionsModule(context.getNodeManager(), context.getRowExpressionSerde(), context.getFunctionMetadataManager(), context.getFunctionResolution())); + + Injector injector = app + .noStrictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .quiet() + .initialize(); + return injector.getInstance(NativeExpressionOptimizerProvider.class).createOptimizer(); + } + } + + @Override + public String getName() + { + return "native"; + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java new file mode 100644 index 0000000000000..8efe3b8a77eb5 --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.ExpressionOptimizer; + +import javax.inject.Inject; + +import static java.util.Objects.requireNonNull; + +public class NativeExpressionOptimizerProvider +{ + private final NativeSidecarExpressionInterpreter expressionInterpreterService; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution resolution; + + @Inject + public NativeExpressionOptimizerProvider(NativeSidecarExpressionInterpreter expressionInterpreterService, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution resolution) + { + this.expressionInterpreterService = requireNonNull(expressionInterpreterService, "expressionInterpreterService is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + public ExpressionOptimizer createOptimizer() + { + return new NativeExpressionOptimizer(expressionInterpreterService, functionMetadataManager, resolution); + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java new file mode 100644 index 0000000000000..fdc6da53f91bc --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.google.inject.Binder; +import com.google.inject.Module; + +import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; + +public class NativeExpressionsCommunicationModule + implements Module +{ + @Override + public void configure(Binder binder) + { + httpClientBinder(binder).bindHttpClient("sidecar", ForSidecarInfo.class); + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java new file mode 100644 index 0000000000000..046b4e60eac7a --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpression; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static java.util.Objects.requireNonNull; + +public class NativeExpressionsModule + implements Module +{ + private final NodeManager nodeManager; + private final RowExpressionSerde rowExpressionSerde; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution functionResolution; + + public NativeExpressionsModule(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); + } + + @Override + public void configure(Binder binder) + { + // Core dependencies + binder.bind(NodeManager.class).toInstance(nodeManager); + binder.bind(RowExpressionSerde.class).toInstance(rowExpressionSerde); + binder.bind(FunctionMetadataManager.class).toInstance(functionMetadataManager); + binder.bind(StandardFunctionResolution.class).toInstance(functionResolution); + + // JSON dependencies and setup + binder.install(new JsonModule()); + jsonBinder(binder).addDeserializerBinding(RowExpression.class).to(RowExpressionDeserializer.class).in(Scopes.SINGLETON); + jsonBinder(binder).addSerializerBinding(RowExpression.class).to(RowExpressionSerializer.class).in(Scopes.SINGLETON); + jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class); + + binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON); + + // The main service provider + binder.bind(NativeExpressionOptimizerProvider.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsPlugin.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsPlugin.java new file mode 100644 index 0000000000000..14a23b87be31a --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsPlugin.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.CoordinatorPlugin; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.google.common.collect.ImmutableList; + +public class NativeExpressionsPlugin + implements CoordinatorPlugin +{ + @Override + public Iterable getRowExpressionInterpreterServiceFactories() + { + return ImmutableList.of(new NativeExpressionOptimizerFactory(getClassLoader())); + } + + private static ClassLoader getClassLoader() + { + return firstNonNull(Thread.currentThread().getContextClassLoader(), NativeExpressionsPlugin.class.getClassLoader()); + } + + private static ClassLoader firstNonNull(ClassLoader contextClassLoader, ClassLoader classLoader) + { + if (contextClassLoader != null) { + return contextClassLoader; + } + return classLoader; + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java new file mode 100644 index 0000000000000..165ba866b6371 --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.inject.Inject; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.preparePost; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.net.HttpHeaders.ACCEPT; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.util.Objects.requireNonNull; + +public class NativeSidecarExpressionInterpreter +{ + private final NodeManager nodeManager; + private final HttpClient httpClient; + private final JsonCodec> rowExpressionSerde; + + @Inject + public NativeSidecarExpressionInterpreter(NodeManager nodeManager, @ForSidecarInfo HttpClient httpClient, JsonCodec> rowExpressionSerde) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + public Map optimizeBatch(ConnectorSession session, Map expressions, ExpressionOptimizer.Level level) + { + ImmutableList.Builder unaliasedBuilder = ImmutableList.builder(); + ImmutableList.Builder aliasedBuilder = ImmutableList.builder(); + for (Map.Entry entry : expressions.entrySet()) { + unaliasedBuilder.add(entry.getKey()); + aliasedBuilder.add(entry.getValue()); + } + List unaliased = unaliasedBuilder.build(); + List aliased = aliasedBuilder.build(); + + Request request = preparePost() + .setUri(getLocation()) + .setBodyGenerator(jsonBodyGenerator(rowExpressionSerde, aliased)) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .setHeader(ACCEPT, JSON_UTF_8.toString()) + .build(); + + List optimized = httpClient.execute(request, createJsonResponseHandler(rowExpressionSerde)); + checkArgument(optimized.size() == aliased.size(), "Expected %s optimized expressions, but got %s", aliased.size(), optimized.size()); + + ImmutableMap.Builder result = ImmutableMap.builder(); + for (int i = 0; i < optimized.size(); i++) { + result.put(unaliased.get(i), optimized.get(i)); + } + return result.build(); + } + + private URI getLocation() + { + Node sidecarNode = nodeManager.getSidecarNode(); + return HttpUriBuilder.uriBuilder() + .scheme("http") // The sidecar is presumed to be colocated with the coordinator + .host(sidecarNode.getHost()) + .port(sidecarNode.getHostAndPort().getPort()) + .appendPath("/v1/expressions") + .build(); + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java new file mode 100644 index 0000000000000..ffc7039d4568b --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; +import com.google.inject.Inject; + +import java.io.IOException; + +import static java.util.Objects.requireNonNull; + +public final class RowExpressionDeserializer + extends JsonDeserializer +{ + private final RowExpressionSerde rowExpressionSerde; + + @Inject + public RowExpressionDeserializer(RowExpressionSerde rowExpressionSerde) + { + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + @Override + public RowExpression deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException + { + return rowExpressionSerde.deserialize(jsonParser.readValueAsTree().toString()); + } + + @Override + public RowExpression deserializeWithType(JsonParser jsonParser, DeserializationContext context, TypeDeserializer typeDeserializer) + throws IOException + { + return deserialize(jsonParser, context); + } +} diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java new file mode 100644 index 0000000000000..6ca5dcc4354f4 --- /dev/null +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.jsontype.TypeSerializer; +import com.google.inject.Inject; + +import java.io.IOException; + +import static java.util.Objects.requireNonNull; + +public final class RowExpressionSerializer + extends JsonSerializer +{ + private final RowExpressionSerde rowExpressionSerde; + + @Inject + public RowExpressionSerializer(RowExpressionSerde rowExpressionSerde) + { + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + @Override + public void serialize(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) + throws IOException + { + jsonGenerator.writeRawValue(rowExpressionSerde.serialize(rowExpression)); + } + + @Override + public void serializeWithType(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider, TypeSerializer typeSerializer) + throws IOException + { + serialize(rowExpression, jsonGenerator, serializerProvider); + } +} diff --git a/presto-native-plugin/src/main/resources/META-INF/services/com.facebook.presto.spi.CoordinatorPlugin b/presto-native-plugin/src/main/resources/META-INF/services/com.facebook.presto.spi.CoordinatorPlugin new file mode 100644 index 0000000000000..0ef5589b94dfd --- /dev/null +++ b/presto-native-plugin/src/main/resources/META-INF/services/com.facebook.presto.spi.CoordinatorPlugin @@ -0,0 +1,2 @@ +com.facebook.presto.session.sql.expressions.NativeExpressionsPlugin + diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionsPlugin.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionsPlugin.java new file mode 100644 index 0000000000000..f6521926142a2 --- /dev/null +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionsPlugin.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.CoordinatorPlugin; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestNativeExpressionsPlugin +{ + @Test + public void testLoadPlugin() + { + CoordinatorPlugin plugin = new NativeExpressionsPlugin(); + Iterable serviceFactories = plugin.getRowExpressionInterpreterServiceFactories(); + ExpressionOptimizerFactory factory = getOnlyElement(serviceFactories); + factory.createOptimizer( + ImmutableMap.of(), + new ExpressionOptimizerContext( + new UnimplementedNodeManager(), + new UnimplementedRowExpressionSerde(), + new UnimplementedFunctionMetadataManager(), + new UnimplementedFunctionResolution())); + } +} diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionMetadataManager.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionMetadataManager.java new file mode 100644 index 0000000000000..0cb1ccddfcec9 --- /dev/null +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionMetadataManager.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.FunctionMetadataManager; + +class UnimplementedFunctionMetadataManager + implements FunctionMetadataManager +{ + @Override + public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionResolution.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionResolution.java new file mode 100644 index 0000000000000..b70fe03d113e9 --- /dev/null +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedFunctionResolution.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.StandardFunctionResolution; + +import java.util.List; + +class UnimplementedFunctionResolution + implements StandardFunctionResolution +{ + @Override + public FunctionHandle notFunction() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNotFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle negateFunction(Type type) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNegateFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle likeVarcharFunction() + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle likeCharFunction(Type valueType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isLikeFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle likePatternFunction() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isLikePatternFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle arrayConstructor(List argumentTypes) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle arithmeticFunction(OperatorType operator, Type leftType, Type rightType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isArithmeticFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle comparisonFunction(OperatorType operator, Type leftType, Type rightType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isComparisonFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isEqualsFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle betweenFunction(Type valueType, Type lowerBoundType, Type upperBoundType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isBetweenFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle subscriptFunction(Type baseType, Type indexType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isSubscriptFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCastFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCountFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle countFunction() + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle countFunction(Type valueType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isMaxFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle maxFunction(Type valueType) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle greatestFunction(List valueTypes) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isMinFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle minFunction(Type valueType) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle leastFunction(List valueTypes) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isApproximateCountDistinctFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle approximateCountDistinctFunction(Type valueType) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isApproximateSetFunction(FunctionHandle functionHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionHandle approximateSetFunction(Type valueType) + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedNodeManager.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedNodeManager.java new file mode 100644 index 0000000000000..a8d7c5e31c4b7 --- /dev/null +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedNodeManager.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.NodePoolType; + +import java.net.URI; +import java.util.Set; + +class UnimplementedNodeManager + implements NodeManager +{ + @Override + public Set getAllNodes() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getWorkerNodes() + { + throw new UnsupportedOperationException(); + } + + @Override + public Node getCurrentNode() + { + return new Node() + { + @Override + public String getHost() + { + throw new UnsupportedOperationException(); + } + + @Override + public HostAddress getHostAndPort() + { + throw new UnsupportedOperationException(); + } + + @Override + public URI getHttpUri() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getNodeIdentifier() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getVersion() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCoordinator() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isResourceManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCatalogServer() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCoordinatorSidecar() + { + throw new UnsupportedOperationException(); + } + + @Override + public NodePoolType getPoolType() + { + throw new UnsupportedOperationException(); + } + }; + } + + @Override + public Node getSidecarNode() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getEnvironment() + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedRowExpressionSerde.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedRowExpressionSerde.java new file mode 100644 index 0000000000000..c86447be0fa91 --- /dev/null +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/UnimplementedRowExpressionSerde.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; + +class UnimplementedRowExpressionSerde + implements RowExpressionSerde +{ + @Override + public String serialize(RowExpression expression) + { + throw new UnsupportedOperationException(); + } + + @Override + public RowExpression deserialize(String value) + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index aeeacdcf588f6..0c2c489ef45ad 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -134,11 +134,13 @@ import com.facebook.presto.spi.ConnectorTypeSerde; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -171,6 +173,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -293,6 +296,7 @@ protected void setup(Binder binder) jsonCodecBinder(binder).bindJsonCodec(PrestoSparkLocalShuffleWriteInfo.class); jsonCodecBinder(binder).bindJsonCodec(BatchTaskUpdateRequest.class); jsonCodecBinder(binder).bindJsonCodec(BroadcastFileInfo.class); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); // smile codecs smileCodecBinder(binder).bindSmileCodec(TaskSource.class); @@ -339,6 +343,7 @@ protected void setup(Binder binder) // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // tracer provider managers binder.bind(TracerProviderManager.class).in(Scopes.SINGLETON); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java new file mode 100644 index 0000000000000..ab5381aa2556c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import com.facebook.presto.spi.relation.RowExpression; + +public interface RowExpressionSerde +{ + String serialize(RowExpression expression); + + RowExpression deserialize(String value); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java index c9a6f84aaaa1d..51b7cce149d8f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.sql.planner; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; @@ -22,12 +23,14 @@ public class ExpressionOptimizerContext { private final NodeManager nodeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionMetadataManager functionMetadataManager; private final StandardFunctionResolution functionResolution; - public ExpressionOptimizerContext(NodeManager nodeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + public ExpressionOptimizerContext(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); } @@ -37,6 +40,11 @@ public NodeManager getNodeManager() return nodeManager; } + public RowExpressionSerde getRowExpressionSerde() + { + return rowExpressionSerde; + } + public FunctionMetadataManager getFunctionMetadataManager() { return functionMetadataManager; diff --git a/presto-tests/pom.xml b/presto-tests/pom.xml index bb806244f3027..a2d03b17b6cfb 100644 --- a/presto-tests/pom.xml +++ b/presto-tests/pom.xml @@ -244,6 +244,11 @@ javax.servlet-api + + com.facebook.airlift + jaxrs + + com.facebook.airlift @@ -277,6 +282,18 @@ test + + com.facebook.presto + presto-native-plugin + test + + + + com.facebook.airlift + jaxrs-testing + test + + org.openjdk.jmh diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java new file mode 100644 index 0000000000000..66e786923fe68 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java @@ -0,0 +1,227 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.airlift.jaxrs.JsonMapper; +import com.facebook.airlift.jaxrs.testing.JaxrsTestingHttpProcessor; +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.block.BlockJsonSerde; +import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncoding; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.session.sql.expressions.ForSidecarInfo; +import com.facebook.presto.session.sql.expressions.NativeExpressionOptimizerProvider; +import com.facebook.presto.session.sql.expressions.NativeExpressionsModule; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; +import com.facebook.presto.sql.planner.RowExpressionInterpreter; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.type.TypeDeserializer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Injector; +import com.google.inject.Module; +import com.google.inject.Scopes; +import org.testng.annotations.Test; + +import javax.ws.rs.Consumes; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.UriBuilder; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; +import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; + +public class TestNativeExpressions +{ + public static final URI SIDECAR_URI = URI.create("http://127.0.0.1:1122"); + private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + @Test + public void testLoadPlugin() + { + ExpressionOptimizer interpreterService = getExpressionOptimizer(METADATA, null); + + // Test the native row expression interpreter service with some simple expressions + RowExpression simpleAddition = compileExpression("1+1"); + RowExpression unnecessaryCoalesce = compileExpression("coalesce(1, 2)"); + + // Assert simple optimizations are performed + assertEquals(interpreterService.optimize(simpleAddition, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(2L, simpleAddition.getType())); + assertEquals(interpreterService.optimize(unnecessaryCoalesce, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(1L, unnecessaryCoalesce.getType())); + } + + private static RowExpression compileExpression(String expression) + { + return TRANSLATOR.translate(expression, ImmutableMap.of()); + } + + protected static ExpressionOptimizer getExpressionOptimizer(Metadata metadata, HandleResolver handleResolver) + { + // Set up dependencies in main for this module + InMemoryNodeManager nodeManager = getNodeManagerWithSidecar(SIDECAR_URI); + Injector prestoMainInjector = getPrestoMainInjector(metadata, handleResolver); + JsonMapper jsonMapper = prestoMainInjector.getInstance(JsonMapper.class); + RowExpressionSerde rowExpressionSerde = prestoMainInjector.getInstance(RowExpressionSerde.class); + FunctionAndTypeManager functionMetadataManager = prestoMainInjector.getInstance(FunctionAndTypeManager.class); + + // Set up the mock HTTP endpoint that delegates to the Java based row expression interpreter + TestingExpressionOptimizerResource resource = new TestingExpressionOptimizerResource( + metadata.getFunctionAndTypeManager(), + testSessionBuilder().build().toConnectorSession(), + SERIALIZABLE); + JaxrsTestingHttpProcessor jaxrsTestingHttpProcessor = new JaxrsTestingHttpProcessor( + UriBuilder.fromUri(SIDECAR_URI).path("/").build(), + resource, + jsonMapper); + TestingHttpClient testingHttpClient = new TestingHttpClient(jaxrsTestingHttpProcessor); + + // Create the native row expression interpreter service + return createExpressionOptimizer(nodeManager, rowExpressionSerde, testingHttpClient, functionMetadataManager); + } + + private static InMemoryNodeManager getNodeManagerWithSidecar(URI sidecarUri) + { + InMemoryNodeManager nodeManager = new InMemoryNodeManager(); + nodeManager.addNode(new ConnectorId("test"), new InternalNode("test", sidecarUri, NodeVersion.UNKNOWN, false, false, false, true)); + return nodeManager; + } + + private static ExpressionOptimizer createExpressionOptimizer(InternalNodeManager internalNodeManager, RowExpressionSerde rowExpressionSerde, HttpClient httpClient, FunctionAndTypeManager functionMetadataManager) + { + requireNonNull(internalNodeManager, "inMemoryNodeManager is null"); + NodeManager nodeManager = new PluginNodeManager(internalNodeManager); + FunctionResolution functionResolution = new FunctionResolution(functionMetadataManager.getFunctionAndTypeResolver()); + + Bootstrap app = new Bootstrap( + // Specially use a testing HTTP client instead of a real one + binder -> binder.bind(HttpClient.class).annotatedWith(ForSidecarInfo.class).toInstance(httpClient), + // Otherwise use the exact same module as the native row expression interpreter service + new NativeExpressionsModule(nodeManager, rowExpressionSerde, functionMetadataManager, functionResolution)); + + Injector injector = app + .noStrictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(ImmutableMap.of()) + .quiet() + .initialize(); + return injector.getInstance(NativeExpressionOptimizerProvider.class).createOptimizer(); + } + + private static Injector getPrestoMainInjector(Metadata metadata, HandleResolver handleResolver) + { + Module module = binder -> { + // Installs the JSON codec + binder.install(new JsonModule()); + // Required to deserialize function handles + binder.install(new HandleJsonModule(handleResolver)); + // Required for this test in the JaxrsTestingHttpProcessor because the underlying object mapper + // must be the same as all other object mappers + binder.bind(JsonMapper.class); + + // These dependencies are needed to serialize and deserialize types (found in expressions) + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + binder.bind(FunctionAndTypeManager.class).toInstance(functionAndTypeManager); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); + + // These dependencies are needed to serialize and deserialize blocks (found in constant values of expressions) + binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON); + newSetBinder(binder, BlockEncoding.class); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + + // Create the serde which is used by the plugin to serialize and deserialize expressions + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); + }; + Bootstrap app = new Bootstrap(ImmutableList.of(module)); + Injector injector = app + .doNotInitializeLogging() + .quiet() + .initialize(); + return injector; + } + + @Path("/v1/expressions") + public static class TestingExpressionOptimizerResource + { + private final FunctionAndTypeManager functionAndTypeManager; + private final ConnectorSession connectorSession; + private final ExpressionOptimizer.Level level; + + public TestingExpressionOptimizerResource(FunctionAndTypeManager functionAndTypeManager, ConnectorSession connectorSession, ExpressionOptimizer.Level level) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.connectorSession = requireNonNull(connectorSession, "connectorSession is null"); + this.level = requireNonNull(level, "level is null"); + } + + @POST + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public List post(List rowExpressions) + { + Map input = rowExpressions.stream().collect(toImmutableMap(i -> i, i -> i)); + Map optimizedExpressions = new HashMap<>(); + input.forEach((key, value) -> optimizedExpressions.put( + key, + new RowExpressionInterpreter(key, functionAndTypeManager, connectorSession, level).optimize())); + ImmutableList.Builder builder = ImmutableList.builder(); + for (RowExpression inputExpression : rowExpressions) { + builder.add(toRowExpression(optimizedExpressions.get(inputExpression), inputExpression.getType())); + } + return builder.build(); + } + } +} From 10862ebc7ea31124d54cdb06ac2e43aa39f375fa Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Thu, 8 Aug 2024 11:16:47 -0400 Subject: [PATCH 06/10] Add OpenAPI documentation for /v1/expressions --- .../src/main/resources/expressions.yaml | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 presto-openapi/src/main/resources/expressions.yaml diff --git a/presto-openapi/src/main/resources/expressions.yaml b/presto-openapi/src/main/resources/expressions.yaml new file mode 100644 index 0000000000000..8d1e3ef5d9379 --- /dev/null +++ b/presto-openapi/src/main/resources/expressions.yaml @@ -0,0 +1,173 @@ +openapi: 3.0.0 +info: + title: Presto Expression API + description: API for evaluating and simplifying row expressions in Presto + version: "1" +servers: + - url: http://localhost:8080 + description: Presto endpoint when running locally +paths: + /v1/expressions: + post: + summary: Simplify the list of row expressions + description: This endpoint takes in a list of row expressions and attempts to simplify them to their simplest logical equivalent expression. + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RowExpressions' + required: true + responses: + '200': + description: Results + content: + application/json: + schema: + $ref: '#/components/schemas/RowExpressions' +components: + schemas: + RowExpressions: + type: array + maxItems: 100 + items: + $ref: "#/components/schemas/RowExpression" + RowExpression: + oneOf: + - $ref: "#/components/schemas/ConstantExpression" + - $ref: "#/components/schemas/VariableReferenceExpression" + - $ref: "#/components/schemas/InputReferenceExpression" + - $ref: "#/components/schemas/LambdaDefinitionExpression" + - $ref: "#/components/schemas/SpecialFormExpression" + - $ref: "#/components/schemas/CallExpression" + RowExpressionParent: + type: object + properties: + sourceLocation: + $ref: "#/components/schemas/SourceLocation" + SourceLocation: + description: The source location of the row expression in the original query, referencing the line and the column of the query. + type: object + properties: + line: + type: integer + column: + type: integer + ConstantExpression: + description: A constant expression is a row expression that represents a constant value. The value attribute is the constant value. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["constant"] + typeSignature: + type: string + valueBlock: + type: string + VariableReferenceExpression: + description: A variable reference expression is a row expression that represents a reference to a variable. The name attribute indicates the name of the variable. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["variable"] + typeSignature: + type: string + name: + type: string + InputReferenceExpression: + description: > + An input reference expression is a row expression that represents a reference to a column in the input schema. The field attribute indicates the index of the column in the + input schema. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["input"] + typeSignature: + type: string + field: + type: integer + LambdaDefinitionExpression: + description: > + A lambda definition expression is a row expression that represents a lambda function. The lambda function is defined by a list of argument types, a list of argument names, + and a body expression. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["lambda"] + argumentTypeSignatures: + type: array + items: + type: string + arguments: + type: array + items: + type: string + body: + $ref: "#/components/schemas/RowExpression" + SpecialFormExpression: + description: > + A special form expression is a row expression that represents a special language construct. The form attribute indicates the specific form of the special form, + which is a well known list, and with each having special semantics. The arguments attribute is a list of row expressions that are the arguments to the special form, with + each form taking in a specific number of arguments. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["special"] + form: + type: string + enum: ["IF","NULL_IF","SWITCH","WHEN","IS_NULL","COALESCE","IN","AND","OR","DEREFERENCE","ROW_CONSTRUCTOR","BIND"] + returnTypeSignature: + type: string + arguments: + type: array + items: + $ref: "#/components/schemas/RowExpression" + CallExpression: + description: > + A call expression is a row expression that represents a call to a function. The functionHandle attribute is an opaque handle to the function that is being called. + The arguments attribute is a list of row expressions that are the arguments to the function. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["call"] + displayName: + type: string + functionHandle: + $ref: "#/components/schemas/FunctionHandle" + returnTypeSignature: + type: string + arguments: + type: array + items: + $ref: "#/components/schemas/RowExpression" + FunctionHandle: + description: An opaque handle to a function that may be invoked. This is interpreted by the registered function namespace manager. + anyOf: + - $ref: "#/components/schemas/OpaqueFunctionHandle" + - $ref: "#/components/schemas/SqlFunctionHandle" + OpaqueFunctionHandle: + type: object + properties: {} # any opaque object may be passed and interpreted by a function namespace manager + SqlFunctionHandle: + type: object + properties: + functionId: + type: string + version: + type: string From d07a76a821c8f82fb9bba226ca5fc941e919b739 Mon Sep 17 00:00:00 2001 From: auden-woolfson Date: Wed, 11 Sep 2024 16:23:23 -0400 Subject: [PATCH 07/10] Add end to end expression tests with Presto sideccar --- pom.xml | 14 ++ .../PrestoNativeQueryRunnerUtils.java | 18 ++- presto-native-plugin/pom.xml | 96 +++++++++++++ .../expressions/NativePluginQueryRunner.java | 48 +++++++ .../TestDelegatingExpressionOptimizer.java | 42 ++++-- .../TestNativeExpressionOptimization.java | 127 ++++++------------ presto-tests/pom.xml | 11 -- 7 files changed, 247 insertions(+), 109 deletions(-) create mode 100644 presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/NativePluginQueryRunner.java rename {presto-tests/src/test/java/com/facebook/presto/tests => presto-native-plugin/src/test/java/com/facebook/presto/session/sql}/expressions/TestDelegatingExpressionOptimizer.java (79%) rename presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java => presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimization.java (59%) diff --git a/pom.xml b/pom.xml index f3486341bc2f9..8418ee4390e33 100644 --- a/pom.xml +++ b/pom.xml @@ -698,6 +698,13 @@ ${project.version} + + com.facebook.presto + presto-tests + ${project.version} + test-jar + + com.facebook.presto presto-benchmark @@ -899,6 +906,13 @@ ${project.version} + + com.facebook.presto + presto-native-execution + ${project.version} + test-jar + + com.facebook.hive hive-dwrf diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java index 69bb2433b3f18..c245f17b1b394 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.net.ServerSocket; import java.net.URI; import java.nio.file.Files; import java.nio.file.Path; @@ -420,6 +421,11 @@ public static NativeQueryRunnerParameters getNativeQueryRunnerParameters() } public static Optional> getExternalWorkerLauncher(String catalogName, String prestoServerPath, int cacheMaxSize, Optional remoteFunctionServerUds, Boolean failOnNestedLoopJoin) + { + return getExternalWorkerLauncher(catalogName, prestoServerPath, OptionalInt.empty(), cacheMaxSize, remoteFunctionServerUds, failOnNestedLoopJoin); + } + + public static Optional> getExternalWorkerLauncher(String catalogName, String prestoServerPath, OptionalInt port, int cacheMaxSize, Optional remoteFunctionServerUds, Boolean failOnNestedLoopJoin) { return Optional.of((workerIndex, discoveryUri) -> { @@ -428,13 +434,13 @@ public static Optional> getExternalWorkerLaunc Files.createDirectories(dir); Path tempDirectoryPath = Files.createTempDirectory(dir, "worker"); log.info("Temp directory for Worker #%d: %s", workerIndex, tempDirectoryPath.toString()); - int port = 1234 + workerIndex; // Write config file String configProperties = format("discovery.uri=%s%n" + "presto.version=testversion%n" + "system-memory-gb=4%n" + - "http-server.http.port=%d", discoveryUri, port); + "native-sidecar=true%n" + + "http-server.http.port=%d", discoveryUri, port.orElse(1234 + workerIndex)); if (remoteFunctionServerUds.isPresent()) { String jsonSignaturesPath = Resources.getResource(REMOTE_FUNCTION_JSON_SIGNATURES).getFile(); @@ -518,4 +524,12 @@ public static void setupJsonFunctionNamespaceManager(QueryRunner queryRunner, St "function-implementation-type", "CPP", "json-based-function-manager.path-to-function-definition", jsonDefinitionPath)); } + + public static int findRandomPortForWorker() + throws IOException + { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } } diff --git a/presto-native-plugin/pom.xml b/presto-native-plugin/pom.xml index eeba938f82aa6..8172f9e1bdb9f 100644 --- a/presto-native-plugin/pom.xml +++ b/presto-native-plugin/pom.xml @@ -58,6 +58,16 @@ http-client + + com.facebook.airlift + log + + + + com.facebook.airlift + log-manager + + com.facebook.presto @@ -119,5 +129,91 @@ testing test + + + com.facebook.presto + presto-tests + test + test-jar + + + + com.facebook.presto + presto-tests + test + + + + com.facebook.presto + presto-native-execution + test + test-jar + + + + com.facebook.presto + presto-main + test + test-jar + + + + com.facebook.presto + presto-main + test + + + + com.facebook.presto + presto-tpcds + test + + + + com.facebook.airlift + jaxrs + test + + + + com.facebook.presto + presto-client + test + + + + com.facebook.airlift + jaxrs-testing + test + + + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + parquet.thrift + about.html + mozilla/public-suffix-list.txt + iceberg-build.properties + org.apache.avro.data/Json.avsc + + + com.esotericsoftware.kryo.* + com.esotericsoftware.minlog.Log + com.esotericsoftware.reflectasm.* + module-info + META-INF.versions.9.module-info + org.apache.avro.* + com.github.benmanes.caffeine.* + org.roaringbitmap.* + + + + + diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/NativePluginQueryRunner.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/NativePluginQueryRunner.java new file mode 100644 index 0000000000000..af55642b2acf7 --- /dev/null +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/NativePluginQueryRunner.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.log.Logging; +import com.facebook.presto.nativeworker.NativeQueryRunnerUtils; +import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; + +public class NativePluginQueryRunner +{ + private NativePluginQueryRunner() {} + + public static void main(String[] args) + throws Exception + { + // You need to add "--user user" to your CLI for your queries to work. + Logging.initialize(); + + // Create tables before launching distributed runner. + QueryRunner javaQueryRunner = PrestoNativeQueryRunnerUtils.createJavaQueryRunner(false); + NativeQueryRunnerUtils.createAllTables(javaQueryRunner); + javaQueryRunner.close(); + + // Launch distributed runner. + DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.createQueryRunner(false); + queryRunner.getExpressionManager().addExpressionOptimizerFactory(new NativeExpressionOptimizerFactory(ClassLoader.getSystemClassLoader())); + queryRunner.getExpressionManager().loadExpressions(ImmutableMap.builder().put("expression-manager-factory.name", "native").build()); + Thread.sleep(10); + Logger log = Logger.get(DistributedQueryRunner.class); + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java similarity index 79% rename from presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java rename to presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java index 57f28fa483645..bc71240832fa1 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java @@ -11,8 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.tests.expressions; +package com.facebook.presto.session.sql.expressions; +import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.ExpressionOptimizer; @@ -21,17 +22,22 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.relational.DelegatingRowExpressionOptimizer; import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.tests.expressions.TestExpressions; import com.google.common.collect.ImmutableList; -import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.net.URI; import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.BiFunction; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.findRandomPortForWorker; +import static com.facebook.presto.session.sql.expressions.TestNativeExpressionOptimization.getExpressionOptimizer; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; @@ -47,14 +53,30 @@ public class TestDelegatingExpressionOptimizer { public static final FunctionResolution RESOLUTION = new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()); private ExpressionOptimizer expressionOptimizer; + private Process sidecar; @BeforeClass public void setup() + throws Exception { - METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); - setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); + int port = findRandomPortForWorker(); + URI sidecarUri = URI.create("http://127.0.0.1:" + port); + Optional> launcher = PrestoNativeQueryRunnerUtils.getExternalWorkerLauncher( + "hive", + "/Users/tdcmeehan/git/presto/presto-native-execution/_build/release/presto_cpp/main/presto_server", + OptionalInt.of(port), + 0, + Optional.empty(), + false); + sidecar = launcher.get().apply(0, URI.create("http://test.invalid/")); + + expressionOptimizer = new DelegatingRowExpressionOptimizer(METADATA, () -> getExpressionOptimizer(METADATA, HANDLE_RESOLVER, sidecarUri)); + } - expressionOptimizer = new DelegatingRowExpressionOptimizer(METADATA, () -> TestNativeExpressions.getExpressionOptimizer(METADATA, HANDLE_RESOLVER)); + @AfterClass + public void tearDown() + { + sidecar.destroyForcibly(); } @Test @@ -118,7 +140,7 @@ protected Object evaluate(String expression, boolean deterministic) } @Override - protected Object optimize(@Language("SQL") String expression) + protected Object optimize(String expression) { assertRoundTrip(expression); RowExpression parsedExpression = sqlToRowExpression(expression); @@ -154,7 +176,7 @@ public Object unwrap(Object result) } @Override - protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + protected void assertOptimizedEquals(String actual, String expected) { Object optimizedActual = optimize(actual); Object optimizedExpected = optimize(expected); @@ -162,7 +184,7 @@ protected void assertOptimizedEquals(@Language("SQL") String actual, @Language(" } @Override - protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + protected void assertOptimizedMatches(String actual, String expected) { Object actualOptimized = optimize(actual); Object expectedOptimized = optimize(expected); @@ -172,7 +194,7 @@ protected void assertOptimizedMatches(@Language("SQL") String actual, @Language( } @Override - protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + protected void assertDoNotOptimize(String expression, Level optimizationLevel) { assertRoundTrip(expression); RowExpression rowExpression = sqlToRowExpression(expression); diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimization.java similarity index 59% rename from presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java rename to presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimization.java index 66e786923fe68..b4062c4e7d5e2 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestNativeExpressions.java +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimization.java @@ -11,14 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.tests.expressions; +package com.facebook.presto.session.sql.expressions; import com.facebook.airlift.bootstrap.Bootstrap; -import com.facebook.airlift.http.client.HttpClient; -import com.facebook.airlift.http.client.testing.TestingHttpClient; import com.facebook.airlift.jaxrs.JsonMapper; -import com.facebook.airlift.jaxrs.testing.JaxrsTestingHttpProcessor; import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.SessionTestUtils; import com.facebook.presto.block.BlockJsonSerde; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.common.block.Block; @@ -35,19 +33,16 @@ import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; import com.facebook.presto.nodeManager.PluginNodeManager; -import com.facebook.presto.session.sql.expressions.ForSidecarInfo; -import com.facebook.presto.session.sql.expressions.NativeExpressionOptimizerProvider; -import com.facebook.presto.session.sql.expressions.NativeExpressionsModule; import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; -import com.facebook.presto.sql.planner.RowExpressionInterpreter; +import com.facebook.presto.sql.planner.LiteralEncoder; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.type.TypeDeserializer; import com.google.common.collect.ImmutableList; @@ -57,47 +52,53 @@ import com.google.inject.Scopes; import org.testng.annotations.Test; -import javax.ws.rs.Consumes; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.UriBuilder; - import java.net.URI; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.BiFunction; import static com.facebook.airlift.json.JsonBinder.jsonBinder; import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; -import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.findRandomPortForWorker; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; -import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; -import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; -import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.inject.multibindings.Multibinder.newSetBinder; import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; -public class TestNativeExpressions +public class TestNativeExpressionOptimization { - public static final URI SIDECAR_URI = URI.create("http://127.0.0.1:1122"); + public static final String SIDECAR_URI = "http://127.0.0.1:"; private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); @Test public void testLoadPlugin() + throws Exception { - ExpressionOptimizer interpreterService = getExpressionOptimizer(METADATA, null); - - // Test the native row expression interpreter service with some simple expressions - RowExpression simpleAddition = compileExpression("1+1"); - RowExpression unnecessaryCoalesce = compileExpression("coalesce(1, 2)"); - - // Assert simple optimizations are performed - assertEquals(interpreterService.optimize(simpleAddition, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(2L, simpleAddition.getType())); - assertEquals(interpreterService.optimize(unnecessaryCoalesce, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(1L, unnecessaryCoalesce.getType())); + int port = findRandomPortForWorker(); + URI sidecarUri = URI.create(SIDECAR_URI + port); + Optional> launcher = PrestoNativeQueryRunnerUtils.getExternalWorkerLauncher( + "hive", + "/Users/tdcmeehan/git/presto/presto-native-execution/_build/release/presto_cpp/main/presto_server", + OptionalInt.of(port), + 0, + Optional.empty(), + false); + Process process = launcher.get().apply(0, URI.create("http://test.invalid/")); + + try { + ExpressionOptimizer interpreterService = getExpressionOptimizer(METADATA, null, sidecarUri); + + // Test the native row expression interpreter service with some simple expressions + RowExpression simpleAddition = compileExpression("1+1"); + RowExpression unnecessaryCoalesce = compileExpression("coalesce(1, 2)"); + + // Assert simple optimizations are performed + assertEquals(interpreterService.optimize(simpleAddition, OPTIMIZED, SessionTestUtils.TEST_SESSION.toConnectorSession()), LiteralEncoder.toRowExpression(2L, simpleAddition.getType())); + assertEquals(interpreterService.optimize(unnecessaryCoalesce, OPTIMIZED, SessionTestUtils.TEST_SESSION.toConnectorSession()), LiteralEncoder.toRowExpression(1L, unnecessaryCoalesce.getType())); + } + finally { + process.destroyForcibly(); + } } private static RowExpression compileExpression(String expression) @@ -105,28 +106,16 @@ private static RowExpression compileExpression(String expression) return TRANSLATOR.translate(expression, ImmutableMap.of()); } - protected static ExpressionOptimizer getExpressionOptimizer(Metadata metadata, HandleResolver handleResolver) + protected static ExpressionOptimizer getExpressionOptimizer(Metadata metadata, HandleResolver handleResolver, URI sidecarUri) { // Set up dependencies in main for this module - InMemoryNodeManager nodeManager = getNodeManagerWithSidecar(SIDECAR_URI); + InMemoryNodeManager nodeManager = getNodeManagerWithSidecar(sidecarUri); Injector prestoMainInjector = getPrestoMainInjector(metadata, handleResolver); - JsonMapper jsonMapper = prestoMainInjector.getInstance(JsonMapper.class); RowExpressionSerde rowExpressionSerde = prestoMainInjector.getInstance(RowExpressionSerde.class); FunctionAndTypeManager functionMetadataManager = prestoMainInjector.getInstance(FunctionAndTypeManager.class); - // Set up the mock HTTP endpoint that delegates to the Java based row expression interpreter - TestingExpressionOptimizerResource resource = new TestingExpressionOptimizerResource( - metadata.getFunctionAndTypeManager(), - testSessionBuilder().build().toConnectorSession(), - SERIALIZABLE); - JaxrsTestingHttpProcessor jaxrsTestingHttpProcessor = new JaxrsTestingHttpProcessor( - UriBuilder.fromUri(SIDECAR_URI).path("/").build(), - resource, - jsonMapper); - TestingHttpClient testingHttpClient = new TestingHttpClient(jaxrsTestingHttpProcessor); - // Create the native row expression interpreter service - return createExpressionOptimizer(nodeManager, rowExpressionSerde, testingHttpClient, functionMetadataManager); + return createExpressionOptimizer(nodeManager, rowExpressionSerde, functionMetadataManager); } private static InMemoryNodeManager getNodeManagerWithSidecar(URI sidecarUri) @@ -136,16 +125,14 @@ private static InMemoryNodeManager getNodeManagerWithSidecar(URI sidecarUri) return nodeManager; } - private static ExpressionOptimizer createExpressionOptimizer(InternalNodeManager internalNodeManager, RowExpressionSerde rowExpressionSerde, HttpClient httpClient, FunctionAndTypeManager functionMetadataManager) + private static ExpressionOptimizer createExpressionOptimizer(InternalNodeManager internalNodeManager, RowExpressionSerde rowExpressionSerde, FunctionAndTypeManager functionMetadataManager) { requireNonNull(internalNodeManager, "inMemoryNodeManager is null"); NodeManager nodeManager = new PluginNodeManager(internalNodeManager); FunctionResolution functionResolution = new FunctionResolution(functionMetadataManager.getFunctionAndTypeResolver()); Bootstrap app = new Bootstrap( - // Specially use a testing HTTP client instead of a real one - binder -> binder.bind(HttpClient.class).annotatedWith(ForSidecarInfo.class).toInstance(httpClient), - // Otherwise use the exact same module as the native row expression interpreter service + new NativeExpressionsCommunicationModule(), new NativeExpressionsModule(nodeManager, rowExpressionSerde, functionMetadataManager, functionResolution)); Injector injector = app @@ -192,36 +179,4 @@ private static Injector getPrestoMainInjector(Metadata metadata, HandleResolver .initialize(); return injector; } - - @Path("/v1/expressions") - public static class TestingExpressionOptimizerResource - { - private final FunctionAndTypeManager functionAndTypeManager; - private final ConnectorSession connectorSession; - private final ExpressionOptimizer.Level level; - - public TestingExpressionOptimizerResource(FunctionAndTypeManager functionAndTypeManager, ConnectorSession connectorSession, ExpressionOptimizer.Level level) - { - this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); - this.connectorSession = requireNonNull(connectorSession, "connectorSession is null"); - this.level = requireNonNull(level, "level is null"); - } - - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - public List post(List rowExpressions) - { - Map input = rowExpressions.stream().collect(toImmutableMap(i -> i, i -> i)); - Map optimizedExpressions = new HashMap<>(); - input.forEach((key, value) -> optimizedExpressions.put( - key, - new RowExpressionInterpreter(key, functionAndTypeManager, connectorSession, level).optimize())); - ImmutableList.Builder builder = ImmutableList.builder(); - for (RowExpression inputExpression : rowExpressions) { - builder.add(toRowExpression(optimizedExpressions.get(inputExpression), inputExpression.getType())); - } - return builder.build(); - } - } } diff --git a/presto-tests/pom.xml b/presto-tests/pom.xml index a2d03b17b6cfb..6be55b940be62 100644 --- a/presto-tests/pom.xml +++ b/presto-tests/pom.xml @@ -244,11 +244,6 @@ javax.servlet-api - - com.facebook.airlift - jaxrs - - com.facebook.airlift @@ -282,12 +277,6 @@ test - - com.facebook.presto - presto-native-plugin - test - - com.facebook.airlift jaxrs-testing From f94197d2f034b57ba32e2d8189e437573d108d3e Mon Sep 17 00:00:00 2001 From: Pramod Date: Tue, 10 Sep 2024 21:43:47 +0530 Subject: [PATCH 08/10] [native] Add expression evaluation support in sidecar --- .../presto_cpp/main/CMakeLists.txt | 2 + .../presto_cpp/main/PrestoServer.cpp | 9 + .../presto_cpp/main/PrestoServer.h | 2 + .../presto_cpp/main/expression/CMakeLists.txt | 32 + .../expression/RowExpressionEvaluator.cpp | 571 ++++++++++++++++++ .../main/expression/RowExpressionEvaluator.h | 129 ++++ .../main/expression/tests/CMakeLists.txt | 26 + .../tests/RowExpressionEvaluatorTest.cpp | 183 ++++++ .../tests/data/SimpleExpressionsExpected.json | 17 + .../tests/data/SimpleExpressionsInput.json | 278 +++++++++ .../tests/data/SpecialFormExpected.json | 17 + .../tests/data/SpecialFormInput.json | 193 ++++++ .../main/types/PrestoToVeloxExpr.cpp | 35 +- .../presto_cpp/main/types/PrestoToVeloxExpr.h | 24 + 14 files changed, 1494 insertions(+), 24 deletions(-) create mode 100644 presto-native-execution/presto_cpp/main/expression/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.cpp create mode 100644 presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.h create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/RowExpressionEvaluatorTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 30ba84dc5461d..489b4f2a01225 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(operators) add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) +add_subdirectory(expression) add_subdirectory(thrift) add_library( @@ -49,6 +50,7 @@ target_link_libraries( presto_exception presto_http presto_operators + presto_expr_eval velox_aggregates velox_caching velox_common_base diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 5a1e09d7cc840..253500fffe691 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -479,6 +479,15 @@ void PrestoServer::run() { taskManager_->getQueryContextManager()->getSessionProperties(); http::sendOkResponse(downstream, sessionProperties.serialize()); }); + rowExpressionEvaluator_ = + std::make_unique(); + httpServer_->registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* /*message*/, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + return rowExpressionEvaluator_->evaluate(body, downstream); + }); } std::string taskUri; diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index bee2d8d43391a..37cb2faf37d6f 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -25,6 +25,7 @@ #include "presto_cpp/main/PeriodicHeartbeatManager.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/PrestoServerOperations.h" +#include "presto_cpp/main/expression/RowExpressionEvaluator.h" #include "presto_cpp/main/types/VeloxPlanValidator.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/MemoryAllocator.h" @@ -277,6 +278,7 @@ class PrestoServer { std::string address_; std::string nodeLocation_; folly::SSLContextPtr sslContext_; + std::unique_ptr rowExpressionEvaluator_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/expression/CMakeLists.txt b/presto-native-execution/presto_cpp/main/expression/CMakeLists.txt new file mode 100644 index 0000000000000..1e1c970547fc4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_expr_eval RowExpressionEvaluator.cpp) + +target_link_libraries( + presto_expr_eval + presto_type_converter + presto_types + presto_protocol + presto_http + velox_coverage_util + velox_parse_expression + velox_parse_parser + velox_presto_serializer + velox_serialization + velox_type_parser + ${FOLLY_WITH_DEPENDENCIES}) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.cpp b/presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.cpp new file mode 100644 index 0000000000000..6dc9fa2e02b64 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.cpp @@ -0,0 +1,571 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/expression/RowExpressionEvaluator.h" +#include +#include "presto_cpp/main/common/Utils.h" +#include "velox/common/encode/Base64.h" +#include "velox/exec/ExchangeQueue.h" +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/expression/ExprCompiler.h" +#include "velox/expression/FieldReference.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace facebook::presto::expression { + +namespace { + +protocol::TypeSignature getTypeSignature(const TypePtr& type) { + std::string typeSignature; + if (type->parameters().empty()) { + typeSignature = type->toString(); + boost::algorithm::to_lower(typeSignature); + } else if (type->isDecimal()) { + typeSignature = type->toString(); + } else { + std::vector childTypes; + if (type->isRow()) { + typeSignature = "row("; + childTypes = asRowType(type)->children(); + } else if (type->isArray()) { + typeSignature = "array("; + childTypes = type->asArray().children(); + } else if (type->isMap()) { + typeSignature = "map("; + const auto mapType = type->asMap(); + childTypes = {mapType.keyType(), mapType.valueType()}; + } else { + VELOX_USER_FAIL("Invalid type {}", type->toString()); + } + + if (!childTypes.empty()) { + auto numChildren = childTypes.size(); + for (auto i = 0; i < numChildren - 1; i++) { + typeSignature += fmt::format("{},", getTypeSignature(childTypes[i])); + } + typeSignature += getTypeSignature(childTypes[numChildren - 1]); + } + typeSignature += ")"; + } + return typeSignature; +} + +json toVariableReferenceExpression( + const std::shared_ptr& fieldReference, + const json& input) { + protocol::VariableReferenceExpression vexpr; + vexpr.name = fieldReference->name(); + vexpr._type = "variable"; + vexpr.type = getTypeSignature(fieldReference->type()); + + json res; + protocol::to_json(res, vexpr); + if (input.contains("sourceLocation")) { + res["sourceLocation"] = input["sourceLocation"]; + } + return res; +} + +bool isPrestoSpecialForm(const std::string& name) { + static const std::unordered_set kPrestoSpecialForms = { + "and", + "coalesce", + "if", + "in", + "is_null", + "or", + "switch", + "when", + "null_if"}; + return kPrestoSpecialForms.count(name) != 0; +} + +json::array_t getInputExpressions( + const std::vector>& body) { + std::ostringstream oss; + for (auto& buf : body) { + oss << std::string((const char*)buf->data(), buf->length()); + } + return json::parse(oss.str()); +} +} // namespace + +// ValueBlock in ConstantExpression requires only the column from the serialized +// PrestoPage without the page header. +std::string RowExpressionConverter::getValueBlock(const VectorPtr& vector) { + std::ostringstream output; + serde_->serializeSingleColumn(vector, nullptr, pool_.get(), &output); + const auto serialized = output.str(); + const auto serializedSize = serialized.size(); + return velox::encoding::Base64::encode(serialized.c_str(), serializedSize); +} + +std::shared_ptr +RowExpressionConverter::getConstantRowExpression( + const std::shared_ptr& constantExpr) { + protocol::ConstantExpression cexpr; + cexpr.type = getTypeSignature(constantExpr->type()); + cexpr.valueBlock.data = getValueBlock(constantExpr->value()); + return std::make_shared(cexpr); +} + +json RowExpressionConverter::getRowConstructorSpecialForm( + const exec::ExprPtr& expr, + const json& input) { + json res; + res["@type"] = "special"; + res["form"] = "ROW_CONSTRUCTOR"; + res["returnType"] = getTypeSignature(expr->type()); + + res["arguments"] = json::array(); + auto exprInputs = expr->inputs(); + if (!exprInputs.empty()) { + for (const auto& exprInput : exprInputs) { + res["arguments"].push_back(veloxExprToRowExpression(exprInput, input)); + } + } else if ( + auto constantExpr = + std::dynamic_pointer_cast(expr)) { + auto value = constantExpr->value(); + auto* constVector = value->as>(); + auto* rowVector = constVector->valueVector()->as(); + auto type = asRowType(constantExpr->type()); + auto children = rowVector->children(); + auto size = children.size(); + + json j; + protocol::ConstantExpression cexpr; + for (auto i = 0; i < size; i++) { + cexpr.type = getTypeSignature(type->childAt(i)); + cexpr.valueBlock.data = getValueBlock(rowVector->childAt(i)); + protocol::to_json(j, cexpr); + res["arguments"].push_back(j); + } + } + + if (input.contains("sourceLocation")) { + res["sourceLocation"] = input["sourceLocation"]; + } + return res; +} + +json RowExpressionConverter::getWhenSpecialForm( + const std::vector& exprInputs, + const vector_size_t& idx, + json::array_t inputArgs, + bool isSearchedForm) { + json res; + res["@type"] = "special"; + res["form"] = "WHEN"; + const auto& equalExprInputs = exprInputs[idx]->inputs(); + // expressions to the left and right of WHEN. + const auto& leftExpr = + isSearchedForm ? equalExprInputs[1] : equalExprInputs[0]; + const auto& rightExpr = exprInputs[idx + 1]; + const vector_size_t argsIdx = idx / 2 + 1; + json::array_t whenArgs = inputArgs[argsIdx].at("arguments"); + + json::array_t args; + if (!leftExpr->inputs().empty()) { + args.emplace_back(veloxExprToCallExpr( + leftExpr->inputs()[0], whenArgs[0].at("arguments")[0])); + } else { + args.emplace_back(veloxExprToRowExpression(leftExpr, whenArgs[0])); + } + args.emplace_back(veloxExprToRowExpression(rightExpr, whenArgs[1])); + res["arguments"] = args; + res["returnType"] = getTypeSignature(rightExpr->type()); + + if (inputArgs[argsIdx].contains("sourceLocation")) { + res["sourceLocation"] = inputArgs[argsIdx].at("sourceLocation"); + } + return res; +} + +json::array_t RowExpressionConverter::getSwitchSpecialFormArgs( + const exec::ExprPtr& expr, + const json& input) { + json::array_t inputArgs = input["arguments"]; + auto numArgs = inputArgs.size(); + bool isSearchedForm = false; + if (typeParser_.parse(inputArgs[0]["type"]) == BOOLEAN() && + inputArgs[0]["@type"] == "constant") { + isSearchedForm = true; + } + const std::vector exprInputs = expr->inputs(); + const auto numInputs = exprInputs.size(); + + json::array_t result = json::array(); + if (isSearchedForm) { + auto variableExpr = exprInputs[0]->inputs()[0]; + result.push_back(veloxExprToRowExpression(variableExpr, inputArgs[0])); + for (auto i = 0; i < numInputs - 1; i += 2) { + result.push_back(getWhenSpecialForm(exprInputs, i, inputArgs, true)); + } + } else { + auto variableExpr = exprInputs[0]->inputs()[1]; + result.push_back(veloxExprToRowExpression(variableExpr, inputArgs[0])); + for (auto i = 0; i < numInputs - 1; i += 2) { + result.push_back(getWhenSpecialForm(exprInputs, i, inputArgs, false)); + } + } + result.push_back(inputArgs[numArgs - 1]); + + return result; +} + +json RowExpressionConverter::getSpecialForm( + const exec::ExprPtr& expr, + const json& input) { + json res; + res["@type"] = "special"; + std::string form; + if (input.contains("form")) { + form = input["form"]; + } else { + // If input json is a call expression instead of a special form, for cases + // like 'is_null', the key 'form' will not be present in the input json. + form = expr->name(); + } + // Presto requires the field form to be in upper case. + std::transform(form.begin(), form.end(), form.begin(), ::toupper); + res["form"] = form; + auto exprInputs = expr->inputs(); + res["arguments"] = json::array(); + + // Arguments for switch expression include special form expression 'when' + // which needs to be constructed separately. + if (form == "SWITCH") { + res["arguments"] = getSwitchSpecialFormArgs(expr, input); + } else { + json::array_t inputArguments = input["arguments"]; + const auto numInputs = exprInputs.size(); + VELOX_USER_CHECK_LE(numInputs, inputArguments.size()); + for (auto i = 0; i < numInputs; i++) { + res["arguments"].push_back( + veloxExprToRowExpression(exprInputs[i], inputArguments[i])); + } + } + res["returnType"] = getTypeSignature(expr->type()); + + if (input.contains("sourceLocation")) { + res["sourceLocation"] = input["sourceLocation"]; + } + return res; +} + +json RowExpressionConverter::veloxExprToCallExpr( + const exec::ExprPtr& expr, + const json& input) { + json res; + res["@type"] = "call"; + protocol::Signature signature; + std::string exprName = expr->name(); + if (veloxToPrestoOperatorMap_.find(expr->name()) != + veloxToPrestoOperatorMap_.end()) { + exprName = veloxToPrestoOperatorMap_.at(expr->name()); + } + signature.name = exprName; + res["displayName"] = exprName; + signature.kind = protocol::FunctionKind::SCALAR; + signature.typeVariableConstraints = {}; + signature.longVariableConstraints = {}; + signature.returnType = getTypeSignature(expr->type()); + + std::vector argumentTypes; + auto exprInputs = expr->inputs(); + auto numArgs = exprInputs.size(); + argumentTypes.reserve(numArgs); + for (auto i = 0; i < numArgs; i++) { + argumentTypes.emplace_back(getTypeSignature(exprInputs[i]->type())); + } + signature.argumentTypes = argumentTypes; + signature.variableArity = false; + + protocol::BuiltInFunctionHandle builtInFunctionHandle; + builtInFunctionHandle._type = "$static"; + builtInFunctionHandle.signature = signature; + res["functionHandle"] = builtInFunctionHandle; + res["returnType"] = getTypeSignature(expr->type()); + res["arguments"] = json::array(); + for (const auto& exprInput : exprInputs) { + res["arguments"].push_back(veloxExprToRowExpression(exprInput, input)); + } + + return res; +} + +json RowExpressionConverter::veloxExprToRowExpression( + const exec::ExprPtr& expr, + const json& input) { + if (expr->type()->isRow()) { + // Velox constant expressions of ROW type map to special form expression + // row_constructor in Presto. + return getRowConstructorSpecialForm(expr, input); + } else if (expr->isConstant()) { + if (expr->inputs().empty()) { + json res; + auto constantExpr = + std::dynamic_pointer_cast(expr); + VELOX_USER_CHECK_NOT_NULL(constantExpr); + auto constantRowExpr = getConstantRowExpression(constantExpr); + protocol::to_json(res, constantRowExpr); + return res; + } else { + // Inputs to constant expressions are constant, eg: divide(1, 2). + return input; + } + } else if ( + auto field = + std::dynamic_pointer_cast(expr)) { + // variable + return toVariableReferenceExpression(field, input); + } else if (expr->isSpecialForm() || expr->vectorFunction()) { + // Check if special form expression or call expression. + auto exprName = expr->name(); + boost::algorithm::to_lower(exprName); + if (isPrestoSpecialForm(exprName)) { + return getSpecialForm(expr, input); + } else { + return veloxExprToCallExpr(expr, input); + } + } + + VELOX_NYI( + "Conversion of Velox Expr {} to Presto RowExpression is not supported", + expr->toString()); +} + +RowExpressionPtr RowExpressionEvaluator::optimizeAndSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto leftExpr = compileExpression(left); + bool isLeftNull; + + if (auto constantExpr = + std::dynamic_pointer_cast(leftExpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (!isLeftNull) { + if (auto constVector = + constantExpr->value()->as>()) { + if (!constVector->valueAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } else { + return right; + } + } + } + } + + auto rightExpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rightExpr)) { + if (isLeftNull && constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + if (auto constVector = constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return left; + } + return right; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionEvaluator::optimizeIfSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto condition = specialFormExpr->arguments[0]; + auto expr = compileExpression(condition); + + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (auto constVector = constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return specialFormExpr->arguments[1]; + } + return specialFormExpr->arguments[2]; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionEvaluator::optimizeIsNullSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto expr = compileExpression(specialFormExpr); + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionEvaluator::optimizeOrSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto leftExpr = compileExpression(left); + bool isLeftNull; + + if (auto constantExpr = + std::dynamic_pointer_cast(leftExpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (!isLeftNull) { + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + return right; + } + } + } + + auto rightExpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rightExpr)) { + if (isLeftNull && constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + if (auto constVector = constantExpr->value()->as>()) { + if (!constVector->valueAt(0)) { + return left; + } + return right; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionEvaluator::optimizeCoalesceSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto argsNoNulls = specialFormExpr->arguments; + argsNoNulls.erase( + std::remove_if( + argsNoNulls.begin(), + argsNoNulls.end(), + [&](const auto& arg) { + auto compiledExpr = compileExpression(arg); + if (auto constantExpr = + std::dynamic_pointer_cast( + compiledExpr)) { + return constantExpr->value()->isNullAt(0); + } + return false; + }), + argsNoNulls.end()); + + if (argsNoNulls.empty()) { + return specialFormExpr->arguments[0]; + } + specialFormExpr->arguments = argsNoNulls; + return specialFormExpr; +} + +RowExpressionPtr RowExpressionEvaluator::optimizeSpecialForm( + const std::shared_ptr& specialFormExpr) { + switch (specialFormExpr->form) { + case protocol::Form::IF: + return optimizeIfSpecialForm(specialFormExpr); + case protocol::Form::NULL_IF: + VELOX_USER_FAIL("NULL_IF specialForm not supported"); + break; + case protocol::Form::IS_NULL: + return optimizeIsNullSpecialForm(specialFormExpr); + case protocol::Form::AND: + return optimizeAndSpecialForm(specialFormExpr); + case protocol::Form::OR: + return optimizeOrSpecialForm(specialFormExpr); + case protocol::Form::COALESCE: + return optimizeCoalesceSpecialForm(specialFormExpr); + case protocol::Form::IN: + case protocol::Form::DEREFERENCE: + case protocol::Form::SWITCH: + case protocol::Form::WHEN: + case protocol::Form::ROW_CONSTRUCTOR: + case protocol::Form::BIND: + default: + break; + } + + return specialFormExpr; +} + +exec::ExprPtr RowExpressionEvaluator::compileExpression( + const std::shared_ptr& inputRowExpr) { + auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr); + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + return compiledExprs[0]; +} + +json::array_t RowExpressionEvaluator::evaluateExpressions( + json::array_t& input) { + auto numExpr = input.size(); + json::array_t output = json::array(); + + for (auto i = 0; i < numExpr; i++) { + std::shared_ptr inputRowExpr = input[i]; + VLOG(2) << input[i].dump(); + if (const auto special = + std::dynamic_pointer_cast( + inputRowExpr)) { + inputRowExpr = optimizeSpecialForm(special); + } + const auto compiledExpr = compileExpression(inputRowExpr); + json resultJson = rowExpressionConverter_.veloxExprToRowExpression( + compiledExpr, input[i]); + VLOG(2) << resultJson.dump(); + output.push_back(resultJson); + } + + return output; +} + +void RowExpressionEvaluator::evaluate( + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + try { + json::array_t inputList = getInputExpressions(body); + json output = evaluateExpressions(inputList); + + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "OK") + .header( + proxygen::HTTP_HEADER_CONTENT_TYPE, http::kMimeTypeApplicationJson) + .body(output.dump()) + .sendWithEOM(); + } catch (const velox::VeloxUserError& e) { + VLOG(2) << e.what(); + http::sendErrorResponse(downstream, e.what()); + } catch (const velox::VeloxException& e) { + VLOG(2) << e.what(); + http::sendErrorResponse(downstream, e.what()); + } catch (const std::exception& e) { + VLOG(2) << e.what(); + http::sendErrorResponse(downstream, e.what()); + } +} + +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.h b/presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.h new file mode 100644 index 0000000000000..d253defc62ac7 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/RowExpressionEvaluator.h @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/main/http/HttpServer.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/core/QueryCtx.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/Expr.h" +#include "velox/serializers/PrestoSerializer.h" + +namespace facebook::presto::expression { + +using RowExpressionPtr = std::shared_ptr; +using SpecialFormExpressionPtr = + std::shared_ptr; + +// Helper class to convert Velox Expr of different types to the respective kind +// of Presto RowExpression. +class RowExpressionConverter { + public: + RowExpressionConverter(const std::shared_ptr& pool) + : pool_(pool), veloxToPrestoOperatorMap_(veloxToPrestoOperatorMap()) {} + + std::shared_ptr getConstantRowExpression( + const std::shared_ptr& constantExpr); + + json veloxExprToRowExpression( + const velox::exec::ExprPtr& expr, + const json& inputRowExpr); + + protected: + std::string getValueBlock(const velox::VectorPtr& vector); + + json getRowConstructorSpecialForm( + const velox::exec::ExprPtr& expr, + const json& inputRowExpr); + + json getWhenSpecialForm( + const std::vector& exprInputs, + const velox::vector_size_t& idx, + json::array_t inputArgs, + bool isSearchedForm); + + json::array_t getSwitchSpecialFormArgs( + const velox::exec::ExprPtr& expr, + const json& input); + + json getSpecialForm( + const velox::exec::ExprPtr& expr, + const json& inputRowExpr); + + json veloxExprToCallExpr(const velox::exec::ExprPtr& expr, const json& input); + + const std::shared_ptr pool_; + const std::unordered_map veloxToPrestoOperatorMap_; + const std::unique_ptr serde_ = + std::make_unique(); + const TypeParser typeParser_; +}; + +class RowExpressionEvaluator { + public: + explicit RowExpressionEvaluator() + : pool_(velox::memory::MemoryManager::getInstance()->addLeafPool( + "RowExpressionEvaluator")), + execCtx_{std::make_unique( + pool_.get(), + queryCtx_.get())}, + veloxExprConverter_(pool_.get(), &typeParser_), + rowExpressionConverter_(RowExpressionConverter(pool_)) {} + + /// Evaluate expressions sent along endpoint '/v1/expressions'. + void evaluate( + const std::vector>& body, + proxygen::ResponseHandler* downstream); + + protected: + RowExpressionPtr optimizeAndSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeIfSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeIsNullSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeOrSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeCoalesceSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Optimizes special form expressions. Optimization rules borrowed from + /// Presto function visitSpecialForm() in RowExpressionInterpreter.java. + RowExpressionPtr optimizeSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Converts protocol::RowExpression into a velox expression with constant + /// folding enabled during velox expression compilation. + velox::exec::ExprPtr compileExpression(const RowExpressionPtr& inputRowExpr); + + /// Optimizes and constant folds each expression from input json array and + /// returns an array of expressions that are optimized and constant folded. + /// Uses RowExpressionConverter to convert Velox expression(s) to their + /// corresponding Presto RowExpression(s). + json::array_t evaluateExpressions(json::array_t& input); + + const std::shared_ptr pool_; + const std::shared_ptr queryCtx_{ + facebook::velox::core::QueryCtx::create()}; + const std::unique_ptr execCtx_; + TypeParser typeParser_; + VeloxExprConverter veloxExprConverter_; + RowExpressionConverter rowExpressionConverter_; +}; +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt new file mode 100644 index 0000000000000..0e95be687a2c1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_executable(presto_expr_eval_test RowExpressionEvaluatorTest.cpp) + +add_test(presto_expr_eval_test presto_expr_eval_test) + +target_link_libraries( + presto_expr_eval_test + presto_expr_eval + presto_http + velox_exec_test_lib + velox_presto_serializer + GTest::gtest + GTest::gtest_main + ${PROXYGEN_LIBRARIES}) diff --git a/presto-native-execution/presto_cpp/main/expression/tests/RowExpressionEvaluatorTest.cpp b/presto-native-execution/presto_cpp/main/expression/tests/RowExpressionEvaluatorTest.cpp new file mode 100644 index 0000000000000..50f23c0c3cb39 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/RowExpressionEvaluatorTest.cpp @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/expression/RowExpressionEvaluator.h" +#include +#include +#include +#include +#include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/http/tests/HttpTestBase.h" +#include "velox/exec/OutputBufferManager.h" +#include "velox/expression/RegisterSpecialForm.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/VectorStream.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace { +std::string getDataPath(const std::string& fileName) { + std::string currentPath = fs::current_path().c_str(); + + if (boost::algorithm::ends_with(currentPath, "fbcode")) { + return currentPath + + "/github/presto-trunk/presto-native-execution/presto_cpp/main/expression/tests/data/" + + fileName; + } + + if (boost::algorithm::ends_with(currentPath, "fbsource")) { + return currentPath + "/third-party/presto_cpp/main/expression/tests/data/" + + fileName; + } + + // CLion runs the tests from cmake-build-release/ or cmake-build-debug/ + // directory. Hard-coded json files are not copied there and test fails with + // file not found. Fixing the path so that we can trigger these tests from + // CLion. + boost::algorithm::replace_all(currentPath, "cmake-build-release/", ""); + boost::algorithm::replace_all(currentPath, "cmake-build-debug/", ""); + + return currentPath + "/data/" + fileName; +} +} // namespace + +class RowExpressionEvaluatorTest + : public ::testing::Test, + public facebook::velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions("presto.default."); + exec::registerFunctionCallToSpecialForms(); + + auto httpServer = std::make_unique( + httpSrvIOExecutor_, + std::make_unique( + folly::SocketAddress("127.0.0.1", 0))); + driverExecutor_ = std::make_unique(4); + rowExpressionEvaluator_ = + std::make_unique(); + httpServer->registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* /*message*/, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + return rowExpressionEvaluator_->evaluate(body, downstream); + }); + httpServerWrapper_ = + std::make_unique(std::move(httpServer)); + auto address = httpServerWrapper_->start().get(); + client_ = clientFactory_.newClient( + address, + std::chrono::milliseconds(100'000), + std::chrono::milliseconds(0), + false, + pool_); + } + + void TearDown() override { + if (httpServerWrapper_) { + httpServerWrapper_->stop(); + } + } + + std::string getHttpBody(const std::unique_ptr& response) { + std::ostringstream oss; + auto iobufs = response->consumeBody(); + for (auto& body : iobufs) { + oss << std::string((const char*)body->data(), body->length()); + } + return oss.str(); + } + + void validateHttpResponse( + const std::string& inputStr, + const std::string& expectedStr) { + http::RequestBuilder() + .method(proxygen::HTTPMethod::POST) + .url("/v1/expressions") + .send(client_.get(), inputStr) + .via(driverExecutor_.get()) + .thenValue( + [expectedStr, this](std::unique_ptr response) { + VELOX_USER_CHECK_EQ( + response->headers()->getStatusCode(), http::kHttpOk); + if (response->hasError()) { + VELOX_USER_FAIL( + "Expression evaluation failed: {}", response->error()); + } + + auto resStr = getHttpBody(response); + auto resJson = json::parse(resStr); + ASSERT_TRUE(resJson.is_array()); + auto expectedJson = json::parse(expectedStr); + ASSERT_TRUE(expectedJson.is_array()); + EXPECT_EQ(expectedJson.size(), resJson.size()); + auto size = resJson.size(); + for (auto i = 0; i < size; i++) { + EXPECT_EQ(resJson[i], expectedJson[i]); + } + }) + .thenError( + folly::tag_t{}, [&](const std::exception& e) { + VLOG(1) << "Expression evaluation failed: " << e.what(); + }); + } + + void testFile(const std::string& prefix) { + std::string input = slurp(getDataPath(fmt::format("{}Input.json", prefix))); + auto inputExpressions = json::parse(input); + std::string output = + slurp(getDataPath(fmt::format("{}Expected.json", prefix))); + auto expectedExpressions = json::parse(output); + + validateHttpResponse(inputExpressions.dump(), expectedExpressions.dump()); + } + + std::unique_ptr rowExpressionEvaluator_; + std::unique_ptr httpServerWrapper_; + HttpClientFactory clientFactory_; + std::shared_ptr client_; + std::shared_ptr httpSrvIOExecutor_{ + std::make_shared(8)}; + std::unique_ptr driverExecutor_; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool("RowExpressionEvaluatorTest")}; +}; + +TEST_F(RowExpressionEvaluatorTest, simple) { + // File SimpleExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // select 1 + 2; + // select abs(-11) + ceil(cast(3.4 as double)) + floor(cast(5.6 as double)); + // select 2 between 1 and 3; + // Simple expression evaluation with constant folding is verified here. + testFile("SimpleExpressions"); +} + +TEST_F(RowExpressionEvaluatorTest, specialFormRewrites) { + // File SpecialExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // select if(1 < 2, 2, 3); + // select (1 < 2) and (2 < 3); + // select (1 < 2) or (2 < 3); + // Special form expression rewrites are verified here. + testFile("SpecialForm"); +} diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json new file mode 100644 index 0000000000000..e9b014fecd3b2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "double", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAADRA" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json new file mode 100644 index 0000000000000..dd2a0497703eb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json @@ -0,0 +1,278 @@ +[ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAACwAAAA==" + } + ], + "displayName": "NEGATION", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$negation", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "abs", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.abs", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAACIAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ceil", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.ceil", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADgAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "floor", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.floor", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json new file mode 100644 index 0000000000000..2ce6acb1ab46e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json new file mode 100644 index 0000000000000..77802722541b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json @@ -0,0 +1,193 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "form": "IF", + "returnType": "integer" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "AND", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "OR", + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 2c2b2a3c5ea00..791e215fb0f3c 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -33,32 +33,10 @@ std::string toJsonString(const T& value) { } std::string mapScalarFunction(const std::string& name) { - static const std::unordered_map kFunctionNames = { - // Operator overrides: com.facebook.presto.common.function.OperatorType - {"presto.default.$operator$add", "presto.default.plus"}, - {"presto.default.$operator$between", "presto.default.between"}, - {"presto.default.$operator$divide", "presto.default.divide"}, - {"presto.default.$operator$equal", "presto.default.eq"}, - {"presto.default.$operator$greater_than", "presto.default.gt"}, - {"presto.default.$operator$greater_than_or_equal", "presto.default.gte"}, - {"presto.default.$operator$is_distinct_from", - "presto.default.distinct_from"}, - {"presto.default.$operator$less_than", "presto.default.lt"}, - {"presto.default.$operator$less_than_or_equal", "presto.default.lte"}, - {"presto.default.$operator$modulus", "presto.default.mod"}, - {"presto.default.$operator$multiply", "presto.default.multiply"}, - {"presto.default.$operator$negation", "presto.default.negate"}, - {"presto.default.$operator$not_equal", "presto.default.neq"}, - {"presto.default.$operator$subtract", "presto.default.minus"}, - {"presto.default.$operator$subscript", "presto.default.subscript"}, - // Special form function overrides. - {"presto.default.in", "in"}, - }; - std::string lowerCaseName = boost::to_lower_copy(name); - auto it = kFunctionNames.find(lowerCaseName); - if (it != kFunctionNames.end()) { + auto it = kPrestoOperatorMap.find(lowerCaseName); + if (it != kPrestoOperatorMap.end()) { return it->second; } @@ -102,6 +80,15 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) { } // namespace +const std::unordered_map veloxToPrestoOperatorMap() { + std::unordered_map veloxToPrestoOperatorMap; + for (const auto& entry : kPrestoOperatorMap) { + veloxToPrestoOperatorMap[entry.second] = entry.first; + } + veloxToPrestoOperatorMap.insert({"cast", "presto.default.$operator$cast"}); + return veloxToPrestoOperatorMap; +} + velox::variant VeloxExprConverter::getConstantValue( const velox::TypePtr& type, const protocol::Block& block) const { diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index 6e93c675a55f5..8d20e9f07484e 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -20,6 +20,30 @@ namespace facebook::presto { +static const std::unordered_map kPrestoOperatorMap = { + // Operator overrides: com.facebook.presto.common.function.OperatorType + {"presto.default.$operator$add", "presto.default.plus"}, + {"presto.default.$operator$between", "presto.default.between"}, + {"presto.default.$operator$divide", "presto.default.divide"}, + {"presto.default.$operator$equal", "presto.default.eq"}, + {"presto.default.$operator$greater_than", "presto.default.gt"}, + {"presto.default.$operator$greater_than_or_equal", "presto.default.gte"}, + {"presto.default.$operator$is_distinct_from", + "presto.default.distinct_from"}, + {"presto.default.$operator$less_than", "presto.default.lt"}, + {"presto.default.$operator$less_than_or_equal", "presto.default.lte"}, + {"presto.default.$operator$modulus", "presto.default.mod"}, + {"presto.default.$operator$multiply", "presto.default.multiply"}, + {"presto.default.$operator$negation", "presto.default.negate"}, + {"presto.default.$operator$not_equal", "presto.default.neq"}, + {"presto.default.$operator$subtract", "presto.default.minus"}, + {"presto.default.$operator$subscript", "presto.default.subscript"}, + // Special form function overrides. + {"presto.default.in", "in"}, +}; + +const std::unordered_map veloxToPrestoOperatorMap(); + class VeloxExprConverter { public: VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser) From 2e7a405800e70fd2fc9d6201e595751e2a570f97 Mon Sep 17 00:00:00 2001 From: Pramod Date: Fri, 8 Nov 2024 13:35:57 +0530 Subject: [PATCH 09/10] Disable tests unsupported in Prestissimo --- .../TestDelegatingExpressionOptimizer.java | 34 ++++++++++++++++++- .../tests/expressions/TestExpressions.java | 1 + 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java index bc71240832fa1..338d51d2d4f4d 100644 --- a/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java +++ b/presto-native-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestDelegatingExpressionOptimizer.java @@ -59,6 +59,7 @@ public class TestDelegatingExpressionOptimizer public void setup() throws Exception { + super.setup(); int port = findRandomPortForWorker(); URI sidecarUri = URI.create("http://127.0.0.1:" + port); Optional> launcher = PrestoNativeQueryRunnerUtils.getExternalWorkerLauncher( @@ -79,7 +80,8 @@ public void tearDown() sidecar.destroyForcibly(); } - @Test + // TODO: Pending on native function namespace manager. + @Test(enabled = false) public void assertLikeOptimizations() { assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)"); @@ -119,6 +121,36 @@ public void testMassiveArray() public void testNonDeterministicFunctionCall() { } + // TODO: apply function is not supported in Presto native. + @Test(enabled = false) + @Override + public void testBind() {} + + // TODO: TIME type is unsupported in Presto native. + @Test(enabled = false) + @Override + public void testLiterals() {} + + // TODO: NULL_IF special form is unsupported in Presto native. + @Test(enabled = false) + @Override + public void testNullIf() {} + + // TODO: Bounded varchar is currently unsupported in Presto native. + @Test(enabled = false) + @Override + public void testCastBigintToBoundedVarchar() {} + + // TODO: current_user function is not implemented in Presto native. + @Test(enabled = false) + @Override + public void testCurrentUser() {} + + // TODO: Non-legacy map subscript is not supported in Presto native. + @Test(enabled = false) + @Override + public void testMapSubscriptMissingKey() {} + @Override protected void assertLike(byte[] value, String pattern, boolean expected) { diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java index 785be015e06b8..a0b4e3888330b 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java @@ -203,6 +203,7 @@ public TestExpressions() @BeforeClass public void setup() + throws Exception { setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); } From 0630fa2c5a9abb012f18349296673690b856e520 Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Fri, 8 Nov 2024 12:15:25 -0500 Subject: [PATCH 10/10] sq Add native row expression optimizer --- .../sql/expressions/NativeSidecarExpressionInterpreter.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java index 165ba866b6371..ef10643b8808c 100644 --- a/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java +++ b/presto-native-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java @@ -42,6 +42,7 @@ public class NativeSidecarExpressionInterpreter { + public static final String PRESTO_TIME_ZONE_HEADER = "X-Presto-Time-Zone"; private final NodeManager nodeManager; private final HttpClient httpClient; private final JsonCodec> rowExpressionSerde; @@ -70,6 +71,7 @@ public Map optimizeBatch(ConnectorSession session, Map optimized = httpClient.execute(request, createJsonResponseHandler(rowExpressionSerde));