Skip to content

Commit

Permalink
restarts seem to be working this way
Browse files Browse the repository at this point in the history
  • Loading branch information
zackattack01 committed Dec 18, 2024
1 parent d2ceff5 commit a1d2603
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 2 deletions.
1 change: 1 addition & 0 deletions ee/agent/types/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ type RegistrationTracker interface {
RegistrationIDs() []string
SetRegistrationIDs(registrationIDs []string) error
}

68 changes: 66 additions & 2 deletions pkg/osquery/runtime/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"slices"
"sync"
"time"

Expand All @@ -27,6 +28,7 @@ type Runner struct {
serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS
opts []OsqueryInstanceOption // global options applying to all osquery instances
shutdown chan struct{}
rerunRequired bool
interrupted bool
}

Expand All @@ -38,6 +40,7 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
knapsack: k,
serviceClient: serviceClient,
shutdown: make(chan struct{}),
rerunRequired: false,
opts: opts,
}

Expand All @@ -49,6 +52,31 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
}

func (r *Runner) Run() error {
for {
// if our instances ever exit unexpectedly, return immediately
if err := r.runRegisteredInstances(); err != nil {
return err
}

// if we're in a state that required re-running all registered instances,
// reset the field and do that
if r.rerunRequired {
r.rerunRequired = false
continue
}

// otherwise, exit cleanly
return nil
}
}

func (r *Runner) runRegisteredInstances() error {
// clear the internal instances to add back in fresh as we runInstance,
// this prevents old instances from sticking around if a registrationID is ever removed
r.instanceLock.Lock()
r.instances = make(map[string]*OsqueryInstance)
r.instanceLock.Unlock()

// Create a group to track the workers running each instance
wg, ctx := errgroup.WithContext(context.Background())

Expand Down Expand Up @@ -334,7 +362,43 @@ func (r *Runner) InstanceStatuses() map[string]types.InstanceStatus {
return instanceStatuses
}

func (r *Runner) UpdateRegistrationIDs(registrationIDs []string) error {
// TODO: detect any difference in reg IDs and shut down/spin up instances accordingly
func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
slices.Sort(newRegistrationIDs)
existingRegistrationIDs := r.registrationIds
slices.Sort(existingRegistrationIDs)

if slices.Equal(newRegistrationIDs, existingRegistrationIDs) {
r.slogger.Log(context.TODO(), slog.LevelDebug,
"skipping runner restarts for updated registration IDs, no changes detected",
)

return nil
}

r.slogger.Log(context.TODO(), slog.LevelDebug,
"detected changes to registrationIDs, will restart runner instances",
"previous_registration_ids", existingRegistrationIDs,
"new_registration_ids", newRegistrationIDs,
)

// we know there are changes, safe to update the internal registrationIDs now
r.registrationIds = newRegistrationIDs
// mark rerun as required so that we can safely shutdown all workers and have the changes
// picked back up from within the main Run function
r.rerunRequired = true

if err := r.Shutdown(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelWarn,
"could not shut down runner instances for restart after registration changes",
"err", err,
)

return err
}

// reset the shutdown channel and interrupted state
r.shutdown = make(chan struct{})
r.interrupted = false

return nil
}
151 changes: 151 additions & 0 deletions pkg/osquery/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,157 @@ func TestExtensionIsCleanedUp(t *testing.T) {
<-timer1.C
}

func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {
t.Parallel()
rootDirectory := testRootDirectory(t)

logBytes, slogger := setUpTestSlogger()

k := typesMocks.NewKnapsack(t)
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID})
k.On("OsqueryHealthcheckStartupDelay").Return(0 * time.Second).Maybe()
k.On("WatchdogEnabled").Return(false)
k.On("RegisterChangeObserver", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
k.On("Slogger").Return(slogger)
k.On("LatestOsquerydPath", mock.Anything).Return(testOsqueryBinaryDirectory)
k.On("RootDirectory").Return(rootDirectory).Maybe()
k.On("OsqueryFlags").Return([]string{})
k.On("OsqueryVerbose").Return(true)
k.On("LoggingInterval").Return(5 * time.Minute).Maybe()
k.On("LogMaxBytesPerBatch").Return(0).Maybe()
k.On("Transport").Return("jsonrpc").Maybe()
k.On("ReadEnrollSecret").Return("", nil).Maybe()
setUpMockStores(t, k)
serviceClient := mockServiceClient()

runner := New(k, serviceClient)

// Start the instance
go runner.Run()
waitHealthy(t, runner, logBytes)

// Confirm the default instance was started
require.Contains(t, runner.instances, types.DefaultRegistrationID)
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")

// confirm only the default instance has started
require.Equal(t, 1, len(runner.instances))

// Confirm instance statuses are reported correctly
instanceStatuses := runner.InstanceStatuses()
require.Contains(t, instanceStatuses, types.DefaultRegistrationID)
require.Equal(t, instanceStatuses[types.DefaultRegistrationID], types.InstanceStatusHealthy)

// Add in an extra instance
extraRegistrationId := ulid.New()
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId})
waitHealthy(t, runner, logBytes)
updatedInstanceStatuses := runner.InstanceStatuses()
// verify that rerunRequired has been reset for any future changes
require.False(t, runner.rerunRequired)
// now verify both instances are reported
require.Equal(t, 2, len(runner.instances))
require.Contains(t, updatedInstanceStatuses, types.DefaultRegistrationID)
require.Contains(t, updatedInstanceStatuses, extraRegistrationId)
// Confirm the additional instance was started and is healthy
require.NotNil(t, runner.instances[extraRegistrationId].stats)
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.StartTime, "start time should be added to secondary instance stats on start up")
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ConnectTime, "connect time should be added to secondary instance stats on start up")
require.Equal(t, updatedInstanceStatuses[extraRegistrationId], types.InstanceStatusHealthy)

