diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java index 9413dc0018885..bf7f7ef948c0d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java @@ -159,7 +159,8 @@ protected QueryRunner createQueryRunner() { return HiveQueryRunner.createQueryRunner( ImmutableList.of(ORDERS, LINE_ITEM, CUSTOMER, NATION), - ImmutableMap.of("experimental.pushdown-subfields-enabled", "true"), + ImmutableMap.of("experimental.pushdown-subfields-enabled", "true", + "pushdown-subfields-from-lambda-enabled", "true"), Optional.empty()); } @@ -1209,6 +1210,330 @@ private void assertPushdownSubscripts(String tableName) ImmutableMap.of("z", toSubfields("z[1][2]"))); } + @Test + public void testPushDownSubfieldsFromLambdas() + { + final String tableName = "test_pushdown_subfields_from_array_lambda"; + try { + assertUpdate("CREATE TABLE " + tableName + "(id bigint, " + + "a array(bigint), " + + "b array(array(varchar)), " + + "mi map(int,array(row(a1 bigint, a2 double))), " + + "mv map(varchar,array(row(a1 bigint, a2 double))), " + + "m1 map(int,row(a1 bigint, a2 double)), " + + "m2 map(int,row(a1 bigint, a2 double)), " + + "m3 map(int,row(a1 bigint, a2 double)), " + + "r row(a array(row(a1 bigint, a2 double)), i bigint, d row(d1 bigint, d2 double)), " + + "y array(row(a bigint, b varchar, c double, d row(d1 bigint, d2 double))), " + + "yy array(row(a bigint, b varchar, c double, d row(d1 bigint, d2 double))), " + + "yyy array(row(a bigint, b varchar, c double, d row(d1 bigint, d2 double))), " + + "am array(map(int,row(a1 bigint, a2 double))), " + + "aa array(array(row(a1 bigint, a2 double))), " + + "z array(array(row(p bigint, e row(e1 bigint, e2 varchar)))))"); + + // functions that are not outputting all subfields + // all_match + assertPushdownSubfields("SELECT ALL_MATCH(y, x -> x.a > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + assertPushdownSubfields("SELECT ALL_MATCH(y, x -> true) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].$"))); + // any_match + assertPushdownSubfields("SELECT ANY_MATCH(y, x -> x.a > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + // none_match + assertPushdownSubfields("SELECT NONE_MATCH(y, x -> x.d.d1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"))); + // transform + assertPushdownSubfields("SELECT TRANSFORM(y, x -> x.d.d1) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"))); + + // map_zip_with + assertPushdownSubfields("SELECT MAP_ZIP_WITH(m1, m2, (k, v1, v2) -> v1.a1 + v2.a2) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"), "m2", toSubfields("m2[*].a2"))); + + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(MAP_ZIP_WITH(m1, m2, (k, v1, v2) -> CAST(ROW(v1.a1, v2.a2) AS ROW(a1 BIGINT, a2 DOUBLE)))), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"), "m2", toSubfields("m2[*].a2"))); + + // transform_values + assertPushdownSubfields("SELECT TRANSFORM_VALUES(m1, (k, v) -> v.a1 * 1000) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"))); + + assertPushdownSubfields("SELECT TRANSFORM_VALUES(MAP_FILTER(m1, (k,v) -> v.a2 > 100), (k, v) -> v.a1 * 1000) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1", "m1[*].a2"))); + + assertPushdownSubfields("SELECT MAP_FILTER(TRANSFORM_VALUES(m1, (k, v) -> CAST(ROW(v.a1 * 1000, v.a1) AS ROW(a1 BIGINT, a2 DOUBLE))), (k,v) -> v.a2 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"))); + + // map_values + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(m1), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"))); + + // cardinality + assertPushdownSubfields("SELECT CARDINALITY(y) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].$"))); + + assertPushdownSubfields("SELECT CARDINALITY(FLATTEN(z)) FROM " + tableName, tableName, + ImmutableMap.of("z", toSubfields("z[*][*].$"))); + + assertPushdownSubfields("SELECT CARDINALITY(FILTER(y, x -> true)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].$"))); + + assertPushdownSubfields("SELECT CARDINALITY(FILTER(y, x -> position('9' IN x.b) > 0)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].b"))); + + assertPushdownSubfields("SELECT CARDINALITY(m1) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].$"))); + + assertPushdownSubfields("SELECT CARDINALITY(MAP_FILTER(m1, (k,v) -> v.a2 > 100)) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a2"))); + + assertPushdownSubfields("SELECT CARDINALITY(MAP_FILTER(m1, (k,v) -> v.a1 > 100)) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"))); + + // transform + assertPushdownSubfields("SELECT TRANSFORM(y, x -> ROW(x.a, x.d.d1)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a", "y[*].d.d1"))); + + assertPushdownSubfields("SELECT TRANSFORM(r.a, x -> x.a1) FROM " + tableName, tableName, + ImmutableMap.of("r", toSubfields("r.a[*].a1"))); + + // zip_with + assertPushdownSubfields("SELECT ZIP_WITH(y, yy, (x, xx) -> ROW(x.a, xx.d.d1)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"), "yy", toSubfields("yy[*].d.d1"))); + + // functions that outputing all subfields and accept functional parameter + + //filter + assertPushdownSubfields("SELECT FILTER(y, x -> x.a > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); + + assertPushdownSubfields("SELECT CARDINALITY(FILTER(y, x -> x.a > 0)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + // slice + assertPushdownSubfields("SELECT TRANSFORM(SLICE(y, 1, 5), x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + //array_sort + assertPushdownSubfields("SELECT ARRAY_SORT(y, (l, r) -> IF(l.a < r.a, 1, IF(l.a = r.a, 0, -1))) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); + + assertPushdownSubfields("SELECT ANY_MATCH(SLICE(ARRAY_SORT(y, (l, r) -> IF(l.a < r.a, 1, IF(l.a = r.a, 0, -1))), 1, 3), x -> x.c > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a", "y[*].c"))); + + assertPushdownSubfields("SELECT TRANSFORM(ARRAY_SORT(y), x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // ARRAY_SORT(y) function accesses all the subfields as there is no lambda function provided. + + // combinations + assertPushdownSubfields("SELECT TRANSFORM(COMBINATIONS(y, 3), x -> x[1].a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + assertPushdownSubfields("SELECT TRANSFORM(FLATTEN(COMBINATIONS(y, 3)), x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + assertPushdownSubfields("SELECT TRANSFORM(COMBINATIONS(FLATTEN(z), 3), x -> x[1].p) FROM " + tableName, tableName, + ImmutableMap.of("z", toSubfields("z[*][*].p"))); + + // flatten + assertPushdownSubfields("SELECT TRANSFORM(FLATTEN(z), x -> x.p) FROM " + tableName, tableName, + ImmutableMap.of("z", toSubfields("z[*][*].p"))); + + // reverse + assertPushdownSubfields("SELECT TRANSFORM(REVERSE(y), x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + // shuffle + assertPushdownSubfields("SELECT TRANSFORM(SHUFFLE(y), x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + // trim_array + assertPushdownSubfields("SELECT TRANSFORM(TRIM_ARRAY(y, 5), x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + // concat + assertPushdownSubfields("SELECT TRANSFORM(y || yy, x -> x.d.d1) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"), "yy", toSubfields("yy[*].d.d1"))); + + assertPushdownSubfields("SELECT TRANSFORM(CONCAT(y, yy) , x -> x.d.d1) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"), "yy", toSubfields("yy[*].d.d1"))); + + assertPushdownSubfields("SELECT TRANSFORM(CONCAT(y, yy, yyy) , x -> x.d.d1) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"), "yy", toSubfields("yy[*].d.d1"), "yyy", toSubfields("yyy[*].d.d1"))); + + // map_concat + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(MAP_CONCAT(m1, m2, m3)), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"), "m2", toSubfields("m2[*].a1"), "m3", toSubfields("m3[*].a1"))); + + // map + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(MAP(a, y)), x -> x.a > 100) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(CAST(MAP() AS MAP(VARCHAR, ROW(a BIGINT)))), x -> x.a > 100) FROM " + tableName, tableName, + ImmutableMap.of()); + + // map_filter + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(MAP_FILTER(m1, (k,v) -> v.a2 > 100)), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1", "m1[*].a2"))); + + assertPushdownSubfields("SELECT MAP_FILTER(m1, (k,v) -> v.a2 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields())); + + // map_remove_null_values + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(MAP_REMOVE_NULL_VALUES(m1)), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"))); + + // map_subset + assertPushdownSubfields("SELECT ANY_MATCH(MAP_VALUES(map_subset(m1, ARRAY[1,2,3])), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1"))); + + // map_top_n_values + assertPushdownSubfields("SELECT ANY_MATCH(MAP_TOP_N_VALUES(m1, 10, (x, y) -> IF(x.a1 < y.a1, -1, IF(x.a1 = y.a1, 0, 1))), x -> x.a2 > 100) FROM " + tableName, tableName, + ImmutableMap.of("m1", toSubfields("m1[*].a1", "m1[*].a2"))); + + // Simple test of different column type of the array argument + assertPushdownSubfields("SELECT ALL_MATCH(y, x -> x.d.d1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(r.a, x -> x.a1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("r", toSubfields("r.a[*].a1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(z[1], x -> x.e.e1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("z", toSubfields("z[1][*].e.e1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(mi[1], x -> x.a1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("mi", toSubfields("mi[1][*].a1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(mv['a'], x -> x.a1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("mv", toSubfields("mv[\"a\"][*].a1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(aa, x -> x[1].a1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("aa", toSubfields("aa[*][1].a1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(am, x -> x[1].a1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("am", toSubfields("am[*][1].a1"))); + + // element_at + assertPushdownSubfields("SELECT ANY_MATCH(ELEMENT_AT(mi, 42), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("mi", toSubfields("mi[42][*].a1"))); + + assertPushdownSubfields("SELECT ANY_MATCH(ELEMENT_AT(mv, '42'), x -> x.a1 > 100) FROM " + tableName, tableName, + ImmutableMap.of("mv", toSubfields("mv[\"42\"][*].a1"))); + + // Queries that reference variables in different arguments + assertPushdownSubfields("SELECT ANY_MATCH(SLICE(y, 1, mv['42'][1].a1), x -> x.a > 100) FROM " + tableName, tableName, + ImmutableMap.of("mv", toSubfields("mv[\"42\"][1].a1"), "y", toSubfields("y[*].a"))); + + // Special form expressions that can hide the access to the entire struct ('and', 'or' are not interesting) + // equal + assertPushdownSubfields("SELECT ANY_MATCH(y, x -> x = (row(1, '', 1.0, row(1, 1.0)))) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // 'equal' effectively accesses the entire struct of the element of 'y' + + // coalesce + assertPushdownSubfields("SELECT ANY_MATCH(ZIP_WITH(y, yy, (x, xx) -> COALESCE(x,xx)), x -> x.a > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields(), "yy", toSubfields())); // + + assertPushdownSubfields("SELECT TRANSFORM(y, x -> COALESCE(x, row(1, '', 1.0, row(1, 1.0)))) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // 'coalesce' effectively includes the entire subfield to returned values + + assertPushdownSubfields("SELECT ANY_MATCH(COALESCE(y, yy), x -> x.a > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"), "yy", toSubfields("yy[*].a"))); + + // in + assertPushdownSubfields("SELECT ANY_MATCH(y, x -> x IN (row(1, '', 1.0, row(1, 1.0)))) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // 'in' effectively accesses the entire struct of the element of 'y' + + // row_construction + assertPushdownSubfields("SELECT TRANSFORM(y, x -> ROW(x)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // entire struct of the element of 'y' was included to the output + + assertPushdownSubfields("SELECT ZIP_WITH(y, yy, (x, xx) -> ROW(x,xx)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields(), "yy", toSubfields())); // entire struct of the elements of 'y' and 'yy' was included to the output + + // switch + assertPushdownSubfields("SELECT TRANSFORM(y, x -> CASE x WHEN row(1, '', 1.0, row(1, 1.0)) THEN true ELSE false END) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // entire struct of the element of 'y' was accessed + + // if + assertPushdownSubfields("SELECT TRANSFORM(y, x -> IF(x = row(1, '', 1.0, row(1, 1.0)), 1, 0)) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields())); // entire struct of the element of 'y' was accessed + + // is_null + assertPushdownSubfields("SELECT ANY_MATCH(y, x -> x IS NULL) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].$"))); // only checks whether entire struct is null and does not access struct subfields + + // bind + assertPushdownSubfields("SELECT ANY_MATCH(y, x -> x.a > mv['42'][1].a1) FROM " + tableName, tableName, // missing support for extracting subfields from BIND expression + ImmutableMap.of("mv", toSubfields("mv[\"42\"][1].a1"))); // In fact, we are accessing only y.a and mv['a'].[*].a1 + + // Queries that lack full support + + // Special form expression + // cast + assertPushdownSubfields("SELECT ANY_MATCH(TRANSFORM(r.a, x -> cast(x AS row(quantity bigint, price double))), x -> x.quantity > 100) FROM " + tableName, tableName, + ImmutableMap.of("r", toSubfields("r.a"))); // in fact, we are accessing only r.a[*].a1 + + // WHERE clause + // lambda subfield extraction from WHERE clause is not supported when hive.pushdown-filter-enabled=true. The entire field will be included in + // OrcSelectivePageSourceFactory by SubfieldExtractor in presto-hive module. + assertPushdownSubfields("SELECT a FROM " + tableName + " WHERE ALL_MATCH(y, x -> x.a > 0)", tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); // In fact, we are accessing only y[*].a. + + assertPushdownSubfields("SELECT TRANSFORM(y, x-> x.d.d1) FROM " + tableName + " WHERE ALL_MATCH(y, x -> x.a > 0)", tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1", "y[*].a"))); // In OrcSelectivePageSourceFactory, + // 'remainingPredicate' will contain 'ALL_MATCH(y, x -> x.a > 0)' and RequiredSubfieldsExtractor will extract the entire 'y' column that will lead to correct results + // though without any optimization + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testMergingSubfields() + { + final String tableName = "test_merging_subfields"; + try { + assertUpdate("CREATE TABLE " + tableName + "(id bigint, " + + "r row(a array(row(a1 bigint, a2 double)), i bigint, d row(d1 bigint, d2 double)), " + + "y array(row(a bigint, b varchar, c double, d row(d1 bigint, d2 double))))"); + + // NoSubfields + assertPushdownSubfields("SELECT CARDINALITY(y) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].$"))); + + assertPushdownSubfields("SELECT CARDINALITY(y), ALL_MATCH(y, x -> x.d.d1 > 0) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"))); + + assertPushdownSubfields("SELECT CARDINALITY(y) FROM " + tableName + " WHERE ALL_MATCH(y, x -> x.d.d1 > 0)", tableName, + ImmutableMap.of("y", toSubfields("y[*].d.d1"))); + + assertPushdownSubfields("SELECT CARDINALITY(y), TRANSFORM(y, x -> x.a) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + assertPushdownSubfields("SELECT TRANSFORM(y, x -> x.a), CARDINALITY(y) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + assertPushdownSubfields("SELECT CARDINALITY(y), ALL_MATCH(y, x -> x.a > 0), TRANSFORM(y, x -> x.d.d1) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a", "y[*].d.d1"))); + + assertPushdownSubfields("SELECT ALL_MATCH(y, x -> x.a > 0) FROM " + tableName + " WHERE CARDINALITY(y) > 42", tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + assertPushdownSubfields("SELECT TRANSFORM(y, x -> x.d), ALL_MATCH(y, x -> x.d.d1 > 0), CARDINALITY(y) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].d"))); + + assertPushdownSubfields("SELECT ALL_MATCH(y, x -> x.a > 0), CARDINALITY(y) FROM " + tableName, tableName, + ImmutableMap.of("y", toSubfields("y[*].a"))); + + // AllSubfields + assertPushdownSubfields("SELECT r.a[1].a1 FROM " + tableName + " WHERE CARDINALITY(r.a) > 42", tableName, + ImmutableMap.of("r", toSubfields("r.a[*].$", "r.a[1].a1"))); // File format reader will handle this situation and extract only r.a[*].a1 + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + @Test public void testPushdownSubfields() { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java new file mode 100644 index 0000000000000..19473ec5261f7 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.tpch.TpchTable; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.hive.HiveQueryRunner.HIVE_CATALOG; +import static com.facebook.presto.hive.HiveSessionProperties.PUSHDOWN_FILTER_ENABLED; + +@Test(singleThreaded = true) +public class TestLambdaSubfieldPruning + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = HiveQueryRunner.createQueryRunner( + ImmutableList.of(TpchTable.LINE_ITEM), + ImmutableMap.of( + "experimental.pushdown-subfields-enabled", "true", + "pushdown-subfields-from-lambda-enabled", "true", + "experimental.pushdown-dereference-enabled", "true"), + "sql-standard", + ImmutableMap.builder() + .put("hive.pushdown-filter-enabled", "true") + .put("hive.parquet.pushdown-filter-enabled", "false") // Parquet does not support selective reader yet. + .build(), + Optional.empty()); + + return createLineItemExTable(queryRunner); + } + + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = HiveQueryRunner.createQueryRunner( + ImmutableList.of(TpchTable.LINE_ITEM), + ImmutableMap.of( + "experimental.pushdown-subfields-enabled", "false", + "pushdown-subfields-from-lambda-enabled", "false", + "experimental.pushdown-dereference-enabled", "true"), + "sql-standard", + ImmutableMap.builder() + .put("hive.pushdown-filter-enabled", "true") + .put("hive.parquet.pushdown-filter-enabled", "false") + .build(), + Optional.empty()); + + return createLineItemExTable(queryRunner); + } + + private static DistributedQueryRunner createLineItemExTable(DistributedQueryRunner queryRunner) + { + for (String fileFormatName : ImmutableSet.of("ORC", "DWRF", "PARQUET")) { + queryRunner.execute(noPushdownFilter(queryRunner.getDefaultSession()), + "CREATE TABLE lineitem_ex_" + fileFormatName + " ( \n" + + " array_of_varchar_keys, \n" + + " array_of_rows, \n" + + " array_of_non_null_rows, \n" + + " array_of_array_of_rows, \n" + + " row_with_array_of_rows, \n" + + " row_with_map_varchar_key_row_value, \n" + + " map_varchar_key_row_value, \n" + + " map_varchar_key_array_of_row_value, \n" + + " array_of_map_entries_varchar_key_row_value, \n" + + " END_OF_LIST \n" + + ") WITH (format = '" + fileFormatName + "') AS \n" + + "SELECT \n" + + " ARRAY['orderkey', 'linenumber', 'partkey'] AS array_of_varchar_keys, \n" + + " IF (orderkey % 49 = 0, NULL, CAST(ARRAY[ \n" + + " ROW(IF (orderkey % 17 = 0, NULL, orderkey), comment), \n" + + " ROW(IF (linenumber % 7 = 0, NULL, linenumber), upper(comment)), \n" + + " ROW(IF (partkey % 5 = 0, NULL, partkey), shipmode) \n" + + " ] AS ARRAY(ROW(itemdata BIGINT, comment VARCHAR)))) AS array_of_rows, \n" + + " CAST(ARRAY[ \n" + + " ROW(orderkey, comment), \n" + + " ROW(linenumber, upper(comment)), \n" + + " ROW(partkey, shipmode) \n" + + " ] AS ARRAY(ROW(itemdata BIGINT, comment VARCHAR))) AS array_of_non_null_rows, \n" + + " IF (orderkey % 49 = 0, NULL, CAST(ARRAY[ARRAY[ \n" + + " ROW(IF (orderkey % 17 = 0, NULL, orderkey), comment), \n" + + " ROW(IF (linenumber % 7 = 0, NULL, linenumber), upper(comment)), \n" + + " ROW(IF (partkey % 5 = 0, NULL, partkey), shipmode) \n" + + " ]] AS ARRAY(ARRAY(ROW(itemdata BIGINT, comment VARCHAR))))) AS array_of_array_of_rows, \n" + + " CAST(ROW(ARRAY[ROW(orderkey, comment)]) AS ROW(array_of_rows ARRAY(ROW(itemdata BIGINT, comment VARCHAR)))) row_with_array_of_rows, \n" + + " CAST(ROW(MAP_FROM_ENTRIES(ARRAY[ \n" + + " ROW('orderdata', ROW(IF (orderkey % 17 = 0, 1, orderkey), linenumber, partkey)), \n" + + " ROW('orderdata_ex', ROW(orderkey + 100, linenumber + 100, partkey + 100))])) \n" + + " AS ROW(map_varchar_key_row_value MAP(VARCHAR, ROW(orderkey BIGINT, linenumber BIGINT, partkey BIGINT)))) row_with_map_varchar_key_row_value, \n" + + " CAST(MAP_FROM_ENTRIES(ARRAY[ \n" + + " ROW('orderdata', ROW(IF (orderkey % 17 = 0, 1, orderkey), linenumber, partkey)), \n" + + " ROW('orderdata_ex', ROW(orderkey + 100, linenumber + 100, partkey + 100))]) \n" + + " AS MAP(VARCHAR, ROW(orderkey BIGINT, linenumber BIGINT, partkey BIGINT))) AS map_varchar_key_row_value, \n" + + " CAST(MAP_FROM_ENTRIES(ARRAY[ \n" + + " ROW('orderdata', IF (orderkey % 13 = 0, NULL, ARRAY[ROW( IF (orderkey % 17 = 0, NULL, orderkey), linenumber, partkey)])), \n" + + " ROW('orderdata_ex', ARRAY[ ROW(orderkey + 100, linenumber + 100, partkey + 100)])]) \n" + + " AS MAP(VARCHAR, ARRAY(ROW(orderkey BIGINT, linenumber BIGINT, partkey BIGINT)))) AS map_varchar_key_array_of_row_value, \n" + + " CAST(ARRAY[ \n" + + " ROW('orderdata', IF (orderkey % 13 = 0, NULL, ROW( IF (orderkey % 17 = 0, NULL, orderkey), linenumber, partkey))), \n" + + " ROW('orderdata_ex', ROW(orderkey + 100, linenumber + 100, partkey + 100))] \n" + + " AS ARRAY(ROW(key VARCHAR, value ROW(orderkey BIGINT, linenumber BIGINT, partkey BIGINT)))) AS array_of_map_entries_varchar_key_row_value, \n" + + " true AS END_OF_LIST \n" + + + "FROM lineitem \n"); + } + return queryRunner; + } + + private static Session noPushdownFilter(Session session) + { + return Session.builder(session) + .setCatalogSessionProperty(HIVE_CATALOG, PUSHDOWN_FILTER_ENABLED, "false") + .build(); + } + + @Test + public void testPushDownSubfieldsFromLambdas() + { + for (String fileFormatName : ImmutableSet.of("ORC", "DWRF", "PARQUET")) { + testPushDownSubfieldsFromLambdas("lineitem_ex_" + fileFormatName); + } + } + + private void testPushDownSubfieldsFromLambdas(String tableName) + { + // functions that are not outputting all subfields + assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(array_of_rows, x -> true) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x IS NULL) FROM " + tableName); + assertQuery("SELECT ANY_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT NONE_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT TRANSFORM(array_of_rows, x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT CARDINALITY(array_of_rows) FROM " + tableName); + assertQuery("SELECT CARDINALITY(FLATTEN(array_of_array_of_rows)) FROM " + tableName); + assertQuery("SELECT CARDINALITY(FILTER(array_of_rows, x -> POSITION('T' IN x.comment) > 0)) FROM " + tableName); + assertQuery("SELECT TRANSFORM(row_with_array_of_rows.array_of_rows, x -> CAST(ROW(x.comment, x.comment) AS ROW(d1 VARCHAR, d2 VARCHAR))) FROM " + tableName); + assertQuery("SELECT TRANSFORM_VALUES(map_varchar_key_row_value, (k,v) -> v.orderkey) FROM " + tableName); + assertQuery("SELECT ZIP_WITH(array_of_rows, row_with_array_of_rows.array_of_rows, (x, y) -> CAST(ROW(x.itemdata, y.comment) AS ROW(d1 BIGINT, d2 VARCHAR))) FROM " + tableName); + assertQuery("SELECT MAP_ZIP_WITH(map_varchar_key_row_value, row_with_map_varchar_key_row_value.map_varchar_key_row_value, (k, v1, v2) -> v1.orderkey + v2.orderkey) FROM " + tableName); + + // functions that outputing all subfields and accept functional parameter + assertQuery("SELECT FILTER(array_of_rows, x -> POSITION('T' IN x.comment) > 0) FROM " + tableName); + assertQuery("SELECT ARRAY_SORT(array_of_rows, (l, r) -> IF(l.itemdata < r.itemdata, 1, IF(l.itemdata = r.itemdata, 0, -1))) FROM " + tableName); + assertQuery("SELECT ANY_MATCH(SLICE(ARRAY_SORT(array_of_rows, (l, r) -> IF(l.itemdata < r.itemdata, 1, IF(l.itemdata = r.itemdata, 0, -1))), 1, 3), x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT TRANSFORM(ARRAY_SORT(array_of_non_null_rows), x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM(COMBINATIONS(array_of_non_null_rows, 3), x -> x[1].itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM(FLATTEN(array_of_array_of_rows), x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM(REVERSE(array_of_rows), x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT ARRAY_SORT(TRANSFORM(SHUFFLE(array_of_non_null_rows), x -> x.itemdata)) FROM " + tableName); + assertQuery("SELECT TRANSFORM(SLICE(array_of_rows, 1, 5), x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM(TRIM_ARRAY(array_of_rows, 2), x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM(CONCAT(array_of_rows, row_with_array_of_rows.array_of_rows), x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM(array_of_rows || row_with_array_of_rows.array_of_rows, x -> x.itemdata) FROM " + tableName); + assertQuery("SELECT TRANSFORM_VALUES(MAP_CONCAT(map_varchar_key_row_value, row_with_map_varchar_key_row_value.map_varchar_key_row_value), (k,v) -> v.orderkey) FROM " + tableName); + assertQuery("SELECT TRANSFORM_VALUES(MAP_REMOVE_NULL_VALUES(row_with_map_varchar_key_row_value.map_varchar_key_row_value), (k,v) -> v.orderkey) FROM " + tableName); + assertQuery("SELECT TRANSFORM_VALUES(MAP_SUBSET(row_with_map_varchar_key_row_value.map_varchar_key_row_value, ARRAY['orderdata_ex']), (k,v) -> v.orderkey) FROM " + tableName); + assertQuery("SELECT ANY_MATCH(MAP_VALUES(map_varchar_key_row_value), x -> x.orderkey % 2 = 0) FROM " + tableName); + assertQuery("SELECT ANY_MATCH(MAP_TOP_N_VALUES(map_varchar_key_row_value, 10, (x, y) -> IF(x.orderkey < y.orderkey, -1, IF(x.orderkey = y.orderkey, 0, 1))), x -> x.orderkey % 2 = 0) FROM " + tableName); + assertQuery("SELECT ANY_MATCH(MAP_VALUES(map_varchar_key_row_value), x -> x.orderkey % 2 = 0) FROM " + tableName); + + // Simple test of different column type of the array argument + assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(row_with_array_of_rows.array_of_rows, x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(array_of_array_of_rows[1], x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(map_varchar_key_array_of_row_value['orderdata'], x -> x.orderkey > 0) FROM " + tableName); + + // element_at + assertQuery("SELECT ALL_MATCH(ELEMENT_AT(map_varchar_key_array_of_row_value, 'orderdata'), x -> x.orderkey > 0) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(ELEMENT_AT(array_of_array_of_rows, 1), x -> x.itemdata > 0) FROM " + tableName); + + // Queries that reference variables in lambdas + assertQuery("SELECT ALL_MATCH(SLICE(array_of_rows, 1, row_with_array_of_rows.array_of_rows[1].itemdata), x -> x.itemdata > 0) FROM " + tableName); + assertQuery("SELECT ALL_MATCH(array_of_rows, x -> x.itemdata > row_with_array_of_rows.array_of_rows[1].itemdata) FROM " + tableName); + + // Queries that lack full support + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index f58380510e656..440bfa20c0a46 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -189,6 +189,7 @@ public final class SystemSessionProperties public static final String OPTIMIZE_CONSTANT_GROUPING_KEYS = "optimize_constant_grouping_keys"; public static final String MAX_CONCURRENT_MATERIALIZATIONS = "max_concurrent_materializations"; public static final String PUSHDOWN_SUBFIELDS_ENABLED = "pushdown_subfields_enabled"; + public static final String PUSHDOWN_SUBFIELDS_FROM_LAMBDA_ENABLED = "pushdown_subfields_from_lambda_enabled"; public static final String TABLE_WRITER_MERGE_OPERATOR_ENABLED = "table_writer_merge_operator_enabled"; public static final String INDEX_LOADER_TIMEOUT = "index_loader_timeout"; public static final String OPTIMIZED_REPARTITIONING_ENABLED = "optimized_repartitioning"; @@ -1044,6 +1045,11 @@ public SystemSessionProperties( "Experimental: enable subfield pruning", featuresConfig.isPushdownSubfieldsEnabled(), false), + booleanProperty( + PUSHDOWN_SUBFIELDS_FROM_LAMBDA_ENABLED, + "Enable subfield pruning from lambdas", + featuresConfig.isPushdownSubfieldsFromLambdaEnabled(), + false), booleanProperty( PUSHDOWN_DEREFERENCE_ENABLED, "Experimental: enable dereference pushdown", @@ -2403,6 +2409,11 @@ public static boolean isPushdownSubfieldsEnabled(Session session) return session.getSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, Boolean.class); } + public static boolean isPushdownSubfieldsFromArrayLambdasEnabled(Session session) + { + return session.getSystemProperty(PUSHDOWN_SUBFIELDS_FROM_LAMBDA_ENABLED, Boolean.class); + } + public static boolean isPushdownDereferenceEnabled(Session session) { return session.getSystemProperty(PUSHDOWN_DEREFERENCE_ENABLED, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 7bfe03b587f78..0da4b3e5fd3fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -1081,7 +1081,8 @@ public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) signature.getKind(), JAVA, function.isDeterministic(), - function.isCalledOnNullInput()); + function.isCalledOnNullInput(), + function.getComplexTypeFunctionDescriptor()); } else if (function instanceof SqlInvokedFunction) { SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; @@ -1096,7 +1097,8 @@ else if (function instanceof SqlInvokedFunction) { SQL, function.isDeterministic(), function.isCalledOnNullInput(), - sqlFunction.getVersion()); + sqlFunction.getVersion(), + sqlFunction.getComplexTypeFunctionDescriptor()); } else { return new FunctionMetadata( @@ -1106,7 +1108,8 @@ else if (function instanceof SqlInvokedFunction) { signature.getKind(), JAVA, function.isDeterministic(), - function.isCalledOnNullInput()); + function.isCalledOnNullInput(), + function.getComplexTypeFunctionDescriptor()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SqlScalarFunction.java b/presto-main/src/main/java/com/facebook/presto/metadata/SqlScalarFunction.java index 48793ae7afe22..901a62564332f 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SqlScalarFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SqlScalarFunction.java @@ -15,8 +15,10 @@ import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Signature; +import static com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor.defaultFunctionDescriptor; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -25,11 +27,18 @@ public abstract class SqlScalarFunction extends BuiltInFunction { private final Signature signature; + private final ComplexTypeFunctionDescriptor descriptor; - protected SqlScalarFunction(Signature signature) + protected SqlScalarFunction(Signature signature, ComplexTypeFunctionDescriptor descriptor) { this.signature = requireNonNull(signature, "signature is null"); checkArgument(signature.getKind() == SCALAR, "function kind must be SCALAR"); + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + protected SqlScalarFunction(Signature signature) + { + this(signature, defaultFunctionDescriptor()); } @Override @@ -38,6 +47,12 @@ public final Signature getSignature() return signature; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + public abstract BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager); public static PolymorphicScalarFunctionBuilder builder(Class clazz, OperatorType operatorType) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAllMatchFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAllMatchFunction.java index 4bf8974a5f885..98943e896e649 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAllMatchFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAllMatchFunction.java @@ -16,10 +16,15 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StaticMethodPointer; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.TypeParameterSpecialization; import io.airlift.slice.Slice; @@ -27,7 +32,15 @@ import static java.lang.Boolean.FALSE; @Description("Returns true if all elements of the array match the given predicate") -@ScalarFunction(value = "all_match") +@ScalarFunction(value = "all_match", descriptor = @ScalarFunctionDescriptor( + outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")}, + lambdaDescriptors = { + @ScalarFunctionLambdaDescriptor( + callArgumentIndex = 1, + lambdaArgumentDescriptors = { + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 0, + callArgumentIndex = 0)})})) public final class ArrayAllMatchFunction { private ArrayAllMatchFunction() {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAnyMatchFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAnyMatchFunction.java index 029c0caf9a325..1f5043a549fad 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAnyMatchFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayAnyMatchFunction.java @@ -16,10 +16,15 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StaticMethodPointer; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.TypeParameterSpecialization; import io.airlift.slice.Slice; @@ -27,7 +32,15 @@ import static java.lang.Boolean.TRUE; @Description("Returns true if the array contains one or more elements that match the given predicate") -@ScalarFunction(value = "any_match") +@ScalarFunction(value = "any_match", descriptor = @ScalarFunctionDescriptor( + outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")}, + lambdaDescriptors = { + @ScalarFunctionLambdaDescriptor( + callArgumentIndex = 1, + lambdaArgumentDescriptors = { + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 0, + callArgumentIndex = 0)})})) public final class ArrayAnyMatchFunction { private ArrayAnyMatchFunction() {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java index 58baf6c0d475a..4865cf6721892 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java @@ -15,13 +15,19 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StaticMethodPointer; import com.facebook.presto.spi.function.TypeParameter; @Description("Returns the cardinality (length) of the array") -@ScalarFunction("cardinality") +@ScalarFunction(value = "cardinality", descriptor = @ScalarFunctionDescriptor( + isAccessingInputValues = false, + outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")}, + lambdaDescriptors = {})) public final class ArrayCardinalityFunction { private ArrayCardinalityFunction() {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCombinationsFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCombinationsFunction.java index bc2d442340c22..dc4a62aff07cb 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCombinationsFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCombinationsFunction.java @@ -21,9 +21,12 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StaticMethodPointer; import com.facebook.presto.spi.function.TypeParameter; import com.google.common.annotations.VisibleForTesting; @@ -39,7 +42,10 @@ import static java.util.Arrays.setAll; @Description("Returns n-element combinations from array") -@ScalarFunction("combinations") +@ScalarFunction(value = "combinations", descriptor = @ScalarFunctionDescriptor( + isAccessingInputValues = false, + outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "removeSecondPathElement")}, + lambdaDescriptors = {})) public final class ArrayCombinationsFunction { private ArrayCombinationsFunction() {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java index 0b63a11e11d18..f64ecefa10993 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java @@ -22,6 +22,7 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; @@ -29,6 +30,7 @@ import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; @@ -51,6 +53,8 @@ public final class ArrayConcatFunction private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayConcatFunction.class, "concat", Type.class, Block[].class); + private final ComplexTypeFunctionDescriptor descriptor; + private ArrayConcatFunction() { super(new Signature( @@ -61,6 +65,12 @@ private ArrayConcatFunction() parseTypeSignature("array(E)"), ImmutableList.of(parseTypeSignature("array(E)")), true)); + descriptor = new ComplexTypeFunctionDescriptor( + false, + ImmutableList.of(), + Optional.empty(), + Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired), + getSignature()); } @Override @@ -81,6 +91,12 @@ public String getDescription() return DESCRIPTION; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java index 74495c2061138..e229030203fb1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java @@ -18,6 +18,9 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.TypeParameterSpecialization; @@ -26,7 +29,14 @@ import static java.lang.Boolean.TRUE; @Description("return array containing elements that match the given predicate") -@ScalarFunction(value = "filter", deterministic = false) +@ScalarFunction(value = "filter", deterministic = false, descriptor = @ScalarFunctionDescriptor( + lambdaDescriptors = { + @ScalarFunctionLambdaDescriptor( + callArgumentIndex = 1, + lambdaArgumentDescriptors = { + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 0, + callArgumentIndex = 0)})})) public final class ArrayFilterFunction { private ArrayFilterFunction() {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java index 07574db7f595e..3a3f537dc209a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java @@ -22,12 +22,15 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; @@ -44,6 +47,7 @@ public class ArrayFlattenFunction public static final ArrayFlattenFunction ARRAY_FLATTEN_FUNCTION = new ArrayFlattenFunction(); private static final String FUNCTION_NAME = "flatten"; private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayFlattenFunction.class, FUNCTION_NAME, Type.class, Type.class, Block.class); + private final ComplexTypeFunctionDescriptor descriptor; private ArrayFlattenFunction() { @@ -54,6 +58,12 @@ private ArrayFlattenFunction() parseTypeSignature("array(E)"), ImmutableList.of(parseTypeSignature("array(array(E))")), false)); + descriptor = new ComplexTypeFunctionDescriptor( + false, + ImmutableList.of(), + Optional.of(ImmutableSet.of(0)), + Optional.of(ComplexTypeFunctionDescriptor::prependAllSubscripts), + getSignature()); } @Override @@ -74,6 +84,12 @@ public String getDescription() return "Flattens the given array"; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNoneMatchFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNoneMatchFunction.java index b3bcacd103a30..6fe1d559cb29b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNoneMatchFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNoneMatchFunction.java @@ -16,16 +16,29 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StaticMethodPointer; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.TypeParameterSpecialization; import io.airlift.slice.Slice; @Description("Returns true if all elements of the array don't match the given predicate") -@ScalarFunction(value = "none_match") +@ScalarFunction(value = "none_match", descriptor = @ScalarFunctionDescriptor( + outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")}, + lambdaDescriptors = { + @ScalarFunctionLambdaDescriptor( + callArgumentIndex = 1, + lambdaArgumentDescriptors = { + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 0, + callArgumentIndex = 0)})})) public final class ArrayNoneMatchFunction { private ArrayNoneMatchFunction() {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReverseFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReverseFunction.java index cc5555c227a99..85653db745aa9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReverseFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReverseFunction.java @@ -18,11 +18,12 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; -@ScalarFunction("reverse") @Description("Returns an array which has the reversed order of the given array.") +@ScalarFunction(value = "reverse", descriptor = @ScalarFunctionDescriptor(isAccessingInputValues = false)) public final class ArrayReverseFunction { @TypeParameter("E") diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayShuffleFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayShuffleFunction.java index 53da9ace28b53..dff971c95a573 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayShuffleFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayShuffleFunction.java @@ -18,13 +18,14 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import java.util.concurrent.ThreadLocalRandom; -@ScalarFunction(value = "shuffle", deterministic = false) @Description("Generates a random permutation of the given array.") +@ScalarFunction(value = "shuffle", deterministic = false, descriptor = @ScalarFunctionDescriptor(isAccessingInputValues = false)) public final class ArrayShuffleFunction { private static final int INITIAL_LENGTH = 128; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySliceFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySliceFunction.java index e37f7f0384b9a..8b645528a5f90 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySliceFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySliceFunction.java @@ -19,13 +19,14 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.util.Failures.checkCondition; -@ScalarFunction("slice") +@ScalarFunction(value = "slice", descriptor = @ScalarFunctionDescriptor(isAccessingInputValues = false)) @Description("Subsets an array given an offset (1-indexed) and length") public final class ArraySliceFunction { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySortComparatorFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySortComparatorFunction.java index 6e986fa3d7f30..0be7d15564201 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySortComparatorFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySortComparatorFunction.java @@ -19,6 +19,9 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.TypeParameterSpecialization; @@ -32,7 +35,17 @@ import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.util.Failures.checkCondition; -@ScalarFunction("array_sort") +@ScalarFunction(value = "array_sort", descriptor = @ScalarFunctionDescriptor( + lambdaDescriptors = { + @ScalarFunctionLambdaDescriptor( + callArgumentIndex = 1, + lambdaArgumentDescriptors = { + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 0, + callArgumentIndex = 0), + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 1, + callArgumentIndex = 0)})})) @Description("Sorts the given array with a lambda comparator.") public final class ArraySortComparatorFunction { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java index 1f67c06b98a91..b524a822d1137 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java @@ -32,13 +32,20 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.facebook.presto.sql.gen.lambda.UnaryFunctionInterface; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Primitives; +import java.util.Optional; + import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; @@ -70,6 +77,8 @@ public final class ArrayTransformFunction { public static final ArrayTransformFunction ARRAY_TRANSFORM_FUNCTION = new ArrayTransformFunction(); + private final ComplexTypeFunctionDescriptor descriptor; + private ArrayTransformFunction() { super(new Signature( @@ -80,6 +89,12 @@ private ArrayTransformFunction() parseTypeSignature("array(U)"), ImmutableList.of(parseTypeSignature("array(T)"), parseTypeSignature("function(T,U)")), false)); + descriptor = new ComplexTypeFunctionDescriptor( + true, + ImmutableList.of(new LambdaDescriptor(1, ImmutableMap.of(0, new LambdaArgumentDescriptor(0, ComplexTypeFunctionDescriptor::prependAllSubscripts)))), + Optional.of(ImmutableSet.of(0)), + Optional.of(ComplexTypeFunctionDescriptor::clearRequiredSubfields), + getSignature()); } @Override @@ -114,6 +129,12 @@ public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariab methodHandle(generatedClass, "transform", Block.class, UnaryFunctionInterface.class)); } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + private static Class generateTransform(Type inputType, Type outputType) { CallSiteBinder binder = new CallSiteBinder(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTrimFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTrimFunction.java index 469b88f741afb..12b8ffe4656e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTrimFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTrimFunction.java @@ -18,6 +18,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; @@ -25,7 +26,9 @@ import static com.facebook.presto.util.Failures.checkCondition; import static java.lang.Math.toIntExact; -@ScalarFunction("trim_array") +@ScalarFunction(value = "trim_array", descriptor = @ScalarFunctionDescriptor( + isAccessingInputValues = false, + lambdaDescriptors = {})) @Description("Remove elements from the end of array") public final class ArrayTrimFunction { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapCardinalityFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapCardinalityFunction.java index bbe0b9d4c4e2a..4aaa4e167c553 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapCardinalityFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapCardinalityFunction.java @@ -15,12 +15,18 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StaticMethodPointer; import com.facebook.presto.spi.function.TypeParameter; -@ScalarFunction("cardinality") +@ScalarFunction(value = "cardinality", descriptor = @ScalarFunctionDescriptor( + isAccessingInputValues = false, + outputToInputTransformationFunction = {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "clearRequiredSubfields")}, + lambdaDescriptors = {})) @Description("Returns the cardinality (the number of key-value pairs) of the map") public final class MapCardinalityFunction { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java index 02d199e7981f5..e0af8cb522616 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java @@ -28,6 +28,7 @@ import com.facebook.presto.operator.aggregation.OptimizedTypedSet; import com.facebook.presto.operator.project.SelectedPositions; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; @@ -36,6 +37,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; @@ -59,6 +61,8 @@ public final class MapConcatFunction private static final MethodHandle METHOD_HANDLE = methodHandle(MapConcatFunction.class, "mapConcat", MapType.class, Block[].class); + private final ComplexTypeFunctionDescriptor descriptor; + private MapConcatFunction() { super(new Signature( @@ -69,6 +73,12 @@ private MapConcatFunction() parseTypeSignature("map(K,V)"), ImmutableList.of(parseTypeSignature("map(K,V)")), true)); + descriptor = new ComplexTypeFunctionDescriptor( + false, + ImmutableList.of(), + Optional.empty(), + Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired), + getSignature()); } @Override @@ -89,6 +99,12 @@ public String getDescription() return DESCRIPTION; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java index 4dfb3de0863e1..0dee32f9b7102 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java @@ -30,10 +30,12 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.lang.invoke.MethodHandle; import java.util.Optional; @@ -62,6 +64,8 @@ public final class MapConstructor { public static final MapConstructor MAP_CONSTRUCTOR = new MapConstructor(); + private final ComplexTypeFunctionDescriptor descriptor; + private static final MethodHandle METHOD_HANDLE = methodHandle( MapConstructor.class, "createMap", @@ -87,6 +91,12 @@ public MapConstructor() TypeSignature.parseTypeSignature("map(K,V)"), ImmutableList.of(TypeSignature.parseTypeSignature("array(K)"), TypeSignature.parseTypeSignature("array(V)")), false)); + descriptor = new ComplexTypeFunctionDescriptor( + false, + ImmutableList.of(), + Optional.of(ImmutableSet.of(1)), + Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired), + getSignature()); } @Override @@ -107,6 +117,12 @@ public String getDescription() return DESCRIPTION; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java index b8c9bd7238c62..ab5177bcd4765 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java @@ -30,6 +30,10 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.CodegenScalarFunction; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.IntArray; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; @@ -64,7 +68,15 @@ public final class MapFilterFunction { private MapFilterFunction() {} - @CodegenScalarFunction(value = "map_filter", deterministic = false) + @CodegenScalarFunction(value = "map_filter", deterministic = false, descriptor = @ScalarFunctionDescriptor( + argumentIndicesContainingMapOrArray = @IntArray(0), + lambdaDescriptors = { + @ScalarFunctionLambdaDescriptor( + callArgumentIndex = 1, + lambdaArgumentDescriptors = { + @ScalarFunctionLambdaArgumentDescriptor( + lambdaArgumentIndex = 1, + callArgumentIndex = 0)})})) @Description("return map containing entries that match the given predicate") @TypeParameter("K") @TypeParameter("V") diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubsetFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubsetFunction.java index abc360666d9e1..8449715f48e46 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubsetFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubsetFunction.java @@ -1,4 +1,4 @@ -/* + /* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -19,9 +19,12 @@ import com.facebook.presto.operator.aggregation.TypedSet; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; -@ScalarFunction("map_subset") +@ScalarFunction(value = "map_subset", descriptor = @ScalarFunctionDescriptor( + isAccessingInputValues = false, + lambdaDescriptors = {})) @Description("returns a map where the keys are a subset of the given array of keys") public final class MapSubsetFunction { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java index 7f456fea3fd0d..a884c9d360700 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java @@ -37,16 +37,22 @@ import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.ErrorCodeSupplier; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; @@ -85,6 +91,8 @@ public final class MapTransformValueFunction { public static final MapTransformValueFunction MAP_TRANSFORM_VALUE_FUNCTION = new MapTransformValueFunction(); + private final ComplexTypeFunctionDescriptor descriptor; + private MapTransformValueFunction() { super(new Signature( @@ -95,6 +103,12 @@ private MapTransformValueFunction() parseTypeSignature("map(K,V2)"), ImmutableList.of(parseTypeSignature("map(K,V1)"), parseTypeSignature("function(K,V1,V2)")), false)); + descriptor = new ComplexTypeFunctionDescriptor( + true, + ImmutableList.of(new LambdaDescriptor(1, ImmutableMap.of(1, new LambdaArgumentDescriptor(0, ComplexTypeFunctionDescriptor::prependAllSubscripts)))), + Optional.of(ImmutableSet.of(0)), + Optional.of(ComplexTypeFunctionDescriptor::clearRequiredSubfields), + getSignature()); } @Override @@ -115,6 +129,12 @@ public String getDescription() return "apply lambda to each entry of the map and transform the value"; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapValues.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapValues.java index adca86b8d6df6..dc89acb0b2591 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapValues.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapValues.java @@ -18,10 +18,12 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; -@ScalarFunction("map_values") +@ScalarFunction(value = "map_values", descriptor = @ScalarFunctionDescriptor( + isAccessingInputValues = false)) @Description("Returns the values of the given map(K,V) as an array") public final class MapValues { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapZipWithFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapZipWithFunction.java index f443238d06c8e..100c161d28f16 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapZipWithFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapZipWithFunction.java @@ -27,13 +27,19 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.facebook.presto.sql.gen.lambda.LambdaFunctionInterface; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.common.block.MethodHandleUtil.compose; import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter; @@ -57,6 +63,8 @@ public final class MapZipWithFunction { public static final MapZipWithFunction MAP_ZIP_WITH_FUNCTION = new MapZipWithFunction(); + private final ComplexTypeFunctionDescriptor descriptor; + private static final MethodHandle METHOD_HANDLE = methodHandle(MapZipWithFunction.class, "mapZipWith", Type.class, Type.class, Type.class, MapType.class, MethodHandle.class, MethodHandle.class, MethodHandle.class, Block.class, Block.class, MapZipWithLambda.class); private MapZipWithFunction() { @@ -68,6 +76,14 @@ private MapZipWithFunction() parseTypeSignature("map(K,V3)"), ImmutableList.of(parseTypeSignature("map(K,V1)"), parseTypeSignature("map(K,V2)"), parseTypeSignature("function(K,V1,V2,V3)")), false)); + descriptor = new ComplexTypeFunctionDescriptor( + true, + ImmutableList.of(new LambdaDescriptor(2, ImmutableMap.of( + 1, new LambdaArgumentDescriptor(0, ComplexTypeFunctionDescriptor::prependAllSubscripts), + 2, new LambdaArgumentDescriptor(1, ComplexTypeFunctionDescriptor::prependAllSubscripts)))), + Optional.of(ImmutableSet.of(0, 1)), + Optional.of(ComplexTypeFunctionDescriptor::clearRequiredSubfields), + getSignature()); } @Override @@ -88,6 +104,12 @@ public String getDescription() return "merge two maps into a single map by applying the lambda function to the pair of values with the same key"; } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index 3950c075e9fe4..7a836fbd2b3c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -19,6 +19,7 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.google.common.annotations.VisibleForTesting; @@ -38,6 +39,7 @@ public class ParametricScalar { private final ScalarHeader details; private final ParametricImplementationsGroup implementations; + private final ComplexTypeFunctionDescriptor descriptor; public ParametricScalar( Signature signature, @@ -47,6 +49,12 @@ public ParametricScalar( super(signature); this.details = requireNonNull(details); this.implementations = requireNonNull(implementations); + this.descriptor = new ComplexTypeFunctionDescriptor( + details.getFunctionDescriptor().isAccessingInputValues(), + details.getFunctionDescriptor().getLambdaDescriptors(), + details.getFunctionDescriptor().getArgumentIndicesContainingMapOrArray(), + details.getFunctionDescriptor().getOutputToInputTransformationFunction(), + signature); } @Override @@ -114,4 +122,10 @@ public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariab throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", boundVariables, getSignature())); } + + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java index b6b5e33fa6c00..8aa40cf3ff11a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.SqlFunctionVisibility; import java.util.Optional; @@ -23,13 +24,15 @@ public class ScalarHeader private final SqlFunctionVisibility visibility; private final boolean deterministic; private final boolean calledOnNullInput; + private final ComplexTypeFunctionDescriptor functionDescriptor; - public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput) + public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput, ComplexTypeFunctionDescriptor functionDescriptor) { this.description = description; this.visibility = visibility; this.deterministic = deterministic; this.calledOnNullInput = calledOnNullInput; + this.functionDescriptor = functionDescriptor; } public Optional getDescription() @@ -51,4 +54,9 @@ public boolean isCalledOnNullInput() { return calledOnNullInput; } + + public ComplexTypeFunctionDescriptor getFunctionDescriptor() + { + return functionDescriptor; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java index 47c840ed5e7e8..1ce21b169d21e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java @@ -21,13 +21,19 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.TypeUtils.readNativeValue; @@ -46,9 +52,10 @@ public final class ZipWithFunction extends SqlScalarFunction { public static final ZipWithFunction ZIP_WITH_FUNCTION = new ZipWithFunction(); - private static final MethodHandle METHOD_HANDLE = methodHandle(ZipWithFunction.class, "zipWith", Type.class, Type.class, ArrayType.class, Block.class, Block.class, BinaryFunctionInterface.class); + private final ComplexTypeFunctionDescriptor descriptor; + private ZipWithFunction() { super(new Signature( @@ -59,6 +66,14 @@ private ZipWithFunction() parseTypeSignature("array(R)"), ImmutableList.of(parseTypeSignature("array(T)"), parseTypeSignature("array(U)"), parseTypeSignature("function(T,U,R)")), false)); + descriptor = new ComplexTypeFunctionDescriptor( + true, + ImmutableList.of(new LambdaDescriptor(2, ImmutableMap.of( + 0, new LambdaArgumentDescriptor(0, ComplexTypeFunctionDescriptor::prependAllSubscripts), + 1, new LambdaArgumentDescriptor(1, ComplexTypeFunctionDescriptor::prependAllSubscripts)))), + Optional.of(ImmutableSet.of(0, 1)), + Optional.of(ComplexTypeFunctionDescriptor::clearRequiredSubfields), + getSignature()); } @Override @@ -95,6 +110,12 @@ public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariab METHOD_HANDLE.bindTo(leftElementType).bindTo(rightElementType).bindTo(outputArrayType)); } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + public static Block zipWith( Type leftElementType, Type rightElementType, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java index f773497fd08bb..49804e1d440ae 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CodegenScalarFunction; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.IsNull; @@ -56,6 +57,7 @@ import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.USE_BOXED_TYPE; +import static com.facebook.presto.operator.scalar.annotations.FunctionDescriptorParser.parseFunctionDescriptor; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.function.Signature.withVariadicBound; import static com.facebook.presto.util.Failures.checkCondition; @@ -121,8 +123,13 @@ private static SqlScalarFunction createSqlScalarFunction(Method method) parseTypeSignature(method.getAnnotation(SqlType.class).value()), Arrays.stream(method.getParameters()).map(p -> parseTypeSignature(p.getAnnotation(SqlType.class).value())).collect(toImmutableList()), false); - - return new SqlScalarFunction(signature) + ComplexTypeFunctionDescriptor functionDescriptor = parseFunctionDescriptor(codegenScalarFunction.descriptor()); + return new SqlScalarFunction(signature, new ComplexTypeFunctionDescriptor( + functionDescriptor.isAccessingInputValues(), + functionDescriptor.getLambdaDescriptors(), + functionDescriptor.getArgumentIndicesContainingMapOrArray(), + functionDescriptor.getOutputToInputTransformationFunction(), + signature)) { @Override public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/FunctionDescriptorParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/FunctionDescriptorParser.java new file mode 100644 index 0000000000000..ebf3ae125cc9e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/FunctionDescriptorParser.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar.annotations; + +import com.facebook.presto.common.Subfield; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaArgumentDescriptor; +import com.facebook.presto.spi.function.ScalarFunctionLambdaDescriptor; +import com.facebook.presto.spi.function.StaticMethodPointer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; + +public class FunctionDescriptorParser +{ + private FunctionDescriptorParser() {} + public static ComplexTypeFunctionDescriptor parseFunctionDescriptor(ScalarFunctionDescriptor descriptor) + { + if (descriptor.outputToInputTransformationFunction().length > 1) { + throw new IllegalArgumentException("outputToInputTransformationFunction must contain at most 1 element."); + } + return new ComplexTypeFunctionDescriptor( + descriptor.isAccessingInputValues(), + parseLambdaDescriptors(descriptor.lambdaDescriptors()), + descriptor.argumentIndicesContainingMapOrArray().length == 1 ? + Optional.of(ImmutableSet.copyOf(Arrays.stream(descriptor.argumentIndicesContainingMapOrArray()[0].value()).iterator())) : Optional.empty(), + descriptor.outputToInputTransformationFunction().length == 1 ? + Optional.of(parseSubfieldTransformationFunction(descriptor.outputToInputTransformationFunction()[0])) : + Optional.empty()); + } + + private static Function, Set> parseSubfieldTransformationFunction(StaticMethodPointer staticMethodPointer) + { + Method subfieldTransformationMethod; + try { + subfieldTransformationMethod = ((Class) staticMethodPointer.clazz()).getDeclaredMethod(staticMethodPointer.method(), Set.class); + } + catch (NoSuchMethodException e) { + return null; + } + checkSubfieldTransformFunctionTypeSignature(subfieldTransformationMethod); + + return (Set subfields) -> { + try { + return (Set) subfieldTransformationMethod.invoke(null, subfields); + } + catch (IllegalAccessException | InvocationTargetException e) { + return ComplexTypeFunctionDescriptor.allSubfieldsRequired(subfields); + } + }; + } + + private static void checkSubfieldTransformFunctionTypeSignature(Method subfieldTransformationMethod) + { + { + String errorMessage = "Subfield transformation function must accept a single parameter of type java.util.Set"; + Type[] inputTypes = subfieldTransformationMethod.getGenericParameterTypes(); + checkArgument(inputTypes.length == 1, errorMessage); + checkTypeIsSetOfSubfields(inputTypes[0], errorMessage); + } + { + String errorMessage = "Subfield transformation function return type must be java.util.Set"; + Type type = subfieldTransformationMethod.getGenericReturnType(); + checkTypeIsSetOfSubfields(type, errorMessage); + } + } + + private static void checkTypeIsSetOfSubfields(Type type, String errorMessage) + { + final ParameterizedType setType = (ParameterizedType) type; + checkArgument(setType.getRawType().equals(Set.class), errorMessage); + checkArgument(setType.getActualTypeArguments()[0].equals(Subfield.class), errorMessage); + } + + private static List parseLambdaDescriptors(ScalarFunctionLambdaDescriptor[] lambdaDescriptors) + { + ImmutableList.Builder lambdaDescriptorBuilder = ImmutableList.builder(); + for (ScalarFunctionLambdaDescriptor lambdaDescriptor : lambdaDescriptors) { + lambdaDescriptorBuilder.add( + new LambdaDescriptor(lambdaDescriptor.callArgumentIndex(), + parseLambdaArgumentDescriptors(lambdaDescriptor.lambdaArgumentDescriptors()))); + } + return lambdaDescriptorBuilder.build(); + } + + private static Map parseLambdaArgumentDescriptors(ScalarFunctionLambdaArgumentDescriptor[] lambdaArgumentToCallArgumentIndexMapEntries) + { + ImmutableMap.Builder lambdaArgumentToCallArgumentIndexMap = ImmutableMap.builder(); + for (ScalarFunctionLambdaArgumentDescriptor entry : lambdaArgumentToCallArgumentIndexMapEntries) { + lambdaArgumentToCallArgumentIndexMap.put(entry.lambdaArgumentIndex(), + new LambdaArgumentDescriptor(entry.callArgumentIndex(), parseSubfieldTransformationFunction(entry.lambdaArgumentToInputTransformationFunction()))); + } + return lambdaArgumentToCallArgumentIndexMap.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java index 0a83dc4233d49..3680d5bbfb974 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.operator.scalar.ScalarHeader; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlFunctionVisibility; @@ -28,10 +29,12 @@ import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseDescription; +import static com.facebook.presto.operator.scalar.annotations.FunctionDescriptorParser.parseFunctionDescriptor; import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN; import static com.google.common.base.CaseFormat.LOWER_CAMEL; import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.builder; import static java.util.Objects.requireNonNull; public class ScalarImplementationHeader @@ -77,19 +80,29 @@ public static List fromAnnotatedElement(AnnotatedEle ScalarOperator scalarOperator = annotated.getAnnotation(ScalarOperator.class); Optional description = parseDescription(annotated); - ImmutableList.Builder builder = ImmutableList.builder(); + ImmutableList.Builder builder = builder(); if (scalarFunction != null) { + ComplexTypeFunctionDescriptor functionDescriptor = parseFunctionDescriptor(scalarFunction.descriptor()); String baseName = scalarFunction.value().isEmpty() ? camelToSnake(annotatedName(annotated)) : scalarFunction.value(); - builder.add(new ScalarImplementationHeader(baseName, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput()))); + builder.add(new ScalarImplementationHeader(baseName, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), + scalarFunction.calledOnNullInput(), functionDescriptor))); for (String alias : scalarFunction.alias()) { - builder.add(new ScalarImplementationHeader(alias, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput()))); + builder.add(new ScalarImplementationHeader(alias, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), + scalarFunction.calledOnNullInput(), functionDescriptor))); } } if (scalarOperator != null) { - builder.add(new ScalarImplementationHeader(scalarOperator.value(), new ScalarHeader(description, HIDDEN, true, scalarOperator.value().isCalledOnNullInput()))); + builder.add(new ScalarImplementationHeader( + scalarOperator.value(), + new ScalarHeader( + description, + HIDDEN, + true, + scalarOperator.value().isCalledOnNullInput(), + parseFunctionDescriptor(scalarOperator.descriptor())))); } List result = builder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java index fa6eeb8d8e6f5..92c8db2502ff9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.CodegenScalarFunction; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.Parameter; import com.facebook.presto.spi.function.RoutineCharacteristics; @@ -175,7 +176,8 @@ else if (method.isAnnotationPresent(SqlParameters.class)) { body, notVersioned(), SCALAR, - Optional.empty())) + Optional.empty(), + ComplexTypeFunctionDescriptor.defaultFunctionDescriptor())) .collect(toImmutableList()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 0161e29a09bf1..3d58f6479b739 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -193,6 +193,7 @@ public class FeaturesConfig private boolean optimizedRepartitioningEnabled; private boolean pushdownSubfieldsEnabled; + private boolean pushdownSubfieldsFromLambdaEnabled; private boolean tableWriterMergeOperatorEnabled = true; @@ -1841,6 +1842,19 @@ public boolean isPushdownSubfieldsEnabled() return pushdownSubfieldsEnabled; } + @Config("pushdown-subfields-from-lambda-enabled") + @ConfigDescription("Enable subfield pruning from lambda expressions") + public FeaturesConfig setPushdownSubfieldsFromLambdaEnabled(boolean pushdownSubfieldsFromLambdaEnabled) + { + this.pushdownSubfieldsFromLambdaEnabled = pushdownSubfieldsFromLambdaEnabled; + return this; + } + + public boolean isPushdownSubfieldsFromLambdaEnabled() + { + return pushdownSubfieldsFromLambdaEnabled; + } + @Config("experimental.pushdown-dereference-enabled") @ConfigDescription("Experimental: enable dereference pushdown") public FeaturesConfig setPushdownDereferenceEnabled(boolean pushdownDereferenceEnabled) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java index 246b6ff87445d..4e1b0e9c9329f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java @@ -17,15 +17,21 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.Subfield.NestedField; +import com.facebook.presto.common.Subfield.PathElement; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.DistinctLimitNode; @@ -42,6 +48,7 @@ import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -64,30 +71,42 @@ import com.facebook.presto.sql.relational.RowExpressionOptimizer; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.isLegacyUnnest; import static com.facebook.presto.SystemSessionProperties.isPushdownSubfieldsEnabled; +import static com.facebook.presto.SystemSessionProperties.isPushdownSubfieldsFromArrayLambdasEnabled; import static com.facebook.presto.common.Subfield.allSubscripts; +import static com.facebook.presto.common.Subfield.noSubfield; import static com.facebook.presto.common.type.Varchars.isVarcharType; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; public class PushdownSubfields implements PlanOptimizer { + public static final QualifiedObjectName CARDINALITY = QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "cardinality"); + public static final QualifiedObjectName ELEMENT_AT = QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "element_at"); + public static final QualifiedObjectName CAST = QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "$operator$cast"); private final Metadata metadata; private boolean isEnabledForTesting; @@ -138,7 +157,12 @@ public Rewriter(Session session, Metadata metadata) this.metadata = requireNonNull(metadata, "metadata is null"); this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); this.expressionOptimizer = new RowExpressionOptimizer(metadata); - this.subfieldExtractor = new SubfieldExtractor(functionResolution, expressionOptimizer, session.toConnectorSession()); + this.subfieldExtractor = new SubfieldExtractor( + functionResolution, + expressionOptimizer, + session.toConnectorSession(), + metadata.getFunctionAndTypeManager(), + isPushdownSubfieldsFromArrayLambdasEnabled(session)); } @Override @@ -264,7 +288,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) continue; } - Optional subfield = toSubfield(expression, functionResolution, expressionOptimizer, session.toConnectorSession()); + Optional subfield = toSubfield(expression, functionResolution, expressionOptimizer, session.toConnectorSession(), metadata.getFunctionAndTypeManager()); if (subfield.isPresent()) { context.get().addAssignment(variable, subfield.get()); continue; @@ -328,15 +352,24 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext conte String columnName = getColumnName(session, metadata, node.getTable(), entry.getValue()); + List subfieldsWithoutNoSubfield = subfields.stream().filter(subfield -> !containsNoSubfieldPathElement(subfield)).collect(toList()); + List subfieldsWithNoSubfield = subfields.stream().filter(subfield -> containsNoSubfieldPathElement(subfield)).collect(toList()); + // Prune subfields: if one subfield is a prefix of another subfield, keep the shortest one. // Example: {a.b.c, a.b} -> {a.b} - List columnSubfields = subfields.stream() - .filter(subfield -> !prefixExists(subfield, subfields)) + List columnSubfields = subfieldsWithoutNoSubfield.stream() + .filter(subfield -> !prefixExists(subfield, subfieldsWithoutNoSubfield)) .map(Subfield::getPath) .map(path -> new Subfield(columnName, path)) - .collect(toImmutableList()); + .collect(toList()); + + columnSubfields.addAll(subfieldsWithNoSubfield.stream() + .filter(subfield -> !isPrefixOf(dropNoSubfield(subfield), subfieldsWithoutNoSubfield)) + .map(Subfield::getPath) + .map(path -> new Subfield(columnName, path)) + .collect(toList())); - newAssignments.put(variable, entry.getValue().withRequiredSubfields(columnSubfields)); + newAssignments.put(variable, entry.getValue().withRequiredSubfields(ImmutableList.copyOf(columnSubfields))); } return new TableScanNode( @@ -426,7 +459,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) found = true; matchingSubfields.stream() .map(Subfield::getPath) - .map(path -> new Subfield(container.getName(), ImmutableList.builder() + .map(path -> new Subfield(container.getName(), ImmutableList.builder() .add(allSubscripts()) .addAll(path) .build())) @@ -481,11 +514,27 @@ private boolean isRowType(VariableReferenceExpression variable) return variable.getType() instanceof ArrayType && ((ArrayType) variable.getType()).getElementType() instanceof RowType; } + private static Subfield dropNoSubfield(Subfield subfield) + { + return new Subfield(subfield.getRootName(), + subfield.getPath().stream().filter(pathElement -> !(pathElement instanceof Subfield.NoSubfield)).collect(toImmutableList())); + } + + private static boolean containsNoSubfieldPathElement(Subfield subfield) + { + return subfield.getPath().stream().anyMatch(pathElement -> pathElement instanceof Subfield.NoSubfield); + } + private static boolean prefixExists(Subfield subfieldPath, Collection subfieldPaths) { return subfieldPaths.stream().anyMatch(path -> path.isPrefix(subfieldPath)); } + private static boolean isPrefixOf(Subfield subfieldPath, Collection subfieldPaths) + { + return subfieldPaths.stream().anyMatch(subfieldPath::isPrefix); + } + private static String getColumnName(Session session, Metadata metadata, TableHandle tableHandle, ColumnHandle columnHandle) { return metadata.getColumnMetadata(session, tableHandle, columnHandle).getName(); @@ -495,7 +544,8 @@ private static Optional toSubfield( RowExpression expression, StandardFunctionResolution functionResolution, ExpressionOptimizer expressionOptimizer, - ConnectorSession connectorSession) + ConnectorSession connectorSession, + FunctionAndTypeManager functionAndTypeManager) { ImmutableList.Builder elements = ImmutableList.builder(); while (true) { @@ -527,7 +577,8 @@ private static Optional toSubfield( } return Optional.empty(); } - if (expression instanceof CallExpression && functionResolution.isSubscriptFunction(((CallExpression) expression).getFunctionHandle())) { + if (expression instanceof CallExpression && + isSubscriptOrElementAtFunction((CallExpression) expression, functionResolution, functionAndTypeManager)) { List arguments = ((CallExpression) expression).getArguments(); RowExpression indexExpression = expressionOptimizer.optimize( arguments.get(1), @@ -569,44 +620,196 @@ private static final class SubfieldExtractor private final StandardFunctionResolution functionResolution; private final ExpressionOptimizer expressionOptimizer; private final ConnectorSession connectorSession; - - private SubfieldExtractor(StandardFunctionResolution functionResolution, ExpressionOptimizer expressionOptimizer, ConnectorSession connectorSession) + private final FunctionAndTypeManager functionAndTypeManager; + private final boolean isPushDownSubfieldsFromLambdasEnabled; + + private SubfieldExtractor( + StandardFunctionResolution functionResolution, + ExpressionOptimizer expressionOptimizer, + ConnectorSession connectorSession, + FunctionAndTypeManager functionAndTypeManager, + boolean isPushDownSubfieldsFromLambdasEnabled) { this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.expressionOptimizer = requireNonNull(expressionOptimizer, "expressionOptimizer is null"); - this.connectorSession = requireNonNull(connectorSession, "connectorSession is null"); + this.connectorSession = connectorSession; + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.isPushDownSubfieldsFromLambdasEnabled = isPushDownSubfieldsFromLambdasEnabled; } @Override public Void visitCall(CallExpression call, Context context) { - if (!functionResolution.isSubscriptFunction(call.getFunctionHandle())) { + ComplexTypeFunctionDescriptor functionDescriptor = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getDescriptor(); + if (isSubscriptOrElementAtFunction(call, functionResolution, functionAndTypeManager)) { + Optional subfield = toSubfield(call, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager); + if (subfield.isPresent()) { + if (context.isPruningLambdaSubfieldsPossible()) { + addRequiredLambdaSubfields(context, subfield.get()); + } + else { + context.subfields.add(subfield.get()); + } + } + else { + call.getArguments().forEach(argument -> argument.accept(this, context)); + } + return null; + } + if (!isPushDownSubfieldsFromLambdasEnabled) { + context.setLambdaSubfields(Context.ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE); call.getArguments().forEach(argument -> argument.accept(this, context)); return null; } + Set lambdaSubfieldsOriginal = context.getLambdaSubfields(); + if ((functionDescriptor.isAccessingInputValues() && functionDescriptor.getLambdaDescriptors().isEmpty())) { + // If function internally accesses the input values we cannot prune any unaccessed lambda subfields since we do not know what subfields function accessed. + context.giveUpOnCollectingLambdaSubfields(); + } - // visit subscript expressions only - Optional subfield = toSubfield(call, functionResolution, expressionOptimizer, connectorSession); - if (subfield.isPresent()) { - context.subfields.add(subfield.get()); + // We need to apply output to input transformation function in order to make sense of all lambda subfields accessed in outer functions w.r.t. the + // input of the current function. + if (functionDescriptor.getOutputToInputTransformationFunction().isPresent()) { + Set transformedLambdaSubfields = + functionDescriptor.getOutputToInputTransformationFunction().get().apply(context.getLambdaSubfields()); + context.setLambdaSubfields(ImmutableSet.copyOf(transformedLambdaSubfields)); } - else { - call.getArguments().forEach(argument -> argument.accept(this, context)); + + Set argumentIndicesContainingMapOrArray = functionDescriptor.getArgumentIndicesContainingMapOrArray() + .orElse(IntStream.range(0, call.getArguments().size()) + .filter(argIndex -> isMapOrArrayOfRowType(call.getArguments().get(argIndex))) + .boxed() + .collect(toImmutableSet())); + + // All the lambda subfields collected in outer functions relate only to the arguments of the function specified in + // functionDescriptor.argumentIndicesContainingMapOrArray. + Map> lambdaSubfieldsFromOuterFunctions = argumentIndicesContainingMapOrArray.stream() + .collect(toImmutableMap(callArgumentIndex -> callArgumentIndex, unused -> ImmutableSet.copyOf(context.getLambdaSubfields()))); + + // If the function accepts lambdas, add all the lambda subfields from each lambda. + Map> lambdaSubfieldsFromCurrentFunction = ImmutableMap.of(); + for (LambdaDescriptor lambdaDescriptor : functionDescriptor.getLambdaDescriptors()) { + Optional>> lambdaSubfields = collectLambdaSubfields(call, lambdaDescriptor); + if (!lambdaSubfields.isPresent()) { + context.giveUpOnCollectingLambdaSubfields(); + call.getArguments().forEach(argument -> argument.accept(this, context)); + return null; + } + lambdaSubfieldsFromCurrentFunction = merge(lambdaSubfieldsFromCurrentFunction, lambdaSubfields.get()); + } + + Map> lambdaSubfields = merge(lambdaSubfieldsFromOuterFunctions, lambdaSubfieldsFromCurrentFunction); + + lambdaSubfields = addNoSubfieldIfNoAccessedSubfieldsFound(call, lambdaSubfields); + + // We need to continue visiting the function arguments and collect all lambda subfields in inner function calls as well as non-lambda subfields in all + // function arguments. Once reached the leaf node, we will try to prune the subfields of the input field, subscript, or subfield. + for (int callArgumentIndex = 0; callArgumentIndex < call.getArguments().size(); callArgumentIndex++) { + // Since context is global during the traversal of all the nodes in expression tree, we need to pass lambda subfields only to those function + // arguments that they relate to. + if (lambdaSubfields.containsKey(callArgumentIndex)) { + context.setLambdaSubfields(lambdaSubfields.get(callArgumentIndex)); + } + else { + context.setLambdaSubfields(Context.ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE); + } + call.getArguments().get(callArgumentIndex).accept(this, context); } + + // When we are done with inner calls (child nodes) we need to restore lambda subfields we received from parent expression to handle such situations like + // in example below + // SELECT * FROM my_table WHERE ANY_MATCH(column1, x -> x.ds > '2023-01-01') AND ALL_MATCH(column2, x -> STRPOS(x.comment, 'Presto') > 0) + // After we are done with ANY_MATCH, we need to restore the lambda subfields to what we received from parent node 'AND' so that it does not collide with + // lambda subfields of ALL_MATCH function. + context.setLambdaSubfields(lambdaSubfieldsOriginal); return null; } + private static Map> merge(Map> s1, Map> s2) + { + Map> result = new HashMap<>(s1); + s2.forEach((callArgumentIndex, subfields) -> result.merge( + callArgumentIndex, + subfields, + (lambdaSubfields1, lambdaSubfields2) -> ImmutableSet.builder().addAll(lambdaSubfields1).addAll(lambdaSubfields2).build())); + return ImmutableMap.copyOf(result); + } + + private static Map> addNoSubfieldIfNoAccessedSubfieldsFound(CallExpression call, Map> argumentIndexToLambdaSubfieldsMap) + { + ImmutableMap.Builder> argumentIndexToLambdaSubfieldsMapBuilder = ImmutableMap.builder(); + for (Integer callArgumentIndex : argumentIndexToLambdaSubfieldsMap.keySet()) { + if (!argumentIndexToLambdaSubfieldsMap.get(callArgumentIndex).isEmpty()) { + argumentIndexToLambdaSubfieldsMapBuilder.put(callArgumentIndex, argumentIndexToLambdaSubfieldsMap.get(callArgumentIndex)); + } + else { + RowExpression argument = call.getArguments().get(callArgumentIndex); + if (isMapOrArrayOfRowType(argument)) { + argumentIndexToLambdaSubfieldsMapBuilder.put(callArgumentIndex, ImmutableSet.of(new Subfield("", ImmutableList.of(allSubscripts(), noSubfield())))); + } + } + } + return argumentIndexToLambdaSubfieldsMapBuilder.build(); + } + + private static boolean isMapOrArrayOfRowType(RowExpression argument) + { + return (argument.getType() instanceof ArrayType && ((ArrayType) argument.getType()).getElementType() instanceof RowType) || + (argument.getType() instanceof MapType && ((MapType) argument.getType()).getValueType() instanceof RowType); + } + + private Optional>> collectLambdaSubfields(CallExpression call, LambdaDescriptor lambdaDescriptor) + { + Map> argumentIndexToLambdaSubfieldsMap = new HashMap<>(); + if (!(call.getArguments().get(lambdaDescriptor.getCallArgumentIndex()) instanceof LambdaDefinitionExpression)) { + // In this case, we cannot prune the subfields because the function can potentially access all subfields + return Optional.empty(); + } + LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) call.getArguments().get(lambdaDescriptor.getCallArgumentIndex()); + + Context subContext = new Context(); + lambda.getBody().accept(this, subContext); + for (int lambdaArgumentIndex : lambdaDescriptor.getLambdaArgumentDescriptors().keySet()) { + final LambdaArgumentDescriptor lambdaArgumentDescriptor = lambdaDescriptor.getLambdaArgumentDescriptors().get(lambdaArgumentIndex); + int callArgumentIndex = lambdaArgumentDescriptor.getCallArgumentIndex(); + argumentIndexToLambdaSubfieldsMap.putIfAbsent(callArgumentIndex, new HashSet<>()); + String root = lambda.getArguments().get(lambdaArgumentIndex); + if (subContext.variables.stream().anyMatch(variable -> variable.getName().equals(root))) { + // The entire struct was accessed. + return Optional.empty(); + } + Set transformedLambdaSubfields = lambdaArgumentDescriptor.getLambdaArgumentToInputTransformationFunction().apply( + subContext.subfields.stream() + .filter(x -> x.getRootName().equals(root)) + .collect(toImmutableSet())); + argumentIndexToLambdaSubfieldsMap.get(callArgumentIndex).addAll(transformedLambdaSubfields); + } + return Optional.of(ImmutableMap.copyOf(argumentIndexToLambdaSubfieldsMap)); + } + @Override public Void visitSpecialForm(SpecialFormExpression specialForm, Context context) { - if (specialForm.getForm() != DEREFERENCE) { + if (specialForm.getForm() == IS_NULL) { + if (specialForm.getArguments().get(0) instanceof VariableReferenceExpression && specialForm.getArguments().get(0).getType() instanceof RowType) { + context.subfields.add(new Subfield(((VariableReferenceExpression) specialForm.getArguments().get(0)).getName(), ImmutableList.of(noSubfield()))); + return null; + } + } + else if (specialForm.getForm() != DEREFERENCE) { specialForm.getArguments().forEach(argument -> argument.accept(this, context)); return null; } - Optional subfield = toSubfield(specialForm, functionResolution, expressionOptimizer, connectorSession); + Optional subfield = toSubfield(specialForm, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager); + if (subfield.isPresent()) { - context.subfields.add(subfield.get()); + if (context.isPruningLambdaSubfieldsPossible()) { + addRequiredLambdaSubfields(context, subfield.get()); + } + else { + context.subfields.add(subfield.get()); + } } else { specialForm.getArguments().forEach(argument -> argument.accept(this, context)); @@ -614,9 +817,33 @@ public Void visitSpecialForm(SpecialFormExpression specialForm, Context context) return null; } + /** + * Adds lambda subfields from the context to the list of the required subfields of the field/subscript/subfield provided in parameter 'input'. This function should be + * invoked + * once we reached leaf node while visiting the expression tree. Effectively, it prunes all unaccessed subfields of the 'input'. + * + * @param context - SubfieldExtractor context + * @param input - input field, subscript, or subfield, for which lambda subfields were collected. + */ + private void addRequiredLambdaSubfields(Context context, Subfield input) + { + Set lambdaSubfields = context.getLambdaSubfields(); + for (Subfield lambdaSubfield : lambdaSubfields) { + List newPath = ImmutableList.builder() + .addAll(input.getPath()) + .addAll(lambdaSubfield.getPath()) + .build(); + context.subfields.add(new Subfield(input.getRootName(), newPath)); + } + } + @Override public Void visitVariableReference(VariableReferenceExpression reference, Context context) { + if (context.isPruningLambdaSubfieldsPossible()) { + addRequiredLambdaSubfields(context, toSubfield(reference, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager).get()); + return null; + } context.variables.add(reference); return null; } @@ -624,9 +851,11 @@ public Void visitVariableReference(VariableReferenceExpression reference, Contex private static final class Context { + public static final Set ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE = ImmutableSet.of(new Subfield("", ImmutableList.of(allSubscripts()))); // Variables whose subfields cannot be pruned private final Set variables = new HashSet<>(); private final Set subfields = new HashSet<>(); + private Set lambdaSubfields = ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE; private void addAssignment(VariableReferenceExpression variable, VariableReferenceExpression otherVariable) { @@ -656,7 +885,7 @@ private void addAssignment(VariableReferenceExpression variable, Subfield subfie matchingSubfields.stream() .map(Subfield::getPath) - .map(path -> new Subfield(subfield.getRootName(), ImmutableList.builder() + .map(path -> new Subfield(subfield.getRootName(), ImmutableList.builder() .addAll(subfield.getPath()) .addAll(path) .build())) @@ -669,6 +898,37 @@ private List findSubfields(String rootName) .filter(subfield -> rootName.equals(subfield.getRootName())) .collect(toImmutableList()); } + + public void setLambdaSubfields(Set lambdaSubfields) + { + this.lambdaSubfields = lambdaSubfields; + } + + public Set getLambdaSubfields() + { + return lambdaSubfields; + } + + private void giveUpOnCollectingLambdaSubfields() + { + setLambdaSubfields(ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE); + } + + private boolean isPruningLambdaSubfieldsPossible() + { + return !getLambdaSubfields().isEmpty() && + getLambdaSubfields().stream() + .noneMatch( + subfield -> subfield.getPath().stream() + .skip(subfield.getPath().size() - 1) + .anyMatch(pathElement -> pathElement.equals(allSubscripts()))); + } } } + + private static boolean isSubscriptOrElementAtFunction(CallExpression expression, StandardFunctionResolution functionResolution, FunctionAndTypeManager functionAndTypeManager) + { + return functionResolution.isSubscriptFunction(expression.getFunctionHandle()) || + functionAndTypeManager.getFunctionMetadata(expression.getFunctionHandle()).getName().equals(ELEMENT_AT); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index d8522aaa7d03f..3a7bd96889c65 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -169,6 +169,7 @@ public void testDefaults() .setOptimizeConstantGroupingKeys(true) .setMaxConcurrentMaterializations(3) .setPushdownSubfieldsEnabled(false) + .setPushdownSubfieldsFromLambdaEnabled(false) .setPushdownDereferenceEnabled(false) .setTableWriterMergeOperatorEnabled(true) .setIndexLoaderTimeout(new Duration(20, SECONDS)) @@ -362,6 +363,7 @@ public void testExplicitPropertyMappings() .put("optimizer.optimize-constant-grouping-keys", "false") .put("max-concurrent-materializations", "5") .put("experimental.pushdown-subfields-enabled", "true") + .put("pushdown-subfields-from-lambda-enabled", "true") .put("experimental.pushdown-dereference-enabled", "true") .put("experimental.table-writer-merge-operator-enabled", "false") .put("index-loader-timeout", "10s") @@ -552,6 +554,7 @@ public void testExplicitPropertyMappings() .setOptimizeConstantGroupingKeys(false) .setMaxConcurrentMaterializations(5) .setPushdownSubfieldsEnabled(true) + .setPushdownSubfieldsFromLambdaEnabled(true) .setPushdownDereferenceEnabled(true) .setTableWriterMergeOperatorEnabled(false) .setIndexLoaderTimeout(new Duration(10, SECONDS)) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/CodegenScalarFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/CodegenScalarFunction.java index 90f8c390a6fd1..28351daf6ca87 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/CodegenScalarFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/CodegenScalarFunction.java @@ -33,4 +33,6 @@ boolean deterministic() default true; boolean calledOnNullInput() default false; + + ScalarFunctionDescriptor descriptor() default @ScalarFunctionDescriptor; } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ComplexTypeFunctionDescriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ComplexTypeFunctionDescriptor.java new file mode 100644 index 0000000000000..55ec13ca1ac3b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ComplexTypeFunctionDescriptor.java @@ -0,0 +1,234 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.common.Subfield; +import com.facebook.presto.common.type.TypeSignature; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.common.Subfield.allSubscripts; +import static com.facebook.presto.common.Utils.checkArgument; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.unmodifiableList; +import static java.util.Collections.unmodifiableSet; +import static java.util.Objects.requireNonNull; + +/** + * Contains properties that describe how the function operates on Map or Array inputs. + */ +public class ComplexTypeFunctionDescriptor +{ + public static final List MAP_AND_ARRAY = unmodifiableList(Arrays.asList("map", "array")); + public static final ComplexTypeFunctionDescriptor DEFAULT = new ComplexTypeFunctionDescriptor( + true, + emptyList(), + Optional.of(emptySet()), + Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired)); + + /** + * Indicates whether the function accessing subfields. + */ + private final boolean isAccessingInputValues; + + /** + * Set of indices of the function arguments containing map or array arguments. Those arguments are important because all accessed subfields collected so far relate only to + * those map or array arguments and will be passed only to those arguments during the expression analysis phase. + * If argumentIndicesContainingMapOrArray is Optional.empty(), it indicates that accessed subfields collected so far relate to all function arguments + * are of the map or array types. For the vast majority of function, this value should be used. + * If the value of argumentIndicesContainingMapOrArray is present, it indicates that accessed subfields collected so far relate only to subset of the arguments. + * For example, in MapConstructor function accessed map value subfield from outer call relate only to second argument and therefore for this + * argumentIndicesContainingMapOrArray needs to be set to Optional.of(ImmutableSet.of(1)). + */ + private final Optional> argumentIndicesContainingMapOrArray; + + /** + * Contains the transformation function to convert the output back to the input elements of the array or map. + * If outputToInputTransformationFunction is Optional.empty(), it indicates that transformation is not required and equivalent to the identity function + */ + private final Optional, Set>> outputToInputTransformationFunction; + + /** + * Contains the description of all lambdas that this function accepts. + * If function does not accept any lambda parameter, then lambdaDescriptors should be an empty list. + */ + private final List lambdaDescriptors; + + public ComplexTypeFunctionDescriptor( + boolean isAccessingInputValues, + List lambdaDescriptors, + Optional> argumentIndicesContainingMapOrArray, + Optional, Set>> outputToInputTransformationFunction, + Signature signature) + { + this(isAccessingInputValues, lambdaDescriptors, argumentIndicesContainingMapOrArray, outputToInputTransformationFunction, signature.getArgumentTypes()); + } + public ComplexTypeFunctionDescriptor( + boolean isAccessingInputValues, + List lambdaDescriptors, + Optional> argumentIndicesContainingMapOrArray, + Optional, Set>> outputToInputTransformationFunction, + List argumentTypes) + { + this(isAccessingInputValues, lambdaDescriptors, argumentIndicesContainingMapOrArray, outputToInputTransformationFunction); + if (argumentIndicesContainingMapOrArray.isPresent()) { + checkArgument(argumentIndicesContainingMapOrArray.get().stream().allMatch(index -> index >= 0 && + index < argumentTypes.size() && + MAP_AND_ARRAY.contains(argumentTypes.get(index).getBase().toLowerCase(Locale.ENGLISH)))); + } + for (LambdaDescriptor lambdaDescriptor : lambdaDescriptors) { + checkArgument(lambdaDescriptor.getCallArgumentIndex() >= 0 && argumentTypes.get(lambdaDescriptor.getCallArgumentIndex()).isFunction()); + checkArgument(lambdaDescriptor.getLambdaArgumentDescriptors().keySet().stream().allMatch( + argumentIndex -> argumentIndex >= 0 && argumentIndex < argumentTypes.size())); + for (Integer lambdaArgumentIndex : lambdaDescriptor.getLambdaArgumentDescriptors().keySet()) { + checkArgument(lambdaArgumentIndex >= 0 && + lambdaArgumentIndex < argumentTypes.get(lambdaDescriptor.getCallArgumentIndex()).getParameters().size() - 1); + LambdaArgumentDescriptor lambdaArgumentDescriptor = lambdaDescriptor.getLambdaArgumentDescriptors().get(lambdaArgumentIndex); + checkArgument(lambdaArgumentDescriptor.getCallArgumentIndex() >= 0 && + lambdaArgumentDescriptor.getCallArgumentIndex() < argumentTypes.size()); + } + } + } + + public ComplexTypeFunctionDescriptor( + boolean isAccessingInputValues, + List lambdaDescriptors, + Optional> argumentIndicesContainingMapOrArray, + Optional, Set>> outputToInputTransformationFunction) + { + requireNonNull(argumentIndicesContainingMapOrArray, "argumentIndicesContainingMapOrArray is null"); + this.isAccessingInputValues = isAccessingInputValues; + this.lambdaDescriptors = unmodifiableList(requireNonNull(lambdaDescriptors, "lambdaDescriptors is null")); + this.argumentIndicesContainingMapOrArray = argumentIndicesContainingMapOrArray.isPresent() ? + Optional.of(unmodifiableSet(argumentIndicesContainingMapOrArray.get())) : + Optional.empty(); + this.outputToInputTransformationFunction = requireNonNull(outputToInputTransformationFunction, "outputToInputTransformationFunction is null"); + } + + public static ComplexTypeFunctionDescriptor defaultFunctionDescriptor() + { + return DEFAULT; + } + + public boolean isAccessingInputValues() + { + return isAccessingInputValues; + } + + public Optional> getArgumentIndicesContainingMapOrArray() + { + return argumentIndicesContainingMapOrArray; + } + + public List getLambdaDescriptors() + { + return lambdaDescriptors; + } + + public boolean isAcceptingLambdaArgument() + { + return !lambdaDescriptors.isEmpty(); + } + + public Optional, Set>> getOutputToInputTransformationFunction() + { + return outputToInputTransformationFunction; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ComplexTypeFunctionDescriptor that = (ComplexTypeFunctionDescriptor) o; + return isAccessingInputValues == that.isAccessingInputValues && + Objects.equals(argumentIndicesContainingMapOrArray, that.argumentIndicesContainingMapOrArray) && + Objects.equals(outputToInputTransformationFunction, that.outputToInputTransformationFunction) && + Objects.equals(lambdaDescriptors, that.lambdaDescriptors); + } + + @Override + public int hashCode() + { + return Objects.hash(isAccessingInputValues, argumentIndicesContainingMapOrArray, outputToInputTransformationFunction, lambdaDescriptors); + } + + /** + * Adds allSubscripts on top of the path for every subfield in 'subfields'. + * + * @param subfields set of Subfield to transform + * @return transformed copy of the input set of subfields with allSubscripts. + */ + public static Set prependAllSubscripts(Set subfields) + { + return subfields.stream().map(subfield -> new Subfield(subfield.getRootName(), + unmodifiableList( + Stream.concat( + Arrays.asList(allSubscripts()).stream(), + subfield.getPath().stream()).collect(Collectors.toList())))) + .collect(Collectors.toSet()); + } + + /** + * Transformation function that overrides all lambda subfields from outer functions with the single subfield with allSubscripts in its path. + * Essentially, it instructs to include all subfields of the array element or map value. This function is most commonly used with the function that + * returns the entire value from its input or accesses input values internally. + * + * @return one subfield with allSubscripts in its path. + */ + public static Set allSubfieldsRequired(Set subfields) + { + if (subfields.isEmpty()) { + return unmodifiableSet(Stream.of(new Subfield("", Arrays.asList(allSubscripts()))).collect(Collectors.toSet())); + } + return subfields; + } + + /** + * Transformation function that removes any previously accessed subfields. This function is most commonly used with the function that do not return values from its input. + * + * @return empty set. + */ + public static Set clearRequiredSubfields(Set ignored) + { + return emptySet(); + } + + /** + * Removes the second path element from every subfield in 'subfields'. + * + * @param subfields set of Subfield to transform + * @return transformed copy of the input set of subfields with removed the second path element. + */ + public static Set removeSecondPathElement(Set subfields) + { + return subfields.stream().map(subfield -> new Subfield(subfield.getRootName(), + unmodifiableList( + Stream.concat(Arrays.asList(subfield.getPath().get(0)).stream(), subfield.getPath().stream().skip(2)).collect(Collectors.toList())))) + .collect(Collectors.toSet()); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java index a4a6127fee134..a82a744b32fbe 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java @@ -23,6 +23,7 @@ import java.util.Objects; import java.util.Optional; +import static com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor.defaultFunctionDescriptor; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; @@ -40,6 +41,7 @@ public class FunctionMetadata private final boolean deterministic; private final boolean calledOnNullInput; private final FunctionVersion version; + private final ComplexTypeFunctionDescriptor descriptor; public FunctionMetadata( QualifiedObjectName name, @@ -53,6 +55,19 @@ public FunctionMetadata( this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned()); } + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + TypeSignature returnType, + FunctionKind functionKind, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + ComplexTypeFunctionDescriptor functionDescriptor) + { + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor); + } + public FunctionMetadata( QualifiedObjectName name, List argumentTypes, @@ -68,6 +83,23 @@ public FunctionMetadata( this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, calledOnNullInput, version); } + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + List argumentNames, + TypeSignature returnType, + FunctionKind functionKind, + Language language, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + FunctionVersion version, + ComplexTypeFunctionDescriptor functionDescriptor) + { + this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, + calledOnNullInput, version, functionDescriptor); + } + public FunctionMetadata( OperatorType operatorType, List argumentTypes, @@ -80,6 +112,19 @@ public FunctionMetadata( this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned()); } + public FunctionMetadata( + OperatorType operatorType, + List argumentTypes, + TypeSignature returnType, + FunctionKind functionKind, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + ComplexTypeFunctionDescriptor functionDescriptor) + { + this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor); + } + private FunctionMetadata( QualifiedObjectName name, Optional operatorType, @@ -92,6 +137,35 @@ private FunctionMetadata( boolean deterministic, boolean calledOnNullInput, FunctionVersion version) + { + this( + name, + operatorType, + argumentTypes, + argumentNames, + returnType, + functionKind, + language, + implementationType, + deterministic, + calledOnNullInput, + version, + defaultFunctionDescriptor()); + } + + private FunctionMetadata( + QualifiedObjectName name, + Optional operatorType, + List argumentTypes, + Optional> argumentNames, + TypeSignature returnType, + FunctionKind functionKind, + Optional language, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + FunctionVersion version, + ComplexTypeFunctionDescriptor functionDescriptor) { this.name = requireNonNull(name, "name is null"); this.operatorType = requireNonNull(operatorType, "operatorType is null"); @@ -104,8 +178,14 @@ private FunctionMetadata( this.deterministic = deterministic; this.calledOnNullInput = calledOnNullInput; this.version = requireNonNull(version, "version is null"); + requireNonNull(functionDescriptor, "functionDescriptor is null"); + this.descriptor = new ComplexTypeFunctionDescriptor( + functionDescriptor.isAccessingInputValues(), + functionDescriptor.getLambdaDescriptors(), + functionDescriptor.getArgumentIndicesContainingMapOrArray(), + functionDescriptor.getOutputToInputTransformationFunction(), + argumentTypes); } - public FunctionKind getFunctionKind() { return functionKind; @@ -161,6 +241,11 @@ public FunctionVersion getVersion() return version; } + public ComplexTypeFunctionDescriptor getDescriptor() + { + return descriptor; + } + @Override public boolean equals(Object obj) { @@ -181,12 +266,13 @@ public boolean equals(Object obj) Objects.equals(this.implementationType, other.implementationType) && Objects.equals(this.deterministic, other.deterministic) && Objects.equals(this.calledOnNullInput, other.calledOnNullInput) && - Objects.equals(this.version, other.version); + Objects.equals(this.version, other.version) && + Objects.equals(this.descriptor, other.descriptor); } @Override public int hashCode() { - return Objects.hash(name, operatorType, argumentTypes, argumentNames, returnType, functionKind, language, implementationType, deterministic, calledOnNullInput, version); + return Objects.hash(name, operatorType, argumentTypes, argumentNames, returnType, functionKind, language, implementationType, deterministic, calledOnNullInput, version, descriptor); } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/IntArray.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/IntArray.java new file mode 100644 index 0000000000000..ffd3332a6c6e5 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/IntArray.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({METHOD, TYPE}) +public @interface IntArray { + int[] value() default {}; +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/LambdaArgumentDescriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/LambdaArgumentDescriptor.java new file mode 100644 index 0000000000000..2e6eaee14f94d --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/LambdaArgumentDescriptor.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.common.Subfield; + +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +import static java.util.Objects.requireNonNull; + +public class LambdaArgumentDescriptor +{ + /** + * Index of the function argument that contains the function input (Array or Map), to which this lambda argument relate to. + */ + final int callArgumentIndex; + + /** + * Contains the transformation function between the subfields of this lambda argument and the input of the function. + * + * The reason this transformation is needed because the input of the function is the Array or Map, while the lambda arguments are the element of the array or key/value of + * the map. Specifically for map, the transformation function for the lambda argument of the map key will be different from the transformation of the map value. + * If transformation succeeded, then the returned value contains the transformed set of lambda subfields. Otherwise, the function must return Optional.empty() + * value. + */ + private final Function, Set> lambdaArgumentToInputTransformationFunction; + + public LambdaArgumentDescriptor(int callArgumentIndex, Function, Set> lambdaArgumentToInputTransformationFunction) + { + this.callArgumentIndex = callArgumentIndex; + this.lambdaArgumentToInputTransformationFunction = requireNonNull(lambdaArgumentToInputTransformationFunction, "lambdaArgumentToInputTransformationFunction is null"); + } + + public int getCallArgumentIndex() + { + return callArgumentIndex; + } + + public Function, Set> getLambdaArgumentToInputTransformationFunction() + { + return lambdaArgumentToInputTransformationFunction; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LambdaArgumentDescriptor that = (LambdaArgumentDescriptor) o; + return callArgumentIndex == that.callArgumentIndex && Objects.equals(lambdaArgumentToInputTransformationFunction, that.lambdaArgumentToInputTransformationFunction); + } + + @Override + public int hashCode() + { + return Objects.hash(callArgumentIndex, lambdaArgumentToInputTransformationFunction); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/LambdaDescriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/LambdaDescriptor.java new file mode 100644 index 0000000000000..2a933d943e62f --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/LambdaDescriptor.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.util.Map; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class LambdaDescriptor +{ + /** + * Index of the argument in the Call expression of the lambda function that this LambdaDescriptor represents + */ + final int callArgumentIndex; + + /** + * Map of lambda argument descriptors where the key corresponds to the index in the list of lambda argument and value is the descriptor of the argument. + */ + final Map lambdaArgumentDescriptors; + + public LambdaDescriptor(int callArgumentIndex, Map lambdaArgumentDescriptors) + { + this.callArgumentIndex = callArgumentIndex; + this.lambdaArgumentDescriptors = requireNonNull(lambdaArgumentDescriptors, "lambdaArgumentDescriptors is null"); + } + + public int getCallArgumentIndex() + { + return callArgumentIndex; + } + + public Map getLambdaArgumentDescriptors() + { + return lambdaArgumentDescriptors; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LambdaDescriptor that = (LambdaDescriptor) o; + return callArgumentIndex == that.callArgumentIndex && Objects.equals(lambdaArgumentDescriptors, that.lambdaArgumentDescriptors); + } + + @Override + public int hashCode() + { + return Objects.hash(callArgumentIndex, lambdaArgumentDescriptors); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunction.java index 273f060278c38..58ef8d5dac96e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunction.java @@ -34,4 +34,6 @@ boolean deterministic() default true; boolean calledOnNullInput() default false; + + ScalarFunctionDescriptor descriptor() default @ScalarFunctionDescriptor; } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionDescriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionDescriptor.java new file mode 100644 index 0000000000000..1b1cc0bc96a2b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionDescriptor.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({METHOD, TYPE}) +public @interface ScalarFunctionDescriptor +{ + /** + * Indicates whether the function accessing subfields. + */ + boolean isAccessingInputValues() default true; + + /** + * Set of indices of the function arguments containing map or array arguments. Those arguments are important because all accessed subfields collected so far relate only to + * those map or array arguments and will be passed only to those arguments during the expression analysis phase. + * If argumentIndicesContainingMapOrArray is empty array, it indicates that all function arguments of the map or array types (similarly to + * Optional.empty() value in ComplexTypeFunctionDescriptor.argumentIndicesContainingMapOrArray). For the vast majority of function, this value should + * be used. + * If argumentIndicesContainingMapOrArray is a non-empty array, it must contain a single value of @IntArray type containing the array of the arguments where + * accessed subfields need to be passed to. + */ + IntArray[] argumentIndicesContainingMapOrArray() default {}; + + /** + * Contains the transformation function to convert the output back to the input elements of the array or map. + */ + StaticMethodPointer[] outputToInputTransformationFunction() default {@StaticMethodPointer(clazz = ComplexTypeFunctionDescriptor.class, method = "allSubfieldsRequired")}; + + /** + * Contains the description of all lambdas that this function accepts. + * If function does not accept any lambda parameter, then lambdaDescriptors should be an empty list. + */ + ScalarFunctionLambdaDescriptor[] lambdaDescriptors() default {}; +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionLambdaArgumentDescriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionLambdaArgumentDescriptor.java new file mode 100644 index 0000000000000..11d5748aadaac --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionLambdaArgumentDescriptor.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({METHOD, TYPE}) +public @interface ScalarFunctionLambdaArgumentDescriptor +{ + /** + * Index of the argument in the lambda argument list. Example: for the lambda (x, y) -> x + y, lambdaArgumentIndex for x and y are 0 and 1 respectively. + */ + int lambdaArgumentIndex(); + + /** + * Index of the function argument that contains the function input (Array or Map), to which this lambda argument relate to. + */ + int callArgumentIndex(); + + /** + * Contains the transformation function between the subfields of this lambda argument and the input of the function. + * + * The reason this transformation is needed because the input of the function is the Array or Map, while the lambda arguments are the element of the array or key/value of + * the map. Specifically for map, the transformation function for the lambda argument of the map key will be different from the transformation of the map value. + * If transformation succeeded, then the returned value contains the transformed set of lambda subfields. Otherwise, the function must return Optional.empty() + * value. + */ + StaticMethodPointer lambdaArgumentToInputTransformationFunction() default @StaticMethodPointer( + clazz = ComplexTypeFunctionDescriptor.class, method = "prependAllSubscripts"); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionLambdaDescriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionLambdaDescriptor.java new file mode 100644 index 0000000000000..51ec612b5d248 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionLambdaDescriptor.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({METHOD, TYPE}) +public @interface ScalarFunctionLambdaDescriptor +{ + /** + * Index of the argument in the Call expression of the lambda function that this LambdaDescriptor represents + */ + int callArgumentIndex(); + + /** + * Map of lambda argument descriptors where the key corresponds to the index in the list of lambda argument and value is the descriptor of the argument. + */ + ScalarFunctionLambdaArgumentDescriptor[] lambdaArgumentDescriptors(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarOperator.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarOperator.java index 96c5958b10701..d72cc295db6c3 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarOperator.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarOperator.java @@ -27,4 +27,6 @@ public @interface ScalarOperator { OperatorType value(); + + ScalarFunctionDescriptor descriptor() default @ScalarFunctionDescriptor; } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunction.java index a0ec06b138506..1cfa9f1236fb4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunction.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.spi.function; +import static com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor.defaultFunctionDescriptor; + public interface SqlFunction { Signature getSignature(); @@ -24,4 +26,9 @@ public interface SqlFunction boolean isCalledOnNullInput(); String getDescription(); + + default ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return defaultFunctionDescriptor(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java index 17e720a580277..184ee64c5c8e3 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java @@ -27,6 +27,7 @@ import java.util.Objects; import java.util.Optional; +import static com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor.defaultFunctionDescriptor; import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; @@ -52,6 +53,7 @@ public class SqlInvokedFunction private final SqlFunctionId functionId; private final FunctionVersion functionVersion; private final Optional functionHandle; + private final ComplexTypeFunctionDescriptor descriptor; /** * Metadata required for Aggregation Functions @@ -77,6 +79,7 @@ public SqlInvokedFunction( this.functionVersion = notVersioned(); this.functionHandle = Optional.empty(); this.aggregationMetadata = Optional.empty(); + this.descriptor = defaultFunctionDescriptor(); } // This constructor creates a SCALAR SqlInvokedFunction @@ -89,7 +92,7 @@ public SqlInvokedFunction( String body, FunctionVersion version) { - this(functionName, parameters, emptyList(), returnType, description, routineCharacteristics, body, version, SCALAR, Optional.empty()); + this(functionName, parameters, emptyList(), returnType, description, routineCharacteristics, body, version, SCALAR, Optional.empty(), defaultFunctionDescriptor()); } public SqlInvokedFunction( @@ -103,8 +106,9 @@ public SqlInvokedFunction( FunctionKind kind, Optional aggregationMetadata) { - this(functionName, parameters, emptyList(), returnType, description, routineCharacteristics, body, version, kind, aggregationMetadata); + this(functionName, parameters, emptyList(), returnType, description, routineCharacteristics, body, version, kind, aggregationMetadata, defaultFunctionDescriptor()); } + public SqlInvokedFunction( QualifiedObjectName functionName, List parameters, @@ -115,7 +119,8 @@ public SqlInvokedFunction( String body, FunctionVersion version, FunctionKind kind, - Optional aggregationMetadata) + Optional aggregationMetadata, + ComplexTypeFunctionDescriptor descriptor) { this.parameters = requireNonNull(parameters, "parameters is null"); this.description = requireNonNull(description, "description is null"); @@ -130,6 +135,7 @@ public SqlInvokedFunction( this.functionId = new SqlFunctionId(functionName, argumentTypes); this.functionVersion = requireNonNull(version, "version is null"); this.functionHandle = version.hasVersion() ? Optional.of(new SqlFunctionHandle(this.functionId, version.toString())) : Optional.empty(); + this.descriptor = requireNonNull(descriptor, "descriptor is null"); this.aggregationMetadata = requireNonNull(aggregationMetadata, "aggregationMetadata is null"); if ((kind == AGGREGATE && !aggregationMetadata.isPresent()) || (kind != AGGREGATE && aggregationMetadata.isPresent())) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StaticMethodPointer.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StaticMethodPointer.java new file mode 100644 index 0000000000000..537de62267d33 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StaticMethodPointer.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({METHOD, TYPE}) +public @interface StaticMethodPointer +{ + Class clazz(); + String method(); +}