Skip to content

Commit

Permalink
Adding safe divide function (#11904)
Browse files Browse the repository at this point in the history
* IMPLY-4344: Adding safe divide function along with testcases and documentation updates

* Changing based on review comments

* Addressing review comments, fixing coding style, docs and spelling

* Checkstyle passes for all code

* Fixing expected results for infinity

* Revert "Fixing expected results for infinity"

This reverts commit 5fd5cd4.

* Updating test result and a space in docs
  • Loading branch information
somu-imply authored Nov 17, 2021
1 parent d76e646 commit 2971078
Show file tree
Hide file tree
Showing 9 changed files with 1,056 additions and 474 deletions.
62 changes: 57 additions & 5 deletions core/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
/**
* Base interface describing the mechanism used to evaluate a {@link FunctionExpr}. All {@link Function} implementations
* are immutable.
*
* <p>
* Do NOT remove "unused" members in this class. They are used by generated Antlr
*/
@SuppressWarnings("unused")
Expand Down Expand Up @@ -1165,6 +1165,51 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
}
}

class SafeDivide extends BivariateMathFunction
{
public static final String NAME = "safe_divide";

@Override
public String name()
{
return NAME;
}

@Nullable
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
return ExpressionTypeConversion.function(
args.get(0).getOutputType(inspector),
args.get(1).getOutputType(inspector)
);
}

@Override
protected ExprEval eval(final long x, final long y)
{
if (y == 0) {
if (x != 0) {
return ExprEval.ofLong(NullHandling.defaultLongValue());
}
return ExprEval.ofLong(0);
}
return ExprEval.ofLong(x / y);
}

@Override
protected ExprEval eval(final double x, final double y)
{
if (y == 0 || Double.isNaN(y)) {
if (x != 0) {
return ExprEval.ofDouble(NullHandling.defaultDoubleValue());
}
return ExprEval.ofDouble(0);
}
return ExprEval.ofDouble(x / y);
}
}

class Div extends BivariateMathFunction
{
@Override
Expand Down Expand Up @@ -1932,7 +1977,9 @@ protected ExprEval eval(ExprEval x, ExprEval y)
public Set<Expr> getScalarInputs(List<Expr> args)
{
if (args.get(1).isLiteral()) {
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()));
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1)
.getLiteralValue()
.toString()));
switch (castTo.getType()) {
case ARRAY:
return Collections.emptySet();
Expand All @@ -1948,7 +1995,9 @@ public Set<Expr> getScalarInputs(List<Expr> args)
public Set<Expr> getArrayInputs(List<Expr> args)
{
if (args.get(1).isLiteral()) {
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()));
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1)
.getLiteralValue()
.toString()));
switch (castTo.getType()) {
case LONG:
case DOUBLE:
Expand Down Expand Up @@ -3237,7 +3286,9 @@ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
break;
}
}
return index < 0 ? ExprEval.ofLong(NullHandling.replaceWithDefault() ? -1 : null) : ExprEval.ofLong(index + 1);
return index < 0
? ExprEval.ofLong(NullHandling.replaceWithDefault() ? -1 : null)
: ExprEval.ofLong(index + 1);
default:
throw new IAE("Function[%s] 2nd argument must be a a scalar type", name());
}
Expand Down Expand Up @@ -3591,7 +3642,8 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
name()
);
}
ExpressionType complexType = ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
ExpressionType complexType = ExpressionTypeFactory.getInstance()
.ofComplex((String) args.get(0).getLiteralValue());
ObjectByteStrategy strategy = Types.getStrategy(complexType.getComplexTypeName());
if (strategy == null) {
throw new IAE(
Expand Down
141 changes: 103 additions & 38 deletions core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,36 @@ public class FunctionTest extends InitializedNullHandlingTest
@BeforeClass
public static void setupClass()
{
Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
Types.registerStrategy(
TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
new TypesTest.PairObjectByteStrategy()
);
}

@Before
public void setup()
{
ImmutableMap.Builder<String, Object> builder = ImmutableMap.<String, Object>builder()
.put("x", "foo")
.put("y", 2)
.put("z", 3.1)
.put("d", 34.56D)
.put("maxLong", Long.MAX_VALUE)
.put("minLong", Long.MIN_VALUE)
.put("f", 12.34F)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("o", 0)
.put("od", 0D)
.put("of", 0F)
.put("a", new String[] {"foo", "bar", "baz", "foobar"})
.put("b", new Long[] {1L, 2L, 3L, 4L, 5L})
.put("c", new Double[] {3.1, 4.2, 5.3})
.put("someComplex", new TypesTest.NullableLongPair(1L, 2L));
.put("x", "foo")
.put("y", 2)
.put("z", 3.1)
.put("d", 34.56D)
.put("maxLong", Long.MAX_VALUE)
.put("minLong", Long.MIN_VALUE)
.put("f", 12.34F)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("o", 0)
.put("od", 0D)
.put("of", 0F)
.put("a", new String[]{"foo", "bar", "baz", "foobar"})
.put("b", new Long[]{1L, 2L, 3L, 4L, 5L})
.put("c", new Double[]{3.1, 4.2, 5.3})
.put(
"someComplex",
new TypesTest.NullableLongPair(1L, 2L)
);
bindings = InputBindings.withMap(builder.build());
}

Expand Down Expand Up @@ -350,17 +356,20 @@ public void testArrayCast()
assertArrayExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"});
assertArrayExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0});
assertArrayExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L});
assertArrayExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L});
assertArrayExpr(
"cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')",
new Long[]{1L, 2L, 3L, 4L, 5L}
);
assertArrayExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L});
}

