From a87b89ed0a03dad533a17220cdfb841ae7824483 Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 28 Jul 2021 18:12:49 +0000 Subject: [PATCH] Add a zero-copy deserializer to gRPC Read (#564) --- .../GoogleCloudStorageGrpcReadChannel.java | 68 +++++++-- .../gcsio/ZeroCopyMessageMarshaller.java | 132 ++++++++++++++++++ .../gcsio/ZeroCopyReadinessChecker.java | 67 +++++++++ .../gcsio/ZeroCopyMessageMarshallerTest.java | 95 +++++++++++++ pom.xml | 2 +- 5 files changed, 354 insertions(+), 10 deletions(-) create mode 100644 gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshaller.java create mode 100644 gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyReadinessChecker.java create mode 100644 gcsio/src/test/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshallerTest.java diff --git a/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageGrpcReadChannel.java b/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageGrpcReadChannel.java index 2347fc7159..c730401b18 100644 --- a/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageGrpcReadChannel.java +++ b/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageGrpcReadChannel.java @@ -35,14 +35,17 @@ import com.google.common.hash.Hashing; import com.google.google.storage.v1.GetObjectMediaRequest; import com.google.google.storage.v1.GetObjectMediaResponse; +import com.google.google.storage.v1.StorageGrpc; import com.google.google.storage.v1.StorageGrpc.StorageBlockingStub; import com.google.protobuf.ByteString; import io.grpc.Context; import io.grpc.Context.CancellableContext; +import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.StatusRuntimeException; import java.io.EOFException; import java.io.IOException; +import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SeekableByteChannel; @@ -55,6 +58,16 @@ public class GoogleCloudStorageGrpcReadChannel implements SeekableByteChannel { private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); protected static final String METADATA_FIELDS = "contentEncoding,generation,size"; + // ZeroCopy version of GetObjectMedia Method + private static final ZeroCopyMessageMarshaller getObjectMediaResponseMarshaller = + new ZeroCopyMessageMarshaller(GetObjectMediaResponse.getDefaultInstance()); + private static final MethodDescriptor + getObjectMediaMethod = + StorageGrpc.getGetObjectMediaMethod().toBuilder() + .setResponseMarshaller(getObjectMediaResponseMarshaller) + .build(); + private static final boolean useZeroCopyMarshaller = ZeroCopyReadinessChecker.isReady(); + private volatile StorageBlockingStub stub; private final StorageStubProvider stubProvider; @@ -84,6 +97,10 @@ public class GoogleCloudStorageGrpcReadChannel implements SeekableByteChannel { private int bufferedContentReadOffset = 0; + // InputStream that backs bufferedContent. This needs to be closed when bufferedContent is no + // longer needed. + @Nullable private InputStream streamForBufferedContent = null; + // The streaming read operation. If null, there is not an in-flight read in progress. @Nullable private Iterator resIterator = null; @@ -331,8 +348,7 @@ private int readBufferedContentInto(ByteBuffer byteBuffer) { if (remainingBufferedContentLargerThanByteBuffer) { bufferedContentReadOffset += bytesToWrite; } else { - bufferedContent = null; - bufferedContentReadOffset = 0; + invalidateBufferedContent(); } return bytesToWrite; @@ -417,7 +433,10 @@ private int readObjectContentFromGCS(ByteBuffer byteBuffer) throws IOException { int bytesRead = 0; while (moreServerContent() && byteBuffer.hasRemaining()) { GetObjectMediaResponse res = resIterator.next(); - + // When zero-copy mashaller is used, the stream that backs GetObjectMediaResponse + // should be closed when the mssage is no longed needed so that all buffers in the + // stream can be reclaimed. If zero-copy is not used, stream will be null. + InputStream stream = getObjectMediaResponseMarshaller.popStream(res); ByteString content = res.getChecksummedData().getContent(); if (bytesToSkipBeforeReading >= 0 && bytesToSkipBeforeReading < content.size()) { content = res.getChecksummedData().getContent().substring((int) bytesToSkipBeforeReading); @@ -426,6 +445,9 @@ private int readObjectContentFromGCS(ByteBuffer byteBuffer) throws IOException { } else if (bytesToSkipBeforeReading >= content.size()) { positionInGrpcStream += content.size(); bytesToSkipBeforeReading -= content.size(); + if (stream != null) { + stream.close(); + } continue; } @@ -441,8 +463,15 @@ private int readObjectContentFromGCS(ByteBuffer byteBuffer) throws IOException { positionInGrpcStream += bytesToWrite; if (responseSizeLargerThanRemainingBuffer) { + invalidateBufferedContent(); bufferedContent = content; bufferedContentReadOffset = bytesToWrite; + // This is to keep the stream alive for the message backed by this. + streamForBufferedContent = stream; + } else { + if (stream != null) { + stream.close(); + } } } return bytesRead; @@ -513,10 +542,19 @@ private void requestObjectMedia(OptionalLong bytesToRead) throws IOException { try { requestContext = Context.current().withCancellation(); Context toReattach = requestContext.attach(); + StorageBlockingStub blockingStub = + stub.withDeadlineAfter(readOptions.getGrpcReadTimeoutMillis(), MILLISECONDS); try { - resIterator = - stub.withDeadlineAfter(readOptions.getGrpcReadTimeoutMillis(), MILLISECONDS) - .getObjectMedia(request); + if (useZeroCopyMarshaller) { + resIterator = + io.grpc.stub.ClientCalls.blockingServerStreamingCall( + blockingStub.getChannel(), + getObjectMediaMethod, + blockingStub.getCallOptions(), + request); + } else { + resIterator = blockingStub.getObjectMedia(request); + } } finally { requestContext.detach(toReattach); } @@ -631,9 +669,7 @@ public SeekableByteChannel position(long newPosition) throws IOException { // Reset any ongoing read operations or local data caches. cancelCurrentRequest(); - bufferedContent = null; - bufferedContentReadOffset = 0; - bytesToSkipBeforeReading = 0; + invalidateBufferedContent(); positionInGrpcStream = newPosition; return this; @@ -660,6 +696,7 @@ public boolean isOpen() { @Override public void close() { cancelCurrentRequest(); + invalidateBufferedContent(); channelIsOpen = false; } @@ -670,4 +707,17 @@ public String toString() { .add("generation", objectGeneration) .toString(); } + + private void invalidateBufferedContent() { + bufferedContent = null; + bufferedContentReadOffset = 0; + if (streamForBufferedContent != null) { + try { + streamForBufferedContent.close(); + streamForBufferedContent = null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } } diff --git a/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshaller.java b/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshaller.java new file mode 100644 index 0000000000..8fde949d20 --- /dev/null +++ b/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshaller.java @@ -0,0 +1,132 @@ +/* + * Copyright 2021 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.hadoop.gcsio; + +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.MessageLite; +import com.google.protobuf.Parser; +import com.google.protobuf.UnsafeByteOperations; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.KnownLength; +import io.grpc.MethodDescriptor.PrototypeMarshaller; +import io.grpc.Status; +import io.grpc.protobuf.lite.ProtoLiteUtils; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; + +/** + * Custom gRPC marshaller to use zero memory copy feature of gRPC when deserializing messages. This + * achieves zero-copy by deserializing proto messages pointing to the buffers in the input stream to + * avoid memory copy so stream should live as long as the message can be referenced. Hence, it + * exposes the input stream to applications (through popStream) and applications are responsible to + * close it when it's no longer needed. Otherwise, it'd cause memory leak. + */ +class ZeroCopyMessageMarshaller implements PrototypeMarshaller { + private Map unclosedStreams = + Collections.synchronizedMap(new IdentityHashMap<>()); + private final Parser parser; + private final PrototypeMarshaller marshaller; + + ZeroCopyMessageMarshaller(T defaultInstance) { + parser = (Parser) defaultInstance.getParserForType(); + marshaller = (PrototypeMarshaller) ProtoLiteUtils.marshaller(defaultInstance); + } + + @Override + public Class getMessageClass() { + return marshaller.getMessageClass(); + } + + @Override + public T getMessagePrototype() { + return marshaller.getMessagePrototype(); + } + + @Override + public InputStream stream(T value) { + return marshaller.stream(value); + } + + @Override + public T parse(InputStream stream) { + try { + if (stream instanceof KnownLength + && stream instanceof Detachable + && stream instanceof HasByteBuffer + && ((HasByteBuffer) stream).byteBufferSupported()) { + int size = stream.available(); + // Stream is now detached here and should be closed later. + stream = ((Detachable) stream).detach(); + // This mark call is to keep buffer while traversing buffers using skip. + stream.mark(size); + List byteStrings = new ArrayList<>(); + while (stream.available() != 0) { + ByteBuffer buffer = ((HasByteBuffer) stream).getByteBuffer(); + byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer)); + stream.skip(buffer.remaining()); + } + stream.reset(); + CodedInputStream codedInputStream = ByteString.copyFrom(byteStrings).newCodedInput(); + codedInputStream.enableAliasing(true); + codedInputStream.setSizeLimit(Integer.MAX_VALUE); + // fast path (no memory copy) + T message; + try { + message = parseFrom(codedInputStream); + } catch (InvalidProtocolBufferException ipbe) { + stream.close(); + throw Status.INTERNAL + .withDescription("Invalid protobuf byte sequence") + .withCause(ipbe) + .asRuntimeException(); + } + unclosedStreams.put(message, stream); + return message; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + // slow path + return marshaller.parse(stream); + } + + private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException { + T message = parser.parseFrom(stream); + try { + stream.checkLastTagWas(0); + return message; + } catch (InvalidProtocolBufferException e) { + e.setUnfinishedMessage(message); + throw e; + } + } + + // Application needs to call this function to get the stream for the message and + // call stream.close() function to return it to the pool. + public InputStream popStream(T message) { + return unclosedStreams.remove(message); + } +} diff --git a/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyReadinessChecker.java b/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyReadinessChecker.java new file mode 100644 index 0000000000..5207a3e331 --- /dev/null +++ b/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/ZeroCopyReadinessChecker.java @@ -0,0 +1,67 @@ +/* + * Copyright 2021 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.hadoop.gcsio; + +import com.google.common.flogger.GoogleLogger; +import com.google.protobuf.MessageLite; +import io.grpc.KnownLength; + +/** + * Checker to test whether a zero-copy masharller is available from the versions of gRPC and + * Protobuf. + */ +class ZeroCopyReadinessChecker { + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + private static final boolean isZeroCopyReady; + + static { + // Check whether io.grpc.Detachable exists? + boolean detachableClassExists = false; + try { + // Try to load Detachable interface in the package where KnownLength is in. + // This can be done directly by looking up io.grpc.Detachable but rather + // done indirectly to handle the case where gRPC is being shaded in a + // different package. + String knownLengthClassName = KnownLength.class.getName(); + String detachableClassName = + knownLengthClassName.substring(0, knownLengthClassName.lastIndexOf('.') + 1) + + "Detachable"; + Class detachableClass = Class.forName(detachableClassName); + detachableClassExists = (detachableClass != null); + } catch (ClassNotFoundException ex) { + logger.atFine().withCause(ex).log("io.grpc.Detachable not found"); + } + // Check whether com.google.protobuf.UnsafeByteOperations exists? + boolean unsafeByteOperationsClassExists = false; + try { + // Same above + String messageLiteClassName = MessageLite.class.getName(); + String unsafeByteOperationsClassName = + messageLiteClassName.substring(0, messageLiteClassName.lastIndexOf('.') + 1) + + "UnsafeByteOperations"; + Class unsafeByteOperationsClass = Class.forName(unsafeByteOperationsClassName); + unsafeByteOperationsClassExists = (unsafeByteOperationsClass != null); + } catch (ClassNotFoundException ex) { + logger.atFine().withCause(ex).log("com.google.protobuf.UnsafeByteOperations not found"); + } + isZeroCopyReady = detachableClassExists && unsafeByteOperationsClassExists; + } + + public static boolean isReady() { + return isZeroCopyReady; + } +} diff --git a/gcsio/src/test/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshallerTest.java b/gcsio/src/test/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshallerTest.java new file mode 100644 index 0000000000..b2b7671036 --- /dev/null +++ b/gcsio/src/test/java/com/google/cloud/hadoop/gcsio/ZeroCopyMessageMarshallerTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2021 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the + * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.hadoop.gcsio; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.google.storage.v1.GetObjectRequest; +import io.grpc.StatusRuntimeException; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ZeroCopyMessageMarshallerTest { + private GetObjectRequest REQUEST = + GetObjectRequest.newBuilder().setBucket("b").setObject("o").build(); + + private ZeroCopyMessageMarshaller createMarshaller() { + return new ZeroCopyMessageMarshaller<>(GetObjectRequest.getDefaultInstance()); + } + + private byte[] dropLastOneByte(byte[] bytes) { + return Arrays.copyOfRange(bytes, 0, bytes.length - 1); + } + + private InputStream createInputStream(byte[] bytes, boolean isZeroCopyable) { + ReadableBuffer buffer = + isZeroCopyable ? ReadableBuffers.wrap(ByteBuffer.wrap(bytes)) : ReadableBuffers.wrap(bytes); + return ReadableBuffers.openStream(buffer, true); + } + + @Test + public void testParseOnFastPath() throws IOException { + InputStream stream = createInputStream(REQUEST.toByteArray(), true); + ZeroCopyMessageMarshaller marshaller = createMarshaller(); + GetObjectRequest request = marshaller.parse(stream); + assertThat(request).isEqualTo(REQUEST); + InputStream stream2 = marshaller.popStream(request); + assertThat(stream2).isNotNull(); + stream2.close(); + InputStream stream3 = marshaller.popStream(request); + assertThat(stream3).isNull(); + } + + @Test + public void testParseOnSlowPath() { + InputStream stream = createInputStream(REQUEST.toByteArray(), false); + ZeroCopyMessageMarshaller marshaller = createMarshaller(); + GetObjectRequest request = marshaller.parse(stream); + assertThat(request).isEqualTo(REQUEST); + InputStream stream2 = marshaller.popStream(request); + assertThat(stream2).isNull(); + } + + @Test + public void testParseBrokenMessageOnFastPath() { + InputStream stream = createInputStream(dropLastOneByte(REQUEST.toByteArray()), true); + ZeroCopyMessageMarshaller marshaller = createMarshaller(); + assertThrows( + StatusRuntimeException.class, + () -> { + marshaller.parse(stream); + }); + } + + @Test + public void testParseBrokenMessageOnSlowPath() { + InputStream stream = createInputStream(dropLastOneByte(REQUEST.toByteArray()), false); + ZeroCopyMessageMarshaller marshaller = createMarshaller(); + assertThrows( + StatusRuntimeException.class, + () -> { + marshaller.parse(stream); + }); + } +} diff --git a/pom.xml b/pom.xml index ba1ccb3208..49070d6b55 100644 --- a/pom.xml +++ b/pom.xml @@ -99,7 +99,7 @@ 1.39.2-sp.1 1.31.5 3.17.3 - 1.38.1 + 1.39.0 2.10.1 3.2.2