Skip to content

Commit

Permalink
SQL: Allow look-ahead resolution of aliases for WHERE clause (#38450)
Browse files Browse the repository at this point in the history
Aliases defined in SELECT (Project or Aggregate) are now resolved in the
following WHERE clause. The Analyzer has been enhanced to identify this
rule and replace the field accordingly.

Close #29983
  • Loading branch information
costin authored Feb 6, 2019
1 parent 6ff4a8c commit 1a02445
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
new ResolveRefs(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ResolveMissingRefs(),
new ResolveFilterRefs(),
new ResolveFunctions(),
new ResolveAliases(),
new ProjectedAggregations(),
Expand Down Expand Up @@ -762,6 +763,68 @@ private static UnresolvedAttribute resolveMetadataToMessage(UnresolvedAttribute
}
}

//
// Resolve aliases defined in SELECT that are referred inside the WHERE clause:
// SELECT int AS i FROM t WHERE i > 10
//
// As such, identify all project and aggregates that have a Filter child
// and look at any resoled aliases that match and replace them.
private class ResolveFilterRefs extends AnalyzeRule<LogicalPlan> {

@Override
protected LogicalPlan rule(LogicalPlan plan) {
if (plan instanceof Project) {
Project p = (Project) plan;
if (p.child() instanceof Filter) {
Filter f = (Filter) p.child();
Expression condition = f.condition();
if (condition.resolved() == false && f.childrenResolved() == true) {
Expression newCondition = replaceAliases(condition, p.projections());
if (newCondition != condition) {
return new Project(p.source(), new Filter(f.source(), f.child(), newCondition), p.projections());
}
}
}
}

if (plan instanceof Aggregate) {
Aggregate a = (Aggregate) plan;
if (a.child() instanceof Filter) {
Filter f = (Filter) a.child();
Expression condition = f.condition();
if (condition.resolved() == false && f.childrenResolved() == true) {
Expression newCondition = replaceAliases(condition, a.aggregates());
if (newCondition != condition) {
return new Aggregate(a.source(), new Filter(f.source(), f.child(), newCondition), a.groupings(),
a.aggregates());
}
}
}
}

return plan;
}

private Expression replaceAliases(Expression condition, List<? extends NamedExpression> named) {
List<Alias> aliases = new ArrayList<>();
named.forEach(n -> {
if (n instanceof Alias) {
aliases.add((Alias) n);
}
});

return condition.transformDown(u -> {
boolean qualified = u.qualifier() != null;
for (Alias alias : aliases) {
if (qualified ? Objects.equals(alias.qualifiedName(), u.qualifiedName()) : Objects.equals(alias.name(), u.name())) {
return alias;
}
}
return u;
}, UnresolvedAttribute.class);
}
}

// to avoid creating duplicate functions
// this rule does two iterations
// 1. collect all functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.EsField;

Expand Down Expand Up @@ -75,6 +75,10 @@ public String qualifier() {
return qualifier;
}

public String qualifiedName() {
return qualifier == null ? name() : qualifier + "." + name();
}

@Override
public Nullability nullable() {
return child.nullable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,5 +638,16 @@ public void testMaxOnKeywordGroupByHavingUnsupported() {
assertEquals("1:52: HAVING filter is unsupported for function [MAX(keyword)]",
error("SELECT MAX(keyword) FROM test GROUP BY text HAVING MAX(keyword) > 10"));
}
}

public void testProjectAliasInFilter() {
accept("SELECT int AS i FROM test WHERE i > 10");
}

public void testAggregateAliasInFilter() {
accept("SELECT int AS i FROM test WHERE i > 10 GROUP BY i HAVING MAX(i) > 10");
}

public void testProjectUnresolvedAliasInFilter() {
assertEquals("1:8: Unknown column [tni]", error("SELECT tni AS i FROM test WHERE i > 10 GROUP BY i"));
}
}

0 comments on commit 1a02445

Please sign in to comment.