From be77cbd962c39b5cba33d98adea961231462627b Mon Sep 17 00:00:00 2001
From: Steve Clark <steven.clark@hashicorp.com>
Date: Mon, 2 May 2022 11:48:34 -0400
Subject: [PATCH 1/3] WIP: Address codebase for managed key fixes

---
 builtin/logical/pki/backend.go             |   2 +
 builtin/logical/pki/ca_util.go             | 131 +++++++++++----------
 builtin/logical/pki/crl_test.go            |   3 +-
 builtin/logical/pki/fields.go              |   1 +
 builtin/logical/pki/key_util.go            | 116 ++++++++++++++++++
 builtin/logical/pki/managed_key_util.go    |  16 ++-
 builtin/logical/pki/path_intermediate.go   |   2 +-
 builtin/logical/pki/path_manage_issuers.go |   5 +-
 builtin/logical/pki/path_manage_keys.go    |  98 ++++++++-------
 builtin/logical/pki/path_root.go           |   2 +-
 builtin/logical/pki/storage.go             |  91 +++++++-------
 builtin/logical/pki/storage_migrations.go  |   3 +-
 builtin/logical/pki/storage_test.go        |  17 +--
 sdk/helper/certutil/helpers.go             |  19 +--
 14 files changed, 330 insertions(+), 176 deletions(-)
 create mode 100644 builtin/logical/pki/key_util.go

diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go
index 3629dab1572e..e168bb8ea94d 100644
--- a/builtin/logical/pki/backend.go
+++ b/builtin/logical/pki/backend.go
@@ -165,6 +165,7 @@ func Backend(conf *logical.BackendConfig) *backend {
 	b.tidyCASGuard = new(uint32)
 	b.tidyStatus = &tidyStatus{state: tidyStatusInactive}
 	b.storage = conf.StorageView
+	b.backendUuid = conf.BackendUUID
 
 	b.pkiStorageVersion.Store(0)
 
@@ -175,6 +176,7 @@ func Backend(conf *logical.BackendConfig) *backend {
 type backend struct {
 	*framework.Backend
 
+	backendUuid       string
 	storage           logical.Storage
 	crlLifetime       time.Duration
 	revokeStorageLock sync.RWMutex
diff --git a/builtin/logical/pki/ca_util.go b/builtin/logical/pki/ca_util.go
index 5e81c7fe6031..680139a56c13 100644
--- a/builtin/logical/pki/ca_util.go
+++ b/builtin/logical/pki/ca_util.go
@@ -5,7 +5,6 @@ import (
 	"crypto"
 	"crypto/ecdsa"
 	"crypto/rsa"
-	"encoding/pem"
 	"errors"
 	"fmt"
 	"io"
@@ -38,8 +37,8 @@ func (b *backend) getGenerationParams(ctx context.Context, storage logical.Stora
 			`the "format" path parameter must be "pem", "der", or "pem_bundle"`)
 		return
 	}
-
-	keyType, keyBits, err := getKeyTypeAndBitsForRole(ctx, b, storage, data, mountPoint)
+	mkc := newManagedKeyContext(ctx, b, mountPoint)
+	keyType, keyBits, err := getKeyTypeAndBitsForRole(mkc, storage, data)
 	if err != nil {
 		errorResp = logical.ErrorResponse(err.Error())
 		return
@@ -77,7 +76,11 @@ func (b *backend) getGenerationParams(ctx context.Context, storage logical.Stora
 
 func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data *certutil.CreationBundle, randomSource io.Reader) (*certutil.ParsedCertBundle, error) {
 	if kmsRequested(input) {
-		return generateManagedKeyCABundle(ctx, b, input, data, randomSource)
+		keyId, err := getManagedKeyId(input.apiData)
+		if err != nil {
+			return nil, err
+		}
+		return generateManagedKeyCABundle(ctx, b, input, keyId, data, randomSource)
 	}
 
 	if existingKeyRequested(input) {
@@ -85,7 +88,21 @@ func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data
 		if err != nil {
 			return nil, err
 		}
-		return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef))
+
+		keyEntry, err := getExistingKeyFromRef(ctx, input.req.Storage, keyRef)
+		if err != nil {
+			return nil, err
+		}
+
+		if keyEntry.isManagedPrivateKey() {
+			keyId, err := keyEntry.getManagedKeyUUID()
+			if err != nil {
+				return nil, err
+			}
+			return generateManagedKeyCABundle(ctx, b, input, keyId, data, randomSource)
+		}
+
+		return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry))
 	}
 
 	return certutil.CreateCertificateWithRandomSource(data, randomSource)
@@ -93,7 +110,12 @@ func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data
 
 func generateCSRBundle(ctx context.Context, b *backend, input *inputBundle, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (*certutil.ParsedCSRBundle, error) {
 	if kmsRequested(input) {
-		return generateManagedKeyCSRBundle(ctx, b, input, data, addBasicConstraints, randomSource)
+		keyId, err := getManagedKeyId(input.apiData)
+		if err != nil {
+			return nil, err
+		}
+
+		return generateManagedKeyCSRBundle(ctx, b, input, keyId, data, addBasicConstraints, randomSource)
 	}
 
 	if existingKeyRequested(input) {
@@ -101,7 +123,21 @@ func generateCSRBundle(ctx context.Context, b *backend, input *inputBundle, data
 		if err != nil {
 			return nil, err
 		}
-		return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef))
+
+		key, err := getExistingKeyFromRef(ctx, input.req.Storage, keyRef)
+		if err != nil {
+			return nil, err
+		}
+
+		if key.isManagedPrivateKey() {
+			keyId, err := key.getManagedKeyUUID()
+			if err != nil {
+				return nil, err
+			}
+			return generateManagedKeyCSRBundle(ctx, b, input, keyId, data, addBasicConstraints, randomSource)
+		}
+
+		return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key))
 	}
 
 	return certutil.CreateCSRWithRandomSource(data, addBasicConstraints, randomSource)
@@ -114,7 +150,7 @@ func parseCABundle(ctx context.Context, b *backend, req *logical.Request, bundle
 	return bundle.ToParsedCertBundle()
 }
 
