From 6b5a86198cdea6598394a862033949c6fcdf3748 Mon Sep 17 00:00:00 2001 From: striderarun Date: Thu, 28 Nov 2024 18:42:27 -0800 Subject: [PATCH] Add dereference pushdown support in ElasticSearch --- plugin/trino-elasticsearch/pom.xml | 7 + .../plugin/elasticsearch/BuiltinColumns.java | 3 +- .../ElasticsearchColumnHandle.java | 15 +- .../elasticsearch/ElasticsearchMetadata.java | 140 +++++- .../ElasticsearchTableHandle.java | 24 +- .../elasticsearch/ScanQueryPageSource.java | 30 +- .../elasticsearch/decoders/ArrayDecoder.java | 20 + .../elasticsearch/decoders/BigintDecoder.java | 20 + .../decoders/BooleanDecoder.java | 20 + .../elasticsearch/decoders/DoubleDecoder.java | 20 + .../decoders/IntegerDecoder.java | 20 + .../decoders/IpAddressDecoder.java | 21 + .../decoders/RawJsonDecoder.java | 20 + .../elasticsearch/decoders/RealDecoder.java | 20 + .../elasticsearch/decoders/RowDecoder.java | 43 ++ .../decoders/TimestampDecoder.java | 20 + .../decoders/TinyintDecoder.java | 20 + .../decoders/VarbinaryDecoder.java | 20 + .../decoders/VarcharDecoder.java | 20 + .../BaseElasticsearchConnectorTest.java | 467 +++++++++++++++++- ...ticsearchComplexTypePredicatePushDown.java | 379 ++++++++++++++ ...tElasticsearchProjectionPushdownPlans.java | 305 ++++++++++++ .../TestElasticsearchQueryBuilder.java | 8 +- 23 files changed, 1645 insertions(+), 17 deletions(-) create mode 100644 plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchComplexTypePredicatePushDown.java create mode 100644 plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchProjectionPushdownPlans.java diff --git a/plugin/trino-elasticsearch/pom.xml b/plugin/trino-elasticsearch/pom.xml index edfc70f61e22..26c44f998ab5 100644 --- a/plugin/trino-elasticsearch/pom.xml +++ b/plugin/trino-elasticsearch/pom.xml @@ -303,6 +303,13 @@ + + io.trino + trino-spi + test-jar + test + + io.trino trino-testing diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/BuiltinColumns.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/BuiltinColumns.java index c46dd7646124..d197f8917a17 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/BuiltinColumns.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/BuiltinColumns.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.elasticsearch; +import com.google.common.collect.ImmutableList; import io.trino.plugin.elasticsearch.client.IndexMetadata; import io.trino.plugin.elasticsearch.decoders.IdColumnDecoder; import io.trino.plugin.elasticsearch.decoders.ScoreColumnDecoder; @@ -86,7 +87,7 @@ public ColumnMetadata getMetadata() public ColumnHandle getColumnHandle() { return new ElasticsearchColumnHandle( - name, + ImmutableList.of(name), type, elasticsearchType, decoderDescriptor, diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchColumnHandle.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchColumnHandle.java index c6b6432d6b9b..419bd1f55ba4 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchColumnHandle.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchColumnHandle.java @@ -13,14 +13,19 @@ */ package io.trino.plugin.elasticsearch; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; import io.trino.plugin.elasticsearch.client.IndexMetadata; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.type.Type; +import java.util.List; + import static java.util.Objects.requireNonNull; public record ElasticsearchColumnHandle( - String name, + List path, Type type, IndexMetadata.Type elasticsearchType, DecoderDescriptor decoderDescriptor, @@ -29,12 +34,18 @@ public record ElasticsearchColumnHandle( { public ElasticsearchColumnHandle { - requireNonNull(name, "name is null"); + path = ImmutableList.copyOf(path); requireNonNull(type, "type is null"); requireNonNull(elasticsearchType, "elasticsearchType is null"); requireNonNull(decoderDescriptor, "decoderDescriptor is null"); } + @JsonIgnore + public String name() + { + return Joiner.on('.').join(path); + } + @Override public String toString() { diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java index 09fce8b9484c..3e10b551c849 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java @@ -19,6 +19,8 @@ import com.google.inject.Inject; import io.airlift.slice.Slice; import io.trino.plugin.base.expression.ConnectorExpressions; +import io.trino.plugin.base.projection.ApplyProjectionUtil; +import io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation; import io.trino.plugin.elasticsearch.client.ElasticsearchClient; import io.trino.plugin.elasticsearch.client.IndexMetadata; import io.trino.plugin.elasticsearch.client.IndexMetadata.DateTimeType; @@ -41,6 +43,7 @@ import io.trino.plugin.elasticsearch.decoders.VarcharDecoder; import io.trino.plugin.elasticsearch.ptf.RawQuery.RawQueryFunctionHandle; import io.trino.spi.TrinoException; +import io.trino.spi.connector.Assignment; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorMetadata; @@ -52,6 +55,7 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.LimitApplicationResult; +import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableColumnsMetadata; @@ -59,6 +63,7 @@ import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; +import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.Variable; import io.trino.spi.function.table.ConnectorTableFunctionHandle; import io.trino.spi.predicate.Domain; @@ -93,12 +98,16 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterators.singletonIterator; import static io.airlift.slice.SliceUtf8.getCodePointAt; import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables; import static io.trino.plugin.elasticsearch.ElasticsearchTableHandle.Type.QUERY; import static io.trino.plugin.elasticsearch.ElasticsearchTableHandle.Type.SCAN; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -118,6 +127,7 @@ import static java.util.Collections.emptyIterator; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; public class ElasticsearchMetadata implements ConnectorMetadata @@ -133,7 +143,7 @@ public class ElasticsearchMetadata private static final Map PASSTHROUGH_QUERY_COLUMNS = ImmutableMap.of( PASSTHROUGH_QUERY_RESULT_COLUMN_NAME, new ElasticsearchColumnHandle( - PASSTHROUGH_QUERY_RESULT_COLUMN_NAME, + ImmutableList.of(PASSTHROUGH_QUERY_RESULT_COLUMN_NAME), VARCHAR, new IndexMetadata.PrimitiveType("text"), new VarcharDecoder.Descriptor(PASSTHROUGH_QUERY_RESULT_COLUMN_NAME), @@ -258,7 +268,7 @@ private Map makeColumnHandles(List fi for (IndexMetadata.Field field : fields) { TypeAndDecoder converted = toTrino(field); result.put(field.name(), new ElasticsearchColumnHandle( - field.name(), + ImmutableList.of(field.name()), converted.type(), field.type(), converted.decoderDescriptor(), @@ -489,7 +499,8 @@ public Optional> applyLimit(Connect handle.constraint(), handle.regexes(), handle.query(), - OptionalLong.of(limit)); + OptionalLong.of(limit), + ImmutableSet.of()); return Optional.of(new LimitApplicationResult<>(handle, false, false)); } @@ -565,7 +576,8 @@ public Optional> applyFilter(C newDomain, newRegexes, handle.query(), - handle.limit()); + handle.limit(), + ImmutableSet.of()); return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported), newExpression, false)); } @@ -655,6 +667,126 @@ private static boolean isPassthroughQuery(ElasticsearchTableHandle table) return table.type().equals(QUERY); } + @Override + public Optional> applyProjection( + ConnectorSession session, + ConnectorTableHandle handle, + List projections, + Map assignments) + { + // Create projected column representations for supported sub expressions. Simple column references and chain of + // dereferences on a variable are supported right now. + Set projectedExpressions = projections.stream() + .flatMap(expression -> extractSupportedProjectedColumns(expression, ElasticsearchMetadata::isSupportedForPushdown).stream()) + .collect(toImmutableSet()); + + Map columnProjections = projectedExpressions.stream() + .collect(toImmutableMap(identity(), ApplyProjectionUtil::createProjectedColumnRepresentation)); + + ElasticsearchTableHandle elasticsearchTableHandle = (ElasticsearchTableHandle) handle; + + // all references are simple variables + if (columnProjections.values().stream().allMatch(ProjectedColumnRepresentation::isVariable)) { + Set projectedColumns = assignments.values().stream() + .map(ElasticsearchColumnHandle.class::cast) + .collect(toImmutableSet()); + if (elasticsearchTableHandle.columns().equals(projectedColumns)) { + return Optional.empty(); + } + List assignmentsList = assignments.entrySet().stream() + .map(assignment -> new Assignment( + assignment.getKey(), + assignment.getValue(), + ((ElasticsearchColumnHandle) assignment.getValue()).type())) + .collect(toImmutableList()); + + return Optional.of(new ProjectionApplicationResult<>( + elasticsearchTableHandle.withColumns(projectedColumns), + projections, + assignmentsList, + false)); + } + + Map newAssignments = new HashMap<>(); + ImmutableMap.Builder newVariablesBuilder = ImmutableMap.builder(); + ImmutableSet.Builder columns = ImmutableSet.builder(); + + for (Map.Entry entry : columnProjections.entrySet()) { + ConnectorExpression expression = entry.getKey(); + ProjectedColumnRepresentation projectedColumn = entry.getValue(); + + ElasticsearchColumnHandle baseColumnHandle = (ElasticsearchColumnHandle) assignments.get(projectedColumn.getVariable().getName()); + ElasticsearchColumnHandle projectedColumnHandle = projectColumn(baseColumnHandle, projectedColumn.getDereferenceIndices(), expression.getType()); + String projectedColumnName = projectedColumnHandle.name(); + + Variable projectedColumnVariable = new Variable(projectedColumnName, expression.getType()); + Assignment newAssignment = new Assignment(projectedColumnName, projectedColumnHandle, expression.getType()); + newAssignments.putIfAbsent(projectedColumnName, newAssignment); + + newVariablesBuilder.put(expression, projectedColumnVariable); + columns.add(projectedColumnHandle); + } + + // Modify projections to refer to new variables + Map newVariables = newVariablesBuilder.buildOrThrow(); + List newProjections = projections.stream() + .map(expression -> replaceWithNewVariables(expression, newVariables)) + .collect(toImmutableList()); + + List outputAssignments = newAssignments.values().stream().collect(toImmutableList()); + return Optional.of(new ProjectionApplicationResult<>( + elasticsearchTableHandle.withColumns(columns.build()), + newProjections, + outputAssignments, + false)); + } + + private static boolean isSupportedForPushdown(ConnectorExpression connectorExpression) + { + if (connectorExpression instanceof Variable) { + return true; + } + if (connectorExpression instanceof FieldDereference fieldDereference) { + RowType rowType = (RowType) fieldDereference.getTarget().getType(); + RowType.Field field = rowType.getFields().get(fieldDereference.getField()); + return field.getName().isPresent(); + } + return false; + } + + private static ElasticsearchColumnHandle projectColumn(ElasticsearchColumnHandle baseColumn, List indices, Type projectedColumnType) + { + if (indices.isEmpty()) { + return baseColumn; + } + ImmutableList.Builder path = ImmutableList.builder(); + path.addAll(baseColumn.path()); + + DecoderDescriptor decoderDescriptor = baseColumn.decoderDescriptor(); + IndexMetadata.Type elasticsearchType = baseColumn.elasticsearchType(); + Type type = baseColumn.type(); + + for (int index : indices) { + verify(type instanceof RowType, "type should be Row type"); + RowType rowType = (RowType) type; + RowType.Field field = rowType.getFields().get(index); + path.add(field.getName() + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "ROW type does not have field name declared: " + rowType))); + type = field.getType(); + + verify(decoderDescriptor instanceof RowDecoder.Descriptor, "decoderDescriptor should be RowDecoder.Descriptor type"); + decoderDescriptor = ((RowDecoder.Descriptor) decoderDescriptor).getFields().get(index).getDescriptor(); + elasticsearchType = ((IndexMetadata.ObjectType) elasticsearchType).fields().get(index).type(); + } + + return new ElasticsearchColumnHandle( + path.build(), + projectedColumnType, + elasticsearchType, + decoderDescriptor, + supportsPredicates(elasticsearchType, projectedColumnType)); + } + @Override public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) { diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java index c4bae910a90c..2550986de68a 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java @@ -14,6 +14,7 @@ package io.trino.plugin.elasticsearch; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.predicate.TupleDomain; @@ -21,6 +22,7 @@ import java.util.Map; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import java.util.stream.Collectors; import static java.util.Objects.requireNonNull; @@ -32,7 +34,8 @@ public record ElasticsearchTableHandle( TupleDomain constraint, Map regexes, Optional query, - OptionalLong limit) + OptionalLong limit, + Set columns) implements ConnectorTableHandle { public enum Type @@ -49,7 +52,21 @@ public ElasticsearchTableHandle(Type type, String schema, String index, Optional TupleDomain.all(), ImmutableMap.of(), query, - OptionalLong.empty()); + OptionalLong.empty(), + ImmutableSet.of()); + } + + public ElasticsearchTableHandle withColumns(Set columns) + { + return new ElasticsearchTableHandle( + type, + schema, + index, + constraint, + regexes, + query, + limit, + columns); } public ElasticsearchTableHandle @@ -58,7 +75,8 @@ public ElasticsearchTableHandle(Type type, String schema, String index, Optional requireNonNull(schema, "schema is null"); requireNonNull(index, "index is null"); requireNonNull(constraint, "constraint is null"); - regexes = ImmutableMap.copyOf(requireNonNull(regexes, "regexes is null")); + regexes = ImmutableMap.copyOf(regexes); + columns = ImmutableSet.copyOf(columns); requireNonNull(query, "query is null"); requireNonNull(limit, "limit is null"); } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java index 6e4ccc6c3e99..e32f722ee52e 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java @@ -159,8 +159,17 @@ public Page getNextPage() Map document = hit.getSourceAsMap(); for (int i = 0; i < decoders.size(); i++) { - String field = columns.get(i).name(); - decoders.get(i).decode(hit, () -> getField(document, field), columnBuilders[i]); + ElasticsearchColumnHandle columnHandle = columns.get(i); + if (columnHandle.path().size() == 1) { + decoders.get(i).decode(hit, () -> getField(document, columnHandle.path().getFirst()), columnBuilders[i]); + continue; + } + Map resolvedField = resolveField(document, columnHandle); + decoders.get(i) + .decode( + hit, + () -> resolvedField == null ? null : getField(resolvedField, columnHandle.path().getLast()), + columnBuilders[i]); } if (hit.getSourceRef() != null) { @@ -181,6 +190,23 @@ public Page getNextPage() return new Page(blocks); } + private static Map resolveField(Map document, ElasticsearchColumnHandle columnHandle) + { + if (document == null) { + return null; + } + Map value = (Map) getField(document, columnHandle.path().getFirst()); + if (value != null) { + for (int i = 1; i < columnHandle.path().size() - 1; i++) { + value = (Map) getField(value, columnHandle.path().get(i)); + if (value == null) { + break; + } + } + } + return value; + } + public static Object getField(Map document, String field) { Object value = document.get(field); diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/ArrayDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/ArrayDecoder.java index 2c6bd89b500c..29d829c5aa34 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/ArrayDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/ArrayDecoder.java @@ -21,6 +21,7 @@ import org.elasticsearch.search.SearchHit; import java.util.List; +import java.util.Objects; import java.util.function.Supplier; public class ArrayDecoder @@ -71,5 +72,24 @@ public Decoder createDecoder() { return new ArrayDecoder(elementDescriptor.createDecoder()); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.elementDescriptor, that.elementDescriptor); + } + + @Override + public int hashCode() + { + return elementDescriptor.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BigintDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BigintDecoder.java index 5543361744ac..737ff875abf5 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BigintDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BigintDecoder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -86,5 +87,24 @@ public Decoder createDecoder() { return new BigintDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BooleanDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BooleanDecoder.java index 80895325e19d..459e61d9cf79 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BooleanDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/BooleanDecoder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -85,5 +86,24 @@ public Decoder createDecoder() { return new BooleanDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/DoubleDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/DoubleDecoder.java index c0c27083c8cc..861ef43545f8 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/DoubleDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/DoubleDecoder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -91,5 +92,24 @@ public Decoder createDecoder() { return new DoubleDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IntegerDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IntegerDecoder.java index 358107531029..4237fbe7e0a9 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IntegerDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IntegerDecoder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -96,5 +97,24 @@ public Decoder createDecoder() { return new IntegerDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IpAddressDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IpAddressDecoder.java index bf2612b61f16..ec5f8423c35a 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IpAddressDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/IpAddressDecoder.java @@ -24,6 +24,7 @@ import io.trino.spi.type.Type; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.airlift.slice.Slices.wrappedBuffer; @@ -120,5 +121,25 @@ public Decoder createDecoder() { return new IpAddressDecoder(path, ipAddressType); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path) + && Objects.equals(this.ipAddressType, that.ipAddressType); + } + + @Override + public int hashCode() + { + return Objects.hash(path, ipAddressType); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RawJsonDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RawJsonDecoder.java index 141113b39388..39e08ae74ca8 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RawJsonDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RawJsonDecoder.java @@ -24,6 +24,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -85,5 +86,24 @@ public Decoder createDecoder() { return new RawJsonDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RealDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RealDecoder.java index c9a09131a184..9ac858b971b8 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RealDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RealDecoder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -91,5 +92,24 @@ public Decoder createDecoder() { return new RealDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RowDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RowDecoder.java index abcaf7b3e75b..65b5e93af49e 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RowDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/RowDecoder.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Supplier; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -103,6 +104,28 @@ public Decoder createDecoder() .map(field -> field.getDescriptor().createDecoder()) .collect(toImmutableList())); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null) { + return false; + } + if (!(o instanceof Descriptor descriptor)) { + return false; + } + return descriptor.path.equals(this.path) + && descriptor.fields.equals(this.fields); + } + + @Override + public int hashCode() + { + return Objects.hash(path, fields); + } } public static class NameAndDescriptor @@ -128,5 +151,25 @@ public DecoderDescriptor getDescriptor() { return descriptor; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NameAndDescriptor that = (NameAndDescriptor) o; + return Objects.equals(this.name, that.name) + && Objects.equals(this.descriptor, that.descriptor); + } + + @Override + public int hashCode() + { + return Objects.hash(name, descriptor); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TimestampDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TimestampDecoder.java index 5d3480586f07..ba755035bfd9 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TimestampDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TimestampDecoder.java @@ -24,6 +24,7 @@ import java.time.Instant; import java.time.LocalDateTime; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; @@ -114,5 +115,24 @@ public Decoder createDecoder() { return new TimestampDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TinyintDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TinyintDecoder.java index ac6c5f3802e2..b24ccbf80a81 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TinyintDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/TinyintDecoder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -95,5 +96,24 @@ public Decoder createDecoder() { return new TinyintDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarbinaryDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarbinaryDecoder.java index a81ce6a78c36..0b006f2755cb 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarbinaryDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarbinaryDecoder.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.SearchHit; import java.util.Base64; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -76,5 +77,24 @@ public Decoder createDecoder() { return new VarbinaryDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarcharDecoder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarcharDecoder.java index 70dd2d175538..c7c4f2c14976 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarcharDecoder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/decoders/VarcharDecoder.java @@ -21,6 +21,7 @@ import io.trino.spi.block.BlockBuilder; import org.elasticsearch.search.SearchHit; +import java.util.Objects; import java.util.function.Supplier; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; @@ -75,5 +76,24 @@ public Decoder createDecoder() { return new VarcharDecoder(path); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return Objects.equals(this.path, that.path); + } + + @Override + public int hashCode() + { + return path.hashCode(); + } } } diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java index f28ffe64c262..6f76a251642f 100644 --- a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java @@ -34,8 +34,10 @@ import java.io.IOException; import java.time.LocalDateTime; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Stream; import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_REFERENCE; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -103,6 +105,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_SET_COLUMN_TYPE, SUPPORTS_TOPN_PUSHDOWN, SUPPORTS_UPDATE -> false; + case SUPPORTS_DEREFERENCE_PUSHDOWN -> true; default -> super.hasBehavior(connectorBehavior); }; } @@ -1402,7 +1405,8 @@ public void testNestedTimestamps() .matches("VALUES " + "(TIMESTAMP '1970-01-01 00:00:00.000')," + "(TIMESTAMP '1970-01-01 00:00:00.001')," + - "(TIMESTAMP '1970-01-01 01:01:00.000')"); + "(TIMESTAMP '1970-01-01 01:01:00.000')") + .isFullyPushedDown(); } @Test @@ -2070,6 +2074,467 @@ FROM TABLE(%s.system.raw_query( .failure().hasMessageContaining("json_parse_exception"); } + @Test + void testSimpleProjectionPushdown() + throws IOException + { + String tableName = "test_projection_pushdown_" + randomNameSuffix(); + + createIndex(tableName); + index(tableName, ImmutableMap.builder() + .put("id", 1L) + .put("root", ImmutableMap.builder() + .put("f1", 1L) + .put("f2", 2L) + .buildOrThrow()) + .buildOrThrow()); + + Map record2 = new HashMap<>(); + record2.put("id", 2L); + record2.put( "root", null); + index(tableName, record2); + + Map record32 = new HashMap<>(); + record32.put("f1", null); + record32.put("f2", 4L); + + Map record3 = new HashMap<>(); + record3.put("id", 3L); + record3.put("row", record32); + index(tableName, record3); + + String selectQuery = "SELECT id, root.f1 FROM " + tableName; + String expectedResult = "VALUES (BIGINT '1', BIGINT '1'), (BIGINT '2', NULL), (BIGINT '3', NULL)"; + + // With Projection Pushdown enabled + assertThat(query(selectQuery)) + .matches(expectedResult) + .isFullyPushedDown(); + + deleteIndex(tableName); + } + + @Test + void testProjectionPushdownWithCaseSensitiveField() + throws IOException + { + String tableName = "test_projection_with_case_sensitive_field_" + randomNameSuffix();; + @Language("JSON") + String properties = + """ + { + "properties": { + "id": { + "type": "integer" + }, + "a": { + "properties": { + "UPPER_CASE": { + "type": "integer" + }, + "lower_case": { + "type": "integer" + }, + "MiXeD_cAsE": { + "type": "integer" + } + } + } + } + } + """; + + createIndex(tableName, properties); + index(tableName, ImmutableMap.builder() + .put("id", 1L) + .put("a", ImmutableMap.builder() + .put("UPPER_CASE", 2) + .put("lower_case", 3) + .put("MiXeD_cAsE", 4) + .buildOrThrow()) + .buildOrThrow()); + index(tableName, ImmutableMap.builder() + .put("id", 5L) + .put("a", ImmutableMap.builder() + .put("UPPER_CASE", 6) + .put("lower_case", 7) + .put("MiXeD_cAsE", 8) + .buildOrThrow()) + .buildOrThrow()); + + String expected = "VALUES (2, 3, 4), (6, 7, 8)"; + assertThat(query("SELECT a.UPPER_CASE, a.lower_case, a.MiXeD_cAsE FROM " + tableName)) + .matches(expected) + .isFullyPushedDown(); + assertThat(query("SELECT a.upper_case, a.lower_case, a.mixed_case FROM " + tableName)) + .matches(expected) + .isFullyPushedDown(); + assertThat(query("SELECT a.UPPER_CASE, a.LOWER_CASE, a.MIXED_CASE FROM " + tableName)) + .matches(expected) + .isFullyPushedDown(); + + deleteIndex(tableName); + } + + @Test + void testProjectionPushdownWithMultipleRows() + throws IOException + { + String tableName = "test_projection_pushdown_multiple_rows_" + randomNameSuffix(); + @Language("JSON") + String properties = + """ + { + "properties": { + "id": { + "type": "integer" + }, + "nested1": { + "properties": { + "child1": { + "type": "integer" + }, + "child2": { + "type": "text" + }, + "child3": { + "type": "integer" + } + } + }, + "nested2": { + "properties": { + "child1": { + "type": "double" + }, + "child2": { + "type": "boolean" + }, + "child3": { + "type": "date" + } + } + } + } + } + """; + + createIndex(tableName, properties); + index(tableName, ImmutableMap.builder() + .put("id", 1) + .put("nested1", ImmutableMap.builder() + .put("child1", 10) + .put("child2", "a") + .put("child3", 100) + .buildOrThrow()) + .put("nested2", ImmutableMap.builder() + .put("child1", 10.10d) + .put("child2", true) + .put("child3", "2023-04-19") + .buildOrThrow()) + .buildOrThrow()); + index(tableName, ImmutableMap.builder() + .put("id", 2) + .put("nested1", ImmutableMap.builder() + .put("child1", 20) + .put("child2", "b") + .put("child3", 200) + .buildOrThrow()) + .put("nested2", ImmutableMap.builder() + .put("child1", 20.20d) + .put("child2", false) + .put("child3", "1990-04-20") + .buildOrThrow()) + .buildOrThrow()); + + Map record3 = new HashMap<>(); + Map record3Nested1 = new HashMap<>(); + record3Nested1.put("child1", 40); + record3Nested1.put("child2", null); + record3Nested1.put("child3", 400); + record3.put("id", 4); + record3.put("nested1", record3Nested1); + record3.put("nested2", null); + index(tableName, record3); + + Map record4 = new HashMap<>(); + Map record4Nested2 = new HashMap<>(); + record4Nested2.put("child1", null); + record4Nested2.put("child2", true); + record4Nested2.put("child3", null); + record4.put("id", 5); + record4.put("nested1", null); + record4.put("nested2", record4Nested2); + index(tableName, record4); + + // Select one field from one row field + assertThat(query("SELECT id, nested1.child1 FROM " + tableName)) + .matches("VALUES (1, 10), (2, 20), (4, 40), (5, NULL)") + .isFullyPushedDown(); + assertThat(query("SELECT nested2.child3, id FROM " + tableName)) + // Use timestamp instead of date as connector converts source date to timestamp + .matches("VALUES (TIMESTAMP '2023-04-19 00:00:00.000', 1), (TIMESTAMP '1990-04-20 00:00:00.000', 2), (NULL, 4), (NULL, 5)") + .isFullyPushedDown(); + + // Select one field each from multiple row fields + assertThat(query("SELECT nested2.child1, id, nested1.child2 FROM " + tableName)) + .skippingTypesCheck() + .matches("VALUES (DOUBLE '10.10', 1, 'a'), (DOUBLE '20.20', 2, 'b'), (NULL, 4, NULL), (NULL, 5, NULL)") + .isFullyPushedDown(); + + // Select multiple fields from one row field + assertThat(query("SELECT nested1.child3, id, nested1.child2 FROM " + tableName)) + .skippingTypesCheck() + .matches("VALUES (100, 1, 'a'), (200, 2, 'b'), (400, 4, NULL), (NULL, 5, NULL)") + .isFullyPushedDown(); + assertThat(query("SELECT nested2.child2, nested2.child3, id FROM " + tableName)) + // Use timestamp instead of date as connector converts source date to timestamp + .matches("VALUES (true, TIMESTAMP '2023-04-19 00:00:00.000' , 1), (false, TIMESTAMP '1990-04-20 00:00:00.000', 2), (NULL, NULL, 4), (true, NULL, 5)") + .isFullyPushedDown(); + + // Select multiple fields from multiple row fields + assertThat(query("SELECT id, nested2.child1, nested1.child3, nested2.child2, nested1.child1 FROM " + tableName)) + .matches("VALUES (1, DOUBLE '10.10', 100, true, 10), (2, DOUBLE '20.20', 200, false, 20), (4, NULL, 400, NULL, 40), (5, NULL, NULL, true, NULL)") + .isFullyPushedDown(); + + // Select only nested fields + assertThat(query("SELECT nested2.child2, nested1.child3 FROM " + tableName)) + .matches("VALUES (true, 100), (false, 200), (NULL, 400), (true, NULL)") + .isFullyPushedDown(); + + deleteIndex(tableName); + } + + @Test + void testProjectionPushdownWithNestedData() + throws IOException + { + String tableName = "test_highly_nested_data_" + randomNameSuffix(); + index(tableName, ImmutableMap.builder() + .put("id", 1) + .put("row1_t", ImmutableMap.builder() + .put("f1", 2) + .put("f2", 3) + .put("row2_t", ImmutableMap.builder() + .put("f1", 4) + .put("f2", 5) + .put("row3_t", ImmutableMap.builder() + .put("f1", 6) + .put("f2", 7) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()); + index(tableName, ImmutableMap.builder() + .put("id", 11) + .put("row1_t", ImmutableMap.builder() + .put("f1", 12) + .put("f2", 13) + .put("row2_t", ImmutableMap.builder() + .put("f1", 14) + .put("f2", 15) + .put("row3_t", ImmutableMap.builder() + .put("f1", 16) + .put("f2", 17) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()); + index(tableName, ImmutableMap.builder() + .put("id", 21) + .put("row1_t", ImmutableMap.builder() + .put("f1", 22) + .put("f2", 23) + .put("row2_t", ImmutableMap.builder() + .put("f1", 24) + .put("f2", 25) + .put("row3_t", ImmutableMap.builder() + .put("f1", 26) + .put("f2", 27) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()); + + // Test select projected columns, with and without their parent column + assertThat(query("SELECT id, row1_t.row2_t.row3_t.f2 FROM " + tableName)).matches("VALUES (BIGINT '1', BIGINT '7'), (BIGINT '11', BIGINT '17'), (BIGINT '21', BIGINT '27')"); + assertThat(query("SELECT id, row1_t.row2_t.row3_t.f2, CAST(row1_t AS JSON) FROM " + tableName)) + .matches( + "VALUES (BIGINT '1', BIGINT '7', JSON '%s'), " + .formatted( + """ + { + "f1": 2, + "f2": 3, + "row2_t": { + "f1": 4, + "f2": 5, + "row3_t": { + "f1": 6, + "f2": 7 + } + } + } + """) + + "(BIGINT '11', BIGINT '17', JSON '%s'), " + .formatted( + """ + { + "f1": 12, + "f2": 13, + "row2_t": { + "f1": 14, + "f2": 15, + "row3_t": { + "f1": 16, + "f2": 17 + } + } + } + """) + + "(BIGINT '21', BIGINT '27', JSON '%s')" + .formatted( + """ + { + "f1": 22, + "f2": 23, + "row2_t": { + "f1": 24, + "f2": 25, + "row3_t": { + "f1": 26, + "f2": 27 + } + } + } + """)); + + // Test predicates on immediate child column and deeper nested column + assertThat(query("SELECT id, CAST(row1_t.row2_t.row3_t AS JSON) FROM " + tableName + " WHERE row1_t.row2_t.row3_t.f2 = 27")) + .matches("VALUES (BIGINT '21', JSON '%s')" + .formatted( + """ + { + "f1": 26, + "f2": 27 + } + """)); + assertThat(query("SELECT id, CAST(row1_t.row2_t.row3_t AS JSON) FROM " + tableName + " WHERE row1_t.row2_t.row3_t.f2 > 20")) + .matches("VALUES (BIGINT '21', JSON '%s')" + .formatted( + """ + { + "f1": 26, + "f2": 27 + } + """)); + assertThat(query("SELECT id, CAST(row1_t AS JSON) FROM " + tableName + " WHERE row1_t.row2_t.row3_t.f2 = 27")) + .matches("VALUES (BIGINT '21', JSON '%s')" + .formatted( + """ + { + "f1": 22, + "f2": 23, + "row2_t": { + "f1": 24, + "f2": 25, + "row3_t": { + "f1": 26, + "f2": 27 + } + } + } + """)); + assertThat(query("SELECT id, CAST(row1_t AS JSON) FROM " + tableName + " WHERE row1_t.row2_t.row3_t.f2 > 20")) + .matches("VALUES (BIGINT '21', JSON '%s')" + .formatted( + """ + { + "f1": 22, + "f2": 23, + "row2_t": { + "f1": 24, + "f2": 25, + "row3_t": { + "f1": 26, + "f2": 27 + } + } + } + """)); + + // Test predicates on parent columns + assertThat(query("SELECT id, row1_t.row2_t.row3_t.f1 FROM " + tableName + " WHERE row1_t.row2_t.row3_t = ROW(16, 17)")) + .matches("VALUES (BIGINT '11', BIGINT '16')"); + assertThat(query("SELECT id, row1_t.row2_t.row3_t.f1 FROM " + tableName + " WHERE row1_t = ROW(22, 23, ROW(24, 25, ROW(26, 27)))")) + .matches("VALUES (BIGINT '21', BIGINT '26')"); + + deleteIndex(tableName); + } + + @Test + void testDereferencePushdownWithNestedFieldsIncludingArrays() + throws IOException + { + String tableName = "test_dereference_pushdown_" + randomNameSuffix(); + index(tableName, ImmutableMap.builder() + .put("array_string_field", ImmutableList.builder() + .addAll(Stream.of("trino", "the", "lean", "machine-ohs")::iterator) + .build()) + .put("object_field_outer", ImmutableMap.builder() + .put("array_string_field", ImmutableList.builder() + .addAll(Stream.of("trino", "the", "lean", "machine-ohs")::iterator) + .build()) + .put("string_field_outer", "sample") + .put("int_field_outer", 44) + .put("object_field_inner", ImmutableMap.builder() + .put("int_field_inner", 432) + .buildOrThrow()) + .buildOrThrow()) + .put("long_field", 314159265359L) + .put("id_field", "564e6982-88ee-4498-aa98-df9e3f6b6109") + .put("timestamp_field", "1987-09-17T06:22:48.000Z") + .put("object_field", ImmutableMap.builder() + .put("array_string_field", ImmutableList.builder() + .addAll(Stream.of("trino", "the", "lean", "machine-ohs")::iterator) + .build()) + .put("string_field", "sample") + .put("int_field", 2) + .put("object_field_2", ImmutableMap.builder() + .put("array_string_field", ImmutableList.builder() + .addAll(Stream.of("trino", "the", "lean", "machine-ohs")::iterator) + .build()) + .put("string_field2", "sample") + .put("int_field", 33) + .put("object_field_3", ImmutableMap.builder() + .put("array_string_field", ImmutableList.builder() + .addAll(Stream.of("trino", "the", "lean", "machine-ohs")::iterator) + .build()) + .put("string_field3", "some value") + .put("int_field3", 55) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()); + + HashMap innerRecord = new HashMap<>(); + innerRecord.put("object_field_inner", null); + index(tableName, ImmutableMap.builder() + .put("long_field", 11122233L) + .put("object_field_outer", innerRecord) + .buildOrThrow()); + + assertThat(query("select id_field, object_field.object_field_2.object_field_3.string_field3 from " + tableName + " where object_field_outer.int_field_outer=44")) + .skippingTypesCheck() + .matches("VALUES ('564e6982-88ee-4498-aa98-df9e3f6b6109', 'some value')"); + assertThat(query("select object_field_outer.int_field_outer from " + tableName + " where object_field_outer.int_field_outer=44")) + .matches("VALUES CAST(44 as BIGINT)"); + assertThat(query("select long_field, id_field, object_field.object_field_2.object_field_3.string_field3, object_field_outer.object_field_inner.int_field_inner from " + tableName + " where long_field=11122233")) + .skippingTypesCheck() + .matches("VALUES (CAST(11122233 AS BIGINT), NULL, NULL, NULL)"); + deleteIndex(tableName); + } + protected void assertTableDoesNotExist(String name) { String catalogName = getSession().getCatalog().orElseThrow(); diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchComplexTypePredicatePushDown.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchComplexTypePredicatePushDown.java new file mode 100644 index 000000000000..e176679f173e --- /dev/null +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchComplexTypePredicatePushDown.java @@ -0,0 +1,379 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.elasticsearch; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.json.ObjectMapperProvider; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RestHighLevelClient; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static io.trino.plugin.elasticsearch.ElasticsearchServer.ELASTICSEARCH_8_IMAGE; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; + +final class TestElasticsearchComplexTypePredicatePushDown + extends AbstractTestQueryFramework +{ + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapperProvider().get(); + + private ElasticsearchServer elasticsearch; + private RestHighLevelClient client; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + elasticsearch = closeAfterClass(new ElasticsearchServer(ELASTICSEARCH_8_IMAGE)); + client = closeAfterClass(elasticsearch.getClient()); + return ElasticsearchQueryRunner.builder(elasticsearch).build(); + } + + @Test + void testRowTypeOnlyNullsRowGroupPruning() + throws IOException + { + String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix(); + @Language("JSON") + String properties = + """ + { + "properties": { + "col": { + "type": "long" + } + } + } + """; + StringBuilder payload = new StringBuilder(); + for (int i = 0; i < 4096; i++) { + Map document = new HashMap<>(); + document.put("col", null); + Map indexPayload = ImmutableMap.of("index", ImmutableMap.of("_index", tableName, "_id", String.valueOf(System.nanoTime()))); + String jsonDocument = OBJECT_MAPPER.writeValueAsString(document);; + String jsonIndex = OBJECT_MAPPER.writeValueAsString(indexPayload); + payload.append(jsonIndex).append("\n").append(jsonDocument).append("\n"); + } + + createIndex(tableName, properties); + bulkIndex(tableName, payload.toString()); + + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL"); + + tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); + properties = + """ + { + "_meta": { + "trino": { + "col": { + "b": { + "isArray": true + } + } + } + }, + "properties": { + "col": { + "properties": { + "a": { + "type": "long" + }, + "b" : { + "type" : "double" + } + } + } + } + } + """; + // Nested column `a` has nulls count of 4096 and contains only nulls + // Nested column `b` also has nulls count of 4096, but it contains non nulls as well + Random random = new Random(); + payload = new StringBuilder(); + for (int i = 0; i < 4096; i++) { + Map document = new HashMap<>(); + Map inner = new HashMap<>(); + inner.put("a", null); + List bArray = new ArrayList<>(); + bArray.add(null); + bArray.add(random.nextDouble()); + inner.put("b", bArray); + document.put("col", inner); + Map indexPayload = ImmutableMap.of("index", ImmutableMap.of("_index", tableName, "_id", String.valueOf(System.nanoTime()))); + + String jsonDocument = OBJECT_MAPPER.writeValueAsString(document); + String jsonIndex = OBJECT_MAPPER.writeValueAsString(indexPayload); + payload.append(jsonIndex).append("\n").append(jsonDocument).append("\n"); + } + + createIndex(tableName, properties); + bulkIndex(tableName, payload.toString()); + + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL"); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.a IS NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + // no predicate push down for the entire array type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b IS NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire ROW + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col IS NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + } + + @Test + void testRowTypeRowGroupPruning() + throws IOException + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + @Language("JSON") + String properties = + """ + { + "properties": { + "col1Row": { + "properties": { + "a": { + "type": "long" + }, + "b": { + "type": "long" + }, + "c": { + "properties": { + "c1": { + "type": "long" + }, + "c2": { + "properties": { + "c21": { + "type": "long" + }, + "c22": { + "type": "long" + } + } + } + } + } + } + } + } + } + """; + + int a = 2; + int b = 100; + int c1 = 1; + int c21 = 5; + int c22 = 6; + + StringBuilder payload = new StringBuilder(); + for (int i=0; i<10000; i++) { + Map document = ImmutableMap.builder() + .put("col1Row", ImmutableMap.builder() + .put("a", a) + .put("b", b) + .put("c", ImmutableMap.builder() + .put("c1", c1) + .put("c2", ImmutableMap.builder() + .put("c21", c21) + .put("c22", c22) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow()) + .buildOrThrow(); + Map indexPayload = ImmutableMap.of("index", ImmutableMap.of("_index", tableName, "_id", String.valueOf(System.nanoTime()))); + String jsonDocument = OBJECT_MAPPER.writeValueAsString(document); + String jsonIndex = OBJECT_MAPPER.writeValueAsString(indexPayload); + payload.append(jsonIndex).append("\n").append(jsonDocument).append("\n"); + + a = a + 2; + c1 = c1 + 1; + c21 = c21 + 5; + c22 = c22 + 6; + } + + createIndex(tableName, properties); + bulkIndex(tableName, payload.toString()); + + // no data read since the row dereference predicate is pushed down + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a IS NULL"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c.c2.c22 = -1"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1 AND col1ROW.b = -1 AND col1ROW.c.c1 = -1 AND col1Row.c.c2.c22 = -1"); + + // read all since predicate case matches with the data + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.b = 100", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(10000)); + + // no predicate push down for matching with ROW type, as file format only stores stats for primitives + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1))", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) OR col1Row.a = -1 ", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no data read since the row group get filtered by primitives in the predicate + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) AND col1Row.a = -1 "); + + // no predicate push down for entire ROW, as file format only stores stats for primitives + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row = ROW(-1, -1, ROW(-1, ROW(-1, -1)))", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + deleteIndex(tableName); + } + + @Test + void testArrayTypeRowGroupPruning() + throws IOException + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + @Language("JSON") + String properties = + """ + { + "_meta": { + "trino": { + "colArray": { + "isArray": true + } + } + }, + "properties": { + "colArray": { + "type": "long" + } + } + } + """; + + StringBuilder payload = new StringBuilder(); + for (int i=0; i<10000; i++) { + Map document = ImmutableMap.builder() + .put("colArray", ImmutableList.builder() + .add(100L) + .add(200L) + .build()) + .buildOrThrow(); + Map indexPayload = ImmutableMap.of("index", ImmutableMap.of("_index", tableName, "_id", String.valueOf(System.nanoTime()))); + + String jsonDocument = OBJECT_MAPPER.writeValueAsString(document); + String jsonIndex = OBJECT_MAPPER.writeValueAsString(indexPayload); + payload.append(jsonIndex).append("\n").append(jsonDocument).append("\n"); + } + createIndex(tableName, properties); + bulkIndex(tableName, payload.toString()); + + // no predicate push down for ARRAY type dereference + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colArray[1] = -1", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire ARRAY type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colArray = ARRAY[-1, -1]", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + deleteIndex(tableName); + } + + private void createIndex(String indexName, @Language("JSON") String properties) + throws IOException + { + String mappings = indexMapping(properties); + Request request = new Request("PUT", "/" + indexName); + request.setJsonEntity(mappings); + client.getLowLevelClient().performRequest(request); + } + + private static String indexMapping(@Language("JSON") String properties) + { + return "{\"mappings\": " + properties + "}"; + } + + private void bulkIndex(String index, String payload) + throws IOException + { + String endpoint = format("%s?refresh", bulkEndpoint(index)); + Request request = new Request("PUT", endpoint); + request.setJsonEntity(payload); + client.getLowLevelClient().performRequest(request); + } + + private static String bulkEndpoint(String index) + { + return format("/%s/_bulk", index); + } + + private void deleteIndex(String indexName) + throws IOException + { + Request request = new Request("DELETE", "/" + indexName); + client.getLowLevelClient().performRequest(request); + } +} diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchProjectionPushdownPlans.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchProjectionPushdownPlans.java new file mode 100644 index 000000000000..06c2fa2ab304 --- /dev/null +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchProjectionPushdownPlans.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.elasticsearch; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.net.HostAndPort; +import io.airlift.json.ObjectMapperProvider; +import io.trino.Session; +import io.trino.metadata.QualifiedObjectName; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableHandle; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.plugin.elasticsearch.client.IndexMetadata; +import io.trino.plugin.elasticsearch.decoders.BigintDecoder; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.function.OperatorType; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.assertions.BasePushdownPlanTest; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.testing.PlanTester; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RestHighLevelClient; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Predicates.equalTo; +import static com.google.common.io.Resources.getResource; +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.plugin.elasticsearch.ElasticsearchServer.ELASTICSEARCH_8_IMAGE; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.planner.assertions.PlanMatchPattern.any; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.plan.JoinType.INNER; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +final class TestElasticsearchProjectionPushdownPlans + extends BasePushdownPlanTest +{ + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapperProvider().get(); + private static final String CATALOG = "elasticsearch"; + private static final String SCHEMA = "test"; + public static final String USER = "elastic_user"; + public static final String PASSWORD = "123456"; + + private ElasticsearchServer elasticsearch; + private RestHighLevelClient client; + + @Override + protected PlanTester createPlanTester() + { + Session session = testSessionBuilder() + .setCatalog(CATALOG) + .setSchema(SCHEMA) + .build(); + + PlanTester planTester = PlanTester.create(session); + + try { + elasticsearch = new ElasticsearchServer(ELASTICSEARCH_8_IMAGE); + } + catch (IOException e) { + throw new RuntimeException(e); + } + HostAndPort address = elasticsearch.getAddress(); + client = elasticsearch.getClient(); + + try { + planTester.installPlugin(new ElasticsearchPlugin()); + planTester.createCatalog( + CATALOG, + "elasticsearch", + ImmutableMap.builder() + .put("elasticsearch.host", address.getHost()) + .put("elasticsearch.port", Integer.toString(address.getPort())) + .put("elasticsearch.ignore-publish-address", "true") + .put("elasticsearch.default-schema-name", SCHEMA) + .put("elasticsearch.scroll-size", "1000") + .put("elasticsearch.scroll-timeout", "1m") + .put("elasticsearch.request-timeout", "2m") + .put("elasticsearch.tls.enabled", "true") + .put("elasticsearch.tls.truststore-path", new File(getResource("truststore.jks").toURI()).getPath()) + .put("elasticsearch.tls.truststore-password", "123456") + .put("elasticsearch.tls.verify-hostnames", "false") + .put("elasticsearch.security", "PASSWORD") + .put("elasticsearch.auth.user", USER) + .put("elasticsearch.auth.password", PASSWORD) + .buildOrThrow()); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + catch (Throwable e) { + closeAllSuppress(e, planTester); + throw e; + } + return planTester; + } + + @AfterAll + void destroy() + throws Exception + { + elasticsearch.close(); + elasticsearch = null; + client.close(); + client = null; + } + + @Test + void testDereferencePushdown() + throws IOException + { + String tableName = "test_simple_projection_pushdown" + randomNameSuffix(); + QualifiedObjectName completeTableName = new QualifiedObjectName(CATALOG, SCHEMA, tableName); + + index(tableName, ImmutableMap.builder() + .put("col0", ImmutableMap.builder() + .put("x", 5L) + .put("y", 6L) + .buildOrThrow()) + .put("col1", 5L) + .buildOrThrow()); + + Session session = getPlanTester().getDefaultSession(); + + Optional tableHandle = getTableHandle(session, completeTableName); + assertThat(tableHandle).as("expected the table handle to be present").isPresent(); + + ElasticsearchTableHandle elasticsearchTableHandle = (ElasticsearchTableHandle) tableHandle.get().connectorHandle(); + Map columns = getColumnHandles(session, completeTableName); + + ElasticsearchColumnHandle column0Handle = (ElasticsearchColumnHandle) columns.get("col0"); + ElasticsearchColumnHandle column1Handle = (ElasticsearchColumnHandle) columns.get("col1"); + + ElasticsearchColumnHandle columnX = projectColumn(ImmutableList.of(column0Handle.path().getFirst(), "x"), BIGINT, new IndexMetadata.PrimitiveType("long"), new BigintDecoder.Descriptor("col0.x"), true); + ElasticsearchColumnHandle columnY = projectColumn(ImmutableList.of(column0Handle.path().getFirst(), "y"), BIGINT, new IndexMetadata.PrimitiveType("long"), new BigintDecoder.Descriptor("col0.y"), true); + + // Simple Projection pushdown + assertPlan( + "SELECT col0.x expr_x, col0.y expr_y FROM " + tableName, + any( + tableScan( + equalTo(elasticsearchTableHandle.withColumns(Set.of(columnX, columnY))), + TupleDomain.all(), + ImmutableMap.of("col0.x", equalTo(columnX), "col0.y", equalTo(columnY))))); + + // Projection and predicate pushdown + assertPlan( + "SELECT col0.x FROM " + tableName + " WHERE col0.x = col1 + 3 and col0.y = 2", + anyTree( + filter( + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "col1"), new Constant(BIGINT, 3L)))), + tableScan( + table -> { + ElasticsearchTableHandle actualTableHandle = (ElasticsearchTableHandle) table; + TupleDomain constraint = actualTableHandle.constraint(); + return actualTableHandle.columns().equals(ImmutableSet.of(column1Handle, columnX)) + && constraint.equals(TupleDomain.withColumnDomains(ImmutableMap.of(columnY, Domain.singleValue(BIGINT, 2L)))); + }, + TupleDomain.all(), + ImmutableMap.of("col1", equalTo(column1Handle), "x", equalTo(columnX)))))); + + // Projection and predicate pushdown with overlapping columns + assertPlan( + "SELECT col0, col0.y expr_y FROM " + tableName + " WHERE col0.x = 5", + anyTree( + tableScan( + table -> { + ElasticsearchTableHandle actualTableHandle = (ElasticsearchTableHandle) table; + TupleDomain constraint = actualTableHandle.constraint(); + return actualTableHandle.columns().equals(ImmutableSet.of(column0Handle, columnY)) + && constraint.equals(TupleDomain.withColumnDomains(ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 5L)))); + }, + TupleDomain.all(), + ImmutableMap.of("col0", equalTo(column0Handle), "y", equalTo(columnY))))); + + // Projection and predicate pushdown with joins + assertPlan( + "SELECT T.col0.x, T.col0, T.col0.y FROM " + tableName + " T join " + tableName + " S on T.col1 = S.col1 WHERE T.col0.x = 2", + anyTree( + project( + ImmutableMap.of( + "expr_0_x", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER), "expr_0"), 0)), + "expr_0", expression(new Reference(RowType.anonymousRow(INTEGER), "expr_0")), + "expr_0_y", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER, INTEGER), "expr_0"), 1))), + PlanMatchPattern.join(INNER, builder -> builder + .equiCriteria("t_expr_1", "s_expr_1") + .left( + anyTree( + tableScan( + table -> { + ElasticsearchTableHandle actualTableHandle = (ElasticsearchTableHandle) table; + TupleDomain constraint = actualTableHandle.constraint(); + Set expectedProjections = ImmutableSet.of(column0Handle, column1Handle); + TupleDomain expectedConstraint = TupleDomain.withColumnDomains( + ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 2L))); + return actualTableHandle.columns().equals(expectedProjections) + && constraint.equals(expectedConstraint); + }, + TupleDomain.all(), + ImmutableMap.of("expr_0", equalTo(column0Handle), "t_expr_1", equalTo(column1Handle))))) + .right( + anyTree( + tableScan( + equalTo(elasticsearchTableHandle.withColumns(Set.of(column1Handle))), + TupleDomain.all(), + ImmutableMap.of("s_expr_1", equalTo(column1Handle))))))))); + deleteIndex(tableName); + } + + private static ElasticsearchColumnHandle projectColumn(List path, Type projectedColumnType, IndexMetadata.Type elasticsearchType, DecoderDescriptor decoderDescriptor, boolean supportsPredicates) + { + return new ElasticsearchColumnHandle( + path, + projectedColumnType, + elasticsearchType, + decoderDescriptor, + supportsPredicates); + } + + private void createIndex(String indexName, @Language("JSON") String properties) + throws IOException + { + String mappings = indexMapping(properties); + Request request = new Request("PUT", "/" + indexName); + request.setJsonEntity(mappings); + client.getLowLevelClient().performRequest(request); + } + + private static String indexMapping(@Language("JSON") String properties) + { + return "{\"mappings\": " + properties + "}"; + } + + private void index(String index, Map document) + throws IOException + { + String json = OBJECT_MAPPER.writeValueAsString(document); + String endpoint = format("%s?refresh", indexEndpoint(index, String.valueOf(System.nanoTime()))); + + Request request = new Request("PUT", endpoint); + request.setJsonEntity(json); + client.getLowLevelClient().performRequest(request); + } + + private static String indexEndpoint(String index, String docId) + { + return format("/%s/_doc/%s", index, docId); + } + + private void deleteIndex(String indexName) + throws IOException + { + Request request = new Request("DELETE", "/" + indexName); + client.getLowLevelClient().performRequest(request); + } +} diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchQueryBuilder.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchQueryBuilder.java index f903541e8ca5..6b6e5d164f9c 100644 --- a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchQueryBuilder.java +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchQueryBuilder.java @@ -42,10 +42,10 @@ public class TestElasticsearchQueryBuilder { - private static final ElasticsearchColumnHandle NAME = new ElasticsearchColumnHandle("name", VARCHAR, new IndexMetadata.PrimitiveType("text"), new VarcharDecoder.Descriptor("name"), true); - private static final ElasticsearchColumnHandle AGE = new ElasticsearchColumnHandle("age", INTEGER, new IndexMetadata.PrimitiveType("int"), new IntegerDecoder.Descriptor("age"), true); - private static final ElasticsearchColumnHandle SCORE = new ElasticsearchColumnHandle("score", DOUBLE, new IndexMetadata.PrimitiveType("double"), new DoubleDecoder.Descriptor("score"), true); - private static final ElasticsearchColumnHandle LENGTH = new ElasticsearchColumnHandle("length", DOUBLE, new IndexMetadata.PrimitiveType("double"), new DoubleDecoder.Descriptor("length"), true); + private static final ElasticsearchColumnHandle NAME = new ElasticsearchColumnHandle(ImmutableList.of("name"), VARCHAR, new IndexMetadata.PrimitiveType("text"), new VarcharDecoder.Descriptor("name"), true); + private static final ElasticsearchColumnHandle AGE = new ElasticsearchColumnHandle(ImmutableList.of("age"), INTEGER, new IndexMetadata.PrimitiveType("int"), new IntegerDecoder.Descriptor("age"), true); + private static final ElasticsearchColumnHandle SCORE = new ElasticsearchColumnHandle(ImmutableList.of("score"), DOUBLE, new IndexMetadata.PrimitiveType("double"), new DoubleDecoder.Descriptor("score"), true); + private static final ElasticsearchColumnHandle LENGTH = new ElasticsearchColumnHandle(ImmutableList.of("length"), DOUBLE, new IndexMetadata.PrimitiveType("double"), new DoubleDecoder.Descriptor("length"), true); @Test public void testMatchAll()