Skip to content

Commit

Permalink
Add dereference pushdown support in ElasticSearch
Browse files Browse the repository at this point in the history
  • Loading branch information
striderarun committed Jan 17, 2025
1 parent 358fa73 commit 6b5a861
Show file tree
Hide file tree
Showing 23 changed files with 1,645 additions and 17 deletions.
7 changes: 7 additions & 0 deletions plugin/trino-elasticsearch/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,13 @@
</exclusions>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-spi</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-testing</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.elasticsearch;

import com.google.common.collect.ImmutableList;
import io.trino.plugin.elasticsearch.client.IndexMetadata;
import io.trino.plugin.elasticsearch.decoders.IdColumnDecoder;
import io.trino.plugin.elasticsearch.decoders.ScoreColumnDecoder;
Expand Down Expand Up @@ -86,7 +87,7 @@ public ColumnMetadata getMetadata()
public ColumnHandle getColumnHandle()
{
return new ElasticsearchColumnHandle(
name,
ImmutableList.of(name),
type,
elasticsearchType,
decoderDescriptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
*/
package io.trino.plugin.elasticsearch;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.plugin.elasticsearch.client.IndexMetadata;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.type.Type;

import java.util.List;

import static java.util.Objects.requireNonNull;

public record ElasticsearchColumnHandle(
String name,
List<String> path,
Type type,
IndexMetadata.Type elasticsearchType,
DecoderDescriptor decoderDescriptor,
Expand All @@ -29,12 +34,18 @@ public record ElasticsearchColumnHandle(
{
public ElasticsearchColumnHandle
{
requireNonNull(name, "name is null");
path = ImmutableList.copyOf(path);
requireNonNull(type, "type is null");
requireNonNull(elasticsearchType, "elasticsearchType is null");
requireNonNull(decoderDescriptor, "decoderDescriptor is null");
}

@JsonIgnore
public String name()
{
return Joiner.on('.').join(path);
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.google.inject.Inject;
import io.airlift.slice.Slice;
import io.trino.plugin.base.expression.ConnectorExpressions;
import io.trino.plugin.base.projection.ApplyProjectionUtil;
import io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation;
import io.trino.plugin.elasticsearch.client.ElasticsearchClient;
import io.trino.plugin.elasticsearch.client.IndexMetadata;
import io.trino.plugin.elasticsearch.client.IndexMetadata.DateTimeType;
Expand All @@ -41,6 +43,7 @@
import io.trino.plugin.elasticsearch.decoders.VarcharDecoder;
import io.trino.plugin.elasticsearch.ptf.RawQuery.RawQueryFunctionHandle;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorMetadata;
Expand All @@ -52,13 +55,15 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.LimitApplicationResult;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SchemaTablePrefix;
import io.trino.spi.connector.TableColumnsMetadata;
import io.trino.spi.connector.TableFunctionApplicationResult;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.Variable;
import io.trino.spi.function.table.ConnectorTableFunctionHandle;
import io.trino.spi.predicate.Domain;
Expand Down Expand Up @@ -93,12 +98,16 @@
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.base.Verify.verifyNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterators.singletonIterator;
import static io.airlift.slice.SliceUtf8.getCodePointAt;
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables;
import static io.trino.plugin.elasticsearch.ElasticsearchTableHandle.Type.QUERY;
import static io.trino.plugin.elasticsearch.ElasticsearchTableHandle.Type.SCAN;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
Expand All @@ -118,6 +127,7 @@
import static java.util.Collections.emptyIterator;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class ElasticsearchMetadata
implements ConnectorMetadata
Expand All @@ -133,7 +143,7 @@ public class ElasticsearchMetadata
private static final Map<String, ColumnHandle> PASSTHROUGH_QUERY_COLUMNS = ImmutableMap.of(
PASSTHROUGH_QUERY_RESULT_COLUMN_NAME,
new ElasticsearchColumnHandle(
PASSTHROUGH_QUERY_RESULT_COLUMN_NAME,
ImmutableList.of(PASSTHROUGH_QUERY_RESULT_COLUMN_NAME),
VARCHAR,
new IndexMetadata.PrimitiveType("text"),
new VarcharDecoder.Descriptor(PASSTHROUGH_QUERY_RESULT_COLUMN_NAME),
Expand Down Expand Up @@ -258,7 +268,7 @@ private Map<String, ColumnHandle> makeColumnHandles(List<IndexMetadata.Field> fi
for (IndexMetadata.Field field : fields) {
TypeAndDecoder converted = toTrino(field);
result.put(field.name(), new ElasticsearchColumnHandle(
field.name(),
ImmutableList.of(field.name()),
converted.type(),
field.type(),
converted.decoderDescriptor(),
Expand Down Expand Up @@ -489,7 +499,8 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
handle.constraint(),
handle.regexes(),
handle.query(),
OptionalLong.of(limit));
OptionalLong.of(limit),
ImmutableSet.of());

return Optional.of(new LimitApplicationResult<>(handle, false, false));
}
Expand Down Expand Up @@ -565,7 +576,8 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
newDomain,
newRegexes,
handle.query(),
handle.limit());
handle.limit(),
ImmutableSet.of());

return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported), newExpression, false));
}
Expand Down Expand Up @@ -655,6 +667,126 @@ private static boolean isPassthroughQuery(ElasticsearchTableHandle table)
return table.type().equals(QUERY);
}

@Override
public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjection(
ConnectorSession session,
ConnectorTableHandle handle,
List<ConnectorExpression> projections,
Map<String, ColumnHandle> assignments)
{
// Create projected column representations for supported sub expressions. Simple column references and chain of
// dereferences on a variable are supported right now.
Set<ConnectorExpression> projectedExpressions = projections.stream()
.flatMap(expression -> extractSupportedProjectedColumns(expression, ElasticsearchMetadata::isSupportedForPushdown).stream())
.collect(toImmutableSet());

Map<ConnectorExpression, ProjectedColumnRepresentation> columnProjections = projectedExpressions.stream()
.collect(toImmutableMap(identity(), ApplyProjectionUtil::createProjectedColumnRepresentation));

ElasticsearchTableHandle elasticsearchTableHandle = (ElasticsearchTableHandle) handle;

// all references are simple variables
if (columnProjections.values().stream().allMatch(ProjectedColumnRepresentation::isVariable)) {
Set<ElasticsearchColumnHandle> projectedColumns = assignments.values().stream()
.map(ElasticsearchColumnHandle.class::cast)
.collect(toImmutableSet());
if (elasticsearchTableHandle.columns().equals(projectedColumns)) {
return Optional.empty();
}
List<Assignment> assignmentsList = assignments.entrySet().stream()
.map(assignment -> new Assignment(
assignment.getKey(),
assignment.getValue(),
((ElasticsearchColumnHandle) assignment.getValue()).type()))
.collect(toImmutableList());

return Optional.of(new ProjectionApplicationResult<>(
elasticsearchTableHandle.withColumns(projectedColumns),
projections,
assignmentsList,
false));
}

Map<String, Assignment> newAssignments = new HashMap<>();
ImmutableMap.Builder<ConnectorExpression, Variable> newVariablesBuilder = ImmutableMap.builder();
ImmutableSet.Builder<ElasticsearchColumnHandle> columns = ImmutableSet.builder();

for (Map.Entry<ConnectorExpression, ProjectedColumnRepresentation> entry : columnProjections.entrySet()) {
ConnectorExpression expression = entry.getKey();
ProjectedColumnRepresentation projectedColumn = entry.getValue();

ElasticsearchColumnHandle baseColumnHandle = (ElasticsearchColumnHandle) assignments.get(projectedColumn.getVariable().getName());
ElasticsearchColumnHandle projectedColumnHandle = projectColumn(baseColumnHandle, projectedColumn.getDereferenceIndices(), expression.getType());
String projectedColumnName = projectedColumnHandle.name();

Variable projectedColumnVariable = new Variable(projectedColumnName, expression.getType());
Assignment newAssignment = new Assignment(projectedColumnName, projectedColumnHandle, expression.getType());
newAssignments.putIfAbsent(projectedColumnName, newAssignment);

newVariablesBuilder.put(expression, projectedColumnVariable);
columns.add(projectedColumnHandle);
}

// Modify projections to refer to new variables
Map<ConnectorExpression, Variable> newVariables = newVariablesBuilder.buildOrThrow();
List<ConnectorExpression> newProjections = projections.stream()
.map(expression -> replaceWithNewVariables(expression, newVariables))
.collect(toImmutableList());

List<Assignment> outputAssignments = newAssignments.values().stream().collect(toImmutableList());
return Optional.of(new ProjectionApplicationResult<>(
elasticsearchTableHandle.withColumns(columns.build()),
newProjections,
outputAssignments,
false));
}

private static boolean isSupportedForPushdown(ConnectorExpression connectorExpression)
{
if (connectorExpression instanceof Variable) {
return true;
}
if (connectorExpression instanceof FieldDereference fieldDereference) {
RowType rowType = (RowType) fieldDereference.getTarget().getType();
RowType.Field field = rowType.getFields().get(fieldDereference.getField());
return field.getName().isPresent();
}
return false;
}

private static ElasticsearchColumnHandle projectColumn(ElasticsearchColumnHandle baseColumn, List<Integer> indices, Type projectedColumnType)
{
if (indices.isEmpty()) {
return baseColumn;
}
ImmutableList.Builder<String> path = ImmutableList.builder();
path.addAll(baseColumn.path());

DecoderDescriptor decoderDescriptor = baseColumn.decoderDescriptor();
IndexMetadata.Type elasticsearchType = baseColumn.elasticsearchType();
Type type = baseColumn.type();

for (int index : indices) {
verify(type instanceof RowType, "type should be Row type");
RowType rowType = (RowType) type;
RowType.Field field = rowType.getFields().get(index);
path.add(field.getName()
.orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "ROW type does not have field name declared: " + rowType)));
type = field.getType();

verify(decoderDescriptor instanceof RowDecoder.Descriptor, "decoderDescriptor should be RowDecoder.Descriptor type");
decoderDescriptor = ((RowDecoder.Descriptor) decoderDescriptor).getFields().get(index).getDescriptor();
elasticsearchType = ((IndexMetadata.ObjectType) elasticsearchType).fields().get(index).type();
}

return new ElasticsearchColumnHandle(
path.build(),
projectedColumnType,
elasticsearchType,
decoderDescriptor,
supportsPredicates(elasticsearchType, projectedColumnType));
}

