Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Expression with RowExpression in assignment #12747

Merged
merged 11 commits into from
Jun 18, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;

import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static java.util.Objects.requireNonNull;

public class ProjectStatsRule
Expand Down Expand Up @@ -53,8 +55,14 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsPro
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
.setOutputRowCount(sourceStats.getOutputRowCount());

for (Map.Entry<VariableReferenceExpression, Expression> entry : node.getAssignments().entrySet()) {
calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types));
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().entrySet()) {
RowExpression expression = entry.getValue();
if (isExpression(expression)) {
calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(castToExpression(expression), sourceStats, session, types));
}
else {
calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(expression, sourceStats, session));
}
}
return Optional.of(calculatedStats.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.google.common.base.Predicates.in;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Maps.transformValues;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -164,7 +165,7 @@ public Expression visitProject(ProjectNode node, Void context)

Expression underlyingPredicate = node.getSource().accept(this, context);

List<Expression> projectionEqualities = node.getAssignments().entrySet().stream()
List<Expression> projectionEqualities = transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream()
.filter(VARIABLE_MATCHES_EXPRESSION.negate())
.map(VARIABLE_ENTRY_TO_EQUALITY)
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.collect.ImmutableList;

import java.util.List;
Expand Down Expand Up @@ -112,7 +111,7 @@ public Void visitFilter(FilterNode node, ImmutableList.Builder<RowExpression> co
@Override
public Void visitProject(ProjectNode node, ImmutableList.Builder<RowExpression> context)
{
context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList()));
context.addAll(node.getAssignments().getExpressions().stream().collect(toImmutableList()));
return super.visitProject(node, context);
}

Expand All @@ -136,7 +135,6 @@ public Void visitApply(ApplyNode node, ImmutableList.Builder<RowExpression> cont
context.addAll(node.getSubqueryAssignments()
.getExpressions()
.stream()
.map(OriginalExpressionUtils::castToRowExpression)
.collect(toImmutableList()));
return super.visitApply(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@
import static com.facebook.presto.SystemSessionProperties.getTaskWriterCount;
import static com.facebook.presto.SystemSessionProperties.isExchangeCompressionEnabled;
import static com.facebook.presto.SystemSessionProperties.isSpillEnabled;
import static com.facebook.presto.execution.warnings.WarningCollector.NOOP;
import static com.facebook.presto.operator.DistinctLimitOperator.DistinctLimitOperatorFactory;
import static com.facebook.presto.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory;
import static com.facebook.presto.operator.NestedLoopJoinOperator.NestedLoopJoinOperatorFactory;
Expand All @@ -240,7 +239,6 @@
import static com.facebook.presto.spi.relation.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider;
import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
Expand All @@ -250,6 +248,7 @@
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT;
Expand All @@ -271,11 +270,9 @@
import static com.google.common.collect.DiscreteDomain.integers;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.IntStream.range;

Expand Down Expand Up @@ -1140,7 +1137,7 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext
RowExpression filterExpression = node.getPredicate();
List<VariableReferenceExpression> outputVariables = node.getOutputVariables();

return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputVariables), outputVariables);
return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identityAssignments(outputVariables), outputVariables);
}

@Override
Expand Down Expand Up @@ -1211,30 +1208,15 @@ private PhysicalOperation visitScanFilterAndProject(
Map<VariableReferenceExpression, Integer> outputMappings = outputMappingsBuilder.build();

// compiler uses inputs instead of symbols, so rewrite the expressions first

List<Expression> projections = new ArrayList<>();
for (VariableReferenceExpression variable : outputVariables) {
projections.add(assignments.get(variable));
}

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
context.getSession(),
metadata,
sqlParser,
context.getTypes(),
concat(assignments.getExpressions()),
emptyList(),
NOOP,
false);

List<RowExpression> translatedProjections = projections.stream()
.map(expression -> toRowExpression(expression, expressionTypes, sourceLayout))
List<RowExpression> projections = outputVariables.stream()
.map(assignments::get)
.map(expression -> bindChannels(expression, sourceLayout))
.collect(toImmutableList());

try {
if (columns != null) {
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, translatedProjections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, projections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId));

SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory(
context.getNextOperatorId(),
Expand All @@ -1244,20 +1226,20 @@ private PhysicalOperation visitScanFilterAndProject(
cursorProcessor,
pageProcessor,
columns,
getTypes(projections, expressionTypes),
projections.stream().map(RowExpression::getType).collect(toImmutableList()),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

return new PhysicalOperation(operatorFactory, outputMappings, context, stageExecutionDescriptor.isScanGroupedExecution(sourceNode.getId()) ? GROUPED_EXECUTION : UNGROUPED_EXECUTION);
}
else {
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId));

OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory(
context.getNextOperatorId(),
planNodeId,
pageProcessor,
getTypes(projections, expressionTypes),
projections.stream().map(RowExpression::getType).collect(toImmutableList()),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

Expand Down Expand Up @@ -2729,14 +2711,6 @@ private OperatorFactory createHashAggregationOperatorFactory(
}
}

private static List<Type> getTypes(List<Expression> expressions, Map<NodeRef<Expression>, Type> expressionTypes)
{
return expressions.stream()
.map(NodeRef::of)
.map(expressionTypes::get)
.collect(toImmutableList());
}

