Skip to content

Commit

Permalink
Change type of AVG aggregates to double (apache#15089)
Browse files Browse the repository at this point in the history
The sql standard is not very restrictive regarding this:

If AVG is specified and DT is exact numeric, then the declared type of the result is an implemen-
tation-defined exact numeric type with precision not less than the precision of DT and scale not
less than the scale of DT.

so; using the same type is also ok (without patch);
however the avg of 0 and 1 is 0 right now because of the retention of the integer typ

Postgres,MySql and Oracle and Drill seem to increase precision ; mssql returns 0
http://sqlfiddle.com/#!9/6f7248/1

I think we should also increase precision as its already calculated more precisely
  • Loading branch information
kgyrtkirk authored and ektravel committed Oct 16, 2023
1 parent 374e908 commit 92ab0e4
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ public void testAvgDailyCountDistinctHllSketch()

final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
1L
1.0
}
);

Expand Down Expand Up @@ -429,11 +429,11 @@ public void testAvgDailyCountDistinctHllSketch()
.setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
notNull("a0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ public void testAvgDailyCountDistinctThetaSketch()

final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
1L
1.0
}
);

Expand Down Expand Up @@ -334,11 +334,11 @@ public void testAvgDailyCountDistinctThetaSketch()
.setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
notNull("a0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
Expand Down Expand Up @@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
private static final String STDDEV_NAME = "STDDEV";

private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(VARIANCE_NAME);
buildSqlVarianceAggFunction(VARIANCE_NAME);
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
buildSqlVarianceAggFunction(SqlKind.VAR_POP.name());
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name());
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(STDDEV_NAME);
buildSqlVarianceAggFunction(STDDEV_NAME);
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name());
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name());

@Nullable
@Override
Expand Down Expand Up @@ -160,14 +161,15 @@ public Aggregation toDruidAggregation(
}

/**
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
* Creates a {@link SqlAggFunction}
*
* It accepts variance aggregator objects in addition to numeric inputs.
*/
private static SqlAggFunction buildSqlAvgAggFunction(String name)
private static SqlAggFunction buildSqlVarianceAggFunction(String name)
{
return OperatorConversions
.aggregatorBuilder(name)
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
.returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE))
.operandTypeChecker(
OperandTypes.or(
OperandTypes.NUMERIC,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ public void testVarPop()
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(true),
holder2.getVariance(true).doubleValue(),
holder3.getVariance(true).longValue()
holder2.getVariance(true),
holder3.getVariance(true)
}
);
testQuery(
Expand Down Expand Up @@ -219,7 +219,7 @@ public void testVarSamp()
new Object[] {
holder1.getVariance(false),
holder2.getVariance(false).doubleValue(),
holder3.getVariance(false).longValue(),
holder3.getVariance(false),
}
);
testQuery(
Expand Down Expand Up @@ -266,7 +266,7 @@ public void testStdDevPop()
new Object[] {
Math.sqrt(holder1.getVariance(true)),
Math.sqrt(holder2.getVariance(true)),
(long) Math.sqrt(holder3.getVariance(true)),
Math.sqrt(holder3.getVariance(true)),
}
);

Expand Down Expand Up @@ -321,7 +321,7 @@ public void testStdDevSamp()
new Object[]{
Math.sqrt(holder1.getVariance(false)),
Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
Math.sqrt(holder3.getVariance(false)),
}
);

Expand Down Expand Up @@ -374,7 +374,7 @@ public void testStdDevWithVirtualColumns()
new Object[]{
Math.sqrt(holder1.getVariance(false)),
Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
Math.sqrt(holder3.getVariance(false)),
}
);

