Skip to content

Commit

Permalink
Refactor hashing and security code
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-radhakrishnan committed Sep 21, 2022
1 parent 393020b commit a69b86a
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 84 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
package com.linkedin.metadata.secret;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;


public class SecretService {
private static final int LOWERCASE_ASCII_START = 97;
private static final int LOWERCASE_ASCII_END = 122;
public static final String HASHING_ALGORITHM = "SHA-256";

private final String _secret;
private final SecureRandom _secureRandom;
private final Base64.Encoder _encoder;
private final Base64.Decoder _decoder;
private final MessageDigest _messageDigest;

public SecretService(final String secret) {
_secret = secret;
_secureRandom = new SecureRandom();
_encoder = Base64.getEncoder();
_decoder = Base64.getDecoder();
try {
_messageDigest = MessageDigest.getInstance(HASHING_ALGORITHM);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Unable to create MessageDigest", e);
}
}

public String encrypt(String value) {
Expand All @@ -33,7 +53,7 @@ public String encrypt(String value) {
}
Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
cipher.init(Cipher.ENCRYPT_MODE, secretKey);
return Base64.getEncoder().encodeToString(cipher.doFinal(value.getBytes(StandardCharsets.UTF_8)));
return _encoder.encodeToString(cipher.doFinal(value.getBytes(StandardCharsets.UTF_8)));
} catch (Exception e) {
throw new RuntimeException("Failed to encrypt value using provided secret!", e);
}
Expand All @@ -55,9 +75,40 @@ public String decrypt(String encryptedValue) {
}
Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5PADDING");
cipher.init(Cipher.DECRYPT_MODE, secretKey);
return new String(cipher.doFinal(Base64.getDecoder().decode(encryptedValue)));
return new String(cipher.doFinal(_decoder.decode(encryptedValue)));
} catch (Exception e) {
throw new RuntimeException("Failed to decrypt value using provided secret!", e);
}
}

public String generateUrlSafeToken(int length) {
return _secureRandom.ints(length, LOWERCASE_ASCII_START, LOWERCASE_ASCII_END + 1)
.mapToObj(i -> String.valueOf((char) i))
.collect(Collectors.joining());
}

public String hashString(@Nonnull final String str) {
byte[] hashedBytes = _messageDigest.digest(str.getBytes());
return _encoder.encodeToString(hashedBytes);
}

public byte[] generateSalt(int length) {
byte[] randomBytes = new byte[length];
_secureRandom.nextBytes(randomBytes);
return randomBytes;
}

public String getHashedPassword(@Nonnull byte[] salt, @Nonnull String password) throws IOException {
byte[] saltedPassword = saltPassword(salt, password);
byte[] hashedPassword = _messageDigest.digest(saltedPassword);
return _encoder.encodeToString(hashedPassword);
}

