From a1d260375cc7d8b65501c48d9d988f47c3b95b33 Mon Sep 17 00:00:00 2001 From: zack olson Date: Wed, 18 Dec 2024 17:02:10 -0500 Subject: [PATCH] restarts seem to be working this way --- ee/agent/types/registration.go | 1 + pkg/osquery/runtime/runner.go | 68 ++++++++++++- pkg/osquery/runtime/runtime_test.go | 151 ++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+), 2 deletions(-) diff --git a/ee/agent/types/registration.go b/ee/agent/types/registration.go index 1c39ef1f7..e6563e195 100644 --- a/ee/agent/types/registration.go +++ b/ee/agent/types/registration.go @@ -11,3 +11,4 @@ type RegistrationTracker interface { RegistrationIDs() []string SetRegistrationIDs(registrationIDs []string) error } + \ No newline at end of file diff --git a/pkg/osquery/runtime/runner.go b/pkg/osquery/runtime/runner.go index fe6493faf..424bdd524 100644 --- a/pkg/osquery/runtime/runner.go +++ b/pkg/osquery/runtime/runner.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "sync" "time" @@ -27,6 +28,7 @@ type Runner struct { serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS opts []OsqueryInstanceOption // global options applying to all osquery instances shutdown chan struct{} + rerunRequired bool interrupted bool } @@ -38,6 +40,7 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI knapsack: k, serviceClient: serviceClient, shutdown: make(chan struct{}), + rerunRequired: false, opts: opts, } @@ -49,6 +52,31 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI } func (r *Runner) Run() error { + for { + // if our instances ever exit unexpectedly, return immediately + if err := r.runRegisteredInstances(); err != nil { + return err + } + + // if we're in a state that required re-running all registered instances, + // reset the field and do that + if r.rerunRequired { + r.rerunRequired = false + continue + } + + // otherwise, exit cleanly + return nil + } +} + +func (r *Runner) runRegisteredInstances() error { + // clear the internal instances to add back in fresh as we runInstance, + // this prevents old instances from sticking around if a registrationID is ever removed + r.instanceLock.Lock() + r.instances = make(map[string]*OsqueryInstance) + r.instanceLock.Unlock() + // Create a group to track the workers running each instance wg, ctx := errgroup.WithContext(context.Background()) @@ -334,7 +362,43 @@ func (r *Runner) InstanceStatuses() map[string]types.InstanceStatus { return instanceStatuses } -func (r *Runner) UpdateRegistrationIDs(registrationIDs []string) error { - // TODO: detect any difference in reg IDs and shut down/spin up instances accordingly +func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error { + slices.Sort(newRegistrationIDs) + existingRegistrationIDs := r.registrationIds + slices.Sort(existingRegistrationIDs) + + if slices.Equal(newRegistrationIDs, existingRegistrationIDs) { + r.slogger.Log(context.TODO(), slog.LevelDebug, + "skipping runner restarts for updated registration IDs, no changes detected", + ) + + return nil + } + + r.slogger.Log(context.TODO(), slog.LevelDebug, + "detected changes to registrationIDs, will restart runner instances", + "previous_registration_ids", existingRegistrationIDs, + "new_registration_ids", newRegistrationIDs, + ) + + // we know there are changes, safe to update the internal registrationIDs now + r.registrationIds = newRegistrationIDs + // mark rerun as required so that we can safely shutdown all workers and have the changes + // picked back up from within the main Run function + r.rerunRequired = true + + if err := r.Shutdown(); err != nil { + r.slogger.Log(context.TODO(), slog.LevelWarn, + "could not shut down runner instances for restart after registration changes", + "err", err, + ) + + return err + } + + // reset the shutdown channel and interrupted state + r.shutdown = make(chan struct{}) + r.interrupted = false + return nil } diff --git a/pkg/osquery/runtime/runtime_test.go b/pkg/osquery/runtime/runtime_test.go index 498cb22c3..8b61d1e46 100644 --- a/pkg/osquery/runtime/runtime_test.go +++ b/pkg/osquery/runtime/runtime_test.go @@ -593,6 +593,157 @@ func TestExtensionIsCleanedUp(t *testing.T) { <-timer1.C } +func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) { + t.Parallel() + rootDirectory := testRootDirectory(t) + + logBytes, slogger := setUpTestSlogger() + + k := typesMocks.NewKnapsack(t) + k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID}) + k.On("OsqueryHealthcheckStartupDelay").Return(0 * time.Second).Maybe() + k.On("WatchdogEnabled").Return(false) + k.On("RegisterChangeObserver", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) + k.On("Slogger").Return(slogger) + k.On("LatestOsquerydPath", mock.Anything).Return(testOsqueryBinaryDirectory) + k.On("RootDirectory").Return(rootDirectory).Maybe() + k.On("OsqueryFlags").Return([]string{}) + k.On("OsqueryVerbose").Return(true) + k.On("LoggingInterval").Return(5 * time.Minute).Maybe() + k.On("LogMaxBytesPerBatch").Return(0).Maybe() + k.On("Transport").Return("jsonrpc").Maybe() + k.On("ReadEnrollSecret").Return("", nil).Maybe() + setUpMockStores(t, k) + serviceClient := mockServiceClient() + + runner := New(k, serviceClient) + + // Start the instance + go runner.Run() + waitHealthy(t, runner, logBytes) + + // Confirm the default instance was started + require.Contains(t, runner.instances, types.DefaultRegistrationID) + require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats) + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up") + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up") + + // confirm only the default instance has started + require.Equal(t, 1, len(runner.instances)) + + // Confirm instance statuses are reported correctly + instanceStatuses := runner.InstanceStatuses() + require.Contains(t, instanceStatuses, types.DefaultRegistrationID) + require.Equal(t, instanceStatuses[types.DefaultRegistrationID], types.InstanceStatusHealthy) + + // Add in an extra instance + extraRegistrationId := ulid.New() + runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId}) + waitHealthy(t, runner, logBytes) + updatedInstanceStatuses := runner.InstanceStatuses() + // verify that rerunRequired has been reset for any future changes + require.False(t, runner.rerunRequired) + // now verify both instances are reported + require.Equal(t, 2, len(runner.instances)) + require.Contains(t, updatedInstanceStatuses, types.DefaultRegistrationID) + require.Contains(t, updatedInstanceStatuses, extraRegistrationId) + // Confirm the additional instance was started and is healthy + require.NotNil(t, runner.instances[extraRegistrationId].stats) + require.NotEmpty(t, runner.instances[extraRegistrationId].stats.StartTime, "start time should be added to secondary instance stats on start up") + require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ConnectTime, "connect time should be added to secondary instance stats on start up") + require.Equal(t, updatedInstanceStatuses[extraRegistrationId], types.InstanceStatusHealthy) + + // update registration IDs one more time, this time removing the additional registration + originalDefaultInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime + runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID}) + waitHealthy(t, runner, logBytes) + + // now verify only the default instance remains + require.Equal(t, 1, len(runner.instances)) + // Confirm the default instance was started and is healthy + require.Contains(t, runner.instances, types.DefaultRegistrationID) + require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats) + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up") + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up") + // verify that rerunRequired has been reset for any future changes + require.False(t, runner.rerunRequired) + // verify the default instance was restarted + require.NotEqual(t, originalDefaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime) + + waitShutdown(t, runner, logBytes) + + // Confirm both instances exited + require.Contains(t, runner.instances, types.DefaultRegistrationID) + require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats) + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ExitTime, "exit time should be added to default instance stats on shutdown") +} + +func TestUpdatingRegistrationIDsOnlyRestartsForChanges(t *testing.T) { + t.Parallel() + rootDirectory := testRootDirectory(t) + + logBytes, slogger := setUpTestSlogger() + extraRegistrationId := ulid.New() + + k := typesMocks.NewKnapsack(t) + k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID, extraRegistrationId}) + k.On("OsqueryHealthcheckStartupDelay").Return(0 * time.Second).Maybe() + k.On("WatchdogEnabled").Return(false) + k.On("RegisterChangeObserver", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) + k.On("Slogger").Return(slogger) + k.On("LatestOsquerydPath", mock.Anything).Return(testOsqueryBinaryDirectory) + k.On("RootDirectory").Return(rootDirectory).Maybe() + k.On("OsqueryFlags").Return([]string{}) + k.On("OsqueryVerbose").Return(true) + k.On("LoggingInterval").Return(5 * time.Minute).Maybe() + k.On("LogMaxBytesPerBatch").Return(0).Maybe() + k.On("Transport").Return("jsonrpc").Maybe() + k.On("ReadEnrollSecret").Return("", nil).Maybe() + setUpMockStores(t, k) + serviceClient := mockServiceClient() + + runner := New(k, serviceClient) + + // Start the instance + go runner.Run() + waitHealthy(t, runner, logBytes) + + require.Equal(t, 2, len(runner.instances)) + // Confirm the default instance was started + require.Contains(t, runner.instances, types.DefaultRegistrationID) + require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats) + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up") + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up") + // note the original start time + defaultInstanceStartTime := runner.instances[types.DefaultRegistrationID].stats.StartTime + + // Confirm the extra instance was started + require.Contains(t, runner.instances, extraRegistrationId) + require.NotNil(t, runner.instances[extraRegistrationId].stats) + require.NotEmpty(t, runner.instances[extraRegistrationId].stats.StartTime, "start time should be added to extra instance stats on start up") + require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ConnectTime, "connect time should be added to extra instance stats on start up") + // note the original start time + extraInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime + + // rerun with identical registrationIDs in swapped order and verify that the instances are not restarted + runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID}) + waitHealthy(t, runner, logBytes) + + require.Equal(t, 2, len(runner.instances)) + require.Equal(t, extraInstanceStartTime, runner.instances[extraRegistrationId].stats.StartTime) + require.Equal(t, defaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime) + + waitShutdown(t, runner, logBytes) + + // Confirm both instances exited + require.Contains(t, runner.instances, types.DefaultRegistrationID) + require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats) + require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ExitTime, "exit time should be added to default instance stats on shutdown") + require.Contains(t, runner.instances, extraRegistrationId) + require.NotNil(t, runner.instances[extraRegistrationId].stats) + require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ExitTime, "exit time should be added to secondary instance stats on shutdown") +} + // sets up an osquery instance with a running extension to be used in tests. func setupOsqueryInstanceForTests(t *testing.T) (runner *Runner, logBytes *threadsafebuffer.ThreadSafeBuffer, teardown func()) { rootDirectory := testRootDirectory(t)