-func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.Storage, data *framework.FieldData, mountPoint string) (string, int, error) {
+func getKeyTypeAndBitsForRole(mkc managedKeyContext, storage logical.Storage, data *framework.FieldData) (string, int, error) {
 	exportedStr := data.Get("exported").(string)
 	var keyType string
 	var keyBits int
@@ -138,7 +174,12 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.S
 
 	var pubKey crypto.PublicKey
 	if kmsRequestedFromFieldData(data) {
-		pubKeyManagedKey, err := getManagedKeyPublicKey(ctx, b, data, mountPoint)
+		keyId, err := getManagedKeyId(data)
+		if err != nil {
+			return "", 0, errors.New("unable to determine managed key id" + err.Error())
+		}
+
+		pubKeyManagedKey, err := getManagedKeyPublicKey(mkc, keyId)
 		if err != nil {
 			return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error())
 		}
@@ -146,95 +187,67 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.S
 	}
 
 	if existingKeyRequestedFromFieldData(data) {
-		existingPubKey, err := getExistingPublicKey(ctx, storage, data)
+		existingPubKey, err := getExistingPublicKey(mkc, storage, data)
 		if err != nil {
 			return "", 0, errors.New("failed to lookup public key from existing key: " + err.Error())
 		}
 		pubKey = existingPubKey
 	}
 
-	return getKeyTypeAndBitsFromPublicKeyForRole(pubKey)
+	privateKeyType, keyBits, err := getKeyTypeAndBitsFromPublicKeyForRole(pubKey)
+	return string(privateKeyType), keyBits, err
 }
 
