Skip to content

Commit

Permalink
Change FunctionMetadata signature to be the actual caller signature
Browse files Browse the repository at this point in the history
Remove actual name field since signature carries actual name
Add canonical name field so push down can use canonical name
  • Loading branch information
dain committed Oct 9, 2021
1 parent f382a17 commit 01584b5
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class FunctionMetadata
{
private final FunctionId functionId;
private final Signature signature;
private final String actualName;
private final String canonicalName;
private final boolean nullable;
private final List<FunctionArgumentDefinition> argumentDefinitions;
private final boolean hidden;
Expand Down Expand Up @@ -56,7 +56,7 @@ public FunctionMetadata(

public FunctionMetadata(
Signature signature,
String actualName,
String canonicalName,
boolean nullable,
List<FunctionArgumentDefinition> argumentDefinitions,
boolean hidden,
Expand All @@ -66,16 +66,9 @@ public FunctionMetadata(
boolean deprecated)
{
this(
FunctionId.toFunctionId(
new Signature(
actualName,
signature.getTypeVariableConstraints(),
signature.getLongVariableConstraints(),
signature.getReturnType(),
signature.getArgumentTypes(),
signature.isVariableArity())),
FunctionId.toFunctionId(signature),
signature,
actualName,
canonicalName,
nullable,
argumentDefinitions,
hidden,
Expand All @@ -88,7 +81,7 @@ public FunctionMetadata(
public FunctionMetadata(
FunctionId functionId,
Signature signature,
String actualName,
String canonicalName,
boolean nullable,
List<FunctionArgumentDefinition> argumentDefinitions,
boolean hidden,
Expand All @@ -99,7 +92,7 @@ public FunctionMetadata(
{
this.functionId = requireNonNull(functionId, "functionId is null");
this.signature = requireNonNull(signature, "signature is null");
this.actualName = requireNonNull(actualName, "actualName is null");
this.canonicalName = requireNonNull(canonicalName, "canonicalName is null");
this.nullable = nullable;
this.argumentDefinitions = ImmutableList.copyOf(requireNonNull(argumentDefinitions, "argumentDefinitions is null"));
this.hidden = hidden;
Expand All @@ -110,29 +103,29 @@ public FunctionMetadata(
}

/**
* Returns {@link FunctionId} under which function is to be registered. It is based on the {@link #getActualName()},
* which is either the canonical function name or an alias.
* Unique id of this function.
* For aliased functions, each alias must have a different alias.
*/
public FunctionId getFunctionId()
{
return functionId;
}

/**
* Returns function {@link Signature} with canonical name of the function.
* Signature of a matching call site.
* For aliased functions, the signature must use the alias name.
*/
public Signature getSignature()
{
return signature;
}

/**
* Returns the name under which function is registered. Typically same as {@code getSignature().getName()}
* unless this is an alias.
* For aliased functions, the canonical name of the function.
*/
public String getActualName()
public String getCanonicalName()
{
return actualName;
return canonicalName;
}

public boolean isNullable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ public FunctionMap(FunctionMap map, Collection<? extends SqlFunction> functions)
.putAll(map.functionsByName);
functions.stream()
.map(SqlFunction::getFunctionMetadata)
.forEach(functionMetadata -> functionsByName.put(QualifiedName.of(functionMetadata.getActualName()), functionMetadata));
.forEach(functionMetadata -> functionsByName.put(QualifiedName.of(functionMetadata.getSignature().getName()), functionMetadata));
this.functionsByName = functionsByName.build();

// Make sure all functions with the same name are aggregations or none of them are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2384,7 +2384,7 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction)
return new FunctionMetadata(
functionMetadata.getFunctionId(),
resolvedFunction.getSignature().toSignature(),
functionMetadata.getActualName(),
functionMetadata.getCanonicalName(),
functionMetadata.isNullable(),
argumentDefinitions,
functionMetadata.isHidden(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionDependencies;
import io.trino.metadata.Signature;
import io.trino.operator.annotations.ImplementationDependency;

import java.lang.invoke.MethodHandle;
Expand All @@ -25,11 +26,25 @@ public final class ParametricFunctionHelpers
{
private ParametricFunctionHelpers() {}

public static MethodHandle bindDependencies(MethodHandle handle, List<ImplementationDependency> dependencies, FunctionBinding functionBinding, FunctionDependencies functionDependencies)
public static MethodHandle bindDependencies(MethodHandle handle,
List<ImplementationDependency> dependencies,
FunctionBinding functionBinding,
FunctionDependencies functionDependencies)
{
for (ImplementationDependency dependency : dependencies) {
handle = MethodHandles.insertArguments(handle, 0, dependency.resolve(functionBinding, functionDependencies));
}
return handle;
}

public static Signature signatureWithName(String name, Signature signature)
{
return new Signature(
name,
signature.getTypeVariableConstraints(),
signature.getLongVariableConstraints(),
signature.getReturnType(),
signature.getArgumentTypes(),
signature.isVariableArity());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.operator.ParametricFunctionHelpers.signatureWithName;
import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation;
import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription;
import static java.lang.String.format;
Expand Down Expand Up @@ -83,7 +84,7 @@ public static List<ParametricAggregation> parseFunctionDefinitions(Class<?> aggr
for (AggregationHeader header : parseHeaders(aggregationDefinition, outputFunction)) {
AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, removeInputFunction, outputFunction, combineFunction);
ParametricImplementationsGroup<AggregationImplementation> implementations = ParametricImplementationsGroup.of(onlyImplementation);
builder.add(new ParametricAggregation(implementations.getSignature(), header, implementations, deprecated));
builder.add(new ParametricAggregation(signatureWithName(header.getName(), implementations.getSignature()), header, implementations, deprecated));
}
}
}
Expand All @@ -109,7 +110,7 @@ public static ParametricAggregation parseFunctionDefinition(Class<?> aggregation
}

ParametricImplementationsGroup<AggregationImplementation> implementations = implementationsBuilder.build();
return new ParametricAggregation(implementations.getSignature(), header, implementations, deprecated);
return new ParametricAggregation(signatureWithName(header.getName(), implementations.getSignature()), header, implementations, deprecated);
}

private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public class AggregationHeader
Expand Down Expand Up @@ -65,4 +66,16 @@ public boolean isHidden()
{
return hidden;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("name", name)
.add("description", description)
.add("decomposable", decomposable)
.add("orderSensitive", orderSensitive)
.add("hidden", hidden)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ private Parser(
private AggregationImplementation get()
{
Signature signature = new Signature(
header.getCanonicalName(),
header.getName(),
typeVariableConstraints,
longVariableConstraints,
returnType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public ParametricAggregation(
super(
new FunctionMetadata(
signature,
details.getName(),
details.getCanonicalName(),
true,
implementations.getArgumentDefinitions(),
details.isHidden(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public static Optional<PlanNode> pushAggregationIntoTableScan(

List<AggregateFunction> aggregateFunctions = aggregationsList.stream()
.map(Entry::getValue)
.map(aggregation -> toAggregateFunction(context, aggregation))
.map(aggregation -> toAggregateFunction(metadata, context, aggregation))
.collect(toImmutableList());

List<Symbol> aggregationOutputSymbols = aggregationsList.stream()
Expand Down Expand Up @@ -214,8 +214,9 @@ public static Optional<PlanNode> pushAggregationIntoTableScan(
assignmentBuilder.build()));
}

private static AggregateFunction toAggregateFunction(Context context, AggregationNode.Aggregation aggregation)
private static AggregateFunction toAggregateFunction(Metadata metadata, Context context, AggregationNode.Aggregation aggregation)
{
String canonicalName = metadata.getFunctionMetadata(aggregation.getResolvedFunction()).getCanonicalName();
BoundSignature signature = aggregation.getResolvedFunction().getSignature();

ImmutableList.Builder<ConnectorExpression> arguments = new ImmutableList.Builder<>();
Expand All @@ -231,7 +232,7 @@ private static AggregateFunction toAggregateFunction(Context context, Aggregatio
.map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol)));

return new AggregateFunction(
signature.getName(),
canonicalName,
signature.getReturnType(),
arguments.build(),
sortBy.orElse(ImmutableList.of()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ protected Node visitShowFunctions(ShowFunctions node, Void context)
List<Expression> rows = metadata.listFunctions().stream()
.filter(function -> !function.isHidden())
.map(function -> row(
new StringLiteral(function.getActualName()),
new StringLiteral(function.getSignature().getName()),
new StringLiteral(function.getSignature().getReturnType().toString()),
new StringLiteral(Joiner.on(", ").join(function.getSignature().getArgumentTypes())),
new StringLiteral(getFunctionType(function)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import java.util.List;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.metadata.Signature.typeVariable;
Expand Down Expand Up @@ -1215,16 +1216,26 @@ public void testAggregateFunctionGetCanonicalName()
assertEquals(
aggregationOutputFunctions.stream()
.map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getSignature().getName())
.collect(toImmutableList()),
ImmutableList.of("aggregation_output", "aggregation_output", "aggregation_output"));
.collect(toImmutableSet()),
ImmutableSet.of("aggregation_output", "aggregation_output_alias_1", "aggregation_output_alias_2"));
assertEquals(
aggregationOutputFunctions.stream()
.map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getCanonicalName())
.collect(toImmutableSet()),
ImmutableSet.of("aggregation_output"));

List<ParametricAggregation> aggregationFunctions = parseFunctionDefinitions(AggregationFunctionWithAlias.class);
assertEquals(aggregationFunctions.size(), 3);
assertEquals(
aggregationFunctions.stream()
.map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getSignature().getName())
.collect(toImmutableList()),
ImmutableList.of("aggregation", "aggregation", "aggregation"));
.collect(toImmutableSet()),
ImmutableSet.of("aggregation", "aggregation_alias_1", "aggregation_alias_2"));
assertEquals(
aggregationFunctions.stream()
.map(aggregateFunction -> aggregateFunction.getFunctionMetadata().getCanonicalName())
.collect(toImmutableSet()),
ImmutableSet.of("aggregation"));
}

private static InternalAggregationFunction specializeAggregationFunction(BoundSignature boundSignature, SqlAggregationFunction aggregation)
Expand Down

0 comments on commit 01584b5

Please sign in to comment.