Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tpm2: Avoid attempting to open duplicate TPM connections #360

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bootscope/keydata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ func (h *mockPlatformKeyDataHandler) RecoverKeysWithAuthKey(data *PlatformKeyDat
return h.recoverKeys(handle, encryptedPayload)
}

func (h *mockPlatformKeyDataHandler) ChangeAuthKey(data *PlatformKeyData, old, new []byte) ([]byte, error) {
func (h *mockPlatformKeyDataHandler) ChangeAuthKey(data *PlatformKeyData, old, new []byte, context any) ([]byte, error) {
if err := h.checkState(); err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion hooks/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (*hooksPlatform) RecoverKeysWithAuthKey(data *secboot.PlatformKeyData, encr
return nil, errors.New("unsupported action")
}

func (*hooksPlatform) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte) ([]byte, error) {
func (*hooksPlatform) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte, context any) ([]byte, error) {
return nil, errors.New("unsupported action")
}

Expand Down
40 changes: 35 additions & 5 deletions internal/tpm2test/suites.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (
)

type tpmTestMixin struct {
TPM *secboot_tpm2.Connection
Transport *Transport
TPM *secboot_tpm2.Connection
Transport *Transport
mockConnection *secboot_tpm2.Connection // Unit tests always consume a connection, although this can be temporarily closed
}

func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func() (*secboot_tpm2.Connection, *Transport)) (cleanup func(*C)) {
Expand Down Expand Up @@ -68,15 +69,23 @@ func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func()
// tests existing underlying connection, but don't allow the code
// to fully close the connection - leave this to the test fixture.
// TODO: Support resource managed device concepts in tests.
// XXX: Set the maximum numbner of open connections to 1 when
// https://github.com/canonical/secboot/issues/353 is fixed.
internalDev := newTpmDevice(tpm2_testutil.NewTransportBackedDevice(suite.Transport, false, -1), tpm2_device.DeviceModeDirect, nil, tpm2_device.ErrNoPPI)
internalDev := newTpmDevice(tpm2_testutil.NewTransportBackedDevice(suite.Transport, false, 1), tpm2_device.DeviceModeDirect, nil, tpm2_device.ErrNoPPI)
restoreDefaultDeviceFn := MockDefaultDeviceFn(func(mode tpm2_device.DeviceMode) (tpm2_device.TPMDevice, error) {
c.Assert(mode, Equals, tpm2_device.DeviceModeDirect)
return internalDev, nil
})

// Unit tests always consume a connection
mockConn, err := secboot_tpm2.ConnectToDefaultTPM()
c.Assert(err, IsNil)
m.mockConnection = mockConn
// This connection isn't going to be used, so don't take up a loaded session slot
m.mockConnection.FlushContext(m.mockConnection.HmacSession())

return func(c *C) {
if m.mockConnection != nil {
c.Check(m.mockConnection.Close(), IsNil)
}
restoreDefaultDeviceFn()
c.Check(internalDev.TPMDevice.(*tpm2_testutil.TransportBackedDevice).NumberOpen(), Equals, 0)
c.Check(m.TPM.Close(), IsNil)
Expand All @@ -85,6 +94,27 @@ func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func()
}
}

// CloseMockConnection closes a mock connection that is opened automatically in test
// setup via the mock device, but not exposed for use by testing. The connection mocks
// the behaviour of having already called ConnectToDefaultTPM for testing APIs that
// accept an already open connection. In order to test APIs that don't already accept
// an open connection, and open their own connection instead, the test should call this
// API to temporarily close the internal mock connection.
//
// It returns a callback to re-open the mock connection again.
func (m *tpmTestMixin) CloseMockConnection(c *C) (restore func()) {
c.Assert(m.mockConnection, NotNil)
c.Check(m.mockConnection.Close(), IsNil)
m.mockConnection = nil
return func() {
mockConn, err := secboot_tpm2.ConnectToDefaultTPM()
c.Assert(err, IsNil)
m.mockConnection = mockConn
// This connection isn't going to be used, so don't take up a loaded session slot
m.mockConnection.FlushContext(m.mockConnection.HmacSession())
}
}

