Skip to content

Commit

Permalink
closes #231, closes #232
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Furer committed Aug 24, 2021
1 parent 71ab5a2 commit 0550964
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -52,7 +53,7 @@ public static void startConsul(){
@AfterClass
public static void clear(){
System.clearProperty("spring.cloud.consul.port");
consul.close();
Optional.ofNullable(consul).ifPresent(ConsulProcess::close);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,26 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

import com.google.protobuf.Empty;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.lognet.springboot.grpc.GRpcErrorHandler;
import org.lognet.springboot.grpc.GRpcGlobalInterceptor;
import org.lognet.springboot.grpc.GrpcServerTestBase;
import org.lognet.springboot.grpc.demo.DemoApp;
import org.lognet.springboot.grpc.security.AuthCallCredentials;
Expand Down Expand Up @@ -55,13 +60,52 @@ static class TestCfg extends GrpcSecurityConfigurerAdapter {
public void configure(GrpcSecurity builder) throws Exception {
builder.authorizeRequests()
.withSecuredAnnotation()
.userDetailsService(new InMemoryUserDetailsManager());
.userDetailsService(new InMemoryUserDetailsManager(
User.withDefaultPasswordEncoder()
.username("user")
.password("user")
.authorities("SCOPE_profile")
.build()
));
}
@GRpcGlobalInterceptor
@Bean
public ServerInterceptor customInterceptor(){
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
if(io.grpc.examples.SecuredGreeterGrpc.getSayAuthHello2Method().equals(call.getMethodDescriptor())) {
final Status status = Status.ALREADY_EXISTS;
call.close(status, new Metadata());
throw status.asRuntimeException();
}
return next.startCall(call,headers);
}
};
}

}

@SpyBean
private GRpcErrorHandler errorHandler;

