Skip to content

Commit

Permalink
Require session for operator resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Oct 15, 2021
1 parent efcddf4 commit 86fd8f6
Show file tree
Hide file tree
Showing 22 changed files with 93 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ default Type getParameterizedType(String baseTypeName, List<TypeSignatureParamet

ResolvedFunction resolveFunction(Session session, QualifiedName name, List<TypeSignatureProvider> parameterTypes);

ResolvedFunction resolveOperator(OperatorType operatorType, List<? extends Type> argumentTypes)
ResolvedFunction resolveOperator(Session session, OperatorType operatorType, List<? extends Type> argumentTypes)
throws OperatorNotFoundException;

default ResolvedFunction getCoercion(Type fromType, Type toType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2258,7 +2258,7 @@ public ResolvedFunction resolveFunction(Session session, QualifiedName name, Lis
}

@Override
public ResolvedFunction resolveOperator(OperatorType operatorType, List<? extends Type> argumentTypes)
public ResolvedFunction resolveOperator(Session session, OperatorType operatorType, List<? extends Type> argumentTypes)
throws OperatorNotFoundException
{
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type bou
operatorType = ADD;
}
try {
function = metadata.resolveOperator(operatorType, ImmutableList.of(sortKeyType, offsetValueType));
function = metadata.resolveOperator(session, operatorType, ImmutableList.of(sortKeyType, offsetValueType));
}
catch (TrinoException e) {
ErrorCode errorCode = e.getErrorCode();
Expand Down Expand Up @@ -2282,7 +2282,7 @@ private Type getOperator(StackableAstVisitorContext<Context> context, Expression

BoundSignature operatorSignature;
try {
operatorSignature = metadata.resolveOperator(operatorType, argumentTypes.build()).getSignature();
operatorSignature = metadata.resolveOperator(session, operatorType, argumentTypes.build()).getSignature();
}
catch (OperatorNotFoundException e) {
throw semanticException(TYPE_MISMATCH, node, e, "%s", e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2366,7 +2366,7 @@ private Scope analyzeJoinUsing(Join node, List<Identifier> columns, Optional<Sco

// ensure a comparison operator exists for the given types (applying coercions if necessary)
try {
metadata.resolveOperator(OperatorType.EQUAL, ImmutableList.of(
metadata.resolveOperator(session, OperatorType.EQUAL, ImmutableList.of(
leftField.get().getType(), rightField.get().getType()));
}
catch (OperatorNotFoundException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,8 @@ protected Object visitInPredicate(InPredicate node, Object context)
set = FastutilSetHelper.toFastutilHashSet(
objectSet,
type,
metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(),
metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle());
metadata.getScalarFunctionInvoker(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(),
metadata.getScalarFunctionInvoker(metadata.resolveOperator(session, EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle());
}
inListCache.put(valueList, set);
}
Expand All @@ -620,7 +620,7 @@ protected Object visitInPredicate(InPredicate node, Object context)
List<Object> values = new ArrayList<>(valueList.getValues().size());
List<Type> types = new ArrayList<>(valueList.getValues().size());

ResolvedFunction equalsOperator = metadata.resolveOperator(OperatorType.EQUAL, types(node.getValue(), valueList));
ResolvedFunction equalsOperator = metadata.resolveOperator(session, OperatorType.EQUAL, types(node.getValue(), valueList));
for (Expression expression : valueList.getValues()) {
// Use process() instead of processWithExceptionHandling() for processing in-list items.
// Do not handle exceptions thrown while processing a single in-list expression,
Expand Down Expand Up @@ -720,7 +720,7 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con
case PLUS:
return value;
case MINUS:
ResolvedFunction resolvedOperator = metadata.resolveOperator(OperatorType.NEGATION, types(node.getValue()));
ResolvedFunction resolvedOperator = metadata.resolveOperator(session, OperatorType.NEGATION, types(node.getValue()));
InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), FAIL_ON_NULL, true, false);
MethodHandle handle = metadata.getScalarFunctionInvoker(resolvedOperator, invocationConvention).getMethodHandle();

Expand Down Expand Up @@ -1402,7 +1402,7 @@ private boolean hasUnresolvedValue(List<Object> values)

private Object invokeOperator(OperatorType operatorType, List<? extends Type> argumentTypes, List<Object> argumentValues)
{
ResolvedFunction operator = metadata.resolveOperator(operatorType, argumentTypes);
ResolvedFunction operator = metadata.resolveOperator(session, operatorType, argumentTypes);
return functionInvoker.invoke(operator, connectorSession, argumentValues);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma
private static class CanonicalizationVisitor
implements RowExpressionVisitor<RowExpression, Void>
{
public CanonicalizationVisitor()
{
}

@Override
public RowExpression visitCall(CallExpression call, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ private Visitor(
this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null"));
this.layout = layout;
this.session = session;
standardFunctionResolution = new StandardFunctionResolution(metadata);
standardFunctionResolution = new StandardFunctionResolution(session, metadata);
}

private Type getType(Expression node)
Expand Down Expand Up @@ -409,7 +409,7 @@ protected RowExpression visitArithmeticUnary(ArithmeticUnaryExpression node, Voi
return expression;
case MINUS:
return call(
metadata.resolveOperator(NEGATION, ImmutableList.of(expression.getType())),
metadata.resolveOperator(session, NEGATION, ImmutableList.of(expression.getType())),
expression);
}

Expand Down Expand Up @@ -535,7 +535,7 @@ protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Voi
RowExpression operand = process(clause.getOperand(), context);
RowExpression result = process(clause.getResult(), context);

functionDependencies.add(metadata.resolveOperator(EQUAL, ImmutableList.of(value.getType(), operand.getType())));
functionDependencies.add(metadata.resolveOperator(session, EQUAL, ImmutableList.of(value.getType(), operand.getType())));

arguments.add(new SpecialForm(
WHEN,
Expand Down Expand Up @@ -622,9 +622,9 @@ protected RowExpression visitInPredicate(InPredicate node, Void context)
}

List<ResolvedFunction> functionDependencies = ImmutableList.<ResolvedFunction>builder()
.add(metadata.resolveOperator(EQUAL, ImmutableList.of(value.getType(), value.getType())))
.add(metadata.resolveOperator(HASH_CODE, ImmutableList.of(value.getType())))
.add(metadata.resolveOperator(INDETERMINATE, ImmutableList.of(value.getType())))
.add(metadata.resolveOperator(session, EQUAL, ImmutableList.of(value.getType(), value.getType())))
.add(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(value.getType())))
.add(metadata.resolveOperator(session, INDETERMINATE, ImmutableList.of(value.getType())))
.build();

return new SpecialForm(IN, BOOLEAN, arguments.build(), functionDependencies);
Expand Down Expand Up @@ -665,7 +665,7 @@ protected RowExpression visitNullIfExpression(NullIfExpression node, Void contex
RowExpression first = process(node.getFirst(), context);
RowExpression second = process(node.getSecond(), context);

ResolvedFunction resolvedFunction = metadata.resolveOperator(EQUAL, ImmutableList.of(first.getType(), second.getType()));
ResolvedFunction resolvedFunction = metadata.resolveOperator(session, EQUAL, ImmutableList.of(first.getType(), second.getType()));
List<ResolvedFunction> functionDependencies = ImmutableList.<ResolvedFunction>builder()
.add(resolvedFunction)
.add(metadata.getCoercion(first.getType(), resolvedFunction.getSignature().getArgumentTypes().get(0)))
Expand All @@ -687,7 +687,7 @@ protected RowExpression visitBetweenPredicate(BetweenPredicate node, Void contex
RowExpression max = process(node.getMax(), context);

List<ResolvedFunction> functionDependencies = ImmutableList.<ResolvedFunction>builder()
.add(metadata.resolveOperator(LESS_THAN_OR_EQUAL, ImmutableList.of(value.getType(), max.getType())))
.add(metadata.resolveOperator(session, LESS_THAN_OR_EQUAL, ImmutableList.of(value.getType(), max.getType())))
.build();

return new SpecialForm(
Expand All @@ -709,7 +709,7 @@ protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void
}

return call(
metadata.resolveOperator(SUBSCRIPT, ImmutableList.of(base.getType(), index.getType())),
metadata.resolveOperator(session, SUBSCRIPT, ImmutableList.of(base.getType(), index.getType())),
base,
index);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
package io.trino.sql.relational;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticBinaryExpression.Operator;
import io.trino.sql.tree.ComparisonExpression;

import static io.trino.spi.function.OperatorType.ADD;
Expand All @@ -34,14 +35,16 @@

public final class StandardFunctionResolution
{
private final Session session;
private final Metadata metadata;

public StandardFunctionResolution(Metadata metadata)
public StandardFunctionResolution(Session session, Metadata metadata)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "metadata is null");
}

public ResolvedFunction arithmeticFunction(ArithmeticBinaryExpression.Operator operator, Type leftType, Type rightType)
public ResolvedFunction arithmeticFunction(Operator operator, Type leftType, Type rightType)
{
OperatorType operatorType;
switch (operator) {
Expand All @@ -63,7 +66,7 @@ public ResolvedFunction arithmeticFunction(ArithmeticBinaryExpression.Operator o
default:
throw new IllegalStateException("Unknown arithmetic operator: " + operator);
}
return metadata.resolveOperator(operatorType, ImmutableList.of(leftType, rightType));
return metadata.resolveOperator(session, operatorType, ImmutableList.of(leftType, rightType));
}

public ResolvedFunction comparisonFunction(ComparisonExpression.Operator operator, Type leftType, Type rightType)
Expand All @@ -86,6 +89,6 @@ public ResolvedFunction comparisonFunction(ComparisonExpression.Operator operato
throw new IllegalStateException("Unsupported comparison operator type: " + operator);
}

return metadata.resolveOperator(operatorType, ImmutableList.of(leftType, rightType));
return metadata.resolveOperator(session, operatorType, ImmutableList.of(leftType, rightType));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ public ResolvedFunction resolveFunction(Session session, QualifiedName name, Lis
}

@Override
public ResolvedFunction resolveOperator(OperatorType operatorType, List<? extends Type> argumentTypes)
public ResolvedFunction resolveOperator(Session session, OperatorType operatorType, List<? extends Type> argumentTypes)
throws OperatorNotFoundException
{
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public void testIdentityCast()
@Test
public void testExactMatchBeforeCoercion()
{
Metadata metadata = createTestMetadataManager();
TestingFunctionResolution functionResolution = new TestingFunctionResolution();
Metadata metadata = functionResolution.getMetadata();
boolean foundOperator = false;
for (FunctionMetadata function : listOperators(metadata)) {
OperatorType operatorType = unmangleOperator(function.getSignature().getName());
Expand All @@ -83,7 +84,7 @@ public void testExactMatchBeforeCoercion()
List<Type> argumentTypes = function.getSignature().getArgumentTypes().stream()
.map(metadata::getType)
.collect(toImmutableList());
BoundSignature exactOperator = metadata.resolveOperator(operatorType, argumentTypes).getSignature();
BoundSignature exactOperator = functionResolution.resolveOperator(operatorType, argumentTypes).getSignature();
assertEquals(exactOperator.toSignature(), function.getSignature());
foundOperator = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public PageFunctionCompiler getPageFunctionCompiler(int expressionCacheSize)
public ResolvedFunction resolveOperator(OperatorType operatorType, List<? extends Type> argumentTypes)
throws OperatorNotFoundException
{
return inTransaction(session -> metadata.resolveOperator(operatorType, argumentTypes));
return inTransaction(session -> metadata.resolveOperator(session, operatorType, argumentTypes));
}

public ResolvedFunction getCoercion(Type fromType, Type toType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.trino.metadata.Metadata;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.project.PageProcessor;
import io.trino.spi.Page;
import io.trino.sql.gen.ExpressionCompiler;
import io.trino.sql.gen.PageFunctionCompiler;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.relational.RowExpression;
import io.trino.testing.MaterializedResult;
Expand All @@ -37,7 +36,6 @@
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.trino.RowPagesBuilder.rowPagesBuilder;
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.operator.OperatorAssertion.assertOperatorEquals;
import static io.trino.spi.function.OperatorType.ADD;
import static io.trino.spi.function.OperatorType.EQUAL;
Expand Down Expand Up @@ -83,19 +81,19 @@ public void test()
.addSequencePage(100, 0, 0)
.build();

Metadata metadata = createTestMetadataManager();
TestingFunctionResolution functionResolution = new TestingFunctionResolution();
RowExpression filter = call(
metadata.resolveOperator(LESS_THAN_OR_EQUAL, ImmutableList.of(BIGINT, BIGINT)),
functionResolution.resolveOperator(LESS_THAN_OR_EQUAL, ImmutableList.of(BIGINT, BIGINT)),
field(1, BIGINT),
constant(9L, BIGINT));

RowExpression field0 = field(0, VARCHAR);
RowExpression add5 = call(
metadata.resolveOperator(ADD, ImmutableList.of(BIGINT, BIGINT)),
functionResolution.resolveOperator(ADD, ImmutableList.of(BIGINT, BIGINT)),
field(1, BIGINT),
constant(5L, BIGINT));

ExpressionCompiler compiler = new ExpressionCompiler(metadata, new PageFunctionCompiler(metadata, 0));
ExpressionCompiler compiler = functionResolution.getExpressionCompiler();
Supplier<PageProcessor> processor = compiler.compilePageProcessor(Optional.of(filter), ImmutableList.of(field0, add5));

OperatorFactory operatorFactory = FilterAndProjectOperator.createOperatorFactory(
Expand Down Expand Up @@ -133,13 +131,13 @@ public void testMergeOutput()
.addSequencePage(100, 0, 0)
.build();

Metadata metadata = createTestMetadataManager();
TestingFunctionResolution functionResolution = new TestingFunctionResolution();
RowExpression filter = call(
metadata.resolveOperator(EQUAL, ImmutableList.of(BIGINT, BIGINT)),
functionResolution.resolveOperator(EQUAL, ImmutableList.of(BIGINT, BIGINT)),
field(1, BIGINT),
constant(10L, BIGINT));

ExpressionCompiler compiler = new ExpressionCompiler(metadata, new PageFunctionCompiler(metadata, 0));
ExpressionCompiler compiler = functionResolution.getExpressionCompiler();
Supplier<PageProcessor> processor = compiler.compilePageProcessor(Optional.of(filter), ImmutableList.of(field(1, BIGINT)));

OperatorFactory operatorFactory = FilterAndProjectOperator.createOperatorFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.metadata.Metadata;
import io.trino.metadata.Split;
import io.trino.metadata.SqlScalarFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.index.PageRecordSet;
import io.trino.operator.project.CursorProcessor;
import io.trino.operator.project.PageProcessor;
Expand Down Expand Up @@ -56,7 +57,6 @@
import static io.trino.RowPagesBuilder.rowPagesBuilder;
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.block.BlockAssertions.toValues;
import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.operator.OperatorAssertion.toMaterializedResult;
import static io.trino.operator.PageAssertions.assertPageEquals;
import static io.trino.operator.project.PageProcessor.MAX_BATCH_SIZE;
Expand All @@ -80,8 +80,8 @@
public class TestScanFilterAndProjectOperator
extends AbstractTestFunctions
{
private final Metadata metadata = createTestMetadataManager();
private final ExpressionCompiler expressionCompiler = new ExpressionCompiler(metadata, new PageFunctionCompiler(metadata, 0));
private final TestingFunctionResolution functionResolution = new TestingFunctionResolution();
private final ExpressionCompiler expressionCompiler = functionResolution.getExpressionCompiler();
private ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s"));
private ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s"));

Expand Down Expand Up @@ -140,7 +140,7 @@ public void testPageSourceMergeOutput()
.build();

RowExpression filter = call(
metadata.resolveOperator(EQUAL, ImmutableList.of(BIGINT, BIGINT)),
functionResolution.resolveOperator(EQUAL, ImmutableList.of(BIGINT, BIGINT)),
field(0, BIGINT),
constant(10L, BIGINT));
List<RowExpression> projections = ImmutableList.of(field(0, BIGINT));
Expand Down
Loading

0 comments on commit 86fd8f6

Please sign in to comment.