diff --git a/extras/kms/kms.go b/extras/kms/kms.go index cfe27609..31ae828c 100644 --- a/extras/kms/kms.go +++ b/extras/kms/kms.go @@ -174,7 +174,8 @@ func (k *Kms) GetExternalRootWrapper() (wrapping.Wrapper, error) { return nil, fmt.Errorf("%s: missing external root wrapper: %w", op, ErrKeyNotFound) } -// GetWrapper returns a wrapper for the given scope and purpose. +// GetWrapper returns a wrapper for the given scope and purpose. The +// WithReader(...) option is supported for getting a wrapper. // // If an optional WithKeyVersionId(...) or WithKeyId(...) is // passed, it will ensure that the returning wrapper has that key version ID in the @@ -228,7 +229,7 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose // root for the scope as we'll need it to decrypt the value coming from the // DB. We don't cache the roots as we expect that after a few calls the // scope-purpose cache will catch everything in steady-state. - rootWrapper, rootKeyId, err := k.loadRoot(ctx, scopeId) + rootWrapper, rootKeyId, err := k.loadRoot(ctx, scopeId, WithReader(opts.withReader)) if err != nil { return nil, fmt.Errorf("%s: error loading root key for scope %q: %w", op, scopeId, err) } @@ -240,7 +241,7 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose return rootWrapper, nil } - wrapper, err := k.loadDek(ctx, scopeId, purpose, rootWrapper, rootKeyId) + wrapper, err := k.loadDek(ctx, scopeId, purpose, rootWrapper, rootKeyId, WithReader(opts.withReader)) if err != nil { return nil, fmt.Errorf("%s: error loading %q for scope %q: %w", op, purpose, scopeId, err) } @@ -479,7 +480,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err // since we could have started a local txn, we'll use an anon function for // all the stmts which should be managed within that possible local txn. if err := func() error { - rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, withReader(reader)) + rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, WithReader(reader)) if err != nil { return fmt.Errorf("%s: unable to load the scope's root key: %w", op, err) } @@ -499,7 +500,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err return fmt.Errorf("%s: unable to rotate root key version: %w", op, err) } - rkvWrapper, _, err := k.loadRoot(ctx, scopeId, withReader(reader)) + rkvWrapper, _, err := k.loadRoot(ctx, scopeId, WithReader(reader)) if err != nil { return fmt.Errorf("%s: unable to load the root key version wrapper: %w", op, err) } @@ -574,7 +575,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err // since we could have started a local txn, we'll use an anon function for // all the stmts which should be managed within that possible local txn. if err := func() error { - rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, withReader(reader)) + rk, err := k.repo.LookupRootKeyByScope(ctx, scopeId, WithReader(reader)) if err != nil { return fmt.Errorf("%s: unable to load the scope's root key: %w", op, err) } @@ -585,7 +586,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err return fmt.Errorf("%s: unable to rewrap root key versions: %w", op, err) } - rkvWrapper, _, err := k.loadRoot(ctx, scopeId, withReader(reader)) + rkvWrapper, _, err := k.loadRoot(ctx, scopeId, WithReader(reader)) if err != nil { return fmt.Errorf("%s: unable to load the root key version wrapper: %w", op, err) } @@ -740,7 +741,7 @@ func (k *Kms) loadRoot(ctx context.Context, scopeId string, opt ...Option) (_ *m return nil, "", fmt.Errorf("%s: missing root key wrapper for scope %q: %w", op, scopeId, ErrKeyNotFound) } opts := getOpts(opt...) - rootKeyVersions, err := k.repo.ListRootKeyVersions(ctx, externalRootWrapper, rootKeyId, withOrderByVersion(descendingOrderBy), withReader(opts.withReader)) + rootKeyVersions, err := k.repo.ListRootKeyVersions(ctx, externalRootWrapper, rootKeyId, withOrderByVersion(descendingOrderBy), WithReader(opts.withReader)) if err != nil { return nil, "", fmt.Errorf("%s: error looking up root key versions for scope %q: %w", op, scopeId, err) } @@ -774,7 +775,7 @@ func (k *Kms) loadRoot(ctx context.Context, scopeId string, opt ...Option) (_ *m return pooled, rootKeyId, nil } -func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, rootWrapper wrapping.Wrapper, rootKeyId string) (*multi.PooledWrapper, error) { +func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, rootWrapper wrapping.Wrapper, rootKeyId string, opt ...Option) (*multi.PooledWrapper, error) { const op = "kms.(Kms).loadDek" if scopeId == "" { return nil, fmt.Errorf("%s: missing scope id: %w", op, ErrInvalidParameter) @@ -789,7 +790,9 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r return nil, fmt.Errorf("%s: not a supported key purpose %q: %w", op, purpose, ErrInvalidParameter) } - keys, err := k.repo.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId)) + opts := getOpts(opt...) + + keys, err := k.repo.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId), WithReader(opts.withReader)) if err != nil { return nil, fmt.Errorf("%s: error listing keys for purpose %q: %w", op, purpose, err) } @@ -802,7 +805,7 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r default: keyId = keys[0].GetPrivateId() } - keyVersions, err := k.repo.ListDataKeyVersions(ctx, rootWrapper, keyId, withOrderByVersion(descendingOrderBy)) + keyVersions, err := k.repo.ListDataKeyVersions(ctx, rootWrapper, keyId, withOrderByVersion(descendingOrderBy), WithReader(opts.withReader)) if err != nil { return nil, fmt.Errorf("%s: error looking up %q key versions for scope %q: %w", op, purpose, scopeId, err) } diff --git a/extras/kms/kms_test.go b/extras/kms/kms_test.go index 00333ada..533daeaa 100644 --- a/extras/kms/kms_test.go +++ b/extras/kms/kms_test.go @@ -501,6 +501,30 @@ func TestKms_GetWrapper(t *testing.T) { assert.NotNil(got) }) } + t.Run("with-reader", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + db, _ := kms.TestDb(t) + rw := dbw.New(db) + k, err := kms.New(rw, rw, []kms.KeyPurpose{"database", "auth"}) + require.NoError(err) + require.NoError(k.AddExternalWrapper(testCtx, kms.KeyPurposeRootKey, wrapper)) + testDeleteWhere(t, db, &rootKey{}, "1=1") + require.NoError(k.CreateKeys(testCtx, "global", []kms.KeyPurpose{"database", "auth"})) + + emptyDb, _ := kms.TestDb(t) + emptyRw := dbw.New(emptyDb) + emptyKms, err := kms.New(emptyRw, emptyRw, []kms.KeyPurpose{"database", "auth"}) + require.NoError(err) + require.NoError(emptyKms.AddExternalWrapper(testCtx, kms.KeyPurposeRootKey, wrapper)) + + got, err := emptyKms.GetWrapper(testCtx, "global", "database") + assert.Empty(got) + assert.Error(err) + + got, err = emptyKms.GetWrapper(testCtx, "global", "database", kms.WithReader(rw)) + assert.NotEmpty(got) + assert.NoError(err) + }) } func TestKms_CreateKeys(t *testing.T) { diff --git a/extras/kms/option.go b/extras/kms/option.go index 054c51ca..deb2b7e3 100644 --- a/extras/kms/option.go +++ b/extras/kms/option.go @@ -142,8 +142,8 @@ func WithScopeIds(id ...string) Option { } } -// withReader provides an optional reader -func withReader(r dbw.Reader) Option { +// WithReader provides an optional reader +func WithReader(r dbw.Reader) Option { return func(o *options) { o.withReader = r } diff --git a/extras/kms/repository_data_key_version.go b/extras/kms/repository_data_key_version.go index b62d8bed..4a8c7889 100644 --- a/extras/kms/repository_data_key_version.go +++ b/extras/kms/repository_data_key_version.go @@ -277,13 +277,13 @@ func rewrapDataKeyVersionsTx(ctx context.Context, reader dbw.Reader, writer dbw. if err != nil { return fmt.Errorf("%s: unable to create repo: %w", op, err) } - dks, err := r.ListDataKeys(ctx, withRootKeyId(rootKeyId), withReader(reader)) + dks, err := r.ListDataKeys(ctx, withRootKeyId(rootKeyId), WithReader(reader)) if err != nil { return fmt.Errorf("%s: unable to list the current data keys: %w", op, err) } for _, dk := range dks { var versions []*dataKeyVersion - if err := r.list(ctx, &versions, "data_key_id = ?", []interface{}{dk.PrivateId}, withReader(reader)); err != nil { + if err := r.list(ctx, &versions, "data_key_id = ?", []interface{}{dk.PrivateId}, WithReader(reader)); err != nil { return fmt.Errorf("%s: unable to list the current data key versions: %w", op, err) } for _, v := range versions { @@ -338,7 +338,7 @@ func rotateDataKeyVersionTx(ctx context.Context, reader dbw.Reader, writer dbw.W if err != nil { return fmt.Errorf("%s: unable to create repo: %w", op, err) } - dataKeys, err := r.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId), withReader(reader)) + dataKeys, err := r.ListDataKeys(ctx, withPurpose(purpose), withRootKeyId(rootKeyId), WithReader(reader)) switch { case err != nil: return fmt.Errorf("%s: unable to lookup data key for %q: %w", op, purpose, err) diff --git a/extras/kms/repository_root_key_version.go b/extras/kms/repository_root_key_version.go index f9005a5a..c89135af 100644 --- a/extras/kms/repository_root_key_version.go +++ b/extras/kms/repository_root_key_version.go @@ -211,7 +211,7 @@ func rewrapRootKeyVersionsTx(ctx context.Context, reader dbw.Reader, writer dbw. return fmt.Errorf("%s: unable to create repo: %w", op, err) } // rewrap the rootKey versions using the scope's root key to find them - rkvs, err := r.ListRootKeyVersions(ctx, rootWrapper, rootKeyId, withReader(reader)) + rkvs, err := r.ListRootKeyVersions(ctx, rootWrapper, rootKeyId, WithReader(reader)) if err != nil { return fmt.Errorf("%s: unable to list root key versions: %w", op, err) }