diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java index 51b502d3f0b6..d09197f8473d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java @@ -23,7 +23,7 @@ public class FunctionMetadata { private final FunctionId functionId; private final Signature signature; - private final String actualName; + private final String canonicalName; private final boolean nullable; private final List argumentDefinitions; private final boolean hidden; @@ -56,7 +56,7 @@ public FunctionMetadata( public FunctionMetadata( Signature signature, - String actualName, + String canonicalName, boolean nullable, List argumentDefinitions, boolean hidden, @@ -66,16 +66,9 @@ public FunctionMetadata( boolean deprecated) { this( - FunctionId.toFunctionId( - new Signature( - actualName, - signature.getTypeVariableConstraints(), - signature.getLongVariableConstraints(), - signature.getReturnType(), - signature.getArgumentTypes(), - signature.isVariableArity())), + FunctionId.toFunctionId(signature), signature, - actualName, + canonicalName, nullable, argumentDefinitions, hidden, @@ -88,7 +81,7 @@ public FunctionMetadata( public FunctionMetadata( FunctionId functionId, Signature signature, - String actualName, + String canonicalName, boolean nullable, List argumentDefinitions, boolean hidden, @@ -99,7 +92,7 @@ public FunctionMetadata( { this.functionId = requireNonNull(functionId, "functionId is null"); this.signature = requireNonNull(signature, "signature is null"); - this.actualName = requireNonNull(actualName, "actualName is null"); + this.canonicalName = requireNonNull(canonicalName, "canonicalName is null"); this.nullable = nullable; this.argumentDefinitions = ImmutableList.copyOf(requireNonNull(argumentDefinitions, "argumentDefinitions is null")); this.hidden = hidden; @@ -110,8 +103,8 @@ public FunctionMetadata( } /** - * Returns {@link FunctionId} under which function is to be registered. It is based on the {@link #getActualName()}, - * which is either the canonical function name or an alias. + * Unique id of this function. + * For aliased functions, each alias must have a different alias. */ public FunctionId getFunctionId() { @@ -119,7 +112,8 @@ public FunctionId getFunctionId() } /** - * Returns function {@link Signature} with canonical name of the function. + * Signature of a matching call site. + * For aliased functions, the signature must use the alias name. */ public Signature getSignature() { @@ -127,12 +121,11 @@ public Signature getSignature() } /** - * Returns the name under which function is registered. Typically same as {@code getSignature().getName()} - * unless this is an alias. + * For aliased functions, the canonical name of the function. */ - public String getActualName() + public String getCanonicalName() { - return actualName; + return canonicalName; } public boolean isNullable() diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java index 2db0552a1825..5fa97a4f6db1 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java @@ -908,7 +908,7 @@ public FunctionMap(FunctionMap map, Collection functions) .putAll(map.functionsByName); functions.stream() .map(SqlFunction::getFunctionMetadata) - .forEach(functionMetadata -> functionsByName.put(QualifiedName.of(functionMetadata.getActualName()), functionMetadata)); + .forEach(functionMetadata -> functionsByName.put(QualifiedName.of(functionMetadata.getSignature().getName()), functionMetadata)); this.functionsByName = functionsByName.build(); // Make sure all functions with the same name are aggregations or none of them are diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 1aeafab8a616..3fd136c24e51 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -2384,7 +2384,7 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction) return new FunctionMetadata( functionMetadata.getFunctionId(), resolvedFunction.getSignature().toSignature(), - functionMetadata.getActualName(), + functionMetadata.getCanonicalName(), functionMetadata.isNullable(), argumentDefinitions, functionMetadata.isHidden(), diff --git a/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java b/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java index 07acba816e78..5aa38b69fa54 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java +++ b/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java @@ -15,6 +15,7 @@ import io.trino.metadata.FunctionBinding; import io.trino.metadata.FunctionDependencies; +import io.trino.metadata.Signature; import io.trino.operator.annotations.ImplementationDependency; import java.lang.invoke.MethodHandle; @@ -25,11 +26,25 @@ public final class ParametricFunctionHelpers { private ParametricFunctionHelpers() {} - public static MethodHandle bindDependencies(MethodHandle handle, List dependencies, FunctionBinding functionBinding, FunctionDependencies functionDependencies) + public static MethodHandle bindDependencies(MethodHandle handle, + List dependencies, + FunctionBinding functionBinding, + FunctionDependencies functionDependencies) { for (ImplementationDependency dependency : dependencies) { handle = MethodHandles.insertArguments(handle, 0, dependency.resolve(functionBinding, functionDependencies)); } return handle; } + + public static Signature signatureWithName(String name, Signature signature) + { + return new Signature( + name, + signature.getTypeVariableConstraints(), + signature.getLongVariableConstraints(), + signature.getReturnType(), + signature.getArgumentTypes(), + signature.isVariableArity()); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 295688880edf..6304a9d5efb0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -41,6 +41,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.operator.ParametricFunctionHelpers.signatureWithName; import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation; import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription; import static java.lang.String.format; @@ -83,7 +84,7 @@ public static List parseFunctionDefinitions(Class aggr for (AggregationHeader header : parseHeaders(aggregationDefinition, outputFunction)) { AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, removeInputFunction, outputFunction, combineFunction); ParametricImplementationsGroup implementations = ParametricImplementationsGroup.of(onlyImplementation); - builder.add(new ParametricAggregation(implementations.getSignature(), header, implementations, deprecated)); + builder.add(new ParametricAggregation(signatureWithName(header.getName(), implementations.getSignature()), header, implementations, deprecated)); } } } @@ -109,7 +110,7 @@ public static ParametricAggregation parseFunctionDefinition(Class aggregation } ParametricImplementationsGroup implementations = implementationsBuilder.build(); - return new ParametricAggregation(implementations.getSignature(), header, implementations, deprecated); + return new ParametricAggregation(signatureWithName(header.getName(), implementations.getSignature()), header, implementations, deprecated); } private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinition) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java index c3a612b0d7e6..9194c065ba94 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java @@ -15,6 +15,7 @@ import java.util.Optional; +import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; public class AggregationHeader @@ -65,4 +66,16 @@ public boolean isHidden() { return hidden; } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("description", description) + .add("decomposable", decomposable) + .add("orderSensitive", orderSensitive) + .add("hidden", hidden) + .toString(); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java index 6d333c9e803c..5f4dc5dc6943 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java @@ -323,7 +323,7 @@ private Parser( private AggregationImplementation get() { Signature signature = new Signature( - header.getCanonicalName(), + header.getName(), typeVariableConstraints, longVariableConstraints, returnType, diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index 8983dc98a69a..53211f95abd4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -67,7 +67,7 @@ public ParametricAggregation( super( new FunctionMetadata( signature, - details.getName(), + details.getCanonicalName(), true, implementations.getArgumentDefinitions(), details.isHidden(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 7c4926647859..918172ac3d6a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -135,7 +135,7 @@ public static Optional pushAggregationIntoTableScan( List aggregateFunctions = aggregationsList.stream() .map(Entry::getValue) - .map(aggregation -> toAggregateFunction(context, aggregation)) + .map(aggregation -> toAggregateFunction(metadata, context, aggregation)) .collect(toImmutableList()); List aggregationOutputSymbols = aggregationsList.stream() @@ -214,8 +214,9 @@ public static Optional pushAggregationIntoTableScan( assignmentBuilder.build())); } - private static AggregateFunction toAggregateFunction(Context context, AggregationNode.Aggregation aggregation) + private static AggregateFunction toAggregateFunction(Metadata metadata, Context context, AggregationNode.Aggregation aggregation) { + String canonicalName = metadata.getFunctionMetadata(aggregation.getResolvedFunction()).getCanonicalName(); BoundSignature signature = aggregation.getResolvedFunction().getSignature(); ImmutableList.Builder arguments = new ImmutableList.Builder<>(); @@ -231,7 +232,7 @@ private static AggregateFunction toAggregateFunction(Context context, Aggregatio .map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol))); return new AggregateFunction( - signature.getName(), + canonicalName, signature.getReturnType(), arguments.build(), sortBy.orElse(ImmutableList.of()), diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java index 6b6fdb2ef4cc..b326b7c1c1af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java @@ -707,7 +707,7 @@ protected Node visitShowFunctions(ShowFunctions node, Void context) List rows = metadata.listFunctions().stream() .filter(function -> !function.isHidden()) .map(function -> row( - new StringLiteral(function.getActualName()), + new StringLiteral(function.getSignature().getName()), new StringLiteral(function.getSignature().getReturnType().toString()), new StringLiteral(Joiner.on(", ").join(function.getSignature().getArgumentTypes())), new StringLiteral(getFunctionType(function)), diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java index 863099f6ae92..f5597d4cb054 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java @@ -69,6 +69,7 @@ import java.util.List; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.metadata.Signature.typeVariable; @@ -1215,16 +1216,26 @@ public void testAggregateFunctionGetCanonicalName() assertEquals( aggregationOutputFunctions.stream() .map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getSignature().getName()) - .collect(toImmutableList()), - ImmutableList.of("aggregation_output", "aggregation_output", "aggregation_output")); + .collect(toImmutableSet()), + ImmutableSet.of("aggregation_output", "aggregation_output_alias_1", "aggregation_output_alias_2")); + assertEquals( + aggregationOutputFunctions.stream() + .map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getCanonicalName()) + .collect(toImmutableSet()), + ImmutableSet.of("aggregation_output")); List aggregationFunctions = parseFunctionDefinitions(AggregationFunctionWithAlias.class); assertEquals(aggregationFunctions.size(), 3); assertEquals( aggregationFunctions.stream() .map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getSignature().getName()) - .collect(toImmutableList()), - ImmutableList.of("aggregation", "aggregation", "aggregation")); + .collect(toImmutableSet()), + ImmutableSet.of("aggregation", "aggregation_alias_1", "aggregation_alias_2")); + assertEquals( + aggregationFunctions.stream() + .map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getCanonicalName()) + .collect(toImmutableSet()), + ImmutableSet.of("aggregation")); } private static InternalAggregationFunction specializeAggregationFunction(BoundSignature boundSignature, SqlAggregationFunction aggregation)