From 2ad0fe8dd25f03828bd5e506f9b7b7d9aa3cbe72 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 28 Aug 2018 14:49:19 -0400 Subject: [PATCH] Move logic around a bit to avoid holding locks when not necessary Also, ensure we are error checking the rand call --- vault/barrier_aes_gcm.go | 154 +++++++++++++++++++++------------- vault/barrier_aes_gcm_test.go | 21 ++++- 2 files changed, 112 insertions(+), 63 deletions(-) diff --git a/vault/barrier_aes_gcm.go b/vault/barrier_aes_gcm.go index 001bd3b71a5f..c3912cebd122 100644 --- a/vault/barrier_aes_gcm.go +++ b/vault/barrier_aes_gcm.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/subtle" "encoding/binary" + "errors" "fmt" "strings" "sync" @@ -154,7 +155,10 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er } // Encrypt the barrier init value - value := b.encrypt(keyringPath, initialKeyTerm, gcm, keyringBuf) + value, err := b.encrypt(keyringPath, initialKeyTerm, gcm, keyringBuf) + if err != nil { + return err + } // Create the keyring physical entry pe := &physical.Entry{ @@ -183,7 +187,10 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er if err != nil { return err } - value = b.encrypt(masterKeyPath, activeKey.Term, aead, keyBuf) + value, err = b.encrypt(masterKeyPath, activeKey.Term, aead, keyBuf) + if err != nil { + return err + } // Update the masterKeyPath for standby instances pe = &physical.Entry{ @@ -253,7 +260,13 @@ func (b *AESGCMBarrier) ReloadKeyring(ctx context.Context) error { // Ensure that the keyring exists. This should never happen, // and indicates something really bad has happened. if out == nil { - return fmt.Errorf("keyring unexpectedly missing") + return errors.New("keyring unexpectedly missing") + } + + // Verify the term is always just one + term := binary.BigEndian.Uint32(out.Value[:4]) + if term != initialKeyTerm { + return errors.New("term mis-match") } // Decrypt the barrier init key @@ -340,6 +353,12 @@ func (b *AESGCMBarrier) Unseal(ctx context.Context, key []byte) error { return errwrap.Wrapf("failed to check for keyring: {{err}}", err) } if out != nil { + // Verify the term is always just one + term := binary.BigEndian.Uint32(out.Value[:4]) + if term != initialKeyTerm { + return errors.New("term mis-match") + } + // Decrypt the barrier init key plain, err := b.decrypt(keyringPath, gcm, out.Value) defer memzero(plain) @@ -371,6 +390,12 @@ func (b *AESGCMBarrier) Unseal(ctx context.Context, key []byte) error { return ErrBarrierNotInit } + // Verify the term is always just one + term := binary.BigEndian.Uint32(out.Value[:4]) + if term != initialKeyTerm { + return errors.New("term mis-match") + } + // Decrypt the barrier init key plain, err := b.decrypt(barrierInitPath, gcm, out.Value) if err != nil { @@ -494,7 +519,10 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error { } key := fmt.Sprintf("%s%d", keyringUpgradePrefix, prevTerm) - value := b.encrypt(key, prevTerm, primary, buf) + value, err := b.encrypt(key, prevTerm, primary, buf) + if err != nil { + return err + } // Create upgrade key pe := &physical.Entry{ Key: key, @@ -637,20 +665,25 @@ func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) { func (b *AESGCMBarrier) Put(ctx context.Context, entry *Entry) error { defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now()) b.l.RLock() - defer b.l.RUnlock() if b.sealed { + b.l.RUnlock() return ErrBarrierSealed } term := b.keyring.ActiveTerm() primary, err := b.aeadForTerm(term) + b.l.RUnlock() if err != nil { return err } + value, err := b.encrypt(entry.Key, term, primary, entry.Value) + if err != nil { + return err + } pe := &physical.Entry{ Key: entry.Key, - Value: b.encrypt(entry.Key, term, primary, entry.Value), + Value: value, SealWrap: entry.SealWrap, } return b.backend.Put(ctx, pe) @@ -660,21 +693,38 @@ func (b *AESGCMBarrier) Put(ctx context.Context, entry *Entry) error { func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*Entry, error) { defer metrics.MeasureSince([]string{"barrier", "get"}, time.Now()) b.l.RLock() - defer b.l.RUnlock() if b.sealed { + b.l.RUnlock() return nil, ErrBarrierSealed } // Read the key from the backend pe, err := b.backend.Get(ctx, key) if err != nil { + b.l.RUnlock() return nil, err } else if pe == nil { + b.l.RUnlock() return nil, nil } + // Verify the term + term := binary.BigEndian.Uint32(pe.Value[:4]) + + // Get the GCM by term + // It is expensive to do this first but it is not a + // normal case that this won't match + gcm, err := b.aeadForTerm(term) + b.l.RUnlock() + if err != nil { + return nil, err + } + if gcm == nil { + return nil, fmt.Errorf("no decryption key available for term %d", term) + } + // Decrypt the ciphertext - plain, err := b.decryptKeyring(key, pe.Value) + plain, err := b.decrypt(key, gcm, pe.Value) if err != nil { return nil, errwrap.Wrapf("decryption failed: {{err}}", err) } @@ -692,8 +742,9 @@ func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*Entry, error) { func (b *AESGCMBarrier) Delete(ctx context.Context, key string) error { defer metrics.MeasureSince([]string{"barrier", "delete"}, time.Now()) b.l.RLock() - defer b.l.RUnlock() - if b.sealed { + sealed := b.sealed + b.l.RUnlock() + if sealed { return ErrBarrierSealed } @@ -705,8 +756,9 @@ func (b *AESGCMBarrier) Delete(ctx context.Context, key string) error { func (b *AESGCMBarrier) List(ctx context.Context, prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"barrier", "list"}, time.Now()) b.l.RLock() - defer b.l.RUnlock() - if b.sealed { + sealed := b.sealed + b.l.RUnlock() + if sealed { return nil, ErrBarrierSealed } @@ -765,7 +817,7 @@ func (b *AESGCMBarrier) aeadFromKey(key []byte) (cipher.AEAD, error) { } // encrypt is used to encrypt a value -func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain []byte) []byte { +func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain []byte) ([]byte, error) { // Allocate the output buffer with room for tern, version byte, // nonce, GCM tag and the plaintext capacity := termSize + 1 + gcm.NonceSize() + gcm.Overhead() + len(plain) @@ -780,7 +832,13 @@ func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain // Generate a random nonce nonce := out[5 : 5+gcm.NonceSize()] - rand.Read(nonce) + n, err := rand.Read(nonce) + if err != nil { + return nil, err + } + if n != len(nonce) { + return nil, errors.New("unable to read enough random bytes to fill gcm nonce") + } // Seal the output switch b.currentAESGCMVersionByte { @@ -792,53 +850,16 @@ func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain panic("Unknown AESGCM version") } - return out + return out, nil } -// decrypt is used to decrypt a value +// decrypt is used to decrypt a value using the keyring func (b *AESGCMBarrier) decrypt(path string, gcm cipher.AEAD, cipher []byte) ([]byte, error) { - // Verify the term is always just one - term := binary.BigEndian.Uint32(cipher[:4]) - if term != initialKeyTerm { - return nil, fmt.Errorf("term mis-match") - } - // Capture the parts nonce := cipher[5 : 5+gcm.NonceSize()] raw := cipher[5+gcm.NonceSize():] out := make([]byte, 0, len(raw)-gcm.NonceSize()) - // Verify the cipher byte and attempt to open - switch cipher[4] { - case AESGCMVersion1: - return gcm.Open(out, nonce, raw, nil) - case AESGCMVersion2: - return gcm.Open(out, nonce, raw, []byte(path)) - default: - return nil, fmt.Errorf("version bytes mis-match") - } -} - -// decryptKeyring is used to decrypt a value using the keyring -func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, error) { - // Verify the term - term := binary.BigEndian.Uint32(cipher[:4]) - - // Get the GCM by term - // It is expensive to do this first but it is not a - // normal case that this won't match - gcm, err := b.aeadForTerm(term) - if err != nil { - return nil, err - } - if gcm == nil { - return nil, fmt.Errorf("no decryption key available for term %d", term) - } - - nonce := cipher[5 : 5+gcm.NonceSize()] - raw := cipher[5+gcm.NonceSize():] - out := make([]byte, 0, len(raw)-gcm.NonceSize()) - // Attempt to open switch cipher[4] { case AESGCMVersion1: @@ -860,13 +881,15 @@ func (b *AESGCMBarrier) Encrypt(ctx context.Context, key string, plaintext []byt term := b.keyring.ActiveTerm() primary, err := b.aeadForTerm(term) + b.l.RUnlock() if err != nil { - b.l.RUnlock() return nil, err } - ciphertext := b.encrypt(key, term, primary, plaintext) - b.l.RUnlock() + ciphertext, err := b.encrypt(key, term, primary, plaintext) + if err != nil { + return nil, err + } return ciphertext, nil } @@ -878,14 +901,27 @@ func (b *AESGCMBarrier) Decrypt(ctx context.Context, key string, ciphertext []by return nil, ErrBarrierSealed } + // Verify the term + term := binary.BigEndian.Uint32(ciphertext[:4]) + + // Get the GCM by term + // It is expensive to do this first but it is not a + // normal case that this won't match + gcm, err := b.aeadForTerm(term) + b.l.RUnlock() + if err != nil { + return nil, err + } + if gcm == nil { + return nil, fmt.Errorf("no decryption key available for term %d", term) + } + // Decrypt the ciphertext - plain, err := b.decryptKeyring(key, ciphertext) + plain, err := b.decrypt(key, gcm, ciphertext) if err != nil { - b.l.RUnlock() return nil, errwrap.Wrapf("decryption failed: {{err}}", err) } - b.l.RUnlock() return plain, nil } diff --git a/vault/barrier_aes_gcm_test.go b/vault/barrier_aes_gcm_test.go index 71270e262649..33a6a83c90dd 100644 --- a/vault/barrier_aes_gcm_test.go +++ b/vault/barrier_aes_gcm_test.go @@ -125,7 +125,10 @@ func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) { // Protect with master key master, _ := b.GenerateKey() gcm, _ := b.aeadFromKey(master) - value := b.encrypt(barrierInitPath, initialKeyTerm, gcm, buf) + value, err := b.encrypt(barrierInitPath, initialKeyTerm, gcm, buf) + if err != nil { + t.Fatal(err) + } // Write to the physical backend pe := &physical.Entry{ @@ -136,9 +139,13 @@ func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) { // Create a fake key gcm, _ = b.aeadFromKey(encrypt) + value, err = b.encrypt("test/foo", initialKeyTerm, gcm, []byte("test")) + if err != nil { + t.Fatal(err) + } pe = &physical.Entry{ Key: "test/foo", - Value: b.encrypt("test/foo", initialKeyTerm, gcm, []byte("test")), + Value: value, } inm.Put(context.Background(), pe) @@ -429,8 +436,14 @@ func TestEncrypt_Unique(t *testing.T) { term := b.keyring.ActiveTerm() primary, _ := b.aeadForTerm(term) - first := b.encrypt("test", term, primary, entry.Value) - second := b.encrypt("test", term, primary, entry.Value) + first, err := b.encrypt("test", term, primary, entry.Value) + if err != nil { + t.Fatal(err) + } + second, err := b.encrypt("test", term, primary, entry.Value) + if err != nil { + t.Fatal(err) + } if bytes.Equal(first, second) == true { t.Fatalf("improper random seeding detected")