@Override
public Optional<TableFunctionApplicationResult<ConnectorTableHandle>> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
package io.trino.plugin.elasticsearch;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.predicate.TupleDomain;

import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;
Expand All @@ -32,7 +34,8 @@ public record ElasticsearchTableHandle(
TupleDomain<ColumnHandle> constraint,
Map<String, String> regexes,
Optional<String> query,
OptionalLong limit)
OptionalLong limit,
Set<ElasticsearchColumnHandle> columns)
implements ConnectorTableHandle
{
public enum Type
Expand All @@ -49,7 +52,21 @@ public ElasticsearchTableHandle(Type type, String schema, String index, Optional
TupleDomain.all(),
ImmutableMap.of(),
query,
OptionalLong.empty());
OptionalLong.empty(),
ImmutableSet.of());
}

public ElasticsearchTableHandle withColumns(Set<ElasticsearchColumnHandle> columns)
{
return new ElasticsearchTableHandle(
type,
schema,
index,
constraint,
regexes,
query,
limit,
columns);
}

public ElasticsearchTableHandle
Expand All @@ -58,7 +75,8 @@ public ElasticsearchTableHandle(Type type, String schema, String index, Optional
requireNonNull(schema, "schema is null");
requireNonNull(index, "index is null");
requireNonNull(constraint, "constraint is null");
regexes = ImmutableMap.copyOf(requireNonNull(regexes, "regexes is null"));
regexes = ImmutableMap.copyOf(regexes);
columns = ImmutableSet.copyOf(columns);
requireNonNull(query, "query is null");
requireNonNull(limit, "limit is null");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,17 @@ public Page getNextPage()
Map<String, Object> document = hit.getSourceAsMap();

for (int i = 0; i < decoders.size(); i++) {
String field = columns.get(i).name();
decoders.get(i).decode(hit, () -> getField(document, field), columnBuilders[i]);
ElasticsearchColumnHandle columnHandle = columns.get(i);
if (columnHandle.path().size() == 1) {
decoders.get(i).decode(hit, () -> getField(document, columnHandle.path().getFirst()), columnBuilders[i]);
continue;
}
Map<String, Object> resolvedField = resolveField(document, columnHandle);
decoders.get(i)
.decode(
hit,
() -> resolvedField == null ? null : getField(resolvedField, columnHandle.path().getLast()),
columnBuilders[i]);
}

if (hit.getSourceRef() != null) {
Expand All @@ -181,6 +190,23 @@ public Page getNextPage()
return new Page(blocks);
}

private static Map<String, Object> resolveField(Map<String, Object> document, ElasticsearchColumnHandle columnHandle)
{
if (document == null) {
return null;
}
Map<String, Object> value = (Map<String, Object>) getField(document, columnHandle.path().getFirst());
if (value != null) {
for (int i = 1; i < columnHandle.path().size() - 1; i++) {
value = (Map<String, Object>) getField(value, columnHandle.path().get(i));
if (value == null) {
break;
}
}
}
return value;
}

public static Object getField(Map<String, Object> document, String field)
{
Object value = document.get(field);
Expand Down
Loading

0 comments on commit 6b5a861

Please sign in to comment.