Skip to content

Commit

Permalink
feat: make UDAFs configurable and remove limit on COLLECT_LIST/SET (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Jan 13, 2021
1 parent e2cd29d commit 63ae169
Show file tree
Hide file tree
Showing 33 changed files with 1,088 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

package io.confluent.ksql.function;

import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
Expand All @@ -28,18 +30,41 @@
*/
public class AggregateFunctionInitArguments {

public static final AggregateFunctionInitArguments EMPTY_ARGS =
new AggregateFunctionInitArguments();

private final int udafIndex;
private final List<Object> initArgs;
private final Map<String, ?> config;

public static final AggregateFunctionInitArguments EMPTY_ARGS =
new AggregateFunctionInitArguments();
/**
* This method should only be used for legacy "built-in" UDAF
* implementations that implement AggregateFunctionFactory directly
* such as TopKAggregateFuncitonFactory. Otherwise, the config will
* not be properly passed through to the aggregate function.
*/
public AggregateFunctionInitArguments(
final int index,
final Object... initArgs
) {
this(index, ImmutableMap.of(/* not a configurable function */), Arrays.asList(initArgs));
}

public AggregateFunctionInitArguments(final int index, final Object... initArgs) {
this(index, Arrays.asList(initArgs));
public AggregateFunctionInitArguments(
final int index,
final Map<String, ?> config,
final Object... initArgs
) {
this(index, config, Arrays.asList(initArgs));
}

public AggregateFunctionInitArguments(final int index, final List<Object> initArgs) {
public AggregateFunctionInitArguments(
final int index,
final Map<String, ?> config,
final List<Object> initArgs
) {
this.udafIndex = index;
this.config = ImmutableMap.copyOf(Objects.requireNonNull(config, "config"));
this.initArgs = Objects.requireNonNull(initArgs);

if (index < 0) {
Expand All @@ -49,6 +74,7 @@ public AggregateFunctionInitArguments(final int index, final List<Object> initAr

private AggregateFunctionInitArguments() {
this.udafIndex = 0;
this.config = ImmutableMap.of();
this.initArgs = Collections.emptyList();
}

Expand All @@ -63,4 +89,8 @@ public Object arg(final int i) {
public List<Object> args() {
return initArgs;
}

public Map<String, ?> config() {
return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,10 @@ String getName() {
public static final Set<String> SSL_CONFIG_NAMES = sslConfigNames();
public static final Set<String> STREAM_TOPIC_CONFIG_NAMES = streamTopicConfigNames();

public static KsqlConfig empty() {
return new KsqlConfig(ImmutableMap.of());
}

private static ConfigDef configDef(final ConfigGeneration generation) {
return generation == ConfigGeneration.CURRENT ? CURRENT_DEF : LEGACY_DEF;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private List<SqlType> buildAllParams(
allParams.add(primitiveType);
} catch (final Exception e) {
throw new KsqlFunctionException("Only primitive init arguments are supported by UDAF "
+ getName() + ", but got " + arg);
+ getName() + ", but got " + arg, e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.kafka.common.Configurable;
import org.apache.kafka.common.metrics.Metrics;

class UdafFactoryInvoker implements FunctionSignature {
Expand Down Expand Up @@ -78,6 +79,11 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr
final Object[] factoryArgs = initArgs.args().toArray();
try {
final Udaf udaf = (Udaf)method.invoke(null, factoryArgs);

if (udaf instanceof Configurable) {
((Configurable) udaf).configure(initArgs.config());
}

final KsqlAggregateFunction function;
if (TableUdaf.class.isAssignableFrom(method.getReturnType())) {
function = new UdafTableAggregateFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,98 +15,109 @@

package io.confluent.ksql.function.udaf.array;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
import io.confluent.ksql.util.KsqlConstants;
import java.util.List;
import java.util.Map;
import org.apache.kafka.common.Configurable;

@UdafDescription(
name = "collect_list",
description = "Gather all of the values from an input grouping into a single Array field."
+ "\nAlthough this aggregate works on both Stream and Table inputs, the order of entries"
+ " in the result array is not guaranteed when working on Table input data."
+ "\nThis version limits the size of the resultant Array to 1000 entries, beyond which"
+ " any further values will be silently ignored.",
+ "\nYou may limit the size of the resultant Array to N entries, beyond which"
+ " any further values will be silently ignored, by setting the"
+ " ksql.functions.collect_list.limit configuration to N.",
author = KsqlConstants.CONFLUENT_AUTHOR
)
public final class CollectListUdaf {

@VisibleForTesting
static final int LIMIT = 1000;
public static final String LIMIT_CONFIG = "ksql.functions.collect_list.limit";

private CollectListUdaf() {
// just to make the checkstyle happy
}

private static <T> TableUdaf<T, List<T>, List<T>> listCollector() {
return new TableUdaf<T, List<T>, List<T>>() {

@Override
public List<T> initialize() {
return Lists.newArrayList();
}

@Override
public List<T> aggregate(final T thisValue, final List<T> aggregate) {
if (aggregate.size() < LIMIT) {
aggregate.add(thisValue);
}
return aggregate;
}

@Override
public List<T> merge(final List<T> aggOne, final List<T> aggTwo) {
final int remainingCapacity = LIMIT - aggOne.size();
aggOne.addAll(aggTwo.subList(0, Math.min(remainingCapacity, aggTwo.size())));
return aggOne;
}

@Override
public List<T> map(final List<T> agg) {
return agg;
}

@Override
public List<T> undo(final T valueToUndo, final List<T> aggregateValue) {
// A more ideal solution would remove the value which corresponded to the original insertion
// but keeping track of that is more complex so we just remove the last value for now.
final int lastIndex = aggregateValue.lastIndexOf(valueToUndo);
// If we cannot find the value, that means that we hit the limit and never inserted it, so
// just return.
if (lastIndex < 0) {
return aggregateValue;
}
aggregateValue.remove(lastIndex);
return aggregateValue;
}
};
}

@UdafFactory(description = "collect values of a Bigint field into a single Array")
public static TableUdaf<Long, List<Long>, List<Long>> createCollectListLong() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of an Integer field into a single Array")
public static TableUdaf<Integer, List<Integer>, List<Integer>> createCollectListInt() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of a Double field into a single Array")
public static TableUdaf<Double, List<Double>, List<Double>> createCollectListDouble() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of a String/Varchar field into a single Array")
public static TableUdaf<String, List<String>, List<String>> createCollectListString() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of a Boolean field into a single Array")
public static TableUdaf<Boolean, List<Boolean>, List<Boolean>> createCollectListBool() {
return listCollector();
return new Collect<>();
}

private static final class Collect<T> implements TableUdaf<T, List<T>, List<T>>, Configurable {

private int limit = Integer.MAX_VALUE;

@Override
public void configure(final Map<String, ?> map) {
final Object limit = map.get(LIMIT_CONFIG);
this.limit = (limit == null) ? this.limit : ((Number) limit).intValue();

if (this.limit < 0) {
this.limit = Integer.MAX_VALUE;
}
}

@Override
public List<T> initialize() {
return Lists.newArrayList();
}

@Override
public List<T> aggregate(final T thisValue, final List<T> aggregate) {
if (aggregate.size() < limit) {
aggregate.add(thisValue);
}
return aggregate;
}

@Override
public List<T> merge(final List<T> aggOne, final List<T> aggTwo) {
final int remainingCapacity = limit - aggOne.size();
aggOne.addAll(aggTwo.subList(0, Math.min(remainingCapacity, aggTwo.size())));
return aggOne;
}

@Override
public List<T> map(final List<T> agg) {
return agg;
}

@Override
public List<T> undo(final T valueToUndo, final List<T> aggregateValue) {
// A more ideal solution would remove the value which corresponded to the original insertion
// but keeping track of that is more complex so we just remove the last value for now.
final int lastIndex = aggregateValue.lastIndexOf(valueToUndo);
// If we cannot find the value, that means that we hit the limit and never inserted it, so
// just return.
if (lastIndex < 0) {
return aggregateValue;
}
aggregateValue.remove(lastIndex);
return aggregateValue;
}
}
}
Loading

0 comments on commit 63ae169

Please sign in to comment.