Skip to content

Commit

Permalink
refactor: use AWS UserId as AuthorizationID
Browse files Browse the repository at this point in the history
- AuthorizationID is now required
- BREAKING CHANGE: ARN is no longer used
- BREAKING CHANGE: AuthenticationID is no longer used
- Clients must specify the UserID of their own credentials (can
be retrieved from AWS STS). This is to ensure that clients are
logging in as the appropriate principal in Kafka.
  • Loading branch information
kjdelisle committed Jun 19, 2019
1 parent 0d53480 commit 3ac6bbd
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 166 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>com.stack.security.auth.aws</groupId>
<artifactId>kafka-auth-aws-iam</artifactId>
<version>0.3.0</version>
<version>0.4.0</version>
<packaging>jar</packaging>
<properties>
<maven.compiler.source>8</maven.compiler.source>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class AwsIamAuthenticateCallback implements Callback {
private boolean authenticated;

/**
* Creates a callback with the password provided by the client
* Creates a callback with the AWS credentials provided by the client
*
* @param accessKeyId The AWS Access Key ID provided by the client during
* SASL/AWS authentication
Expand All @@ -27,8 +27,9 @@ public class AwsIamAuthenticateCallback implements Callback {
public AwsIamAuthenticateCallback(String accessKeyId, String secretAccessKey, String sessionToken) {
setAccessKeyId(accessKeyId);
setSecretAccessKey(secretAccessKey);
this.secretAccessKey = secretAccessKey.toCharArray();
this.sessionToken = sessionToken.toCharArray();
if (sessionToken != null && !sessionToken.isBlank()) {
setSessionToken(sessionToken);
}
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProviderChain;
import com.amazonaws.auth.AWSSessionCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult;

public class AwsIamSaslClient implements SaslClient {
Expand All @@ -35,8 +32,7 @@ public class AwsIamSaslClient implements SaslClient {
protected ScheduledExecutorService executor;

protected CallbackHandler cbh;
protected String authorizationID; // Should be the ARN
protected String authenticationID; // Should also be the ARN
protected String authorizationID; // The Unique UserId from AWS STS
protected byte[] accessKeyId;
protected byte[] secretAccessKey;
protected byte[] sessionToken;
Expand Down Expand Up @@ -111,20 +107,15 @@ public byte[] evaluateChallenge(byte[] challengeData) throws SaslException {
// client.
private void setCredentials(AWSCredentials credentials) {
// Use the STS service to find the ARN of our own credentials.
// The ARN is our User principal for Kafka, so we need to provide it.
// NOTE: The server will independently verify with AWS!
AWSSecurityTokenService service = this.stsBuilder.withCredentials(new AWSStaticCredentialsProvider(credentials))
.build();
GetCallerIdentityResult result = service.getCallerIdentity(new GetCallerIdentityRequest());
// They're the same thing.
this.authenticationID = this.authorizationID = result.getArn();
GetCallerIdentityResult result = AwsIamUtilities.getCallerIdentity(stsBuilder, credentials);
this.authorizationID = AwsIamUtilities.getUniqueIdentity(result);
this.accessKeyId = credentials.getAWSAccessKeyId().getBytes(UTF_8);
this.secretAccessKey = credentials.getAWSSecretKey().getBytes(UTF_8);
if (credentials instanceof AWSSessionCredentials) {
AWSSessionCredentials sessionCreds = (AWSSessionCredentials) credentials;
this.sessionToken = sessionCreds.getSessionToken().getBytes(UTF_8);
}
service.shutdown();
credentials = null;
}

Expand All @@ -138,35 +129,23 @@ protected final byte[] generateAnswer() throws SaslException {
}

try {
byte[] authz = (authorizationID != null) ? authorizationID.getBytes(UTF_8) : null;
byte[] auth = authenticationID.getBytes(UTF_8);
byte[] authz = authorizationID.getBytes(UTF_8);

/*
* Answer should be the length of the authentication, authorization (if not
* null), accessKeyId, secretAccessKey, and sessionToken (if not null) plus the
* number of null separator bytes between them all.
*/
byte[] answer;
if (sessionToken != null && authz != null) {
answer = new byte[accessKeyId.length + secretAccessKey.length + auth.length + authz.length + sessionToken.length
+ 5];
} else if (sessionToken != null) {
answer = new byte[accessKeyId.length + secretAccessKey.length + auth.length + sessionToken.length + 4];
} else if (authz != null) {
answer = new byte[accessKeyId.length + secretAccessKey.length + auth.length + authz.length + 4];
if (authz != null && sessionToken != null) {
answer = new byte[authz.length + accessKeyId.length + secretAccessKey.length + sessionToken.length + 4];
} else {
answer = new byte[accessKeyId.length + secretAccessKey.length + auth.length + 3];
answer = new byte[authz.length + accessKeyId.length + secretAccessKey.length + 3];
}

int pos = 0;
if (authz != null) {
System.arraycopy(authz, 0, answer, 0, authz.length);
pos = authz.length;
answer[pos++] = SEP;
}

System.arraycopy(auth, 0, answer, pos, auth.length);
pos += auth.length;
System.arraycopy(authz, 0, answer, 0, authz.length);
pos = authz.length;
answer[pos++] = SEP;

System.arraycopy(accessKeyId, 0, answer, pos, accessKeyId.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;

import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.stack.security.auth.aws.AwsIamAuthenticateCallback;

import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;

import com.stack.security.auth.aws.AwsIamAuthenticateCallback;

/**
* Simple SaslServer implementation for SASL/AWS. Checks the provided AWS
* credentials against the AWS STS service and compares the returned identity
Expand All @@ -31,6 +32,7 @@ public class AwsIamSaslServer implements SaslServer {
public static final String AWS_MECHANISM = "AWS";

private final AuthenticateCallbackHandler callbackHandler;
private AWSSecurityTokenServiceClientBuilder builder;
private boolean complete;
private String authorizationId;

Expand All @@ -41,6 +43,15 @@ public AwsIamSaslServer(CallbackHandler callbackHandler) {
this.callbackHandler = (AuthenticateCallbackHandler) callbackHandler;
}

/**
* This overload is primarily for testing purposes (as a passthrough for the
* AwsIamUtilities call).
*/
public AwsIamSaslServer(CallbackHandler callbackHandler, AWSSecurityTokenServiceClientBuilder builder) {
this(callbackHandler);
this.builder = builder;
}

/**
* @throws SaslAuthenticationException if username/password combination is
* invalid or if the requested authorization
Expand All @@ -63,28 +74,27 @@ public byte[] evaluateResponse(byte[] responseBytes) throws SaslAuthenticationEx
String response = new String(responseBytes, StandardCharsets.UTF_8);
List<String> tokens = extractTokens(response);
String authorizationIdFromClient = tokens.get(0);
String arn = tokens.get(1);
String accessKeyId = tokens.get(2);
String secretAccessKey = tokens.get(3);
String accessKeyId = tokens.get(1);
String secretAccessKey = tokens.get(2);
String sessionToken;
try {
sessionToken = tokens.get(4);
sessionToken = tokens.get(3);
} catch (Throwable e) {
// Ignore the exception, just set the token to empty string.
sessionToken = "";
}

if (arn.isEmpty()) {
throw new SaslAuthenticationException("Authentication failed: arn not specified");
if (authorizationIdFromClient.isBlank()) {
throw new SaslAuthenticationException("Authentication failed: authorizationId not specified");
}
if (accessKeyId.isEmpty()) {
if (accessKeyId.isBlank()) {
throw new SaslAuthenticationException("Authentication failed: accessKeyId not specified");
}
if (secretAccessKey.isEmpty()) {
if (secretAccessKey.isBlank()) {
throw new SaslAuthenticationException("Authentication failed: secretAccessKey not specified");
}

NameCallback nameCallback = new NameCallback("arn", arn);
NameCallback nameCallback = new NameCallback("authorizationId", authorizationIdFromClient);
AwsIamAuthenticateCallback authenticateCallback = new AwsIamAuthenticateCallback(accessKeyId, secretAccessKey,
sessionToken);
try {
Expand All @@ -95,11 +105,12 @@ public byte[] evaluateResponse(byte[] responseBytes) throws SaslAuthenticationEx
}
if (!authenticateCallback.authenticated())
throw new SaslAuthenticationException("Authentication failed: Invalid AWS credentials");
if (!authorizationIdFromClient.isEmpty() && !authorizationIdFromClient.equals(arn))
throw new SaslAuthenticationException(
"Authentication failed: Client requested an authorization id that is different from username");

this.authorizationId = arn;
if (this.builder != null) {
this.authorizationId = AwsIamUtilities.getUniqueIdentity(builder, accessKeyId, secretAccessKey, sessionToken);
} else {
this.authorizationId = AwsIamUtilities.getUniqueIdentity(accessKeyId, secretAccessKey, sessionToken);
}

complete = true;
return new byte[0];
Expand All @@ -108,21 +119,23 @@ public byte[] evaluateResponse(byte[] responseBytes) throws SaslAuthenticationEx
private List<String> extractTokens(String string) {
List<String> tokens = new ArrayList<>();
int startIndex = 0;
for (int i = 0; i < 6; ++i) {
for (int i = 0; i < 5; ++i) {
int endIndex = string.indexOf("\u0000", startIndex);
if (endIndex == -1) {
String remaining = string.substring(startIndex);
if (!remaining.equals("")) {
tokens.add(remaining);
if (startIndex < string.length()) {
String remaining = string.substring(startIndex);
if (!remaining.equals("")) {
tokens.add(remaining);
}
}
break;
}
tokens.add(string.substring(startIndex, endIndex));
startIndex = endIndex + 1;
}

if (tokens.size() < 4 || tokens.size() > 5)
throw new SaslAuthenticationException("Invalid SASL/AWS response: expected 4 or 5 tokens, got " + tokens.size());
if (tokens.size() < 3 || tokens.size() > 4)
throw new SaslAuthenticationException("Invalid SASL/AWS response: expected 3 or 4 tokens, got " + tokens.size());

return tokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,48 @@ public void configure(Map<String, ?> configs, String mechanism, List<AppConfigur

@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
String arn = null;
String authorizationId = null;
for (Callback callback : callbacks) {
if (callback instanceof NameCallback)
arn = ((NameCallback) callback).getDefaultName();
authorizationId = ((NameCallback) callback).getDefaultName();
else if (callback instanceof AwsIamAuthenticateCallback) {
AwsIamAuthenticateCallback awsIamCallback = (AwsIamAuthenticateCallback) callback;
boolean authenticated = authenticate(arn, awsIamCallback.getAccessKeyId(), awsIamCallback.getSecretAccessKey(),
awsIamCallback.getSessionToken());
boolean authenticated = authenticate(authorizationId, awsIamCallback.getAccessKeyId(),
awsIamCallback.getSecretAccessKey(), awsIamCallback.getSessionToken());
awsIamCallback.authenticated(authenticated);
} else
throw new UnsupportedCallbackException(callback);
}
}

protected boolean authenticate(String arn, char[] accessKeyId, char[] secretAccessKey, char[] sessionToken) {
protected boolean authenticate(String authorizationId, char[] accessKeyId, char[] secretAccessKey,
char[] sessionToken) {

// At a minimum, the ARN, Access Key ID and Secret Access Key MUST be defined!
if (arn == null || accessKeyId == null || secretAccessKey == null) {
// At a minimum, the authorizationId, Access Key ID and Secret Access Key MUST
// be defined!
if (authorizationId == null || accessKeyId == null || secretAccessKey == null) {
return false;
}

AWSCredentials awsCreds;
String accessKeyIdString = new String(accessKeyId);
String secretAccessKeyString = new String(secretAccessKey);
String sessionTokenString = new String(sessionToken);
String sessionTokenString = sessionToken == null ? "" : new String(sessionToken);

if (!sessionTokenString.isEmpty()) {
awsCreds = new BasicSessionCredentials(accessKeyIdString, secretAccessKeyString, sessionTokenString);
} else {
awsCreds = new BasicAWSCredentials(accessKeyIdString, secretAccessKeyString);
if (authorizationId.isBlank() || accessKeyIdString.isBlank() || secretAccessKeyString.isBlank()) {
return false;
}
AWSSecurityTokenService service = builder.withCredentials(new AWSStaticCredentialsProvider(awsCreds)).build();
// As an added measure of safety, the server can specify what AWS Account ID it
// expects to see as a part of the caller's identity.
String expectedAwsAccountId = JaasContext.configEntryOption(jaasConfigEntries, AWS_ACCOUNT_ID,
AwsIamLoginModule.class.getName());

// Check the credentials with AWS STS and GetCallerIdentity.

GetCallerIdentityRequest request = new GetCallerIdentityRequest();
GetCallerIdentityResult result = service.getCallerIdentity(request);
GetCallerIdentityResult result = AwsIamUtilities.getCallerIdentity(builder, accessKeyIdString,
secretAccessKeyString, sessionTokenString);

// Both the ARN returned by the credentials, and the configured account ID need
// to match!
if (result.getArn().equals(arn) && result.getAccount().equals(expectedAwsAccountId)) {
if (result.getUserId().equals(authorizationId) && result.getAccount().equals(expectedAwsAccountId)) {
return true;
} else {
return false;
Expand Down
Loading

0 comments on commit 3ac6bbd

Please sign in to comment.