private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata)
{
WriterTarget target = node.getTarget();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import static com.facebook.presto.sql.planner.plan.TableWriterNode.WriterTarget;
import static com.facebook.presto.sql.planner.sanity.PlanSanityChecker.DISTRIBUTED_PLAN_SANITY_CHECKER;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -358,19 +359,19 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement)
int index = insert.getColumns().indexOf(columns.get(column.getName()));
if (index < 0) {
Expression cast = new Cast(new NullLiteral(), column.getType().getTypeSignature().toString());
assignments.put(output, cast);
assignments.put(output, castToRowExpression(cast));
}
else {
Symbol input = plan.getSymbol(index);
Type tableType = column.getType();
Type queryType = symbolAllocator.getTypes().get(input);

if (queryType.equals(tableType) || metadata.getTypeManager().isTypeOnlyCoercion(queryType, tableType)) {
assignments.put(output, input.toSymbolReference());
assignments.put(output, castToRowExpression(input.toSymbolReference()));
}
else {
Expression cast = new Cast(input.toSymbolReference(), tableType.getTypeSignature().toString());
assignments.put(output, cast);
assignments.put(output, castToRowExpression(cast));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;
import java.util.Map;

import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static java.util.Objects.requireNonNull;

class PlanBuilder
Expand Down Expand Up @@ -105,13 +106,13 @@ public PlanBuilder appendProjections(Iterable<Expression> expressions, SymbolAll

// add an identity projection for underlying plan
for (VariableReferenceExpression variable : getRoot().getOutputVariables()) {
projections.put(variable, new SymbolReference(variable.getName()));
projections.put(variable, castToRowExpression(new SymbolReference(variable.getName())));
}

ImmutableMap.Builder<VariableReferenceExpression, Expression> newTranslations = ImmutableMap.builder();
for (Expression expression : expressions) {
VariableReferenceExpression variable = symbolAllocator.newVariable(expression, getAnalysis().getTypeWithCoercions(expression));
projections.put(variable, translations.rewrite(expression));
projections.put(variable, castToRowExpression(translations.rewrite(expression)));
newTranslations.put(variable, expression);
}
// Now append the new translations into the TranslationMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding;
Expand All @@ -69,7 +70,6 @@
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.planner.sanity.PlanSanityChecker;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -109,6 +109,7 @@
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -560,13 +561,13 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite
private PartitioningVariableAssignments assignPartitioningVariables(Partitioning partitioning)
{
ImmutableList.Builder<VariableReferenceExpression> variables = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, Expression> constants = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> constants = ImmutableMap.builder();
for (ArgumentBinding argumentBinding : partitioning.getArguments()) {
VariableReferenceExpression variable;
if (argumentBinding.isConstant()) {
ConstantExpression constant = argumentBinding.getConstant();
Expression expression = literalEncoder.toExpression(constant.getValue(), constant.getType());
hellium01 marked this conversation as resolved.
Show resolved Hide resolved
variable = symbolAllocator.newVariable(expression, constant.getType());
RowExpression expression = constant(constant.getValue(), constant.getType());
hellium01 marked this conversation as resolved.
Show resolved Hide resolved
variable = symbolAllocator.newVariable("constant_partition", constant.getType());
constants.put(variable, expression);
}
else {
Expand Down Expand Up @@ -632,7 +633,7 @@ private TableFinishNode createTemporaryTableWrite(
List<VariableReferenceExpression> outputs,
List<List<VariableReferenceExpression>> inputs,
List<PlanNode> sources,
Map<VariableReferenceExpression, Expression> constantExpressions,
Map<VariableReferenceExpression, RowExpression> constantExpressions,
PartitioningMetadata partitioningMetadata)
{
if (!constantExpressions.isEmpty()) {
Expand All @@ -656,8 +657,8 @@ private TableFinishNode createTemporaryTableWrite(
sources = sources.stream()
.map(source -> {
Assignments.Builder assignments = Assignments.builder();
assignments.putIdentities(source.getOutputVariables());
constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable)));
source.getOutputVariables().forEach(variable -> assignments.put(variable, new VariableReferenceExpression(variable.getName(), variable.getType())));
constantVariables.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol)));
hellium01 marked this conversation as resolved.
Show resolved Hide resolved
return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
})
.collect(toImmutableList());
Expand Down Expand Up @@ -1216,9 +1217,9 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context)
private static class PartitioningVariableAssignments
{
private final List<VariableReferenceExpression> variables;
private final Map<VariableReferenceExpression, Expression> constants;
private final Map<VariableReferenceExpression, RowExpression> constants;

private PartitioningVariableAssignments(List<VariableReferenceExpression> variables, Map<VariableReferenceExpression, Expression> constants)
private PartitioningVariableAssignments(List<VariableReferenceExpression> variables, Map<VariableReferenceExpression, RowExpression> constants)
{
this.variables = ImmutableList.copyOf(requireNonNull(variables, "variables is null"));
this.constants = ImmutableMap.copyOf(requireNonNull(constants, "constants is null"));
Expand All @@ -1232,7 +1233,7 @@ public List<VariableReferenceExpression> getVariables()
return variables;
}

public Map<VariableReferenceExpression, Expression> getConstants()
public Map<VariableReferenceExpression, RowExpression> getConstants()
{
return constants;
}
Expand Down
Loading