Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tpm2: Avoid attempting to open duplicate TPM connections #360

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions internal/tpm2test/suites.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (
)

type tpmTestMixin struct {
TPM *secboot_tpm2.Connection
Transport *Transport
TPM *secboot_tpm2.Connection
Transport *Transport
mockConnection *secboot_tpm2.Connection // Unit tests always consume a connection, although this can be temporarily closed
}

func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func() (*secboot_tpm2.Connection, *Transport)) (cleanup func(*C)) {
Expand Down Expand Up @@ -68,15 +69,23 @@ func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func()
// tests existing underlying connection, but don't allow the code
// to fully close the connection - leave this to the test fixture.
// TODO: Support resource managed device concepts in tests.
// XXX: Set the maximum numbner of open connections to 1 when
// https://github.com/canonical/secboot/issues/353 is fixed.
internalDev := newTpmDevice(tpm2_testutil.NewTransportBackedDevice(suite.Transport, false, -1), tpm2_device.DeviceModeDirect, nil, tpm2_device.ErrNoPPI)
internalDev := newTpmDevice(tpm2_testutil.NewTransportBackedDevice(suite.Transport, false, 1), tpm2_device.DeviceModeDirect, nil, tpm2_device.ErrNoPPI)
restoreDefaultDeviceFn := MockDefaultDeviceFn(func(mode tpm2_device.DeviceMode) (tpm2_device.TPMDevice, error) {
c.Assert(mode, Equals, tpm2_device.DeviceModeDirect)
return internalDev, nil
})

// Unit tests always consume a connection
mockConn, err := secboot_tpm2.ConnectToDefaultTPM()
c.Assert(err, IsNil)
m.mockConnection = mockConn
// This connection isn't going to be used, so don't take up a loaded session slot
m.mockConnection.FlushContext(m.mockConnection.HmacSession())

return func(c *C) {
if m.mockConnection != nil {
c.Check(m.mockConnection.Close(), IsNil)
}
restoreDefaultDeviceFn()
c.Check(internalDev.TPMDevice.(*tpm2_testutil.TransportBackedDevice).NumberOpen(), Equals, 0)
c.Check(m.TPM.Close(), IsNil)
Expand All @@ -85,6 +94,27 @@ func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func()
}
}

// CloseMockConnection closes a mock connection that is opened automatically in test
// setup via the mock device, but not exposed for use by testing. The connection mocks
// the behaviour of having already called ConnectToDefaultTPM for testing APIs that
// accept an already open connection. In order to test APIs that don't already accept
// an open connection, and open their own connection instead, the test should call this
// API to temporarily close the internal mock connection.
//
// It returns a callback to re-open the mock connection again.
func (m *tpmTestMixin) CloseMockConnection(c *C) (restore func()) {
c.Assert(m.mockConnection, NotNil)
c.Check(m.mockConnection.Close(), IsNil)
m.mockConnection = nil
return func() {
mockConn, err := secboot_tpm2.ConnectToDefaultTPM()
c.Assert(err, IsNil)
m.mockConnection = mockConn
// This connection isn't going to be used, so don't take up a loaded session slot
m.mockConnection.FlushContext(m.mockConnection.HmacSession())
}
}

