From be77cbd962c39b5cba33d98adea961231462627b Mon Sep 17 00:00:00 2001 From: Steve Clark <steven.clark@hashicorp.com> Date: Mon, 2 May 2022 11:48:34 -0400 Subject: [PATCH 1/3] WIP: Address codebase for managed key fixes --- builtin/logical/pki/backend.go | 2 + builtin/logical/pki/ca_util.go | 131 +++++++++++---------- builtin/logical/pki/crl_test.go | 3 +- builtin/logical/pki/fields.go | 1 + builtin/logical/pki/key_util.go | 116 ++++++++++++++++++ builtin/logical/pki/managed_key_util.go | 16 ++- builtin/logical/pki/path_intermediate.go | 2 +- builtin/logical/pki/path_manage_issuers.go | 5 +- builtin/logical/pki/path_manage_keys.go | 98 ++++++++------- builtin/logical/pki/path_root.go | 2 +- builtin/logical/pki/storage.go | 91 +++++++------- builtin/logical/pki/storage_migrations.go | 3 +- builtin/logical/pki/storage_test.go | 17 +-- sdk/helper/certutil/helpers.go | 19 +-- 14 files changed, 330 insertions(+), 176 deletions(-) create mode 100644 builtin/logical/pki/key_util.go diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index 3629dab1572e..e168bb8ea94d 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -165,6 +165,7 @@ func Backend(conf *logical.BackendConfig) *backend { b.tidyCASGuard = new(uint32) b.tidyStatus = &tidyStatus{state: tidyStatusInactive} b.storage = conf.StorageView + b.backendUuid = conf.BackendUUID b.pkiStorageVersion.Store(0) @@ -175,6 +176,7 @@ func Backend(conf *logical.BackendConfig) *backend { type backend struct { *framework.Backend + backendUuid string storage logical.Storage crlLifetime time.Duration revokeStorageLock sync.RWMutex diff --git a/builtin/logical/pki/ca_util.go b/builtin/logical/pki/ca_util.go index 5e81c7fe6031..680139a56c13 100644 --- a/builtin/logical/pki/ca_util.go +++ b/builtin/logical/pki/ca_util.go @@ -5,7 +5,6 @@ import ( "crypto" "crypto/ecdsa" "crypto/rsa" - "encoding/pem" "errors" "fmt" "io" @@ -38,8 +37,8 @@ func (b *backend) getGenerationParams(ctx context.Context, storage logical.Stora `the "format" path parameter must be "pem", "der", or "pem_bundle"`) return } - - keyType, keyBits, err := getKeyTypeAndBitsForRole(ctx, b, storage, data, mountPoint) + mkc := newManagedKeyContext(ctx, b, mountPoint) + keyType, keyBits, err := getKeyTypeAndBitsForRole(mkc, storage, data) if err != nil { errorResp = logical.ErrorResponse(err.Error()) return @@ -77,7 +76,11 @@ func (b *backend) getGenerationParams(ctx context.Context, storage logical.Stora func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data *certutil.CreationBundle, randomSource io.Reader) (*certutil.ParsedCertBundle, error) { if kmsRequested(input) { - return generateManagedKeyCABundle(ctx, b, input, data, randomSource) + keyId, err := getManagedKeyId(input.apiData) + if err != nil { + return nil, err + } + return generateManagedKeyCABundle(ctx, b, input, keyId, data, randomSource) } if existingKeyRequested(input) { @@ -85,7 +88,21 @@ func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data if err != nil { return nil, err } - return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef)) + + keyEntry, err := getExistingKeyFromRef(ctx, input.req.Storage, keyRef) + if err != nil { + return nil, err + } + + if keyEntry.isManagedPrivateKey() { + keyId, err := keyEntry.getManagedKeyUUID() + if err != nil { + return nil, err + } + return generateManagedKeyCABundle(ctx, b, input, keyId, data, randomSource) + } + + return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry)) } return certutil.CreateCertificateWithRandomSource(data, randomSource) @@ -93,7 +110,12 @@ func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data func generateCSRBundle(ctx context.Context, b *backend, input *inputBundle, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (*certutil.ParsedCSRBundle, error) { if kmsRequested(input) { - return generateManagedKeyCSRBundle(ctx, b, input, data, addBasicConstraints, randomSource) + keyId, err := getManagedKeyId(input.apiData) + if err != nil { + return nil, err + } + + return generateManagedKeyCSRBundle(ctx, b, input, keyId, data, addBasicConstraints, randomSource) } if existingKeyRequested(input) { @@ -101,7 +123,21 @@ func generateCSRBundle(ctx context.Context, b *backend, input *inputBundle, data if err != nil { return nil, err } - return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef)) + + key, err := getExistingKeyFromRef(ctx, input.req.Storage, keyRef) + if err != nil { + return nil, err + } + + if key.isManagedPrivateKey() { + keyId, err := key.getManagedKeyUUID() + if err != nil { + return nil, err + } + return generateManagedKeyCSRBundle(ctx, b, input, keyId, data, addBasicConstraints, randomSource) + } + + return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key)) } return certutil.CreateCSRWithRandomSource(data, addBasicConstraints, randomSource) @@ -114,7 +150,7 @@ func parseCABundle(ctx context.Context, b *backend, req *logical.Request, bundle return bundle.ToParsedCertBundle() } -func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.Storage, data *framework.FieldData, mountPoint string) (string, int, error) { +func getKeyTypeAndBitsForRole(mkc managedKeyContext, storage logical.Storage, data *framework.FieldData) (string, int, error) { exportedStr := data.Get("exported").(string) var keyType string var keyBits int @@ -138,7 +174,12 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.S var pubKey crypto.PublicKey if kmsRequestedFromFieldData(data) { - pubKeyManagedKey, err := getManagedKeyPublicKey(ctx, b, data, mountPoint) + keyId, err := getManagedKeyId(data) + if err != nil { + return "", 0, errors.New("unable to determine managed key id" + err.Error()) + } + + pubKeyManagedKey, err := getManagedKeyPublicKey(mkc, keyId) if err != nil { return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error()) } @@ -146,95 +187,67 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.S } if existingKeyRequestedFromFieldData(data) { - existingPubKey, err := getExistingPublicKey(ctx, storage, data) + existingPubKey, err := getExistingPublicKey(mkc, storage, data) if err != nil { return "", 0, errors.New("failed to lookup public key from existing key: " + err.Error()) } pubKey = existingPubKey } - return getKeyTypeAndBitsFromPublicKeyForRole(pubKey) + privateKeyType, keyBits, err := getKeyTypeAndBitsFromPublicKeyForRole(pubKey) + return string(privateKeyType), keyBits, err } -func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) { +func getExistingPublicKey(mkc managedKeyContext, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) { keyRef, err := getKeyRefWithErr(data) if err != nil { return nil, err } - id, err := resolveKeyReference(ctx, s, keyRef) - if err != nil { - return nil, err - } - key, err := fetchKeyById(ctx, s, id) + id, err := resolveKeyReference(mkc.ctx, s, keyRef) if err != nil { return nil, err } - signer, err := key.GetSigner() + key, err := fetchKeyById(mkc.ctx, s, id) if err != nil { return nil, err } - return signer.Public(), nil + return getPublicKey(mkc, key) } -func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (string, int, error) { - var keyType string +func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.PrivateKeyType, int, error) { + var keyType certutil.PrivateKeyType var keyBits int switch pubKey.(type) { case *rsa.PublicKey: - keyType = "rsa" + keyType = certutil.RSAPrivateKey keyBits = certutil.GetPublicKeySize(pubKey) case *ecdsa.PublicKey: - keyType = "ec" + keyType = certutil.ECPrivateKey case *ed25519.PublicKey: - keyType = "ed25519" + keyType = certutil.Ed25519PrivateKey default: - return "", 0, fmt.Errorf("unsupported public key: %#v", pubKey) + return certutil.UnknownPrivateKey, 0, fmt.Errorf("unsupported public key: %#v", pubKey) } return keyType, keyBits, nil } -func getManagedKeyPublicKey(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (crypto.PublicKey, error) { - keyId, err := getManagedKeyId(data) +func getExistingKeyFromRef(ctx context.Context, s logical.Storage, keyRef string) (*keyEntry, error) { + keyId, err := resolveKeyReference(ctx, s, keyRef) if err != nil { - return nil, errors.New("unable to determine managed key id") - } - // Determine key type and key bits from the managed public key - var pubKey crypto.PublicKey - err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error { - pubKey, err = key.GetPublicKey(ctx) - if err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, errors.New("failed to lookup public key from managed key: " + err.Error()) + return nil, err } - return pubKey, nil + return fetchKeyById(ctx, s, keyId) } -func existingGeneratePrivateKey(ctx context.Context, s logical.Storage, keyRef string) certutil.KeyGenerator { - return func(keyType string, keyBits int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error { - keyId, err := resolveKeyReference(ctx, s, keyRef) - if err != nil { - return err - } - key, err := fetchKeyById(ctx, s, keyId) +func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator { + return func(_ string, _ int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error { + signer, _, pemBytes, err := getSignerFromKeyEntryBytes(key) if err != nil { return err } - signer, err := key.GetSigner() - if err != nil { - return err - } - privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer) - if privateKeyType == certutil.UnknownPrivateKey { - return errors.New("unknown private key type loaded from key id: " + keyId.String()) - } - blk, _ := pem.Decode([]byte(key.PrivateKey)) - container.SetParsedPrivateKey(signer, privateKeyType, blk.Bytes) + + container.SetParsedPrivateKey(signer, key.PrivateKeyType, pemBytes.Bytes) return nil } } diff --git a/builtin/logical/pki/crl_test.go b/builtin/logical/pki/crl_test.go index 368a412ca8a6..20ed3403e15d 100644 --- a/builtin/logical/pki/crl_test.go +++ b/builtin/logical/pki/crl_test.go @@ -131,10 +131,11 @@ func TestBackend_CRL_EnableDisable(t *testing.T) { func TestBackend_Secondary_CRL_Rebuilding(t *testing.T) { ctx := context.Background() b, s := createBackendWithStorage(t) + mkc := newManagedKeyContext(ctx, b, "test") // Write out the issuer/key to storage without going through the api call as replication would. bundle := genCertBundle(t, b, s) - issuer, _, err := writeCaBundle(ctx, s, bundle, "", "") + issuer, _, err := writeCaBundle(mkc, s, bundle, "", "") require.NoError(t, err) // Just to validate, before we call the invalidate function, make sure our CRL has not been generated diff --git a/builtin/logical/pki/fields.go b/builtin/logical/pki/fields.go index 45c329bd8acf..b53ddea36664 100644 --- a/builtin/logical/pki/fields.go +++ b/builtin/logical/pki/fields.go @@ -8,6 +8,7 @@ const ( keyRefParam = "key_ref" keyIdParam = "key_id" keyTypeParam = "key_type" + keyBitsParam = "key_bits" ) // addIssueAndSignCommonFields adds fields common to both CA and non-CA issuing diff --git a/builtin/logical/pki/key_util.go b/builtin/logical/pki/key_util.go new file mode 100644 index 000000000000..7c8ba83f6175 --- /dev/null +++ b/builtin/logical/pki/key_util.go @@ -0,0 +1,116 @@ +package pki + +import ( + "context" + "crypto" + "encoding/pem" + "errors" + "fmt" + "github.com/hashicorp/vault/sdk/helper/certutil" + "github.com/hashicorp/vault/sdk/helper/errutil" + "github.com/hashicorp/vault/sdk/logical" +) + +type managedKeyContext struct { + ctx context.Context + b *backend + mountPoint string +} + +func newManagedKeyContext(ctx context.Context, b *backend, mountPoint string) managedKeyContext { + return managedKeyContext{ + ctx: ctx, + b: b, + mountPoint: mountPoint, + } +} + +func comparePublicKey(ctx managedKeyContext, key *keyEntry, publicKey crypto.PublicKey) (bool, error) { + publicKeyForKeyEntry, err := getPublicKey(ctx, key) + if err != nil { + return false, err + } + + return certutil.ComparePublicKeysAndType(publicKeyForKeyEntry, publicKey) +} + +func getPublicKey(mkc managedKeyContext, key *keyEntry) (crypto.PublicKey, error) { + if key.PrivateKeyType == certutil.ManagedPrivateKey { + keyId, err := extractManagedKeyId([]byte(key.PrivateKey)) + if err != nil { + return nil, err + } + return getManagedKeyPublicKey(mkc, keyId) + } + + signer, _, _, err := getSignerFromKeyEntryBytes(key) + if err != nil { + return nil, err + } + return signer.Public(), nil +} + +func getSignerFromKeyEntryBytes(key *keyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) { + if key.PrivateKeyType == certutil.UnknownPrivateKey { + return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported unknown private key type for key: %s (%s)", key.ID, key.Name)} + } + + if key.PrivateKeyType == certutil.ManagedPrivateKey { + return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("can not get a signer from a managed key: %s (%s)", key.ID, key.Name)} + } + + bytes, blockType, blk, err := getSignerFromBytes([]byte(key.PrivateKey)) + if err != nil { + return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("failed parsing key entry bytes for key id: %s (%s): %s", key.ID, key.Name, err.Error())} + } + + return bytes, blockType, blk, nil +} + +func getSignerFromBytes(keyBytes []byte) (crypto.Signer, certutil.BlockType, *pem.Block, error) { + pemBlock, _ := pem.Decode(keyBytes) + if pemBlock == nil { + return nil, certutil.UnknownBlock, pemBlock, errutil.InternalError{Err: "no data found in PEM block"} + } + + signer, blk, err := certutil.ParseDERKey(pemBlock.Bytes) + if err != nil { + return nil, certutil.UnknownBlock, pemBlock, errutil.InternalError{Err: fmt.Sprintf("failed to parse PEM block: %s", err.Error())} + } + return signer, blk, pemBlock, nil +} + +func getManagedKeyPublicKey(mkc managedKeyContext, keyId managedKeyId) (crypto.PublicKey, error) { + // Determine key type and key bits from the managed public key + var pubKey crypto.PublicKey + err := withManagedPKIKey(mkc.ctx, mkc.b, keyId, mkc.mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error { + var myErr error + pubKey, myErr = key.GetPublicKey(ctx) + if myErr != nil { + return myErr + } + + return nil + }) + if err != nil { + return nil, errors.New("failed to lookup public key from managed key: " + err.Error()) + } + return pubKey, nil +} + +func importKeyFromBytes(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) { + signer, _, _, err := getSignerFromBytes([]byte(keyValue)) + if err != nil { + return nil, false, err + } + privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer) + if privateKeyType == certutil.UnknownPrivateKey { + return nil, false, errors.New("unsupported private key type within pem bundle") + } + + key, existed, err := importKey(mkc, s, keyValue, keyName, privateKeyType) + if err != nil { + return nil, false, err + } + return key, existed, nil +} diff --git a/builtin/logical/pki/managed_key_util.go b/builtin/logical/pki/managed_key_util.go index 45d80d643dcd..a69ac056bfc4 100644 --- a/builtin/logical/pki/managed_key_util.go +++ b/builtin/logical/pki/managed_key_util.go @@ -13,18 +13,26 @@ import ( var errEntOnly = errors.New("managed keys are supported within enterprise edition only") -func generateManagedKeyCABundle(_ context.Context, _ *backend, _ *inputBundle, _ *certutil.CreationBundle, _ io.Reader) (*certutil.ParsedCertBundle, error) { +func generateManagedKeyCABundle(ctx context.Context, b *backend, input *inputBundle, keyId managedKeyId, data *certutil.CreationBundle, randomSource io.Reader) (bundle *certutil.ParsedCertBundle, err error) { return nil, errEntOnly } -func generateManagedKeyCSRBundle(_ context.Context, _ *backend, _ *inputBundle, _ *certutil.CreationBundle, _ bool, _ io.Reader) (*certutil.ParsedCSRBundle, error) { +func generateManagedKeyCSRBundle(ctx context.Context, b *backend, input *inputBundle, keyId managedKeyId, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (bundle *certutil.ParsedCSRBundle, err error) { return nil, errEntOnly } -func parseManagedKeyCABundle(_ context.Context, _ *backend, _ *logical.Request, _ *certutil.CertBundle) (*certutil.ParsedCertBundle, error) { +func parseManagedKeyCABundle(ctx context.Context, b *backend, req *logical.Request, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) { return nil, errEntOnly } -func withManagedPKIKey(_ context.Context, _ *backend, _ managedKeyId, _ string, _ logical.ManagedSigningKeyConsumer) error { +func withManagedPKIKey(ctx context.Context, b *backend, keyId managedKeyId, mountPoint string, f logical.ManagedSigningKeyConsumer) error { return errEntOnly } + +func extractManagedKeyId(privateKeyBytes []byte) (UUIDKey, error) { + return "", errEntOnly +} + +func createKmsKeyBundle(mkc managedKeyContext, keyId managedKeyId) (certutil.KeyBundle, certutil.PrivateKeyType, error) { + return certutil.KeyBundle{}, certutil.UnknownPrivateKey, errEntOnly +} diff --git a/builtin/logical/pki/path_intermediate.go b/builtin/logical/pki/path_intermediate.go index a0b69e1a481c..bf0b416bf4a4 100644 --- a/builtin/logical/pki/path_intermediate.go +++ b/builtin/logical/pki/path_intermediate.go @@ -130,7 +130,7 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req } } - myKey, _, err := importKey(ctx, req.Storage, csrb.PrivateKey, keyName) + myKey, _, err := importKey(newManagedKeyContext(ctx, b, req.MountPoint), req.Storage, csrb.PrivateKey, keyName, csrb.PrivateKeyType) if err != nil { return nil, err } diff --git a/builtin/logical/pki/path_manage_issuers.go b/builtin/logical/pki/path_manage_issuers.go index 0138755c8196..ea7abcfa08c1 100644 --- a/builtin/logical/pki/path_manage_issuers.go +++ b/builtin/logical/pki/path_manage_issuers.go @@ -181,9 +181,10 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d return logical.ErrorResponse("private keys found in the PEM bundle but not allowed by the path; use /issuers/import/bundle"), nil } + mkc := newManagedKeyContext(ctx, b, req.MountPoint) for keyIndex, keyPem := range keys { // Handle import of private key. - key, existing, err := importKey(ctx, req.Storage, keyPem, "") + key, existing, err := importKeyFromBytes(mkc, req.Storage, keyPem, "") if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error parsing key %v: %v", keyIndex, err)), nil } @@ -194,7 +195,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d } for certIndex, certPem := range issuers { - cert, existing, err := importIssuer(ctx, req.Storage, certPem, "") + cert, existing, err := importIssuer(mkc, req.Storage, certPem, "") if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error parsing issuer %v: %v\n%v", certIndex, err, certPem)), nil } diff --git a/builtin/logical/pki/path_manage_keys.go b/builtin/logical/pki/path_manage_keys.go index 04195aae2ed5..cce0fde52bd1 100644 --- a/builtin/logical/pki/path_manage_keys.go +++ b/builtin/logical/pki/path_manage_keys.go @@ -11,7 +11,7 @@ import ( func pathGenerateKey(b *backend) *framework.Path { return &framework.Path{ - Pattern: "keys/generate/(internal|exported)", + Pattern: "keys/generate/(internal|exported|kms)", Fields: map[string]*framework.FieldSchema{ keyNameParam: { @@ -23,15 +23,27 @@ func pathGenerateKey(b *backend) *framework.Path { Default: "rsa", Description: `Type of the secret key to generate`, }, - "key_bits": { + keyBitsParam: { Type: framework.TypeInt, Default: 2048, Description: `Type of the secret key to generate`, }, + "managed_key_name": { + Type: framework.TypeString, + Description: `The name of the managed key to use when the exported +type is kms. When kms type is the key type, this field or managed_key_id +is required. Ignored for other types.`, + }, + "managed_key_id": { + Type: framework.TypeString, + Description: `The name of the managed key to use when the exported +type is kms. When kms type is the key type, this field or managed_key_name +is required. Ignored for other types.`, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ - logical.CreateOperation: &framework.PathOperation{ + logical.UpdateOperation: &framework.PathOperation{ Callback: b.pathGenerateKeyHandler, ForwardPerformanceStandby: true, ForwardPerformanceSecondary: true, @@ -58,60 +70,60 @@ func (b *backend) pathGenerateKeyHandler(ctx context.Context, req *logical.Reque if err != nil { // Fail Immediately if Key Name is in Use, etc... return nil, err } - keyType := data.Get(keyTypeParam).(string) - keyBits := data.Get("key_bits").(int) + mkc := newManagedKeyContext(ctx, b, req.MountPoint) + exportPrivateKey := false + var keyBundle certutil.KeyBundle + var actualPrivateKeyType certutil.PrivateKeyType switch { + case strings.HasSuffix(req.Path, "/exported"): + exportPrivateKey = true + fallthrough case strings.HasSuffix(req.Path, "/internal"): + keyType := data.Get(keyTypeParam).(string) + keyBits := data.Get(keyBitsParam).(int) + // Internal key generation, stored in storage - keyBundle, err := certutil.GetKeyBundleFromKeyGenerator(keyType, keyBits, nil) - if err != nil { - return nil, err - } - privateKeyPemString, err := keyBundle.ToPrivateKeyPemString() - if err != nil { - return nil, err - } - key, _, err := importKey(ctx, req.Storage, privateKeyPemString, keyName) - if err != nil { - return nil, err - } - resp := logical.Response{ - Data: map[string]interface{}{ - keyIdParam: key.ID, - keyNameParam: key.Name, - keyTypeParam: key.PrivateKeyType, - }, - } - return &resp, nil - case strings.HasSuffix(req.Path, "/exported"): - // Same as internal key generation but we return the generated key - keyBundle, err := certutil.GetKeyBundleFromKeyGenerator(keyType, keyBits, nil) + keyBundle, err = certutil.CreateKeyBundle(keyType, keyBits, b.GetRandomReader()) if err != nil { return nil, err } - privateKeyPemString, err := keyBundle.ToPrivateKeyPemString() + + actualPrivateKeyType = keyBundle.PrivateKeyType + case strings.HasSuffix(req.Path, "/kms"): + keyId, err := getManagedKeyId(data) if err != nil { return nil, err } - key, _, err := importKey(ctx, req.Storage, privateKeyPemString, keyName) + + keyBundle, actualPrivateKeyType, err = createKmsKeyBundle(mkc, keyId) if err != nil { return nil, err } - resp := logical.Response{ - Data: map[string]interface{}{ - keyIdParam: key.ID, - keyNameParam: key.Name, - keyTypeParam: key.PrivateKeyType, - "private_key": privateKeyPemString, - }, - } - return &resp, nil - case strings.HasSuffix(req.Path, "/kms"): - return nil, errEntOnly default: return logical.ErrorResponse("Unknown type of key to generate"), nil } + + privateKeyPemString, err := keyBundle.ToPrivateKeyPemString() + if err != nil { + return nil, err + } + + key, _, err := importKey(mkc, req.Storage, privateKeyPemString, keyName, keyBundle.PrivateKeyType) + if err != nil { + return nil, err + } + responseData := map[string]interface{}{ + keyIdParam: key.ID, + keyNameParam: key.Name, + keyTypeParam: string(actualPrivateKeyType), + } + if exportPrivateKey { + responseData["private_key"] = privateKeyPemString + } + return &logical.Response{ + Data: responseData, + }, nil } func pathImportKey(b *backend) *framework.Path { @@ -161,7 +173,8 @@ func (b *backend) pathImportKeyHandler(ctx context.Context, req *logical.Request keyValue := keyValueInterface.(string) keyName := data.Get(keyNameParam).(string) - key, existed, err := importKey(ctx, req.Storage, keyValue, keyName) + mkc := newManagedKeyContext(ctx, b, req.MountPoint) + key, existed, err := importKeyFromBytes(mkc, req.Storage, keyValue, keyName) if err != nil { return logical.ErrorResponse(err.Error()), nil } @@ -171,7 +184,6 @@ func (b *backend) pathImportKeyHandler(ctx context.Context, req *logical.Request keyIdParam: key.ID, keyNameParam: key.Name, keyTypeParam: key.PrivateKeyType, - "backing": "", // This would show up as "Managed" in "type" }, } diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index 66ac5b0a2666..121c7c4b5b39 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -199,7 +199,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, } // Store it as the CA bundle - myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName, keyName) + myIssuer, myKey, err := writeCaBundle(newManagedKeyContext(ctx, b, req.MountPoint), req.Storage, cb, issuerName, keyName) if err != nil { return nil, err } diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go index 74812364a5e5..a91afbf2d05b 100644 --- a/builtin/logical/pki/storage.go +++ b/builtin/logical/pki/storage.go @@ -2,7 +2,6 @@ package pki import ( "context" - "crypto" "crypto/x509" "encoding/pem" "fmt" @@ -55,6 +54,17 @@ type keyEntry struct { PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` } +func (e keyEntry) getManagedKeyUUID() (UUIDKey, error) { + if !e.isManagedPrivateKey() { + return "", errutil.InternalError{Err: "getManagedKeyId called on a key id %s (%s) "} + } + return extractManagedKeyId([]byte(e.PrivateKey)) +} + +func (e keyEntry) isManagedPrivateKey() bool { + return e.PrivateKeyType == certutil.ManagedPrivateKey +} + type issuerEntry struct { ID issuerID `json:"id" structs:"id" mapstructure:"id"` Name string `json:"name" structs:"name" mapstructure:"name"` @@ -79,11 +89,6 @@ type issuerConfigEntry struct { DefaultIssuerId issuerID `json:"default" structs:"default" mapstructure:"default"` } -func (k keyEntry) GetSigner() (crypto.Signer, error) { - signer, _, err := certutil.ParsePEMKey(k.PrivateKey) - return signer, err -} - func listKeys(ctx context.Context, s logical.Storage) ([]keyID, error) { strList, err := s.List(ctx, keyPrefix) if err != nil { @@ -149,7 +154,7 @@ func deleteKey(ctx context.Context, s logical.Storage, id keyID) (bool, error) { return wasDefault, s.Delete(ctx, keyPrefix+id.String()) } -func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) { +func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string, keyType certutil.PrivateKeyType) (*keyEntry, bool, error) { // importKey imports the specified PEM-format key (from keyValue) into // the new PKI storage format. The first return field is a reference to // the new key; the second is whether or not the key already existed @@ -164,22 +169,13 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName // Before we can import a known key, we first need to know if the key // exists in storage already. This means iterating through all known // keys and comparing their private value against this value. - knownKeys, err := listKeys(ctx, s) - if err != nil { - return nil, false, err - } - - // Before we return below, we need to iterate over _all_ issuers and see if - // one of them has a missing KeyId link, and if so, point it back to - // ourselves. We fetch the list of issuers up front, even when don't need - // it, to give ourselves a better chance of succeeding below. - knownIssuers, err := listIssuers(ctx, s) + knownKeys, err := listKeys(mkc.ctx, s) if err != nil { return nil, false, err } for _, identifier := range knownKeys { - existingKey, err := fetchKeyById(ctx, s, identifier) + existingKey, err := fetchKeyById(mkc.ctx, s, identifier) if err != nil { return nil, false, err } @@ -197,25 +193,30 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName result.ID = genKeyId() result.Name = keyName result.PrivateKey = keyValue + result.PrivateKeyType = keyType - // Extracting the signer is necessary for two reasons: first, to get the - // public key for comparison with existing issuers; second, to get the - // corresponding private key type. - keySigner, err := result.GetSigner() + keyPublic, err := getPublicKey(mkc, &result) if err != nil { return nil, false, err } - keyPublic := keySigner.Public() - result.PrivateKeyType = certutil.GetPrivateKeyTypeFromSigner(keySigner) // Finally, we can write the key to storage. - if err := writeKey(ctx, s, result); err != nil { + if err := writeKey(mkc.ctx, s, result); err != nil { + return nil, false, err + } + + // Before we return below, we need to iterate over _all_ issuers and see if + // one of them has a missing KeyId link, and if so, point it back to + // ourselves. We fetch the list of issuers up front, even when don't need + // it, to give ourselves a better chance of succeeding below. + knownIssuers, err := listIssuers(mkc.ctx, s) + if err != nil { return nil, false, err } // Now, for each issuer, try and compute the issuer<->key link if missing. for _, identifier := range knownIssuers { - existingIssuer, err := fetchIssuerById(ctx, s, identifier) + existingIssuer, err := fetchIssuerById(mkc.ctx, s, identifier) if err != nil { return nil, false, err } @@ -243,7 +244,7 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName // These public keys are equal, so this key entry must be the // corresponding private key to this issuer; update it accordingly. existingIssuer.KeyID = result.ID - if err := writeIssuer(ctx, s, existingIssuer); err != nil { + if err := writeIssuer(mkc.ctx, s, existingIssuer); err != nil { return nil, false, err } } @@ -251,12 +252,12 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName // If there was no prior default value set and/or we had no known // keys when we started, set this key as default. - keyDefaultSet, err := isDefaultKeySet(ctx, s) + keyDefaultSet, err := isDefaultKeySet(mkc.ctx, s) if err != nil { return nil, false, err } if len(knownKeys) == 0 || !keyDefaultSet { - if err = updateDefaultKeyId(ctx, s, result.ID); err != nil { + if err = updateDefaultKeyId(mkc.ctx, s, result.ID); err != nil { return nil, false, err } } @@ -384,7 +385,7 @@ func deleteIssuer(ctx context.Context, s logical.Storage, id issuerID) (bool, er return wasDefault, s.Delete(ctx, issuerPrefix+id.String()) } -func importIssuer(ctx context.Context, s logical.Storage, certValue string, issuerName string) (*issuerEntry, bool, error) { +func importIssuer(ctx managedKeyContext, s logical.Storage, certValue string, issuerName string) (*issuerEntry, bool, error) { // importIssuers imports the specified PEM-format certificate (from // certValue) into the new PKI storage format. The first return field is a // reference to the new issuer; the second is whether or not the issuer @@ -409,7 +410,7 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu // Before we can import a known issuer, we first need to know if the issuer // exists in storage already. This means iterating through all known // issuers and comparing their private value against this value. - knownIssuers, err := listIssuers(ctx, s) + knownIssuers, err := listIssuers(ctx.ctx, s) if err != nil { return nil, false, err } @@ -418,13 +419,13 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu // one of them a public key matching this certificate, and if so, update our // link accordingly. We fetch the list of keys up front, even may not need // it, to give ourselves a better chance of succeeding below. - knownKeys, err := listKeys(ctx, s) + knownKeys, err := listKeys(ctx.ctx, s) if err != nil { return nil, false, err } for _, identifier := range knownIssuers { - existingIssuer, err := fetchIssuerById(ctx, s, identifier) + existingIssuer, err := fetchIssuerById(ctx.ctx, s, identifier) if err != nil { return nil, false, err } @@ -470,18 +471,12 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu // writing issuer to storage as we won't need to update the key, only // the issuer. for _, identifier := range knownKeys { - existingKey, err := fetchKeyById(ctx, s, identifier) - if err != nil { - return nil, false, err - } - - // Fetch the signer for its Public() value. - signer, err := existingKey.GetSigner() + existingKey, err := fetchKeyById(ctx.ctx, s, identifier) if err != nil { return nil, false, err } - equal, err := certutil.ComparePublicKeysAndType(issuerCert.PublicKey, signer.Public()) + equal, err := comparePublicKey(ctx, existingKey, issuerCert.PublicKey) if err != nil { return nil, false, err } @@ -498,18 +493,18 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu // Finally, rebuild the chains. In this process, because the provided // reference issuer is non-nil, we'll save this issuer to storage. - if err := rebuildIssuersChains(ctx, s, &result); err != nil { + if err := rebuildIssuersChains(ctx.ctx, s, &result); err != nil { return nil, false, err } // If there was no prior default value set and/or we had no known // issuers when we started, set this issuer as default. - issuerDefaultSet, err := isDefaultIssuerSet(ctx, s) + issuerDefaultSet, err := isDefaultIssuerSet(ctx.ctx, s) if err != nil { return nil, false, err } if len(knownIssuers) == 0 || !issuerDefaultSet { - if err = updateDefaultIssuerId(ctx, s, result.ID); err != nil { + if err = updateDefaultIssuerId(ctx.ctx, s, result.ID); err != nil { return nil, false, err } } @@ -692,19 +687,19 @@ func fetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id issuer return issuer, &bundle, nil } -func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string, keyName string) (*issuerEntry, *keyEntry, error) { - myKey, _, err := importKey(ctx, s, caBundle.PrivateKey, keyName) +func writeCaBundle(mkc managedKeyContext, s logical.Storage, caBundle *certutil.CertBundle, issuerName string, keyName string) (*issuerEntry, *keyEntry, error) { + myKey, _, err := importKey(mkc, s, caBundle.PrivateKey, keyName, caBundle.PrivateKeyType) if err != nil { return nil, nil, err } - myIssuer, _, err := importIssuer(ctx, s, caBundle.Certificate, issuerName) + myIssuer, _, err := importIssuer(mkc, s, caBundle.Certificate, issuerName) if err != nil { return nil, nil, err } for _, cert := range caBundle.CAChain { - if _, _, err = importIssuer(ctx, s, cert, ""); err != nil { + if _, _, err = importIssuer(mkc, s, cert, ""); err != nil { return nil, nil, err } } diff --git a/builtin/logical/pki/storage_migrations.go b/builtin/logical/pki/storage_migrations.go index c5679b53c1d7..3b3f75267465 100644 --- a/builtin/logical/pki/storage_migrations.go +++ b/builtin/logical/pki/storage_migrations.go @@ -77,7 +77,8 @@ func migrateStorage(ctx context.Context, b *backend, s logical.Storage) error { b.Logger().Info("performing PKI migration to new keys/issuers layout") if migrationInfo.legacyBundle != nil { - anIssuer, aKey, err := writeCaBundle(ctx, s, migrationInfo.legacyBundle, "current", "current") + mkc := newManagedKeyContext(ctx, b, b.backendUuid) + anIssuer, aKey, err := writeCaBundle(mkc, s, migrationInfo.legacyBundle, "current", "current") if err != nil { return err } diff --git a/builtin/logical/pki/storage_test.go b/builtin/logical/pki/storage_test.go index 8b752cd73d00..1d2d93acc783 100644 --- a/builtin/logical/pki/storage_test.go +++ b/builtin/logical/pki/storage_test.go @@ -2,7 +2,6 @@ package pki import ( "context" - "crypto/rand" "strings" "testing" @@ -94,6 +93,8 @@ func Test_IssuerRoundTrip(t *testing.T) { func Test_KeysIssuerImport(t *testing.T) { b, s := createBackendWithStorage(t) + mkc := newManagedKeyContext(ctx, b, "test") + issuer1, key1 := genIssuerAndKey(t, b, s) issuer2, key2 := genIssuerAndKey(t, b, s) @@ -103,21 +104,21 @@ func Test_KeysIssuerImport(t *testing.T) { issuer1.ID = "" issuer1.KeyID = "" - key1Ref1, existing, err := importKey(ctx, s, key1.PrivateKey, "key1") + key1Ref1, existing, err := importKey(mkc, s, key1.PrivateKey, "key1", key1.PrivateKeyType) require.NoError(t, err) require.False(t, existing) require.Equal(t, strings.TrimSpace(key1.PrivateKey), strings.TrimSpace(key1Ref1.PrivateKey)) // Make sure if we attempt to re-import the same private key, no import/updates occur. // So the existing flag should be set to true, and we do not update the existing Name field. - key1Ref2, existing, err := importKey(ctx, s, key1.PrivateKey, "ignore-me") + key1Ref2, existing, err := importKey(mkc, s, key1.PrivateKey, "ignore-me", key1.PrivateKeyType) require.NoError(t, err) require.True(t, existing) require.Equal(t, key1.PrivateKey, key1Ref1.PrivateKey) require.Equal(t, key1Ref1.ID, key1Ref2.ID) require.Equal(t, key1Ref1.Name, key1Ref2.Name) - issuer1Ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "issuer1") + issuer1Ref1, existing, err := importIssuer(mkc, s, issuer1.Certificate, "issuer1") require.NoError(t, err) require.False(t, existing) require.Equal(t, strings.TrimSpace(issuer1.Certificate), strings.TrimSpace(issuer1Ref1.Certificate)) @@ -126,7 +127,7 @@ func Test_KeysIssuerImport(t *testing.T) { // Make sure if we attempt to re-import the same issuer, no import/updates occur. // So the existing flag should be set to true, and we do not update the existing Name field. - issuer1Ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "ignore-me") + issuer1Ref2, existing, err := importIssuer(mkc, s, issuer1.Certificate, "ignore-me") require.NoError(t, err) require.True(t, existing) require.Equal(t, strings.TrimSpace(issuer1.Certificate), strings.TrimSpace(issuer1Ref1.Certificate)) @@ -141,7 +142,7 @@ func Test_KeysIssuerImport(t *testing.T) { require.NoError(t, err) // Same double import tests as above, but make sure if the previous was created through writeIssuer not importIssuer. - issuer2Ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "ignore-me") + issuer2Ref, existing, err := importIssuer(mkc, s, issuer2.Certificate, "ignore-me") require.NoError(t, err) require.True(t, existing) require.Equal(t, strings.TrimSpace(issuer2.Certificate), strings.TrimSpace(issuer2Ref.Certificate)) @@ -150,7 +151,7 @@ func Test_KeysIssuerImport(t *testing.T) { require.Equal(t, issuer2.KeyID, issuer2Ref.KeyID) // Same double import tests as above, but make sure if the previous was created through writeKey not importKey. - key2Ref, existing, err := importKey(ctx, s, key2.PrivateKey, "ignore-me") + key2Ref, existing, err := importKey(mkc, s, key2.PrivateKey, "ignore-me", key2.PrivateKeyType) require.NoError(t, err) require.True(t, existing) require.Equal(t, key2.PrivateKey, key2Ref.PrivateKey) @@ -207,7 +208,7 @@ func genCertBundle(t *testing.T, b *backend, s logical.Storage) *certutil.CertBu apiData: apiData, role: role, } - parsedCertBundle, err := generateCert(ctx, b, input, nil, true, rand.Reader) + parsedCertBundle, err := generateCert(ctx, b, input, nil, true, b.GetRandomReader()) require.NoError(t, err) certBundle, err := parsedCertBundle.ToCertBundle() diff --git a/sdk/helper/certutil/helpers.go b/sdk/helper/certutil/helpers.go index b9d3c61bc7f6..99bed25402ce 100644 --- a/sdk/helper/certutil/helpers.go +++ b/sdk/helper/certutil/helpers.go @@ -1229,16 +1229,19 @@ func GetPublicKeySize(key crypto.PublicKey) int { return -1 } -func GetKeyBundleFromKeyGenerator(keyType string, keyBits int, keyGenerator KeyGenerator) (KeyBundle, error) { - result := KeyBundle{} - - if keyGenerator == nil { - keyGenerator = generatePrivateKey - } +// CreateKeyBundle create a KeyBundle struct object which includes a generated key +// of keyType with keyBits leveraging the randomness from randReader. +func CreateKeyBundle(keyType string, keyBits int, randReader io.Reader) (KeyBundle, error) { + return CreateKeyBundleWithKeyGenerator(keyType, keyBits, randReader, generatePrivateKey) +} - if err := keyGenerator(keyType, keyBits, &result, nil); err != nil { +// CreateKeyBundleWithKeyGenerator create a KeyBundle struct object which includes +// a generated key of keyType with keyBits leveraging the randomness from randReader and +// delegates the actual key generation to keyGenerator +func CreateKeyBundleWithKeyGenerator(keyType string, keyBits int, randReader io.Reader, keyGenerator KeyGenerator) (KeyBundle, error) { + result := KeyBundle{} + if err := keyGenerator(keyType, keyBits, &result, randReader); err != nil { return result, err } - return result, nil } From fd5e1a737987d0c253b49889519dd9208f8ca536 Mon Sep 17 00:00:00 2001 From: Steve Clark <steven.clark@hashicorp.com> Date: Mon, 2 May 2022 13:56:40 -0400 Subject: [PATCH 2/3] Add proper public key comparison for better managed key support to importKeys --- builtin/logical/pki/key_util.go | 10 ++++++++++ builtin/logical/pki/storage.go | 25 ++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/builtin/logical/pki/key_util.go b/builtin/logical/pki/key_util.go index 7c8ba83f6175..3aa7d670dbb1 100644 --- a/builtin/logical/pki/key_util.go +++ b/builtin/logical/pki/key_util.go @@ -6,6 +6,7 @@ import ( "encoding/pem" "errors" "fmt" + "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/logical" @@ -98,6 +99,15 @@ func getManagedKeyPublicKey(mkc managedKeyContext, keyId managedKeyId) (crypto.P return pubKey, nil } +func getPublicKeyFromBytes(keyBytes []byte) (crypto.PublicKey, error) { + signer, _, _, err := getSignerFromBytes(keyBytes) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("failed parsing key bytes: %s", err.Error())} + } + + return signer.Public(), nil +} + func importKeyFromBytes(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) { signer, _, _, err := getSignerFromBytes([]byte(keyValue)) if err != nil { diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go index a91afbf2d05b..a6b93d173936 100644 --- a/builtin/logical/pki/storage.go +++ b/builtin/logical/pki/storage.go @@ -2,6 +2,7 @@ package pki import ( "context" + "crypto" "crypto/x509" "encoding/pem" "fmt" @@ -174,13 +175,35 @@ func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyNam return nil, false, err } + // Get our public key from the current inbound key, to compare against all the other keys. + var pkForImportingKey crypto.PublicKey + if keyType == certutil.ManagedPrivateKey { + managedKeyUUID, err := extractManagedKeyId([]byte(keyValue)) + if err != nil { + return nil, false, errutil.InternalError{Err: fmt.Sprintf("failed extracting managed key uuid from key: %v", err)} + } + pkForImportingKey, err = getManagedKeyPublicKey(mkc, managedKeyUUID) + if err != nil { + return nil, false, err + } + } else { + pkForImportingKey, err = getPublicKeyFromBytes([]byte(keyValue)) + if err != nil { + return nil, false, err + } + } + for _, identifier := range knownKeys { existingKey, err := fetchKeyById(mkc.ctx, s, identifier) if err != nil { return nil, false, err } + areEqual, err := comparePublicKey(mkc, existingKey, pkForImportingKey) + if err != nil { + return nil, false, err + } - if existingKey.PrivateKey == keyValue { + if areEqual { // Here, we don't need to stitch together the issuer entries, // because the last run should've done that for us (or, when // importing an issuer). From 68b7410ea6407ffe3593ff2542400933c11eb60e Mon Sep 17 00:00:00 2001 From: Steve Clark <steven.clark@hashicorp.com> Date: Mon, 2 May 2022 16:05:07 -0400 Subject: [PATCH 3/3] Remove redundant public key fetching within PKI importKeys --- builtin/logical/pki/storage.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go index a6b93d173936..9908bb75fccc 100644 --- a/builtin/logical/pki/storage.go +++ b/builtin/logical/pki/storage.go @@ -218,11 +218,6 @@ func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyNam result.PrivateKey = keyValue result.PrivateKeyType = keyType - keyPublic, err := getPublicKey(mkc, &result) - if err != nil { - return nil, false, err - } - // Finally, we can write the key to storage. if err := writeKey(mkc.ctx, s, result); err != nil { return nil, false, err @@ -258,7 +253,7 @@ func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyNam return nil, false, err } - equal, err := certutil.ComparePublicKeysAndType(cert.PublicKey, keyPublic) + equal, err := certutil.ComparePublicKeysAndType(cert.PublicKey, pkForImportingKey) if err != nil { return nil, false, err }