Skip to content

Commit

Permalink
Add support for array_least_frequent UDF
Browse files Browse the repository at this point in the history
UDF to return the least frequent eleement of an array
  • Loading branch information
jainavi17 authored and NikhilCollooru committed Oct 8, 2023
1 parent ed0c6e7 commit 9d121df
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.benchmark;

import com.facebook.presto.testing.LocalQueryRunner;

import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner;

public class SqlArrayLeastFrequentBenchmark
extends AbstractSqlBenchmark
{
public SqlArrayLeastFrequentBenchmark(LocalQueryRunner localQueryRunner, String query, String name)
{
super(localQueryRunner, name, 10, 10, query);
}

public static void main(String[] args)
{
new SqlArrayLeastFrequentBenchmark(createLocalQueryRunner(), "SELECT ARRAY_LEAST_FREQUENT(x) FROM (SELECT (SEQUENCE(1, random(100)) || SEQUENCE(1, random(100))) AS x FROM (SELECT 1) CROSS JOIN UNNEST(SEQUENCE(1, 10)) CROSS JOIN UNNEST(SEQUENCE(1, 5000)) T(x))", "sql_array_least_frequent").runBenchmark(new SimpleLineBenchmarkResultWriter(System.out));
}
}
11 changes: 10 additions & 1 deletion presto-docs/src/main/sphinx/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ Array Functions

Concatenates the elements of the given array using the delimiter and an optional string to replace nulls.

.. function:: array_least_frequent(array(T)) -> array(T)

Returns the least frequent element of an array. If there are multiple elements with same frequency, the function returns the largest element.

.. function:: array_least_frequent(array(T), n) -> array(T)

Returns n least frequent elements of an array. The elements are based on increasing order of their frequencies.
If two elements have same frequency then element with higher value will appear before lower value.

.. function:: array_max(x) -> x

Returns the maximum value of input array.
Expand Down Expand Up @@ -205,7 +214,7 @@ Array Functions
.. function:: combinations(array(T), n) -> array(array(T))

Returns n-element combinations of the input array.
If the input array has no duplicates, ``combinations`` returns n-element subsets.
If the input array has no duplicates, ``combinations`` returns n-element subsets.
Order of subgroup is deterministic but unspecified. Order of elements within
a subgroup are deterministic but unspecified. ``n`` must not be greater than 5,
and the total size of subgroups generated must be smaller than 100000::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ public static String arrayHasDuplicatesVarchar()
return "RETURN cardinality(array_duplicates(input)) > 0";
}

@SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true)
@Description("Determines the least frequent element in the array. If there are multiple elements, the function returns the smallest element")
@TypeParameter("T")
@SqlParameter(name = "input", type = "array(T)")
@SqlType("array<T>")
public static String array_least_frequent()
{
return "RETURN IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, 1), x -> x[2]))";
}

@SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true)
@Description("Determines the n least frequent element in the array in the ascending order of the elements.")
@TypeParameter("T")
@SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "bigint")})
@SqlType("array<T>")
public static String array_n_least_frequent()
{
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, n), x -> x[2])))";
}

@SqlInvokedScalarFunction(value = "array_max_by", deterministic = true, calledOnNullInput = true)
@Description("Get the maximum value of array, by using a specific transformation function")
@TypeParameter("T")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,55 @@ public void testArrayDuplicates()
assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, null)])", StandardErrorCode.NOT_SUPPORTED, "map key cannot be null or contain nulls");
}

@Test
public void testArrayLeastFrequent()
{
// Base Case
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 2, 2, 3, 3, 3])", new ArrayType(INTEGER), ImmutableList.of(1));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY ['a', 'b', 'b', 'c', 'c', 'c'])", new ArrayType(createVarcharType(1)), ImmutableList.of("a"));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 1, 2, 2, 3, 3])", new ArrayType(INTEGER), ImmutableList.of(1));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", new ArrayType(DOUBLE), asList(1.0d));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY ['abc', 'bc', 'aaa'])", new ArrayType(createVarcharType(3)), ImmutableList.of("aaa"));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY ['', '', ' '])", new ArrayType(createVarcharType(1)), ImmutableList.of(" "));
// Empty Case
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [])", new ArrayType(UNKNOWN), null);
// Null Case
assertFunction("ARRAY_LEAST_FREQUENT(null)", new ArrayType(UNKNOWN), null);
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [NULL])", new ArrayType(UNKNOWN), null);
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 2, 2, NULL])", new ArrayType(INTEGER), ImmutableList.of(1));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [NULL, NULL, NULL])", new ArrayType(UNKNOWN), null);
// Complex Case
RowType rowType = RowType.from(ImmutableList.of(RowType.field(INTEGER), RowType.field(INTEGER)));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [ROW(1, 2), ROW(2, 3), ROW(2, 3)])", new ArrayType(rowType), ImmutableList.of(ImmutableList.of(1, 2)));
}

@Test
public void testArrayNLeastFrequent()
{
// Base Case
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 2, 2, 3, 3, 3], 2)", new ArrayType(INTEGER), ImmutableList.of(1, 2));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY ['a', 'b', 'b', 'c', 'c', 'c'], 3)", new ArrayType(createVarcharType(1)), ImmutableList.of("a", "b", "c"));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 1, 2, 2, 3, 3], 1)", new ArrayType(INTEGER), ImmutableList.of(1));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], 2)", new ArrayType(DOUBLE), asList(1.0d, 2.0d));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY ['abc', 'bc', 'aaa'], 3)", new ArrayType(createVarcharType(3)), ImmutableList.of("aaa", "abc", "bc"));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY ['', '', ' '], 1)", new ArrayType(createVarcharType(1)), ImmutableList.of(" "));
// Empty Case
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [], 2)", new ArrayType(UNKNOWN), null);
// Null Case
assertFunction("ARRAY_LEAST_FREQUENT(null, 3)", new ArrayType(UNKNOWN), null);
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [NULL], 0)", new ArrayType(UNKNOWN), null);
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [NULL, NULL, NULL], 1)", new ArrayType(UNKNOWN), null);
// N = 0
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 2, 2, NULL], 0)", new ArrayType(INTEGER), emptyList());
// N < 0
assertInvalidFunction("ARRAY_LEAST_FREQUENT(ARRAY ['a', 'b', 'b', 'c', 'c', 'c'], -1)", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
// N greater distinct non-null elements in the array
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [1, 2, 2, 3, 3, 3, -1], 5)", new ArrayType(INTEGER), ImmutableList.of(-1, 1, 2, 3));
// Complex Case
RowType rowType = RowType.from(ImmutableList.of(RowType.field(INTEGER), RowType.field(INTEGER)));
assertFunction("ARRAY_LEAST_FREQUENT(ARRAY [ROW(1, 2), ROW(2, 3), ROW(2, 3)], 2)", new ArrayType(rowType), ImmutableList.of(ImmutableList.of(1, 2), ImmutableList.of(2, 3)));
}

@Test
public void testArrayMaxBy()
{
Expand Down

0 comments on commit 9d121df

Please sign in to comment.