Skip to content

Commit

Permalink
Merge pull request #42599 from Malandril/fix-augment-bug
Browse files Browse the repository at this point in the history
Fix augmentors called multiple times for each identity provider
  • Loading branch information
gsmet authored Aug 28, 2024
2 parents 4b1ae89 + 1609bee commit a0906bb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
5 changes: 5 additions & 0 deletions extensions/security/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public Uni<SecurityIdentity> authenticate(AuthenticationRequest request) {
if (providers.size() == 1) {
return handleSingleProvider(providers.get(0), request);
}
return handleProvider(0, (List) providers, request);
return handleProviders((List) providers, request);
} catch (Throwable t) {
return Uni.createFrom().failure(t);
}
Expand Down Expand Up @@ -106,18 +106,30 @@ public SecurityIdentity authenticateBlocking(AuthenticationRequest request) {
throw new IllegalArgumentException(
"No IdentityProviders were registered to handle AuthenticationRequest " + request);
}
return (SecurityIdentity) handleProvider(0, (List) providers, request).await().indefinitely();
return (SecurityIdentity) handleProviders((List) providers, request).await().indefinitely();
}

private <T extends AuthenticationRequest> Uni<SecurityIdentity> handleProvider(int pos,
private <T extends AuthenticationRequest> Uni<SecurityIdentity> handleProviders(
List<IdentityProvider<T>> providers, T request) {
return handleProvider(0, providers, request)
.onItem()
.transformToUni(new Function<SecurityIdentity, Uni<? extends SecurityIdentity>>() {
@Override
public Uni<? extends SecurityIdentity> apply(SecurityIdentity securityIdentity) {
return handleIdentityFromProvider(0, securityIdentity, request.getAttributes());
}
});
}

private <T extends AuthenticationRequest> Uni<SecurityIdentity> handleProvider(int pos, List<IdentityProvider<T>> providers,
T request) {
if (pos == providers.size()) {
//we failed to authentication
log.debug("Authentication failed as providers would authenticate the request");
return Uni.createFrom().failure(new AuthenticationFailedException());
}
IdentityProvider<T> current = providers.get(pos);
Uni<SecurityIdentity> cs = current.authenticate(request, blockingRequestContext)
return current.authenticate(request, blockingRequestContext)
.onItem().transformToUni(new Function<SecurityIdentity, Uni<? extends SecurityIdentity>>() {
@Override
public Uni<SecurityIdentity> apply(SecurityIdentity securityIdentity) {
Expand All @@ -127,12 +139,6 @@ public Uni<SecurityIdentity> apply(SecurityIdentity securityIdentity) {
return handleProvider(pos + 1, providers, request);
}
});
return cs.onItem().transformToUni(new Function<SecurityIdentity, Uni<? extends SecurityIdentity>>() {
@Override
public Uni<? extends SecurityIdentity> apply(SecurityIdentity securityIdentity) {
return handleIdentityFromProvider(0, securityIdentity, request.getAttributes());
}
});
}

private Uni<SecurityIdentity> handleIdentityFromProvider(int pos, SecurityIdentity identity,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package io.quarkus.security.runtime;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.concurrent.Executors;

Expand All @@ -11,6 +16,7 @@
import io.quarkus.security.identity.IdentityProvider;
import io.quarkus.security.identity.IdentityProviderManager;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.security.identity.SecurityIdentityAugmentor;
import io.quarkus.security.identity.request.BaseAuthenticationRequest;
import io.smallrye.mutiny.Uni;

Expand All @@ -32,6 +38,21 @@ void testIdentityProviderPriority() {
assertEquals(new QuarkusPrincipal("Bob"), identity.getPrincipal());
}

@Test
void testIdentityProviderAugmentOnlyOnce() {
TestSecurityAugmentor augmentor = spy(new TestSecurityAugmentor());
IdentityProviderManager identityProviderManager = QuarkusIdentityProviderManagerImpl.builder()
.addProvider(new TestIdentityProviderSystemFirstPriorityNoop())
.addProvider(new TestIdentityProviderSystemLastPriorityUser())
.addProvider(new AnonymousIdentityProvider())
.addSecurityIdentityAugmentor(augmentor)
.setBlockingExecutor(Executors.newSingleThreadExecutor()).build();
SecurityIdentity identity = identityProviderManager.authenticateBlocking(new TestAuthenticationRequest());
assertEquals(new QuarkusPrincipal("Bob"), identity.getPrincipal());
assertTrue(identity.getRoles().contains("role"));
verify(augmentor, times(1)).augment(any(), any());
}

static class TestAuthenticationRequest extends BaseAuthenticationRequest {
}

Expand Down Expand Up @@ -78,4 +99,26 @@ public int priority() {
return SYSTEM_LAST;
}
}

static class TestIdentityProviderSystemFirstPriorityNoop extends TestIdentityProviderSystemLastPriority {
@Override
public Uni<SecurityIdentity> authenticate(TestAuthenticationRequest request, AuthenticationRequestContext context) {
return Uni.createFrom().nullItem();
}
}

static class TestIdentityProviderSystemLastPriorityUser extends TestIdentityProviderUserFirstPriority {
@Override
public int priority() {
return SYSTEM_LAST;
}
}

private static class TestSecurityAugmentor implements SecurityIdentityAugmentor {
@Override
public Uni<SecurityIdentity> augment(SecurityIdentity securityIdentity,
AuthenticationRequestContext authenticationRequestContext) {
return Uni.createFrom().item(QuarkusSecurityIdentity.builder(securityIdentity).addRole("role").build());
}
}
}

0 comments on commit a0906bb

Please sign in to comment.