Skip to content

Commit

Permalink
Remove redundant cast in join clause
Browse files Browse the repository at this point in the history
  • Loading branch information
feilong-liu committed Oct 12, 2023
1 parent 627165a commit 8d9dfa9
Show file tree
Hide file tree
Showing 12 changed files with 389 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ public final class SystemSessionProperties
public static final String REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION = "rewrite_constant_array_contains_to_in_expression";
public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates";
public static final String ENABLE_HISTORY_BASED_SCALED_WRITER = "enable_history_based_scaled_writer";
public static final String REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN = "remove_redundant_cast_to_varchar_in_join";

// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1769,6 +1770,11 @@ public SystemSessionProperties(
ENABLE_HISTORY_BASED_SCALED_WRITER,
"Enable setting the initial number of tasks for scaled writers with HBO",
featuresConfig.isUseHBOForScaledWriters(),
false),
booleanProperty(
REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN,
"If both left and right side of join clause are varchar cast from int/bigint, remove the cast here",
featuresConfig.isRemoveRedundantCastToVarcharInJoin(),
false));
}

Expand Down Expand Up @@ -2948,4 +2954,9 @@ public static boolean useHBOForScaledWriters(Session session)
{
return session.getSystemProperty(ENABLE_HISTORY_BASED_SCALED_WRITER, Boolean.class);
}

public static boolean isRemoveRedundantCastToVarcharInJoinEnabled(Session session)
{
return session.getSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ public class FeaturesConfig
private boolean preProcessMetadataCalls;
private boolean useHBOForScaledWriters;

private boolean removeRedundantCastToVarcharInJoin = true;

public enum PartitioningPrecisionStrategy
{
// Let Presto decide when to repartition
Expand Down Expand Up @@ -2816,4 +2818,17 @@ public FeaturesConfig setUseHBOForScaledWriters(boolean useHBOForScaledWriters)
this.useHBOForScaledWriters = useHBOForScaledWriters;
return this;
}

public boolean isRemoveRedundantCastToVarcharInJoin()
{
return removeRedundantCastToVarcharInJoin;
}

@Config("optimizer.remove-redundant-cast-to-varchar-in-join")
@ConfigDescription("If both left and right side of join clause are varchar cast from int/bigint, remove the cast")
public FeaturesConfig setRemoveRedundantCastToVarcharInJoin(boolean removeRedundantCastToVarcharInJoin)
{
this.removeRedundantCastToVarcharInJoin = removeRedundantCastToVarcharInJoin;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveIdentityProjectionsBelowProjection;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantAggregateDistinct;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantCastToVarcharInJoinClause;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantDistinct;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantDistinctLimit;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
Expand Down Expand Up @@ -459,6 +460,12 @@ public PlanOptimizers(
new RemoveRedundantIdentityProjections(),
new TransformCorrelatedSingleRowSubqueryToProject())),
new CheckSubqueryNodesAreRewritten(),
new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager()))),
new IterativeOptimizer(
metadata,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isRemoveRedundantCastToVarcharInJoinEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.relational.Expressions.castToBigInt;

/**
* Remove redundant cast to varchar in join condition for queries like `select select * from orders o join customer c on cast(o.custkey as varchar) = cast(c.custkey as varchar)`
* Transform from
* <pre>
* - Join
* left_cast = right_cast
* - Project
* left_cast := cast(lkey as varchar)
* - TableScan
* lkey BIGINT
* - Project
* right_cast := cast(rkey as varchar)
* - TableScan
* rkey BIGINT
*
* </pre>
* into
* <pre>
* - Join
* new_lkey = new_rkey
* - Project
* left_cast := cast(lkey as varchar)
* new_lkey := lkey
* - TableScan
* lkey BIGINT
* - Project
* right_cast := cast(rkey as varchar)
* new_rkey := rkey
* - TableScan
* rkey BIGINT
* </pre>
* We will rely on optimizations later to remove unnecessary cast (if not used) and identity projection here.
* <p>
* Notice that we do not apply similar optimizations to queries with similar join condition like `cast(bigint as varchar) = varchar`. In general it can be converted to
* `bigint = try_cast(varchar as bigint)` as if the varchar here cannot be converted to bigint, try_cast will return null and will not match anyway. However, a special case is
* varchar begins with 0. `select cast(92 as varchar) = '092'` is false, but `select 92 = try_cast('092' as bigint)` returns true.
*/
public class RemoveRedundantCastToVarcharInJoinClause
implements Rule<JoinNode>
{
private static final List<Type> TYPE_SUPPORTED = ImmutableList.of(INTEGER, BIGINT);
private final FunctionAndTypeManager functionAndTypeManager;
private final FunctionResolution functionResolution;

public RemoveRedundantCastToVarcharInJoinClause(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = functionAndTypeManager;
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
}

@Override
public boolean isEnabled(Session session)
{
return isRemoveRedundantCastToVarcharInJoinEnabled(session);
}

@Override
public Pattern<JoinNode> getPattern()
{
return join();
}

@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
PlanNode leftInput = context.getLookup().resolve(node.getLeft());
PlanNode rightInput = context.getLookup().resolve(node.getRight());
if (!(leftInput instanceof ProjectNode) || !(rightInput instanceof ProjectNode)) {
return Result.empty();
}
ProjectNode leftProject = (ProjectNode) leftInput;
ProjectNode rightProject = (ProjectNode) rightInput;

ImmutableList.Builder<JoinNode.EquiJoinClause> joinClauseBuilder = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newLeftAssignmentsBuilder = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newRightAssignmentsBuilder = ImmutableMap.builder();
boolean isChanged = false;
for (JoinNode.EquiJoinClause equiJoinClause : node.getCriteria()) {
RowExpression leftProjectAssignment = leftProject.getAssignments().getMap().get(equiJoinClause.getLeft());
RowExpression rightProjectAssignment = rightProject.getAssignments().getMap().get(equiJoinClause.getRight());
if (!isSupportedCast(leftProjectAssignment) || !isSupportedCast(rightProjectAssignment)) {
joinClauseBuilder.add(equiJoinClause);
continue;
}

RowExpression leftAssignment = ((CallExpression) leftProjectAssignment).getArguments().get(0);
RowExpression rightAssignment = ((CallExpression) rightProjectAssignment).getArguments().get(0);

if (!leftAssignment.getType().equals(rightAssignment.getType())) {
leftAssignment = castToBigInt(functionAndTypeManager, leftAssignment);
rightAssignment = castToBigInt(functionAndTypeManager, rightAssignment);
}

VariableReferenceExpression newLeft = context.getVariableAllocator().newVariable(leftAssignment);
newLeftAssignmentsBuilder.put(newLeft, leftAssignment);

VariableReferenceExpression newRight = context.getVariableAllocator().newVariable(rightAssignment);
newRightAssignmentsBuilder.put(newRight, rightAssignment);

joinClauseBuilder.add(new JoinNode.EquiJoinClause(newLeft, newRight));
isChanged = true;
}

if (!isChanged) {
return Result.empty();
}

newLeftAssignmentsBuilder.putAll(leftProject.getAssignments().getMap());
Map<VariableReferenceExpression, RowExpression> newLeftAssignments = newLeftAssignmentsBuilder.build();
newRightAssignmentsBuilder.putAll(rightProject.getAssignments().getMap());
Map<VariableReferenceExpression, RowExpression> newRightAssignments = newRightAssignmentsBuilder.build();

PlanNode newLeftProject = addProjections(leftProject.getSource(), context.getIdAllocator(), newLeftAssignments);
PlanNode newRightProject = addProjections(rightProject.getSource(), context.getIdAllocator(), newRightAssignments);

return Result.ofPlanNode(new JoinNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getType(), newLeftProject, newRightProject, joinClauseBuilder.build(), node.getOutputVariables(), node.getFilter(), Optional.empty(), Optional.empty(), node.getDistributionType(), node.getDynamicFilters()));
}

