diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java index 18d5cd99d74..42796e2caed 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java @@ -26,6 +26,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.protobuf.Empty; import io.grpc.CallOptions; import io.grpc.ManagedChannel; @@ -45,6 +46,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; @@ -76,6 +79,7 @@ public void setupServiceDefinitionsAndMethods() { MethodDescriptor.newBuilder(marshaller, marshaller) .setFullMethodName(name) .setType(MethodDescriptor.MethodType.UNARY) + .setSampledToLocalTracing(true) .build(); ServerCallHandler callHandler = ServerCalls.asyncUnaryCall( @@ -139,12 +143,16 @@ private void assertCallSuccess(MethodDescriptor method) { .isNotNull(); } - private void assertCallFailure(MethodDescriptor method, Status status) { + @CanIgnoreReturnValue + private StatusRuntimeException assertCallFailure( + MethodDescriptor method, Status status) { try { ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null); - fail(); + fail("Expected call to " + method.getFullMethodName() + " to fail but it succeeded."); + throw new AssertionError(); // impossible } catch (StatusRuntimeException sre) { assertThat(sre.getStatus().getCode()).isEqualTo(status.getCode()); + return sre; } } @@ -172,6 +180,70 @@ public void testServerDisallowsCalls() throws Exception { } } + @Test + public void testFailedFuturesPropagateOriginalException() throws Exception { + String errorMessage = "something went wrong"; + IllegalStateException originalException = new IllegalStateException(errorMessage); + createChannel( + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", new AsyncSecurityPolicy() { + @Override + ListenableFuture checkAuthorizationAsync(int uid) { + return Futures.immediateFailedFuture(originalException); + } + }) + .build(), + SecurityPolicies.internalOnly()); + MethodDescriptor method = methods.get("foo/method0"); + + StatusRuntimeException sre = assertCallFailure(method, Status.INTERNAL); + assertThat(sre.getStatus().getDescription()).contains(errorMessage); + } + + @Test + public void testFailedFuturesAreNotCachedPermanently() throws Exception { + AtomicReference firstAttempt = new AtomicReference<>(true); + createChannel( + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", new AsyncSecurityPolicy() { + @Override + ListenableFuture checkAuthorizationAsync(int uid) { + if (firstAttempt.getAndSet(false)) { + return Futures.immediateFailedFuture(new IllegalStateException()); + } + return Futures.immediateFuture(Status.OK); + } + }) + .build(), + SecurityPolicies.internalOnly()); + MethodDescriptor method = methods.get("foo/method0"); + + assertCallFailure(method, Status.INTERNAL); + assertCallSuccess(method); + } + + @Test + public void testCancelledFuturesAreNotCachedPermanently() throws Exception { + AtomicReference firstAttempt = new AtomicReference<>(true); + createChannel( + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", new AsyncSecurityPolicy() { + @Override + ListenableFuture checkAuthorizationAsync(int uid) { + if (firstAttempt.getAndSet(false)) { + return Futures.immediateCancelledFuture(); + } + return Futures.immediateFuture(Status.OK); + } + }) + .build(), + SecurityPolicies.internalOnly()); + MethodDescriptor method = methods.get("foo/method0"); + + assertCallFailure(method, Status.INTERNAL); + assertCallSuccess(method); + } + @Test public void testClientDoesntTrustServer() throws Exception { createChannel(SecurityPolicies.serverInternalOnly(), policy((uid) -> false)); diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java index 1866bf54a47..56464d58a4b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java @@ -19,6 +19,7 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import io.grpc.Attributes; import io.grpc.Internal; @@ -32,6 +33,7 @@ import io.grpc.Status; import io.grpc.internal.GrpcAttributes; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -110,9 +112,13 @@ public ServerCall.Listener interceptCall( Status authStatus; try { authStatus = Futures.getDone(authStatusFuture); - } catch (ExecutionException e) { + } catch (ExecutionException | CancellationException e) { // Failed futures are treated as an internal error rather than a security rejection. authStatus = Status.INTERNAL.withCause(e); + @Nullable String message = e.getMessage(); + if (message != null) { + authStatus = authStatus.withDescription(message); + } } if (authStatus.isOk()) { @@ -179,6 +185,8 @@ ListenableFuture checkAuthorization(MethodDescriptor method) { if (useCache) { @Nullable ListenableFuture authorization = serviceAuthorization.get(serviceName); if (authorization != null) { + // Authorization check exists and is a pending or successful future (even if for a + // failed authorization). return authorization; } } @@ -193,6 +201,15 @@ ListenableFuture checkAuthorization(MethodDescriptor method) { serverPolicyChecker.checkAuthorizationForServiceAsync(uid, serviceName); if (useCache) { serviceAuthorization.putIfAbsent(serviceName, authorization); + Futures.addCallback(authorization, new FutureCallback() { + @Override + public void onSuccess(Status result) {} + + @Override + public void onFailure(Throwable t) { + serviceAuthorization.remove(serviceName, authorization); + } + }, MoreExecutors.directExecutor()); } return authorization; }