Skip to content

Commit

Permalink
Reduce the scope of the ThreadLocals used in SecurityManagerTestSupport
Browse files Browse the repository at this point in the history
Individual test methods can scope the usage of a SecurityManger and Subject to specific methods.
  • Loading branch information
bdemers committed Jun 17, 2022
1 parent 1aa024b commit ca87a3a
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ public Logical logical() {
}
};

assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresPermissionAnnotation));
runWithSubject(subject -> {
assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresPermissionAnnotation));
});
}

//Added to satisfy SHIRO-146
Expand All @@ -82,7 +84,9 @@ public Logical logical() {
}
};

assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresPermissionAnnotation));
runWithSubject(subject -> {
assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresPermissionAnnotation));
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ public Logical logical() {
}
};

assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresRolesAnnotation));
runWithSubject(subject -> {
assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresRolesAnnotation));
});
}

//Added to satisfy SHIRO-146
Expand All @@ -87,7 +89,9 @@ public Logical logical() {
}
};

assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresRolesAnnotation));
runWithSubject(subject -> {
assertThrows(UnauthenticatedException.class, () -> handler.assertAuthorized(requiresRolesAnnotation));
});
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ public void testSubmitRunnable() {

final SubjectAwareExecutorService executor = new SubjectAwareExecutorService(mockExecutorService);

Runnable testRunnable = () -> System.out.println("Hello World");
runWithSubject(subject -> {
Runnable testRunnable = () -> System.out.println("Hello World");

executor.submit(testRunnable);
SubjectRunnable subjectRunnable = captor.getValue();
Assertions.assertNotNull(subjectRunnable);
executor.submit(testRunnable);
SubjectRunnable subjectRunnable = captor.getValue();
Assertions.assertNotNull(subjectRunnable);
});
}

private static class DummyFuture<V> implements Future<V> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void testExecute() {
final SubjectAwareExecutor executor = new SubjectAwareExecutor(targetMockExecutor);

Runnable work = () -> System.out.println("Hello World");
executor.execute(work);
runWithSubject(subject -> executor.execute(work));

//* ensure the target Executor receives a SubjectRunnable instance that retains the subject identity:
//(this is what verifies the test is valid):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import org.apache.shiro.realm.text.IniRealm;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;

import java.util.Objects;

/**
* Utility methods for use by Shiro test case subclasses. You can use these methods as examples for your own
Expand Down Expand Up @@ -64,13 +64,38 @@ protected Subject createAndBindTestSubject() {
return SecurityUtils.getSubject();
}

@BeforeEach
public void setup() {
createAndBindTestSubject();
protected Subject createSubject() {
SecurityManager securityManager = createTestSecurityManager();
return new Subject.Builder(securityManager).buildSubject();
}

/**
* Associates the {@code consumer} with the {@code subject} and executes. If an exeception was thrown by the
* consumer, it is re-thrown by this method.
* @param subject The subject to bind to the current thread.
* @param consumer The block of code to run under the context of the subject.
* @throws Exception propagates any exception thrown by the consumer.
*/
protected void runWithSubject(Subject subject, SubjectConsumer consumer) {
Exception exception = subject.execute(() -> {
try {
consumer.accept(subject);
return null;
} catch (Exception e) {
return e;
}
});
if (Objects.nonNull(exception)) {
throw new RuntimeException("Test execution threw exception", exception);
}
}

protected void runWithSubject(SubjectConsumer consumer) {
runWithSubject(createSubject(), consumer);
}

@AfterEach
public void teardown() {
ThreadContext.remove();
@FunctionalInterface
protected interface SubjectConsumer {
void accept(Subject subject) throws Exception;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {

HttpServletRequest request = mockRequest()
HttpServletResponse response = mockResponse()

AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is(""))

runWithSubject({
AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is(""))
})

verify(request, response)
}
Expand All @@ -57,10 +59,12 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {

HttpServletRequest request = mockRequest("")
HttpServletResponse response = mockResponse()

AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is(""))

runWithSubject({
AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is(""))
})

verify(request, response)
}
Expand All @@ -72,9 +76,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {
HttpServletRequest request = mockRequest("some-value")
HttpServletResponse response = mockResponse()

AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is("some-value"))
runWithSubject({
AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is("some-value"))
})

verify(request, response)
}
Expand All @@ -86,9 +92,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {
HttpServletRequest request = mockRequest(" ")
HttpServletResponse response = mockResponse()

AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is(""))
runWithSubject({
AuthenticationToken token = testFilter.createToken(request, response)
assertThat(token, CoreMatchers.instanceOf(BearerToken.class))
assertThat(token.getPrincipal(), Matchers.is(""))
})

verify(request, response)
}
Expand All @@ -104,9 +112,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {
HttpServletResponse response = createMock(HttpServletResponse.class)
replay(response)

String[] methods = [ "POST", "PUT", "DELETE" ]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, methods)
assertThat("Access not allowed for GET", accessAllowed)
runWithSubject({
String[] methods = ["POST", "PUT", "DELETE"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, methods)
assertThat("Access not allowed for GET", accessAllowed)
})
verify(request, response)
}

Expand All @@ -120,9 +130,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {

HttpServletResponse response = mockResponse()

String[] methods = [ "POST", "PUT", "DELETE" ]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, methods)
assertThat("Access allowed for POST", !accessAllowed)
runWithSubject({
String[] methods = ["POST", "PUT", "DELETE"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, methods)
assertThat("Access allowed for POST", !accessAllowed)
})
}

@Test
Expand All @@ -135,9 +147,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {

HttpServletResponse response = mockResponse()

String[] mappedValue = ["permissive"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, mappedValue)
assertThat("Access allowed for GET", !accessAllowed) // login attempt should always be false
runWithSubject({
String[] mappedValue = ["permissive"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, mappedValue)
assertThat("Access allowed for GET", !accessAllowed) // login attempt should always be false
})
}

@Test
Expand All @@ -152,9 +166,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {

HttpServletResponse response = mockResponse()

String[] mappedValue = ["permissive"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, mappedValue)
assertThat("Access should be allowed for GET", accessAllowed) // non-login attempt, return true
runWithSubject({
String[] mappedValue = ["permissive"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, mappedValue)
assertThat("Access should be allowed for GET", accessAllowed) // non-login attempt, return true
})
}

@Test
Expand All @@ -167,9 +183,11 @@ class BearerHttpFilterAuthenticationTest extends SecurityManagerTestSupport {

HttpServletResponse response = mockResponse()

String[] mappedValue = ["permissive", "POST", "PUT", "DELETE" ]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, mappedValue)
assertThat("Access allowed for POST", !accessAllowed)
runWithSubject({
String[] mappedValue = ["permissive", "POST", "PUT", "DELETE"]
boolean accessAllowed = testFilter.isAccessAllowed(request, response, mappedValue)
assertThat("Access allowed for POST", !accessAllowed)
})
}

static private String createAuthorizationHeader(String token) {
Expand Down
Loading

0 comments on commit ca87a3a

Please sign in to comment.