diff --git a/node/pkg/aggregator/app.go b/node/pkg/aggregator/app.go index 796e9915b..e3b64d089 100644 --- a/node/pkg/aggregator/app.go +++ b/node/pkg/aggregator/app.go @@ -12,6 +12,7 @@ import ( "bisonai.com/miko/node/pkg/db" errorSentinel "bisonai.com/miko/node/pkg/error" + "bisonai.com/miko/node/pkg/utils/condition" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/host" "github.com/rs/zerolog/log" @@ -153,6 +154,11 @@ func (a *App) startAggregator(ctx context.Context, aggregator *Aggregator) error return nil } + isReady := func() bool { + return a.isLocalAggregateReady(aggregator.ID) + } + condition.WaitForCondition(ctx, isReady) + nodeCtx, cancel := context.WithCancel(ctx) aggregator.nodeCtx = nodeCtx aggregator.nodeCancel = cancel @@ -360,3 +366,8 @@ func (a *App) handleMessage(ctx context.Context, msg bus.Message) { return } } + +func (a *App) isLocalAggregateReady(confidId int32) bool { + _, ok := a.LatestLocalAggregates.Load(confidId) + return ok +} diff --git a/node/pkg/error/sentinel.go b/node/pkg/error/sentinel.go index 882c1013c..0b04dd2e8 100644 --- a/node/pkg/error/sentinel.go +++ b/node/pkg/error/sentinel.go @@ -241,4 +241,6 @@ var ( ErrLogscribeConsumerServiceNotProvided = &CustomError{Service: LogscribeConsumer, Code: InvalidInputError, Message: "Service field not provided in logscribeconsumer"} ErrLogscribeConsumerInvalidLevel = &CustomError{Service: LogscribeConsumer, Code: InvalidInputError, Message: "Invalid log level provided to logscribeconsumer"} ErrLogscribeConsumerEndpointUnresponsive = &CustomError{Service: LogscribeConsumer, Code: NetworkError, Message: "Logscribe endpoint unresponsive"} + + ErrConditionTimedOut = &CustomError{Service: Others, Code: InternalError, Message: "Condition timed out"} ) diff --git a/node/pkg/utils/condition/condition.go b/node/pkg/utils/condition/condition.go new file mode 100644 index 000000000..9552e63e9 --- /dev/null +++ b/node/pkg/utils/condition/condition.go @@ -0,0 +1,41 @@ +package condition + +import ( + "context" + "time" + + errorsentinel "bisonai.com/miko/node/pkg/error" +) + +// can be blocking infinitely if condition is not met, use with caution +func WaitForCondition(ctx context.Context, condition func() bool) { + for { + if condition() { + return + } + + select { + case <-ctx.Done(): + return + default: + time.Sleep(500 * time.Millisecond) + } + } +} + +func WaitForConditionWithTimeout(ctx context.Context, timeout time.Duration, condition func() bool) error { + for { + if condition() { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(timeout): + return errorsentinel.ErrConditionTimedOut + default: + time.Sleep(500 * time.Millisecond) + } + } +} diff --git a/node/pkg/utils/tests/condition_test.go b/node/pkg/utils/tests/condition_test.go new file mode 100644 index 000000000..a15154b09 --- /dev/null +++ b/node/pkg/utils/tests/condition_test.go @@ -0,0 +1,41 @@ +package tests + +import ( + "context" + "testing" + "time" + + "bisonai.com/miko/node/pkg/utils/condition" + "github.com/stretchr/testify/assert" +) + +func TestWaitForCondition_Success(t *testing.T) { + conditionMet := false + + testCond := func() bool { + return conditionMet + } + + go func() { + time.Sleep(100 * time.Millisecond) + conditionMet = true + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + condition.WaitForCondition(ctx, testCond) + + assert.True(t, conditionMet) +} + +func TestWaitForCondition_Timeout(t *testing.T) { + testCond := func() bool { + return false + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + condition.WaitForCondition(ctx, testCond) +}