From 17516247415d183c2be1473caf70cd5f5598563b Mon Sep 17 00:00:00 2001
From: Zhan Yuan <yuanzhanhku@gmail.com>
Date: Tue, 25 May 2021 13:44:43 -0700
Subject: [PATCH] Implement ARRAY_NORMALIZE function

Normalizes array ``x`` by dividing each element by the p-norm of the array.
It is equivalent to
  `TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))`.
But the reduce part is only executed once to improve performance.
Returns null if the array is null or there are null array elements.
---
 .../src/main/sphinx/functions/array.rst       |   7 +
 ...uiltInTypeAndFunctionNamespaceManager.java |   2 +
 .../scalar/ArrayNormalizeFunction.java        | 137 ++++++++++++++++++
 .../scalar/TestArrayNormalizeFunction.java    | 106 ++++++++++++++
 4 files changed, 252 insertions(+)
 create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java
 create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java

diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst
index 767d2acae0e58..5edef93919132 100644
--- a/presto-docs/src/main/sphinx/functions/array.rst
+++ b/presto-docs/src/main/sphinx/functions/array.rst
@@ -69,6 +69,13 @@ Array Functions
 
     Returns the minimum value of input array.
 
+.. function:: array_normalize(x, p) -> array
+
+   Normalizes array ``x`` by dividing each element by the p-norm of the array.
+   It is equivalent to ``TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))``,
+   but the reduce part is only executed once.
+   Returns null if the array is null or there are null array elements.
+
 .. function:: array_position(x, element) -> bigint
 
     Returns the position of the first occurrence of the ``element`` in array ``x`` (or 0 if not found).
diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java
index 27ed3317c19b2..e5e7682380bd7 100644
--- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java
+++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java
@@ -106,6 +106,7 @@
 import com.facebook.presto.operator.scalar.ArrayMinFunction;
 import com.facebook.presto.operator.scalar.ArrayNgramsFunction;
 import com.facebook.presto.operator.scalar.ArrayNoneMatchFunction;
+import com.facebook.presto.operator.scalar.ArrayNormalizeFunction;
 import com.facebook.presto.operator.scalar.ArrayNotEqualOperator;
 import com.facebook.presto.operator.scalar.ArrayPositionFunction;
 import com.facebook.presto.operator.scalar.ArrayRemoveFunction;
@@ -751,6 +752,7 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC
                 .scalar(ArrayAllMatchFunction.class)
                 .scalar(ArrayAnyMatchFunction.class)
                 .scalar(ArrayNoneMatchFunction.class)
+                .scalar(ArrayNormalizeFunction.class)
                 .scalar(MapDistinctFromOperator.class)
                 .scalar(MapEqualOperator.class)
                 .scalar(MapEntriesFunction.class)
diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java
new file mode 100644
index 0000000000000..b4975a6adb164
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.operator.scalar;
+
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockBuilder;
+import com.facebook.presto.common.type.DoubleType;
+import com.facebook.presto.common.type.RealType;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.ScalarFunction;
+import com.facebook.presto.spi.function.SqlNullable;
+import com.facebook.presto.spi.function.SqlType;
+import com.facebook.presto.spi.function.TypeParameter;
+import com.facebook.presto.spi.function.TypeParameterSpecialization;
+
+import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
+import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
+import static com.facebook.presto.util.Failures.checkCondition;
+import static java.lang.String.format;
+
+@ScalarFunction("array_normalize")
+@Description("Normalizes an array by dividing each element by the p-norm of the array.")
+public final class ArrayNormalizeFunction
+{
+    private static final ValueAccessor DOUBLE_VALUE_ACCESSOR = new DoubleValueAccessor();
+    private static final ValueAccessor REAL_VALUE_ACCESSOR = new RealValueAccessor();
+
+    private ArrayNormalizeFunction() {}
+
+    @TypeParameter("T")
+    @TypeParameterSpecialization(name = "T", nativeContainerType = double.class)
+    @SqlType("array(T)")
+    @SqlNullable
+    public static Block normalizeDoubleArray(
+            @TypeParameter("T") Type elementType,
+            @SqlType("array(T)") Block block,
+            @SqlType("T") double p)
+    {
+        return normalizeArray(elementType, block, p, DOUBLE_VALUE_ACCESSOR);
+    }
+
+    @TypeParameter("T")
+    @TypeParameterSpecialization(name = "T", nativeContainerType = long.class)
+    @SqlType("array(T)")
+    @SqlNullable
+    public static Block normalizeRealArray(
+            @TypeParameter("T") Type elementType,
+            @SqlType("array(T)") Block block,
+            @SqlType("T") long p)
+    {
+        return normalizeArray(elementType, block, Float.intBitsToFloat((int) p), REAL_VALUE_ACCESSOR);
+    }
+
+    private static Block normalizeArray(Type elementType, Block block, double p, ValueAccessor valueAccessor)
+    {
+        if (!(elementType instanceof RealType) && !(elementType instanceof DoubleType)) {
+            throw new PrestoException(
+                    FUNCTION_IMPLEMENTATION_MISSING,
+                    format("Unsupported array element type for array_normalize function: %s", elementType.getDisplayName()));
+        }
+        checkCondition(p >= 0, INVALID_FUNCTION_ARGUMENT, "array_normalize only supports non-negative p: %s", p);
+
+        if (p == 0) {
+            return block;
+        }
+
+        int elementCount = block.getPositionCount();
+        double pNorm = 0;
+        for (int i = 0; i < elementCount; i++) {
+            if (block.isNull(i)) {
+                return null;
+            }
+            pNorm += Math.pow(Math.abs(valueAccessor.getValue(elementType, block, i)), p);
+        }
+        if (pNorm == 0) {
+            return block;
+        }
+        pNorm = Math.pow(pNorm, 1.0 / p);
+        BlockBuilder blockBuilder = elementType.createBlockBuilder(null, elementCount);
+        for (int i = 0; i < elementCount; i++) {
+            valueAccessor.writeValue(elementType, blockBuilder, valueAccessor.getValue(elementType, block, i) / pNorm);
+        }
+        return blockBuilder.build();
+    }
+
+    private interface ValueAccessor
+    {
+        double getValue(Type elementType, Block block, int position);
+
+        void writeValue(Type elementType, BlockBuilder blockBuilder, double value);
+    }
+
+    private static class DoubleValueAccessor
+            implements ValueAccessor
+    {
+        @Override
+        public double getValue(Type elementType, Block block, int position)
+        {
+            return elementType.getDouble(block, position);
+        }
+
+        @Override
+        public void writeValue(Type elementType, BlockBuilder blockBuilder, double value)
+        {
+            elementType.writeDouble(blockBuilder, value);
+        }
+    }
+
+    private static class RealValueAccessor
+            implements ValueAccessor
+    {
+        @Override
+        public double getValue(Type elementType, Block block, int position)
+        {
+            return Float.intBitsToFloat((int) elementType.getLong(block, position));
+        }
+
+        @Override
+        public void writeValue(Type elementType, BlockBuilder blockBuilder, double value)
+        {
+            elementType.writeLong(blockBuilder, Float.floatToIntBits((float) value));
+        }
+    }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java
new file mode 100644
index 0000000000000..7153033587363
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.operator.scalar;
+
+import com.facebook.presto.common.type.ArrayType;
+import com.google.common.collect.ImmutableList;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.common.type.DoubleType.DOUBLE;
+import static com.facebook.presto.common.type.RealType.REAL;
+import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
+import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
+
+public class TestArrayNormalizeFunction
+
+        extends AbstractTestFunctions
+{
+    @Test
+    public void test0Norm()
+    {
+        assertFunction("array_normalize(ARRAY[1.0E0, 2.0E0, 3.3E0], 0.0E0)", new ArrayType(DOUBLE), ImmutableList.of(1.0, 2.0, 3.3));
+        assertFunction("array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.3'], REAL '0.0')", new ArrayType(REAL), ImmutableList.of(1.0f, 2.0f, 3.3f));
+
+        // Test with negative element.
+        assertFunction("array_normalize(ARRAY[-1.0E0, 2.0E0, 3.3E0], 0.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-1.0, 2.0, 3.3));
+        assertFunction("array_normalize(ARRAY[REAL '-1.0', REAL '2.0', REAL '3.3'], REAL '0.0')", new ArrayType(REAL), ImmutableList.of(-1.0f, 2.0f, 3.3f));
+    }
+
+    @Test
+    public void test1Norm()
+    {
+        assertFunction("array_normalize(ARRAY[1.0E0, 2.0E0, 3.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(1.0 / 6.0, 2.0 / 6.0, 3.0 / 6.0));
+        assertFunction("array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(1.0f / 6.0f, 2.0f / 6.0f, 3.0f / 6.0f));
+
+        // Test with negative element.
+        assertFunction("array_normalize(ARRAY[-1.0E0, 2.0E0, 3.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-1.0 / 6.0, 2.0 / 6.0, 3.0 / 6.0));
+        assertFunction("array_normalize(ARRAY[REAL '-1.0', REAL '2.0', REAL '3.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(-1.0f / 6.0f, 2.0f / 6.0f, 3.0f / 6.0f));
+    }
+
+    @Test
+    public void test2Norm()
+    {
+        assertFunction("array_normalize(ARRAY[4.0E0, 3.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(4.0 / 5.0, 3.0 / 5.0));
+        assertFunction("array_normalize(ARRAY[REAL '4.0', REAL '3.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(4.0f / 5.0f, 3.0f / 5.0f));
+
+        // Test with negative element.
+        assertFunction("array_normalize(ARRAY[-4.0E0, 3.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-4.0 / 5.0, 3.0 / 5.0));
+        assertFunction("array_normalize(ARRAY[REAL '-4.0', REAL '3.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(-4.0f / 5.0f, 3.0f / 5.0f));
+    }
+
+    @Test
+    public void testNulls()
+    {
+        assertFunction("array_normalize(null, 2.0E0)", new ArrayType(DOUBLE), null);
+        assertFunction("array_normalize(ARRAY[4.0E0, 3.0E0], null)", new ArrayType(DOUBLE), null);
+        assertFunction("array_normalize(ARRAY[4.0E0, null], 2.0E0)", new ArrayType(DOUBLE), null);
+        assertFunction("array_normalize(ARRAY[REAL '4.0', REAL '3.0'], null)", new ArrayType(REAL), null);
+    }
+
+    @Test
+    public void testArrayOfZeros()
+    {
+        assertFunction("array_normalize(ARRAY[0.0E0, 0.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(0.0, 0.0));
+        assertFunction("array_normalize(ARRAY[REAL '0.0', REAL '0.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(0.0f, 0.0f));
+
+        assertFunction("array_normalize(ARRAY[0.0E0, 0.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(0.0, 0.0));
+        assertFunction("array_normalize(ARRAY[REAL '0.0', REAL '0.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(0.0f, 0.0f));
+    }
+
+    @Test
+    public void testUnsupportedType()
+    {
+        assertInvalidFunction(
+                "array_normalize(ARRAY[1, 2, 3], 1)",
+                FUNCTION_IMPLEMENTATION_MISSING,
+                "Unsupported array element type for array_normalize function: integer");
+        assertInvalidFunction(
+                "array_normalize(ARRAY['a', 'b', 'c'], 'd')",
+                FUNCTION_IMPLEMENTATION_MISSING,
+                "Unsupported type parameters.*");
+    }
+
+    @Test
+    public void testNegativeP()
+    {
+        assertInvalidFunction(
+                "array_normalize(ARRAY[1.0E0, 2.0E0, 3.3E0], -1.0E0)",
+                INVALID_FUNCTION_ARGUMENT,
+                "array_normalize only supports non-negative p:.*");
+        assertInvalidFunction(
+                "array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.3'], REAL '-1.0')",
+                INVALID_FUNCTION_ARGUMENT,
+                "array_normalize only supports non-negative p:.*");
+    }
+}