diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/auth/JaasAuthProviderTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/auth/JaasAuthProviderTest.java index 6b5a0800811b..db510156959d 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/auth/JaasAuthProviderTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/auth/JaasAuthProviderTest.java @@ -38,6 +38,7 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.stream.Stream; import javax.security.auth.Subject; import javax.security.auth.login.LoginContext; @@ -47,9 +48,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; @RunWith(MockitoJUnitRunner.class) public class JaasAuthProviderTest { @@ -93,9 +92,9 @@ public void setUp() throws Exception { } @Test - public void shouldAuthenticateWithWildcardAllowedRole() { + public void shouldAuthenticateWithWildcardAllowedRole() throws Exception { // Given: - givenAllowedRoles("*"); + givenAllowedRoles("**"); givenUserRoles(); // When: @@ -106,7 +105,7 @@ public void shouldAuthenticateWithWildcardAllowedRole() { } @Test - public void shouldAuthenticateWithNonWildcardRole() { + public void shouldAuthenticateWithNonWildcardRole() throws Exception { // Given: givenAllowedRoles("user"); givenUserRoles("user"); @@ -119,7 +118,7 @@ public void shouldAuthenticateWithNonWildcardRole() { } @Test - public void shouldAuthenticateWithAdditionalAllowedRoles() { + public void shouldAuthenticateWithAdditionalAllowedRoles() throws Exception { // Given: givenAllowedRoles("user", "other"); givenUserRoles("user"); @@ -132,7 +131,7 @@ public void shouldAuthenticateWithAdditionalAllowedRoles() { } @Test - public void shouldAuthenticateWithExtraRoles() { + public void shouldAuthenticateWithExtraRoles() throws Exception { // Given: givenAllowedRoles("user"); givenUserRoles("user", "other"); @@ -169,7 +168,7 @@ public void shouldFailToAuthenticateOnMissingPassword() { } @Test - public void shouldFailToAuthenticateWithNoRole() { + public void shouldFailToAuthenticateWithNoRole() throws Exception { // Given: givenAllowedRoles("user"); givenUserRoles(); @@ -182,7 +181,7 @@ public void shouldFailToAuthenticateWithNoRole() { } @Test - public void shouldFailToAuthenticateWithNonAllowedRole() { + public void shouldFailToAuthenticateWithNonAllowedRole() throws Exception { // Given: givenAllowedRoles("user"); givenUserRoles("other"); @@ -197,6 +196,8 @@ public void shouldFailToAuthenticateWithNonAllowedRole() { private void givenAllowedRoles(final String... roles) { when(config.getList(KsqlRestConfig.AUTHENTICATION_ROLES_CONFIG)) .thenReturn(Arrays.asList(roles)); + + authProvider = new JaasAuthProvider(server, config, loginContextSupplier); } private void givenUserRoles(final String... roles) { @@ -207,23 +208,32 @@ private void givenUserRoles(final String... roles) { when(subject.getPrincipals()).thenReturn(principals); } - private void verifyAuthorizedSuccessfulLogin() { + private void verifyAuthorizedSuccessfulLogin() throws Exception { verifyLoginSuccessWithAuthorization(true); } - private void verifyUnauthorizedSuccessfulLogin() { + private void verifyUnauthorizedSuccessfulLogin() throws Exception { verifyLoginSuccessWithAuthorization(false); } - private void verifyLoginSuccessWithAuthorization(final boolean isAuthorized) { + private void verifyLoginSuccessWithAuthorization(final boolean isAuthorized) throws Exception { verify(userHandler).handle(userCaptor.capture()); final AsyncResult result = userCaptor.getValue(); assertThat(result.succeeded(), is(true)); + assertThat(result.result(), instanceOf(JaasUser.class)); final JaasUser user = (JaasUser) result.result(); + assertThat(user.getPrincipal(), instanceOf(JaasPrincipal.class)); final JaasPrincipal apiPrincipal = (JaasPrincipal) user.getPrincipal(); assertThat(apiPrincipal.getName(), is(USERNAME)); + + final CountDownLatch latch = new CountDownLatch(1); + user.doIsPermitted("some permission", ar -> { + assertThat(ar.result(), is(isAuthorized)); + latch.countDown(); + }); + latch.await(); } private void verifyLoginFailure(final String expectedMsg) { @@ -241,16 +251,13 @@ private static Principal principalWithName(final String name) { private void handleAsyncExecution() { when(server.getWorkerExecutor()).thenReturn(worker); - doAnswer(new Answer() { - @Override - public Void answer(final InvocationOnMock invocation) throws Throwable { - final Handler> blockingCodeHandler = invocation.getArgument(0); - final Handler> resultHandler = invocation.getArgument(1); - final Promise promise = Promise.promise(); - promise.future().onComplete(resultHandler); - blockingCodeHandler.handle(promise); - return null; - } + doAnswer(invocation -> { + final Handler> blockingCodeHandler = invocation.getArgument(0); + final Handler> resultHandler = invocation.getArgument(1); + final Promise promise = Promise.promise(); + promise.future().onComplete(resultHandler); + blockingCodeHandler.handle(promise); + return null; }).when(worker).executeBlocking(any(), any()); } }