diff --git a/.github/workflows/router-ci.yaml b/.github/workflows/router-ci.yaml index 7ae5bc390c..f0117e3cb9 100644 --- a/.github/workflows/router-ci.yaml +++ b/.github/workflows/router-ci.yaml @@ -156,7 +156,11 @@ jobs: with: cache-dependency-path: | router-tests/go.sum - - uses: nick-fields/retry@v3 + - name: Run Integration tests + working-directory: ./router-tests + run: make test test_params="-run '^Test[^(Flaky)]' --timeout=5m --parallel 10" + - name: Run Flaky Integration tests + uses: nick-fields/retry@v3 with: timeout_minutes: 30 max_attempts: 5 @@ -164,7 +168,7 @@ jobs: retry_on: error command: | cd router-tests - make test test_params="--timeout=5m" + make test test_params="-run '^TestFlaky' --timeout=5m -p 1 --parallel 1" image_scan: if: github.event.pull_request.head.repo.full_name == github.repository diff --git a/demo/go.mod b/demo/go.mod index c4050d3f5a..91dbd74b9d 100644 --- a/demo/go.mod +++ b/demo/go.mod @@ -15,7 +15,7 @@ require ( github.com/wundergraph/cosmo/composition-go v0.0.0-20240124120900-5effe48a4a1d github.com/wundergraph/cosmo/router v0.0.0-20250119174948-4b991294658e github.com/wundergraph/cosmo/router-tests v0.0.0-20241213115435-a249dba8c52a - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 go.opentelemetry.io/otel v1.28.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.23.1 diff --git a/demo/go.sum b/demo/go.sum index 33ab57ca5e..101d3b8436 100644 --- a/demo/go.sum +++ b/demo/go.sum @@ -312,8 +312,8 @@ github.com/wundergraph/cosmo/router v0.0.0-20250119174948-4b991294658e h1:ee4fu7 github.com/wundergraph/cosmo/router v0.0.0-20250119174948-4b991294658e/go.mod h1:ImqCvxvvNOy1UxbuTnFtin/CDBFHoFqrZly3rC2z+e0= github.com/wundergraph/cosmo/router-tests v0.0.0-20241213115435-a249dba8c52a h1:GVLe85f5g+G0IOorDBBNTfm5Ua9DO0vuVY7ReSTOEbQ= github.com/wundergraph/cosmo/router-tests v0.0.0-20241213115435-a249dba8c52a/go.mod h1:I+SFviFnd3BHlPmYn+ckmzQyDB9+/c8RZJo4t6VQAds= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145 h1:3JuBmRux6YB/UZgh6COvgLXzQhMIsdHV7A02NsYdAVE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145/go.mod h1:B7eV0Qh8Lop9QzIOQcsvKp3S0ejfC6mgyWoJnI917yQ= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146 h1:C9+jjMgbU/RJTiFGC0HNHan4LxrY7fIhmbZRoqZryLk= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146/go.mod h1:B7eV0Qh8Lop9QzIOQcsvKp3S0ejfC6mgyWoJnI917yQ= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/router-tests/cache_warmup_test.go b/router-tests/cache_warmup_test.go index aeae4b6ef1..3925fe43b0 100644 --- a/router-tests/cache_warmup_test.go +++ b/router-tests/cache_warmup_test.go @@ -570,7 +570,8 @@ func TestCacheWarmup(t *testing.T) { }) } -func TestCacheWarmupMetrics(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyCacheWarmupMetrics(t *testing.T) { t.Run("should emit planning times metrics during warmup", func(t *testing.T) { t.Parallel() diff --git a/router-tests/complexity_limits_test.go b/router-tests/complexity_limits_test.go index 2aabb798cd..85e47d8f23 100644 --- a/router-tests/complexity_limits_test.go +++ b/router-tests/complexity_limits_test.go @@ -3,6 +3,7 @@ package integration import ( "net/http" "testing" + "time" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" @@ -152,6 +153,8 @@ func TestComplexityLimits(t *testing.T) { require.Contains(t, testSpan.Attributes(), otel.WgQueryDepth.Int(3)) require.Contains(t, testSpan.Attributes(), otel.WgQueryDepthCacheHit.Bool(false)) exporter.Reset() + // wait to let cache get consistent + time.Sleep(100 * time.Millisecond) failedRes2, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ Query: `{ employee(id:1) { id details { forename surname } } }`, @@ -163,6 +166,8 @@ func TestComplexityLimits(t *testing.T) { require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepth.Int(3)) require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true)) exporter.Reset() + // wait to let cache get consistent + time.Sleep(100 * time.Millisecond) successRes := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `query { employees { id } }`, @@ -172,6 +177,8 @@ func TestComplexityLimits(t *testing.T) { require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepth.Int(2)) require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepthCacheHit.Bool(false)) exporter.Reset() + // wait to let cache get consistent + time.Sleep(100 * time.Millisecond) successRes2 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `query { employees { id } }`, diff --git a/router-tests/config_hot_reload_test.go b/router-tests/config_hot_reload_test.go index f69d3e270c..10618aebb4 100644 --- a/router-tests/config_hot_reload_test.go +++ b/router-tests/config_hot_reload_test.go @@ -10,6 +10,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/routerconfig" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" @@ -238,28 +239,35 @@ func TestConfigHotReload(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { - var done atomic.Bool - + var startedReq atomic.Bool go func() { - defer done.Store(true) - + startedReq.Store(true) res, err := xEnv.MakeGraphQLRequestWithContext(context.Background(), testenv.GraphQLRequest{ Query: `{ employees { id } }`, }) require.NoError(t, err) - require.Equal(t, res.Response.StatusCode, 200) - require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'."}],"data":{"employees":null}}`, res.Body) + assert.Equal(t, res.Response.StatusCode, 200) + assert.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'."}],"data":{"employees":null}}`, res.Body) }() // Let's wait a bit to make sure all requests are in flight // otherwise the shutdown will be too fast and the wait-group will not be done fully + require.Eventually(t, func() bool { + return startedReq.Load() + }, time.Second*10, time.Millisecond*100) time.Sleep(time.Millisecond * 100) - xEnv.Shutdown() + var done atomic.Bool + go func() { + defer done.Store(true) + + err := xEnv.Router.Shutdown(context.Background()) + assert.ErrorContains(t, err, context.DeadlineExceeded.Error()) + }() require.Eventually(t, func() bool { return done.Load() - }, time.Second*5, time.Millisecond*100) + }, time.Second*20, time.Millisecond*100) }) }) @@ -314,13 +322,15 @@ func TestConfigHotReload(t *testing.T) { // Swap config require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) - err = conn.ReadJSON(&msg) - // Ensure that the connection is closed. In the future, we might want to send a complete message to the client + // If the operation happen fast enough, ensure that the connection is closed. + // In the future, we might want to send a complete message to the client // and wait until in-flight messages are delivered before closing the connection - var wsErr *websocket.CloseError - require.ErrorAs(t, err, &wsErr) + if err != nil { + var wsErr *websocket.CloseError + require.ErrorAs(t, err, &wsErr) + } require.NoError(t, conn.Close()) diff --git a/router-tests/events_config_test.go b/router-tests/events/events_config_test.go similarity index 98% rename from router-tests/events_config_test.go rename to router-tests/events/events_config_test.go index 5fb60fe603..110d0d4cff 100644 --- a/router-tests/events_config_test.go +++ b/router-tests/events/events_config_test.go @@ -1,4 +1,4 @@ -package integration_test +package events_test import ( "github.com/stretchr/testify/assert" diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index 9916bfb256..6afdcf7d35 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -20,6 +20,8 @@ import ( "github.com/wundergraph/cosmo/router/pkg/config" ) +const KafkaWaitTimeout = time.Second * 30 + func TestLocalKafka(t *testing.T) { t.Skip("skip only for local testing") @@ -88,17 +90,17 @@ func TestKafkaEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) _ = client.Close() }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(1, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForMessagesSent(1, KafkaWaitTimeout) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -159,32 +161,32 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], ``) // Empty message require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error require.Eventually(t, func() bool { return counter.Load() == 3 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -238,20 +240,20 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(2, time.Second*10) + xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(2, time.Second*10) + xEnv.WaitForMessagesSent(2, KafkaWaitTimeout) require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) _ = client.Close() - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -330,20 +332,20 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(2, time.Second*10) + xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) produceKafkaMessage(t, xEnv, topics[1], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForMessagesSent(4, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForMessagesSent(4, KafkaWaitTimeout) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -398,37 +400,24 @@ func TestKafkaEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) _ = client.Close() }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(1, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForMessagesSent(1, KafkaWaitTimeout) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) t.Run("multipart", func(t *testing.T) { t.Parallel() - assertLineEquals := func(t *testing.T, reader *bufio.Reader, expected string) { - line, _, err := reader.ReadLine() - require.NoError(t, err) - require.Equal(t, expected, string(line)) - } - - assertMultipartPrefix := func(t *testing.T, reader *bufio.Reader) { - assertLineEquals(t, reader, "") - assertLineEquals(t, reader, "--graphql") - assertLineEquals(t, reader, "Content-Type: application/json") - assertLineEquals(t, reader, "") - } - - var multipartHeartbeatInterval = time.Second + var multipartHeartbeatInterval = time.Second * 5 t.Run("subscribe sync", func(t *testing.T) { t.Parallel() @@ -447,10 +436,11 @@ func TestKafkaEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) - var consumed atomic.Uint32 - var produced atomic.Uint32 + var done atomic.Bool go func() { + defer done.Store(true) + client := http.Client{ Timeout: time.Second * 100, } @@ -461,45 +451,20 @@ func TestKafkaEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) - require.Eventually(t, func() bool { - return produced.Load() == 1 - }, time.Second*10, time.Millisecond*100) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") - consumed.Add(1) - - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{}") - consumed.Add(1) - - require.Eventually(t, func() bool { - return produced.Load() == 2 - }, time.Second*10, time.Millisecond*100) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") - - consumed.Add(1) + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(1, time.Second*5) - produced.Add(1) - require.Eventually(t, func() bool { - return consumed.Load() == 2 - }, time.Second*10, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(3, time.Second*5) // 2 messages + the empty one - produced.Add(1) - require.Eventually(t, func() bool { - return consumed.Load() == 3 - }, time.Second*10, time.Millisecond*100) + require.Eventually(t, done.Load, KafkaWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -528,11 +493,10 @@ func TestKafkaEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) }) @@ -585,16 +549,16 @@ func TestKafkaEvents(t *testing.T) { }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -645,16 +609,16 @@ func TestKafkaEvents(t *testing.T) { require.Equal(t, "", string(line)) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -697,8 +661,8 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, err) require.Equal(t, "data: {\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}", string(data)) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -738,7 +702,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -749,7 +713,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == MsgCount-11 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr := conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -763,7 +727,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == MsgCount-7 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -777,7 +741,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == MsgCount-4 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -791,7 +755,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == MsgCount-3 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -805,7 +769,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == MsgCount-1 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -822,14 +786,14 @@ func TestKafkaEvents(t *testing.T) { for i := MsgCount; i > 0; i-- { require.Eventually(t, func() bool { return consumed.Load() >= MsgCount-i - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) produced.Add(1) } require.Eventually(t, func() bool { return consumed.Load() == MsgCount && produced.Load() == MsgCount - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) }) }) @@ -869,7 +833,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -877,7 +841,7 @@ func TestKafkaEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return produced.Load() == 1 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr := conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -891,7 +855,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 2 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -905,7 +869,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 11 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -919,7 +883,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 12 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -936,14 +900,14 @@ func TestKafkaEvents(t *testing.T) { for i := uint32(1); i < 13; i++ { require.Eventually(t, func() bool { return consumed.Load() >= i-1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) produced.Add(1) } require.Eventually(t, func() bool { return consumed.Load() == 12 && produced.Load() == 12 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) }) }) @@ -983,7 +947,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -1050,7 +1014,7 @@ func TestKafkaEvents(t *testing.T) { for i := uint32(1); i < 13; i++ { require.Eventually(t, func() bool { return consumed.Load() >= i-1 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) produced.Add(1) } @@ -1097,7 +1061,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var counter atomic.Uint32 @@ -1125,7 +1089,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) }) }) @@ -1186,29 +1150,29 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{asas`) // Invalid message require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error require.Eventually(t, func() bool { return counter.Load() == 3 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) } @@ -1252,7 +1216,7 @@ func produceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri require.Eventually(t, func() bool { return done.Load() - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, pErr) diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index e07cd63196..b44ea68144 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -21,10 +21,39 @@ import ( "github.com/wundergraph/cosmo/router/pkg/config" "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" ) +const NatsWaitTimeout = time.Second * 30 + +func assertLineEquals(t *testing.T, reader *bufio.Reader, expected string) { + line, _, err := reader.ReadLine() + assert.NoError(t, err) + assert.Equal(t, expected, string(line)) +} + +func assertMultipartPrefix(t *testing.T, reader *bufio.Reader) { + assertLineEquals(t, reader, "") + assertLineEquals(t, reader, "--graphql") + assertLineEquals(t, reader, "Content-Type: application/json") + assertLineEquals(t, reader, "") +} + +func assertMultipartValueEventually(t *testing.T, reader *bufio.Reader, expected string) { + assert.Eventually(t, func() bool { + assertMultipartPrefix(t, reader) + line, _, err := reader.ReadLine() + assert.NoError(t, err) + if string(line) == "{}" { + return false + } + assert.Equal(t, expected, string(line)) + return true + }, NatsWaitTimeout, time.Millisecond*100) +} + func TestNatsEvents(t *testing.T) { t.Parallel() @@ -71,16 +100,7 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - var closed atomic.Bool - go func() { - require.Eventually(t, func() bool { - return subscriptionCalled.Load() == 2 - }, time.Second*20, time.Millisecond*100) - require.NoError(t, client.Close()) - closed.Store(true) - }() - - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the first subscription resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -88,6 +108,10 @@ func TestNatsEvents(t *testing.T) { }) require.JSONEq(t, `{"data":{"updateAvailability":{"id":3}}}`, resOne.Body) + assert.Eventually(t, func() bool { + return subscriptionCalled.Load() == 1 + }, NatsWaitTimeout, time.Millisecond*100) + // Trigger the first subscription via NATS err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) require.NoError(t, err) @@ -95,13 +119,19 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - require.Eventually(t, func() bool { - return closed.Load() - }, time.Second*20, time.Millisecond*100) + var closed atomic.Bool + go func() { + defer closed.Store(true) + assert.Eventually(t, func() bool { + return subscriptionCalled.Load() == 2 + }, NatsWaitTimeout, time.Millisecond*100) + assert.NoError(t, client.Close()) + }() + + assert.Eventually(t, closed.Load, NatsWaitTimeout, time.Millisecond*100) - xEnv.WaitForMessagesSent(2, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) natsLogs := xEnv.Observer().FilterMessageSnippet("Nats").All() require.Len(t, natsLogs, 4) @@ -137,7 +167,7 @@ func TestNatsEvents(t *testing.T) { subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { oldCount := counter.Load() - counter.Add(1) + defer counter.Add(1) if oldCount == 0 { var gqlErr graphql.Errors @@ -162,40 +192,46 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(``)) // Empty message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(1, time.Second*10) + require.Eventually(t, func() bool { + return counter.Load() == 1 + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(2, time.Second*10) + require.Eventually(t, func() bool { + return counter.Load() == 2 + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","update":{"name":"foo"}}`)) // Missing id require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(3, time.Second*10) + + require.Eventually(t, func() bool { + return counter.Load() == 3 + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(4, time.Second*10) require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) @@ -244,7 +280,7 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -265,32 +301,19 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForMessagesSent(2, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - //xEnv.WaitForConnectionCount(0, time.Second*10) flaky + xEnv.WaitForMessagesSent(2, NatsWaitTimeout) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) t.Run("multipart", func(t *testing.T) { t.Parallel() - assertLineEquals := func(t *testing.T, reader *bufio.Reader, expected string) { - line, _, err := reader.ReadLine() - require.NoError(t, err) - require.Equal(t, expected, string(line)) - } - - assertMultipartPrefix := func(t *testing.T, reader *bufio.Reader) { - assertLineEquals(t, reader, "") - assertLineEquals(t, reader, "--graphql") - assertLineEquals(t, reader, "Content-Type: application/json") - assertLineEquals(t, reader, "") - } - heartbeatInterval := 150 * time.Millisecond t.Run("subscribe with multipart responses", func(t *testing.T) { @@ -315,8 +338,6 @@ func TestNatsEvents(t *testing.T) { var consumed atomic.Uint32 go func() { - defer produced.Add(1) - req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload)) resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) @@ -331,24 +352,14 @@ func TestNatsEvents(t *testing.T) { reader := bufio.NewReader(resp.Body) - // Read the first part - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") - consumed.Add(1) - - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") consumed.Add(1) - require.Eventually(t, func() bool { - return produced.Load() == 2 - }, time.Second*5, time.Millisecond*100) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") consumed.Add(1) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -356,11 +367,10 @@ func TestNatsEvents(t *testing.T) { Query: `mutation { updateAvailability(employeeID: 3, isAvailable: true) { id } }`, }) require.JSONEq(t, `{"data":{"updateAvailability":{"id":3}}}`, res.Body) - produced.Add(1) require.Eventually(t, func() bool { - return consumed.Load() == 2 - }, time.Second*10, time.Millisecond*100) + return consumed.Load() == 1 + }, NatsWaitTimeout, time.Millisecond*100) // Trigger the subscription via NATS err := xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) @@ -371,12 +381,12 @@ func TestNatsEvents(t *testing.T) { produced.Add(1) require.Eventually(t, func() bool { - return consumed.Load() == 3 - }, time.Second*10, time.Millisecond*100) + return consumed.Load() == 2 + }, NatsWaitTimeout, time.Millisecond*100) }) }) - t.Run("subscribe with multipart responses http/1", func(t *testing.T) { + t.Run("subscribe with multipart responses http", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ @@ -413,11 +423,11 @@ func TestNatsEvents(t *testing.T) { assertLineEquals(t, reader, "{}") }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -431,13 +441,12 @@ func TestNatsEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { countFor(count: 3) }"}`) - var counter atomic.Uint32 + var done atomic.Bool - var client http.Client go func() { - defer counter.Add(1) + defer done.Store(true) - client = http.Client{} + client := http.Client{} req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload)) resp, err := client.Do(req) require.NoError(t, err) @@ -447,21 +456,15 @@ func TestNatsEvents(t *testing.T) { reader := bufio.NewReader(resp.Body) // Read the first part - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":0}}}") - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":1}}}") - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":2}}}") - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":3}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":0}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":1}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":2}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":3}}}") assertLineEquals(t, reader, "--graphql--") }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) - require.Eventually(t, func() bool { - return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) + require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -494,8 +497,7 @@ func TestNatsEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") } }) }) @@ -547,7 +549,7 @@ func TestNatsEvents(t *testing.T) { require.Error(t, err, io.EOF) // Subscription closed after one time }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -565,7 +567,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -579,9 +581,16 @@ func TestNatsEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } } }"}`) - var requestCompleted atomic.Bool + var done atomic.Bool + var producerDone atomic.Bool + + waitForProducer := func() { + assert.Eventually(t, producerDone.Load, NatsWaitTimeout, time.Millisecond*100) + producerDone.Store(false) + } go func() { + defer done.Store(true) client := http.Client{} req, err := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload)) require.NoError(t, err) @@ -597,6 +606,7 @@ func TestNatsEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) + waitForProducer() eventNext, _, err := reader.ReadLine() require.NoError(t, err) require.Equal(t, "event: next", string(eventNext)) @@ -607,6 +617,7 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, err) require.Equal(t, "", string(line)) + waitForProducer() eventNext, _, err = reader.ReadLine() require.NoError(t, err) require.Equal(t, "event: next", string(eventNext)) @@ -616,11 +627,9 @@ func TestNatsEvents(t *testing.T) { line, _, err = reader.ReadLine() require.NoError(t, err) require.Equal(t, "", string(line)) - - requestCompleted.Store(true) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -628,17 +637,23 @@ func TestNatsEvents(t *testing.T) { Query: `mutation { updateAvailability(employeeID: 3, isAvailable: true) { id } }`, }) require.JSONEq(t, `{"data":{"updateAvailability":{"id":3}}}`, res.Body) + err := xEnv.NatsConnectionDefault.Flush() + require.NoError(t, err) + producerDone.Store(true) + + assert.Eventually(t, func() bool { + return !producerDone.Load() + }, NatsWaitTimeout, time.Millisecond*100) // Trigger the subscription via NATS - err := xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) + err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) + producerDone.Store(true) - require.Eventually(t, func() bool { - return requestCompleted.Load() - }, time.Second*10, time.Millisecond*100) + require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -755,7 +770,7 @@ func TestNatsEvents(t *testing.T) { require.Equal(t, "", string(line)) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -772,10 +787,10 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) @@ -902,7 +917,7 @@ func TestNatsEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*20) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Trigger the first subscription via NATS err = xEnv.NatsConnectionMyNats.Publish(xEnv.GetPubSubName("employeeUpdatedMyNats.12"), []byte(`{"id":13,"__typename":"Employee"}`)) @@ -911,7 +926,7 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionMyNats.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(1, time.Second*10) + xEnv.WaitForMessagesSent(1, NatsWaitTimeout) err = conn.ReadJSON(&msg) require.NoError(t, err) @@ -928,7 +943,7 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionMyNats.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(2, time.Second*10) + xEnv.WaitForMessagesSent(2, NatsWaitTimeout) err = conn.ReadJSON(&msg) require.NoError(t, err) @@ -980,7 +995,7 @@ func TestNatsEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Trigger the first subscription via NATS err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.12"), []byte(`{"id":13,"__typename":"Employee"}`)) @@ -1003,7 +1018,7 @@ func TestNatsEvents(t *testing.T) { Type: "complete", }) require.NoError(t, err) - xEnv.WaitForSubscriptionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) var complete testenv.WebSocketMessage err = conn.ReadJSON(&complete) @@ -1024,7 +1039,7 @@ func TestNatsEvents(t *testing.T) { Payload: []byte(`{"query":"subscription { employeeUpdatedNatsStream(id: 12) { id }}"}`), }) require.NoError(t, err) - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) err = conn.ReadJSON(&msg) require.NoError(t, err) @@ -1101,7 +1116,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -1139,7 +1154,7 @@ func TestNatsEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -1147,7 +1162,7 @@ func TestNatsEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return produced.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr := conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1161,7 +1176,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1175,7 +1190,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1189,7 +1204,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 5 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1203,7 +1218,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 6 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1217,7 +1232,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 8 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1231,7 +1246,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 9 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1241,11 +1256,11 @@ func TestNatsEvents(t *testing.T) { require.Equal(t, float64(8), payload.Data.FilteredEmployeeUpdated.ID) require.Equal(t, "Nithin", payload.Data.FilteredEmployeeUpdated.Details.Forename) require.Equal(t, "Kumar", payload.Data.FilteredEmployeeUpdated.Details.Surname) - consumed.Add(2) // should skip two messages + consumed.Add(3) // should skip two messages require.Eventually(t, func() bool { return produced.Load() == 12 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1266,9 +1281,8 @@ func TestNatsEvents(t *testing.T) { // Events 1, 3, 4, 5, 7, 8, and 11 should be included for i := uint32(1); i < 13; i++ { require.Eventually(t, func() bool { - return consumed.Load() >= i-1 - }, time.Second*10, time.Millisecond*100) - + return consumed.Load() >= i + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.1"), []byte(fmt.Sprintf(`{"id":%d,"__typename":"Employee"}`, i))) require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() @@ -1277,8 +1291,8 @@ func TestNatsEvents(t *testing.T) { } require.Eventually(t, func() bool { - return consumed.Load() == 11 && produced.Load() == 13 - }, time.Second*10, time.Millisecond*100) + return consumed.Load() == 12 && produced.Load() == 13 + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -1292,13 +1306,19 @@ func TestNatsEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { filteredEmployeeUpdated(id: 1) { id details { forename surname } } }"}`) - var requestsDone atomic.Bool + var done atomic.Bool + var producerDone atomic.Bool + + waitForProducer := func() { + assert.Eventually(t, producerDone.Load, NatsWaitTimeout, time.Millisecond*100) + producerDone.Store(false) + } tick := make(chan struct{}, 1) - timeout := time.After(time.Second * 10) + timeout := time.After(NatsWaitTimeout) go func() { - defer requestsDone.Store(true) + defer done.Store(true) client := http.Client{} req, gErr := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload)) @@ -1321,6 +1341,7 @@ func TestNatsEvents(t *testing.T) { reader := bufio.NewReader(resp.Body) + waitForProducer() eventNext, _, gErr := reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1337,6 +1358,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1353,6 +1375,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1369,6 +1392,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1385,6 +1409,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1401,6 +1426,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1417,6 +1443,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1433,6 +1460,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1444,7 +1472,7 @@ func TestNatsEvents(t *testing.T) { require.Equal(t, "", string(line)) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Trigger the subscription via NATS err := xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.1"), []byte(`{"id":1,"__typename": "Employee"}`)) @@ -1453,6 +1481,8 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) + producerDone.Store(true) + // Events 1, 3, 4, 5, 7, 8, and 11 should be included for i := 1; i < 13; i++ { @@ -1460,6 +1490,9 @@ func TestNatsEvents(t *testing.T) { case 1, 3, 4, 5, 7, 8, 11: select { case <-tick: + assert.Eventually(t, func() bool { + return !producerDone.Load() + }, NatsWaitTimeout, time.Millisecond*100) case <-timeout: require.Fail(t, "timeout") } @@ -1471,11 +1504,10 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) + producerDone.Store(true) } - require.Eventually(t, func() bool { - return requestsDone.Load() - }, time.Second*10, time.Millisecond*100) + require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -1507,19 +1539,25 @@ func TestNatsEvents(t *testing.T) { subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { defer consumed.Add(1) - oldCount := produced.Load() + oldCount := consumed.Load() + require.Eventually(t, func() bool { + return oldCount == produced.Load()-1 + }, NatsWaitTimeout, time.Millisecond*100) - if oldCount == 1 { + if oldCount == 0 { var gqlErr graphql.Errors require.ErrorAs(t, errValue, &gqlErr) - require.Equal(t, "Invalid message received", gqlErr[0].Message) - } else if oldCount == 2 || oldCount == 4 { - require.NoError(t, errValue) - require.JSONEq(t, `{"employeeUpdated":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(dataValue)) - } else if oldCount == 3 { + assert.Equal(t, "Invalid message received", gqlErr[0].Message) + } else if oldCount == 1 { + assert.NoError(t, errValue) + assert.JSONEq(t, `{"employeeUpdated":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(dataValue)) + } else if oldCount == 2 { var gqlErr graphql.Errors require.ErrorAs(t, errValue, &gqlErr) - require.Equal(t, "Cannot return null for non-nullable field 'Subscription.employeeUpdated.id'.", gqlErr[0].Message) + assert.Equal(t, "Cannot return null for non-nullable field 'Subscription.employeeUpdated.id'.", gqlErr[0].Message) + } else if oldCount == 3 { + assert.NoError(t, errValue) + assert.JSONEq(t, `{"employeeUpdated":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(dataValue)) } return nil @@ -1532,53 +1570,53 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{asas`)) // Invalid message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(1, time.Second*10) + xEnv.WaitForMessagesSent(1, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 1 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(2, time.Second*10) + xEnv.WaitForMessagesSent(2, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 2 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","update":{"name":"foo"}}`)) // Missing id require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(3, time.Second*10) + xEnv.WaitForMessagesSent(3, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 3 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(4, time.Second*10) + xEnv.WaitForMessagesSent(4, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 4 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) } diff --git a/router-tests/go.mod b/router-tests/go.mod index 0ed2233e88..2b734cb1f3 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -26,7 +26,7 @@ require ( github.com/twmb/franz-go/pkg/kadm v1.11.0 github.com/wundergraph/cosmo/demo v0.0.0-20250119174948-4b991294658e github.com/wundergraph/cosmo/router v0.0.0-20250119174948-4b991294658e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146 go.opentelemetry.io/otel v1.28.0 go.opentelemetry.io/otel/sdk v1.28.0 go.opentelemetry.io/otel/sdk/metric v1.28.0 @@ -60,7 +60,6 @@ require ( github.com/cpuguy83/dockercfg v0.3.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgraph-io/ristretto/v2 v2.1.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.5.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 2bdc071b5e..fc2c1f7366 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -67,8 +67,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= -github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= @@ -363,8 +361,8 @@ github.com/vektah/gqlparser/v2 v2.5.21 h1:Zw1rG2dr1pRR4wqwbVq4d6+xk2f4ut/yo+hwr4 github.com/vektah/gqlparser/v2 v2.5.21/go.mod h1:xMl+ta8a5M1Yo1A1Iwt/k7gSpscwSnHZdw7tfhEGfTM= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145 h1:3JuBmRux6YB/UZgh6COvgLXzQhMIsdHV7A02NsYdAVE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145/go.mod h1:B7eV0Qh8Lop9QzIOQcsvKp3S0ejfC6mgyWoJnI917yQ= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146 h1:C9+jjMgbU/RJTiFGC0HNHan4LxrY7fIhmbZRoqZryLk= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146/go.mod h1:B7eV0Qh8Lop9QzIOQcsvKp3S0ejfC6mgyWoJnI917yQ= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/router-tests/telemetry/telemetry_test.go b/router-tests/telemetry/telemetry_test.go index 3b8c5542eb..832e48ed3e 100644 --- a/router-tests/telemetry/telemetry_test.go +++ b/router-tests/telemetry/telemetry_test.go @@ -2,6 +2,7 @@ package telemetry import ( "context" + "github.com/stretchr/testify/assert" "net/http" "regexp" "runtime" @@ -209,6 +210,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn.ReadJSON(&res) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(1, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 1, Connections: 1, @@ -220,6 +222,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn.ReadJSON(&complete) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(2, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 1, Connections: 1, @@ -288,6 +291,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { wg.Wait() xEnv.WaitForSubscriptionCount(2, time.Second*5) + xEnv.WaitForTriggerCount(1, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, @@ -300,6 +304,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err := conn1.ReadJSON(&res) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(1, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, Connections: 2, @@ -310,6 +315,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn2.ReadJSON(&res) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(2, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, Connections: 2, @@ -324,6 +330,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn2.ReadJSON(&complete) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(4, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, Connections: 2, @@ -460,7 +467,8 @@ func TestEngineStatisticsTelemetry(t *testing.T) { }) } -func TestOperationCacheTelemetry(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyOperationCacheTelemetry(t *testing.T) { t.Parallel() const ( @@ -2522,7 +2530,8 @@ func TestOperationCacheTelemetry(t *testing.T) { }) } -func TestRuntimeTelemetry(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyRuntimeTelemetry(t *testing.T) { t.Parallel() const employeesIDData = `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}` @@ -2899,7 +2908,8 @@ func TestRuntimeTelemetry(t *testing.T) { }) } -func TestTelemetry(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyTelemetry(t *testing.T) { t.Parallel() const employeesIDData = `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}` @@ -4033,7 +4043,7 @@ func TestTelemetry(t *testing.T) { }) require.NoError(t, err) require.Equal(t, `{"data":{"rootFieldWithListArg":["a"]}}`, res.Body) - require.Equal(t, "HIT", res.Response.Header.Get(core.PersistedOperationCacheHeader)) + assert.Equal(t, "HIT", res.Response.Header.Get(core.PersistedOperationCacheHeader)) sn = exporter.GetSpans().Snapshots() @@ -8593,8 +8603,10 @@ func TestTelemetry(t *testing.T) { require.Equal(t, `{"errors":[{"message":"The total number of fields 2 exceeds the limit allowed (1)"}]}`, failedRes2.Body) testSpan2 := integration.RequireSpanWithName(t, exporter, "Operation - Validate") - require.Contains(t, testSpan2.Attributes(), otel.WgQueryTotalFields.Int(2)) - require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true)) + assert.Contains(t, testSpan2.Attributes(), otel.WgQueryTotalFields.Int(2)) + assert.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true)) + assert.Equal(t, codes.Unset, testSpan2.Status().Code) + assert.Equal(t, []sdktrace.Event(nil), testSpan2.Events()) exporter.Reset() successRes := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ diff --git a/router-tests/testenv/pubsub.go b/router-tests/testenv/pubsub.go index c44c3389d2..1bf3dcd14e 100644 --- a/router-tests/testenv/pubsub.go +++ b/router-tests/testenv/pubsub.go @@ -96,6 +96,9 @@ func setupNatsData(t testing.TB) (*NatsData, error) { nats.MaxReconnects(10), nats.ReconnectWait(1*time.Second), nats.Timeout(5*time.Second), + nats.ErrorHandler(func(conn *nats.Conn, subscription *nats.Subscription, err error) { + t.Log(err) + }), ) if err != nil { return nil, err diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 011095b77a..607c741713 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -87,10 +87,12 @@ var ( func Run(t *testing.T, cfg *Config, f func(t *testing.T, xEnv *Environment)) { t.Helper() env, err := createTestEnv(t, cfg) + if env != nil { + t.Cleanup(env.Shutdown) + } if err != nil { t.Fatalf("could not create environment: %s", err) } - t.Cleanup(env.Shutdown) f(t, env) if cfg.AssertCacheMetrics != nil { assertCacheMetrics(t, env, cfg.AssertCacheMetrics.BaseGraphAssertions, "") @@ -106,10 +108,12 @@ func Run(t *testing.T, cfg *Config, f func(t *testing.T, xEnv *Environment)) { func RunWithError(t *testing.T, cfg *Config, f func(t *testing.T, xEnv *Environment)) error { t.Helper() env, err := createTestEnv(t, cfg) + if env != nil { + t.Cleanup(env.Shutdown) + } if err != nil { return err } - t.Cleanup(env.Shutdown) f(t, env) if cfg.AssertCacheMetrics != nil { assertCacheMetrics(t, env, cfg.AssertCacheMetrics.BaseGraphAssertions, "") @@ -122,10 +126,12 @@ func Bench(b *testing.B, cfg *Config, f func(b *testing.B, xEnv *Environment)) { b.Helper() b.StopTimer() env, err := createTestEnv(b, cfg) + if env != nil { + b.Cleanup(env.Shutdown) + } if err != nil { b.Fatalf("could not create environment: %s", err) } - b.Cleanup(env.Shutdown) b.StartTimer() f(b, env) if cfg.AssertCacheMetrics != nil { @@ -377,6 +383,27 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { ctx, cancel := context.WithCancelCause(context.Background()) + var ( + logObserver *observer.ObservedLogs + ) + + if oc := cfg.LogObservation; oc.Enabled { + var zCore zapcore.Core + zCore, logObserver = observer.New(oc.LogLevel) + cfg.Logger = logging.NewZapLoggerWithCore(zCore, true) + } else { + ec := zap.NewProductionEncoderConfig() + ec.EncodeDuration = zapcore.SecondsDurationEncoder + ec.TimeKey = "time" + + syncer := zapcore.AddSync(os.Stderr) + cfg.Logger = logging.NewZapLogger(syncer, false, true, zapcore.WarnLevel) + } + + if cfg.AccessLogger == nil { + cfg.AccessLogger = cfg.Logger + } + counters := &SubgraphRequestCount{ Global: atomic.NewInt64(0), Employees: atomic.NewInt64(0), @@ -401,7 +428,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { getPubSubName := GetPubSubNameFn(pubSubPrefix) employees := &Subgraph{ - handler: subgraphs.EmployeesHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.EmployeesHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Employees.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -411,7 +438,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } family := &Subgraph{ - handler: subgraphs.FamilyHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.FamilyHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Family.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -421,7 +448,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } hobbies := &Subgraph{ - handler: subgraphs.HobbiesHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.HobbiesHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Hobbies.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -431,7 +458,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } products := &Subgraph{ - handler: subgraphs.ProductsHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.ProductsHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Products.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -441,7 +468,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } productsFg := &Subgraph{ - handler: subgraphs.ProductsFGHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.ProductsFGHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.ProductsFg.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -451,7 +478,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } test1 := &Subgraph{ - handler: subgraphs.Test1Handler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.Test1Handler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Test1.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -461,7 +488,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } availability := &Subgraph{ - handler: subgraphs.AvailabilityHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.AvailabilityHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Availability.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -471,7 +498,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } mood := &Subgraph{ - handler: subgraphs.MoodHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.MoodHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Mood.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -481,7 +508,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } countries := &Subgraph{ - handler: subgraphs.CountriesHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.CountriesHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Countries.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -556,27 +583,6 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { client = retryClient.StandardClient() } - var ( - logObserver *observer.ObservedLogs - ) - - if oc := cfg.LogObservation; oc.Enabled { - var zCore zapcore.Core - zCore, logObserver = observer.New(oc.LogLevel) - cfg.Logger = logging.NewZapLoggerWithCore(zCore, true) - } else { - ec := zap.NewProductionEncoderConfig() - ec.EncodeDuration = zapcore.SecondsDurationEncoder - ec.TimeKey = "time" - - syncer := zapcore.AddSync(os.Stderr) - cfg.Logger = logging.NewZapLogger(syncer, false, true, zapcore.ErrorLevel) - } - - if cfg.AccessLogger == nil { - cfg.AccessLogger = cfg.Logger - } - kafkaStarted.Wait() rr, err := configureRouter(listenerAddr, cfg, &routerConfig, cdn, natsSetup) @@ -1870,7 +1876,71 @@ func (e *Environment) WaitForTriggerCount(desiredCount uint64, timeout time.Dura } } -func subgraphOptions(ctx context.Context, t testing.TB, natsData *NatsData, pubSubName func(string) string) *subgraphs.SubgraphOptions { +func DeflakeWSReadMessage(t testing.TB, conn *websocket.Conn) (messageType int, p []byte, err error) { + for i := 0; i < 5; i++ { + messageType, p, err = conn.ReadMessage() + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return messageType, p, err +} + +func DeflakeWSReadJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) { + for i := 0; i < 5; i++ { + err = conn.ReadJSON(v) + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return err +} + +func DeflakeWSWriteMessage(t testing.TB, conn *websocket.Conn, messageType int, data []byte) (err error) { + for i := 0; i < 5; i++ { + err = conn.WriteMessage(messageType, data) + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return err +} + +func DeflakeWSWriteJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) { + for i := 0; i < 5; i++ { + err = conn.WriteJSON(v) + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return err +} + +func subgraphOptions(ctx context.Context, t testing.TB, logger *zap.Logger, natsData *NatsData, pubSubName func(string) string) *subgraphs.SubgraphOptions { if natsData == nil { return &subgraphs.SubgraphOptions{ NatsPubSubByProviderID: map[string]pubsub_datasource.NatsPubSub{}, @@ -1879,13 +1949,10 @@ func subgraphOptions(ctx context.Context, t testing.TB, natsData *NatsData, pubS } natsPubSubByProviderID := make(map[string]pubsub_datasource.NatsPubSub, len(demoNatsProviders)) for _, sourceName := range demoNatsProviders { - natsConnection, err := nats.Connect(natsData.Server.ClientURL()) - require.NoError(t, err) - - js, err := jetstream.New(natsConnection) + js, err := jetstream.New(natsData.Connections[0]) require.NoError(t, err) - natsPubSubByProviderID[sourceName] = pubsubNats.NewConnector(zap.NewNop(), natsConnection, js, "hostname", "listenaddr").New(ctx) + natsPubSubByProviderID[sourceName] = pubsubNats.NewConnector(logger, natsData.Connections[0], js, "hostname", "listenaddr").New(ctx) } return &subgraphs.SubgraphOptions{ diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index eb82b4d8ca..347b5f5924 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -49,20 +49,20 @@ func TestWebSockets(t *testing.T) { testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -99,20 +99,20 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id startDate } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) require.Equal(t, `[{"message":"Unauthorized"}]`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -149,20 +149,20 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id startDate } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) require.Equal(t, `[{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",0,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",1,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",2,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",3,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",4,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",5,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",6,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",7,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",8,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",9,"startDate"]}]`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -201,7 +201,7 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } startDate }}"}`), @@ -220,7 +220,7 @@ func TestWebSockets(t *testing.T) { }() var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) @@ -261,7 +261,7 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } startDate }}"}`), @@ -278,7 +278,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) }() var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) @@ -321,7 +321,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) initialPayload := []byte(`{"Authorization":"Bearer ` + token + `"}`) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, initialPayload) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -340,7 +340,7 @@ func TestWebSockets(t *testing.T) { }() var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) @@ -379,7 +379,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { require.NoError(t, err) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -387,7 +387,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) payload, err := json.Marshal(res.Payload) @@ -429,7 +429,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) initialPayload := []byte(`{"Authorization": true }`) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, initialPayload) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -437,7 +437,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) payload, err := json.Marshal(res.Payload) @@ -468,7 +468,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) initialPayload := []byte(`{"Authorization":"` + token + `"}`) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, initialPayload) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -490,7 +490,7 @@ func TestWebSockets(t *testing.T) { require.Eventually(t, done.Load, time.Second*5, time.Millisecond*100) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) @@ -519,18 +519,6 @@ func TestWebSockets(t *testing.T) { expectConnectAndReadCurrentTime(t, xEnv) }) }) - t.Run("subscription with multiple reconnects and netPoll disabled", func(t *testing.T) { - t.Parallel() - - testenv.Run(t, &testenv.Config{ - ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { - engineExecutionConfiguration.EnableNetPoll = false - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - expectConnectAndReadCurrentTime(t, xEnv) - expectConnectAndReadCurrentTime(t, xEnv) - }) - }) t.Run("subscription with header propagation", func(t *testing.T) { t.Parallel() @@ -579,27 +567,27 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, message, err := conn.ReadMessage() + _, message, err := testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) require.Equal(t, `{"type":"connection_init","payload":{"Custom-Auth":"test","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Reverse-Canonical-Header-Name":"matches as well","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"token":"Bearer Something"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) require.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription{currentTime {unixTime timeStamp}}","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Reverse-Canonical-Header-Name":"matches as well","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"token":"Bearer Something"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) if errors.Is(err, websocket.ErrCloseSent) { return } require.Equal(t, `{"id":"1","type":"complete"}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) require.NoError(t, err) }) }, @@ -629,7 +617,7 @@ func TestWebSockets(t *testing.T) { }, []byte(`{"Custom-Auth":"test"}`), ) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -639,7 +627,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -648,7 +636,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -657,7 +645,7 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) @@ -720,29 +708,29 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, message, err := conn.ReadMessage() + _, message, err := testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) message = jsonparser.Delete(message, "payload", "extensions", "upgradeHeaders", "Sec-Websocket-Key") // Sec-Websocket-Key is a random value require.Equal(t, `{"type":"connection_init","payload":{"Custom-Auth":"test","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Connection":"Upgrade","Ignored":"ignored","Not-Allowlisted-But-Forwarded":"but still part of the origin upgrade request","Reverse-Canonical-Header-Name":"matches as well","Sec-Websocket-Protocol":"graphql-transport-ws","Sec-Websocket-Version":"13","Upgrade":"websocket","User-Agent":"Go-http-client/1.1","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"ignored":"ignored","token":"Bearer Something","x-custom-auth":"customAuth"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) message = jsonparser.Delete(message, "payload", "extensions", "upgradeHeaders", "Sec-Websocket-Key") // Sec-Websocket-Key is a random value require.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription{currentTime {unixTime timeStamp}}","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Connection":"Upgrade","Ignored":"ignored","Not-Allowlisted-But-Forwarded":"but still part of the origin upgrade request","Reverse-Canonical-Header-Name":"matches as well","Sec-Websocket-Protocol":"graphql-transport-ws","Sec-Websocket-Version":"13","Upgrade":"websocket","User-Agent":"Go-http-client/1.1","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"ignored":"ignored","token":"Bearer Something","x-custom-auth":"customAuth"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) if errors.Is(err, websocket.ErrCloseSent) { return } require.Equal(t, `{"id":"1","type":"complete"}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) require.NoError(t, err) }) }, @@ -773,7 +761,7 @@ func TestWebSockets(t *testing.T) { }, []byte(`{"Custom-Auth":"test"}`), ) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -783,7 +771,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -792,7 +780,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -801,7 +789,7 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) @@ -883,7 +871,7 @@ func TestWebSockets(t *testing.T) { conn := xEnv.InitGraphQLWebSocketConnection(http.Header{ "Authorization": []string{"Bearer test"}, }, nil, []byte(`{"Custom-Auth":"test"}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -893,7 +881,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -902,7 +890,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -911,14 +899,14 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { @@ -999,7 +987,7 @@ func TestWebSockets(t *testing.T) { conn := xEnv.InitGraphQLWebSocketConnection(http.Header{ "Authorization": []string{"Bearer test"}, }, nil, []byte(`{"Custom-Auth":"test"}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -1009,7 +997,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -1018,7 +1006,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -1027,14 +1015,14 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { @@ -1064,7 +1052,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -1073,7 +1061,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1103,7 +1091,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -1112,7 +1100,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1139,7 +1127,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { returnsError }"}`), @@ -1148,7 +1136,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1178,7 +1166,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { returnsError }"}`), @@ -1187,7 +1175,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1331,7 +1319,7 @@ func TestWebSockets(t *testing.T) { testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { does_not_exist }"}`), @@ -1340,7 +1328,7 @@ func TestWebSockets(t *testing.T) { err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "error", msg.Type) // Payload should be an array of GraphQLError @@ -1539,14 +1527,14 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123": 456, "extensions": {"hello": "world"}}`)) var err error - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }"}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1559,14 +1547,14 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123": 456, "extensions": {"hello": "world"}}`)) var err error - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }"}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1581,20 +1569,20 @@ func TestWebSockets(t *testing.T) { conn := xEnv.InitGraphQLWebSocketConnection(map[string][]string{ "X-Feature-Flag": {"myff"}, }, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id productCount } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) require.JSONEq(t, `{"data":{"employees":[{"id":1,"productCount":5},{"id":2,"productCount":2},{"id":3,"productCount":2},{"id":4,"productCount":3},{"id":5,"productCount":2},{"id":7,"productCount":0},{"id":8,"productCount":2},{"id":10,"productCount":3},{"id":11,"productCount":1},{"id":12,"productCount":4}]}}`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -1607,14 +1595,14 @@ func TestWebSockets(t *testing.T) { testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id productCount } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) @@ -1635,7 +1623,7 @@ func TestWebSockets(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { does_not_exist }"}`), @@ -1643,7 +1631,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) // Discard the first message var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) xEnv.Shutdown() _, _, err = conn.NextReader() @@ -1664,7 +1652,7 @@ func TestWebSockets(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { does_not_exist }"}`), @@ -1672,7 +1660,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) // Discard the first message var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) xEnv.Shutdown() _, _, err = conn.NextReader() @@ -1692,14 +1680,14 @@ func TestWebSockets(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123":456,"extensions":{"hello":"world"}}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }"}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1714,14 +1702,14 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { // "extensions" in the request should override the "extensions" in initial payload conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123":456,"extensions":{"hello":"world"}}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }","extensions":{"hello":"world2"}}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"hello":"world2","initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1743,7 +1731,7 @@ func TestWebSockets(t *testing.T) { Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } } }"}`), } - err := conn.WriteJSON(&sub1) + err := testenv.DeflakeWSWriteJSON(t, conn, &sub1) require.NoError(t, err) sub2 := testenv.WebSocketMessage{ @@ -1751,7 +1739,7 @@ func TestWebSockets(t *testing.T) { Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), } - err = conn.WriteJSON(&sub2) + err = testenv.DeflakeWSWriteJSON(t, conn, &sub2) require.NoError(t, err) xEnv.WaitForSubscriptionCount(2, time.Second*5) @@ -1769,7 +1757,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage for { - err := conn.ReadJSON(&msg) + err := testenv.DeflakeWSReadJSON(t, conn, &msg) if err != nil { return } @@ -1786,10 +1774,10 @@ func TestWebSockets(t *testing.T) { ID: "1", Type: "complete", } - err = conn.WriteJSON(&stop) + err = testenv.DeflakeWSWriteJSON(t, conn, &stop) require.NoError(t, err) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) @@ -1804,10 +1792,10 @@ func TestWebSockets(t *testing.T) { ID: "2", Type: "complete", } - err = conn.WriteJSON(&stop) + err = testenv.DeflakeWSWriteJSON(t, conn, &stop) require.NoError(t, err) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "2", complete.ID) require.Equal(t, "complete", complete.Type) @@ -1818,7 +1806,7 @@ func TestWebSockets(t *testing.T) { terminate := testenv.WebSocketMessage{ Type: "connection_terminate", } - err = conn.WriteJSON(&terminate) + err = testenv.DeflakeWSWriteJSON(t, conn, &terminate) require.NoError(t, err) _, _, err = conn.NextReader() require.Error(t, err) @@ -1896,19 +1884,19 @@ func TestWebSockets(t *testing.T) { } conn := xEnv.InitAbsintheWebSocketConnection(nil, json.RawMessage(`["1", "1", "__absinthe__:control", "phx_join", {}]`)) - err := conn.WriteJSON(json.RawMessage(`["1", "1", "__absinthe__:control", "doc", {"query":"subscription { currentTime { unixTime timeStamp }}" }]`)) + err := testenv.DeflakeWSWriteJSON(t, conn, json.RawMessage(`["1", "1", "__absinthe__:control", "doc", {"query":"subscription { currentTime { unixTime timeStamp }}" }]`)) require.NoError(t, err) var msg json.RawMessage var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) h := sha256.New() h.Write([]byte("1")) operationId := new(big.Int).SetBytes(h.Sum(nil)) require.Equal(t, string(msg), fmt.Sprintf(`["1","1","__absinthe__:control","phx_reply",{"status":"ok","response":{"subscriptionId":"__absinthe__:doc:1:%s"}}]`, operationId)) - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Contains(t, string(msg), `["1","1","__absinthe__:control","subscription:data"`) var data []json.RawMessage @@ -1920,7 +1908,7 @@ func TestWebSockets(t *testing.T) { unix1 := payload.Result.Data.CurrentTime.UnixTime - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Contains(t, string(msg), `["1","1","__absinthe__:control","subscription:data"`) err = json.Unmarshal(msg, &data) @@ -1933,19 +1921,19 @@ func TestWebSockets(t *testing.T) { require.Greater(t, unix2, unix1) // Sending a complete must stop the subscription - err = conn.WriteJSON(json.RawMessage(`["1", "1", "__absinthe__:control", "phx_leave", {}]`)) + err = testenv.DeflakeWSWriteJSON(t, conn, json.RawMessage(`["1", "1", "__absinthe__:control", "phx_leave", {}]`)) require.NoError(t, err) var complete json.RawMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, string(complete), fmt.Sprintf(`["1","","__absinthe__:control","phx_reply",{"status":"ok","response":{"subscriptionId":"__absinthe__:doc:1:%s"}}]`, operationId)) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { @@ -1972,7 +1960,7 @@ func TestWebSockets(t *testing.T) { })}, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime } }"}`), @@ -1980,7 +1968,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) @@ -1991,6 +1979,21 @@ func TestWebSockets(t *testing.T) { }) } +func TestFlakyWebSockets(t *testing.T) { + t.Run("subscription with multiple reconnects and netPoll disabled", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { + engineExecutionConfiguration.EnableNetPoll = false + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + expectConnectAndReadCurrentTime(t, xEnv) + expectConnectAndReadCurrentTime(t, xEnv) + }) + }) +} + func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { type currentTimePayload struct { Data struct { @@ -2004,7 +2007,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) defer conn.Close() - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -2014,7 +2017,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -2023,7 +2026,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { unix1 := payload.Data.CurrentTime.UnixTime - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -2034,7 +2037,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { require.Greater(t, unix2, unix1) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -2043,14 +2046,14 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { diff --git a/router/go.mod b/router/go.mod index 0b7d3c3a0c..240ff1747c 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146 // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index 5dabb1661f..c2694f8900 100644 --- a/router/go.sum +++ b/router/go.sum @@ -275,8 +275,8 @@ github.com/vektah/gqlparser/v2 v2.5.16 h1:1gcmLTvs3JLKXckwCwlUagVn/IlV2bwqle0vJ0 github.com/vektah/gqlparser/v2 v2.5.16/go.mod h1:1lz1OeCqgQbQepsGxPVywrjdBHW2T08PUS3pJqepRww= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145 h1:3JuBmRux6YB/UZgh6COvgLXzQhMIsdHV7A02NsYdAVE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.145/go.mod h1:B7eV0Qh8Lop9QzIOQcsvKp3S0ejfC6mgyWoJnI917yQ= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146 h1:C9+jjMgbU/RJTiFGC0HNHan4LxrY7fIhmbZRoqZryLk= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.146/go.mod h1:B7eV0Qh8Lop9QzIOQcsvKp3S0ejfC6mgyWoJnI917yQ= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24= diff --git a/router/pkg/pubsub/nats/nats.go b/router/pkg/pubsub/nats/nats.go index 921b199e40..f7aaeda978 100644 --- a/router/pkg/pubsub/nats/nats.go +++ b/router/pkg/pubsub/nats/nats.go @@ -271,8 +271,19 @@ func (p *natsPubSub) Shutdown(ctx context.Context) error { p.logger.Error("error draining NATS connection", zap.Error(drainErr)) } - // Wait for all subscriptions to be closed - p.closeWg.Wait() + // Wait for all subscriptions to be closed until timeout + timeout := time.Second * 5 + done := make(chan struct{}) + go func() { + p.closeWg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + p.logger.Warn("timeout reached before the connection has been drained") + } return err }