From 17516247415d183c2be1473caf70cd5f5598563b Mon Sep 17 00:00:00 2001 From: Zhan Yuan <yuanzhanhku@gmail.com> Date: Tue, 25 May 2021 13:44:43 -0700 Subject: [PATCH] Implement ARRAY_NORMALIZE function Normalizes array ``x`` by dividing each element by the p-norm of the array. It is equivalent to `TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))`. But the reduce part is only executed once to improve performance. Returns null if the array is null or there are null array elements. --- .../src/main/sphinx/functions/array.rst | 7 + ...uiltInTypeAndFunctionNamespaceManager.java | 2 + .../scalar/ArrayNormalizeFunction.java | 137 ++++++++++++++++++ .../scalar/TestArrayNormalizeFunction.java | 106 ++++++++++++++ 4 files changed, 252 insertions(+) create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 767d2acae0e58..5edef93919132 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -69,6 +69,13 @@ Array Functions Returns the minimum value of input array. +.. function:: array_normalize(x, p) -> array + + Normalizes array ``x`` by dividing each element by the p-norm of the array. + It is equivalent to ``TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))``, + but the reduce part is only executed once. + Returns null if the array is null or there are null array elements. + .. function:: array_position(x, element) -> bigint Returns the position of the first occurrence of the ``element`` in array ``x`` (or 0 if not found). diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 27ed3317c19b2..e5e7682380bd7 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -106,6 +106,7 @@ import com.facebook.presto.operator.scalar.ArrayMinFunction; import com.facebook.presto.operator.scalar.ArrayNgramsFunction; import com.facebook.presto.operator.scalar.ArrayNoneMatchFunction; +import com.facebook.presto.operator.scalar.ArrayNormalizeFunction; import com.facebook.presto.operator.scalar.ArrayNotEqualOperator; import com.facebook.presto.operator.scalar.ArrayPositionFunction; import com.facebook.presto.operator.scalar.ArrayRemoveFunction; @@ -751,6 +752,7 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC .scalar(ArrayAllMatchFunction.class) .scalar(ArrayAnyMatchFunction.class) .scalar(ArrayNoneMatchFunction.class) + .scalar(ArrayNormalizeFunction.class) .scalar(MapDistinctFromOperator.class) .scalar(MapEqualOperator.class) .scalar(MapEntriesFunction.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java new file mode 100644 index 0000000000000..b4975a6adb164 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.function.TypeParameterSpecialization; + +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.util.Failures.checkCondition; +import static java.lang.String.format; + +@ScalarFunction("array_normalize") +@Description("Normalizes an array by dividing each element by the p-norm of the array.") +public final class ArrayNormalizeFunction +{ + private static final ValueAccessor DOUBLE_VALUE_ACCESSOR = new DoubleValueAccessor(); + private static final ValueAccessor REAL_VALUE_ACCESSOR = new RealValueAccessor(); + + private ArrayNormalizeFunction() {} + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) + @SqlType("array(T)") + @SqlNullable + public static Block normalizeDoubleArray( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block block, + @SqlType("T") double p) + { + return normalizeArray(elementType, block, p, DOUBLE_VALUE_ACCESSOR); + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) + @SqlType("array(T)") + @SqlNullable + public static Block normalizeRealArray( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block block, + @SqlType("T") long p) + { + return normalizeArray(elementType, block, Float.intBitsToFloat((int) p), REAL_VALUE_ACCESSOR); + } + + private static Block normalizeArray(Type elementType, Block block, double p, ValueAccessor valueAccessor) + { + if (!(elementType instanceof RealType) && !(elementType instanceof DoubleType)) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_MISSING, + format("Unsupported array element type for array_normalize function: %s", elementType.getDisplayName())); + } + checkCondition(p >= 0, INVALID_FUNCTION_ARGUMENT, "array_normalize only supports non-negative p: %s", p); + + if (p == 0) { + return block; + } + + int elementCount = block.getPositionCount(); + double pNorm = 0; + for (int i = 0; i < elementCount; i++) { + if (block.isNull(i)) { + return null; + } + pNorm += Math.pow(Math.abs(valueAccessor.getValue(elementType, block, i)), p); + } + if (pNorm == 0) { + return block; + } + pNorm = Math.pow(pNorm, 1.0 / p); + BlockBuilder blockBuilder = elementType.createBlockBuilder(null, elementCount); + for (int i = 0; i < elementCount; i++) { + valueAccessor.writeValue(elementType, blockBuilder, valueAccessor.getValue(elementType, block, i) / pNorm); + } + return blockBuilder.build(); + } + + private interface ValueAccessor + { + double getValue(Type elementType, Block block, int position); + + void writeValue(Type elementType, BlockBuilder blockBuilder, double value); + } + + private static class DoubleValueAccessor + implements ValueAccessor + { + @Override + public double getValue(Type elementType, Block block, int position) + { + return elementType.getDouble(block, position); + } + + @Override + public void writeValue(Type elementType, BlockBuilder blockBuilder, double value) + { + elementType.writeDouble(blockBuilder, value); + } + } + + private static class RealValueAccessor + implements ValueAccessor + { + @Override + public double getValue(Type elementType, Block block, int position) + { + return Float.intBitsToFloat((int) elementType.getLong(block, position)); + } + + @Override + public void writeValue(Type elementType, BlockBuilder blockBuilder, double value) + { + elementType.writeLong(blockBuilder, Float.floatToIntBits((float) value)); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java new file mode 100644 index 0000000000000..7153033587363 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.ArrayType; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; + +public class TestArrayNormalizeFunction + + extends AbstractTestFunctions +{ + @Test + public void test0Norm() + { + assertFunction("array_normalize(ARRAY[1.0E0, 2.0E0, 3.3E0], 0.0E0)", new ArrayType(DOUBLE), ImmutableList.of(1.0, 2.0, 3.3)); + assertFunction("array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.3'], REAL '0.0')", new ArrayType(REAL), ImmutableList.of(1.0f, 2.0f, 3.3f)); + + // Test with negative element. + assertFunction("array_normalize(ARRAY[-1.0E0, 2.0E0, 3.3E0], 0.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-1.0, 2.0, 3.3)); + assertFunction("array_normalize(ARRAY[REAL '-1.0', REAL '2.0', REAL '3.3'], REAL '0.0')", new ArrayType(REAL), ImmutableList.of(-1.0f, 2.0f, 3.3f)); + } + + @Test + public void test1Norm() + { + assertFunction("array_normalize(ARRAY[1.0E0, 2.0E0, 3.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(1.0 / 6.0, 2.0 / 6.0, 3.0 / 6.0)); + assertFunction("array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(1.0f / 6.0f, 2.0f / 6.0f, 3.0f / 6.0f)); + + // Test with negative element. + assertFunction("array_normalize(ARRAY[-1.0E0, 2.0E0, 3.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-1.0 / 6.0, 2.0 / 6.0, 3.0 / 6.0)); + assertFunction("array_normalize(ARRAY[REAL '-1.0', REAL '2.0', REAL '3.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(-1.0f / 6.0f, 2.0f / 6.0f, 3.0f / 6.0f)); + } + + @Test + public void test2Norm() + { + assertFunction("array_normalize(ARRAY[4.0E0, 3.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(4.0 / 5.0, 3.0 / 5.0)); + assertFunction("array_normalize(ARRAY[REAL '4.0', REAL '3.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(4.0f / 5.0f, 3.0f / 5.0f)); + + // Test with negative element. + assertFunction("array_normalize(ARRAY[-4.0E0, 3.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-4.0 / 5.0, 3.0 / 5.0)); + assertFunction("array_normalize(ARRAY[REAL '-4.0', REAL '3.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(-4.0f / 5.0f, 3.0f / 5.0f)); + } + + @Test + public void testNulls() + { + assertFunction("array_normalize(null, 2.0E0)", new ArrayType(DOUBLE), null); + assertFunction("array_normalize(ARRAY[4.0E0, 3.0E0], null)", new ArrayType(DOUBLE), null); + assertFunction("array_normalize(ARRAY[4.0E0, null], 2.0E0)", new ArrayType(DOUBLE), null); + assertFunction("array_normalize(ARRAY[REAL '4.0', REAL '3.0'], null)", new ArrayType(REAL), null); + } + + @Test + public void testArrayOfZeros() + { + assertFunction("array_normalize(ARRAY[0.0E0, 0.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(0.0, 0.0)); + assertFunction("array_normalize(ARRAY[REAL '0.0', REAL '0.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(0.0f, 0.0f)); + + assertFunction("array_normalize(ARRAY[0.0E0, 0.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(0.0, 0.0)); + assertFunction("array_normalize(ARRAY[REAL '0.0', REAL '0.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(0.0f, 0.0f)); + } + + @Test + public void testUnsupportedType() + { + assertInvalidFunction( + "array_normalize(ARRAY[1, 2, 3], 1)", + FUNCTION_IMPLEMENTATION_MISSING, + "Unsupported array element type for array_normalize function: integer"); + assertInvalidFunction( + "array_normalize(ARRAY['a', 'b', 'c'], 'd')", + FUNCTION_IMPLEMENTATION_MISSING, + "Unsupported type parameters.*"); + } + + @Test + public void testNegativeP() + { + assertInvalidFunction( + "array_normalize(ARRAY[1.0E0, 2.0E0, 3.3E0], -1.0E0)", + INVALID_FUNCTION_ARGUMENT, + "array_normalize only supports non-negative p:.*"); + assertInvalidFunction( + "array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.3'], REAL '-1.0')", + INVALID_FUNCTION_ARGUMENT, + "array_normalize only supports non-negative p:.*"); + } +}