byte[] saltPassword(@Nonnull byte[] salt, @Nonnull String password) throws IOException {
byte[] passwordBytes = password.getBytes();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byteArrayOutputStream.write(salt);
byteArrayOutputStream.write(passwordBytes);
return byteArrayOutputStream.toByteArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,26 @@
import com.linkedin.mxe.MetadataChangeProposal;
import com.linkedin.r2.RemoteInvocationException;
import java.net.URISyntaxException;
import java.security.MessageDigest;
import java.util.Base64;
import java.util.Collections;
import java.util.Objects;
import java.util.UUID;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import static com.linkedin.metadata.Constants.*;
import static com.linkedin.metadata.entity.AspectUtils.*;


@Slf4j
@RequiredArgsConstructor
public class InviteTokenService {
private static final String HASHING_ALGORITHM = "SHA-256";
private static final String ROLE_FIELD_NAME = "role";
private static final String HAS_ROLE_FIELD_NAME = "hasRole";
private final EntityClient _entityClient;
private final SecretService _secretService;
private final MessageDigest _messageDigest;
private final Base64.Encoder _encoder;

public InviteTokenService(@Nonnull EntityClient entityClient, @Nonnull SecretService secretService) throws Exception {
_entityClient = Objects.requireNonNull(entityClient, "entityClient must not be null");
_secretService = Objects.requireNonNull(secretService, "secretService must not be null");
_messageDigest = MessageDigest.getInstance(HASHING_ALGORITHM);
_encoder = Base64.getEncoder();
}

public Urn getInviteTokenUrn(@Nonnull final String inviteTokenStr) throws URISyntaxException {
String hashedInviteTokenStr = hashString(inviteTokenStr);
String hashedInviteTokenStr = _secretService.hashString(inviteTokenStr);
String inviteTokenUrnStr = String.format("urn:li:inviteToken:%s", hashedInviteTokenStr);
return Urn.createFromString(inviteTokenUrnStr);
}
Expand Down Expand Up @@ -94,11 +82,6 @@ public String getInviteToken(@Nullable final String roleUrnStr, boolean regenera
return _secretService.decrypt(inviteToken.getToken());
}

private String hashString(@Nonnull final String str) {
byte[] hashedBytes = _messageDigest.digest(str.getBytes());
return _encoder.encodeToString(hashedBytes);
}

private com.linkedin.identity.InviteToken getInviteTokenEntity(@Nonnull final Urn inviteTokenUrn,
@Nonnull final Authentication authentication) throws RemoteInvocationException, URISyntaxException {
final EntityResponse inviteTokenEntity =
Expand Down Expand Up @@ -159,8 +142,8 @@ private Filter createInviteTokenFilter(@Nonnull final String roleUrnStr) {
@Nonnull
private String createInviteToken(@Nullable final String roleUrnStr, @Nonnull final Authentication authentication)
throws Exception {
String inviteTokenStr = UUID.randomUUID().toString();
String hashedInviteTokenStr = hashString(inviteTokenStr);
String inviteTokenStr = _secretService.generateUrlSafeToken(INVITE_TOKEN_LENGTH);
String hashedInviteTokenStr = _secretService.hashString(inviteTokenStr);
InviteTokenKey inviteTokenKey = new InviteTokenKey();
inviteTokenKey.setId(hashedInviteTokenStr);
com.linkedin.identity.InviteToken inviteTokenAspect =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
import com.linkedin.metadata.secret.SecretService;
import com.linkedin.metadata.utils.GenericRecordUtils;
import com.linkedin.mxe.MetadataChangeProposal;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Base64;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import static com.linkedin.metadata.Constants.*;
Expand All @@ -31,25 +27,13 @@
* Service responsible for creating, updating and authenticating native DataHub users.
*/
@Slf4j
@RequiredArgsConstructor
public class NativeUserService {
private static final int SALT_TOKEN_LENGTH = 16;
private static final String HASHING_ALGORITHM = "SHA-256";
private static final long ONE_DAY_MILLIS = TimeUnit.DAYS.toMillis(1);

private final EntityService _entityService;
private final EntityClient _entityClient;
private final SecretService _secretService;
private final SecureRandom _secureRandom;
private final MessageDigest _messageDigest;

public NativeUserService(@Nonnull EntityService entityService, @Nonnull EntityClient entityClient,
@Nonnull SecretService secretService) throws Exception {
_entityService = Objects.requireNonNull(entityService, "entityService must not be null!");
_entityClient = Objects.requireNonNull(entityClient, "entityClient must not be null!");
_secretService = Objects.requireNonNull(secretService, "secretService must not be null!");
_secureRandom = new SecureRandom();
_messageDigest = MessageDigest.getInstance(HASHING_ALGORITHM);
}

public void createNativeUser(@Nonnull String userUrnString, @Nonnull String fullName, @Nonnull String email,
@Nonnull String title, @Nonnull String password, @Nonnull Authentication authentication) throws Exception {
Expand Down Expand Up @@ -110,10 +94,10 @@ void updateCorpUserCredentials(@Nonnull Urn userUrn, @Nonnull String password, @
throws Exception {
// Construct corpUserCredentials
CorpUserCredentials corpUserCredentials = new CorpUserCredentials();
final byte[] salt = getRandomBytes(SALT_TOKEN_LENGTH);
final byte[] salt = _secretService.generateSalt(SALT_TOKEN_LENGTH);
String encryptedSalt = _secretService.encrypt(Base64.getEncoder().encodeToString(salt));
corpUserCredentials.setSalt(encryptedSalt);
String hashedPassword = getHashedPassword(salt, password);
String hashedPassword = _secretService.getHashedPassword(salt, password);
corpUserCredentials.setHashedPassword(hashedPassword);

// Ingest corpUserCredentials MCP
Expand All @@ -138,7 +122,7 @@ public String generateNativeUserPasswordResetToken(@Nonnull String userUrnString
throw new RuntimeException("User does not exist or is a non-native user!");
}
// Add reset token to CorpUserCredentials
String passwordResetToken = generateRandomLowercaseToken();
String passwordResetToken = _secretService.generateUrlSafeToken(PASSWORD_RESET_TOKEN_LENGTH);
corpUserCredentials.setPasswordResetToken(_secretService.encrypt(passwordResetToken));

long expirationTime = Instant.now().plusMillis(ONE_DAY_MILLIS).toEpochMilli();
Expand Down Expand Up @@ -186,10 +170,10 @@ public void resetCorpUserCredentials(@Nonnull String userUrnString, @Nonnull Str
}

// Construct corpUserCredentials
final byte[] salt = getRandomBytes(SALT_TOKEN_LENGTH);
final byte[] salt = _secretService.generateSalt(SALT_TOKEN_LENGTH);
String encryptedSalt = _secretService.encrypt(Base64.getEncoder().encodeToString(salt));
corpUserCredentials.setSalt(encryptedSalt);
String hashedPassword = getHashedPassword(salt, password);
String hashedPassword = _secretService.getHashedPassword(salt, password);
corpUserCredentials.setHashedPassword(hashedPassword);

// Ingest corpUserCredentials MCP
Expand All @@ -202,30 +186,6 @@ public void resetCorpUserCredentials(@Nonnull String userUrnString, @Nonnull Str
_entityClient.ingestProposal(corpUserCredentialsProposal, authentication);
}

byte[] getRandomBytes(int length) {
byte[] randomBytes = new byte[length];
_secureRandom.nextBytes(randomBytes);
return randomBytes;
}

String generateRandomLowercaseToken() {
return UUID.randomUUID().toString();
}

byte[] saltPassword(@Nonnull byte[] salt, @Nonnull String password) throws IOException {
byte[] passwordBytes = password.getBytes();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byteArrayOutputStream.write(salt);
byteArrayOutputStream.write(passwordBytes);
return byteArrayOutputStream.toByteArray();
}

public String getHashedPassword(@Nonnull byte[] salt, @Nonnull String password) throws IOException {
byte[] saltedPassword = saltPassword(salt, password);
byte[] hashedPassword = _messageDigest.digest(saltedPassword);
return Base64.getEncoder().encodeToString(hashedPassword);
}

public boolean doesPasswordMatch(@Nonnull String userUrnString, @Nonnull String password) throws Exception {
Objects.requireNonNull(userUrnString, "userUrnSting must not be null!");
Objects.requireNonNull(password, "Password must not be null!");
Expand All @@ -240,7 +200,7 @@ public boolean doesPasswordMatch(@Nonnull String userUrnString, @Nonnull String
String decryptedSalt = _secretService.decrypt(corpUserCredentials.getSalt());
byte[] salt = Base64.getDecoder().decode(decryptedSalt);
String storedHashedPassword = corpUserCredentials.getHashedPassword();
String hashedPassword = getHashedPassword(salt, password);
String hashedPassword = _secretService.getHashedPassword(salt, password);
return storedHashedPassword.equals(hashedPassword);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class InviteTokenServiceTest {
private static final String INVITE_TOKEN_URN_STRING = "urn:li:inviteToken:admin-invite-token";
private static final String ROLE_URN_STRING = "urn:li:dataHubRole:Admin";
private static final String INVITE_TOKEN_STRING = "inviteToken";
private static final String HASHED_INVITE_TOKEN_STRING = "hashedInviteToken";
private static final String ENCRYPTED_INVITE_TOKEN_STRING = "encryptedInviteToken";
private static final String DATAHUB_SYSTEM_CLIENT_ID = "__datahub_system";
private static final Authentication SYSTEM_AUTHENTICATION =
Expand Down Expand Up @@ -127,6 +128,8 @@ public void getInviteTokenRegenerate() throws Exception {
searchResult.setEntities(new SearchEntityArray());
when(_entityClient.filter(eq(INVITE_TOKEN_ENTITY_NAME), any(), any(), anyInt(), anyInt(),
eq(SYSTEM_AUTHENTICATION))).thenReturn(searchResult);
when(_secretService.generateUrlSafeToken(anyInt())).thenReturn(INVITE_TOKEN_STRING);
when(_secretService.hashString(anyString())).thenReturn(HASHED_INVITE_TOKEN_STRING);
when(_secretService.encrypt(anyString())).thenReturn(ENCRYPTED_INVITE_TOKEN_STRING);

_inviteTokenService.getInviteToken(null, true, SYSTEM_AUTHENTICATION);
Expand All @@ -139,6 +142,8 @@ public void getInviteTokenEmptySearchResult() throws Exception {
searchResult.setEntities(new SearchEntityArray());
when(_entityClient.filter(eq(INVITE_TOKEN_ENTITY_NAME), any(), any(), anyInt(), anyInt(),
eq(SYSTEM_AUTHENTICATION))).thenReturn(searchResult);
when(_secretService.generateUrlSafeToken(anyInt())).thenReturn(INVITE_TOKEN_STRING);
when(_secretService.hashString(anyString())).thenReturn(HASHED_INVITE_TOKEN_STRING);
when(_secretService.encrypt(anyString())).thenReturn(ENCRYPTED_INVITE_TOKEN_STRING);

_inviteTokenService.getInviteToken(null, false, SYSTEM_AUTHENTICATION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ public class NativeUserServiceTest {
private static final String EMAIL = "[email protected]";
private static final String TITLE = "Data Scientist";
private static final String PASSWORD = "password";
private static final String HASHED_PASSWORD = "hashedPassword";
private static final String ENCRYPTED_INVITE_TOKEN = "encryptedInviteroToken";
private static final String RESET_TOKEN = "inviteToken";
private static final String ENCRYPTED_RESET_TOKEN = "encryptedInviteToken";
private static final byte[] SALT = "salt".getBytes();
private static final String ENCRYPTED_SALT = "encryptedSalt";
private static final Urn USER_URN = new CorpuserUrn(EMAIL);
private static final long ONE_DAY_MILLIS = TimeUnit.DAYS.toMillis(1);
Expand All @@ -50,16 +52,6 @@ public void setupTest() throws Exception {
_nativeUserService = new NativeUserService(_entityService, _entityClient, _secretService);
}

@Test
public void testConstructor() throws Exception {
assertThrows(() -> new NativeUserService(null, _entityClient, _secretService));
assertThrows(() -> new NativeUserService(_entityService, null, _secretService));
assertThrows(() -> new NativeUserService(_entityService, _entityClient, null));

// Succeeds!
new NativeUserService(_entityService, _entityClient, _secretService);
}

@Test
public void testCreateNativeUserNullArguments() {
assertThrows(
Expand All @@ -85,7 +77,9 @@ public void testCreateNativeUserUserAlreadyExists() throws Exception {
@Test
public void testCreateNativeUserPasses() throws Exception {
when(_entityService.exists(any())).thenReturn(false);
when(_secretService.generateSalt(anyInt())).thenReturn(SALT);
when(_secretService.encrypt(any())).thenReturn(ENCRYPTED_SALT);
when(_secretService.getHashedPassword(any(), any())).thenReturn(HASHED_PASSWORD);

_nativeUserService.createNativeUser(USER_URN_STRING, FULL_NAME, EMAIL, TITLE, PASSWORD, SYSTEM_AUTHENTICATION);
}
Expand All @@ -104,7 +98,9 @@ public void testUpdateCorpUserStatusPasses() throws Exception {

@Test
public void testUpdateCorpUserCredentialsPasses() throws Exception {
when(_secretService.generateSalt(anyInt())).thenReturn(SALT);
when(_secretService.encrypt(any())).thenReturn(ENCRYPTED_SALT);
when(_secretService.getHashedPassword(any(), any())).thenReturn(HASHED_PASSWORD);

_nativeUserService.updateCorpUserCredentials(USER_URN, PASSWORD, SYSTEM_AUTHENTICATION);
verify(_entityClient).ingestProposal(any(), any());
Expand Down Expand Up @@ -209,6 +205,7 @@ public void testResetCorpUserCredentialsPasses() throws Exception {
when(mockCorpUserCredentialsAspect.getPasswordResetTokenExpirationTimeMillis()).thenReturn(
Instant.now().plusMillis(ONE_DAY_MILLIS).toEpochMilli());
when(_secretService.decrypt(eq(ENCRYPTED_RESET_TOKEN))).thenReturn(RESET_TOKEN);
when(_secretService.generateSalt(anyInt())).thenReturn(SALT);
when(_secretService.encrypt(any())).thenReturn(ENCRYPTED_SALT);

_nativeUserService.resetCorpUserCredentials(USER_URN_STRING, PASSWORD, RESET_TOKEN, SYSTEM_AUTHENTICATION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ public class Constants {

public static final String DEFAULT_RUN_ID = "no-run-id-provided";

public static final String GLOBAL_INVITE_TOKEN = "urn:li:inviteToken:global";

/**
* Entities
*/
Expand Down Expand Up @@ -236,6 +234,10 @@ public class Constants {

// Invite Token
public static final String INVITE_TOKEN_ASPECT_NAME = "inviteToken";
public static final int INVITE_TOKEN_LENGTH = 32;
public static final int SALT_TOKEN_LENGTH = 16;
public static final int PASSWORD_RESET_TOKEN_LENGTH = 32;


// Relationships
public static final String IS_MEMBER_OF_GROUP_RELATIONSHIP_NAME = "IsMemberOfGroup";
Expand Down

0 comments on commit a69b86a

Please sign in to comment.