Skip to content

Commit

Permalink
Remove subfield pruning for lambda for parametric functions
Browse files Browse the repository at this point in the history
Remove subfield pruning for lambda for parametric functions
  • Loading branch information
rmarduga authored and pranjalssh committed Oct 4, 2023
1 parent 7a6f376 commit e4d6044
Show file tree
Hide file tree
Showing 32 changed files with 28 additions and 775 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -146,55 +146,19 @@ public void testPushDownSubfieldsFromLambdas()

private void testPushDownSubfieldsFromLambdas(String tableName)
{
// functions that are not outputting all subfields
assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(array_of_rows, x -> true) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x IS NULL) FROM " + tableName);
assertQuery("SELECT ANY_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT NONE_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT TRANSFORM(array_of_rows, x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT CARDINALITY(array_of_rows) FROM " + tableName);
assertQuery("SELECT CARDINALITY(FLATTEN(array_of_array_of_rows)) FROM " + tableName);
assertQuery("SELECT CARDINALITY(FILTER(array_of_rows, x -> POSITION('T' IN x.comment) > 0)) FROM " + tableName);
assertQuery("SELECT TRANSFORM(row_with_array_of_rows.array_of_rows, x -> CAST(ROW(x.comment, x.comment) AS ROW(d1 VARCHAR, d2 VARCHAR))) FROM " + tableName);
assertQuery("SELECT TRANSFORM_VALUES(map_varchar_key_row_value, (k,v) -> v.orderkey) FROM " + tableName);
assertQuery("SELECT ZIP_WITH(array_of_rows, row_with_array_of_rows.array_of_rows, (x, y) -> CAST(ROW(x.itemdata, y.comment) AS ROW(d1 BIGINT, d2 VARCHAR))) FROM " + tableName);
assertQuery("SELECT MAP_ZIP_WITH(map_varchar_key_row_value, row_with_map_varchar_key_row_value.map_varchar_key_row_value, (k, v1, v2) -> v1.orderkey + v2.orderkey) FROM " + tableName);

// functions that outputing all subfields and accept functional parameter
assertQuery("SELECT FILTER(array_of_rows, x -> POSITION('T' IN x.comment) > 0) FROM " + tableName);
assertQuery("SELECT ARRAY_SORT(array_of_rows, (l, r) -> IF(l.itemdata < r.itemdata, 1, IF(l.itemdata = r.itemdata, 0, -1))) FROM " + tableName);
assertQuery("SELECT ANY_MATCH(SLICE(ARRAY_SORT(array_of_rows, (l, r) -> IF(l.itemdata < r.itemdata, 1, IF(l.itemdata = r.itemdata, 0, -1))), 1, 3), x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT TRANSFORM(ARRAY_SORT(array_of_non_null_rows), x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM(COMBINATIONS(array_of_non_null_rows, 3), x -> x[1].itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM(FLATTEN(array_of_array_of_rows), x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM(REVERSE(array_of_rows), x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT ARRAY_SORT(TRANSFORM(SHUFFLE(array_of_non_null_rows), x -> x.itemdata)) FROM " + tableName);
assertQuery("SELECT TRANSFORM(SLICE(array_of_rows, 1, 5), x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM(TRIM_ARRAY(array_of_rows, 2), x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM(CONCAT(array_of_rows, row_with_array_of_rows.array_of_rows), x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM(array_of_rows || row_with_array_of_rows.array_of_rows, x -> x.itemdata) FROM " + tableName);
assertQuery("SELECT TRANSFORM_VALUES(MAP_CONCAT(map_varchar_key_row_value, row_with_map_varchar_key_row_value.map_varchar_key_row_value), (k,v) -> v.orderkey) FROM " + tableName);
assertQuery("SELECT TRANSFORM_VALUES(MAP_REMOVE_NULL_VALUES(row_with_map_varchar_key_row_value.map_varchar_key_row_value), (k,v) -> v.orderkey) FROM " + tableName);
assertQuery("SELECT TRANSFORM_VALUES(MAP_SUBSET(row_with_map_varchar_key_row_value.map_varchar_key_row_value, ARRAY['orderdata_ex']), (k,v) -> v.orderkey) FROM " + tableName);
assertQuery("SELECT ANY_MATCH(MAP_VALUES(map_varchar_key_row_value), x -> x.orderkey % 2 = 0) FROM " + tableName);
assertQuery("SELECT ANY_MATCH(MAP_TOP_N_VALUES(map_varchar_key_row_value, 10, (x, y) -> IF(x.orderkey < y.orderkey, -1, IF(x.orderkey = y.orderkey, 0, 1))), x -> x.orderkey % 2 = 0) FROM " + tableName);
assertQuery("SELECT ANY_MATCH(MAP_VALUES(map_varchar_key_row_value), x -> x.orderkey % 2 = 0) FROM " + tableName);

