From 8632f80f349260b09273c250f9b82bb1a1afd02f Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 11 Jun 2021 14:35:55 -0700 Subject: [PATCH 1/2] Impl stddev and variance function in SQL and PPL (#115) * impl variance frontend and backend * Support construct AggregationResponseParser during Aggregator build stage * add var and varp for PPL Signed-off-by: penghuo * add UT Signed-off-by: penghuo * fix UT Signed-off-by: penghuo * fix doc format Signed-off-by: penghuo * fix doc format Signed-off-by: penghuo * fix the doc Signed-off-by: penghuo * add stddev_samp and stddev_pop Signed-off-by: penghuo * fix UT coverage * address comments Signed-off-by: penghuo --- core/build.gradle | 1 + .../sql/analysis/ExpressionAnalyzer.java | 3 +- .../org/opensearch/sql/expression/DSL.java | 16 ++ .../aggregation/AggregatorFunction.java | 52 ++++ .../aggregation/StdDevAggregator.java | 110 +++++++++ .../aggregation/VarianceAggregator.java | 109 +++++++++ .../function/BuiltinFunctionName.java | 30 +++ .../sql/analysis/ExpressionAnalyzerTest.java | 8 + .../aggregation/StdDevAggregatorTest.java | 182 ++++++++++++++ .../aggregation/VarianceAggregatorTest.java | 190 +++++++++++++++ docs/user/dql/aggregations.rst | 222 ++++++++++++++++++ docs/user/dql/window.rst | 86 ++++++- docs/user/ppl/cmd/stats.rst | 168 +++++++++++++ .../correctness/queries/aggregation.txt | 6 +- .../resources/correctness/queries/window.txt | 12 + .../sql/opensearch/response/agg/Utils.java | 2 +- .../dsl/MetricAggregationBuilder.java | 30 +++ .../dsl/MetricAggregationBuilderTest.java | 73 ++++++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 6 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../ppl/parser/AstExpressionBuilderTest.java | 84 +++++++ sql/src/main/antlr/OpenSearchSQLLexer.g4 | 7 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 +- .../sql/parser/AstExpressionBuilderTest.java | 21 ++ 24 files changed, 1414 insertions(+), 8 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java diff --git a/core/build.gradle b/core/build.gradle index 69acf5cef3..1c6c0c0481 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -51,6 +51,7 @@ dependencies { compile group: 'org.springframework', name: 'spring-beans', version: '5.2.5.RELEASE' compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' compile group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' compile project(':common') testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 0f207c0374..d5c1538b77 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -155,7 +155,8 @@ public Expression visitNot(Not node, AnalysisContext context) { @Override public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) { - Optional builtinFunctionName = BuiltinFunctionName.of(node.getFuncName()); + Optional builtinFunctionName = + BuiltinFunctionName.ofAggregation(node.getFuncName()); if (builtinFunctionName.isPresent()) { Expression arg = node.getField().accept(this, context); Aggregator aggregator = (Aggregator) repository.compile( diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 31050afc87..560414592c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -500,6 +500,22 @@ public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator varSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARSAMP, expressions); + } + + public Aggregator varPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARPOP, expressions); + } + + public Aggregator stddevSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions); + } + + public Aggregator stddevPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_POP, expressions); + } + public RankingWindowFunction rowNumber() { return (RankingWindowFunction) repository.compile( BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList()); diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index a6be7378f7..640ae8a934 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -35,6 +35,10 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIME; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; import com.google.common.collect.ImmutableMap; import java.util.Collections; @@ -68,6 +72,10 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(count()); repository.register(min()); repository.register(max()); + repository.register(varSamp()); + repository.register(varPop()); + repository.register(stddevSamp()); + repository.register(stddevPop()); } private static FunctionResolver avg() { @@ -159,4 +167,48 @@ private static FunctionResolver max() { .build() ); } + + private static FunctionResolver varSamp() { + FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> varianceSample(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver varPop() { + FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> variancePopulation(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver stddevSamp() { + FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> stddevSample(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver stddevPop() { + FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> stddevPopulation(arguments, DOUBLE)) + .build() + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java new file mode 100644 index 0000000000..0cd8494449 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * StandardDeviation Aggregator. + */ +public class StdDevAggregator extends Aggregator { + + private final boolean isSampleStdDev; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevPopulation(List arguments, + ExprCoreType returnType) { + return new StdDevAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevSample(List arguments, + ExprCoreType returnType) { + return new StdDevAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleStdDev true for sample standard deviation aggregator, false for population + * standard deviation aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public StdDevAggregator( + Boolean isSampleStdDev, List arguments, ExprCoreType returnType) { + super( + isSampleStdDev + ? BuiltinFunctionName.STDDEV_SAMP.getName() + : BuiltinFunctionName.STDDEV_POP.getName(), + arguments, + returnType); + this.isSampleStdDev = isSampleStdDev; + } + + @Override + public StdDevAggregator.StdDevState create() { + return new StdDevAggregator.StdDevState(isSampleStdDev); + } + + @Override + protected StdDevAggregator.StdDevState iterate(ExprValue value, + StdDevAggregator.StdDevState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format( + "%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments())); + } + + protected static class StdDevState implements AggregationState { + + private final StandardDeviation standardDeviation; + + private final List values = new ArrayList<>(); + + public StdDevState(boolean isSampleStdDev) { + this.standardDeviation = new StandardDeviation(isSampleStdDev); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java new file mode 100644 index 0000000000..bd9f0948f6 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Variance Aggregator. + */ +public class VarianceAggregator extends Aggregator { + + private final boolean isSampleVariance; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator variancePopulation(List arguments, + ExprCoreType returnType) { + return new VarianceAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator varianceSample(List arguments, + ExprCoreType returnType) { + return new VarianceAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleVariance true for sample variance aggregator, false for population variance + * aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public VarianceAggregator( + Boolean isSampleVariance, List arguments, ExprCoreType returnType) { + super( + isSampleVariance + ? BuiltinFunctionName.VARSAMP.getName() + : BuiltinFunctionName.VARPOP.getName(), + arguments, + returnType); + this.isSampleVariance = isSampleVariance; + } + + @Override + public VarianceState create() { + return new VarianceState(isSampleVariance); + } + + @Override + protected VarianceState iterate(ExprValue value, VarianceState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format( + "%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments())); + } + + protected static class VarianceState implements AggregationState { + + private final Variance variance; + + private final List values = new ArrayList<>(); + + public VarianceState(boolean isSampleVariance) { + this.variance = new Variance(isSampleVariance); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 0210161abe..24e65d4b5d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -12,6 +12,7 @@ package org.opensearch.sql.expression.function; import com.google.common.collect.ImmutableMap; +import java.util.Locale; import java.util.Map; import java.util.Optional; import lombok.Getter; @@ -126,6 +127,14 @@ public enum BuiltinFunctionName { COUNT(FunctionName.of("count")), MIN(FunctionName.of("min")), MAX(FunctionName.of("max")), + // sample variance + VARSAMP(FunctionName.of("var_samp")), + // population standard variance + VARPOP(FunctionName.of("var_pop")), + // sample standard deviation. + STDDEV_SAMP(FunctionName.of("stddev_samp")), + // population standard deviation. + STDDEV_POP(FunctionName.of("stddev_pop")), /** * Text Functions. @@ -189,7 +198,28 @@ public enum BuiltinFunctionName { ALL_NATIVE_FUNCTIONS = builder.build(); } + private static final Map AGGREGATION_FUNC_MAPPING = + new ImmutableMap.Builder() + .put("max", BuiltinFunctionName.MAX) + .put("min", BuiltinFunctionName.MIN) + .put("avg", BuiltinFunctionName.AVG) + .put("count", BuiltinFunctionName.COUNT) + .put("sum", BuiltinFunctionName.SUM) + .put("var_pop", BuiltinFunctionName.VARPOP) + .put("var_samp", BuiltinFunctionName.VARSAMP) + .put("variance", BuiltinFunctionName.VARPOP) + .put("std", BuiltinFunctionName.STDDEV_POP) + .put("stddev", BuiltinFunctionName.STDDEV_POP) + .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) + .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) + .build(); + public static Optional of(String str) { return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); } + + public static Optional ofAggregation(String functionName) { + return Optional.ofNullable( + AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index aa8d2b12de..8cb7288273 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -292,6 +292,14 @@ public void aggregation_filter() { ); } + @Test + public void variance_mapto_varPop() { + assertAnalyzeEqual( + dsl.varPop(DSL.ref("integer_value", INTEGER)), + AstDSL.aggregate("variance", qualifiedName("integer_value")) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java new file mode 100644 index 0000000000..ef085a81d3 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java @@ -0,0 +1,182 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class StdDevAggregatorTest extends AggregationTest { + + @Mock + Expression expression; + + @Mock + ExprValue tupleValue; + + @Mock + BindingTuple tuple; + + @Test + public void stddev_sample_field_expression() { + ExprValue result = + stddevSample(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.2909944487358056, result.value()); + } + + @Test + public void stddev_population_field_expression() { + ExprValue result = + stddevPop(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.118033988749895, result.value()); + } + + @Test + public void stddev_sample_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.stddevSamp(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(12.909944487358056, result.value()); + } + + @Test + public void stddev_population_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.stddevPop(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(11.180339887498949, result.value()); + } + + @Test + public void filtered_stddev_sample() { + ExprValue result = + aggregation( + dsl.stddevSamp(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(1.0, result.value()); + } + + @Test + public void filtered_stddev_population() { + ExprValue result = + aggregation( + dsl.stddevPop(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(0.816496580927726, result.value()); + } + + @Test + public void stddev_sample_with_missing() { + ExprValue result = stddevSample(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.7071067811865476, result.value()); + } + + @Test + public void stddev_population_with_missing() { + ExprValue result = stddevPop(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void stddev_sample_with_null() { + ExprValue result = stddevSample(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.7071067811865476, result.value()); + } + + @Test + public void stddev_pop_with_null() { + ExprValue result = stddevPop(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void stddev_sample_with_all_missing_or_null() { + ExprValue result = stddevSample(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void stddev_pop_with_all_missing_or_null() { + ExprValue result = stddevPop(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void stddev_sample_to_string() { + Aggregator aggregator = dsl.stddevSamp(ref("integer_value", INTEGER)); + assertEquals("stddev_samp(integer_value)", aggregator.toString()); + } + + @Test + public void stddev_pop_to_string() { + Aggregator aggregator = dsl.stddevPop(ref("integer_value", INTEGER)); + assertEquals("stddev_pop(integer_value)", aggregator.toString()); + } + + @Test + public void stddev_sample_nested_to_string() { + Aggregator avgAggregator = + dsl.stddevSamp( + dsl.multiply( + ref("integer_value", INTEGER), DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals( + String.format("stddev_samp(*(%s, %d))", ref("integer_value", INTEGER), 10), + avgAggregator.toString()); + } + + private ExprValue stddevSample(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.stddevSamp(expression), mockTuples(value, values)); + } + + private ExprValue stddevPop(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.stddevPop(expression), mockTuples(value, values)); + } + + private List mockTuples(ExprValue value, ExprValue... values) { + List mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java new file mode 100644 index 0000000000..09fb8b8012 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class VarianceAggregatorTest extends AggregationTest { + + @Mock Expression expression; + + @Mock ExprValue tupleValue; + + @Mock BindingTuple tuple; + + @Test + public void variance_sample_field_expression() { + ExprValue result = + varianceSample(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.6666666666666667, result.value()); + } + + @Test + public void variance_population_field_expression() { + ExprValue result = + variancePop(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.25, result.value()); + } + + @Test + public void variance_sample_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.varSamp(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(166.66666666666666, result.value()); + } + + @Test + public void variance_pop_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.varPop(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(125d, result.value()); + } + + @Test + public void filtered_variance_sample() { + ExprValue result = + aggregation( + dsl.varSamp(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(1.0, result.value()); + } + + @Test + public void filtered_variance_pop() { + ExprValue result = + aggregation( + dsl.varPop(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(0.6666666666666666, result.value()); + } + + @Test + public void variance_sample_with_missing() { + ExprValue result = varianceSample(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void variance_population_with_missing() { + ExprValue result = variancePop(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.25, result.value()); + } + + @Test + public void variance_sample_with_null() { + ExprValue result = varianceSample(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void variance_pop_with_null() { + ExprValue result = variancePop(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.25, result.value()); + } + + @Test + public void variance_sample_with_all_missing_or_null() { + ExprValue result = varianceSample(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void variance_pop_with_all_missing_or_null() { + ExprValue result = variancePop(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void valueOf() { + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> dsl.avg(ref("double_value", DOUBLE)).valueOf(valueEnv())); + assertEquals("can't evaluate on aggregator: avg", exception.getMessage()); + } + + @Test + public void variance_sample_to_string() { + Aggregator avgAggregator = dsl.varSamp(ref("integer_value", INTEGER)); + assertEquals("var_samp(integer_value)", avgAggregator.toString()); + } + + @Test + public void variance_pop_to_string() { + Aggregator avgAggregator = dsl.varPop(ref("integer_value", INTEGER)); + assertEquals("var_pop(integer_value)", avgAggregator.toString()); + } + + @Test + public void variance_sample_nested_to_string() { + Aggregator avgAggregator = + dsl.varSamp( + dsl.multiply( + ref("integer_value", INTEGER), DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals( + String.format("var_samp(*(%s, %d))", ref("integer_value", INTEGER), 10), + avgAggregator.toString()); + } + + private ExprValue varianceSample(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.varSamp(expression), mockTuples(value, values)); + } + + private ExprValue variancePop(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.varPop(expression), mockTuples(value, values)); + } + + private List mockTuples(ExprValue value, ExprValue... values) { + List mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 98b565e1ec..1d6d172981 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,6 +135,228 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. +Aggregation Functions +===================== + +COUNT +----- + +Description +>>>>>>>>>>> + +Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example:: + + os> SELECT gender, count(*) as countV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+----------+ + | gender | countV | + |----------+----------| + | F | 1 | + | M | 3 | + +----------+----------+ + +SUM +--- + +Description +>>>>>>>>>>> + +Usage: SUM(expr). Returns the sum of expr. + +Example:: + + os> SELECT gender, sum(age) as sumV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+--------+ + | gender | sumV | + |----------+--------| + | F | 28 | + | M | 101 | + +----------+--------+ + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: AVG(expr). Returns the average value of expr. + +Example:: + + os> SELECT gender, avg(age) as avgV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+--------------------+ + | gender | avgV | + |----------+--------------------| + | F | 28.0 | + | M | 33.666666666666664 | + +----------+--------------------+ + +MAX +--- + +Description +>>>>>>>>>>> + +Usage: MAX(expr). Returns the maximum value of expr. + +Example:: + + os> SELECT max(age) as maxV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | maxV | + |--------| + | 36 | + +--------+ + +MIN +--- + +Description +>>>>>>>>>>> + +Usage: MIN(expr). Returns the minimum value of expr. + +Example:: + + os> SELECT min(age) as minV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | minV | + |--------| + | 28 | + +--------+ + +VAR_POP +------- + +Description +>>>>>>>>>>> + +Usage: VAR_POP(expr). Returns the population standard variance of expr. + +Example:: + + os> SELECT var_pop(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | varV | + |--------| + | 8.1875 | + +--------+ + +VAR_SAMP +-------- + +Description +>>>>>>>>>>> + +Usage: VAR_SAMP(expr). Returns the sample variance of expr. + +Example:: + + os> SELECT var_samp(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | varV | + |--------------------| + | 10.916666666666666 | + +--------------------+ + +VARIANCE +-------- + +Description +>>>>>>>>>>> + +Usage: VARIANCE(expr). Returns the population standard variance of expr. VARIANCE() is a synonym VAR_POP() function. + +Example:: + + os> SELECT variance(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | varV | + |--------| + | 8.1875 | + +--------+ + +STDDEV_POP +---------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_POP(expr). Returns the population standard deviation of expr. + +Example:: + + os> SELECT stddev_pop(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + +STDDEV_SAMP +----------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_SAMP(expr). Returns the sample standard deviation of expr. + +Example:: + + os> SELECT stddev_samp(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +-------------------+ + | stddevV | + |-------------------| + | 3.304037933599835 | + +-------------------+ + +STD +--- + +Description +>>>>>>>>>>> + +Usage: STD(expr). Returns the population standard deviation of expr. STD() is a synonym STDDEV_POP() function. + +Example:: + + os> SELECT stddev_pop(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + +STDDEV +------ + +Description +>>>>>>>>>>> + +Usage: STDDEV(expr). Returns the population standard deviation of expr. STDDEV() is a synonym STDDEV_POP() function. + +Example:: + + os> SELECT stddev(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + HAVING Clause ============= diff --git a/docs/user/dql/window.rst b/docs/user/dql/window.rst index 6d71f0637a..feb2aaa44e 100644 --- a/docs/user/dql/window.rst +++ b/docs/user/dql/window.rst @@ -20,7 +20,7 @@ A window function consists of 2 pieces: a function and a window definition. A wi There are three categories of common window functions: -1. **Aggregate Functions**: COUNT(), MIN(), MAX(), AVG() and SUM(). +1. **Aggregate Functions**: COUNT(), MIN(), MAX(), AVG(), SUM(), STDDEV_POP, STDDEV_SAMP, VAR_POP and VAR_SAMP. 2. **Ranking Functions**: ROW_NUMBER(), RANK(), DENSE_RANK(), PERCENT_RANK() and NTILE(). 3. **Analytic Functions**: CUME_DIST(), LAG() and LEAD(). @@ -146,6 +146,90 @@ Here is an example for ``SUM`` function:: | M | 39225 | 49091 | +----------+-----------+-------+ +STDDEV_POP +---------- + +Here is an example for ``STDDEV_POP`` function:: + + os> SELECT + ... gender, balance, + ... STDDEV_POP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 753.0 | + | M | 39225 | 16177.091422406222 | + +----------+-----------+--------------------+ + +STDDEV_SAMP +----------- + +Here is an example for ``STDDEV_SAMP`` function:: + + os> SELECT + ... gender, balance, + ... STDDEV_SAMP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 1064.9028124669405 | + | M | 39225 | 19812.809753624886 | + +----------+-----------+--------------------+ + +VAR_POP +------- + +Here is an example for ``SUM`` function:: + + os> SELECT + ... gender, balance, + ... VAR_POP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 567009.0 | + | M | 39225 | 261698286.88888893 | + +----------+-----------+--------------------+ + +VAR_SAMP +-------- + +Here is an example for ``SUM`` function:: + + os> SELECT + ... gender, balance, + ... VAR_SAMP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------------------+ + | gender | balance | val | + |----------+-----------+-------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 1134018.0 | + | M | 39225 | 392547430.3333334 | + +----------+-----------+-------------------+ + Ranking Functions ================= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index 3aca304fcd..f6dad255ef 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -38,6 +38,174 @@ stats ... [by-clause]... * aggregation: mandatory. A aggregation function. The argument of aggregation must be field. * by-clause: optional. The one or more fields to group the results by. **Default**: If no is specified, the stats command returns only one row, which is the aggregation over the entire result set. + +Aggregation Functions +===================== + +COUNT +----- + +Description +>>>>>>>>>>> + +Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example:: + + os> source=accounts | stats count(); + fetched rows / total rows = 1/1 + +-----------+ + | count() | + |-----------| + | 4 | + +-----------+ + +SUM +--- + +Description +>>>>>>>>>>> + +Usage: SUM(expr). Returns the sum of expr. + +Example:: + + os> source=accounts | stats sum(age) by gender; + fetched rows / total rows = 2/2 + +------------+----------+ + | sum(age) | gender | + |------------+----------| + | 28 | F | + | 101 | M | + +------------+----------+ + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: AVG(expr). Returns the average value of expr. + +Example:: + + os> source=accounts | stats avg(age) by gender; + fetched rows / total rows = 2/2 + +--------------------+----------+ + | avg(age) | gender | + |--------------------+----------| + | 28.0 | F | + | 33.666666666666664 | M | + +--------------------+----------+ + +MAX +--- + +Description +>>>>>>>>>>> + +Usage: MAX(expr). Returns the maximum value of expr. + +Example:: + + os> source=accounts | stats max(age); + fetched rows / total rows = 1/1 + +------------+ + | max(age) | + |------------| + | 36 | + +------------+ + +MIN +--- + +Description +>>>>>>>>>>> + +Usage: MIN(expr). Returns the minimum value of expr. + +Example:: + + os> source=accounts | stats min(age); + fetched rows / total rows = 1/1 + +------------+ + | min(age) | + |------------| + | 28 | + +------------+ + +VAR_SAMP +-------- + +Description +>>>>>>>>>>> + +Usage: VAR_SAMP(expr). Returns the sample variance of expr. + +Example:: + + os> source=accounts | stats var_samp(age); + fetched rows / total rows = 1/1 + +--------------------+ + | var_samp(age) | + |--------------------| + | 10.916666666666666 | + +--------------------+ + +VAR_POP +------- + +Description +>>>>>>>>>>> + +Usage: VAR_POP(expr). Returns the population standard variance of expr. + +Example:: + + os> source=accounts | stats var_pop(age); + fetched rows / total rows = 1/1 + +----------------+ + | var_pop(age) | + |----------------| + | 8.1875 | + +----------------+ + +STDDEV_SAMP +----------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_SAMP(expr). Return the sample standard deviation of expr. + +Example:: + + os> source=accounts | stats stddev_samp(age); + fetched rows / total rows = 1/1 + +--------------------+ + | stddev_samp(age) | + |--------------------| + | 3.304037933599835 | + +--------------------+ + +STDDEV_POP +---------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_POP(expr). Return the population standard deviation of expr. + +Example:: + + os> source=accounts | stats stddev_pop(age); + fetched rows / total rows = 1/1 + +--------------------+ + | stddev_pop(age) | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + Example 1: Calculate the count of events ======================================== diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 6c6e5b73a1..45aa658783 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -5,4 +5,8 @@ SELECT SUM(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights +SELECT VAR_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index a8d134a254..c3f2715322 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -9,10 +9,18 @@ SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboar SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MIN(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, STDDEV_POP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, STDDEV_SAMP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, VAR_POP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, SUM(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, AVG(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, MAX(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, MIN(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights +SELECT FlightDelayMin, AvgTicketPrice, STDDEV_POP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, STDDEV_SAMP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, VAR_POP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce @@ -20,6 +28,8 @@ SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dash SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MIN(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, STDDEV_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user +SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce @@ -27,6 +37,8 @@ SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS nu SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MIN(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, STDDEV_POP(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user +SELECT user, VAR_POP(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user SELECT customer_gender, user, ROW_NUMBER() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT customer_gender, user, RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT customer_gender, user, DENSE_RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java index 53fd66ceef..28b9d41e83 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java @@ -19,7 +19,7 @@ public class Utils { /** * Utils to handle Nan Value. - * @return null if is Nan value. + * @return null if is Nan. */ public static Object handleNanValue(double value) { return Double.isNaN(value) ? null : value; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 0dbfec02c1..3d40258288 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; @@ -46,6 +47,7 @@ import org.opensearch.sql.opensearch.response.agg.FilterParser; import org.opensearch.sql.opensearch.response.agg.MetricParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.response.agg.StatsParser; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -124,6 +126,34 @@ public Pair visitNamedAggregator( condition, name, new SingleValueParser(name)); + case "var_samp": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getVarianceSampling,name)); + case "var_pop": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getVariancePopulation,name)); + case "stddev_samp": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getStdDeviationSampling,name)); + case "stddev_pop": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getStdDeviationPopulation,name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 85b3bd5a65..95a2383475 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -35,6 +35,10 @@ import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Arrays; @@ -53,6 +57,7 @@ import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; +import org.opensearch.sql.expression.aggregation.VarianceAggregator; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -185,6 +190,74 @@ void should_build_max_aggregation() { new MaxAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_varPop_aggregation() { + assertEquals( + "{\n" + + " \"var_pop(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("var_pop(age)", + variancePopulation(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_varSamp_aggregation() { + assertEquals( + "{\n" + + " \"var_samp(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("var_samp(age)", + varianceSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_stddevPop_aggregation() { + assertEquals( + "{\n" + + " \"stddev_pop(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("stddev_pop(age)", + stddevPopulation(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_stddevSamp_aggregation() { + assertEquals( + "{\n" + + " \"stddev_samp(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("stddev_samp(age)", + stddevSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 3874a0a50e..cb665f6c88 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -151,8 +151,10 @@ STDEV: 'STDEV'; STDEVP: 'STDEVP'; SUM: 'SUM'; SUMSQ: 'SUMSQ'; -VAR: 'VAR'; -VARP: 'VARP'; +VAR_SAMP: 'VAR_SAMP'; +VAR_POP: 'VAR_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; +STDDEV_POP: 'STDDEV_POP'; PERCENTILE: 'PERCENTILE'; FIRST: 'FIRST'; LAST: 'LAST'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 77aecf5a44..d552ad0756 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -139,7 +139,7 @@ statsFunction ; statsFunctionName - : AVG | COUNT | SUM | MIN | MAX + : AVG | COUNT | SUM | MIN | MAX | VAR_SAMP | VAR_POP | STDDEV_SAMP | STDDEV_POP ; percentileAggFunction diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 07ad97401e..71ef692abf 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -335,6 +335,90 @@ public void testAggFuncCallExpr() { )); } + @Test + public void testVarAggregationShouldPass() { + assertEqual("source=t | stats var_samp(a) by b", + agg( + relation("t"), + exprList( + alias( + "var_samp(a)", + aggregate("var_samp", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testVarpAggregationShouldPass() { + assertEqual("source=t | stats var_pop(a) by b", + agg( + relation("t"), + exprList( + alias( + "var_pop(a)", + aggregate("var_pop", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testStdDevAggregationShouldPass() { + assertEqual("source=t | stats stddev_samp(a) by b", + agg( + relation("t"), + exprList( + alias( + "stddev_samp(a)", + aggregate("stddev_samp", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testStdDevPAggregationShouldPass() { + assertEqual("source=t | stats stddev_pop(a) by b", + agg( + relation("t"), + exprList( + alias( + "stddev_pop(a)", + aggregate("stddev_pop", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + @Test public void testPercentileAggFuncExpr() { assertEqual("source=t | stats percentile<1>(a)", diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 94f8e7c87a..426c77cf06 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -126,6 +126,13 @@ COUNT: 'COUNT'; MAX: 'MAX'; MIN: 'MIN'; SUM: 'SUM'; +VAR_POP: 'VAR_POP'; +VAR_SAMP: 'VAR_SAMP'; +VARIANCE: 'VARIANCE'; +STD: 'STD'; +STDDEV: 'STDDEV'; +STDDEV_POP: 'STDDEV_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; // Common function Keywords diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 0ad08781bf..18c75b94ff 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -345,7 +345,7 @@ filterClause ; aggregationFunctionName - : AVG | COUNT | SUM | MIN | MAX + : AVG | COUNT | SUM | MIN | MAX | VAR_POP | VAR_SAMP | VARIANCE | STD | STDDEV | STDDEV_POP | STDDEV_SAMP ; mathematicalFunctionName diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index a3c8494e7a..e4e8028f05 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -410,6 +410,27 @@ public void filteredAggregation() { ); } + @Test + public void canBuildVarSamp() { + assertEquals( + aggregate("var_samp", qualifiedName("age")), + buildExprAst("var_samp(age)")); + } + + @Test + public void canBuildVarPop() { + assertEquals( + aggregate("var_pop", qualifiedName("age")), + buildExprAst("var_pop(age)")); + } + + @Test + public void canBuildVariance() { + assertEquals( + aggregate("variance", qualifiedName("age")), + buildExprAst("variance(age)")); + } + private Node buildExprAst(String expr) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(expr)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); From 9ff27939f9c3532263ea3cfe23f2722c70482236 Mon Sep 17 00:00:00 2001 From: Chloe Date: Fri, 11 Jun 2021 15:45:20 -0700 Subject: [PATCH 2/2] Fix the aggregation filter missing in named aggregators (#123) * Take the condition expression as property to the named aggregator when wrapping the delegated aggregator Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * Added test case where filtered agg is not pushed down Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh --- .../aggregation/NamedAggregator.java | 3 ++ .../opensearch/sql/analysis/AnalyzerTest.java | 40 +++++++++++++++++++ .../org/opensearch/sql/sql/AggregationIT.java | 40 +++++++++++++++++++ .../queries/{subquries.txt => subqueries.txt} | 0 4 files changed, 83 insertions(+) create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java rename integ-test/src/test/resources/correctness/queries/{subquries.txt => subqueries.txt} (100%) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java index a1bf2b9961..346bd2d28c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java @@ -54,6 +54,8 @@ public class NamedAggregator extends Aggregator { /** * NamedAggregator. + * The aggregator properties {@link #condition} is inherited by named aggregator + * to avoid errors introduced by the property inconsistency. * * @param name name * @param delegated delegated @@ -64,6 +66,7 @@ public NamedAggregator( super(delegated.getFunctionName(), delegated.getArguments(), delegated.returnType); this.name = name; this.delegated = delegated; + this.condition = delegated.condition; } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 9b42c70e32..fc45f34ffe 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -36,6 +36,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.compare; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; +import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -624,4 +625,43 @@ public void limit_offset() { ) ); } + + /** + * SELECT COUNT(NAME) FILTER(WHERE age > 1) FROM test. + * This test is to verify that the aggregator properties are taken + * when wrapping it to {@link org.opensearch.sql.expression.aggregation.NamedAggregator} + */ + @Test + public void named_aggregator_with_condition() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.relation("schema"), + ImmutableList.of( + DSL.named("count(string_value) filter(where integer_value > 1)", + dsl.count(DSL.ref("string_value", STRING)).condition(dsl.greater(DSL.ref( + "integer_value", INTEGER), DSL.literal(1)))) + ), + emptyList() + ), + DSL.named("count(string_value) filter(where integer_value > 1)", DSL.ref( + "count(string_value) filter(where integer_value > 1)", INTEGER)) + ), + AstDSL.project( + AstDSL.agg( + AstDSL.relation("schema"), + ImmutableList.of( + alias("count(string_value) filter(where integer_value > 1)", filteredAggregate( + "count", qualifiedName("string_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1))))), + emptyList(), + emptyList(), + emptyList() + ), + AstDSL.alias("count(string_value) filter(where integer_value > 1)", filteredAggregate( + "count", qualifiedName("string_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1)))) + ) + ); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java new file mode 100644 index 0000000000..3cbb222afe --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + * + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class AggregationIT extends SQLIntegTestCase { + @Override + protected void init() throws Exception { + loadIndex(Index.BANK); + } + + @Test + void filteredAggregateWithSubquery() throws IOException { + JSONObject response = executeQuery( + "SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK + + ") AS a"); + verifySchema(response, schema("COUNT(*)", null, "integer")); + verifyDataRows(response, rows(3)); + } +} diff --git a/integ-test/src/test/resources/correctness/queries/subquries.txt b/integ-test/src/test/resources/correctness/queries/subqueries.txt similarity index 100% rename from integ-test/src/test/resources/correctness/queries/subquries.txt rename to integ-test/src/test/resources/correctness/queries/subqueries.txt