Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow for decimals to be used as input types for UDFs #3217

Merged
merged 1 commit into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.confluent.ksql.schema.connect.SqlSchemaFormatter;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -262,6 +263,7 @@ static final class Parameter {
.put(Type.MAP, Parameter::mapEquals)
.put(Type.ARRAY, Parameter::arrayEquals)
.put(Type.STRUCT, Parameter::structEquals)
.put(Type.BYTES, Parameter::bytesEquals)
.build();

private final Schema schema;
Expand Down Expand Up @@ -316,7 +318,6 @@ boolean accepts(final Schema argument, final Map<Schema, Schema> reservedGeneric
return Objects.equals(type, argument.type())
&& CUSTOM_SCHEMA_EQ.getOrDefault(type, (a, b) -> true).test(schema, argument)
&& Objects.equals(schema.version(), argument.version())
&& Objects.equals(schema.parameters(), argument.parameters())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only the decimal type that may contain parameters in their schemas? (Trying to understand whether removing this portion of the check is safe.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at the moment yes - it's the only one that we use that has parameters. I think in general it makes sense to defer the parameter to check to the CUSTOM_SCHEMA_EQ instead of checking that they are the same.

&& Objects.deepEquals(schema.defaultValue(), argument.defaultValue());
}
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity
Expand Down Expand Up @@ -356,6 +357,13 @@ private static boolean structEquals(final Schema structA, final Schema structB)
|| Objects.equals(structA.fields(), structB.fields());
}

private static boolean bytesEquals(final Schema bytesA, final Schema bytesB) {
// from a UDF parameter perspective, all decimals are the same
// since they can all be cast to BigDecimal - other bytes types
// are not supported in UDFs
return DecimalUtil.isDecimal(bytesA) && DecimalUtil.isDecimal(bytesB);
}

@Override
public String toString() {
return FORMATTER.format(schema) + (isVararg ? "(VARARG)" : "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ protected String visitBytes(final Schema schema) {
+ DecimalUtil.scale(schema) + ")";
}

throw new KsqlException("Cannot format bytes type: " + schema);
return "BYTES";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still the case that the only bytes type we currently support is Decimals, right? I'm having trouble understanding why this change is necessary (besides the pretty error message in UdfIndexTest#shouldNotFindArbitraryBytesTypes`).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the pretty error message - if you can't find a UDF we shouldn't throw "Cannot format bytes type". The other BYTES type that we support is also the way we represent Generics.

}

private final class Converter implements SchemaWalker.Visitor<String, String> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.util.Arrays;
Expand All @@ -30,6 +31,8 @@ public class UdfIndexTest {
private static final Schema STRUCT3_PERMUTE = SchemaBuilder.struct().field("d", INT).field("c", INT).build();
private static final Schema MAP1 = SchemaBuilder.map(STRING, STRING).build();
private static final Schema MAP2 = SchemaBuilder.map(STRING, INT).build();
private static final Schema DECIMAL1 = DecimalUtil.builder(2, 1).build();
private static final Schema DECIMAL2 = DecimalUtil.builder(3, 1).build();

private static final Schema GENERIC_LIST = GenericsUtil.array("T").build();
private static final Schema STRING_LIST = SchemaBuilder.array(STRING).build();
Expand Down Expand Up @@ -156,6 +159,32 @@ public void shouldChooseCorrectMap() {
assertThat(fun.getFunctionName(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectDecimal() {
// Given:
final KsqlFunction[] functions = new KsqlFunction[]{function(EXPECTED, false, DECIMAL1)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// When:
final KsqlFunction fun = udfIndex.getFunction(ImmutableList.of(DECIMAL1));

// Then:
assertThat(fun.getFunctionName(), equalTo(EXPECTED));
}

@Test
public void shouldAllowAnyDecimal() {
// Given:
final KsqlFunction[] functions = new KsqlFunction[]{function(EXPECTED, false, DECIMAL1)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// When:
final KsqlFunction fun = udfIndex.getFunction(ImmutableList.of(DECIMAL2));

// Then:
assertThat(fun.getFunctionName(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectPermutedStruct() {
// Given:
Expand Down Expand Up @@ -683,6 +712,20 @@ public void shouldNotMatchNestedGenericMethodWithAlreadyReservedTypes() {
udfIndex.getFunction(ImmutableList.of(INT_LIST, STRING_LIST));
}

@Test
public void shouldNotFindArbitraryBytesTypes() {
// Given:
final KsqlFunction[] functions = new KsqlFunction[]{function(EXPECTED, false, DECIMAL1)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectMessage(is("Function 'name' does not accept parameters of types:"
+ "[BYTES]"));

// When:
udfIndex.getFunction(ImmutableList.of(SchemaBuilder.bytes().build()));
}

private static KsqlFunction function(
final String name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ public void shouldFormatDecimal() {
assertThat(STRICT.format(DecimalUtil.builder(2, 1).build()), is("DECIMAL(2, 1)"));
}

@Test
public void shouldFormatOptionalBytes() {
assertThat(DEFAULT.format(Schema.OPTIONAL_BYTES_SCHEMA), is("BYTES"));
assertThat(STRICT.format(Schema.OPTIONAL_BYTES_SCHEMA), is("BYTES"));
}


@Test
public void shouldFormatBytes() {
assertThat(DEFAULT.format(Schema.BYTES_SCHEMA), is("BYTES"));
assertThat(STRICT.format(Schema.BYTES_SCHEMA), is("BYTES NOT NULL"));
}

@Test
public void shouldFormatArray() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.confluent.ksql.function.udf.json.JsonExtractStringKudf;
import io.confluent.ksql.function.udf.math.AbsKudf;
import io.confluent.ksql.function.udf.math.CeilKudf;
import io.confluent.ksql.function.udf.math.FloorKudf;
import io.confluent.ksql.function.udf.math.RandomKudf;
import io.confluent.ksql.function.udf.math.RoundKudf;
import io.confluent.ksql.function.udf.string.ConcatKudf;
Expand Down Expand Up @@ -238,12 +237,6 @@ private void addMathFunctions() {
"CEIL",
CeilKudf.class));

addBuiltInFunction(KsqlFunction.createLegacyBuiltIn(
Schema.OPTIONAL_FLOAT64_SCHEMA,
Collections.singletonList(Schema.OPTIONAL_FLOAT64_SCHEMA),
"FLOOR",
FloorKudf.class));

addBuiltInFunction(KsqlFunction.createLegacyBuiltIn(
Schema.OPTIONAL_INT64_SCHEMA,
Collections.singletonList(Schema.OPTIONAL_FLOAT64_SCHEMA),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
package io.confluent.ksql.function;

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -40,6 +42,10 @@ public final class UdfUtil {
.put(long.class, SchemaBuilder::int64)
.put(Double.class, () -> SchemaBuilder.float64().optional())
.put(double.class, SchemaBuilder::float64)
// from the UDF perspective, all Decimal schemas are the same (BigDecimal) in Java
// so we arbitrarily choose DECIMAL(1,0). if we migrate to use a type system dedicated
// for UDFs, we can update this to be a "generic decimal"
.put(BigDecimal.class, () -> DecimalUtil.builder(1, 0).optional())
.build();

private UdfUtil() {
Expand Down Expand Up @@ -93,7 +99,9 @@ static Schema getSchemaFromType(final Type type, final String name, final String
schema = GenericsUtil.generic(((TypeVariable) type).getName());
} else {
schema = typeToSchema.getOrDefault(type, () -> handleParametrizedType(type)).get();
schema.name(name);
if (schema.name() == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change necessary? (Why would the schema already have a name at this point?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is something that I feel like we should change in the long run, but decimals are defined by their schema name (they all have org.apache.kafka.connect.data.Decimal as their schema name). we should not be using the schema name to define the parameter name, but that's out of scope

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad. I even followed the Decimal schema builder code through but somehow missed it was setting the name. Thanks for the explanations!

schema.name(name);
}
}

schema.doc(doc);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.udf.math;

import io.confluent.ksql.function.udf.Udf;
import io.confluent.ksql.function.udf.UdfDescription;
import io.confluent.ksql.function.udf.UdfParameter;
import java.math.BigDecimal;

@UdfDescription(name = "Floor", description = Floor.DESCRIPTION)
public class Floor {

static final String DESCRIPTION = "Returns the largest integer less than or equal to the "
+ "specified numeric expression. NOTE: for backwards compatibility, this returns a DOUBLE "
+ "that has a mantissa of zero.";


@Udf
public Double floor(@UdfParameter final Integer val) {
return (val == null) ? null : Math.floor(val);
}

@Udf
public Double floor(@UdfParameter final Long val) {
return (val == null) ? null : Math.floor(val);
}

@Udf
public Double floor(@UdfParameter final Double val) {
return (val == null) ? null : Math.floor(val);
}

@Udf
public Double floor(@UdfParameter final BigDecimal val) {
return (val == null) ? null : Math.floor(val.doubleValue());
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ public void shouldHaveBuiltInUDFRegistered() {
// String UDF
"LCASE", "UCASE", "CONCAT", "TRIM", "IFNULL", "LEN",
// Math UDF
"ABS", "CEIL", "FLOOR", "ROUND", "RANDOM",
"ABS", "CEIL", "ROUND", "RANDOM",
// JSON UDF
"EXTRACTJSONFIELD", "ARRAYCONTAINS",
// Struct UDF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
Expand All @@ -34,6 +35,7 @@
import io.confluent.ksql.function.udf.Udf;
import io.confluent.ksql.function.udf.UdfDescription;
import io.confluent.ksql.function.udf.UdfParameter;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.io.File;
Expand Down Expand Up @@ -133,6 +135,19 @@ public void shouldLoadStructUdafs() {
equalTo(new Struct(schema).put("A", 1).put("B", 2)));
}

@Test
public void shouldLoadDecimalUdfs() {
// Given:
final Schema schema = DecimalUtil.builder(2, 1).optional().build();

// When:
final KsqlFunction fun = FUNC_REG.getUdfFactory("floor")
.getFunction(ImmutableList.of(schema));

// Then:
assertThat(fun.getFunctionName(), equalToIgnoringCase("floor"));
}

@Test
public void shouldLoadFunctionsFromJarsInPluginDir() {
final UdfFactory toString = FUNC_REG.getUdfFactory("tostring");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;

import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import org.apache.kafka.connect.data.Schema;
Expand Down Expand Up @@ -113,6 +115,12 @@ public void shouldGetFloatSchemaForDoublePrimitiveClass() {
equalTo(Schema.FLOAT64_SCHEMA));
}

@Test
public void shouldGetDecimalSchemaForBigDecimalClass() {
assertThat(UdfUtil.getSchemaFromType(BigDecimal.class).name(),
is(DecimalUtil.builder(2, 1).name()));
}

@Test
public void shouldGetMapSchemaFromMapClass() throws NoSuchMethodException {
final Type type = getClass().getDeclaredMethod("mapType", Map.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.hamcrest.Matchers.is;

import io.confluent.ksql.function.InternalFunctionRegistry;
import io.confluent.ksql.function.TestFunctionRegistry;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.metastore.model.MetaStoreMatchers.OptionalMatchers;
Expand Down Expand Up @@ -50,7 +51,7 @@ public class LogicalPlannerTest {

@Before
public void init() {
metaStore = MetaStoreFixture.getNewMetaStore(new InternalFunctionRegistry());
metaStore = MetaStoreFixture.getNewMetaStore(TestFunctionRegistry.INSTANCE.get());
ksqlConfig = new KsqlConfig(Collections.emptyMap());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@
{"topic": "OUTPUT", "value": {"I": 0, "L": 0, "D": 0}},
{"topic": "OUTPUT", "value": {"I": 1, "L": 1, "D": 1}}
]
},
{
"name": "floor",
"statements": [
"CREATE STREAM INPUT (i INT, l BIGINT, d DOUBLE, b DECIMAL(2,1)) WITH (kafka_topic='input', value_format='AVRO');",
"CREATE STREAM OUTPUT AS SELECT floor(i) i, floor(l) l, floor(d) d, floor(b) b FROM INPUT;"
],
"inputs": [
{"topic": "input", "value": {"i": null, "l": null, "d": null}},
{"topic": "input", "value": {"i": -1, "l": -2, "d": -3.1, "b": "-3.1"}},
{"topic": "input", "value": {"i": 0, "l": 0, "d": 0.0, "b": "0.0"}},
{"topic": "input", "value": {"i": 1, "l": 2, "d": 3.1, "b": "3.1"}}
],
"outputs": [
{"topic": "OUTPUT", "value": {"I": null, "L": null, "D": null, "B": null}},
{"topic": "OUTPUT", "value": {"I": -1.0, "L": -2.0, "D": -4.0, "B": -4.0}},
{"topic": "OUTPUT", "value": {"I": 0.0, "L": 0.0, "D": 0.0, "B": 0.0}},
{"topic": "OUTPUT", "value": {"I": 1.0, "L": 2.0, "D": 3.0, "B": 3.0}}
]
}
]
}