Skip to content

Commit

Permalink
Speed up resolved function to/from qualified name resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Mar 15, 2022
1 parent ad8d214 commit f02c527
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 40 deletions.
6 changes: 6 additions & 0 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@
<artifactId>commons-math3</artifactId>
</dependency>

<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
<version>2.11.1</version>
</dependency>

<dependency>
<groupId>org.apache.lucene</groupId>
<artifactId>lucene-analyzers-common</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,45 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Splitter;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.json.JsonCodec;
import io.airlift.json.JsonCodecFactory;
import io.airlift.json.ObjectMapperProvider;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeId;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.tree.QualifiedName;
import io.trino.type.TypeDeserializer;
import io.trino.type.TypeSignatureDeserializer;
import io.trino.type.TypeSignatureKeyDeserializer;
import io.trino.util.ThreadSafeCompressorDecompressor;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.io.BaseEncoding.base32Hex;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder.resolveQualifiedName;
import static java.lang.Math.toIntExact;
import static java.nio.ByteBuffer.allocate;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

public class ResolvedFunction
{
private static final JsonCodec<ResolvedFunction> SERIALIZE_JSON_CODEC = new JsonCodecFactory().jsonCodec(ResolvedFunction.class);
private static final String PREFIX = "@";
private final BoundSignature signature;
private final FunctionId functionId;
Expand All @@ -55,6 +63,7 @@ public class ResolvedFunction
private final FunctionNullability functionNullability;
private final Map<TypeSignature, Type> typeDependencies;
private final Set<ResolvedFunction> functionDependencies;
private final QualifiedName qualifiedName;

@JsonCreator
public ResolvedFunction(
Expand All @@ -73,6 +82,7 @@ public ResolvedFunction(
this.functionNullability = requireNonNull(functionNullability, "nullability is null");
this.typeDependencies = ImmutableMap.copyOf(requireNonNull(typeDependencies, "typeDependencies is null"));
this.functionDependencies = ImmutableSet.copyOf(requireNonNull(functionDependencies, "functionDependencies is null"));
this.qualifiedName = resolveQualifiedName(this);
checkArgument(functionNullability.getArgumentNullable().size() == signature.getArgumentTypes().size(), "signature and functionNullability must have same argument count");
}

Expand Down Expand Up @@ -125,17 +135,7 @@ public static boolean isResolved(QualifiedName name)

public QualifiedName toQualifiedName()
{
byte[] json = SERIALIZE_JSON_CODEC.toJsonBytes(this);

// json can be large so use zstd to compress
ZstdCompressor compressor = new ZstdCompressor();
byte[] compressed = new byte[compressor.maxCompressedLength(json.length)];
int outputSize = compressor.compress(json, 0, json.length, compressed, 0, compressed.length);

// names are case insensitive, so use base32 instead of base64
String base32 = base32Hex().encode(compressed, 0, outputSize);
// add name so expressions are still readable
return QualifiedName.of(PREFIX + signature.getName() + PREFIX + base32);
return qualifiedName;
}

public static String extractFunctionName(QualifiedName qualifiedName)
Expand Down Expand Up @@ -181,6 +181,11 @@ public String toString()

public static class ResolvedFunctionDecoder
{
private static final JsonCodec<ResolvedFunction> SERIALIZE_JSON_CODEC = new JsonCodecFactory().jsonCodec(ResolvedFunction.class);
private static final ThreadSafeCompressorDecompressor COMPRESSOR_DECOMPRESSOR = new ThreadSafeCompressorDecompressor(ZstdCompressor::new, ZstdDecompressor::new);
private static final NonEvictableCache<QualifiedName, ResolvedFunction> resolvedFunctionsCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1024));
private static final NonEvictableCache<ResolvedFunction, QualifiedName> qualifiedFunctionsCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1024));

private final JsonCodec<ResolvedFunction> jsonCodec;

public ResolvedFunctionDecoder(Function<TypeId, Type> typeLoader)
Expand All @@ -196,26 +201,49 @@ Type.class, new TypeDeserializer(typeLoader),

