From da4792dcd2a4890072ece6dea354d5a3a1253d9c Mon Sep 17 00:00:00 2001 From: findinpath Date: Thu, 17 Jun 2021 00:31:42 +0200 Subject: [PATCH] Add support for generic parameters in `make_set_digest` function A widely applied usecase for minhash algorithm is to check the similarity of two texts. Therefore the current implementation adds support for creating a `setdigest` aggregation for (among other types) varchar slices. --- .../setdigest/BuildSetDigestAggregation.java | 17 ++++++++++++++-- .../sql/query/TestSetDigestFunctions.java | 20 ++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java b/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java index 6f17380f6057..ecade41a55b5 100644 --- a/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java +++ b/core/trino-main/src/main/java/io/trino/type/setdigest/BuildSetDigestAggregation.java @@ -14,13 +14,15 @@ package io.trino.type.setdigest; +import io.airlift.slice.Slice; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; import io.trino.spi.function.InputFunction; import io.trino.spi.function.OutputFunction; import io.trino.spi.function.SqlType; -import io.trino.spi.type.StandardTypes; +import io.trino.spi.function.TypeParameter; @AggregationFunction("make_set_digest") public final class BuildSetDigestAggregation @@ -30,7 +32,18 @@ public final class BuildSetDigestAggregation private BuildSetDigestAggregation() {} @InputFunction - public static void input(SetDigestState state, @SqlType(StandardTypes.BIGINT) long value) + @TypeParameter("T") + public static void input(@AggregationState SetDigestState state, @SqlType("T") long value) + { + if (state.getDigest() == null) { + state.setDigest(new SetDigest()); + } + state.getDigest().add(value); + } + + @InputFunction + @TypeParameter("T") + public static void input(@AggregationState SetDigestState state, @SqlType("T") Slice value) { if (state.getDigest() == null) { state.setDigest(new SetDigest()); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java b/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java index b49da9379387..8aec6613ebb2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestSetDigestFunctions.java @@ -37,7 +37,7 @@ public void teardown() } @Test - public void testCardinality() + public void testCardinalityForBigintSetDigest() { assertThat(assertions.query( "SELECT cardinality(make_set_digest(value)) " + @@ -45,6 +45,24 @@ public void testCardinality() .matches("VALUES CAST(5 AS BIGINT)"); } + @Test + public void testCardinalityForVarcharSetDigest() + { + assertThat(assertions.query( + "SELECT cardinality(make_set_digest(value)) " + + "FROM (VALUES 'trino', 'sql', 'everything', 'sql', 'trino') T(value)")) + .matches("VALUES CAST(3 AS BIGINT)"); + } + + @Test + public void testCardinalityForDateSetDigest() + { + assertThat(assertions.query( + "SELECT cardinality(make_set_digest(value)) " + + "FROM (VALUES DATE '2001-08-22', DATE '2001-08-22', DATE '2001-08-23') T(value)")) + .matches("VALUES CAST(2 AS BIGINT)"); + } + @Test public void testExactIntersectionCardinality() {