diff --git a/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java b/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java index 6f17380f6057..ecade41a55b5 100644 --- a/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java +++ b/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java @@ -14,13 +14,15 @@ package io.trino.type.setdigest; +import io.airlift.slice.Slice; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; import io.trino.spi.function.InputFunction; import io.trino.spi.function.OutputFunction; import io.trino.spi.function.SqlType; -import io.trino.spi.type.StandardTypes; +import io.trino.spi.function.TypeParameter; @AggregationFunction("make_set_digest") public final class BuildSetDigestAggregation @@ -30,7 +32,18 @@ public final class BuildSetDigestAggregation private BuildSetDigestAggregation() {} @InputFunction - public static void input(SetDigestState state, @SqlType(StandardTypes.BIGINT) long value) + @TypeParameter("T") + public static void input(@AggregationState SetDigestState state, @SqlType("T") long value) + { + if (state.getDigest() == null) { + state.setDigest(new SetDigest()); + } + state.getDigest().add(value); + } + + @InputFunction + @TypeParameter("T") + public static void input(@AggregationState SetDigestState state, @SqlType("T") Slice value) { if (state.getDigest() == null) { state.setDigest(new SetDigest()); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java b/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java index b49da9379387..8aec6613ebb2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java @@ -37,7 +37,7 @@ public void teardown() } @Test - public void testCardinality() + public void testCardinalityForBigintSetDigest() { assertThat(assertions.query( "SELECT cardinality(make_set_digest(value)) " + @@ -45,6 +45,24 @@ public void testCardinality() .matches("VALUES CAST(5 AS BIGINT)"); } + @Test + public void testCardinalityForVarcharSetDigest() + { + assertThat(assertions.query( + "SELECT cardinality(make_set_digest(value)) " + + "FROM (VALUES 'trino', 'sql', 'everything', 'sql', 'trino') T(value)")) + .matches("VALUES CAST(3 AS BIGINT)"); + } + + @Test + public void testCardinalityForDateSetDigest() + { + assertThat(assertions.query( + "SELECT cardinality(make_set_digest(value)) " + + "FROM (VALUES DATE '2001-08-22', DATE '2001-08-22', DATE '2001-08-23') T(value)")) + .matches("VALUES CAST(2 AS BIGINT)"); + } + @Test public void testExactIntersectionCardinality() {