Skip to content

Commit

Permalink
tpm2: Avoid attempting to avoid opening duplicate TPM connections
Browse files Browse the repository at this point in the history
When using NewTPMPassphraseProtectedKey, it is passed an already open
connection. But the core secboot package calls back into the platform
handler to set the user auth value, which opens a second connection.
When using the direct device (/dev/tpm0), this doesn't work.

This updates the test suite to make sure that this scenario is caught
and results in an error. It also fixes the problem by providing the open
TPM connection to the singleton tpm2 platform handler so that it uses
this rather than trying to open a new connection.
  • Loading branch information
chrisccoulson committed Jan 16, 2025
1 parent 31629d7 commit 610b686
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 27 deletions.
40 changes: 35 additions & 5 deletions internal/tpm2test/suites.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (
)

type tpmTestMixin struct {
TPM *secboot_tpm2.Connection
Transport *Transport
TPM *secboot_tpm2.Connection
Transport *Transport
mockConnection *secboot_tpm2.Connection // Unit tests always consume a connection, although this can be temporarily closed
}

func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func() (*secboot_tpm2.Connection, *Transport)) (cleanup func(*C)) {
Expand Down Expand Up @@ -68,15 +69,23 @@ func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func()
// tests existing underlying connection, but don't allow the code
// to fully close the connection - leave this to the test fixture.
// TODO: Support resource managed device concepts in tests.
// XXX: Set the maximum numbner of open connections to 1 when
// https://github.com/canonical/secboot/issues/353 is fixed.
internalDev := newTpmDevice(tpm2_testutil.NewTransportBackedDevice(suite.Transport, false, -1), tpm2_device.DeviceModeDirect, nil, tpm2_device.ErrNoPPI)
internalDev := newTpmDevice(tpm2_testutil.NewTransportBackedDevice(suite.Transport, false, 1), tpm2_device.DeviceModeDirect, nil, tpm2_device.ErrNoPPI)
restoreDefaultDeviceFn := MockDefaultDeviceFn(func(mode tpm2_device.DeviceMode) (tpm2_device.TPMDevice, error) {
c.Assert(mode, Equals, tpm2_device.DeviceModeDirect)
return internalDev, nil
})

// Unit tests always consume a connection
mockConn, err := secboot_tpm2.ConnectToDefaultTPM()
c.Assert(err, IsNil)
m.mockConnection = mockConn
// This connection isn't going to be used, so don't take up a loaded session slot
m.mockConnection.FlushContext(m.mockConnection.HmacSession())

return func(c *C) {
if m.mockConnection != nil {
c.Check(m.mockConnection.Close(), IsNil)
}
restoreDefaultDeviceFn()
c.Check(internalDev.TPMDevice.(*tpm2_testutil.TransportBackedDevice).NumberOpen(), Equals, 0)
c.Check(m.TPM.Close(), IsNil)
Expand All @@ -85,6 +94,27 @@ func (m *tpmTestMixin) setUpTest(c *C, suite *tpm2_testutil.TPMTest, open func()
}
}

// CloseMockConnection closes a mock connection that is opened automatically in test
// setup via the mock device, but not exposed for use by testing. The connection mocks
// the behaviour of having already called ConnectToDefaultTPM for testing APIs that
// accept an already open connection. In order to test APIs that don't already accept
// an open connection, and open their own connection instead, the test should call this
// API to temporarily close the internal mock connection.
//
// It returns a callback to re-open the mock connection again.
func (m *tpmTestMixin) CloseMockConnection(c *C) (restore func()) {
c.Assert(m.mockConnection, NotNil)
c.Check(m.mockConnection.Close(), IsNil)
m.mockConnection = nil
return func() {
mockConn, err := secboot_tpm2.ConnectToDefaultTPM()
c.Assert(err, IsNil)
m.mockConnection = mockConn
// This connection isn't going to be used, so don't take up a loaded session slot
m.mockConnection.FlushContext(m.mockConnection.HmacSession())
}
}

func (m *tpmTestMixin) reinitTPMConnectionFromExisting(c *C, suite *tpm2_testutil.TPMTest) {
tpm, transport, err := newTPMConnectionFromExistingTransport(m.TPM, m.Transport)
c.Assert(err, IsNil)
Expand Down
61 changes: 41 additions & 20 deletions tpm2/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,37 @@ import (

const platformName = "tpm2"

type platformKeyDataHandler struct{}
var mainPlatformKeyDataHandler *platformKeyDataHandler

type platformKeyDataHandler struct {
tpm *Connection
}

// tpm returns a currently open or new TPM connection. The caller mustn't call
// the Close method on the returned connection - it should call the returned close
// callback instead.
func (h *platformKeyDataHandler) tpmConnection() (conn *Connection, close func() error, err error) {
if h.tpm != nil {
return h.tpm, func() error { return nil }, nil
}

tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, nil, fmt.Errorf("cannot connect to TPM: %w", err)
}

h.tpm = tpm

return tpm, func() error {
h.tpm = nil
return tpm.Close()
}, nil
}

func (h *platformKeyDataHandler) recoverKeysCommon(data *secboot.PlatformKeyData, encryptedPayload, authKey []byte) ([]byte, error) {
if data.Generation < 0 || int64(data.Generation) > math.MaxUint32 {
Expand Down Expand Up @@ -67,16 +97,11 @@ func (h *platformKeyDataHandler) recoverKeysCommon(data *secboot.PlatformKeyData
Err: fmt.Errorf("invalid key data version: %d", k.data.Version())}
}

tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, xerrors.Errorf("cannot connect to TPM: %w", err)
tpm, closeTpm, err := h.tpmConnection()
if err != nil {
return nil, err
}
defer tpm.Close()
defer closeTpm()

symKey, err := k.unsealDataFromTPM(tpm.TPMContext, authKey, tpm.HmacSession())
if err != nil {
Expand Down Expand Up @@ -121,16 +146,11 @@ func (h *platformKeyDataHandler) RecoverKeysWithAuthKey(data *secboot.PlatformKe
}

func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, old, new []byte) ([]byte, error) {
tpm, err := ConnectToDefaultTPM()
switch {
case err == ErrNoTPM2Device:
return nil, &secboot.PlatformHandlerError{
Type: secboot.PlatformHandlerErrorUnavailable,
Err: err}
case err != nil:
return nil, xerrors.Errorf("cannot connect to TPM: %w", err)
tpm, closeTpm, err := h.tpmConnection()
if err != nil {
return nil, err
}
defer tpm.Close()
defer closeTpm()

var k *SealedKeyData
if err := json.Unmarshal(data.EncodedHandle, &k); err != nil {
Expand Down Expand Up @@ -231,5 +251,6 @@ func (h *platformKeyDataHandler) ChangeAuthKey(data *secboot.PlatformKeyData, ol
}

func init() {
secboot.RegisterPlatformKeyDataHandler(platformName, &platformKeyDataHandler{})
mainPlatformKeyDataHandler = new(platformKeyDataHandler)
secboot.RegisterPlatformKeyDataHandler(platformName, mainPlatformKeyDataHandler)
}
8 changes: 8 additions & 0 deletions tpm2/platform_legacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (s *platformLegacySuite) TestRecoverKeys(c *C) {
PCRPolicyCounterHandle: tpm2.HandleNull})
c.Check(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand Down Expand Up @@ -108,6 +110,8 @@ func (s *platformLegacySuite) TestRecoverKeysInvalidPCRPolicy(c *C) {
_, err = s.TPM().PCREvent(s.TPM().PCRHandleContext(7), tpm2.Event("foo"), nil)
c.Check(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand All @@ -129,6 +133,8 @@ func (s *platformLegacySuite) TestRecoverKeysTPMLockout(c *C) {
// Put the TPM in DA lockout mode
c.Check(s.TPM().DictionaryAttackParameters(s.TPM().LockoutHandleContext(), 0, 7200, 86400, nil), IsNil)

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand All @@ -152,6 +158,8 @@ func (s *platformLegacySuite) TestRecoverKeysErrTPMProvisioning(c *C) {
s.EvictControl(c, tpm2.HandleOwner, srk, srk.Handle())
s.HierarchyChangeAuth(c, tpm2.HandleOwner, []byte("foo"))

s.AddCleanup(s.CloseMockConnection(c))

k, err := NewKeyDataFromSealedKeyObjectFile(keyFile)
c.Assert(err, IsNil)

Expand Down
28 changes: 28 additions & 0 deletions tpm2/platform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ func (s *platformSuite) TestRecoverKeysIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeys()
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand All @@ -102,6 +104,8 @@ func (s *platformSuite) TestRecoverKeysWithPassphraseIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("passphrase")
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand All @@ -123,6 +127,8 @@ func (s *platformSuite) TestRecoverKeysWithPassphraseIntegratedPBKDF2(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("passphrase")
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand All @@ -146,6 +152,8 @@ func (s *platformSuite) TestRecoverKeysWithBadPassphraseIntegrated(c *C) {
k, _, _, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

_, _, err = k.RecoverKeysWithPassphrase("1234")
c.Check(err, Equals, secboot.ErrInvalidPassphrase)
}
Expand All @@ -164,6 +172,8 @@ func (s *platformSuite) TestChangePassphraseIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

c.Check(k.ChangePassphrase("passphrase", "1234"), IsNil)

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("1234")
Expand All @@ -186,6 +196,8 @@ func (s *platformSuite) TestChangePassphraseWithBadPassphraseIntegrated(c *C) {
k, primaryKey, unlockKey, err := NewTPMPassphraseProtectedKey(s.TPM(), passphraseParams, "passphrase")
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

c.Check(k.ChangePassphrase("1234", "1234"), Equals, secboot.ErrInvalidPassphrase)

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeysWithPassphrase("passphrase")
Expand All @@ -211,6 +223,8 @@ func (s *platformSuite) testRecoverKeys(c *C, params *ProtectKeyParams) {
k, primaryKey, unlockKey, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -277,6 +291,8 @@ func (s *platformSuite) testRecoverKeysNoValidSRK(c *C, prepareSrk func()) {

prepareSrk()

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -331,6 +347,8 @@ func (s *platformSuite) testRecoverKeysImportable(c *C, params *ProtectKeyParams
k, primaryKey, unlockKey, err := NewExternalTPMProtectedKey(srkPub, params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -405,6 +423,8 @@ func (s *platformSuite) testRecoverKeysUnsealErrorHandling(c *C, prepare func(*s

prepare(k, primaryKey)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -530,6 +550,8 @@ func (s *platformSuite) TestRecoverKeysWithAuthKey(c *C) {
k, primaryKey, unlockKey, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -613,6 +635,8 @@ func (s *platformSuite) TestRecoverKeysWithIncorrectAuthKey(c *C) {
k, _, _, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -691,6 +715,8 @@ func (s *platformSuite) TestChangeAuthKeyWithIncorrectAuthKey(c *C) {
k, _, _, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down Expand Up @@ -771,6 +797,8 @@ func (s *platformSuite) TestRecoverKeysWithAuthKeyTPMLockout(c *C) {
k, _, _, err := NewTPMProtectedKey(s.TPM(), params)
c.Assert(err, IsNil)

s.AddCleanup(s.CloseMockConnection(c))

var platformHandle json.RawMessage
c.Check(k.UnmarshalPlatformHandle(&platformHandle), IsNil)

Expand Down
9 changes: 7 additions & 2 deletions tpm2/seal.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ func makeKeyDataNoAuth(skd *SealedKeyData, role string, encryptedPayload []byte,
})
}

func makeKeyDataWithPassphraseConstructor(kdfOptions secboot.KDFOptions, passphrase string) keyDataConstructor {
func makeKeyDataWithPassphraseConstructor(tpm *Connection, kdfOptions secboot.KDFOptions, passphrase string) keyDataConstructor {
return func(skd *SealedKeyData, role string, encryptedPayload []byte, kdfAlg crypto.Hash) (*secboot.KeyData, error) {
// Avoid trying to reopen another connection when setting the user auth value
orig := mainPlatformKeyDataHandler.tpm
mainPlatformKeyDataHandler.tpm = tpm
defer func() { mainPlatformKeyDataHandler.tpm = orig }()

return secbootNewKeyDataWithPassphrase(&secboot.KeyWithPassphraseParams{
KeyParams: secboot.KeyParams{
Handle: skd,
Expand Down Expand Up @@ -306,5 +311,5 @@ func NewTPMPassphraseProtectedKey(tpm *Connection, params *PassphraseProtectKeyP
AuthMode: secboot.AuthModePassphrase,
Role: params.Role,
PcrProfile: params.PCRProfile,
}, sealer, makeKeyDataWithPassphraseConstructor(params.KDFOptions, passphrase), tpm.HmacSession())
}, sealer, makeKeyDataWithPassphraseConstructor(tpm, params.KDFOptions, passphrase), tpm.HmacSession())
}
4 changes: 4 additions & 0 deletions tpm2/seal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ func (s *sealSuite) testProtectKeyWithTPM(c *C, params *ProtectKeyParams) {
c.Check(primaryKey, DeepEquals, params.PrimaryKey)
}

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeys()
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand Down Expand Up @@ -401,6 +403,8 @@ func (s *sealSuite) testProtectKeyWithExternalStorageKey(c *C, params *ProtectKe
c.Check(primaryKey, DeepEquals, params.PrimaryKey)
}

s.AddCleanup(s.CloseMockConnection(c))

unlockKeyUnsealed, primaryKeyUnsealed, err := k.RecoverKeys()
c.Check(err, IsNil)
c.Check(unlockKeyUnsealed, DeepEquals, unlockKey)
Expand Down
3 changes: 3 additions & 0 deletions tpm2/tpm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,21 @@ func (s *tpmSuiteCommon) testConnectToDefaultTPM(c *C, hasEncryption bool) {
}

func (s *tpmSuiteSimulator) TestConnectToDefaultTPMUnprovisioned(c *C) {
s.AddCleanup(s.CloseMockConnection(c))
s.testConnectToDefaultTPM(c, false)
}

func (s *tpmSuite) TestConnectToDefaultTPMProvisioned(c *C) {
c.Check(s.TPM().EnsureProvisioned(ProvisionModeWithoutLockout, nil),
testutil.InSlice(Equals), []error{ErrTPMProvisioningRequiresLockout, nil})
s.AddCleanup(s.CloseMockConnection(c))
s.testConnectToDefaultTPM(c, true)
}

func (s *tpmSuite) TestConnectToDefaultTPMInvalidEK(c *C) {
primary := s.CreatePrimary(c, tpm2.HandleOwner, tpm2_testutil.NewRSAKeyTemplate(templates.KeyUsageDecrypt, nil))
s.EvictControl(c, tpm2.HandleOwner, primary, tcg.EKHandle)
s.AddCleanup(s.CloseMockConnection(c))
s.testConnectToDefaultTPM(c, false)
}

Expand Down
Loading

0 comments on commit 610b686

Please sign in to comment.