private boolean isSupportedCast(RowExpression rowExpression)
{
if (rowExpression instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) rowExpression).getFunctionHandle())) {
CallExpression cast = (CallExpression) rowExpression;
return TYPE_SUPPORTED.contains(cast.getArguments().get(0).getType()) && cast.getType() instanceof VarcharType;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.CastType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.function.FunctionHandle;
Expand All @@ -39,6 +40,7 @@
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
Expand Down Expand Up @@ -161,6 +163,14 @@ public static CallExpression callOperator(FunctionAndTypeResolver functionAndTyp
return call(operatorType.name(), functionHandle, returnType, arguments);
}

public static RowExpression castToBigInt(FunctionAndTypeManager functionAndTypeManager, RowExpression rowExpression)
{
if (rowExpression.getType().equals(BIGINT)) {
return rowExpression;
}
return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, rowExpression.getType(), BIGINT), BIGINT, rowExpression);
}

public static RowExpression searchedCaseExpression(List<RowExpression> whenClauses, Optional<RowExpression> defaultValue)
{
// We rewrite this as - CASE true WHEN p1 THEN v1 WHEN p2 THEN v2 .. ELSE v END
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.function.Function;

import static com.facebook.airlift.json.JsonCodec.listJsonCodec;
import static com.facebook.presto.SystemSessionProperties.REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN;
import static com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithInitialTransaction;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.testing.TestingTaskContext.createTaskContext;
Expand Down Expand Up @@ -85,6 +86,7 @@ private void setUp(Supplier<List<Driver>> driversSupplier)
.setCatalog("tpch")
.setSchema("tiny")
.setSystemProperty("task_default_concurrency", "1")
.setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false")
.build();

localQueryRunner = queryRunnerWithInitialTransaction(session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ public void testDefaults()
.setInferInequalityPredicates(false)
.setPullUpExpressionFromLambdaEnabled(false)
.setRewriteConstantArrayContainsToInEnabled(false)
.setUseHBOForScaledWriters(false));
.setUseHBOForScaledWriters(false)
.setRemoveRedundantCastToVarcharInJoin(true));
}

@Test
Expand Down Expand Up @@ -443,6 +444,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.pull-up-expression-from-lambda", "true")
.put("optimizer.rewrite-constant-array-contains-to-in", "true")
.put("optimizer.use-hbo-for-scaled-writers", "true")
.put("optimizer.remove-redundant-cast-to-varchar-in-join", "false")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -635,7 +637,8 @@ public void testExplicitPropertyMappings()
.setInferInequalityPredicates(true)
.setPullUpExpressionFromLambdaEnabled(true)
.setRewriteConstantArrayContainsToInEnabled(true)
.setUseHBOForScaledWriters(true);
.setUseHBOForScaledWriters(true)
.setRemoveRedundantCastToVarcharInJoin(false);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses
return NO_MATCH;
}

return match(SymbolAliases.builder()
MatchResult result = match(SymbolAliases.builder()
.putAll(Maps.transformValues(outputSymbolAliases, index -> createSymbolReference(valuesNode.getOutputVariables().get(index))))
.build());
return result;
}

@Override
Expand Down
Loading

0 comments on commit 8d9dfa9

Please sign in to comment.