Expand Down Expand Up @@ -543,7 +543,7 @@ public void testEmptyTimeseriesResults()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
: new Object[]{null, null, null, null, null, null, null, null}
)
);
Expand Down Expand Up @@ -623,7 +623,7 @@ public void testGroupByAggregatorDefaultValues()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
: new Object[]{"a", null, null, null, null, null, null, null, null}
)
);
Expand Down Expand Up @@ -688,9 +688,9 @@ public void assertResultsEquals(String sql, List<Object[]> expectedResults, List
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,7 @@ public RelDataType deriveAvgAggType(
final RelDataType argumentType
)
{
// Widen all averages to 64-bits regardless of the size of the inputs.

if (SqlTypeName.INT_TYPES.contains(argumentType.getSqlTypeName())) {
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.BIGINT, argumentType.isNullable());
} else {
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable());
}
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
Expand Down Expand Up @@ -127,7 +128,7 @@ public void testCorrelatedSubquery(Map<String, Object> queryContext)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
useDefault
? new CountAggregatorFactory("_a0:count")
: new FilteredAggregatorFactory(
Expand Down Expand Up @@ -158,15 +159,15 @@ public void testCorrelatedSubquery(Map<String, Object> queryContext)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setAggregatorSpecs(new DoubleAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"India", 2L},
new Object[]{"USA", 1L},
new Object[]{"canada", 3L}
new Object[]{"India", 2.0},
new Object[]{"USA", 1.0},
new Object[]{"canada", 3.0}
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public void testParamsInInformationSchema()
+ "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
ImmutableList.of(),
ImmutableList.of(
new Object[]{8L, 1249L, 156L, -5L, 1111L}
new Object[]{8L, 1249L, 156.125, -5L, 1111L}
),
ImmutableList.of(
new SqlParameter(SqlType.VARCHAR, "druid"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ public void testAggregatorsOnInformationSchemaColumns()
+ "WHERE TABLE_SCHEMA = 'druid' AND TABLE_NAME = 'foo'",
ImmutableList.of(),
ImmutableList.of(
new Object[]{8L, 1249L, 156L, -5L, 1111L}
new Object[]{8L, 1249L, 156.125, -5L, 1111L}
)
);
}
Expand Down Expand Up @@ -4942,7 +4942,7 @@ public void testSimpleAggregations()
new CountAggregatorFactory("a1"),
notNull("dim1")
),
new LongSumAggregatorFactory("a2:sum", "cnt"),
new DoubleSumAggregatorFactory("a2:sum", "cnt"),
new CountAggregatorFactory("a2:count"),
new LongSumAggregatorFactory("a3", "cnt"),
new LongMinAggregatorFactory("a4", "cnt"),
Expand All @@ -4964,7 +4964,7 @@ public void testSimpleAggregations()
new CountAggregatorFactory("a2"),
notNull("dim1")
),
new LongSumAggregatorFactory("a3:sum", "cnt"),
new DoubleSumAggregatorFactory("a3:sum", "cnt"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a3:count"),
notNull("cnt")
Expand Down Expand Up @@ -5014,10 +5014,10 @@ public void testSimpleAggregations()
),
NullHandling.replaceWithDefault() ?
ImmutableList.of(
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
new Object[]{6L, 6L, 5L, 1.0, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
) :
ImmutableList.of(
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
new Object[]{6L, 6L, 6L, 1.0, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
)
);
}
Expand Down Expand Up @@ -7429,11 +7429,11 @@ public void testAvgDailyCountDistinct()
.setAggregatorSpecs(
useDefault
? aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
notNull("a0")
Expand All @@ -7455,7 +7455,7 @@ public void testAvgDailyCountDistinct()
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L})
ImmutableList.of(new Object[]{1.0})
);
}

Expand Down Expand Up @@ -9641,7 +9641,7 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValues()
new LongSumAggregatorFactory("a6", "l1"),
new LongMaxAggregatorFactory("a7", "l1"),
new LongMinAggregatorFactory("a8", "l1"),
new LongSumAggregatorFactory("a9:sum", "l1"),
new DoubleSumAggregatorFactory("a9:sum", "l1"),
useDefault
? new CountAggregatorFactory("a9:count")
: new FilteredAggregatorFactory(
Expand Down Expand Up @@ -9690,7 +9690,7 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValues()
0L,
Long.MIN_VALUE,
Long.MAX_VALUE,
0L,
Double.NaN,
Double.NaN
}
: new Object[]{0L, 0L, 0L, null, null, null, null, null, null, null, null}
Expand Down Expand Up @@ -9936,7 +9936,7 @@ public void testGroupByAggregatorDefaultValues()
equality("dim1", "nonexistent", ColumnType.STRING)
),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a9:sum", "l1"),
new DoubleSumAggregatorFactory("a9:sum", "l1"),
equality("dim1", "nonexistent", ColumnType.STRING)
),
useDefault
Expand Down Expand Up @@ -10005,7 +10005,7 @@ public void testGroupByAggregatorDefaultValues()
0L,
Long.MIN_VALUE,
Long.MAX_VALUE,
0L,
Double.NaN,
Double.NaN
}
: new Object[]{"a", 0L, 0L, 0L, null, null, null, null, null, null, null, null}
Expand Down Expand Up @@ -13147,7 +13147,7 @@ public void testCountAndAverageByConstantVirtualColumn()
new CountAggregatorFactory("a0"),
notNull("v0")
),
new LongSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
new DoubleSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
new CountAggregatorFactory("a1:count")
);
virtualColumns = ImmutableList.of(
Expand All @@ -13160,7 +13160,7 @@ public void testCountAndAverageByConstantVirtualColumn()
new CountAggregatorFactory("a0"),
notNull("v0")
),
new LongSumAggregatorFactory("a1:sum", "v1"),
new DoubleSumAggregatorFactory("a1:sum", "v1"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a1:count"),
notNull("v1")
Expand Down Expand Up @@ -13204,7 +13204,7 @@ public void testCountAndAverageByConstantVirtualColumn()
.build()
),
ImmutableList.of(
new Object[]{"ab", 1L, 325323L}
new Object[]{"ab", 1L, 325323.0}
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongMinAggregatorFactory;
Expand Down Expand Up @@ -558,14 +559,14 @@ public void testMinMaxAvgDailyCountWithLimit()
aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new DoubleSumAggregatorFactory("_a2:sum", "a0"),
new CountAggregatorFactory("_a2:count"),
new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4")
) : aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new DoubleSumAggregatorFactory("_a2:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a2:count"),
notNull("a0")
Expand All @@ -590,7 +591,7 @@ public void testMinMaxAvgDailyCountWithLimit()
.setContext(queryContext)
.build()
),
ImmutableList.of(new Object[]{1L, 1L, 1L, 978480000L, 6L})
ImmutableList.of(new Object[]{1L, 1L, 1.0, 978480000L, 6L})
);
}

Expand Down

0 comments on commit 92ab0e4

Please sign in to comment.