Skip to content

Commit

Permalink
feat: add versionless key identifier support (#181)
Browse files Browse the repository at this point in the history
Feat:
- added versionless key/certificate identifier support

Test:
- added unit test
- added E2E test

Docs:
- added tips for versionless id feature
- fixed NOTE marks

Resolves #180 #178 
Signed-off-by: Junjie Gao <[email protected]>

---------

Signed-off-by: Junjie Gao <[email protected]>
  • Loading branch information
JeyJeyGao authored May 29, 2024
1 parent 38be1cc commit f82c165
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public void Constructor_Invalid()
}

[Fact]
public async void RunAsync_NoSecertsGetPermission()
public async Task RunAsync_NoSecertsGetPermission()
{
// Arrange
var keyId = "https://testvault.vault.azure.net/keys/testkey/123";
Expand All @@ -196,7 +196,7 @@ public async void RunAsync_NoSecertsGetPermission()
}

[Fact]
public async void RunAsync_OtherRequestFailedException()
public async Task RunAsync_OtherRequestFailedException()
{
// Arrange
var keyId = "https://testvault.vault.azure.net/keys/testkey/123";
Expand All @@ -223,7 +223,7 @@ public async void RunAsync_OtherRequestFailedException()
}

[Fact]
public async void RunAsync_SelfSignedWithCaCerts()
public async Task RunAsync_SelfSignedWithCaCerts()
{
// Arrange
var keyId = "https://testvault.vault.azure.net/keys/testkey/123";
Expand Down
67 changes: 57 additions & 10 deletions Notation.Plugin.AzureKeyVault.Tests/KeyVault/KeyVaultClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,48 @@ public void TestConstructorWithKeyVaultUrlNameVersion()
Assert.Equal($"{keyVaultUrl}/keys/{name}/{version}", keyVaultClient.KeyId);
}

[Fact]
public void TestConstructorWithVersionlessKey()
{
string keyVaultUrl = "https://myvault.vault.azure.net";
string name = "my-key";

KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, null, Credentials.GetCredentials(defaultCredentialType));
Assert.Equal(name, keyVaultClient.Name);
Assert.Null(keyVaultClient.Version);
Assert.Equal($"{keyVaultUrl}/keys/{name}", keyVaultClient.KeyId);

keyVaultClient = new KeyVaultClient($"{keyVaultUrl}/keys/{name}", Credentials.GetCredentials(defaultCredentialType));
Assert.Equal(name, keyVaultClient.Name);
Assert.Null(keyVaultClient.Version);
Assert.Equal($"{keyVaultUrl}/keys/{name}", keyVaultClient.KeyId);
}

[Theory]
[InlineData("")]
[InlineData("https://myvault.vault.azure.net/invalid/my-key/123")]
[InlineData("https://myvault.vault.azure.net/keys/my-key")]
[InlineData("https://myvault.vault.azure.net/keys/my-key/")]
[InlineData("http://myvault.vault.azure.net/keys/my-key/123")]
[InlineData("https://myvault.vault.azure.net/keys")]
[InlineData("https://myvault.vault.azure.net/invalid/my-key/123/1234")]
public void TestConstructorWithInvalidKeyId(string invalidKeyId)
{
Assert.Throws<ValidationException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials(defaultCredentialType)));
}

[Theory]
[InlineData("", "", "")]
[InlineData("https://myvault.vault.azure.net", "", "")]
[InlineData("https://myvault.vault.azure.net", "my-key", "")]
public void TestConstructorWithInvalidArguments(string keyVaultUrl, string name, string? version)
{
Assert.Throws<ValidationException>(() => new KeyVaultClient(keyVaultUrl, name, version, Credentials.GetCredentials(defaultCredentialType)));
}

[Theory]
[InlineData("")]
public void TestConstructorWithEmptyKeyId(string invalidKeyId)
{
Assert.Throws<ArgumentNullException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials(defaultCredentialType)));
Assert.Throws<ValidationException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials(defaultCredentialType)));
}

private class TestableKeyVaultClient : KeyVaultClient
Expand All @@ -72,7 +99,7 @@ public TestableKeyVaultClient(string keyVaultUrl, string name, string version, C
this._cryptoClient = new Lazy<CryptographyClient>(() => cryptoClient);
}

