Skip to content

Commit

Permalink
feature (extra/kms): add WithReader(...) support to GetWrapper(...) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt authored Oct 28, 2022
1 parent 92bfd9f commit 3e936a5
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 17 deletions.
25 changes: 14 additions & 11 deletions extras/kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ func (k *Kms) GetExternalRootWrapper() (wrapping.Wrapper, error) {
return nil, fmt.Errorf("%s: missing external root wrapper: %w", op, ErrKeyNotFound)
}

// GetWrapper returns a wrapper for the given scope and purpose.
// GetWrapper returns a wrapper for the given scope and purpose. The
// WithReader(...) option is supported for getting a wrapper.
//
// If an optional WithKeyVersionId(...) or WithKeyId(...) is
// passed, it will ensure that the returning wrapper has that key version ID in the
Expand Down Expand Up @@ -228,7 +229,7 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose
// root for the scope as we'll need it to decrypt the value coming from the
// DB. We don't cache the roots as we expect that after a few calls the
// scope-purpose cache will catch everything in steady-state.
rootWrapper, rootKeyId, err := k.loadRoot(ctx, scopeId)
rootWrapper, rootKeyId, err := k.loadRoot(ctx, scopeId, WithReader(opts.withReader))
if err != nil {
return nil, fmt.Errorf("%s: error loading root key for scope %q: %w", op, scopeId, err)
}
Expand All @@ -240,7 +241,7 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose
return rootWrapper, nil
}

wrapper, err := k.loadDek(ctx, scopeId, purpose, rootWrapper, rootKeyId)
wrapper, err := k.loadDek(ctx, scopeId, purpose, rootWrapper, rootKeyId, WithReader(opts.withReader))
if err != nil {
return nil, fmt.Errorf("%s: error loading %q for scope %q: %w", op, purpose, scopeId, err)
}
Expand Down Expand Up @@ -479,7 +480,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err
// since we could have started a local txn, we'll use an anon function for
// all the stmts which should be managed within that possible local txn.
if err := func() error {
rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, withReader(reader))
rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, WithReader(reader))
if err != nil {
return fmt.Errorf("%s: unable to load the scope's root key: %w", op, err)
}
Expand All @@ -499,7 +500,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err
return fmt.Errorf("%s: unable to rotate root key version: %w", op, err)
}

rkvWrapper, _, err := k.loadRoot(ctx, scopeId, withReader(reader))
rkvWrapper, _, err := k.loadRoot(ctx, scopeId, WithReader(reader))
if err != nil {
return fmt.Errorf("%s: unable to load the root key version wrapper: %w", op, err)
}
Expand Down Expand Up @@ -574,7 +575,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err
// since we could have started a local txn, we'll use an anon function for
// all the stmts which should be managed within that possible local txn.
if err := func() error {
rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, withReader(reader))
rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, WithReader(reader))
if err != nil {
return fmt.Errorf("%s: unable to load the scope's root key: %w", op, err)
}
Expand All @@ -585,7 +586,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err
return fmt.Errorf("%s: unable to rewrap root key versions: %w", op, err)
}

rkvWrapper, _, err := k.loadRoot(ctx, scopeId, withReader(reader))
rkvWrapper, _, err := k.loadRoot(ctx, scopeId, WithReader(reader))
if err != nil {
return fmt.Errorf("%s: unable to load the root key version wrapper: %w", op, err)
}
Expand Down Expand Up @@ -740,7 +741,7 @@ func (k *Kms) loadRoot(ctx context.Context, scopeId string, opt ...Option) (_ *m
return nil, "", fmt.Errorf("%s: missing root key wrapper for scope %q: %w", op, scopeId, ErrKeyNotFound)
}
opts := getOpts(opt...)
rootKeyVersions, err := k.repo.ListRootKeyVersions(ctx, externalRootWrapper, rootKeyId, withOrderByVersion(descendingOrderBy), withReader(opts.withReader))
rootKeyVersions, err := k.repo.ListRootKeyVersions(ctx, externalRootWrapper, rootKeyId, withOrderByVersion(descendingOrderBy), WithReader(opts.withReader))
if err != nil {
return nil, "", fmt.Errorf("%s: error looking up root key versions for scope %q: %w", op, scopeId, err)
}
Expand Down Expand Up @@ -774,7 +775,7 @@ func (k *Kms) loadRoot(ctx context.Context, scopeId string, opt ...Option) (_ *m
return pooled, rootKeyId, nil
}

func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, rootWrapper wrapping.Wrapper, rootKeyId string) (*multi.PooledWrapper, error) {
func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, rootWrapper wrapping.Wrapper, rootKeyId string, opt ...Option) (*multi.PooledWrapper, error) {
const op = "kms.(Kms).loadDek"
if scopeId == "" {
return nil, fmt.Errorf("%s: missing scope id: %w", op, ErrInvalidParameter)
Expand All @@ -789,7 +790,9 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r
return nil, fmt.Errorf("%s: not a supported key purpose %q: %w", op, purpose, ErrInvalidParameter)
}

keys, err := k.repo.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId))
opts := getOpts(opt...)

