Skip to content

Commit

Permalink
james/tpm runner handle no tpm (#2066)
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Pickett authored Jan 27, 2025
1 parent 1d64d1c commit 77daf4c
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 4 deletions.
32 changes: 28 additions & 4 deletions ee/tpmrunner/tpmrunner.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//go:build !darwin
// +build !darwin

package tpmrunner

import (
Expand Down Expand Up @@ -26,6 +29,7 @@ type (
slogger *slog.Logger
interrupt chan struct{}
interrupted atomic.Bool
machineHasTpm atomic.Bool
}

// tpmSignerCreator is an interface for creating and loading TPM signers
Expand Down Expand Up @@ -62,6 +66,9 @@ func New(ctx context.Context, slogger *slog.Logger, store types.GetterSetterDele
signerCreator: defaultTpmSignerCreator{},
}

// assume we have a tpm until we know otherwise
tpmRunner.machineHasTpm.Store(true)

for _, opt := range opts {
opt(tpmRunner)
}
Expand All @@ -78,17 +85,17 @@ func (tr *tpmRunner) Execute() error {

for {
// try to create signer if we don't have one
if tr.signer == nil {
if tr.signer == nil && tr.machineHasTpm.Load() {
ctx := context.Background()
if err := tr.loadOrCreateKeys(ctx); err != nil {
tr.slogger.Log(ctx, slog.LevelError,
tr.slogger.Log(ctx, slog.LevelInfo,
"loading or creating keys in execute loop",
"err", err,
)
}
}

if tr.signer != nil {
if tr.signer != nil || !tr.machineHasTpm.Load() {
retryTicker.Stop()
}

Expand Down Expand Up @@ -119,12 +126,16 @@ func (tr *tpmRunner) Interrupt(_ error) {

// Public returns the public hardware key
func (tr *tpmRunner) Public() crypto.PublicKey {
if !tr.machineHasTpm.Load() {
return nil
}

if tr.signer != nil {
return tr.signer.Public()
}

if err := tr.loadOrCreateKeys(context.Background()); err != nil {
tr.slogger.Log(context.Background(), slog.LevelError,
tr.slogger.Log(context.Background(), slog.LevelInfo,
"loading or creating keys in public call",
"err", err,
)
Expand Down Expand Up @@ -216,6 +227,19 @@ func (tr *tpmRunner) loadOrCreateKeys(ctx context.Context) error {
var err error
priData, pubData, err = tr.signerCreator.CreateKey()
if err != nil {

if isTPMNotFoundErr(err) {
tr.machineHasTpm.Store(false)

tr.slogger.Log(ctx, slog.LevelInfo,
"tpm not found",
"err", err,
)

span.AddEvent("tpm_not_found")
return err
}

thisErr := fmt.Errorf("creating key: %w", err)
traces.SetError(span, thisErr)

Expand Down
10 changes: 10 additions & 0 deletions ee/tpmrunner/tpmrunner_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build linux
// +build linux

package tpmrunner

// isTPMNotFoundErr always return false on linux because we don't yet how to
// detect if a TPM is not found on linux.
func isTPMNotFoundErr(err error) bool {
return false
}
3 changes: 3 additions & 0 deletions ee/tpmrunner/tpmrunner_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//go:build !darwin
// +build !darwin

package tpmrunner

import (
Expand Down
14 changes: 14 additions & 0 deletions ee/tpmrunner/tpmrunner_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:build windows
// +build windows

package tpmrunner

import (
"errors"

"github.com/google/go-tpm/tpmutil/tbs"
)

func isTPMNotFoundErr(err error) bool {
return errors.Is(err, tbs.ErrTPMNotFound)
}
67 changes: 67 additions & 0 deletions ee/tpmrunner/tpmrunner_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//go:build windows
// +build windows

package tpmrunner

import (
"context"
"errors"
"testing"
"time"

"github.com/google/go-tpm/tpmutil/tbs"
"github.com/kolide/launcher/ee/agent/storage/inmemory"
"github.com/kolide/launcher/ee/tpmrunner/mocks"
"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/stretchr/testify/require"
)

func Test_tpmRunner_windows(t *testing.T) {
t.Parallel()

t.Run("handles no tpm in exectue", func(t *testing.T) {
t.Parallel()

tpmSignerCreatorMock := mocks.NewTpmSignerCreator(t)
tpmRunner, err := New(context.TODO(), multislogger.NewNopLogger(), inmemory.NewStore(), withTpmSignerCreator(tpmSignerCreatorMock))
require.NoError(t, err)

// we should never try again after getting TPMNotFound err
tpmSignerCreatorMock.On("CreateKey").Return(nil, nil, tbs.ErrTPMNotFound).Once()

go func() {
// sleep long enough to get through 2 cycles of execute

// "CreateKey" should only be called once
time.Sleep(3 * time.Second)
tpmRunner.Interrupt(errors.New("test"))
}()

require.NoError(t, tpmRunner.Execute())
require.Nil(t, tpmRunner.Public())
})

t.Run("handles no tpm in Public() call", func(t *testing.T) {
t.Parallel()

tpmSignerCreatorMock := mocks.NewTpmSignerCreator(t)
tpmRunner, err := New(context.TODO(), multislogger.NewNopLogger(), inmemory.NewStore(), withTpmSignerCreator(tpmSignerCreatorMock))
require.NoError(t, err)

// we should never try again after getting TPMNotFound err
tpmSignerCreatorMock.On("CreateKey").Return(nil, nil, tbs.ErrTPMNotFound).Once()

// this is the only time "CreateKey" should be called
require.Nil(t, tpmRunner.Public())

go func() {
// sleep long enough to get through 2 cycles of execute
time.Sleep(3 * time.Second)
tpmRunner.Interrupt(errors.New("test"))
}()

require.NoError(t, tpmRunner.Execute())
require.Nil(t, tpmRunner.Public())
})

}

0 comments on commit 77daf4c

Please sign in to comment.