@Test
public void originalCustomInterceptorStatusIsPreserved() {
AuthCallCredentials callCredentials = new AuthCallCredentials(
AuthHeader.builder()
.basic("user","user".getBytes(StandardCharsets.UTF_8))
);



final StatusRuntimeException statusRuntimeException = Assert.assertThrows(StatusRuntimeException.class, () -> {
io.grpc.examples.SecuredGreeterGrpc.newBlockingStub(selectedChanel)
.withCallCredentials(callCredentials)
.sayAuthHello2(Empty.newBuilder().build()).getMessage();
});
assertThat(statusRuntimeException.getStatus().getCode(), Matchers.is(Status.Code.ALREADY_EXISTS));
verifyZeroInteractions(errorHandler);
}
@Test
public void unsupportedAuthSchemeShouldThrowUnauthenticatedException() {
AuthCallCredentials callCredentials = new AuthCallCredentials(
Expand Down
2 changes: 1 addition & 1 deletion grpc-spring-boot-starter-native-demo/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ buildscript {

dependencies {
classpath("org.springframework.boot:spring-boot-gradle-plugin:${springBootVersion}")
classpath("org.springframework.experimental:spring-aot-gradle-plugin:0.10.0")
classpath("org.springframework.experimental:spring-aot-gradle-plugin:0.10.3")
}
}

Expand Down
2 changes: 1 addition & 1 deletion grpc-spring-boot-starter/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ jar {


// generate native config by executing test app in graalvm docker image
// then grab generating json files and filter to include class from `org.lognet.springboot.grpc` package only
// then grab generated json files and filter to include class from `org.lognet.springboot.grpc` package only
dependsOn testDependencyFatJar
File aotDir = new File(buildDir, "native-configs");
File transformed = new File(aotDir, "transformed")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@
import io.grpc.StatusRuntimeException;

public interface FailureHandlingServerInterceptor extends ServerInterceptor {
default StatusRuntimeException closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?> call, Metadata headers, final Status status, Exception exception){
default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?> call, Metadata headers, final Status status, Exception exception){

final Metadata responseHeaders = new Metadata();
Status statusToSend;
if(null==o){
statusToSend = errorHandler.handle(status, exception, headers, responseHeaders);
}else {
statusToSend = errorHandler.handle(o,status, exception, headers, responseHeaders);
}

Status statusToSend = errorHandler.handle(o,status, exception, headers, responseHeaders);
call.close(statusToSend, responseHeaders);
return statusToSend.asRuntimeException(responseHeaders);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import io.grpc.Metadata;
import io.grpc.Status;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class GRpcErrorHandler {

/**
Expand All @@ -27,6 +29,7 @@ public Status handle(Status status, Exception exception, Metadata requestHeaders
* @return
*/
public Status handle(Object message,Status status, Exception exception, Metadata requestHeaders, Metadata responseHeaders) {
log.error("Got error with status {} ",status.getCode().name(),exception);
return status.withDescription(exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
public class SecurityInterceptor extends AbstractSecurityInterceptor implements FailureHandlingServerInterceptor, Ordered {


private final GrpcSecurityMetadataSource securedMethods;


private final GrpcSecurityMetadataSource securedMethods;

private final AuthenticationSchemeSelector schemeSelector;
private final AuthenticationSchemeSelector schemeSelector;

private GRpcServerProperties.SecurityProperties.Auth authCfg;

Expand All @@ -46,7 +44,8 @@ public SecurityInterceptor(GrpcSecurityMetadataSource securedMethods, Authentica

@Autowired
public void setErrorHandler(Optional<GRpcErrorHandler> errorHandler) {
this.errorHandler = errorHandler.orElseGet(()->new GRpcErrorHandler() {});
this.errorHandler = errorHandler.orElseGet(() -> new GRpcErrorHandler() {
});
}

public void setConfig(GRpcServerProperties.SecurityProperties.Auth authCfg) {
Expand All @@ -55,7 +54,7 @@ public void setConfig(GRpcServerProperties.SecurityProperties.Auth authCfg) {

@Override
public int getOrder() {
return Optional.ofNullable(authCfg.getInterceptorOrder()).orElse(Ordered.HIGHEST_PRECEDENCE);
return Optional.ofNullable(authCfg.getInterceptorOrder()).orElse(Ordered.HIGHEST_PRECEDENCE+1);
}

@Override
Expand All @@ -80,44 +79,51 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
.orElse(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));

try {
final Authentication authentication = null == authorization ? null :
schemeSelector.getAuthScheme(authorization)
.orElseThrow(() -> new RuntimeException("Can't get authentication from authorization header"));
final Context grpcSecurityContext;
try {
grpcSecurityContext = setupGRpcSecurityContext(call, authorization);
} catch (AccessDeniedException e) {
return fail(next, call, headers, Status.PERMISSION_DENIED, e);
} catch (Exception e) {
return fail(next, call, headers, Status.UNAUTHENTICATED, e);
}
return Contexts.interceptCall(grpcSecurityContext, call, headers, next);
} finally {
SecurityContextHolder.getContext().setAuthentication(null);
}

SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);

beforeInvocation(call.getMethodDescriptor());
}

Context ctx = Context.current()
.withValue(GrpcSecurity.AUTHENTICATION_CONTEXT_KEY, SecurityContextHolder.getContext().getAuthentication());
private Context setupGRpcSecurityContext(ServerCall<?, ?> call, CharSequence authorization) {
final Authentication authentication = null == authorization ? null :
schemeSelector.getAuthScheme(authorization)
.orElseThrow(() -> new RuntimeException("Can't get authentication from authorization header"));

return Contexts.interceptCall(ctx, call, headers, next);
} catch (AccessDeniedException e) {
return fail(next, call, headers, Status.PERMISSION_DENIED, e);
} catch (Exception e) {
return fail(next, call, headers, Status.UNAUTHENTICATED, e);
} finally {
SecurityContextHolder.getContext().setAuthentication(null);
}
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);

beforeInvocation(call.getMethodDescriptor());

return Context.current()
.withValue(GrpcSecurity.AUTHENTICATION_CONTEXT_KEY, SecurityContextHolder.getContext().getAuthentication());
}

private <RespT, ReqT> ServerCall.Listener<ReqT> fail(ServerCallHandler<ReqT, RespT> next, ServerCall<ReqT, RespT> call, Metadata headers,final Status status, Exception exception) {
private <RespT, ReqT> ServerCall.Listener<ReqT> fail(ServerCallHandler<ReqT, RespT> next, ServerCall<ReqT, RespT> call, Metadata headers, final Status status, Exception exception) {

if (authCfg.isFailFast()) {
throw closeCall(null,errorHandler,call,headers,status,exception);
closeCall(null, errorHandler, call, headers, status, exception);
return new ServerCall.Listener<ReqT>() {
// noop
};

} else {

return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call,headers)) {
} else {
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
@Override
public void onMessage(ReqT message) {
throw closeCall(message, errorHandler, call, headers, status, exception);


closeCall(message, errorHandler, call, headers, status, exception);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
public void sendMessage(RespT message) {
final Set<ConstraintViolation<RespT>> violations = validator.validate(message, ResponseMessage.class);
if (!violations.isEmpty()) {
throw closeCall(message, errorHandler, delegate(), headers, Status.FAILED_PRECONDITION, new ConstraintViolationException(violations));
closeCall(message, errorHandler, delegate(), headers, Status.FAILED_PRECONDITION, new ConstraintViolationException(violations));
} else {
super.sendMessage(message);
}
Expand All @@ -59,7 +59,7 @@ public void sendMessage(RespT message) {
public void onMessage(ReqT message) {
final Set<ConstraintViolation<ReqT>> violations = validator.validate(message, RequestMessage.class);
if (!violations.isEmpty()) {
throw closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));

} else {
super.onMessage(message);
Expand Down

0 comments on commit 0550964

Please sign in to comment.