Skip to content

Commit

Permalink
Add a zero-copy deserializer to gRPC Read (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
veblush authored Jul 28, 2021
1 parent 4ac88ba commit a87b89e
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@
import com.google.common.hash.Hashing;
import com.google.google.storage.v1.GetObjectMediaRequest;
import com.google.google.storage.v1.GetObjectMediaResponse;
import com.google.google.storage.v1.StorageGrpc;
import com.google.google.storage.v1.StorageGrpc.StorageBlockingStub;
import com.google.protobuf.ByteString;
import io.grpc.Context;
import io.grpc.Context.CancellableContext;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SeekableByteChannel;
Expand All @@ -55,6 +58,16 @@ public class GoogleCloudStorageGrpcReadChannel implements SeekableByteChannel {
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
protected static final String METADATA_FIELDS = "contentEncoding,generation,size";

// ZeroCopy version of GetObjectMedia Method
private static final ZeroCopyMessageMarshaller getObjectMediaResponseMarshaller =
new ZeroCopyMessageMarshaller(GetObjectMediaResponse.getDefaultInstance());
private static final MethodDescriptor<GetObjectMediaRequest, GetObjectMediaResponse>
getObjectMediaMethod =
StorageGrpc.getGetObjectMediaMethod().toBuilder()
.setResponseMarshaller(getObjectMediaResponseMarshaller)
.build();
private static final boolean useZeroCopyMarshaller = ZeroCopyReadinessChecker.isReady();

private volatile StorageBlockingStub stub;

private final StorageStubProvider stubProvider;
Expand Down Expand Up @@ -84,6 +97,10 @@ public class GoogleCloudStorageGrpcReadChannel implements SeekableByteChannel {

private int bufferedContentReadOffset = 0;

// InputStream that backs bufferedContent. This needs to be closed when bufferedContent is no
// longer needed.
@Nullable private InputStream streamForBufferedContent = null;

// The streaming read operation. If null, there is not an in-flight read in progress.
@Nullable private Iterator<GetObjectMediaResponse> resIterator = null;

Expand Down Expand Up @@ -331,8 +348,7 @@ private int readBufferedContentInto(ByteBuffer byteBuffer) {
if (remainingBufferedContentLargerThanByteBuffer) {
bufferedContentReadOffset += bytesToWrite;
} else {
bufferedContent = null;
bufferedContentReadOffset = 0;
invalidateBufferedContent();
}

return bytesToWrite;
Expand Down Expand Up @@ -417,7 +433,10 @@ private int readObjectContentFromGCS(ByteBuffer byteBuffer) throws IOException {
int bytesRead = 0;
while (moreServerContent() && byteBuffer.hasRemaining()) {
GetObjectMediaResponse res = resIterator.next();

// When zero-copy mashaller is used, the stream that backs GetObjectMediaResponse
// should be closed when the mssage is no longed needed so that all buffers in the
// stream can be reclaimed. If zero-copy is not used, stream will be null.
InputStream stream = getObjectMediaResponseMarshaller.popStream(res);
ByteString content = res.getChecksummedData().getContent();
if (bytesToSkipBeforeReading >= 0 && bytesToSkipBeforeReading < content.size()) {
content = res.getChecksummedData().getContent().substring((int) bytesToSkipBeforeReading);
Expand All @@ -426,6 +445,9 @@ private int readObjectContentFromGCS(ByteBuffer byteBuffer) throws IOException {
} else if (bytesToSkipBeforeReading >= content.size()) {
positionInGrpcStream += content.size();
bytesToSkipBeforeReading -= content.size();
if (stream != null) {
stream.close();
}
continue;
}

Expand All @@ -441,8 +463,15 @@ private int readObjectContentFromGCS(ByteBuffer byteBuffer) throws IOException {
positionInGrpcStream += bytesToWrite;

if (responseSizeLargerThanRemainingBuffer) {
invalidateBufferedContent();
bufferedContent = content;
bufferedContentReadOffset = bytesToWrite;
// This is to keep the stream alive for the message backed by this.
streamForBufferedContent = stream;
} else {
if (stream != null) {
stream.close();
}
}
}
return bytesRead;
Expand Down Expand Up @@ -513,10 +542,19 @@ private void requestObjectMedia(OptionalLong bytesToRead) throws IOException {
try {
requestContext = Context.current().withCancellation();
Context toReattach = requestContext.attach();
StorageBlockingStub blockingStub =
stub.withDeadlineAfter(readOptions.getGrpcReadTimeoutMillis(), MILLISECONDS);
try {
resIterator =
stub.withDeadlineAfter(readOptions.getGrpcReadTimeoutMillis(), MILLISECONDS)
.getObjectMedia(request);
if (useZeroCopyMarshaller) {
resIterator =
io.grpc.stub.ClientCalls.blockingServerStreamingCall(
blockingStub.getChannel(),
getObjectMediaMethod,
blockingStub.getCallOptions(),
request);
} else {
resIterator = blockingStub.getObjectMedia(request);
}
} finally {
requestContext.detach(toReattach);
}
Expand Down Expand Up @@ -631,9 +669,7 @@ public SeekableByteChannel position(long newPosition) throws IOException {

// Reset any ongoing read operations or local data caches.
cancelCurrentRequest();
bufferedContent = null;
bufferedContentReadOffset = 0;
bytesToSkipBeforeReading = 0;
invalidateBufferedContent();

positionInGrpcStream = newPosition;
return this;
Expand All @@ -660,6 +696,7 @@ public boolean isOpen() {
@Override
public void close() {
cancelCurrentRequest();
invalidateBufferedContent();
channelIsOpen = false;
}

Expand All @@ -670,4 +707,17 @@ public String toString() {
.add("generation", objectGeneration)
.toString();
}

private void invalidateBufferedContent() {
bufferedContent = null;
bufferedContentReadOffset = 0;
if (streamForBufferedContent != null) {
try {
streamForBufferedContent.close();
streamForBufferedContent = null;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright 2021 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.hadoop.gcsio;

import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.Detachable;
import io.grpc.HasByteBuffer;
import io.grpc.KnownLength;
import io.grpc.MethodDescriptor.PrototypeMarshaller;
import io.grpc.Status;
import io.grpc.protobuf.lite.ProtoLiteUtils;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;

/**
* Custom gRPC marshaller to use zero memory copy feature of gRPC when deserializing messages. This
* achieves zero-copy by deserializing proto messages pointing to the buffers in the input stream to
* avoid memory copy so stream should live as long as the message can be referenced. Hence, it
* exposes the input stream to applications (through popStream) and applications are responsible to
* close it when it's no longer needed. Otherwise, it'd cause memory leak.
*/
class ZeroCopyMessageMarshaller<T extends MessageLite> implements PrototypeMarshaller<T> {
private Map<T, InputStream> unclosedStreams =
Collections.synchronizedMap(new IdentityHashMap<>());
private final Parser<T> parser;
private final PrototypeMarshaller<T> marshaller;

ZeroCopyMessageMarshaller(T defaultInstance) {
parser = (Parser<T>) defaultInstance.getParserForType();
marshaller = (PrototypeMarshaller<T>) ProtoLiteUtils.marshaller(defaultInstance);
}

@Override
public Class<T> getMessageClass() {
return marshaller.getMessageClass();
}

@Override
public T getMessagePrototype() {
return marshaller.getMessagePrototype();
}

@Override
public InputStream stream(T value) {
return marshaller.stream(value);
}

@Override
public T parse(InputStream stream) {
try {
if (stream instanceof KnownLength
&& stream instanceof Detachable
&& stream instanceof HasByteBuffer
&& ((HasByteBuffer) stream).byteBufferSupported()) {
int size = stream.available();
// Stream is now detached here and should be closed later.
stream = ((Detachable) stream).detach();
// This mark call is to keep buffer while traversing buffers using skip.
stream.mark(size);
List<ByteString> byteStrings = new ArrayList<>();
while (stream.available() != 0) {
ByteBuffer buffer = ((HasByteBuffer) stream).getByteBuffer();
byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer));
stream.skip(buffer.remaining());
}
stream.reset();
CodedInputStream codedInputStream = ByteString.copyFrom(byteStrings).newCodedInput();
codedInputStream.enableAliasing(true);
codedInputStream.setSizeLimit(Integer.MAX_VALUE);
// fast path (no memory copy)
T message;
try {
message = parseFrom(codedInputStream);
} catch (InvalidProtocolBufferException ipbe) {
stream.close();
throw Status.INTERNAL
.withDescription("Invalid protobuf byte sequence")
.withCause(ipbe)
.asRuntimeException();
}
unclosedStreams.put(message, stream);
return message;
}
} catch (IOException e) {
throw new RuntimeException(e);
}
// slow path
return marshaller.parse(stream);
}

private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException {
T message = parser.parseFrom(stream);
try {
stream.checkLastTagWas(0);
return message;
} catch (InvalidProtocolBufferException e) {
e.setUnfinishedMessage(message);
throw e;
}
}

// Application needs to call this function to get the stream for the message and
// call stream.close() function to return it to the pool.
public InputStream popStream(T message) {
return unclosedStreams.remove(message);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2021 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.hadoop.gcsio;

import com.google.common.flogger.GoogleLogger;
import com.google.protobuf.MessageLite;
import io.grpc.KnownLength;

/**
* Checker to test whether a zero-copy masharller is available from the versions of gRPC and
* Protobuf.
*/
class ZeroCopyReadinessChecker {
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
private static final boolean isZeroCopyReady;

static {
// Check whether io.grpc.Detachable exists?
boolean detachableClassExists = false;
try {
// Try to load Detachable interface in the package where KnownLength is in.
// This can be done directly by looking up io.grpc.Detachable but rather
// done indirectly to handle the case where gRPC is being shaded in a
// different package.
String knownLengthClassName = KnownLength.class.getName();
String detachableClassName =
knownLengthClassName.substring(0, knownLengthClassName.lastIndexOf('.') + 1)
+ "Detachable";
Class<?> detachableClass = Class.forName(detachableClassName);
detachableClassExists = (detachableClass != null);
} catch (ClassNotFoundException ex) {
logger.atFine().withCause(ex).log("io.grpc.Detachable not found");
}
// Check whether com.google.protobuf.UnsafeByteOperations exists?
boolean unsafeByteOperationsClassExists = false;
try {
// Same above
String messageLiteClassName = MessageLite.class.getName();
String unsafeByteOperationsClassName =
messageLiteClassName.substring(0, messageLiteClassName.lastIndexOf('.') + 1)
+ "UnsafeByteOperations";
Class<?> unsafeByteOperationsClass = Class.forName(unsafeByteOperationsClassName);
unsafeByteOperationsClassExists = (unsafeByteOperationsClass != null);
} catch (ClassNotFoundException ex) {
logger.atFine().withCause(ex).log("com.google.protobuf.UnsafeByteOperations not found");
}
isZeroCopyReady = detachableClassExists && unsafeByteOperationsClassExists;
}

public static boolean isReady() {
return isZeroCopyReady;
}
}
Loading

0 comments on commit a87b89e

Please sign in to comment.