Skip to content

Commit

Permalink
Merge pull request #514 from smallstep/mariano/delete-key
Browse files Browse the repository at this point in the history
Implement DeleteKey on tpmkms
  • Loading branch information
maraino authored May 29, 2024
2 parents 88c45f1 + a889b14 commit 0a6fdaf
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 13 deletions.
4 changes: 2 additions & 2 deletions keyutil/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ func generateECKey(crv string) (crypto.Signer, error) {
}

func generateRSAKey(bits int) (crypto.Signer, error) {
if min := MinRSAKeyBytes * 8; !insecureMode.isSet() && bits < min {
return nil, errors.Errorf("the size of the RSA key should be at least %d bits", min)
if minBits := MinRSAKeyBytes * 8; !insecureMode.isSet() && bits < minBits {
return nil, errors.Errorf("the size of the RSA key should be at least %d bits", minBits)
}

key, err := rsa.GenerateKey(rand.Reader, bits)
Expand Down
76 changes: 65 additions & 11 deletions kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,35 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
}, nil
}

// DeleteKey deletes a key identified by name from the TPMKMS.
//
// # Experimental
//
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
// release.
func (k *TPMKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
if req.Name == "" {
return fmt.Errorf("deleteKeyRequest 'name' cannot be empty")
}
properties, err := parseNameURI(req.Name)
if err != nil {
return fmt.Errorf("failed parsing %q: %w", req.Name, err)
}

ctx := context.Background()
if properties.ak {
if err := k.tpm.DeleteAK(ctx, properties.name); err != nil {
return notFoundError(err)
}
} else {
if err := k.tpm.DeleteKey(ctx, properties.name); err != nil {
return notFoundError(err)
}
}

return nil
}

