Skip to content

Commit

Permalink
binder: Simplify ownership of ServerAuthInterceptor's executor. (#11293)
Browse files Browse the repository at this point in the history
Allocating this executor before BinderServer even exists is convoluted and actually leaks if the built server is never actually start()ed. Instead, have BinderServer own this executor directly, with a lifetime from start() until termination. Pass it to the ServerAuthInterceptor via TransportAuthorizationState Attribute instead of at construction time.
  • Loading branch information
jdcormie authored Jun 19, 2024
1 parent c540993 commit 15ad9f5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,12 @@
package io.grpc.binder.internal;

import android.content.Context;
import androidx.core.content.ContextCompat;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import io.grpc.ServerStreamTracer;
import io.grpc.binder.AndroidComponentAddress;
import io.grpc.binder.BindServiceFlags;
import io.grpc.binder.BinderChannelCredentials;
import io.grpc.binder.HostServices;
import io.grpc.binder.InboundParcelablePolicy;
import io.grpc.binder.SecurityPolicies;
import io.grpc.internal.AbstractTransportTest;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ClientTransportFactory.ClientTransportOptions;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.InternalServer;
Expand Down Expand Up @@ -57,6 +51,8 @@ public final class BinderTransportTest extends AbstractTransportTest {
SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE);
private final ObjectPool<Executor> offloadExecutorPool =
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR);
private final ObjectPool<Executor> serverExecutorPool =
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR);