-func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) {
+func getExistingPublicKey(mkc managedKeyContext, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) {
 	keyRef, err := getKeyRefWithErr(data)
 	if err != nil {
 		return nil, err
 	}
-	id, err := resolveKeyReference(ctx, s, keyRef)
-	if err != nil {
-		return nil, err
-	}
-	key, err := fetchKeyById(ctx, s, id)
+	id, err := resolveKeyReference(mkc.ctx, s, keyRef)
 	if err != nil {
 		return nil, err
 	}
-	signer, err := key.GetSigner()
+	key, err := fetchKeyById(mkc.ctx, s, id)
 	if err != nil {
 		return nil, err
 	}
-	return signer.Public(), nil
+	return getPublicKey(mkc, key)
 }
 
-func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (string, int, error) {
-	var keyType string
+func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.PrivateKeyType, int, error) {
+	var keyType certutil.PrivateKeyType
 	var keyBits int
 
 	switch pubKey.(type) {
 	case *rsa.PublicKey:
-		keyType = "rsa"
+		keyType = certutil.RSAPrivateKey
 		keyBits = certutil.GetPublicKeySize(pubKey)
 	case *ecdsa.PublicKey:
-		keyType = "ec"
+		keyType = certutil.ECPrivateKey
 	case *ed25519.PublicKey:
-		keyType = "ed25519"
+		keyType = certutil.Ed25519PrivateKey
 	default:
-		return "", 0, fmt.Errorf("unsupported public key: %#v", pubKey)
+		return certutil.UnknownPrivateKey, 0, fmt.Errorf("unsupported public key: %#v", pubKey)
 	}
 	return keyType, keyBits, nil
 }
 
-func getManagedKeyPublicKey(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (crypto.PublicKey, error) {
-	keyId, err := getManagedKeyId(data)
+func getExistingKeyFromRef(ctx context.Context, s logical.Storage, keyRef string) (*keyEntry, error) {
+	keyId, err := resolveKeyReference(ctx, s, keyRef)
 	if err != nil {
-		return nil, errors.New("unable to determine managed key id")
-	}
-	// Determine key type and key bits from the managed public key
-	var pubKey crypto.PublicKey
-	err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
-		pubKey, err = key.GetPublicKey(ctx)
-		if err != nil {
-			return err
-		}
-
-		return nil
-	})
-	if err != nil {
-		return nil, errors.New("failed to lookup public key from managed key: " + err.Error())
+		return nil, err
 	}
-	return pubKey, nil
+	return fetchKeyById(ctx, s, keyId)
 }
 
-func existingGeneratePrivateKey(ctx context.Context, s logical.Storage, keyRef string) certutil.KeyGenerator {
-	return func(keyType string, keyBits int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error {
-		keyId, err := resolveKeyReference(ctx, s, keyRef)
-		if err != nil {
-			return err
-		}
-		key, err := fetchKeyById(ctx, s, keyId)
+func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator {
+	return func(_ string, _ int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error {
+		signer, _, pemBytes, err := getSignerFromKeyEntryBytes(key)
 		if err != nil {
 			return err
 		}
-		signer, err := key.GetSigner()
-		if err != nil {
-			return err
-		}
-		privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer)
-		if privateKeyType == certutil.UnknownPrivateKey {
-			return errors.New("unknown private key type loaded from key id: " + keyId.String())
-		}
-		blk, _ := pem.Decode([]byte(key.PrivateKey))
-		container.SetParsedPrivateKey(signer, privateKeyType, blk.Bytes)
+
+		container.SetParsedPrivateKey(signer, key.PrivateKeyType, pemBytes.Bytes)
 		return nil
 	}
 }
diff --git a/builtin/logical/pki/crl_test.go b/builtin/logical/pki/crl_test.go
index 368a412ca8a6..20ed3403e15d 100644
--- a/builtin/logical/pki/crl_test.go
+++ b/builtin/logical/pki/crl_test.go
@@ -131,10 +131,11 @@ func TestBackend_CRL_EnableDisable(t *testing.T) {
 func TestBackend_Secondary_CRL_Rebuilding(t *testing.T) {
 	ctx := context.Background()
 	b, s := createBackendWithStorage(t)
+	mkc := newManagedKeyContext(ctx, b, "test")
 
 	// Write out the issuer/key to storage without going through the api call as replication would.
 	bundle := genCertBundle(t, b, s)
-	issuer, _, err := writeCaBundle(ctx, s, bundle, "", "")
+	issuer, _, err := writeCaBundle(mkc, s, bundle, "", "")
 	require.NoError(t, err)
 
 	// Just to validate, before we call the invalidate function, make sure our CRL has not been generated
diff --git a/builtin/logical/pki/fields.go b/builtin/logical/pki/fields.go
index 45c329bd8acf..b53ddea36664 100644
--- a/builtin/logical/pki/fields.go
+++ b/builtin/logical/pki/fields.go
@@ -8,6 +8,7 @@ const (
 	keyRefParam    = "key_ref"
 	keyIdParam     = "key_id"
 	keyTypeParam   = "key_type"
+	keyBitsParam   = "key_bits"
 )
 
 // addIssueAndSignCommonFields adds fields common to both CA and non-CA issuing
diff --git a/builtin/logical/pki/key_util.go b/builtin/logical/pki/key_util.go
new file mode 100644
index 000000000000..7c8ba83f6175
--- /dev/null
+++ b/builtin/logical/pki/key_util.go
@@ -0,0 +1,116 @@
+package pki
+
+import (
+	"context"
+	"crypto"
+	"encoding/pem"
+	"errors"
+	"fmt"
+	"github.com/hashicorp/vault/sdk/helper/certutil"
+	"github.com/hashicorp/vault/sdk/helper/errutil"
+	"github.com/hashicorp/vault/sdk/logical"
+)
+
+type managedKeyContext struct {
+	ctx        context.Context
+	b          *backend
+	mountPoint string
+}
+
+func newManagedKeyContext(ctx context.Context, b *backend, mountPoint string) managedKeyContext {
+	return managedKeyContext{
+		ctx:        ctx,
+		b:          b,
+		mountPoint: mountPoint,
+	}
+}
+
+func comparePublicKey(ctx managedKeyContext, key *keyEntry, publicKey crypto.PublicKey) (bool, error) {
+	publicKeyForKeyEntry, err := getPublicKey(ctx, key)
+	if err != nil {
+		return false, err
+	}
+
+	return certutil.ComparePublicKeysAndType(publicKeyForKeyEntry, publicKey)
+}
+
+func getPublicKey(mkc managedKeyContext, key *keyEntry) (crypto.PublicKey, error) {
+	if key.PrivateKeyType == certutil.ManagedPrivateKey {
+		keyId, err := extractManagedKeyId([]byte(key.PrivateKey))
+		if err != nil {
+			return nil, err
+		}
+		return getManagedKeyPublicKey(mkc, keyId)
+	}
+
+	signer, _, _, err := getSignerFromKeyEntryBytes(key)
+	if err != nil {
+		return nil, err
+	}
+	return signer.Public(), nil
+}
+
+func getSignerFromKeyEntryBytes(key *keyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
+	if key.PrivateKeyType == certutil.UnknownPrivateKey {
+		return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported unknown private key type for key: %s (%s)", key.ID, key.Name)}
+	}
+
+	if key.PrivateKeyType == certutil.ManagedPrivateKey {
+		return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("can not get a signer from a managed key: %s (%s)", key.ID, key.Name)}
+	}
+
+	bytes, blockType, blk, err := getSignerFromBytes([]byte(key.PrivateKey))
+	if err != nil {
+		return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("failed parsing key entry bytes for key id: %s (%s): %s", key.ID, key.Name, err.Error())}
+	}
+
+	return bytes, blockType, blk, nil
+}
+
+func getSignerFromBytes(keyBytes []byte) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
+	pemBlock, _ := pem.Decode(keyBytes)
+	if pemBlock == nil {
+		return nil, certutil.UnknownBlock, pemBlock, errutil.InternalError{Err: "no data found in PEM block"}
+	}
+
+	signer, blk, err := certutil.ParseDERKey(pemBlock.Bytes)
+	if err != nil {
+		return nil, certutil.UnknownBlock, pemBlock, errutil.InternalError{Err: fmt.Sprintf("failed to parse PEM block: %s", err.Error())}
+	}
+	return signer, blk, pemBlock, nil
+}
+
+func getManagedKeyPublicKey(mkc managedKeyContext, keyId managedKeyId) (crypto.PublicKey, error) {
+	// Determine key type and key bits from the managed public key
+	var pubKey crypto.PublicKey
+	err := withManagedPKIKey(mkc.ctx, mkc.b, keyId, mkc.mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
+		var myErr error
+		pubKey, myErr = key.GetPublicKey(ctx)
+		if myErr != nil {
+			return myErr
+		}
+
+		return nil
+	})
+	if err != nil {
+		return nil, errors.New("failed to lookup public key from managed key: " + err.Error())
+	}
+	return pubKey, nil
+}
+
+func importKeyFromBytes(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) {
+	signer, _, _, err := getSignerFromBytes([]byte(keyValue))
+	if err != nil {
+		return nil, false, err
+	}
+	privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer)
+	if privateKeyType == certutil.UnknownPrivateKey {
+		return nil, false, errors.New("unsupported private key type within pem bundle")
+	}
+
+	key, existed, err := importKey(mkc, s, keyValue, keyName, privateKeyType)
+	if err != nil {
+		return nil, false, err
+	}
+	return key, existed, nil
+}
diff --git a/builtin/logical/pki/managed_key_util.go b/builtin/logical/pki/managed_key_util.go
index 45d80d643dcd..a69ac056bfc4 100644
--- a/builtin/logical/pki/managed_key_util.go
+++ b/builtin/logical/pki/managed_key_util.go
@@ -13,18 +13,26 @@ import (
 
 var errEntOnly = errors.New("managed keys are supported within enterprise edition only")
 
-func generateManagedKeyCABundle(_ context.Context, _ *backend, _ *inputBundle, _ *certutil.CreationBundle, _ io.Reader) (*certutil.ParsedCertBundle, error) {
+func generateManagedKeyCABundle(ctx context.Context, b *backend, input *inputBundle, keyId managedKeyId, data *certutil.CreationBundle, randomSource io.Reader) (bundle *certutil.ParsedCertBundle, err error) {
 	return nil, errEntOnly
 }
 
-func generateManagedKeyCSRBundle(_ context.Context, _ *backend, _ *inputBundle, _ *certutil.CreationBundle, _ bool, _ io.Reader) (*certutil.ParsedCSRBundle, error) {
+func generateManagedKeyCSRBundle(ctx context.Context, b *backend, input *inputBundle, keyId managedKeyId, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (bundle *certutil.ParsedCSRBundle, err error) {
 	return nil, errEntOnly
 }
 
-func parseManagedKeyCABundle(_ context.Context, _ *backend, _ *logical.Request, _ *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
+func parseManagedKeyCABundle(ctx context.Context, b *backend, req *logical.Request, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
 	return nil, errEntOnly
 }
 
-func withManagedPKIKey(_ context.Context, _ *backend, _ managedKeyId, _ string, _ logical.ManagedSigningKeyConsumer) error {
+func withManagedPKIKey(ctx context.Context, b *backend, keyId managedKeyId, mountPoint string, f logical.ManagedSigningKeyConsumer) error {
 	return errEntOnly
 }
+
+func extractManagedKeyId(privateKeyBytes []byte) (UUIDKey, error) {
+	return "", errEntOnly
+}
+
+func createKmsKeyBundle(mkc managedKeyContext, keyId managedKeyId) (certutil.KeyBundle, certutil.PrivateKeyType, error) {
+	return certutil.KeyBundle{}, certutil.UnknownPrivateKey, errEntOnly
+}
diff --git a/builtin/logical/pki/path_intermediate.go b/builtin/logical/pki/path_intermediate.go
index a0b69e1a481c..bf0b416bf4a4 100644
--- a/builtin/logical/pki/path_intermediate.go
+++ b/builtin/logical/pki/path_intermediate.go
@@ -130,7 +130,7 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req
 		}
 	}
 
-	myKey, _, err := importKey(ctx, req.Storage, csrb.PrivateKey, keyName)
+	myKey, _, err := importKey(newManagedKeyContext(ctx, b, req.MountPoint), req.Storage, csrb.PrivateKey, keyName, csrb.PrivateKeyType)
 	if err != nil {
 		return nil, err
 	}
diff --git a/builtin/logical/pki/path_manage_issuers.go b/builtin/logical/pki/path_manage_issuers.go
index 0138755c8196..ea7abcfa08c1 100644
--- a/builtin/logical/pki/path_manage_issuers.go
+++ b/builtin/logical/pki/path_manage_issuers.go
@@ -181,9 +181,10 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
 		return logical.ErrorResponse("private keys found in the PEM bundle but not allowed by the path; use /issuers/import/bundle"), nil
 	}
 
+	mkc := newManagedKeyContext(ctx, b, req.MountPoint)
 	for keyIndex, keyPem := range keys {
 		// Handle import of private key.
-		key, existing, err := importKey(ctx, req.Storage, keyPem, "")
+		key, existing, err := importKeyFromBytes(mkc, req.Storage, keyPem, "")
 		if err != nil {
 			return logical.ErrorResponse(fmt.Sprintf("Error parsing key %v: %v", keyIndex, err)), nil
 		}
@@ -194,7 +195,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
 	}
 
 	for certIndex, certPem := range issuers {
-		cert, existing, err := importIssuer(ctx, req.Storage, certPem, "")
+		cert, existing, err := importIssuer(mkc, req.Storage, certPem, "")
 		if err != nil {
 			return logical.ErrorResponse(fmt.Sprintf("Error parsing issuer %v: %v\n%v", certIndex, err, certPem)), nil
 		}
diff --git a/builtin/logical/pki/path_manage_keys.go b/builtin/logical/pki/path_manage_keys.go
index 04195aae2ed5..cce0fde52bd1 100644
--- a/builtin/logical/pki/path_manage_keys.go
+++ b/builtin/logical/pki/path_manage_keys.go
@@ -11,7 +11,7 @@ import (
 
 func pathGenerateKey(b *backend) *framework.Path {
 	return &framework.Path{
-		Pattern: "keys/generate/(internal|exported)",
+		Pattern: "keys/generate/(internal|exported|kms)",
 
 		Fields: map[string]*framework.FieldSchema{
 			keyNameParam: {
@@ -23,15 +23,27 @@ func pathGenerateKey(b *backend) *framework.Path {
 				Default:     "rsa",
 				Description: `Type of the secret key to generate`,
 			},
-			"key_bits": {
+			keyBitsParam: {
 				Type:        framework.TypeInt,
 				Default:     2048,
 				Description: `Type of the secret key to generate`,
 			},
+			"managed_key_name": {
+				Type: framework.TypeString,
+				Description: `The name of the managed key to use when the exported
+type is kms. When kms type is the key type, this field or managed_key_id
+is required. Ignored for other types.`,
+			},
+			"managed_key_id": {
+				Type: framework.TypeString,
+				Description: `The name of the managed key to use when the exported
+type is kms. When kms type is the key type, this field or managed_key_name
+is required. Ignored for other types.`,
+			},
 		},
 
 		Operations: map[logical.Operation]framework.OperationHandler{
-			logical.CreateOperation: &framework.PathOperation{
+			logical.UpdateOperation: &framework.PathOperation{
 				Callback:                    b.pathGenerateKeyHandler,
 				ForwardPerformanceStandby:   true,
 				ForwardPerformanceSecondary: true,
@@ -58,60 +70,60 @@ func (b *backend) pathGenerateKeyHandler(ctx context.Context, req *logical.Reque
 	if err != nil { // Fail Immediately if Key Name is in Use, etc...
 		return nil, err
 	}
-	keyType := data.Get(keyTypeParam).(string)
-	keyBits := data.Get("key_bits").(int)
 
+	mkc := newManagedKeyContext(ctx, b, req.MountPoint)
+	exportPrivateKey := false
+	var keyBundle certutil.KeyBundle
+	var actualPrivateKeyType certutil.PrivateKeyType
 	switch {
+	case strings.HasSuffix(req.Path, "/exported"):
+		exportPrivateKey = true
+		fallthrough
 	case strings.HasSuffix(req.Path, "/internal"):
+		keyType := data.Get(keyTypeParam).(string)
+		keyBits := data.Get(keyBitsParam).(int)
+
 		// Internal key generation, stored in storage
-		keyBundle, err := certutil.GetKeyBundleFromKeyGenerator(keyType, keyBits, nil)
-		if err != nil {
-			return nil, err
-		}
-		privateKeyPemString, err := keyBundle.ToPrivateKeyPemString()
-		if err != nil {
-			return nil, err
-		}
-		key, _, err := importKey(ctx, req.Storage, privateKeyPemString, keyName)
-		if err != nil {
-			return nil, err
-		}
-		resp := logical.Response{
-			Data: map[string]interface{}{
-				keyIdParam:   key.ID,
-				keyNameParam: key.Name,
-				keyTypeParam: key.PrivateKeyType,
-			},
-		}
-		return &resp, nil
-	case strings.HasSuffix(req.Path, "/exported"):
-		// Same as internal key generation but we return the generated key
-		keyBundle, err := certutil.GetKeyBundleFromKeyGenerator(keyType, keyBits, nil)
+		keyBundle, err = certutil.CreateKeyBundle(keyType, keyBits, b.GetRandomReader())
 		if err != nil {
 			return nil, err
 		}
-		privateKeyPemString, err := keyBundle.ToPrivateKeyPemString()
+
+		actualPrivateKeyType = keyBundle.PrivateKeyType
+	case strings.HasSuffix(req.Path, "/kms"):
+		keyId, err := getManagedKeyId(data)
 		if err != nil {
 			return nil, err
 		}
-		key, _, err := importKey(ctx, req.Storage, privateKeyPemString, keyName)
+
+		keyBundle, actualPrivateKeyType, err = createKmsKeyBundle(mkc, keyId)
 		if err != nil {
 			return nil, err
 		}
-		resp := logical.Response{
-			Data: map[string]interface{}{
-				keyIdParam:    key.ID,
-				keyNameParam:  key.Name,
-				keyTypeParam:  key.PrivateKeyType,
-				"private_key": privateKeyPemString,
-			},
-		}
-		return &resp, nil
-	case strings.HasSuffix(req.Path, "/kms"):
-		return nil, errEntOnly
 	default:
 		return logical.ErrorResponse("Unknown type of key to generate"), nil
 	}
+
+	privateKeyPemString, err := keyBundle.ToPrivateKeyPemString()
+	if err != nil {
+		return nil, err
+	}
+
+	key, _, err := importKey(mkc, req.Storage, privateKeyPemString, keyName, keyBundle.PrivateKeyType)
+	if err != nil {
+		return nil, err
+	}
+	responseData := map[string]interface{}{
+		keyIdParam:   key.ID,
+		keyNameParam: key.Name,
+		keyTypeParam: string(actualPrivateKeyType),
+	}
+	if exportPrivateKey {
+		responseData["private_key"] = privateKeyPemString
+	}
+	return &logical.Response{
+		Data: responseData,
+	}, nil
 }
 
 func pathImportKey(b *backend) *framework.Path {
@@ -161,7 +173,8 @@ func (b *backend) pathImportKeyHandler(ctx context.Context, req *logical.Request
 	keyValue := keyValueInterface.(string)
 	keyName := data.Get(keyNameParam).(string)
 
-	key, existed, err := importKey(ctx, req.Storage, keyValue, keyName)
+	mkc := newManagedKeyContext(ctx, b, req.MountPoint)
+	key, existed, err := importKeyFromBytes(mkc, req.Storage, keyValue, keyName)
 	if err != nil {
 		return logical.ErrorResponse(err.Error()), nil
 	}
@@ -171,7 +184,6 @@ func (b *backend) pathImportKeyHandler(ctx context.Context, req *logical.Request
 			keyIdParam:   key.ID,
 			keyNameParam: key.Name,
 			keyTypeParam: key.PrivateKeyType,
-			"backing":    "", // This would show up as "Managed" in "type"
 		},
 	}
 
diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go
index 66ac5b0a2666..121c7c4b5b39 100644
--- a/builtin/logical/pki/path_root.go
+++ b/builtin/logical/pki/path_root.go
@@ -199,7 +199,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
 	}
 
 	// Store it as the CA bundle
-	myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName, keyName)
+	myIssuer, myKey, err := writeCaBundle(newManagedKeyContext(ctx, b, req.MountPoint), req.Storage, cb, issuerName, keyName)
 	if err != nil {
 		return nil, err
 	}
diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go
index 74812364a5e5..a91afbf2d05b 100644
--- a/builtin/logical/pki/storage.go
+++ b/builtin/logical/pki/storage.go
@@ -2,7 +2,6 @@ package pki
 
 import (
 	"context"
-	"crypto"
 	"crypto/x509"
 	"encoding/pem"
 	"fmt"
@@ -55,6 +54,17 @@ type keyEntry struct {
 	PrivateKey     string                  `json:"private_key" structs:"private_key" mapstructure:"private_key"`
 }
 
+func (e keyEntry) getManagedKeyUUID() (UUIDKey, error) {
+	if !e.isManagedPrivateKey() {
+		return "", errutil.InternalError{Err: "getManagedKeyId called on a key id %s (%s) "}
+	}
+	return extractManagedKeyId([]byte(e.PrivateKey))
+}
+
+func (e keyEntry) isManagedPrivateKey() bool {
+	return e.PrivateKeyType == certutil.ManagedPrivateKey
+}
+
 type issuerEntry struct {
 	ID                   issuerID                  `json:"id" structs:"id" mapstructure:"id"`
 	Name                 string                    `json:"name" structs:"name" mapstructure:"name"`
@@ -79,11 +89,6 @@ type issuerConfigEntry struct {
 	DefaultIssuerId issuerID `json:"default" structs:"default" mapstructure:"default"`
 }
 
-func (k keyEntry) GetSigner() (crypto.Signer, error) {
-	signer, _, err := certutil.ParsePEMKey(k.PrivateKey)
-	return signer, err
-}
-
 func listKeys(ctx context.Context, s logical.Storage) ([]keyID, error) {
 	strList, err := s.List(ctx, keyPrefix)
 	if err != nil {
@@ -149,7 +154,7 @@ func deleteKey(ctx context.Context, s logical.Storage, id keyID) (bool, error) {
 	return wasDefault, s.Delete(ctx, keyPrefix+id.String())
 }
 
-func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) {
+func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string, keyType certutil.PrivateKeyType) (*keyEntry, bool, error) {
 	// importKey imports the specified PEM-format key (from keyValue) into
 	// the new PKI storage format. The first return field is a reference to
 	// the new key; the second is whether or not the key already existed
@@ -164,22 +169,13 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName
 	// Before we can import a known key, we first need to know if the key
 	// exists in storage already. This means iterating through all known
 	// keys and comparing their private value against this value.
-	knownKeys, err := listKeys(ctx, s)
-	if err != nil {
-		return nil, false, err
-	}
-
-	// Before we return below, we need to iterate over _all_ issuers and see if
-	// one of them has a missing KeyId link, and if so, point it back to
-	// ourselves. We fetch the list of issuers up front, even when don't need
-	// it, to give ourselves a better chance of succeeding below.
-	knownIssuers, err := listIssuers(ctx, s)
+	knownKeys, err := listKeys(mkc.ctx, s)
 	if err != nil {
 		return nil, false, err
 	}
 
 	for _, identifier := range knownKeys {
-		existingKey, err := fetchKeyById(ctx, s, identifier)
+		existingKey, err := fetchKeyById(mkc.ctx, s, identifier)
 		if err != nil {
 			return nil, false, err
 		}
@@ -197,25 +193,30 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName
 	result.ID = genKeyId()
 	result.Name = keyName
 	result.PrivateKey = keyValue
+	result.PrivateKeyType = keyType
 
-	// Extracting the signer is necessary for two reasons: first, to get the
-	// public key for comparison with existing issuers; second, to get the
-	// corresponding private key type.
-	keySigner, err := result.GetSigner()
+	keyPublic, err := getPublicKey(mkc, &result)
 	if err != nil {
 		return nil, false, err
 	}
-	keyPublic := keySigner.Public()
-	result.PrivateKeyType = certutil.GetPrivateKeyTypeFromSigner(keySigner)
 
 	// Finally, we can write the key to storage.
-	if err := writeKey(ctx, s, result); err != nil {
+	if err := writeKey(mkc.ctx, s, result); err != nil {
+		return nil, false, err
+	}
+
+	// Before we return below, we need to iterate over _all_ issuers and see if
+	// one of them has a missing KeyId link, and if so, point it back to
+	// ourselves. We fetch the list of issuers up front, even when don't need
+	// it, to give ourselves a better chance of succeeding below.
+	knownIssuers, err := listIssuers(mkc.ctx, s)
+	if err != nil {
 		return nil, false, err
 	}
 
 	// Now, for each issuer, try and compute the issuer<->key link if missing.
 	for _, identifier := range knownIssuers {
-		existingIssuer, err := fetchIssuerById(ctx, s, identifier)
+		existingIssuer, err := fetchIssuerById(mkc.ctx, s, identifier)
 		if err != nil {
 			return nil, false, err
 		}
@@ -243,7 +244,7 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName
 			// These public keys are equal, so this key entry must be the
 			// corresponding private key to this issuer; update it accordingly.
 			existingIssuer.KeyID = result.ID
-			if err := writeIssuer(ctx, s, existingIssuer); err != nil {
+			if err := writeIssuer(mkc.ctx, s, existingIssuer); err != nil {
 				return nil, false, err
 			}
 		}
@@ -251,12 +252,12 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName
 
 	// If there was no prior default value set and/or we had no known
 	// keys when we started, set this key as default.
-	keyDefaultSet, err := isDefaultKeySet(ctx, s)
+	keyDefaultSet, err := isDefaultKeySet(mkc.ctx, s)
 	if err != nil {
 		return nil, false, err
 	}
 	if len(knownKeys) == 0 || !keyDefaultSet {
-		if err = updateDefaultKeyId(ctx, s, result.ID); err != nil {
+		if err = updateDefaultKeyId(mkc.ctx, s, result.ID); err != nil {
 			return nil, false, err
 		}
 	}
@@ -384,7 +385,7 @@ func deleteIssuer(ctx context.Context, s logical.Storage, id issuerID) (bool, er
 	return wasDefault, s.Delete(ctx, issuerPrefix+id.String())
 }
 
-func importIssuer(ctx context.Context, s logical.Storage, certValue string, issuerName string) (*issuerEntry, bool, error) {
+func importIssuer(ctx managedKeyContext, s logical.Storage, certValue string, issuerName string) (*issuerEntry, bool, error) {
 	// importIssuers imports the specified PEM-format certificate (from
 	// certValue) into the new PKI storage format. The first return field is a
 	// reference to the new issuer; the second is whether or not the issuer
@@ -409,7 +410,7 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu
 	// Before we can import a known issuer, we first need to know if the issuer
 	// exists in storage already. This means iterating through all known
 	// issuers and comparing their private value against this value.
-	knownIssuers, err := listIssuers(ctx, s)
+	knownIssuers, err := listIssuers(ctx.ctx, s)
 	if err != nil {
 		return nil, false, err
 	}
@@ -418,13 +419,13 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu
 	// one of them a public key matching this certificate, and if so, update our
 	// link accordingly. We fetch the list of keys up front, even may not need
 	// it, to give ourselves a better chance of succeeding below.
-	knownKeys, err := listKeys(ctx, s)
+	knownKeys, err := listKeys(ctx.ctx, s)
 	if err != nil {
 		return nil, false, err
 	}
 
 	for _, identifier := range knownIssuers {
-		existingIssuer, err := fetchIssuerById(ctx, s, identifier)
+		existingIssuer, err := fetchIssuerById(ctx.ctx, s, identifier)
 		if err != nil {
 			return nil, false, err
 		}
@@ -470,18 +471,12 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu
 	// writing issuer to storage as we won't need to update the key, only
 	// the issuer.
 	for _, identifier := range knownKeys {
-		existingKey, err := fetchKeyById(ctx, s, identifier)
-		if err != nil {
-			return nil, false, err
-		}
-
-		// Fetch the signer for its Public() value.
-		signer, err := existingKey.GetSigner()
+		existingKey, err := fetchKeyById(ctx.ctx, s, identifier)
 		if err != nil {
 			return nil, false, err
 		}
 
-		equal, err := certutil.ComparePublicKeysAndType(issuerCert.PublicKey, signer.Public())
+		equal, err := comparePublicKey(ctx, existingKey, issuerCert.PublicKey)
 		if err != nil {
 			return nil, false, err
 		}
@@ -498,18 +493,18 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string, issu
 
 	// Finally, rebuild the chains. In this process, because the provided
 	// reference issuer is non-nil, we'll save this issuer to storage.
-	if err := rebuildIssuersChains(ctx, s, &result); err != nil {
+	if err := rebuildIssuersChains(ctx.ctx, s, &result); err != nil {
 		return nil, false, err
 	}
 
 	// If there was no prior default value set and/or we had no known
 	// issuers when we started, set this issuer as default.
-	issuerDefaultSet, err := isDefaultIssuerSet(ctx, s)
+	issuerDefaultSet, err := isDefaultIssuerSet(ctx.ctx, s)
 	if err != nil {
 		return nil, false, err
 	}
 	if len(knownIssuers) == 0 || !issuerDefaultSet {
-		if err = updateDefaultIssuerId(ctx, s, result.ID); err != nil {
+		if err = updateDefaultIssuerId(ctx.ctx, s, result.ID); err != nil {
 			return nil, false, err
 		}
 	}
@@ -692,19 +687,19 @@ func fetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id issuer
 	return issuer, &bundle, nil
 }
 
-func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string, keyName string) (*issuerEntry, *keyEntry, error) {
-	myKey, _, err := importKey(ctx, s, caBundle.PrivateKey, keyName)
+func writeCaBundle(mkc managedKeyContext, s logical.Storage, caBundle *certutil.CertBundle, issuerName string, keyName string) (*issuerEntry, *keyEntry, error) {
+	myKey, _, err := importKey(mkc, s, caBundle.PrivateKey, keyName, caBundle.PrivateKeyType)
 	if err != nil {
 		return nil, nil, err
 	}
 
-	myIssuer, _, err := importIssuer(ctx, s, caBundle.Certificate, issuerName)
+	myIssuer, _, err := importIssuer(mkc, s, caBundle.Certificate, issuerName)
 	if err != nil {
 		return nil, nil, err
 	}
 
 	for _, cert := range caBundle.CAChain {
-		if _, _, err = importIssuer(ctx, s, cert, ""); err != nil {
+		if _, _, err = importIssuer(mkc, s, cert, ""); err != nil {
 			return nil, nil, err
 		}
 	}
diff --git a/builtin/logical/pki/storage_migrations.go b/builtin/logical/pki/storage_migrations.go
index c5679b53c1d7..3b3f75267465 100644
--- a/builtin/logical/pki/storage_migrations.go
+++ b/builtin/logical/pki/storage_migrations.go
@@ -77,7 +77,8 @@ func migrateStorage(ctx context.Context, b *backend, s logical.Storage) error {
 
 	b.Logger().Info("performing PKI migration to new keys/issuers layout")
 	if migrationInfo.legacyBundle != nil {
-		anIssuer, aKey, err := writeCaBundle(ctx, s, migrationInfo.legacyBundle, "current", "current")
+		mkc := newManagedKeyContext(ctx, b, b.backendUuid)
+		anIssuer, aKey, err := writeCaBundle(mkc, s, migrationInfo.legacyBundle, "current", "current")
 		if err != nil {
 			return err
 		}
diff --git a/builtin/logical/pki/storage_test.go b/builtin/logical/pki/storage_test.go
index 8b752cd73d00..1d2d93acc783 100644
--- a/builtin/logical/pki/storage_test.go
+++ b/builtin/logical/pki/storage_test.go
@@ -2,7 +2,6 @@ package pki
 
 import (
 	"context"
-	"crypto/rand"
 	"strings"
 	"testing"
 
@@ -94,6 +93,8 @@ func Test_IssuerRoundTrip(t *testing.T) {
 
 func Test_KeysIssuerImport(t *testing.T) {
 	b, s := createBackendWithStorage(t)
+	mkc := newManagedKeyContext(ctx, b, "test")
+
 	issuer1, key1 := genIssuerAndKey(t, b, s)
 	issuer2, key2 := genIssuerAndKey(t, b, s)
 
@@ -103,21 +104,21 @@ func Test_KeysIssuerImport(t *testing.T) {
 	issuer1.ID = ""
 	issuer1.KeyID = ""
 
-	key1Ref1, existing, err := importKey(ctx, s, key1.PrivateKey, "key1")
+	key1Ref1, existing, err := importKey(mkc, s, key1.PrivateKey, "key1", key1.PrivateKeyType)
 	require.NoError(t, err)
 	require.False(t, existing)
 	require.Equal(t, strings.TrimSpace(key1.PrivateKey), strings.TrimSpace(key1Ref1.PrivateKey))
 
 	// Make sure if we attempt to re-import the same private key, no import/updates occur.
 	// So the existing flag should be set to true, and we do not update the existing Name field.
-	key1Ref2, existing, err := importKey(ctx, s, key1.PrivateKey, "ignore-me")
+	key1Ref2, existing, err := importKey(mkc, s, key1.PrivateKey, "ignore-me", key1.PrivateKeyType)
 	require.NoError(t, err)
 	require.True(t, existing)
 	require.Equal(t, key1.PrivateKey, key1Ref1.PrivateKey)
 	require.Equal(t, key1Ref1.ID, key1Ref2.ID)
 	require.Equal(t, key1Ref1.Name, key1Ref2.Name)
 
-	issuer1Ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "issuer1")
+	issuer1Ref1, existing, err := importIssuer(mkc, s, issuer1.Certificate, "issuer1")
 	require.NoError(t, err)
 	require.False(t, existing)
 	require.Equal(t, strings.TrimSpace(issuer1.Certificate), strings.TrimSpace(issuer1Ref1.Certificate))
@@ -126,7 +127,7 @@ func Test_KeysIssuerImport(t *testing.T) {
 
 	// Make sure if we attempt to re-import the same issuer, no import/updates occur.
 	// So the existing flag should be set to true, and we do not update the existing Name field.
-	issuer1Ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "ignore-me")
+	issuer1Ref2, existing, err := importIssuer(mkc, s, issuer1.Certificate, "ignore-me")
 	require.NoError(t, err)
 	require.True(t, existing)
 	require.Equal(t, strings.TrimSpace(issuer1.Certificate), strings.TrimSpace(issuer1Ref1.Certificate))
@@ -141,7 +142,7 @@ func Test_KeysIssuerImport(t *testing.T) {
 	require.NoError(t, err)
 
 	// Same double import tests as above, but make sure if the previous was created through writeIssuer not importIssuer.
-	issuer2Ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "ignore-me")
+	issuer2Ref, existing, err := importIssuer(mkc, s, issuer2.Certificate, "ignore-me")
 	require.NoError(t, err)
 	require.True(t, existing)
 	require.Equal(t, strings.TrimSpace(issuer2.Certificate), strings.TrimSpace(issuer2Ref.Certificate))
@@ -150,7 +151,7 @@ func Test_KeysIssuerImport(t *testing.T) {
 	require.Equal(t, issuer2.KeyID, issuer2Ref.KeyID)
 
 	// Same double import tests as above, but make sure if the previous was created through writeKey not importKey.
-	key2Ref, existing, err := importKey(ctx, s, key2.PrivateKey, "ignore-me")
+	key2Ref, existing, err := importKey(mkc, s, key2.PrivateKey, "ignore-me", key2.PrivateKeyType)
 	require.NoError(t, err)
 	require.True(t, existing)
 	require.Equal(t, key2.PrivateKey, key2Ref.PrivateKey)
@@ -207,7 +208,7 @@ func genCertBundle(t *testing.T, b *backend, s logical.Storage) *certutil.CertBu
 		apiData: apiData,
 		role:    role,
 	}
-	parsedCertBundle, err := generateCert(ctx, b, input, nil, true, rand.Reader)
+	parsedCertBundle, err := generateCert(ctx, b, input, nil, true, b.GetRandomReader())
 
 	require.NoError(t, err)
 	certBundle, err := parsedCertBundle.ToCertBundle()
diff --git a/sdk/helper/certutil/helpers.go b/sdk/helper/certutil/helpers.go
index b9d3c61bc7f6..99bed25402ce 100644
--- a/sdk/helper/certutil/helpers.go
+++ b/sdk/helper/certutil/helpers.go
@@ -1229,16 +1229,19 @@ func GetPublicKeySize(key crypto.PublicKey) int {
 	return -1
 }
 
-func GetKeyBundleFromKeyGenerator(keyType string, keyBits int, keyGenerator KeyGenerator) (KeyBundle, error) {
-	result := KeyBundle{}
-
-	if keyGenerator == nil {
-		keyGenerator = generatePrivateKey
-	}
+// CreateKeyBundle create a KeyBundle struct object which includes a generated key
+// of keyType with keyBits leveraging the randomness from randReader.
+func CreateKeyBundle(keyType string, keyBits int, randReader io.Reader) (KeyBundle, error) {
+	return CreateKeyBundleWithKeyGenerator(keyType, keyBits, randReader, generatePrivateKey)
+}
 
-	if err := keyGenerator(keyType, keyBits, &result, nil); err != nil {
+// CreateKeyBundleWithKeyGenerator create a KeyBundle struct object which includes
+// a generated key of keyType with keyBits leveraging the randomness from randReader and
+// delegates the actual key generation to keyGenerator
+func CreateKeyBundleWithKeyGenerator(keyType string, keyBits int, randReader io.Reader, keyGenerator KeyGenerator) (KeyBundle, error) {
+	result := KeyBundle{}
+	if err := keyGenerator(keyType, keyBits, &result, randReader); err != nil {
 		return result, err
 	}
-
 	return result, nil
 }

From fd5e1a737987d0c253b49889519dd9208f8ca536 Mon Sep 17 00:00:00 2001
From: Steve Clark <steven.clark@hashicorp.com>
Date: Mon, 2 May 2022 13:56:40 -0400
Subject: [PATCH 2/3] Add proper public key comparison for better managed key
 support to importKeys

---
 builtin/logical/pki/key_util.go | 10 ++++++++++
 builtin/logical/pki/storage.go  | 25 ++++++++++++++++++++++++-
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/builtin/logical/pki/key_util.go b/builtin/logical/pki/key_util.go
index 7c8ba83f6175..3aa7d670dbb1 100644
--- a/builtin/logical/pki/key_util.go
+++ b/builtin/logical/pki/key_util.go
@@ -6,6 +6,7 @@ import (
 	"encoding/pem"
 	"errors"
 	"fmt"
+
 	"github.com/hashicorp/vault/sdk/helper/certutil"
 	"github.com/hashicorp/vault/sdk/helper/errutil"
 	"github.com/hashicorp/vault/sdk/logical"
@@ -98,6 +99,15 @@ func getManagedKeyPublicKey(mkc managedKeyContext, keyId managedKeyId) (crypto.P
 	return pubKey, nil
 }
 
+func getPublicKeyFromBytes(keyBytes []byte) (crypto.PublicKey, error) {
+	signer, _, _, err := getSignerFromBytes(keyBytes)
+	if err != nil {
+		return nil, errutil.InternalError{Err: fmt.Sprintf("failed parsing key bytes: %s", err.Error())}
+	}
+
+	return signer.Public(), nil
+}
+
 func importKeyFromBytes(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) {
 	signer, _, _, err := getSignerFromBytes([]byte(keyValue))
 	if err != nil {
diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go
index a91afbf2d05b..a6b93d173936 100644
--- a/builtin/logical/pki/storage.go
+++ b/builtin/logical/pki/storage.go
@@ -2,6 +2,7 @@ package pki
 
 import (
 	"context"
+	"crypto"
 	"crypto/x509"
 	"encoding/pem"
 	"fmt"
@@ -174,13 +175,35 @@ func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyNam
 		return nil, false, err
 	}
 
+	// Get our public key from the current inbound key, to compare against all the other keys.
+	var pkForImportingKey crypto.PublicKey
+	if keyType == certutil.ManagedPrivateKey {
+		managedKeyUUID, err := extractManagedKeyId([]byte(keyValue))
+		if err != nil {
+			return nil, false, errutil.InternalError{Err: fmt.Sprintf("failed extracting managed key uuid from key: %v", err)}
+		}
+		pkForImportingKey, err = getManagedKeyPublicKey(mkc, managedKeyUUID)
+		if err != nil {
+			return nil, false, err
+		}
+	} else {
+		pkForImportingKey, err = getPublicKeyFromBytes([]byte(keyValue))
+		if err != nil {
+			return nil, false, err
+		}
+	}
+
 	for _, identifier := range knownKeys {
 		existingKey, err := fetchKeyById(mkc.ctx, s, identifier)
 		if err != nil {
 			return nil, false, err
 		}
+		areEqual, err := comparePublicKey(mkc, existingKey, pkForImportingKey)
+		if err != nil {
+			return nil, false, err
+		}
 
-		if existingKey.PrivateKey == keyValue {
+		if areEqual {
 			// Here, we don't need to stitch together the issuer entries,
 			// because the last run should've done that for us (or, when
 			// importing an issuer).

From 68b7410ea6407ffe3593ff2542400933c11eb60e Mon Sep 17 00:00:00 2001
From: Steve Clark <steven.clark@hashicorp.com>
Date: Mon, 2 May 2022 16:05:07 -0400
Subject: [PATCH 3/3] Remove redundant public key fetching within PKI
 importKeys

---
 builtin/logical/pki/storage.go | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go
index a6b93d173936..9908bb75fccc 100644
--- a/builtin/logical/pki/storage.go
+++ b/builtin/logical/pki/storage.go
@@ -218,11 +218,6 @@ func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyNam
 	result.PrivateKey = keyValue
 	result.PrivateKeyType = keyType
 
-	keyPublic, err := getPublicKey(mkc, &result)
-	if err != nil {
-		return nil, false, err
-	}
-
 	// Finally, we can write the key to storage.
 	if err := writeKey(mkc.ctx, s, result); err != nil {
 		return nil, false, err
@@ -258,7 +253,7 @@ func importKey(mkc managedKeyContext, s logical.Storage, keyValue string, keyNam
 			return nil, false, err
 		}
 
-		equal, err := certutil.ComparePublicKeysAndType(cert.PublicKey, keyPublic)
+		equal, err := certutil.ComparePublicKeysAndType(cert.PublicKey, pkForImportingKey)
 		if err != nil {
 			return nil, false, err
 		}