keys, err := k.repo.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId), WithReader(opts.withReader))
if err != nil {
return nil, fmt.Errorf("%s: error listing keys for purpose %q: %w", op, purpose, err)
}
Expand All @@ -802,7 +805,7 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r
default:
keyId = keys[0].GetPrivateId()
}
keyVersions, err := k.repo.ListDataKeyVersions(ctx, rootWrapper, keyId, withOrderByVersion(descendingOrderBy))
keyVersions, err := k.repo.ListDataKeyVersions(ctx, rootWrapper, keyId, withOrderByVersion(descendingOrderBy), WithReader(opts.withReader))
if err != nil {
return nil, fmt.Errorf("%s: error looking up %q key versions for scope %q: %w", op, purpose, scopeId, err)
}
Expand Down
24 changes: 24 additions & 0 deletions extras/kms/kms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,30 @@ func TestKms_GetWrapper(t *testing.T) {
assert.NotNil(got)
})
}
t.Run("with-reader", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
db, _ := kms.TestDb(t)
rw := dbw.New(db)
k, err := kms.New(rw, rw, []kms.KeyPurpose{"database", "auth"})
require.NoError(err)
require.NoError(k.AddExternalWrapper(testCtx, kms.KeyPurposeRootKey, wrapper))
testDeleteWhere(t, db, &rootKey{}, "1=1")
require.NoError(k.CreateKeys(testCtx, "global", []kms.KeyPurpose{"database", "auth"}))

emptyDb, _ := kms.TestDb(t)
emptyRw := dbw.New(emptyDb)
emptyKms, err := kms.New(emptyRw, emptyRw, []kms.KeyPurpose{"database", "auth"})
require.NoError(err)
require.NoError(emptyKms.AddExternalWrapper(testCtx, kms.KeyPurposeRootKey, wrapper))

got, err := emptyKms.GetWrapper(testCtx, "global", "database")
assert.Empty(got)
assert.Error(err)

got, err = emptyKms.GetWrapper(testCtx, "global", "database", kms.WithReader(rw))
assert.NotEmpty(got)
assert.NoError(err)
})
}

func TestKms_CreateKeys(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions extras/kms/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ func WithScopeIds(id ...string) Option {
}
}

// withReader provides an optional reader
func withReader(r dbw.Reader) Option {
// WithReader provides an optional reader
func WithReader(r dbw.Reader) Option {
return func(o *options) {
o.withReader = r
}
Expand Down
6 changes: 3 additions & 3 deletions extras/kms/repository_data_key_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ func rewrapDataKeyVersionsTx(ctx context.Context, reader dbw.Reader, writer dbw.
if err != nil {
return fmt.Errorf("%s: unable to create repo: %w", op, err)
}
dks, err := r.ListDataKeys(ctx, withRootKeyId(rootKeyId), withReader(reader))
dks, err := r.ListDataKeys(ctx, withRootKeyId(rootKeyId), WithReader(reader))
if err != nil {
return fmt.Errorf("%s: unable to list the current data keys: %w", op, err)
}
for _, dk := range dks {
var versions []*dataKeyVersion
if err := r.list(ctx, &versions, "data_key_id = ?", []interface{}{dk.PrivateId}, withReader(reader)); err != nil {
if err := r.list(ctx, &versions, "data_key_id = ?", []interface{}{dk.PrivateId}, WithReader(reader)); err != nil {
return fmt.Errorf("%s: unable to list the current data key versions: %w", op, err)
}
for _, v := range versions {
Expand Down Expand Up @@ -338,7 +338,7 @@ func rotateDataKeyVersionTx(ctx context.Context, reader dbw.Reader, writer dbw.W
if err != nil {
return fmt.Errorf("%s: unable to create repo: %w", op, err)
}
dataKeys, err := r.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId), withReader(reader))
dataKeys, err := r.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId), WithReader(reader))
switch {
case err != nil:
return fmt.Errorf("%s: unable to lookup data key for %q: %w", op, purpose, err)
Expand Down
2 changes: 1 addition & 1 deletion extras/kms/repository_root_key_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func rewrapRootKeyVersionsTx(ctx context.Context, reader dbw.Reader, writer dbw.
return fmt.Errorf("%s: unable to create repo: %w", op, err)
}
// rewrap the rootKey versions using the scope's root key to find them
rkvs, err := r.ListRootKeyVersions(ctx, rootWrapper, rootKeyId, withReader(reader))
rkvs, err := r.ListRootKeyVersions(ctx, rootWrapper, rootKeyId, WithReader(reader))
if err != nil {
return fmt.Errorf("%s: unable to list root key versions: %w", op, err)
}
Expand Down

0 comments on commit 3e936a5

Please sign in to comment.