func (m *tpmTestMixin) reinitTPMConnectionFromExisting(c *C, suite *tpm2_testutil.TPMTest) {
tpm, transport, err := newTPMConnectionFromExistingTransport(m.TPM, m.Transport)
c.Assert(err, IsNil)
Expand Down
15 changes: 11 additions & 4 deletions keydata.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ type KeyWithPassphraseParams struct {
// AuthKeySize is the size of key to derive from the passphrase for
// use by the platform implementation.
AuthKeySize int

// ChangeAuthKeyContext can be set to the caller to any arbitrary value
// that should be passed to the initial call to
// [PlatformKeyDataHandler.ChangeAuthKey]. The main use for this is to
// permit the tpm2 package to supply an open TPM connection.
ChangeAuthKeyContext any
}

// KeyID is the unique ID for a KeyData object. It is used to facilitate the
Expand Down Expand Up @@ -471,7 +477,7 @@ func (d *KeyData) derivePassphraseKeys(passphrase string) (key, iv, auth []byte,
return key, iv, auth, nil
}

func (d *KeyData) updatePassphrase(payload, oldAuthKey []byte, passphrase string) error {
func (d *KeyData) updatePassphrase(payload, oldAuthKey []byte, passphrase string, platformContext any) error {
handler := handlers[d.data.PlatformName]
if handler == nil {
return ErrNoPlatformHandlerRegistered
Expand All @@ -487,7 +493,7 @@ func (d *KeyData) updatePassphrase(payload, oldAuthKey []byte, passphrase string
return fmt.Errorf("unexpected encryption algorithm \"%s\"", d.data.PassphraseParams.Encryption)
}

handle, err := handler.ChangeAuthKey(d.platformKeyData(), oldAuthKey, authKey)
handle, err := handler.ChangeAuthKey(d.platformKeyData(), oldAuthKey, authKey, platformContext)
if err != nil {
return err
}
Expand Down Expand Up @@ -533,6 +539,7 @@ func (d *KeyData) platformKeyData() *PlatformKeyData {
return &PlatformKeyData{
Generation: d.Generation(),
EncodedHandle: d.data.PlatformHandle,
Role: d.data.Role,
KDFAlg: crypto.Hash(d.data.KDFAlg),
AuthMode: d.AuthMode(),
}
Expand Down Expand Up @@ -702,7 +709,7 @@ func (d *KeyData) ChangePassphrase(oldPassphrase, newPassphrase string) error {
return err
}

if err := d.updatePassphrase(payload, oldKey, newPassphrase); err != nil {
if err := d.updatePassphrase(payload, oldKey, newPassphrase, nil); err != nil {
return processPlatformHandlerError(err)
}

Expand Down Expand Up @@ -796,7 +803,7 @@ func NewKeyDataWithPassphrase(params *KeyWithPassphraseParams, passphrase string
AuthKeySize: params.AuthKeySize,
}

if err := kd.updatePassphrase(kd.data.EncryptedPayload, make([]byte, params.AuthKeySize), passphrase); err != nil {
if err := kd.updatePassphrase(kd.data.EncryptedPayload, make([]byte, params.AuthKeySize), passphrase, params.ChangeAuthKeyContext); err != nil {
return nil, xerrors.Errorf("cannot set passphrase: %w", err)
}

Expand Down
24 changes: 20 additions & 4 deletions keydata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"io"
"io/ioutil"
"math/rand"
"reflect"
"time"

. "github.com/snapcore/secboot"
Expand Down Expand Up @@ -160,7 +161,11 @@ func (h *mockPlatformKeyDataHandler) RecoverKeysWithAuthKey(data *PlatformKeyDat
return h.recoverKeys(handle, encryptedPayload)
}

func (h *mockPlatformKeyDataHandler) ChangeAuthKey(data *PlatformKeyData, old, new []byte) ([]byte, error) {
type mockChangeAuthKeyContextType struct{}

var mockChangeAuthKeyContext = mockChangeAuthKeyContextType{}

func (h *mockPlatformKeyDataHandler) ChangeAuthKey(data *PlatformKeyData, old, new []byte, context any) ([]byte, error) {
if !h.passphraseSupport {
return nil, errors.New("not supported")
}
Expand All @@ -169,6 +174,16 @@ func (h *mockPlatformKeyDataHandler) ChangeAuthKey(data *PlatformKeyData, old, n
return nil, err
}

switch c := context.(type) {
case nil:
case mockChangeAuthKeyContextType:
if c != mockChangeAuthKeyContext {
return nil, errors.New("unexpected context value")
}
default:
return nil, fmt.Errorf("unexpected context type: %v", reflect.TypeOf(context))
}

handle, err := h.unmarshalHandle(data)
if err != nil {
return nil, err
Expand Down Expand Up @@ -356,9 +371,10 @@ func (s *keyDataTestBase) mockProtectKeysWithPassphrase(c *C, primaryKey Primary
}

kpp := &KeyWithPassphraseParams{
KeyParams: *kp,
KDFOptions: kdfOptions,
AuthKeySize: authKeySize,
KeyParams: *kp,
KDFOptions: kdfOptions,
AuthKeySize: authKeySize,
ChangeAuthKeyContext: mockChangeAuthKeyContext,
}

return kpp, unlockKey
Expand Down
2 changes: 1 addition & 1 deletion plainkey/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (*platformKeyDataHandler) RecoverKeysWithAuthKey(data *secboot.PlatformKeyD
return nil, errors.New("unsupported action")
}

func (*platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte) ([]byte, error) {
func (*platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte, context any) ([]byte, error) {
return nil, errors.New("unsupported action")
}

Expand Down
5 changes: 4 additions & 1 deletion platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ type PlatformKeyDataHandler interface {
// keys. Either value can be nil if passphrase authentication is being enabled (
// where old will be nil) or disabled (where new will be nil).
//
// The use of the context argument isn't defined here - it's passed during
// key construction and the platform is free to use it however it likes.
//
// On success, it should return an updated handle.
ChangeAuthKey(data *PlatformKeyData, old, new []byte) ([]byte, error)
ChangeAuthKey(data *PlatformKeyData, old, new []byte, context any) ([]byte, error)
}

var handlers = make(map[string]PlatformKeyDataHandler)
Expand Down
35 changes: 22 additions & 13 deletions tpm2/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (h *platformKeyDataHandler) recoverKeysCommon(data *secboot.PlatformKeyData
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, xerrors.Errorf("cannot connect to TPM: %w", err)
return nil, fmt.Errorf("cannot connect to TPM: %w", err)
}
defer tpm.Close()

Expand Down Expand Up @@ -120,17 +120,26 @@ func (h *platformKeyDataHandler) RecoverKeysWithAuthKey(data *secboot.PlatformKe
return h.recoverKeysCommon(data, encryptedPayload, key)
}

func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte) ([]byte, error) {
tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, xerrors.Errorf("cannot connect to TPM: %w", err)
func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte, context any) ([]byte, error) {
var tpm *Connection
switch c := context.(type) {
case *Connection:
tpm = c
}

if tpm == nil {
var err error
tpm, err = ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, fmt.Errorf("cannot connect to TPM: %w", err)
}
defer tpm.Close()
}
defer tpm.Close()

var k *SealedKeyData
if err := json.Unmarshal(data.EncodedHandle, &k); err != nil {
Expand All @@ -148,7 +157,7 @@ func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, ol
}

// Validate the initial key data
_, err = k.validateData(tpm.TPMContext, data.Role)
_, err := k.validateData(tpm.TPMContext, data.Role)
switch {
case isKeyDataError(err):
return nil, &secboot.PlatformHandlerError{
Expand Down Expand Up @@ -231,5 +240,5 @@ func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, ol
}

func init() {
secboot.RegisterPlatformKeyDataHandler(platformName, &platformKeyDataHandler{})
secboot.RegisterPlatformKeyDataHandler(platformName, new(platformKeyDataHandler))
}
2 changes: 1 addition & 1 deletion tpm2/platform_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (h *legacyPlatformKeyDataHandler) RecoverKeysWithAuthKey(data *secboot.Plat
return nil, fmt.Errorf("passphrase authentication is not supported for the %s platform", legacyPlatformName)
}

func (h *legacyPlatformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte) ([]byte, error) {
func (h *legacyPlatformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte, context any) ([]byte, error) {
return nil, fmt.Errorf("passphrase authentication is not supported for the %s platform", legacyPlatformName)
}

Expand Down
8 changes: 8 additions & 0 deletions tpm2/platform_legacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (s *platformLegacySuite) TestRecoverKeys(c *C) {
PCRPolicyCounterHandle: tpm2.HandleNull})
c.Check(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand Down Expand Up @@ -108,6 +110,8 @@ func (s *platformLegacySuite) TestRecoverKeysInvalidPCRPolicy(c *C) {
_, err = s.TPM().PCREvent(s.TPM().PCRHandleContext(7), tpm2.Event("foo"), nil)
c.Check(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand All @@ -129,6 +133,8 @@ func (s *platformLegacySuite) TestRecoverKeysTPMLockout(c *C) {
// Put the TPM in DA lockout mode
c.Check(s.TPM().DictionaryAttackParameters(s.TPM().LockoutHandleContext(), 0, 7200, 86400, nil), IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand All @@ -152,6 +158,8 @@ func (s *platformLegacySuite) TestRecoverKeysErrTPMProvisioning(c *C) {
s.EvictControl(c, tpm2.HandleOwner, srk, srk.Handle())
s.HierarchyChangeAuth(c, tpm2.HandleOwner, []byte("foo"))

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand Down
Loading
Loading