Skip to content

Commit

Permalink
Add PeerUidTestHelper to allow in process servers to use peer uid in …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
abtom committed Mar 14, 2024
1 parent 8a9ce99 commit 099283d
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 0 deletions.
55 changes: 55 additions & 0 deletions binder/src/main/java/io/grpc/binder/PeerUidTestHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.grpc.binder;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.Metadata.AsciiMarshaller;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;

/** Class which helps set up {@link PeerUids} to be used in tests. */
public class PeerUidTestHelper {

private static final Metadata.AsciiMarshaller<Integer> MARSHALLER =
new AsciiMarshaller<Integer>() {
@Override
public String toAsciiString(Integer value) {
return value.toString();
}

@Override
public Integer parseAsciiString(String serialized) {
return Integer.parseInt(serialized);
}
};

/** The UID of the calling package is set with the value of this key. */
public static final Metadata.Key<Integer> UID_KEY =
Metadata.Key.of("binder-remote-uid-for-unit-testing", MARSHALLER);

/**
* Creates an interceptor that associates the {@link PeerUids#REMOTE_PEER} key in the request
* {@link Context} with a UID provided by the client in the {@link #UID_KEY} request header, if
* present.
*
* <p>The returned interceptor works with any gRPC transport but is meant for in-process unit
* testing of gRPC/binder services that depend on {@link PeerUids}.
*/
public static ServerInterceptor newTestPeerIdentifyingServerInterceptor() {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current();
if (headers.containsKey(UID_KEY)) {
context = context.withValue(PeerUids.REMOTE_PEER, new PeerUid(headers.get(UID_KEY)));
}

return Contexts.interceptCall(context, call, headers, next);
}
};
}

private PeerUidTestHelper() {}
}
120 changes: 120 additions & 0 deletions binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package io.grpc.binder;

import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;

import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.ServerCalls;
import io.grpc.testing.GrpcCleanupRule;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class PeerUidTestHelperTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();

private static final int FAKE_UID = 12345;

private final AtomicReference<PeerUid> clientUidCapture = new AtomicReference<>();

@Test
public void keyPopulatedWithInterceptorAndHeader() throws Exception {
makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ true, FAKE_UID);
assertThat(clientUidCapture.get()).isEqualTo(new PeerUid(FAKE_UID));
}

@Test
public void keyNotPopulatedWithInterceptorAndNoHeader() throws Exception {
makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ false, /* uid= */ -1);
assertThat(clientUidCapture.get()).isNull();
}

@Test
public void keyNotPopulatedWithoutInterceptorAndWithHeader() throws Exception {
makeServiceCall(
/* includeInterceptor= */ false, /* includeUidInHeader= */ true, /* uid= */ FAKE_UID);
assertThat(clientUidCapture.get()).isNull();
}

private final MethodDescriptor<String, String> method =
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
.setFullMethodName("test/method")
.setType(MethodDescriptor.MethodType.UNARY)
.build();

private void makeServiceCall(boolean includeInterceptor, boolean includeUidInHeader, int uid)
throws Exception {
ServerCallHandler<String, String> callHandler =
ServerCalls.asyncUnaryCall(
(req, respObserver) -> {
clientUidCapture.set(PeerUids.REMOTE_PEER.get());
respObserver.onNext(req);
respObserver.onCompleted();
});
ImmutableList<ServerInterceptor> interceptors;
if (includeInterceptor) {
interceptors = ImmutableList.of(PeerUidTestHelper.newTestPeerIdentifyingServerInterceptor());
} else {
interceptors = ImmutableList.of();
}
ServerServiceDefinition serviceDef =
ServerInterceptors.intercept(
ServerServiceDefinition.builder("test").addMethod(method, callHandler).build(),
interceptors);

InProcessServerBuilder server =
InProcessServerBuilder.forName("test").directExecutor().addService(serviceDef);

grpcCleanup.register(server.build().start());

Channel channel = InProcessChannelBuilder.forName("test").directExecutor().build();
grpcCleanup.register((ManagedChannel) channel);

if (includeUidInHeader) {
Metadata header = new Metadata();
header.put(PeerUidTestHelper.UID_KEY, uid);
channel =
ClientInterceptors.intercept(channel, MetadataUtils.newAttachHeadersInterceptor(header));
}

ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, "hello");
}

private static class StringMarshaller implements MethodDescriptor.Marshaller<String> {
public static final StringMarshaller INSTANCE = new StringMarshaller();

@Override
public InputStream stream(String value) {
return new ByteArrayInputStream(value.getBytes(UTF_8));
}

@Override
public String parse(InputStream stream) {
try {
return new String(stream.readAllBytes(), UTF_8);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
}

0 comments on commit 099283d

Please sign in to comment.