Skip to content

Commit

Permalink
feat: add COUNT_DISTINCT and allow generics in UDAFs
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra committed Dec 16, 2019
1 parent 04de30e commit 195330b
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 6 deletions.
5 changes: 5 additions & 0 deletions ksql-engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@
<version>1.9.0</version>
</dependency>

<dependency>
<groupId>com.clearspring.analytics</groupId>
<artifactId>stream</artifactId>
</dependency>

<!-- Required for running tests -->

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -84,11 +85,6 @@ class UdafTypes {
ParameterInfo getInputSchema(final String inSchema) {
validateStructAnnotation(inputType, inSchema, "paramSchema");
final ParamType inputSchema = getSchemaFromType(inputType, inSchema);
//Currently, aggregate functions cannot have reified types as input parameters.
if (!GenericsUtil.constituentGenerics(inputSchema).isEmpty()) {
throw new KsqlException("Generic type parameters containing reified types are not currently"
+ " supported. " + functionInfo);
}
return new ParameterInfo("val", inputSchema, "", false);
}

Expand All @@ -103,7 +99,7 @@ ParamType getOutputSchema(final String outSchema) {
}

private void validateTypes(final Type t) {
if (isUnsupportedType((Class<?>) getRawType(t))) {
if (!(t instanceof TypeVariable) && isUnsupportedType((Class<?>) getRawType(t))) {
throw new KsqlException(String.format(invalidClassErrorMsg, t));
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"; you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.udaf.count;

import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.clearspring.analytics.stream.cardinality.RegisterSet;
import com.google.common.primitives.Ints;
import io.confluent.ksql.function.udaf.Udaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
import java.util.List;

@UdafDescription(
name = "COUNT_DISTINCT",
description = CountDistinct.DESCRIPTION
)
public class CountDistinct {

static final String DESCRIPTION = "This function returns the number of items found in a group. "
+ "The implementation is probabilistic with a typical accuracy (standard error) of less "
+ "than 1%.";

// magic number causes accuracy < .01 - we can consider making
// this configurable if the need arises
private static final int M = 1 << 14;
private static final int LOG_2_M = 14;

private CountDistinct() {
}

// NOTE: since our UDAF framework requires the aggregate values to
// be serializable, and we don't support serialization of native int[],
// this implementation can be optimized by avoiding conversions between
// int[] and List<Integer> - since RegisterSet requires an int[], we would
// need to duplicate a lot of code to get this to be zero-copy
private static <T> Udaf<T, List<Integer>, Long> countDistinct() {
return new Udaf<T, List<Integer>, Long>() {

@Override
public List<Integer> initialize() {
return Ints.asList(new int[RegisterSet.getSizeForCount(M)]);
}

@Override
public List<Integer> aggregate(T current, List<Integer> aggregate) {
if (current == null) {
return aggregate;
}

// this operation updates the underlying bytes
final int[] ints = Ints.toArray(aggregate);
final RegisterSet set = new RegisterSet(M, ints);

// this modifies the underlying ints
toHyperLogLog(set).offer(current);

return Ints.asList(ints);
}

@Override
public List<Integer> merge(List<Integer> aggOne, List<Integer> aggTwo) {
final RegisterSet registerSet = new RegisterSet(M, Ints.toArray(aggOne));
registerSet.merge(new RegisterSet(M, Ints.toArray(aggTwo)));

return Ints.asList(registerSet.bits());
}

@Override
public Long map(List<Integer> agg) {
return toHyperLogLog(new RegisterSet(M, Ints.toArray(agg))).cardinality();
}
};
}

@SuppressWarnings("deprecation")
private static HyperLogLog toHyperLogLog(final RegisterSet set) {
return new HyperLogLog(LOG_2_M, set);
}

@UdafFactory(description = "Count distinct")
public static <T> Udaf<T, List<Integer>, Long> distinct() {
return countDistinct();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"; you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.udaf.count;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

import com.google.common.primitives.Ints;
import io.confluent.ksql.function.udaf.Udaf;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.Test;

public class CountDistinctKudafTest {

@Test
public void shouldCountStrings() {
// Given:
final Udaf<String, List<Integer>, Long> udaf = CountDistinct.distinct();
final String[] values = IntStream
.range(0, 100)
.mapToObj(i -> String.valueOf(i % 4))
.toArray(String[]::new);

List<Integer> agg = udaf.initialize();

// When:
for (String value : values) {
agg = udaf.aggregate(value, agg);
}

// Then:
assertThat(udaf.map(agg), is(4L));
}

@Test
public void shouldCountList() {
// Given:
final Udaf<List<Integer>, List<Integer>, Long> udaf = CountDistinct.distinct();
final List<List<Integer>> values = IntStream
.range(0, 100)
.mapToObj(i -> Ints.asList(i % 4))
.collect(Collectors.toList());

List<Integer> agg = udaf.initialize();

// When:
for (List<Integer> value : values) {
agg = udaf.aggregate(value, agg);
}

// Then:
assertThat(udaf.map(agg), is(4L));
}

@Test
public void shouldIgnoreNulls() {
// Given:
final Udaf<String, List<Integer>, Long> udaf = CountDistinct.distinct();
List<Integer> agg = udaf.initialize();

// When:
agg = udaf.aggregate(null, agg);

// Then:
assertThat(udaf.map(agg), is(0L));
}

@Test
public void shouldMerge() {
// Given:
final Udaf<String, List<Integer>, Long> udaf = CountDistinct.distinct();
final String[] values1 = IntStream
.range(0, 100)
.mapToObj(i -> String.valueOf(i % 4))
.toArray(String[]::new);

List<Integer> agg1 = udaf.initialize();
List<Integer> agg2 = udaf.initialize();

// When:
for (String value : values1) {
agg1 = udaf.aggregate(value, agg1);
}

for (String value : new String[]{"5"}) {
agg2 = udaf.aggregate(value, agg2);
}

// Then:
assertThat(udaf.map(udaf.merge(agg1, agg2)), is(5L));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"tests": [
{
"name": "count distinct",
"statements": [
"CREATE STREAM TEST (ID varchar, NAME varchar) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE S2 as SELECT id, count_distinct(name) as count FROM test group by id;"
],
"inputs": [
{"topic": "test_topic", "key": "0", "value": {"id": "foo", "name": "one"}},
{"topic": "test_topic", "key": "0", "value": {"id": "foo", "name": "two"}},
{"topic": "test_topic", "key": "0", "value": {"id": "foo", "name": "one"}},
{"topic": "test_topic", "key": "0", "value": {"id": "foo", "name": "two"}},
{"topic": "test_topic", "key": "0", "value": {"id": "bar", "name": "one"}},
{"topic": "test_topic", "key": "0", "value": {"id": "foo", "name": null}}
],
"outputs": [
{"topic": "S2", "key": "foo" ,"value": {"ID": "foo", "COUNT": 1}},
{"topic": "S2", "key": "foo" ,"value": {"ID": "foo", "COUNT": 2}},
{"topic": "S2", "key": "foo" ,"value": {"ID": "foo", "COUNT": 2}},
{"topic": "S2", "key": "foo" ,"value": {"ID": "foo", "COUNT": 2}},
{"topic": "S2", "key": "bar" ,"value": {"ID": "bar", "COUNT": 1}},
{"topic": "S2", "key": "foo" ,"value": {"ID": "foo", "COUNT": 2}}
]
}
]
}
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
<avro.random.generator.version>0.2.2</avro.random.generator.version>
<apache.curator.version>2.9.0</apache.curator.version>
<wiremock.version>2.24.0</wiremock.version>
<clearspring-analytics.version>2.9.5</clearspring-analytics.version>
<skip.docker.build>true</skip.docker.build>
<skip.docker.test>true</skip.docker.test>
</properties>
Expand Down Expand Up @@ -372,6 +373,12 @@
<version>${javax-validation.version}</version>
</dependency>

<dependency>
<groupId>com.clearspring.analytics</groupId>
<artifactId>stream</artifactId>
<version>${clearspring-analytics.version}</version>
</dependency>

<!-- Required for running tests -->
<dependency>
<groupId>junit</groupId>
Expand Down

0 comments on commit 195330b

Please sign in to comment.