Skip to content

Commit

Permalink
Address review comments - Change groups from List<String> to Set<String>
Browse files Browse the repository at this point in the history
Signed-off-by: Marko Strukelj <[email protected]>
  • Loading branch information
mstruk committed Jan 14, 2022
1 parent 2e2d824 commit a2fe9bf
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 35 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ See [JsonPathFilterQuery JavaDoc](oauth-common/src/main/java/io/strimzi/kafka/oa

When using custom authorization (by installing a custom authorizer) you may want to take user's group membership into account when making the authorization decisions.
One way is to obtain and inspect a parsed JWT token from `io.strimzi.kafka.oauth.server.OAuthKafkaPrincipal` object available through `AuthorizableRequestContext` passed to your `authorize()` method.
Another way is to configure group extraction at authentication time, and get groups as a list of principals from `OAuthKafkaPrincipal` object.
Another way is to configure group extraction at authentication time, and get groups as a set of principals from `OAuthKafkaPrincipal` object.
There are two configuration parameters for configuring group extraction:

- `oauth.groups.claim` (e.g.: `$.roles.client-roles.kafka`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.List;
import java.util.Set;

/**
* This extension of OAuthBearerToken provides a way to associate any additional information with the token
Expand Down Expand Up @@ -38,13 +38,11 @@ public interface BearerTokenWithPayload extends OAuthBearerToken {
void setPayload(Object payload);

/**
* Get groups associated with this token (principal). Logically, groups should be considered a Set.
* However, depending on the infrastructure (e.g. authorization server used), the order may be predictable and configurable,
* and could be used during authorization in a custom authorizer.
* Get groups associated with this token (principal).
*
* @return The groups for the user
*/
List<String> getGroups();
Set<String> getGroups();

/**
* The token claims as a JSON object. For JWT tokens it contains the content of the JWT Payload part of the token.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class TokenInfo {
Expand All @@ -26,15 +25,15 @@ public class TokenInfo {
private Set<String> scopes = new HashSet<>();
private long expiresAt;
private String principal;
private List<String> groups;
private Set<String> groups;
private long issuedAt;
private ObjectNode payload;

public TokenInfo(JsonNode payload, String token, String principal) {
this(payload, token, principal, null);
}

public TokenInfo(JsonNode payload, String token, String principal, List<String> groups) {
public TokenInfo(JsonNode payload, String token, String principal, Set<String> groups) {
this(token,
payload.has(SCOPE) ? payload.get(SCOPE).asText() : null,
principal,
Expand All @@ -48,10 +47,10 @@ public TokenInfo(JsonNode payload, String token, String principal, List<String>
this.payload = (ObjectNode) payload;
}

public TokenInfo(String token, String scope, String principal, List<String> groups, long issuedAtMs, long expiresAtMs) {
public TokenInfo(String token, String scope, String principal, Set<String> groups, long issuedAtMs, long expiresAtMs) {
this.token = token;
this.principal = principal;
this.groups = groups != null ? Collections.unmodifiableList(groups) : null;
this.groups = groups != null ? Collections.unmodifiableSet(groups) : null;
this.issuedAt = issuedAtMs;
this.expiresAt = expiresAtMs;

Expand All @@ -78,7 +77,7 @@ public String principal() {
return principal;
}

public List<String> groups() {
public Set<String> groups() {
return groups;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -364,7 +365,7 @@ public TokenInfo validate(String token) {
}

String principal = extractPrincipal(t);
List<String> groups = extractGroups(t);
Set<String> groups = extractGroups(t);
return new TokenInfo(t, token, principal, groups);
}

Expand All @@ -383,7 +384,7 @@ private String extractPrincipal(JsonNode tokenJson) {
return principal;
}

private List<String> extractGroups(JsonNode tokenJson) {
private Set<String> extractGroups(JsonNode tokenJson) {
if (groupsQuery == null) {
return null;
}
Expand All @@ -393,9 +394,9 @@ private List<String> extractGroups(JsonNode tokenJson) {
}
List<String> groups = JSONUtil.asListOfString(result, groupsDelimiter != null ? groupsDelimiter : ",");
// sanitize the result
groups = groups.stream().map(String::trim).filter(v -> !v.isEmpty()).collect(Collectors.toList());
Set<String> groupSet = groups.stream().map(String::trim).filter(v -> !v.isEmpty()).collect(Collectors.toSet());

return groups.isEmpty() ? null : groups;
return groupSet.isEmpty() ? null : groupSet;
}

@SuppressWarnings({"deprecation", "unchecked"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static io.strimzi.kafka.oauth.common.HttpUtil.post;
Expand Down Expand Up @@ -256,7 +257,7 @@ public TokenInfo validate(String token) {
}
}
performOptionalChecks(response);
List<String> groups = null;
Set<String> groups = null;
if (groupsMatcher != null) {
groups = extractGroupsFromResponse(response);

Expand All @@ -271,15 +272,15 @@ public TokenInfo validate(String token) {
return new TokenInfo(token, scopes, principal, groups, iat, expiresMillis);
}

private List<String> extractGroupsFromResponse(JsonNode userInfoJson) {
private Set<String> extractGroupsFromResponse(JsonNode userInfoJson) {
JsonNode result = groupsMatcher.apply(userInfoJson);
if (result == null) {
return null;
}
List<String> groups = JSONUtil.asListOfString(result, groupsDelimiter != null ? groupsDelimiter : ",");

// sanitize the result
return groups.stream().map(String::trim).filter(v -> !v.isEmpty()).collect(Collectors.toList());
return groups.stream().map(String::trim).filter(v -> !v.isEmpty()).collect(Collectors.toSet());
}

private JsonNode getUserInfoEndpointResponse(String token) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ public void setPayload(Object value) {
}

@Override
public List<String> getGroups() {
public Set<String> getGroups() {
return ti.groups();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;

import java.util.Collections;
import java.util.List;
import java.util.Set;

import static io.strimzi.kafka.oauth.common.LogUtil.mask;

Expand All @@ -24,32 +24,32 @@
public final class OAuthKafkaPrincipal extends KafkaPrincipal {

private final BearerTokenWithPayload jwt;
private final List<String> groups;
private final Set<String> groups;

public OAuthKafkaPrincipal(String principalType, String name) {
this(principalType, name, (List<String>) null);
this(principalType, name, (Set<String>) null);
}

public OAuthKafkaPrincipal(String principalType, String name, List<String> groups) {
public OAuthKafkaPrincipal(String principalType, String name, Set<String> groups) {
super(principalType, name);
this.jwt = null;

this.groups = groups == null ? null : Collections.unmodifiableList(groups);
this.groups = groups == null ? null : Collections.unmodifiableSet(groups);
}

public OAuthKafkaPrincipal(String principalType, String name, BearerTokenWithPayload jwt) {
super(principalType, name);
this.jwt = jwt;
List<String> parsedGroups = jwt.getGroups();
Set<String> parsedGroups = jwt.getGroups();

this.groups = parsedGroups == null ? null : Collections.unmodifiableList(parsedGroups);
this.groups = parsedGroups == null ? null : Collections.unmodifiableSet(parsedGroups);
}

public BearerTokenWithPayload getJwt() {
return jwt;
}

public List<String> getGroups() {
public Set<String> getGroups() {
return groups;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class MockBearerTokenWithPayload implements BearerTokenWithPayload {


private final String principalName;
private final List<String> groups;
private final Set<String> groups;
private final long createTime;
private final long lifetime;
private final Set<String> scopes;
private final String token;
private Object payload;

MockBearerTokenWithPayload(String principalName, List<String> groups, long createTime, long lifetime, String scope, String token, Object payload) {
MockBearerTokenWithPayload(String principalName, Set<String> groups, long createTime, long lifetime, String scope, String token, Object payload) {
this.principalName = principalName;
this.groups = groups;
this.createTime = createTime;
Expand Down Expand Up @@ -51,7 +50,7 @@ public void setPayload(Object payload) {
}

@Override
public List<String> getGroups() {
public Set<String> getGroups() {
return groups;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;

public class OAuthKafkaPrincipalTest {

@Test
public void testEquals() {

BearerTokenWithPayload token = new MockBearerTokenWithPayload("service-account-my-client", Arrays.asList("group1", "group2"),
BearerTokenWithPayload token = new MockBearerTokenWithPayload("service-account-my-client", new HashSet(Arrays.asList("group1", "group2")),
System.currentTimeMillis(), System.currentTimeMillis() + 60000, null, "BEARER-TOKEN-9823eh982u", "Whatever");
OAuthKafkaPrincipal principal = new OAuthKafkaPrincipal("User", "service-account-my-client", token);


BearerTokenWithPayload token2 = new MockBearerTokenWithPayload("bob", Collections.emptyList(),
BearerTokenWithPayload token2 = new MockBearerTokenWithPayload("bob", Collections.emptySet(),
System.currentTimeMillis(), System.currentTimeMillis() + 60000, null, "BEARER-TOKEN-0000dd0000", null);
OAuthKafkaPrincipal principal2 = new OAuthKafkaPrincipal("User", "service-account-my-client", token2);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -103,7 +104,7 @@ private void testNonOAuthUserWithDelegate(Authorizer authorizer, MockAuthorizer
private void testOAuthUserWithDelegate(Authorizer authorizer, MockAuthorizer delegateAuthorizer) throws Exception {

// Prepare condition after mock OAuth athentication with valid token
TokenInfo tokenInfo = new TokenInfo("accesstoken123", null, "User:bob", Arrays.asList("group1", "group2"),
TokenInfo tokenInfo = new TokenInfo("accesstoken123", null, "User:bob", new HashSet(Arrays.asList("group1", "group2")),
System.currentTimeMillis() - 100000,
System.currentTimeMillis() + 100000);
BearerTokenWithPayload token = new JaasServerOauthValidatorCallbackHandler.BearerTokenWithPayloadImpl(tokenInfo);
Expand Down

0 comments on commit a2fe9bf

Please sign in to comment.