// CreateSigner creates a signer using a key present in the TPM KMS.
//
// The `signingKey` in the [apiv1.CreateSignerRequest] can be used to specify
Expand Down Expand Up @@ -460,7 +489,7 @@ func (k *TPMKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er
switch {
case properties.name != "":
ctx := context.Background()
key, err := k.tpm.GetKey(ctx, properties.name)
key, err := k.getKey(ctx, properties.name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -518,7 +547,7 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
switch {
case properties.name != "":
if properties.ak {
ak, err := k.tpm.GetAK(ctx, properties.name)
ak, err := k.getAK(ctx, properties.name)
if err != nil {
return nil, err
}
Expand All @@ -529,7 +558,7 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
return akPub, nil
}

key, err := k.tpm.GetKey(ctx, properties.name)
key, err := k.getKey(ctx, properties.name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -598,13 +627,13 @@ func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([
ctx := context.Background()
var chain []*x509.Certificate
if properties.ak {
ak, err := k.tpm.GetAK(ctx, properties.name)
ak, err := k.getAK(ctx, properties.name)
if err != nil {
return nil, err
}
chain = ak.CertificateChain()
} else {
key, err := k.tpm.GetKey(ctx, properties.name)
key, err := k.getKey(ctx, properties.name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -741,7 +770,7 @@ func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)

ctx := context.Background()
if properties.ak {
ak, err := k.tpm.GetAK(ctx, properties.name)
ak, err := k.getAK(ctx, properties.name)
if err != nil {
return err
}
Expand All @@ -750,7 +779,7 @@ func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err)
}
} else {
key, err := k.tpm.GetKey(ctx, properties.name)
key, err := k.getKey(ctx, properties.name)
if err != nil {
return err
}
Expand Down Expand Up @@ -898,15 +927,15 @@ func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {

ctx := context.Background()
if properties.ak {
ak, err := k.tpm.GetAK(ctx, properties.name)
ak, err := k.getAK(ctx, properties.name)
if err != nil {
return err
}
if err := ak.SetCertificateChain(ctx, nil); err != nil {
return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err)
}
} else {
key, err := k.tpm.GetKey(ctx, properties.name)
key, err := k.getKey(ctx, properties.name)
if err != nil {
return err
}
Expand Down Expand Up @@ -1060,7 +1089,7 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
var key *tpm.Key
akName := properties.name
if !properties.ak {
key, err = k.tpm.GetKey(ctx, properties.name)
key, err = k.getKey(ctx, properties.name)
if err != nil {
return nil, err
}
Expand All @@ -1070,7 +1099,7 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
akName = key.AttestedBy()
}

ak, err := k.tpm.GetAK(ctx, akName)
ak, err := k.getAK(ctx, akName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1227,6 +1256,31 @@ func (k *TPMKMS) hasValidIdentity(ak *tpm.AK, ekURL *url.URL) error {
return ErrIdentityCertificateInvalid
}

func (k *TPMKMS) getAK(ctx context.Context, name string) (*tpm.AK, error) {
ak, err := k.tpm.GetAK(ctx, name)
if err != nil {
return nil, notFoundError(err)
}
return ak, nil
}

func (k *TPMKMS) getKey(ctx context.Context, name string) (*tpm.Key, error) {
key, err := k.tpm.GetKey(ctx, name)
if err != nil {
return nil, notFoundError(err)
}
return key, nil
}

func notFoundError(err error) error {
if errors.Is(err, tpm.ErrNotFound) {
return apiv1.NotFoundError{
Message: err.Error(),
}
}
return err
}

// generateKeyID generates a key identifier from the
// SHA256 hash of the public key.
func generateKeyID(pub crypto.PublicKey) ([]byte, error) {
Expand Down
67 changes: 67 additions & 0 deletions kms/tpmkms/tpmkms_simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,73 @@ func TestTPMKMS_CreateKey(t *testing.T) {
}
}

func TestTPMKMS_DeleteKey(t *testing.T) {
okTPM := newSimulatedTPM(t,
withAK("ak1"), withAK("ak2"),
withKey("key1"), withKey("key2"),
)

validatePending := func(t *testing.T, k *TPMKMS) {
_, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak2;ak=true"})
assert.NoError(t, err)
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key2"})
assert.NoError(t, err)
}

type fields struct {
tpm *tpmp.TPM
}
type args struct {
req *apiv1.DeleteKeyRequest
}
tests := []struct {
name string
fields fields
args args
assertion assert.ErrorAssertionFunc
validate func(*testing.T, *TPMKMS)
}{
{"ok", fields{okTPM}, args{&apiv1.DeleteKeyRequest{
Name: "tpmkms:name=key1",
}}, assert.NoError, func(t *testing.T, k *TPMKMS) {
_, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak1;ak=true"})
assert.NoError(t, err)
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak2;ak=true"})
assert.NoError(t, err)
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key1"})
assert.ErrorIs(t, err, apiv1.NotFoundError{})
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key2"})
assert.NoError(t, err)
}},
{"ok ak", fields{okTPM}, args{&apiv1.DeleteKeyRequest{
Name: "tpmkms:name=ak1;ak=true",
}}, assert.NoError, func(t *testing.T, k *TPMKMS) {
_, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak1;ak=true"})
assert.ErrorIs(t, err, apiv1.NotFoundError{})
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak2;ak=true"})
assert.NoError(t, err)
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key1"})
assert.ErrorIs(t, err, apiv1.NotFoundError{})
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key2"})
assert.NoError(t, err)
}},
{"fail name", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: ""}}, assert.Error, validatePending},
{"fail not ak", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "tpmkms:name=ak2"}}, assert.Error, validatePending},
{"fail not key", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "tpmkms:name=key2;ak=true"}}, assert.Error, validatePending},
{"fail missing other", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "tpmkms:name=missing"}}, assert.Error, validatePending},
{"fail uri", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "kms:name=key2"}}, assert.Error, validatePending},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &TPMKMS{
tpm: tt.fields.tpm,
}
tt.assertion(t, k.DeleteKey(tt.args.req))
tt.validate(t, k)
})
}
}

func TestTPMKMS_CreateSigner(t *testing.T) {
tpmWithKey := newSimulatedTPM(t, withKey("key1"))

Expand Down
30 changes: 30 additions & 0 deletions kms/tpmkms/tpmkms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package tpmkms
import (
"context"
"encoding/asn1"
"errors"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/kms/apiv1"
"go.step.sm/crypto/tpm"
"go.step.sm/crypto/tpm/tss2"
)

Expand Down Expand Up @@ -108,3 +111,30 @@ func Test_parseTSS2(t *testing.T) {
})
}
}

func Test_notFoundError(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
assertion assert.ErrorAssertionFunc
}{
{"nil", args{nil}, assert.NoError},
{"tpm not found", args{tpm.ErrNotFound}, func(tt assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, i...)
}},
{"tpm not found wrapped", args{fmt.Errorf("some error: %w", tpm.ErrNotFound)}, func(tt assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, i...)
}},
{"other", args{tpm.ErrExists}, func(tt assert.TestingT, err error, i ...interface{}) bool {
return assert.False(t, errors.Is(err, apiv1.NotFoundError{}), i...)
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.assertion(t, notFoundError(tt.args.err))
})
}
}

0 comments on commit 0a6fdaf

Please sign in to comment.