// Simple test of different column type of the array argument
assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(row_with_array_of_rows.array_of_rows, x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(array_of_array_of_rows[1], x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(map_varchar_key_array_of_row_value['orderdata'], x -> x.orderkey > 0) FROM " + tableName);

// element_at
assertQuery("SELECT ALL_MATCH(ELEMENT_AT(map_varchar_key_array_of_row_value, 'orderdata'), x -> x.orderkey > 0) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(ELEMENT_AT(array_of_array_of_rows, 1), x -> x.itemdata > 0) FROM " + tableName);

// Queries that reference variables in lambdas
assertQuery("SELECT ALL_MATCH(SLICE(array_of_rows, 1, row_with_array_of_rows.array_of_rows[1].itemdata), x -> x.itemdata > 0) FROM " + tableName);
assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x.itemdata > row_with_array_of_rows.array_of_rows[1].itemdata) FROM " + tableName);

// Queries that lack full support
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,18 @@
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.StaticMethodPointer;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;
import io.airlift.slice.Slice;

import static java.lang.Boolean.FALSE;

@Description("Returns true if all elements of the array match the given predicate")
@ScalarFunction(value = "all_match", descriptor = @ScalarFunctionDescriptor(
outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")},
lambdaDescriptors = {
@ScalarFunctionLambdaDescriptor(
callArgumentIndex = 1,
lambdaArgumentDescriptors = {
@ScalarFunctionLambdaArgumentDescriptor(
lambdaArgumentIndex = 0,
callArgumentIndex = 0)})}))
@ScalarFunction(value = "all_match")
public final class ArrayAllMatchFunction
{
private ArrayAllMatchFunction() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,18 @@
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.StaticMethodPointer;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;
import io.airlift.slice.Slice;

import static java.lang.Boolean.TRUE;

@Description("Returns true if the array contains one or more elements that match the given predicate")
@ScalarFunction(value = "any_match", descriptor = @ScalarFunctionDescriptor(
outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")},
lambdaDescriptors = {
@ScalarFunctionLambdaDescriptor(
callArgumentIndex = 1,
lambdaArgumentDescriptors = {
@ScalarFunctionLambdaArgumentDescriptor(
lambdaArgumentIndex = 0,
callArgumentIndex = 0)})}))
@ScalarFunction(value = "any_match")
public final class ArrayAnyMatchFunction
{
private ArrayAnyMatchFunction() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,13 @@

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.StaticMethodPointer;
import com.facebook.presto.spi.function.TypeParameter;

@Description("Returns the cardinality (length) of the array")
@ScalarFunction(value = "cardinality", descriptor = @ScalarFunctionDescriptor(
isAccessingInputValues = false,
outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")},
lambdaDescriptors = {}))
@ScalarFunction("cardinality")
public final class ArrayCardinalityFunction
{
private ArrayCardinalityFunction() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.StaticMethodPointer;
import com.facebook.presto.spi.function.TypeParameter;
import com.google.common.annotations.VisibleForTesting;

Expand All @@ -42,10 +39,7 @@
import static java.util.Arrays.setAll;

@Description("Returns n-element combinations from array")
@ScalarFunction(value = "combinations", descriptor = @ScalarFunctionDescriptor(
isAccessingInputValues = false,
outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "removeSecondPathElement")},
lambdaDescriptors = {}))
@ScalarFunction("combinations")
public final class ArrayCombinationsFunction
{
private ArrayCombinationsFunction() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;
Expand All @@ -29,14 +26,7 @@
import static java.lang.Boolean.TRUE;

@Description("return array containing elements that match the given predicate")
@ScalarFunction(value = "filter", deterministic = false, descriptor = @ScalarFunctionDescriptor(
lambdaDescriptors = {
@ScalarFunctionLambdaDescriptor(
callArgumentIndex = 1,
lambdaArgumentDescriptors = {
@ScalarFunctionLambdaArgumentDescriptor(
lambdaArgumentIndex = 0,
callArgumentIndex = 0)})}))
@ScalarFunction(value = "filter", deterministic = false)
public final class ArrayFilterFunction
{
private ArrayFilterFunction() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,16 @@
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.StaticMethodPointer;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;
import io.airlift.slice.Slice;

@Description("Returns true if all elements of the array don't match the given predicate")
@ScalarFunction(value = "none_match", descriptor = @ScalarFunctionDescriptor(
outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")},
lambdaDescriptors = {
@ScalarFunctionLambdaDescriptor(
callArgumentIndex = 1,
lambdaArgumentDescriptors = {
@ScalarFunctionLambdaArgumentDescriptor(
lambdaArgumentIndex = 0,
callArgumentIndex = 0)})}))
@ScalarFunction(value = "none_match")
public final class ArrayNoneMatchFunction
{
private ArrayNoneMatchFunction() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

@ScalarFunction("reverse")
@Description("Returns an array which has the reversed order of the given array.")
@ScalarFunction(value = "reverse", descriptor = @ScalarFunctionDescriptor(isAccessingInputValues = false))
public final class ArrayReverseFunction
{
@TypeParameter("E")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import java.util.concurrent.ThreadLocalRandom;

@ScalarFunction(value = "shuffle", deterministic = false)
@Description("Generates a random permutation of the given array.")
@ScalarFunction(value = "shuffle", deterministic = false, descriptor = @ScalarFunctionDescriptor(isAccessingInputValues = false))
public final class ArrayShuffleFunction
{
private static final int INITIAL_LENGTH = 128;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.util.Failures.checkCondition;

@ScalarFunction(value = "slice", descriptor = @ScalarFunctionDescriptor(isAccessingInputValues = false))
@ScalarFunction("slice")
@Description("Subsets an array given an offset (1-indexed) and length")
public final class ArraySliceFunction
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor;
import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;
Expand All @@ -35,17 +32,7 @@
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.util.Failures.checkCondition;

@ScalarFunction(value = "array_sort", descriptor = @ScalarFunctionDescriptor(
lambdaDescriptors = {
@ScalarFunctionLambdaDescriptor(
callArgumentIndex = 1,
lambdaArgumentDescriptors = {
@ScalarFunctionLambdaArgumentDescriptor(
lambdaArgumentIndex = 0,
callArgumentIndex = 0),
@ScalarFunctionLambdaArgumentDescriptor(
lambdaArgumentIndex = 1,
callArgumentIndex = 0)})}))
@ScalarFunction("array_sort")
@Description("Sorts the given array with a lambda comparator.")
public final class ArraySortComparatorFunction
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.util.Failures.checkCondition;
import static java.lang.Math.toIntExact;

@ScalarFunction(value = "trim_array", descriptor = @ScalarFunctionDescriptor(
isAccessingInputValues = false,
lambdaDescriptors = {}))
@ScalarFunction("trim_array")
@Description("Remove elements from the end of array")
public final class ArrayTrimFunction
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionDescriptor;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.StaticMethodPointer;
import com.facebook.presto.spi.function.TypeParameter;

@ScalarFunction(value = "cardinality", descriptor = @ScalarFunctionDescriptor(
isAccessingInputValues = false,
outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")},
lambdaDescriptors = {}))
@ScalarFunction("cardinality")
@Description("Returns the cardinality (the number of key-value pairs) of the map")
public final class MapCardinalityFunction
{
Expand Down
Loading

0 comments on commit e4d6044

Please sign in to comment.