public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CertificateClient certificateClient, TokenCredential credenital)
public TestableKeyVaultClient(string keyVaultUrl, string name, string? version, CertificateClient certificateClient, TokenCredential credenital)
: base(keyVaultUrl, name, version, credenital)
{
this._certificateClient = new Lazy<CertificateClient>(() => certificateClient);
Expand Down Expand Up @@ -103,6 +130,15 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate ce
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object, Credentials.GetCredentials(defaultCredentialType));
}

private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificateWithPolicy certWithPolicy)
{
var mockCertificateClient = new Mock<CertificateClient>(new Uri("https://fake.vault.azure.net/certificates/fake-certificate/123"), new Mock<TokenCredential>().Object);
mockCertificateClient.Setup(c => c.GetCertificateAsync(It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(Response.FromValue(certWithPolicy, new Mock<Response>().Object));

return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", null, mockCertificateClient.Object, Credentials.GetCredentials(defaultCredentialType));
}

private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultSecret secret)
{
var mockSecretClient = new Mock<SecretClient>(new Uri("https://fake.vault.azure.net/secrets/fake-secret/123"), new Mock<TokenCredential>().Object);
Expand Down Expand Up @@ -158,11 +194,6 @@ public async Task TestSignAsyncThrowsExceptionOnInvalidAlgorithm()
public async Task GetCertificateAsync_ReturnsCertificate()
{
var testCertificate = new X509Certificate2(Path.Combine(Directory.GetCurrentDirectory(), "TestData", "rsa_2048.crt"));
var signResult = CryptographyModelFactory.SignResult(
keyId: "https://fake.vault.azure.net/keys/fake-key/123",
signature: new byte[] { 1, 2, 3 },
algorithm: SignatureAlgorithm.RS384);

var keyVaultCertificate = CertificateModelFactory.KeyVaultCertificate(
properties: CertificateModelFactory.CertificateProperties(version: "123"),
cer: testCertificate.RawData);
Expand All @@ -176,6 +207,22 @@ public async Task GetCertificateAsync_ReturnsCertificate()
Assert.Equal(testCertificate.RawData, certificate.RawData);
}

[Fact]
public async Task GetVersionlessCertificateAsync_ReturnCertificate()
{
var testCertificate = new X509Certificate2(Path.Combine(Directory.GetCurrentDirectory(), "TestData", "rsa_2048.crt"));
var keyVaultCertificateWithPolicy = CertificateModelFactory.KeyVaultCertificateWithPolicy(
properties: CertificateModelFactory.CertificateProperties(version: "123"),
cer: testCertificate.RawData);

var keyVaultClient = CreateMockedKeyVaultClient(keyVaultCertificateWithPolicy);
var certificate = await keyVaultClient.GetCertificateAsync();

Assert.NotNull(certificate);
Assert.IsType<X509Certificate2>(certificate);
Assert.Equal(testCertificate.RawData, certificate.RawData);
}

[Fact]
public async Task GetCertificateAsyncThrowValidationException()
{
Expand All @@ -191,7 +238,7 @@ public async Task GetCertificateAsyncThrowValidationException()

var keyVaultClient = CreateMockedKeyVaultClient(keyVaultCertificate);

await Assert.ThrowsAsync<ValidationException>(async () => await keyVaultClient.GetCertificateAsync());
await Assert.ThrowsAsync<PluginException>(keyVaultClient.GetCertificateAsync);
}

[Fact]
Expand Down
3 changes: 1 addition & 2 deletions Notation.Plugin.AzureKeyVault/Command/DescribeKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ public class DescribeKey : IPluginCommand
{
private DescribeKeyRequest _request;
private IKeyVaultClient _keyVaultClient;
private const string invalidInputError = "Invalid input. The valid input format is '{\"contractVersion\":\"1.0\",\"keyId\":\"https://<vaultname>.vault.azure.net/<keys|certificate>/<name>/<version>\"}'";

/// <summary>
/// Constructor to create DescribeKey object from JSON string.
Expand All @@ -23,7 +22,7 @@ public DescribeKey(string inputJson)
var request = JsonSerializer.Deserialize(inputJson, DescribeKeyRequestContext.Default.DescribeKeyRequest);
if (request == null)
{
throw new ValidationException(invalidInputError);
throw new ValidationException("Failed to parse the request in JSON format. Please contact Notation maintainers to resolve the issue.");
}
this._request = request;
this._keyVaultClient = new KeyVaultClient(
Expand Down
2 changes: 1 addition & 1 deletion Notation.Plugin.AzureKeyVault/Command/GenerateSignature.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public GenerateSignature(string inputJson)
var request = JsonSerializer.Deserialize(inputJson, GenerateSignatureRequestContext.Default.GenerateSignatureRequest);
if (request == null)
{
throw new ValidationException("Invalid input");
throw new ValidationException("Failed to parse the request in JSON format. Please contact Notation maintainers to resolve the issue.");
}
this._request = request;
this._keyVaultClient = new KeyVaultClient(
Expand Down
87 changes: 52 additions & 35 deletions Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class KeyVaultClient : IKeyVaultClient
/// <summary>
/// A helper record to store KeyVault metadata.
/// </summary>
private record KeyVaultMetadata(string KeyVaultUrl, string Name, string Version);
private record KeyVaultMetadata(string KeyVaultUrl, string Name, string? Version);

// Certificate client (lazy initialization)
// Protected for unit test
Expand All @@ -43,50 +43,52 @@ private record KeyVaultMetadata(string KeyVaultUrl, string Name, string Version)
protected Lazy<CryptographyClient> _cryptoClient;
// Secret client (lazy initialization)
protected Lazy<SecretClient> _secretClient;
// Error message for invalid input
private const string INVALID_INPUT_ERROR_MSG = "Invalid input. The valid input format is '{\"contractVersion\":\"1.0\",\"keyId\":\"https://<vaultname>.vault.azure.net/<keys|certificate>/<name>/<version>\"}'";

// Key name or certificate name
private string _name;
// Key version or certificate version
private string _version;
private string? _version;
// Key identifier (e.g. https://<vaultname>.vault.azure.net/keys/<name>/<version>)
private string _keyId;

// Internal getters for unit test
internal string Name => _name;
internal string Version => _version;
internal string? Version => _version;
internal string KeyId => _keyId;

/// <summary>
/// Constructor to create AzureKeyVault object from keyVaultUrl, name
/// and version.
/// </summary>
public KeyVaultClient(string keyVaultUrl, string name, string version, TokenCredential credential)
public KeyVaultClient(string keyVaultUrl, string name, string? version, TokenCredential credential)
{
if (string.IsNullOrEmpty(keyVaultUrl))
{
throw new ArgumentNullException(nameof(keyVaultUrl), "KeyVaultUrl must not be null or empty");
throw new ValidationException("Key vault URL must not be null or empty");
}

if (string.IsNullOrEmpty(name))
{
throw new ArgumentNullException(nameof(name), "KeyName must not be null or empty");
throw new ValidationException("Key name must not be null or empty");
}

if (string.IsNullOrEmpty(version))
if (version != null && version == string.Empty)
{
throw new ArgumentNullException(nameof(version), "KeyVersion must not be null or empty");
throw new ValidationException("Key version must not be empty");
}

this._name = name;
this._version = version;
this._keyId = $"{keyVaultUrl}/keys/{name}/{version}";
_name = name;
_version = version;
_keyId = $"{keyVaultUrl}/keys/{name}";
if (version != null)
{
_keyId = $"{_keyId}/{version}";
}

// initialize credential and lazy clients
this._certificateClient = new Lazy<CertificateClient>(() => new CertificateClient(new Uri(keyVaultUrl), credential));
this._cryptoClient = new Lazy<CryptographyClient>(() => new CryptographyClient(new Uri(_keyId), credential));
this._secretClient = new Lazy<SecretClient>(() => new SecretClient(new Uri(keyVaultUrl), credential));
_certificateClient = new Lazy<CertificateClient>(() => new CertificateClient(new Uri(keyVaultUrl), credential));
_cryptoClient = new Lazy<CryptographyClient>(() => new CryptographyClient(new Uri(_keyId), credential));
_secretClient = new Lazy<SecretClient>(() => new SecretClient(new Uri(keyVaultUrl), credential));
}

/// <summary>
Expand Down Expand Up @@ -115,30 +117,36 @@ private static KeyVaultMetadata ParseId(string id)
{
if (string.IsNullOrEmpty(id))
{
throw new ArgumentNullException(nameof(id), "Id must not be null or empty");
throw new ValidationException("Input passed to \"--id\" must not be empty");
}

var uri = new Uri(id);
var uri = new Uri(id.TrimEnd('/'));
// Validate uri
if (uri.Segments.Length != 4)
if (uri.Segments.Length < 3 || uri.Segments.Length > 4)
{
throw new ValidationException(INVALID_INPUT_ERROR_MSG);
throw new ValidationException("Invalid input passed to \"--id\". Please follow this format to input the ID \"https://<vault-name>.vault.azure.net/certificates/<certificate-name>/[certificate-version]\"");
}

if (uri.Segments[1] != "keys/" && uri.Segments[1] != "certificates/")
var type = uri.Segments[1].TrimEnd('/');
if (type != "keys" && type != "certificates")
{
throw new ValidationException(INVALID_INPUT_ERROR_MSG);
throw new ValidationException($"Unsupported key vualt object type {type}.");
}

if (uri.Scheme != "https")
{
throw new ValidationException(INVALID_INPUT_ERROR_MSG);
throw new ValidationException($"Unsupported scheme {uri.Scheme}. The scheme must be https.");
}

string? version = null;
if (uri.Segments.Length == 4)
{
version = uri.Segments[3].TrimEnd('/');
}
return new KeyVaultMetadata(
KeyVaultUrl: $"{uri.Scheme}://{uri.Host}",
Name: uri.Segments[2].TrimEnd('/'),
Version: uri.Segments[3].TrimEnd('/')
Version: version
);
}

Expand All @@ -148,9 +156,10 @@ private static KeyVaultMetadata ParseId(string id)
public async Task<byte[]> SignAsync(SignatureAlgorithm algorithm, byte[] payload)
{
var signResult = await _cryptoClient.Value.SignDataAsync(algorithm, payload);
if (signResult.KeyId != _keyId)

if (!string.IsNullOrEmpty(_version) && signResult.KeyId != _keyId)
{
throw new PluginException($"Invalid keys identifier. The user provides {_keyId} but the response contains {signResult.KeyId} as the keys");
throw new PluginException($"Invalid key identifier. User required {_keyId} does not match {signResult.KeyId} in response. Please ensure the provided key identifier is correct.");
}

if (signResult.Algorithm != algorithm)
Expand All @@ -166,17 +175,25 @@ public async Task<byte[]> SignAsync(SignatureAlgorithm algorithm, byte[] payload
/// </summary>
public async Task<X509Certificate2> GetCertificateAsync()
{
var cert = await _certificateClient.Value.GetCertificateVersionAsync(_name, _version);

// If the version is invalid, the cert will be fallback to
// the latest. So if the version is not the same as the
// requested version, it means the version is invalid.
if (cert.Value.Properties.Version != _version)
KeyVaultCertificate cert;
if (string.IsNullOrEmpty(_version))
{
throw new ValidationException($"Invalid certificate version. The user provides {_version} but the response contains {cert.Value.Properties.Version} as the version");
// If the version is not specified, get the latest version
cert = (await _certificateClient.Value.GetCertificateAsync(_name)).Value;
}

return new X509Certificate2(cert.Value.Cer);
else
{
cert = (await _certificateClient.Value.GetCertificateVersionAsync(_name, _version)).Value;

// If the version is invalid, the cert will be fallback to
// the latest. So if the version is not the same as the
// requested version, it means the version is invalid.
if (cert.Properties.Version != _version)
{
throw new PluginException($"The version specified in the request is {_version} but the version retrieved from Azure Key Vault is {cert.Properties.Version}. Please ensure the version is correct.");
}
}
return new X509Certificate2(cert.Cer);
}

/// <summary>
Expand Down
Loading

0 comments on commit f82c165

Please sign in to comment.