diff --git a/stream.go b/stream.go index e0c90ba..59e67aa 100644 --- a/stream.go +++ b/stream.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "sync" "time" ) @@ -14,11 +15,12 @@ import ( // It will try and reconnect if the connection is lost, respecting both // received retry delays and event id's. type Stream struct { - c *http.Client - req *http.Request - lastEventID string - readTimeout time.Duration - retryDelay *retryDelayStrategy + c *http.Client + req *http.Request + queryParamsFunc *func(existing url.Values) url.Values + lastEventID string + readTimeout time.Duration + retryDelay *retryDelayStrategy // Events emits the events received by the stream Events chan Event // Errors emits any errors encountered while reading events from the stream. @@ -187,6 +189,10 @@ func newStream(request *http.Request, configuredOptions streamOptions) *Stream { closer: make(chan struct{}), } + if configuredOptions.queryParamsFunc != nil { + stream.queryParamsFunc = configuredOptions.queryParamsFunc + } + if configuredOptions.errorHandler == nil { // The Errors channel is only used if there is no error handler. stream.Errors = make(chan error) @@ -231,6 +237,9 @@ func (stream *Stream) connect() (io.ReadCloser, error) { stream.req.Header.Set("Last-Event-ID", stream.lastEventID) } req := *stream.req + if stream.queryParamsFunc != nil { + req.URL.RawQuery = (*stream.queryParamsFunc)(req.URL.Query()).Encode() + } // All but the initial connection will need to regenerate the body if stream.connections > 0 && req.GetBody != nil { diff --git a/stream_options.go b/stream_options.go index b1f9d31..f81935a 100644 --- a/stream_options.go +++ b/stream_options.go @@ -2,6 +2,7 @@ package eventsource import ( "net/http" + "net/url" "time" ) @@ -16,6 +17,7 @@ type streamOptions struct { retryResetInterval time.Duration initialRetryTimeout time.Duration errorHandler StreamErrorHandler + queryParamsFunc *func(existing url.Values) url.Values } // StreamOption is a common interface for optional configuration parameters that can be @@ -24,6 +26,22 @@ type StreamOption interface { apply(s *streamOptions) error } +type dynamicQueryParamsOption struct { + queryParamsFunc func(existing url.Values) url.Values +} + +func (o dynamicQueryParamsOption) apply(s *streamOptions) error { + s.queryParamsFunc = &o.queryParamsFunc + return nil +} + +// StreamOptionDynamicQueryParams returns an option that sets a function to +// generate query parameters each time the stream needs to make a fresh +// connection. +func StreamOptionDynamicQueryParams(f func(existing url.Values) url.Values) StreamOption { + return dynamicQueryParamsOption{queryParamsFunc: f} +} + type readTimeoutOption struct { timeout time.Duration } diff --git a/stream_requests_test.go b/stream_requests_test.go index b4bb723..d98fb52 100644 --- a/stream_requests_test.go +++ b/stream_requests_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net/http" "net/http/httptest" + "net/url" + "strconv" "testing" "time" @@ -45,6 +47,77 @@ func TestStreamSendsLastEventID(t *testing.T) { assert.Equal(t, lastID, r0.Request.Header.Get("Last-Event-ID")) } +func TestCanReplaceStreamQueryParameters(t *testing.T) { + streamHandler, streamControl := httphelpers.SSEHandler(nil) + defer streamControl.Close() + handler, requestsCh := httphelpers.RecordingHandler(streamHandler) + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + option := StreamOptionDynamicQueryParams(func(existing url.Values) url.Values { + return url.Values{ + "filter": []string{"my-custom-filter"}, + "basis": []string{"last-known-basis"}, + } + }) + + stream := mustSubscribe(t, httpServer.URL, option) + defer stream.Close() + + r0 := <-requestsCh + assert.Equal(t, "my-custom-filter", r0.Request.URL.Query().Get("filter")) + assert.Equal(t, "last-known-basis", r0.Request.URL.Query().Get("basis")) +} + +func TestCanUpdateStreamQueryParameters(t *testing.T) { + streamHandler, streamControl := httphelpers.SSEHandler(nil) + defer streamControl.Close() + handler, requestsCh := httphelpers.RecordingHandler(streamHandler) + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + option := StreamOptionDynamicQueryParams(func(existing url.Values) url.Values { + if existing.Has("count") { + count, _ := strconv.Atoi(existing.Get("count")) + + if count == 1 { + existing.Set("count", strconv.Itoa(count+1)) + return existing + } + + return url.Values{} + } + + return url.Values{ + "initial": []string{"payload is set"}, + "count": []string{"1"}, + } + }) + + stream := mustSubscribe(t, httpServer.URL, option, StreamOptionInitialRetry(time.Millisecond)) + defer stream.Close() + + r0 := <-requestsCh + assert.Equal(t, "payload is set", r0.Request.URL.Query().Get("initial")) + assert.Equal(t, "1", r0.Request.URL.Query().Get("count")) + + streamControl.EndAll() + <-stream.Errors // Accept the error to unblock the retry handler + + r1 := <-requestsCh + assert.Equal(t, "payload is set", r1.Request.URL.Query().Get("initial")) + assert.Equal(t, "2", r1.Request.URL.Query().Get("count")) + + streamControl.EndAll() + <-stream.Errors // Accept the error to unblock the retry handler + + r2 := <-requestsCh + assert.False(t, r2.Request.URL.Query().Has("initial")) + assert.False(t, r2.Request.URL.Query().Has("count")) +} + func TestStreamReconnectWithRequestBodySendsBodyTwice(t *testing.T) { body := []byte("my-body")