// update registration IDs one more time, this time removing the additional registration
originalDefaultInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID})
waitHealthy(t, runner, logBytes)

// now verify only the default instance remains
require.Equal(t, 1, len(runner.instances))
// Confirm the default instance was started and is healthy
require.Contains(t, runner.instances, types.DefaultRegistrationID)
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
// verify that rerunRequired has been reset for any future changes
require.False(t, runner.rerunRequired)
// verify the default instance was restarted
require.NotEqual(t, originalDefaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime)

waitShutdown(t, runner, logBytes)

// Confirm both instances exited
require.Contains(t, runner.instances, types.DefaultRegistrationID)
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ExitTime, "exit time should be added to default instance stats on shutdown")
}

func TestUpdatingRegistrationIDsOnlyRestartsForChanges(t *testing.T) {
t.Parallel()
rootDirectory := testRootDirectory(t)

logBytes, slogger := setUpTestSlogger()
extraRegistrationId := ulid.New()

k := typesMocks.NewKnapsack(t)
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID, extraRegistrationId})
k.On("OsqueryHealthcheckStartupDelay").Return(0 * time.Second).Maybe()
k.On("WatchdogEnabled").Return(false)
k.On("RegisterChangeObserver", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
k.On("Slogger").Return(slogger)
k.On("LatestOsquerydPath", mock.Anything).Return(testOsqueryBinaryDirectory)
k.On("RootDirectory").Return(rootDirectory).Maybe()
k.On("OsqueryFlags").Return([]string{})
k.On("OsqueryVerbose").Return(true)
k.On("LoggingInterval").Return(5 * time.Minute).Maybe()
k.On("LogMaxBytesPerBatch").Return(0).Maybe()
k.On("Transport").Return("jsonrpc").Maybe()
k.On("ReadEnrollSecret").Return("", nil).Maybe()
setUpMockStores(t, k)
serviceClient := mockServiceClient()

runner := New(k, serviceClient)

// Start the instance
go runner.Run()
waitHealthy(t, runner, logBytes)

require.Equal(t, 2, len(runner.instances))
// Confirm the default instance was started
require.Contains(t, runner.instances, types.DefaultRegistrationID)
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
// note the original start time
defaultInstanceStartTime := runner.instances[types.DefaultRegistrationID].stats.StartTime

// Confirm the extra instance was started
require.Contains(t, runner.instances, extraRegistrationId)
require.NotNil(t, runner.instances[extraRegistrationId].stats)
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.StartTime, "start time should be added to extra instance stats on start up")
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ConnectTime, "connect time should be added to extra instance stats on start up")
// note the original start time
extraInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime

// rerun with identical registrationIDs in swapped order and verify that the instances are not restarted
runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID})
waitHealthy(t, runner, logBytes)

require.Equal(t, 2, len(runner.instances))
require.Equal(t, extraInstanceStartTime, runner.instances[extraRegistrationId].stats.StartTime)
require.Equal(t, defaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime)

waitShutdown(t, runner, logBytes)

// Confirm both instances exited
require.Contains(t, runner.instances, types.DefaultRegistrationID)
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ExitTime, "exit time should be added to default instance stats on shutdown")
require.Contains(t, runner.instances, extraRegistrationId)
require.NotNil(t, runner.instances[extraRegistrationId].stats)
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ExitTime, "exit time should be added to secondary instance stats on shutdown")
}

// sets up an osquery instance with a running extension to be used in tests.
func setupOsqueryInstanceForTests(t *testing.T) (runner *Runner, logBytes *threadsafebuffer.ThreadSafeBuffer, teardown func()) {
rootDirectory := testRootDirectory(t)
Expand Down

0 comments on commit a1d2603

Please sign in to comment.