@Override
@After
Expand All @@ -69,11 +65,13 @@ public void tearDown() throws InterruptedException {
protected InternalServer newServer(List<ServerStreamTracer.Factory> streamTracerFactories) {
AndroidComponentAddress addr = HostServices.allocateService(appContext);

BinderServer binderServer = new BinderServer.Builder()
.setListenAddress(addr)
.setExecutorServicePool(executorServicePool)
.setStreamTracerFactories(streamTracerFactories)
.build();
BinderServer binderServer =
new BinderServer.Builder()
.setListenAddress(addr)
.setExecutorPool(serverExecutorPool)
.setExecutorServicePool(executorServicePool)
.setStreamTracerFactories(streamTracerFactories)
.build();

HostServices.configureService(addr,
HostServices.serviceParamsBuilder()
Expand Down
8 changes: 2 additions & 6 deletions binder/src/main/java/io/grpc/binder/BinderServerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@
import io.grpc.binder.internal.BinderTransportSecurity;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.ServerImplBuilder;
import io.grpc.internal.ObjectPool;

import java.io.File;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;

/**
Expand Down Expand Up @@ -163,10 +161,8 @@ public Server build() {
checkState(!isBuilt, "BinderServerBuilder can only be used to build one server instance.");
isBuilt = true;
// We install the security interceptor last, so it's closest to the transport.
ObjectPool<? extends Executor> executorPool = serverImplBuilder.getExecutorPool();
Executor executor = executorPool.getObject();
BinderTransportSecurity.installAuthInterceptor(this, executor);
internalBuilder.setTerminationListener(() -> executorPool.returnObject(executor));
BinderTransportSecurity.installAuthInterceptor(this);
internalBuilder.setExecutorPool(serverImplBuilder.getExecutorPool());
return super.build();
}
}
46 changes: 30 additions & 16 deletions binder/src/main/java/io/grpc/binder/internal/BinderServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.io.IOException;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -63,30 +64,34 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder.
private static final Logger logger = Logger.getLogger(BinderServer.class.getName());

private final ObjectPool<ScheduledExecutorService> executorServicePool;
private final ObjectPool<? extends Executor> executorPool;
private final ImmutableList<ServerStreamTracer.Factory> streamTracerFactories;
private final AndroidComponentAddress listenAddress;
private final LeakSafeOneWayBinder hostServiceBinder;
private final BinderTransportSecurity.ServerPolicyChecker serverPolicyChecker;
private final InboundParcelablePolicy inboundParcelablePolicy;
private final Runnable terminationListener;

@GuardedBy("this")
private ServerListener listener;

@GuardedBy("this")
private ScheduledExecutorService executorService;

@Nullable // Before start() and after termination.
@GuardedBy("this")
private Executor executor;

@GuardedBy("this")
private boolean shutdown;

private BinderServer(Builder builder) {
this.listenAddress = checkNotNull(builder.listenAddress);
this.executorPool = checkNotNull(builder.executorPool);
this.executorServicePool = builder.executorServicePool;
this.streamTracerFactories =
ImmutableList.copyOf(checkNotNull(builder.streamTracerFactories, "streamTracerFactories"));
this.serverPolicyChecker = BinderInternal.createPolicyChecker(builder.serverSecurityPolicy);
this.inboundParcelablePolicy = builder.inboundParcelablePolicy;
this.terminationListener = builder.terminationListener;
hostServiceBinder = new LeakSafeOneWayBinder(this);
}

Expand All @@ -97,8 +102,9 @@ public IBinder getHostBinder() {

@Override
public synchronized void start(ServerListener serverListener) throws IOException {
listener = new ActiveTransportTracker(serverListener, terminationListener);
listener = new ActiveTransportTracker(serverListener, this::onTerminated);
executorService = executorServicePool.getObject();
executor = executorPool.getObject();
}

@Override
Expand Down Expand Up @@ -129,10 +135,15 @@ public synchronized void shutdown() {
// Break the connection to the binder. We'll receive no more transactions.
hostServiceBinder.setHandler(GoAwayHandler.INSTANCE);
listener.serverShutdown();
// TODO(jdcormie): Shouldn't this happen in onTerminated()? Is this even used anywhere?
executorService = executorServicePool.returnObject(executorService);
}
}

private synchronized void onTerminated() {
executor = executorPool.returnObject(executor);
}

@Override
public String toString() {
return "BinderServer[" + listenAddress + "]";
Expand Down Expand Up @@ -161,7 +172,11 @@ public synchronized boolean handleTransaction(int code, Parcel parcel) {
.set(BinderTransport.REMOTE_UID, callingUid)
.set(BinderTransport.SERVER_AUTHORITY, listenAddress.getAuthority())
.set(BinderTransport.INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy);
BinderTransportSecurity.attachAuthAttrs(attrsBuilder, callingUid, serverPolicyChecker);
BinderTransportSecurity.attachAuthAttrs(
attrsBuilder,
callingUid,
serverPolicyChecker,
checkNotNull(executor, "Not started?"));
// Create a new transport and let our listener know about it.
BinderTransport.BinderServerTransport transport =
new BinderTransport.BinderServerTransport(
Expand Down Expand Up @@ -202,12 +217,12 @@ public boolean handleTransaction(int code, Parcel parcel) {
public static class Builder {
@Nullable AndroidComponentAddress listenAddress;
@Nullable List<? extends ServerStreamTracer.Factory> streamTracerFactories;
@Nullable ObjectPool<? extends Executor> executorPool;

ObjectPool<ScheduledExecutorService> executorServicePool =
SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE);
ServerSecurityPolicy serverSecurityPolicy = SecurityPolicies.serverInternalOnly();
InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT;
Runnable terminationListener = () -> {};

public BinderServer build() {
return new BinderServer(this);
Expand Down Expand Up @@ -236,6 +251,16 @@ public Builder setStreamTracerFactories(List<? extends ServerStreamTracer.Factor
return this;
}

/**
* Sets the executor to be used for calling into the application.
*
* <p>Required.
*/
public Builder setExecutorPool(ObjectPool<? extends Executor> executorPool) {
this.executorPool = executorPool;
return this;
}

/**
* Sets the executor to be used for scheduling channel timers.
*
Expand Down Expand Up @@ -266,16 +291,5 @@ public Builder setInboundParcelablePolicy(InboundParcelablePolicy inboundParcela
this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy");
return this;
}

/**
* Installs a callback that will be invoked when this server is {@link #shutdown()} and all of
* its transports are terminated.
*
* <p>Optional.
*/
public Builder setTerminationListener(Runnable terminationListener) {
this.terminationListener = checkNotNull(terminationListener, "terminationListener");
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;

import io.grpc.Attributes;
import io.grpc.Internal;
import io.grpc.Metadata;
Expand All @@ -32,7 +31,6 @@
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.internal.GrpcAttributes;

import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
Expand All @@ -57,11 +55,10 @@ private BinderTransportSecurity() {}
* Install a security policy on an about-to-be created server.
*
* @param serverBuilder The ServerBuilder being used to create the server.
* @param executor The executor in which the authorization result will be handled.
*/
@Internal
public static void installAuthInterceptor(ServerBuilder<?> serverBuilder, Executor executor) {
serverBuilder.intercept(new ServerAuthInterceptor(executor));
public static void installAuthInterceptor(ServerBuilder<?> serverBuilder) {
serverBuilder.intercept(new ServerAuthInterceptor());
}

/**
Expand All @@ -71,14 +68,18 @@ public static void installAuthInterceptor(ServerBuilder<?> serverBuilder, Execut
* @param builder The {@link Attributes.Builder} for the transport being created.
* @param remoteUid The remote UID of the transport.
* @param serverPolicyChecker The policy checker for this transport.
* @param executor used for calling into the application. Must outlive the transport.
*/
@Internal
public static void attachAuthAttrs(
Attributes.Builder builder, int remoteUid, ServerPolicyChecker serverPolicyChecker) {
Attributes.Builder builder,
int remoteUid,
ServerPolicyChecker serverPolicyChecker,
Executor executor) {
builder
.set(
TRANSPORT_AUTHORIZATION_STATE,
new TransportAuthorizationState(remoteUid, serverPolicyChecker))
new TransportAuthorizationState(remoteUid, serverPolicyChecker, executor))
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY);
}

Expand All @@ -88,25 +89,20 @@ public static void attachAuthAttrs(
*/
private static final class ServerAuthInterceptor implements ServerInterceptor {

private final Executor executor;

ServerAuthInterceptor(Executor executor) {
this.executor = executor;
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
TransportAuthorizationState transportAuthState =
call.getAttributes().get(TRANSPORT_AUTHORIZATION_STATE);
ListenableFuture<Status> authStatusFuture =
call.getAttributes()
.get(TRANSPORT_AUTHORIZATION_STATE)
.checkAuthorization(call.getMethodDescriptor());
transportAuthState.checkAuthorization(call.getMethodDescriptor());

// Most SecurityPolicy will have synchronous implementations that provide an
// immediately-resolved Future. In that case, short-circuit to avoid unnecessary allocations
// and asynchronous code if the authorization result is already present.
if (!authStatusFuture.isDone()) {
return newServerCallListenerForPendingAuthResult(authStatusFuture, call, headers, next);
return newServerCallListenerForPendingAuthResult(
authStatusFuture, transportAuthState.executor, call, headers, next);
}

Status authStatus;
Expand All @@ -130,31 +126,33 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
}

private <ReqT, RespT> ServerCall.Listener<ReqT> newServerCallListenerForPendingAuthResult(
ListenableFuture<Status> authStatusFuture,
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
ListenableFuture<Status> authStatusFuture,
Executor executor,
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
PendingAuthListener<ReqT, RespT> listener = new PendingAuthListener<>();
Futures.addCallback(
authStatusFuture,
new FutureCallback<Status>() {
@Override
public void onSuccess(Status authStatus) {
if (!authStatus.isOk()) {
call.close(authStatus, new Metadata());
return;
}

listener.startCall(call, headers, next);
}

@Override
public void onFailure(Throwable t) {
call.close(
Status.INTERNAL.withCause(t).withDescription("Authorization future failed"),
new Metadata());
}
}, executor);
authStatusFuture,
new FutureCallback<Status>() {
@Override
public void onSuccess(Status authStatus) {
if (!authStatus.isOk()) {
call.close(authStatus, new Metadata());
return;
}

listener.startCall(call, headers, next);
}

@Override
public void onFailure(Throwable t) {
call.close(
Status.INTERNAL.withCause(t).withDescription("Authorization future failed"),
new Metadata());
}
},
executor);
return listener;
}
}
Expand All @@ -167,10 +165,16 @@ private static final class TransportAuthorizationState {
private final int uid;
private final ServerPolicyChecker serverPolicyChecker;
private final ConcurrentHashMap<String, ListenableFuture<Status>> serviceAuthorization;
private final Executor executor;

TransportAuthorizationState(int uid, ServerPolicyChecker serverPolicyChecker) {
/**
* @param executor used for calling into the application. Must outlive the transport.
*/
TransportAuthorizationState(
int uid, ServerPolicyChecker serverPolicyChecker, Executor executor) {
this.uid = uid;
this.serverPolicyChecker = serverPolicyChecker;
this.executor = executor;
serviceAuthorization = new ConcurrentHashMap<>(8);
}

Expand Down

0 comments on commit 15ad9f5

Please sign in to comment.