Skip to content

Commit

Permalink
Implement ARRAY_NORMALIZE function
Browse files Browse the repository at this point in the history
Normalizes array ``x`` by dividing each element by the p-norm of the array.
It is equivalent to
  `TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))`.
But the reduce part is only executed once to improve performance.
Returns null if the array is null or there are null array elements.
  • Loading branch information
yuanzhanhku committed Jun 3, 2021
1 parent faf0cfe commit 1751624
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 0 deletions.
7 changes: 7 additions & 0 deletions presto-docs/src/main/sphinx/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ Array Functions

Returns the minimum value of input array.

.. function:: array_normalize(x, p) -> array

Normalizes array ``x`` by dividing each element by the p-norm of the array.
It is equivalent to ``TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))``,
but the reduce part is only executed once.
Returns null if the array is null or there are null array elements.

.. function:: array_position(x, element) -> bigint

Returns the position of the first occurrence of the ``element`` in array ``x`` (or 0 if not found).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import com.facebook.presto.operator.scalar.ArrayMinFunction;
import com.facebook.presto.operator.scalar.ArrayNgramsFunction;
import com.facebook.presto.operator.scalar.ArrayNoneMatchFunction;
import com.facebook.presto.operator.scalar.ArrayNormalizeFunction;
import com.facebook.presto.operator.scalar.ArrayNotEqualOperator;
import com.facebook.presto.operator.scalar.ArrayPositionFunction;
import com.facebook.presto.operator.scalar.ArrayRemoveFunction;
Expand Down Expand Up @@ -751,6 +752,7 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC
.scalar(ArrayAllMatchFunction.class)
.scalar(ArrayAnyMatchFunction.class)
.scalar(ArrayNoneMatchFunction.class)
.scalar(ArrayNormalizeFunction.class)
.scalar(MapDistinctFromOperator.class)
.scalar(MapEqualOperator.class)
.scalar(MapEntriesFunction.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.RealType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;

import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.util.Failures.checkCondition;
import static java.lang.String.format;

@ScalarFunction("array_normalize")
@Description("Normalizes an array by dividing each element by the p-norm of the array.")
public final class ArrayNormalizeFunction
{
private static final ValueAccessor DOUBLE_VALUE_ACCESSOR = new DoubleValueAccessor();
private static final ValueAccessor REAL_VALUE_ACCESSOR = new RealValueAccessor();

private ArrayNormalizeFunction() {}

@TypeParameter("T")
@TypeParameterSpecialization(name = "T", nativeContainerType = double.class)
@SqlType("array(T)")
@SqlNullable
public static Block normalizeDoubleArray(
@TypeParameter("T") Type elementType,
@SqlType("array(T)") Block block,
@SqlType("T") double p)
{
return normalizeArray(elementType, block, p, DOUBLE_VALUE_ACCESSOR);
}

@TypeParameter("T")
@TypeParameterSpecialization(name = "T", nativeContainerType = long.class)
@SqlType("array(T)")
@SqlNullable
public static Block normalizeRealArray(
@TypeParameter("T") Type elementType,
@SqlType("array(T)") Block block,
@SqlType("T") long p)
{
return normalizeArray(elementType, block, Float.intBitsToFloat((int) p), REAL_VALUE_ACCESSOR);
}

private static Block normalizeArray(Type elementType, Block block, double p, ValueAccessor valueAccessor)
{
if (!(elementType instanceof RealType) && !(elementType instanceof DoubleType)) {
throw new PrestoException(
FUNCTION_IMPLEMENTATION_MISSING,
format("Unsupported array element type for array_normalize function: %s", elementType.getDisplayName()));
}
checkCondition(p >= 0, INVALID_FUNCTION_ARGUMENT, "array_normalize only supports non-negative p: %s", p);

if (p == 0) {
return block;
}

int elementCount = block.getPositionCount();
double pNorm = 0;
for (int i = 0; i < elementCount; i++) {
if (block.isNull(i)) {
return null;
}
pNorm += Math.pow(Math.abs(valueAccessor.getValue(elementType, block, i)), p);
}
if (pNorm == 0) {
return block;
}
pNorm = Math.pow(pNorm, 1.0 / p);
BlockBuilder blockBuilder = elementType.createBlockBuilder(null, elementCount);
for (int i = 0; i < elementCount; i++) {
valueAccessor.writeValue(elementType, blockBuilder, valueAccessor.getValue(elementType, block, i) / pNorm);
}
return blockBuilder.build();
}

private interface ValueAccessor
{
double getValue(Type elementType, Block block, int position);

void writeValue(Type elementType, BlockBuilder blockBuilder, double value);
}

private static class DoubleValueAccessor
implements ValueAccessor
{
@Override
public double getValue(Type elementType, Block block, int position)
{
return elementType.getDouble(block, position);
}

@Override
public void writeValue(Type elementType, BlockBuilder blockBuilder, double value)
{
elementType.writeDouble(blockBuilder, value);
}
}

private static class RealValueAccessor
implements ValueAccessor
{
@Override
public double getValue(Type elementType, Block block, int position)
{
return Float.intBitsToFloat((int) elementType.getLong(block, position));
}

@Override
public void writeValue(Type elementType, BlockBuilder blockBuilder, double value)
{
elementType.writeLong(blockBuilder, Float.floatToIntBits((float) value));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.type.ArrayType;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;

public class TestArrayNormalizeFunction

extends AbstractTestFunctions
{
@Test
public void test0Norm()
{
assertFunction("array_normalize(ARRAY[1.0E0, 2.0E0, 3.3E0], 0.0E0)", new ArrayType(DOUBLE), ImmutableList.of(1.0, 2.0, 3.3));
assertFunction("array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.3'], REAL '0.0')", new ArrayType(REAL), ImmutableList.of(1.0f, 2.0f, 3.3f));

// Test with negative element.
assertFunction("array_normalize(ARRAY[-1.0E0, 2.0E0, 3.3E0], 0.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-1.0, 2.0, 3.3));
assertFunction("array_normalize(ARRAY[REAL '-1.0', REAL '2.0', REAL '3.3'], REAL '0.0')", new ArrayType(REAL), ImmutableList.of(-1.0f, 2.0f, 3.3f));
}

@Test
public void test1Norm()
{
assertFunction("array_normalize(ARRAY[1.0E0, 2.0E0, 3.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(1.0 / 6.0, 2.0 / 6.0, 3.0 / 6.0));
assertFunction("array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(1.0f / 6.0f, 2.0f / 6.0f, 3.0f / 6.0f));

// Test with negative element.
assertFunction("array_normalize(ARRAY[-1.0E0, 2.0E0, 3.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-1.0 / 6.0, 2.0 / 6.0, 3.0 / 6.0));
assertFunction("array_normalize(ARRAY[REAL '-1.0', REAL '2.0', REAL '3.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(-1.0f / 6.0f, 2.0f / 6.0f, 3.0f / 6.0f));
}

@Test
public void test2Norm()
{
assertFunction("array_normalize(ARRAY[4.0E0, 3.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(4.0 / 5.0, 3.0 / 5.0));
assertFunction("array_normalize(ARRAY[REAL '4.0', REAL '3.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(4.0f / 5.0f, 3.0f / 5.0f));

// Test with negative element.
assertFunction("array_normalize(ARRAY[-4.0E0, 3.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(-4.0 / 5.0, 3.0 / 5.0));
assertFunction("array_normalize(ARRAY[REAL '-4.0', REAL '3.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(-4.0f / 5.0f, 3.0f / 5.0f));
}

@Test
public void testNulls()
{
assertFunction("array_normalize(null, 2.0E0)", new ArrayType(DOUBLE), null);
assertFunction("array_normalize(ARRAY[4.0E0, 3.0E0], null)", new ArrayType(DOUBLE), null);
assertFunction("array_normalize(ARRAY[4.0E0, null], 2.0E0)", new ArrayType(DOUBLE), null);
assertFunction("array_normalize(ARRAY[REAL '4.0', REAL '3.0'], null)", new ArrayType(REAL), null);
}

@Test
public void testArrayOfZeros()
{
assertFunction("array_normalize(ARRAY[0.0E0, 0.0E0], 1.0E0)", new ArrayType(DOUBLE), ImmutableList.of(0.0, 0.0));
assertFunction("array_normalize(ARRAY[REAL '0.0', REAL '0.0'], REAL '1.0')", new ArrayType(REAL), ImmutableList.of(0.0f, 0.0f));

assertFunction("array_normalize(ARRAY[0.0E0, 0.0E0], 2.0E0)", new ArrayType(DOUBLE), ImmutableList.of(0.0, 0.0));
assertFunction("array_normalize(ARRAY[REAL '0.0', REAL '0.0'], REAL '2.0')", new ArrayType(REAL), ImmutableList.of(0.0f, 0.0f));
}

@Test
public void testUnsupportedType()
{
assertInvalidFunction(
"array_normalize(ARRAY[1, 2, 3], 1)",
FUNCTION_IMPLEMENTATION_MISSING,
"Unsupported array element type for array_normalize function: integer");
assertInvalidFunction(
"array_normalize(ARRAY['a', 'b', 'c'], 'd')",
FUNCTION_IMPLEMENTATION_MISSING,
"Unsupported type parameters.*");
}

@Test
public void testNegativeP()
{
assertInvalidFunction(
"array_normalize(ARRAY[1.0E0, 2.0E0, 3.3E0], -1.0E0)",
INVALID_FUNCTION_ARGUMENT,
"array_normalize only supports non-negative p:.*");
assertInvalidFunction(
"array_normalize(ARRAY[REAL '1.0', REAL '2.0', REAL '3.3'], REAL '-1.0')",
INVALID_FUNCTION_ARGUMENT,
"array_normalize only supports non-negative p:.*");
}
}

0 comments on commit 1751624

Please sign in to comment.