@Test
public void testArraySlice()
{
assertArrayExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[] {2L, 3L});
assertArrayExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[] {3.2, 4.3});
assertArrayExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[] {null, null});
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[] {});
assertArrayExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[]{2L, 3L});
assertArrayExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[]{3.2, 4.3});
assertArrayExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[]{null, null});
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[]{});
assertArrayExpr("array_slice([1, 2, 3, 4], 5, 7)", null);
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 1)", null);
}
Expand Down Expand Up @@ -438,12 +447,24 @@ public void testRoundWithExtremeNumbers()
assertExpr("round(maxLong)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(0, RoundingMode.HALF_UP).longValue());
assertExpr("round(minLong)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(0, RoundingMode.HALF_UP).longValue());
// overflow
assertExpr("round(maxLong + 1, 1)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(1, RoundingMode.HALF_UP).longValue());
assertExpr(
"round(maxLong + 1, 1)",
BigDecimal.valueOf(Long.MIN_VALUE).setScale(1, RoundingMode.HALF_UP).longValue()
);
// underflow
assertExpr("round(minLong - 1, -2)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(-2, RoundingMode.HALF_UP).longValue());
assertExpr(
"round(minLong - 1, -2)",
BigDecimal.valueOf(Long.MAX_VALUE).setScale(-2, RoundingMode.HALF_UP).longValue()
);

assertExpr("round(CAST(maxLong, 'DOUBLE') + 1, 1)", BigDecimal.valueOf(((double) Long.MAX_VALUE) + 1).setScale(1, RoundingMode.HALF_UP).doubleValue());
assertExpr("round(CAST(minLong, 'DOUBLE') - 1, -2)", BigDecimal.valueOf(((double) Long.MIN_VALUE) - 1).setScale(-2, RoundingMode.HALF_UP).doubleValue());
assertExpr(
"round(CAST(maxLong, 'DOUBLE') + 1, 1)",
BigDecimal.valueOf(((double) Long.MAX_VALUE) + 1).setScale(1, RoundingMode.HALF_UP).doubleValue()
);
assertExpr(
"round(CAST(minLong, 'DOUBLE') - 1, -2)",
BigDecimal.valueOf(((double) Long.MIN_VALUE) - 1).setScale(-2, RoundingMode.HALF_UP).doubleValue()
);
}

@Test
Expand Down Expand Up @@ -643,7 +664,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(NullHandling.sqlCompatible() ? true : false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs a number as its first argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs a number as its first argument",
e.getMessage()
);
}

try {
Expand All @@ -655,7 +679,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs an integer as its second argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs an integer as its second argument",
e.getMessage()
);
}

try {
Expand All @@ -667,7 +694,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs an integer as its second argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs an integer as its second argument",
e.getMessage()
);
}

try {
Expand All @@ -679,7 +709,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs an integer as its second argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs an integer as its second argument",
e.getMessage()
);
}
}

Expand All @@ -692,7 +725,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[9223372036854775807] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[9223372036854775807] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}

try {
Expand All @@ -701,7 +737,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[-9223372036854775808] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[-9223372036854775808] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}

try {
Expand All @@ -710,7 +749,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[-1] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[-1] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}

try {
Expand All @@ -719,7 +761,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[4] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[4] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}
}

Expand All @@ -732,6 +777,21 @@ public void testSizeFormatInvalidArgumentSize()
.eval(bindings);
}

@Test
public void testSafeDivide()
{
// happy path maths
assertExpr("safe_divide(3, 1)", 3L);
assertExpr("safe_divide(4.5, 2)", 2.25);
assertExpr("safe_divide(3, 0)", NullHandling.defaultLongValue());
assertExpr("safe_divide(1, 0.0)", NullHandling.defaultDoubleValue());
// NaN and Infinity cases
assertExpr("safe_divide(NaN, 0.0)", NullHandling.defaultDoubleValue());
assertExpr("safe_divide(0, NaN)", 0.0);
assertExpr("safe_divide(0, POSITIVE_INFINITY)", NullHandling.defaultLongValue());
assertExpr("safe_divide(POSITIVE_INFINITY,0)", NullHandling.defaultLongValue());
}

@Test
public void testBitwise()
{
Expand Down Expand Up @@ -763,7 +823,10 @@ public void testBitwise()
Assert.fail("Did not throw IllegalArgumentException");
}
catch (IllegalArgumentException e) {
Assert.assertEquals("Possible data truncation, param [461168601842738800000000000000.000000] is out of long value range", e.getMessage());
Assert.assertEquals(
"Possible data truncation, param [461168601842738800000000000000.000000] is out of long value range",
e.getMessage()
);
}

// doubles are cast
Expand Down Expand Up @@ -845,7 +908,8 @@ public void testComplexDecodeBaseWrongArgCount()
public void testComplexDecodeBaseArg0BadType()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] first argument must be constant 'STRING' expression containing a valid complex type name");
expectedException.expectMessage(
"Function[complex_decode_base64] first argument must be constant 'STRING' expression containing a valid complex type name");
assertExpr(
"complex_decode_base64(1, string)",
null
Expand All @@ -856,7 +920,8 @@ public void testComplexDecodeBaseArg0BadType()
public void testComplexDecodeBaseArg0Unknown()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] first argument must be a valid complex type name, unknown complex type [COMPLEX<unknown>]");
expectedException.expectMessage(
"Function[complex_decode_base64] first argument must be a valid complex type name, unknown complex type [COMPLEX<unknown>]");
assertExpr(
"complex_decode_base64('unknown', string)",
null
Expand Down
Loading

0 comments on commit 2971078

Please sign in to comment.