Skip to content

Commit

Permalink
fix: check for nullpointer exception when jwk key can't be retrieved (#…
Browse files Browse the repository at this point in the history
…3503)

* check for nullpointer ex when jwk key can't be retrieved

Signed-off-by: at670475 <[email protected]>

* add test

Signed-off-by: at670475 <[email protected]>

* address comment

Signed-off-by: at670475 <[email protected]>

---------

Signed-off-by: at670475 <[email protected]>
  • Loading branch information
taban03 authored Apr 11, 2024
1 parent 2228d5b commit 7c00dba
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Clock;
import io.jsonwebtoken.JwtException;
Expand Down Expand Up @@ -105,7 +107,11 @@ void fetchJWKSet() {

private Map<String, Key> processKeys(JWKSet jwkKeys) {
return jwkKeys.getKeys().stream()
.filter(jwkKey -> "sig".equals(jwkKey.getKeyUse().getValue()) && "RSA".equals(jwkKey.getKeyType().getValue()))
.filter(jwkKey -> {
KeyUse keyUse = jwkKey.getKeyUse();
KeyType keyType = jwkKey.getKeyType();
return keyUse != null && keyType != null && "sig".equals(keyUse.getValue()) && "RSA".equals(keyType.getValue());
})
.collect(Collectors.toMap(JWK::getKeyID, jwkKey -> {
try {
return jwkKey.toRSAKey().toRSAPublicKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.Requirement;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.*;
import io.jsonwebtoken.impl.DefaultClock;
import io.jsonwebtoken.impl.FixedClock;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -37,21 +33,12 @@
import java.security.Key;
import java.text.ParseException;
import java.time.Instant;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.*;

import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
class OIDCTokenProviderTest {
Expand Down Expand Up @@ -217,6 +204,50 @@ void shouldNotModifyJwksUri() {
}
}

@Test
void shouldHandleNullPointer_whenJWKKeyNull() {

JWKSet mockedJwtSet = mock(JWKSet.class);
List<JWK> mockedKeys = new ArrayList<>();
JWK mockedJwk = mock(JWK.class);
when(mockedJwk.getKeyUse()).thenReturn(null);
RSAKey rsaKey = mock(RSAKey.class);
mockedKeys.add(mockedJwk);
when(mockedJwtSet.getKeys()).thenReturn(mockedKeys);

try (MockedStatic<JWKSet> mockedStatic = Mockito.mockStatic(JWKSet.class)) {
mockedStatic.when(() -> JWKSet.load(any(URL.class))).thenReturn(mockedJwtSet);

oidcTokenProvider.fetchJWKSet();

verify(rsaKey, never()).toRSAPublicKey();
} catch (JOSEException e) {
fail("Exception thrown: " + e.getMessage());
}
}

@Test
void shouldHandleNullPointer_whenJWKTypeNull() {

JWKSet mockedJwtSet = mock(JWKSet.class);
List<JWK> mockedKeys = new ArrayList<>();
JWK mockedJwk = mock(JWK.class);
when(mockedJwk.getKeyType()).thenReturn(null);
RSAKey rsaKey = mock(RSAKey.class);
mockedKeys.add(mockedJwk);
when(mockedJwtSet.getKeys()).thenReturn(mockedKeys);

try (MockedStatic<JWKSet> mockedStatic = Mockito.mockStatic(JWKSet.class)) {
mockedStatic.when(() -> JWKSet.load(any(URL.class))).thenReturn(mockedJwtSet);

oidcTokenProvider.fetchJWKSet();

verify(rsaKey, never()).toRSAPublicKey();
} catch (JOSEException e) {
fail("Exception thrown: " + e.getMessage());
}
}

@Test
void throwsCorrectException() throws JOSEException {

Expand Down

0 comments on commit 7c00dba

Please sign in to comment.