public Optional<ResolvedFunction> fromQualifiedName(QualifiedName qualifiedName)
{
String data = qualifiedName.getSuffix();
if (!data.startsWith(PREFIX)) {
if (!qualifiedName.getSuffix().startsWith(PREFIX)) {
return Optional.empty();
}
List<String> parts = Splitter.on(PREFIX).splitToList(data.subSequence(1, data.length()));
checkArgument(parts.size() == 2, "Expected encoded resolved function to contain two parts: %s", qualifiedName);
String name = parts.get(0);

try {
return Optional.of(resolvedFunctionsCache.get(qualifiedName, () -> deserialize(qualifiedName)));
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

static QualifiedName resolveQualifiedName(ResolvedFunction function)
{
try {
return qualifiedFunctionsCache.get(function, () -> serialize(function));
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

private ResolvedFunction deserialize(QualifiedName qualifiedName)
{
String data = qualifiedName.getSuffix();
List<String> parts = Splitter.on(PREFIX).splitToList(data.substring(1));
checkArgument(parts.size() == 2, "Expected encoded resolved function to contain two parts: %s", qualifiedName);
String base32 = parts.get(1);
// name may have been lower cased, but base32 decoder requires upper case
base32 = base32.toUpperCase(ENGLISH);
byte[] compressed = base32Hex().decode(base32);

byte[] json = new byte[toIntExact(ZstdDecompressor.getDecompressedSize(compressed, 0, compressed.length))];
new ZstdDecompressor().decompress(compressed, 0, compressed.length, json, 0, json.length);
ByteBuffer decompressed = allocate(toIntExact(ZstdDecompressor.getDecompressedSize(compressed, 0, compressed.length)));
COMPRESSOR_DECOMPRESSOR.decompress(ByteBuffer.wrap(compressed), decompressed);
return jsonCodec.fromJson(Arrays.copyOf(decompressed.array(), decompressed.position()));
}

ResolvedFunction resolvedFunction = jsonCodec.fromJson(json);
checkArgument(resolvedFunction.getSignature().getName().equalsIgnoreCase(name),
"Expected decoded function to have name %s, but name is %s", resolvedFunction.getSignature().getName(), name);
return Optional.of(resolvedFunction);
static QualifiedName serialize(ResolvedFunction function)
{
byte[] value = SERIALIZE_JSON_CODEC.toJsonBytes(function);
ByteBuffer compressed = allocate(COMPRESSOR_DECOMPRESSOR.maxCompressedLength(value.length));
COMPRESSOR_DECOMPRESSOR.compress(ByteBuffer.wrap(value), compressed);
return QualifiedName.of(PREFIX + function.signature.getName() + PREFIX + base32Hex().encode(compressed.array(), 0, compressed.position()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.util;

import io.airlift.compress.Compressor;
import io.airlift.compress.Decompressor;
import io.airlift.compress.MalformedInputException;
import org.apache.commons.pool2.DestroyMode;
import org.apache.commons.pool2.ObjectPool;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.commons.pool2.impl.SoftReferenceObjectPool;

import java.nio.ByteBuffer;
import java.util.function.Function;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;

public class ThreadSafeCompressorDecompressor
implements Compressor, Decompressor
{
private final SoftReferenceObjectPool<Compressor> compressors;
private final SoftReferenceObjectPool<Decompressor> decompressors;

public ThreadSafeCompressorDecompressor(Supplier<Compressor> compressorFactory, Supplier<Decompressor> decompressorFactory)
{
this.compressors = new SoftReferenceObjectPool<>(new CompressorObjectFactory<>(requireNonNull(compressorFactory, "compressorFactory is null")));
this.decompressors = new SoftReferenceObjectPool<>(new CompressorObjectFactory<>(requireNonNull(decompressorFactory, "decompressorFactory is null")));
}

@Override
public int maxCompressedLength(int uncompressedSize)
{
return with(compressors, compressor -> compressor.maxCompressedLength(uncompressedSize));
}

@Override
public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength)
{
return with(compressors, compressor -> compressor.compress(input, inputOffset, inputLength, output, outputOffset, maxOutputLength));
}

@Override
public void compress(ByteBuffer input, ByteBuffer output)
{
with(compressors, compressor -> {
compressor.compress(input, output);
return true;
});
}

@Override
public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) throws MalformedInputException
{
return with(decompressors, decompressor -> decompressor.decompress(input, inputOffset, inputLength, output, outputOffset, maxOutputLength));
}

@Override
public void decompress(ByteBuffer input, ByteBuffer output) throws MalformedInputException
{
with(decompressors, decompressor ->
{
decompressor.decompress(input, output);
return true;
});
}

private <T, U> T with(ObjectPool<U> pool, Function<U, T> call)
{
U delegate = null;
try {
delegate = pool.borrowObject();
return call.apply(delegate);
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
if (delegate != null) {
try {
pool.returnObject(delegate);
}
catch (Exception ignored) {
}
}
}
}

private static class CompressorObjectFactory<T>
implements PooledObjectFactory<T>
{
private final Supplier<T> delegate;

CompressorObjectFactory(Supplier<T> delegate)
{
this.delegate = requireNonNull(delegate, "delegate is null");
}

@Override
public void activateObject(PooledObject<T> pooledObject)
{
}

@Override
public void destroyObject(PooledObject<T> pooledObject)
{
}

@Override
public void destroyObject(PooledObject<T> pooledObject, DestroyMode destroyMode)
{
destroyObject(pooledObject);
}

@Override
public PooledObject<T> makeObject()
{
return new DefaultPooledObject<>(delegate.get());
}

@Override
public void passivateObject(PooledObject<T> pooledObject)
{
}

@Override
public boolean validateObject(PooledObject<T> pooledObject)
{
return true;
}
}
}
Loading

0 comments on commit f02c527

Please sign in to comment.