func (m *tpmTestMixin) reinitTPMConnectionFromExisting(c *C, suite *tpm2_testutil.TPMTest) {
tpm, transport, err := newTPMConnectionFromExistingTransport(m.TPM, m.Transport)
c.Assert(err, IsNil)
Expand Down
1 change: 1 addition & 0 deletions keydata.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ func (d *KeyData) platformKeyData() *PlatformKeyData {
return &PlatformKeyData{
Generation: d.Generation(),
EncodedHandle: d.data.PlatformHandle,
Role: d.data.Role,
KDFAlg: crypto.Hash(d.data.KDFAlg),
AuthMode: d.AuthMode(),
}
Expand Down
61 changes: 41 additions & 20 deletions tpm2/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,37 @@ import (

const platformName = "tpm2"

type platformKeyDataHandler struct{}
var mainPlatformKeyDataHandler *platformKeyDataHandler

type platformKeyDataHandler struct {
tpm *Connection
}

// tpm returns a currently open or new TPM connection. The caller mustn't call
// the Close method on the returned connection - it should call the returned close
// callback instead.
func (h *platformKeyDataHandler) tpmConnection() (conn *Connection, close func() error, err error) {
if h.tpm != nil {
return h.tpm, func() error { return nil }, nil
}

tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, nil, fmt.Errorf("cannot connect to TPM: %w", err)
}

h.tpm = tpm

return tpm, func() error {
h.tpm = nil
return tpm.Close()
}, nil
}

func (h *platformKeyDataHandler) recoverKeysCommon(data *secboot.PlatformKeyData, encryptedPayload, authKey []byte) ([]byte, error) {
if data.Generation < 0 || int64(data.Generation) > math.MaxUint32 {
Expand Down Expand Up @@ -67,16 +97,11 @@ func (h *platformKeyDataHandler) recoverKeysCommon(data *secboot.PlatformKeyData
Err: fmt.Errorf("invalid key data version: %d", k.data.Version())}
}

tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, xerrors.Errorf("cannot connect to TPM: %w", err)
tpm, closeTpm, err := h.tpmConnection()
if err != nil {
return nil, err
}
defer tpm.Close()
defer closeTpm()

symKey, err := k.unsealDataFromTPM(tpm.TPMContext, authKey, tpm.HmacSession())
if err != nil {
Expand Down Expand Up @@ -121,16 +146,11 @@ func (h *platformKeyDataHandler) RecoverKeysWithAuthKey(data *secboot.PlatformKe
}

func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte) ([]byte, error) {
tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, xerrors.Errorf("cannot connect to TPM: %w", err)
tpm, closeTpm, err := h.tpmConnection()
if err != nil {
return nil, err
}
defer tpm.Close()
defer closeTpm()

var k *SealedKeyData
if err := json.Unmarshal(data.EncodedHandle, &k); err != nil {
Expand Down Expand Up @@ -231,5 +251,6 @@ func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, ol
}

func init() {
secboot.RegisterPlatformKeyDataHandler(platformName, &platformKeyDataHandler{})
mainPlatformKeyDataHandler = new(platformKeyDataHandler)
secboot.RegisterPlatformKeyDataHandler(platformName, mainPlatformKeyDataHandler)
}
8 changes: 8 additions & 0 deletions tpm2/platform_legacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (s *platformLegacySuite) TestRecoverKeys(c *C) {
PCRPolicyCounterHandle: tpm2.HandleNull})
c.Check(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand Down Expand Up @@ -108,6 +110,8 @@ func (s *platformLegacySuite) TestRecoverKeysInvalidPCRPolicy(c *C) {
_, err = s.TPM().PCREvent(s.TPM().PCRHandleContext(7), tpm2.Event("foo"), nil)
c.Check(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand All @@ -129,6 +133,8 @@ func (s *platformLegacySuite) TestRecoverKeysTPMLockout(c *C) {
// Put the TPM in DA lockout mode
c.Check(s.TPM().DictionaryAttackParameters(s.TPM().LockoutHandleContext(), 0, 7200, 86400, nil), IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand All @@ -152,6 +158,8 @@ func (s *platformLegacySuite) TestRecoverKeysErrTPMProvisioning(c *C) {
s.EvictControl(c, tpm2.HandleOwner, srk, srk.Handle())
s.HierarchyChangeAuth(c, tpm2.HandleOwner, []byte("foo"))

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand Down
28 changes: 28 additions & 0 deletions tpm2/platform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ func (s *platformSuite) TestRecoverKeysIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeys()
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand All @@ -102,6 +104,8 @@ func (s *platformSuite) TestRecoverKeysWithPassphraseIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("passphrase")
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand All @@ -123,6 +127,8 @@ func (s *platformSuite) TestRecoverKeysWithPassphraseIntegratedPBKDF2(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("passphrase")
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand All @@ -146,6 +152,8 @@ func (s *platformSuite) TestRecoverKeysWithBadPassphraseIntegrated(c *C) {
k, _, _, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

_, _, err = k.RecoverKeysWithPassphrase("1234")
c.Check(err, Equals, secboot.ErrInvalidPassphrase)
}
Expand All @@ -164,6 +172,8 @@ func (s *platformSuite) TestChangePassphraseIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

c.Check(k.ChangePassphrase("passphrase", "1234"), IsNil)

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("1234")
Expand All @@ -186,6 +196,8 @@ func (s *platformSuite) TestChangePassphraseWithBadPassphraseIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

c.Check(k.ChangePassphrase("1234", "1234"), Equals, secboot.ErrInvalidPassphrase)

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("passphrase")
Expand All @@ -211,6 +223,8 @@ func (s *platformSuite) testRecoverKeys(c *C, params *ProtectKeyParams) {
k, primaryKey, unlockKey, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -277,6 +291,8 @@ func (s *platformSuite) testRecoverKeysNoValidSRK(c *C, prepareSrk func()) {

prepareSrk()

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -331,6 +347,8 @@ func (s *platformSuite) testRecoverKeysImportable(c *C, params *ProtectKeyParams
k, primaryKey, unlockKey, err := NewExternalTPMProtectedKey(srkPub, params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -405,6 +423,8 @@ func (s *platformSuite) testRecoverKeysUnsealErrorHandling(c *C, prepare func(*s

prepare(k, primaryKey)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -530,6 +550,8 @@ func (s *platformSuite) TestRecoverKeysWithAuthKey(c *C) {
k, primaryKey, unlockKey, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -613,6 +635,8 @@ func (s *platformSuite) TestRecoverKeysWithIncorrectAuthKey(c *C) {
k, _, _, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -691,6 +715,8 @@ func (s *platformSuite) TestChangeAuthKeyWithIncorrectAuthKey(c *C) {
k, _, _, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -771,6 +797,8 @@ func (s *platformSuite) TestRecoverKeysWithAuthKeyTPMLockout(c *C) {
k, _, _, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down
9 changes: 7 additions & 2 deletions tpm2/seal.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ func makeKeyDataNoAuth(skd *SealedKeyData, role string, encryptedPayload []byte,
})
}

func makeKeyDataWithPassphraseConstructor(kdfOptions secboot.KDFOptions, passphrase string) keyDataConstructor {
func makeKeyDataWithPassphraseConstructor(tpm *Connection, kdfOptions secboot.KDFOptions, passphrase string) keyDataConstructor {
return func(skd *SealedKeyData, role string, encryptedPayload []byte, kdfAlg crypto.Hash) (*secboot.KeyData, error) {
// Avoid trying to reopen another connection when setting the user auth value
orig := mainPlatformKeyDataHandler.tpm
mainPlatformKeyDataHandler.tpm = tpm
defer func() { mainPlatformKeyDataHandler.tpm = orig }()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depending how things are used this could create a data race, do we need a global mutex around some of the handler usage?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that might be a good idea - I'll add that. We might want to actually hold the lock throughout the whole code section that runs with the connection set, so that any other code that wants to set the connection will have to wait, else things might get a bit confusing if there are multiple connections around.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whilst trying to figure out the best way to do this, and given that I'd like to reduce existing data races, I ended up taking a different approach by permitting the open connection to be passed to NewKeyDataWithPassphrase and subsequently back to the tpm2 platform handler via a new context argument to ChangeAuthKey. It seemed like the safest way to do it in the end.


return secbootNewKeyDataWithPassphrase(&secboot.KeyWithPassphraseParams{
KeyParams: secboot.KeyParams{
Handle: skd,
Expand Down Expand Up @@ -306,5 +311,5 @@ func NewTPMPassphraseProtectedKey(tpm *Connection, params *PassphraseProtectKeyP
AuthMode: secboot.AuthModePassphrase,
Role: params.Role,
PcrProfile: params.PCRProfile,
}, sealer, makeKeyDataWithPassphraseConstructor(params.KDFOptions, passphrase), tpm.HmacSession())
}, sealer, makeKeyDataWithPassphraseConstructor(tpm, params.KDFOptions, passphrase), tpm.HmacSession())
}
Loading
Loading