Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update credential type logic #164

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions Notation.Plugin.AzureKeyVault.Tests/KeyVault/CredentialsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace Notation.Plugin.AzureKeyVault.Credential.Tests
public class CredentialsTests
{
[Theory]
[InlineData("default")]
[InlineData(null)]
[InlineData("environment")]
[InlineData("workloadid")]
[InlineData("managedid")]
[InlineData("azurecli")]
public void GetCredentials_WithValidCredentialType_ReturnsExpectedCredential(string credentialType)
public void GetCredentials_WithValidCredentialType_ReturnsExpectedCredential(string? credentialType)
{
// Act
var result = Credentials.GetCredentials(credentialType);
Expand All @@ -30,23 +30,7 @@ public void GetCredentials_WithInvalidCredentialType_ThrowsValidationException()

// Act & Assert
var ex = Assert.Throws<ValidationException>(() => Credentials.GetCredentials(invalidCredentialType));
Assert.Equal($"Invalid credential key: {invalidCredentialType}", ex.Message);
}

[Fact]
public void GetCredentials_WithPluginConfig_ReturnsExpectedCredential()
{
// Arrange
var pluginConfig = new Dictionary<string, string>
{
{ "credential_type", "default" }
};

// Act
var result = Credentials.GetCredentials(pluginConfig);

// Assert
Assert.IsAssignableFrom<TokenCredential>(result);
Assert.Equal($"Invalid credential type: {invalidCredentialType}", ex.Message);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ namespace Notation.Plugin.AzureKeyVault.Client.Tests
{
public class KeyVaultClientTests
{
private string? defaultCredentialType = null;

[Fact]
public void TestConstructorWithKeyId()
{
string keyId = "https://myvault.vault.azure.net/keys/my-key/123";

KeyVaultClient keyVaultClient = new KeyVaultClient(keyId, Credentials.GetCredentials("default"));
KeyVaultClient keyVaultClient = new KeyVaultClient(keyId, Credentials.GetCredentials(defaultCredentialType));

Assert.Equal("my-key", keyVaultClient.Name);
Assert.Equal("123", keyVaultClient.Version);
Expand All @@ -38,7 +40,7 @@ public void TestConstructorWithKeyVaultUrlNameVersion()
string name = "my-key";
string version = "123";

KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, version, Credentials.GetCredentials("default"));
KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, version, Credentials.GetCredentials(defaultCredentialType));

Assert.Equal(name, keyVaultClient.Name);
Assert.Equal(version, keyVaultClient.Version);
Expand All @@ -52,14 +54,14 @@ public void TestConstructorWithKeyVaultUrlNameVersion()
[InlineData("http://myvault.vault.azure.net/keys/my-key/123")]
public void TestConstructorWithInvalidKeyId(string invalidKeyId)
{
Assert.Throws<ValidationException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials("default")));
Assert.Throws<ValidationException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials(defaultCredentialType)));
}

[Theory]
[InlineData("")]
public void TestConstructorWithEmptyKeyId(string invalidKeyId)
{
Assert.Throws<ArgumentNullException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials("default")));
Assert.Throws<ArgumentNullException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials(defaultCredentialType)));
}

private class TestableKeyVaultClient : KeyVaultClient
Expand Down Expand Up @@ -89,7 +91,7 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(SignResult signResult)
mockCryptoClient.Setup(c => c.SignDataAsync(It.IsAny<SignatureAlgorithm>(), It.IsAny<byte[]>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(signResult);

return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-key", "123", mockCryptoClient.Object, Credentials.GetCredentials("default"));
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-key", "123", mockCryptoClient.Object, Credentials.GetCredentials(defaultCredentialType));
}

private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate certificate)
Expand All @@ -98,15 +100,15 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate ce
mockCertificateClient.Setup(c => c.GetCertificateVersionAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(Response.FromValue(certificate, new Mock<Response>().Object));

return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object, Credentials.GetCredentials("default"));
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object, Credentials.GetCredentials(defaultCredentialType));
}

private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultSecret secret)
{
var mockSecretClient = new Mock<SecretClient>(new Uri("https://fake.vault.azure.net/secrets/fake-secret/123"), new Mock<TokenCredential>().Object);
mockSecretClient.Setup(c => c.GetSecretAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(Response.FromValue(secret, new Mock<Response>().Object));
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockSecretClient.Object, Credentials.GetCredentials("default"));
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockSecretClient.Object, Credentials.GetCredentials(defaultCredentialType));
}

[Fact]
Expand Down
23 changes: 11 additions & 12 deletions Notation.Plugin.AzureKeyVault/KeyVault/Credentials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ public class Credentials
/// </summary>
public const string CredentialTypeKey = "credential_type";
/// <summary>
/// Default credential name.
/// </summary>
public const string DefaultCredentialName = "default";
/// <summary>
/// Environment credential name.
/// </summary>
public const string EnvironmentCredentialName = "environment";
Expand All @@ -34,13 +30,16 @@ public class Credentials
/// <summary>
/// Get the credential based on the credential type.
/// </summary>
public static TokenCredential GetCredentials(string credentialType)
public static TokenCredential GetCredentials(string? credentialType)
{
if (credentialType == null)
{
return new DefaultAzureCredential();
}

credentialType = credentialType.ToLower();
switch (credentialType)
{
case DefaultCredentialName:
return new DefaultAzureCredential();
case EnvironmentCredentialName:
return new EnvironmentCredential();
case WorkloadIdentityCredentialName:
Expand All @@ -50,7 +49,7 @@ public static TokenCredential GetCredentials(string credentialType)
case AzureCliCredentialName:
return new AzureCliCredential();
default:
throw new ValidationException($"Invalid credential key: {credentialType}");
throw new ValidationException($"Invalid credential type: {credentialType}");
}
}

Expand All @@ -59,9 +58,9 @@ public static TokenCredential GetCredentials(string credentialType)
/// </summary>
public static TokenCredential GetCredentials(Dictionary<string, string>? pluginConfig)
{
var credentialName = pluginConfig?.GetValueOrDefault(CredentialTypeKey, DefaultCredentialName) ??
DefaultCredentialName;
return GetCredentials(credentialName);
string? credentialType = null;
pluginConfig?.TryGetValue(CredentialTypeKey, out credentialType);
return GetCredentials(credentialType);
}
}
}
}
4 changes: 2 additions & 2 deletions docs/plugin-config.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ notation sign <registry>/<repository>@<digest> \

## credential_type
Set the preferred credential type. Currently, the following credential types are supported:
- [default](https://learn.microsoft.com/dotnet/api/azure.identity.defaultazurecredential?view=azure-dotnet)
- [environment](https://learn.microsoft.com/dotnet/api/azure.identity.environmentcredential?view=azure-dotnet)
- [workloadid](https://learn.microsoft.com/dotnet/api/azure.identity.workloadidentitycredential?view=azure-dotnet)
- [managedid](https://learn.microsoft.com/dotnet/api/azure.identity.managedidentitycredential?view=azure-dotnet)
- [azurecli](https://learn.microsoft.com/dotnet/api/azure.identity.azureclicredential?view=azure-dotnet)

Default: **default** (default credential)
Default: please see the [Default Azure Credential](https://learn.microsoft.com/dotnet/api/azure.identity.defaultazurecredential?view=azure-dotnet) for details on automatically trying a list of credential types.

Example
```
Expand Down
Loading