Skip to content

Commit

Permalink
fixup to include tests of JWK provider
Browse files Browse the repository at this point in the history
  • Loading branch information
pquerna committed Jan 30, 2025
1 parent e7c6ca8 commit 306f215
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pkg/crypto/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,6 @@ func TestEncryptionProviderJWKSymmetric(t *testing.T) {
Bytes: []byte("hunter2"),
}
cipherText, err := provider.Encrypt(ctx, config, plainText)
require.ErrorIs(t, err, jwk.JWKUnsupportedKeyTypeError)
require.ErrorIs(t, err, jwk.ErrJWKUnsupportedKeyType)
require.Nil(t, cipherText)
}
17 changes: 16 additions & 1 deletion pkg/crypto/providers/jwk/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,22 @@ func EncryptECDSA(pubKey *ecdsa.PublicKey, plaintext []byte) ([]byte, error) {
}

func DecryptECDSA(privKey *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) {
jwe, err := jose.ParseEncrypted(string(ciphertext))
jwe, err := jose.ParseEncryptedCompact(string(ciphertext),
[]jose.KeyAlgorithm{
jose.ECDH_ES,
jose.ECDH_ES_A128KW,
jose.ECDH_ES_A192KW,
jose.ECDH_ES_A256KW,
},
[]jose.ContentEncryption{
jose.A128CBC_HS256,
jose.A192CBC_HS384,
jose.A256CBC_HS512,
jose.A128GCM,
jose.A192GCM,
jose.A256GCM,
},
)
if err != nil {
return nil, fmt.Errorf("jwk-ecdsa: failed to parse ciphertext: %w", err)
}
Expand Down
17 changes: 11 additions & 6 deletions pkg/crypto/providers/jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/rsa"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"

Expand All @@ -21,8 +22,8 @@ import (
// TODO(morgabra): Fix the circular dependency/entire registry pattern here.
const EncryptionProviderJwk = "baton/jwk/v1"

var JWKInvalidKeyTypeError = fmt.Errorf("jwk: invalid key type")
var JWKUnsupportedKeyTypeError = fmt.Errorf("jwk: unsupported key type")
var ErrJWKInvalidKeyType = errors.New("jwk: invalid key type")
var ErrJWKUnsupportedKeyType = errors.New("jwk: unsupported key type")

func unmarshalJWK(jwkBytes []byte) (*jose.JSONWebKey, error) {
jwk := &jose.JSONWebKey{}
Expand All @@ -44,6 +45,10 @@ func (j *JWKEncryptionProvider) GenerateKey(ctx context.Context) (*v2.Encryption
privKeyJWK := &jose.JSONWebKey{
Key: privKey,
}
return j.marshalKey(ctx, privKeyJWK)
}

func (j *JWKEncryptionProvider) marshalKey(ctx context.Context, privKeyJWK *jose.JSONWebKey) (*v2.EncryptionConfig, []byte, error) {
privKeyJWKBytes, err := privKeyJWK.MarshalJSON()
if err != nil {
return nil, nil, fmt.Errorf("jwk: failed to marshal private key: %w", err)
Expand All @@ -62,7 +67,7 @@ func (j *JWKEncryptionProvider) GenerateKey(ctx context.Context) (*v2.Encryption

return &v2.EncryptionConfig{
Principal: nil,
Provider: "baton/jwk/v1", // TODO(morgabra): Fix the circular dependency/entire registry pattern.
Provider: EncryptionProviderJwk, // TODO(morgabra): Fix the circular dependency/entire registry pattern.
KeyId: kid,
Config: &v2.EncryptionConfig_JwkPublicKeyConfig{
JwkPublicKeyConfig: &v2.EncryptionConfig_JWKPublicKeyConfig{
Expand Down Expand Up @@ -96,7 +101,7 @@ func (j *JWKEncryptionProvider) Encrypt(ctx context.Context, conf *v2.Encryption
return nil, err
}
default:
return nil, JWKUnsupportedKeyTypeError
return nil, ErrJWKUnsupportedKeyType
}

tp, err := thumbprint(jwk)
Expand Down Expand Up @@ -124,7 +129,7 @@ func (j *JWKEncryptionProvider) Decrypt(ctx context.Context, cipherText *v2.Encr
}

if jwk.IsPublic() {
return nil, fmt.Errorf("%w: key is public", JWKInvalidKeyTypeError)
return nil, fmt.Errorf("%w: key is public", ErrJWKInvalidKeyType)
}

decCipherText, err := base64.StdEncoding.DecodeString(string(cipherText.EncryptedBytes))
Expand All @@ -150,7 +155,7 @@ func (j *JWKEncryptionProvider) Decrypt(ctx context.Context, cipherText *v2.Encr
return nil, err
}
default:
return nil, JWKUnsupportedKeyTypeError
return nil, ErrJWKUnsupportedKeyType
}

return &v2.PlaintextData{
Expand Down
146 changes: 146 additions & 0 deletions pkg/crypto/providers/jwk/jwk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package jwk

import (
"context"
"testing"

"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"

v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
"github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/require"
)

func TestGenerateKey(t *testing.T) {
provider := &JWKEncryptionProvider{}
ctx := context.Background()

config, privKey, err := provider.GenerateKey(ctx)
require.NoError(t, err)
require.NotNil(t, config)
require.NotNil(t, privKey)

jwk, err := unmarshalJWK(privKey)
require.NoError(t, err)
require.False(t, jwk.IsPublic())
}

func TestEncryptDecrypt(t *testing.T) {
provider := &JWKEncryptionProvider{}
ctx := context.Background()

config, privKey, err := provider.GenerateKey(ctx)
require.NoError(t, err)
require.NotNil(t, config)
require.NotNil(t, privKey)

plainText := &v2.PlaintextData{
Name: "test",
Description: "test description",
Schema: "test schema",
Bytes: []byte("test data"),
}

encryptedData, err := provider.Encrypt(ctx, config, plainText)
require.NoError(t, err)
require.NotNil(t, encryptedData)

decryptedData, err := provider.Decrypt(ctx, encryptedData, privKey)
require.NoError(t, err)
require.NotNil(t, decryptedData)
require.Equal(t, plainText.Bytes, decryptedData.Bytes)
}

func TestInvalidKeyType(t *testing.T) {
provider := &JWKEncryptionProvider{}
ctx := context.Background()

_, privKey, err := provider.GenerateKey(ctx)
require.NoError(t, err)

jwk, err := unmarshalJWK(privKey)
require.NoError(t, err)

jwk.Key = "invalid key type"
privKeyBytes, err := jwk.MarshalJSON()
require.NoError(t, err)

plainText := &v2.PlaintextData{
Name: "test",
Description: "test description",
Schema: "test schema",
Bytes: []byte("test data"),
}

_, err = provider.Encrypt(ctx, &v2.EncryptionConfig{
Config: &v2.EncryptionConfig_JwkPublicKeyConfig{
JwkPublicKeyConfig: &v2.EncryptionConfig_JWKPublicKeyConfig{
PubKey: privKeyBytes,
},
},
}, plainText)
require.Error(t, err)
require.Equal(t, ErrJWKUnsupportedKeyType, err)
}

func TestEncryptDecryptECDSAKey(t *testing.T) {
provider := &JWKEncryptionProvider{}
ctx := context.Background()

privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
require.NotNil(t, privKey)

jwk := jose.JSONWebKey{Key: privKey, Algorithm: string(jose.ES256)}
pubConfig, privKeyBytes, err := provider.marshalKey(context.TODO(), &jwk)
require.NoError(t, err)
require.NotNil(t, privKeyBytes)

plainText := &v2.PlaintextData{
Name: "test",
Description: "test description",
Schema: "test schema",
Bytes: []byte("test data"),
}
encryptedData, err := provider.Encrypt(ctx, pubConfig, plainText)
require.NoError(t, err)
require.NotNil(t, encryptedData)

decryptedData, err := provider.Decrypt(ctx, encryptedData, privKeyBytes)
require.NoError(t, err)
require.NotNil(t, decryptedData)
require.Equal(t, plainText.Bytes, decryptedData.Bytes)
}

func TestEncryptDecryptRSA1024Key(t *testing.T) {
provider := &JWKEncryptionProvider{}
ctx := context.Background()

privKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
require.NotNil(t, privKey)

jwk := jose.JSONWebKey{Key: privKey, Use: "sig", Algorithm: string(jose.RS256)}
pubConfig, privKeyBytes, err := provider.marshalKey(context.TODO(), &jwk)
require.NoError(t, err)
require.NotNil(t, privKeyBytes)

plainText := &v2.PlaintextData{
Name: "test",
Description: "test description",
Schema: "test schema",
Bytes: []byte("test data"),
}

encryptedData, err := provider.Encrypt(ctx, pubConfig, plainText)
require.NoError(t, err)
require.NotNil(t, encryptedData)

decryptedData, err := provider.Decrypt(ctx, encryptedData, privKeyBytes)
require.NoError(t, err)
require.NotNil(t, decryptedData)
require.Equal(t, plainText.Bytes, decryptedData.Bytes)
}

0 comments on commit 306f215

Please sign in to comment.