diff --git a/client-side.go b/client-side.go index caa38dd7..c53b01d7 100644 --- a/client-side.go +++ b/client-side.go @@ -8,11 +8,11 @@ import ( "net/http" "net/http/httptest" "strconv" - "strings" "net/http/httputil" "github.com/launchdarkly/ld-relay/v6/config" + "github.com/launchdarkly/ld-relay/v6/internal/cors" "github.com/launchdarkly/ld-relay/v6/internal/events" "github.com/launchdarkly/ld-relay/v6/internal/relayenv" "github.com/launchdarkly/ld-relay/v6/internal/util" @@ -57,7 +57,11 @@ func (m clientSideMux) selectClientByUrlParam(next http.Handler) http.Handler { return } - req = req.WithContext(context.WithValue(req.Context(), contextKey, clientCtx)) + reqContext := context.WithValue(req.Context(), contextKey, clientCtx) + // Even though the clientCtx also serves as a CORSContext, we attach it separately here just to keep + // the CORS implementation less reliant on other unrelated implementation details + reqContext = cors.WithCORSContext(reqContext, clientCtx) + req = req.WithContext(reqContext) next.ServeHTTP(w, req) }) } @@ -67,27 +71,6 @@ func (m clientSideMux) getGoals(w http.ResponseWriter, req *http.Request) { clientCtx.proxy.ServeHTTP(w, req) } -var allowedHeadersList = []string{ - "Cache-Control", - "Content-Type", - "Content-Length", - "Accept-Encoding", - "X-LaunchDarkly-User-Agent", - "X-LaunchDarkly-Payload-ID", - "X-LaunchDarkly-Wrapper", - events.EventSchemaHeader, -} - -var allowedHeaders = strings.Join(allowedHeadersList, ",") - -func setCorsHeaders(w http.ResponseWriter, origin string) { - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Allow-Credentials", "false") - w.Header().Set("Access-Control-Max-Age", "300") - w.Header().Set("Access-Control-Allow-Headers", allowedHeaders) - w.Header().Set("Access-Control-Expose-Headers", "Date") -} - const transparent1PixelImgBase64 = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7=" var transparent1PixelImg []byte diff --git a/internal/cors/cors.go b/internal/cors/cors.go index df36cab4..0ed7fa73 100644 --- a/internal/cors/cors.go +++ b/internal/cors/cors.go @@ -1,6 +1,67 @@ package cors +import ( + "context" + "net/http" + "strings" + + "github.com/launchdarkly/ld-relay/v6/internal/events" +) + const ( // The default origin string to use in CORS response headers. DefaultAllowedOrigin = "*" ) + +type contextKeyType string + +const ( + contextKey contextKeyType = "context" + maxAge string = "300" +) + +var allowedHeadersList = []string{ + "Cache-Control", + "Content-Type", + "Content-Length", + "Accept-Encoding", + "X-LaunchDarkly-User-Agent", + "X-LaunchDarkly-Payload-ID", + "X-LaunchDarkly-Wrapper", + events.EventSchemaHeader, +} + +var allowedHeaders = strings.Join(allowedHeadersList, ",") + +// RequestContext represents a scope that has a specific set of allowed origins for CORS requests. This +// can be attached to a request context with WithCORSContext(). +type RequestContext interface { + AllowedOrigins() []string +} + +// GetCORSContext returns the CORSContext that has been attached to this Context with WithCORSContext(), +// or nil if none. +func GetCORSContext(ctx context.Context) RequestContext { + if cc, ok := ctx.Value(contextKey).(RequestContext); ok { + return cc + } + return nil +} + +// WithCORSContext returns a copy of the parent context with the specified CORSContext attached. +func WithCORSContext(parent context.Context, cc RequestContext) context.Context { + if cc == nil { + return parent + } + return context.WithValue(parent, contextKey, cc) +} + +// SetCORSHeaders sets a standard set of CORS headers on an HTTP response. This is meant to be the same +// behavior that the LaunchDarkly service endpoints uses for client-side JS requests. +func SetCORSHeaders(w http.ResponseWriter, origin string) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Credentials", "false") + w.Header().Set("Access-Control-Max-Age", maxAge) + w.Header().Set("Access-Control-Allow-Headers", allowedHeaders) + w.Header().Set("Access-Control-Expose-Headers", "Date") +} diff --git a/internal/cors/cors_test.go b/internal/cors/cors_test.go new file mode 100644 index 00000000..eb006d71 --- /dev/null +++ b/internal/cors/cors_test.go @@ -0,0 +1,43 @@ +package cors + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockCORSContext struct{} + +func (m mockCORSContext) AllowedOrigins() []string { + return nil +} + +func TestCORSContext(t *testing.T) { + t.Run("GetCORSContext when there is no RequestContext returns nil", func(t *testing.T) { + assert.Nil(t, GetCORSContext(context.Background())) + }) + + t.Run("WithCORSContext adds RequestContext to context", func(t *testing.T) { + m := mockCORSContext{} + ctx := WithCORSContext(context.Background(), m) + assert.Equal(t, m, GetCORSContext(ctx)) + }) + + t.Run("WithCORSContext has no effect with nil parameter", func(t *testing.T) { + ctx := WithCORSContext(context.Background(), nil) + assert.Equal(t, context.Background(), ctx) + }) + + t.Run("SetCORSHeaders", func(t *testing.T) { + origin := "http://good.cat" + rr := httptest.ResponseRecorder{} + SetCORSHeaders(&rr, origin) + assert.Equal(t, origin, rr.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "false", rr.Header().Get("Access-Control-Allow-Credentials")) + assert.Equal(t, maxAge, rr.Header().Get("Access-Control-Max-Age")) + assert.Equal(t, allowedHeaders, rr.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, "Date", rr.Header().Get("Access-Control-Expose-Headers")) + }) +} diff --git a/internal/relayenv/env_context.go b/internal/relayenv/env_context.go index c46a973a..4f6e1299 100644 --- a/internal/relayenv/env_context.go +++ b/internal/relayenv/env_context.go @@ -2,6 +2,7 @@ package relayenv import ( "context" + "io" "net/http" "time" @@ -19,6 +20,8 @@ import ( // connection may take a while, so it is possible for the client and store references to be nil if initialization // is not yet complete. type EnvContext interface { + io.Closer + // GetName returns the configured name of the environment. GetName() string diff --git a/internal/relayenv/env_context_impl.go b/internal/relayenv/env_context_impl.go index 8ebe3f9b..c76e763b 100644 --- a/internal/relayenv/env_context_impl.go +++ b/internal/relayenv/env_context_impl.go @@ -214,3 +214,8 @@ func (c *envContextImpl) GetInitError() error { func (c *envContextImpl) IsSecureMode() bool { return c.secureMode } + +func (c *envContextImpl) Close() error { + // This currently isn't used, but will be used in the future when we can dynamically change environments + return nil +} diff --git a/middleware.go b/middleware.go index f6cc0de9..ed74ccd7 100644 --- a/middleware.go +++ b/middleware.go @@ -6,27 +6,15 @@ import ( "encoding/json" "errors" "net/http" - "regexp" "github.com/gorilla/mux" - "github.com/launchdarkly/ld-relay/v6/config" "github.com/launchdarkly/ld-relay/v6/internal/cors" "github.com/launchdarkly/ld-relay/v6/internal/metrics" "github.com/launchdarkly/ld-relay/v6/internal/relayenv" - "github.com/launchdarkly/ld-relay/v6/internal/version" "gopkg.in/launchdarkly/go-sdk-common.v2/lduser" - ld "gopkg.in/launchdarkly/go-server-sdk.v5" ) -var ( - uuidHeaderPattern = regexp.MustCompile(`^(?:api_key )?((?:[a-z]{3}-)?[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12})$`) -) - -type corsContext interface { - AllowedOrigins() []string -} - func chainMiddleware(middlewares ...mux.MiddlewareFunc) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { handler := next @@ -37,56 +25,7 @@ func chainMiddleware(middlewares ...mux.MiddlewareFunc) mux.MiddlewareFunc { } } -type clientMux struct { - clientContextByKey map[config.SDKCredential]relayenv.EnvContext -} - -func (m clientMux) getStatus(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - envs := make(map[string]environmentStatus) - - healthy := true - for _, clientCtx := range m.clientContextByKey { - var status environmentStatus - creds := clientCtx.GetCredentials() - status.SdkKey = obscureKey(creds.SDKKey) - if mobileKey, ok := creds.MobileKey.Get(); ok { - status.MobileKey = obscureKey(mobileKey) - } - status.EnvId = creds.EnvironmentID.StringValue() - client := clientCtx.GetClient() - if client == nil || !client.Initialized() { - status.Status = "disconnected" - healthy = false - } else { - status.Status = "connected" - } - envs[clientCtx.GetName()] = status - } - - resp := struct { - Environments map[string]environmentStatus `json:"environments"` - Status string `json:"status"` - Version string `json:"version"` - ClientVersion string `json:"clientVersion"` - }{ - Environments: envs, - Version: version.Version, - ClientVersion: ld.Version, - } - - if healthy { - resp.Status = "healthy" - } else { - resp.Status = "degraded" - } - - data, _ := json.Marshal(resp) - - w.Write(data) -} - -func (m clientMux) selectClientByAuthorizationKey(sdkKind sdkKind) func(http.Handler) http.Handler { +func selectEnvironmentByAuthorizationKey(sdkKind sdkKind, envs RelayEnvironments) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { credential, err := sdkKind.getSDKCredential(req) @@ -95,7 +34,7 @@ func (m clientMux) selectClientByAuthorizationKey(sdkKind sdkKind) func(http.Han return } - clientCtx := m.clientContextByKey[credential] + clientCtx := envs.GetEnvironment(credential) if clientCtx == nil { w.WriteHeader(http.StatusUnauthorized) @@ -168,24 +107,24 @@ func withGauge(handler http.Handler, measure metrics.Measure) http.Handler { func corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var domains []string - if context, ok := r.Context().Value(contextKey).(corsContext); ok { - domains = context.AllowedOrigins() + if corsContext := cors.GetCORSContext(r.Context()); corsContext != nil { + domains = corsContext.AllowedOrigins() } if len(domains) > 0 { for _, d := range domains { if r.Header.Get("Origin") == d { - setCorsHeaders(w, d) + cors.SetCORSHeaders(w, d) return } } // Not a valid origin, set allowed origin to any allowed origin - setCorsHeaders(w, domains[0]) + cors.SetCORSHeaders(w, domains[0]) } else { origin := cors.DefaultAllowedOrigin if r.Header.Get("Origin") != "" { origin = r.Header.Get("Origin") } - setCorsHeaders(w, origin) + cors.SetCORSHeaders(w, origin) } next.ServeHTTP(w, r) }) diff --git a/middleware_test.go b/middleware_test.go index efa896af..6e03e462 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -19,61 +19,100 @@ func buildPreRoutedRequestWithAuth(key config.SDKCredential) *http.Request { return buildPreRoutedRequest("GET", nil, headers, nil, nil) } -func TestClientMuxRejectsMalformedSDKKeyOrMobileKey(t *testing.T) { - mux := clientMux{ - clientContextByKey: map[config.SDKCredential]relayenv.EnvContext{ - malformedSDKKey: newTestEnvContext("server", false, nil), - malformedMobileKey: newTestEnvContext("mobile", false, nil), - }, +func TestSelectEnvironmentByAuthorizationKey(t *testing.T) { + env1 := newTestEnvContext("env1", false, nil) + env2 := newTestEnvContext("env2", false, nil) + + handlerThatDetectsEnvironment := func(outCh chan<- relayenv.EnvContext) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + outCh <- getClientContext(req) + }) } - req1 := buildPreRoutedRequestWithAuth(malformedSDKKey) - resp1, _ := doRequest(req1, mux.selectClientByAuthorizationKey(serverSdk)(nullHandler())) + t.Run("finds by SDK key", func(t *testing.T) { + envs := testEnvironments{ + testEnvMain.config.SDKKey: env1, + testEnvMobile.config.SDKKey: env2, + } + selector := selectEnvironmentByAuthorizationKey(serverSdk, envs) + envCh := make(chan relayenv.EnvContext, 1) - assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) + req := buildPreRoutedRequestWithAuth(testEnvMain.config.SDKKey) + resp, _ := doRequest(req, selector(handlerThatDetectsEnvironment(envCh))) - req2 := buildPreRoutedRequestWithAuth(malformedMobileKey) - resp2, _ := doRequest(req2, mux.selectClientByAuthorizationKey(serverSdk)(nullHandler())) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, env1, <-envCh) + }) - assert.Equal(t, http.StatusUnauthorized, resp2.StatusCode) -} + t.Run("finds by mobile key", func(t *testing.T) { + envs := testEnvironments{ + testEnvMain.config.SDKKey: env1, + testEnvMobile.config.SDKKey: env2, + testEnvMobile.config.MobileKey: env2, + } + selector := selectEnvironmentByAuthorizationKey(mobileSdk, envs) + envCh := make(chan relayenv.EnvContext, 1) -func TestClientMuxRejectsUnknownSDKKeyOrMobileKey(t *testing.T) { - mux := clientMux{} + req := buildPreRoutedRequestWithAuth(testEnvMobile.config.MobileKey) + resp, _ := doRequest(req, selector(handlerThatDetectsEnvironment(envCh))) - req1 := buildPreRoutedRequestWithAuth(undefinedSDKKey) - resp1, _ := doRequest(req1, mux.selectClientByAuthorizationKey(serverSdk)(nullHandler())) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, env2, <-envCh) + }) - assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) + t.Run("rejects unknown SDK key", func(t *testing.T) { + envs := testEnvironments{testEnvMain.config.SDKKey: env1} + selector := selectEnvironmentByAuthorizationKey(serverSdk, envs) - req2 := buildPreRoutedRequestWithAuth(undefinedMobileKey) - resp2, _ := doRequest(req2, mux.selectClientByAuthorizationKey(serverSdk)(nullHandler())) + req1 := buildPreRoutedRequestWithAuth(undefinedSDKKey) + resp1, _ := doRequest(req1, selector(nullHandler())) - assert.Equal(t, http.StatusUnauthorized, resp2.StatusCode) -} + assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) + }) -func TestClientMuxReturns503IfClientHasNotBeenCreated(t *testing.T) { - ctx := newTestEnvContextWithClientFactory("env", clientFactoryThatFails(errors.New("sorry")), nil) - serverSideMux := clientMux{ - clientContextByKey: map[config.SDKCredential]relayenv.EnvContext{ - testEnvMain.config.SDKKey: ctx, - }, - } - mobileMux := clientMux{ - clientContextByKey: map[config.SDKCredential]relayenv.EnvContext{ - testEnvMobile.config.MobileKey: ctx, - }, - } + t.Run("rejects unknown mobile key", func(t *testing.T) { + envs := testEnvironments{testEnvMain.config.MobileKey: env1} + selector := selectEnvironmentByAuthorizationKey(mobileSdk, envs) + + req1 := buildPreRoutedRequestWithAuth(undefinedMobileKey) + resp1, _ := doRequest(req1, selector(nullHandler())) + + assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) + }) + + t.Run("rejects malformed SDK key", func(t *testing.T) { + envs := testEnvironments{malformedSDKKey: newTestEnvContext("server", false, nil)} + selector := selectEnvironmentByAuthorizationKey(serverSdk, envs) + + req1 := buildPreRoutedRequestWithAuth(malformedSDKKey) + resp1, _ := doRequest(req1, selector(nullHandler())) + + assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) + }) + + t.Run("rejects malformed mobile key", func(t *testing.T) { + envs := testEnvironments{ + malformedSDKKey: newTestEnvContext("server", false, nil), + malformedMobileKey: newTestEnvContext("server", false, nil), + } + selector := selectEnvironmentByAuthorizationKey(mobileSdk, envs) + + req1 := buildPreRoutedRequestWithAuth(malformedMobileKey) + resp1, _ := doRequest(req1, selector(nullHandler())) - req1 := buildPreRoutedRequestWithAuth(testEnvMain.config.SDKKey) - resp1, _ := doRequest(req1, serverSideMux.selectClientByAuthorizationKey(serverSdk)(nullHandler())) + assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) + }) - assert.Equal(t, http.StatusServiceUnavailable, resp1.StatusCode) + t.Run("returns 503 if client has not been created", func(t *testing.T) { + notReadyEnv := newTestEnvContextWithClientFactory("env", clientFactoryThatFails(errors.New("sorry")), nil) + envs := testEnvironments{testEnvMain.config.SDKKey: notReadyEnv} + selector := selectEnvironmentByAuthorizationKey(serverSdk, envs) - req2 := buildPreRoutedRequestWithAuth(testEnvMobile.config.MobileKey) - resp2, _ := doRequest(req2, mobileMux.selectClientByAuthorizationKey(mobileSdk)(nullHandler())) + req := buildPreRoutedRequestWithAuth(testEnvMain.config.SDKKey) + resp, _ := doRequest(req, selector(nullHandler())) - assert.Equal(t, http.StatusServiceUnavailable, resp2.StatusCode) + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + }) } func TestCorsMiddlewareSetsCorrectDefaultHeaders(t *testing.T) { diff --git a/relay.go b/relay.go index 4479e348..b6d6facc 100644 --- a/relay.go +++ b/relay.go @@ -1,29 +1,15 @@ package relay import ( - "crypto/tls" - "errors" - "fmt" "net/http" - "net/http/httputil" - "net/url" - "strings" "gopkg.in/launchdarkly/go-sdk-common.v2/ldtime" "gopkg.in/launchdarkly/go-sdk-common.v2/ldvalue" - "github.com/gorilla/mux" - "github.com/gregjones/httpcache" - - "github.com/launchdarkly/eventsource" "gopkg.in/launchdarkly/go-sdk-common.v2/ldlog" "gopkg.in/launchdarkly/go-sdk-common.v2/ldreason" "github.com/launchdarkly/ld-relay/v6/config" - "github.com/launchdarkly/ld-relay/v6/internal/events" - "github.com/launchdarkly/ld-relay/v6/internal/logging" - "github.com/launchdarkly/ld-relay/v6/internal/metrics" - "github.com/launchdarkly/ld-relay/v6/internal/relayenv" "github.com/launchdarkly/ld-relay/v6/sdkconfig" ) @@ -32,22 +18,12 @@ const ( ldUserAgentHeader = "X-LaunchDarkly-User-Agent" ) -type environmentStatus struct { - SdkKey string `json:"sdkKey"` - EnvId string `json:"envId,omitempty"` - MobileKey string `json:"mobileKey,omitempty"` - Status string `json:"status"` -} - // Relay relays endpoints to and from the LaunchDarkly service type Relay struct { http.Handler - sdkClientMux clientMux - mobileClientMux clientMux - clientSideMux clientSideMux - metricsManager *metrics.Manager - config config.Config - loggers ldlog.Loggers + core *RelayCore + config config.Config + loggers ldlog.Loggers } type evalXResult struct { @@ -64,149 +40,26 @@ type evalXResult struct { // // If any metrics exporters are enabled in c.MetricsConfig, it also registers those in OpenCensus. func NewRelay(c config.Config, loggers ldlog.Loggers, clientFactory sdkconfig.ClientFactoryFunc) (*Relay, error) { - if err := config.ValidateConfig(&c, loggers); err != nil { // in case a not-yet-validated Config was passed to NewRelay - return nil, err - } - - if c.Main.LogLevel.IsDefined() { - loggers.SetMinLevel(c.Main.LogLevel.GetOrElse(ldlog.Info)) - } - - metricsManager, err := metrics.NewManager(c.MetricsConfig, 0, loggers) + core, err := NewRelayCore(c, loggers, clientFactory) if err != nil { - return nil, fmt.Errorf("unable to create metrics manager: %s", err) - } - - makeSSEServer := func() *eventsource.Server { - s := eventsource.NewServer() - s.Gzip = false - s.AllowCORS = true - s.ReplayAll = true - s.MaxConnTime = c.Main.MaxClientConnectionTime.GetOrElse(0) - return s - } - allPublisher := makeSSEServer() - flagsPublisher := makeSSEServer() - pingPublisher := makeSSEServer() - clients := make(map[config.SDKCredential]relayenv.EnvContext) - mobileClients := make(map[config.SDKCredential]relayenv.EnvContext) - - clientSideMux := clientSideMux{ - contextByKey: map[config.SDKCredential]*clientSideContext{}, - } - - if len(c.Environment) == 0 { - return nil, fmt.Errorf("you must specify at least one environment in your configuration") - } - - baseUrl := c.Main.BaseURI.Get() - if baseUrl == nil { - baseUrl, err = url.Parse(config.DefaultBaseURI) - if err != nil { - return nil, errors.New("unexpected error: default base URI is invalid") - } - } - - clientReadyCh := make(chan relayenv.EnvContext, len(c.Environment)) - - for envName, envConfigPtr := range c.Environment { - var envConfig config.EnvConfig - if envConfigPtr != nil { // this is a pointer only because that's how gcfg works; should not be nil - envConfig = *envConfigPtr - } - - dataStoreFactory, err := sdkconfig.ConfigureDataStore(c, envConfig, loggers) - if err != nil { - return nil, err - } - - clientContext, err := relayenv.NewEnvContext( - envName, - envConfig, - c, - clientFactory, - dataStoreFactory, - allPublisher, - flagsPublisher, - pingPublisher, - metricsManager, - loggers, - clientReadyCh, - ) - if err != nil { - return nil, fmt.Errorf(`unable to create client context for "%s": %s`, envName, err) - } - clients[envConfig.SDKKey] = clientContext - if envConfig.MobileKey != "" { - mobileClients[envConfig.MobileKey] = clientContext - } - - if envConfig.EnvID != "" { - allowedOrigins := envConfig.AllowedOrigin.Values() - cachingTransport := httpcache.NewMemoryCacheTransport() - if envConfig.InsecureSkipVerify { - transport := &(*http.DefaultTransport.(*http.Transport)) - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: envConfig.InsecureSkipVerify} // nolint:gas // allow this because the user has to explicitly enable it - cachingTransport.Transport = transport - } - - proxy := &httputil.ReverseProxy{ - Director: func(r *http.Request) { - url := r.URL - url.Scheme = baseUrl.Scheme - url.Host = baseUrl.Host - r.Host = baseUrl.Hostname() - }, - ModifyResponse: func(r *http.Response) error { - // Leave access control to our own cors middleware - for h := range r.Header { - if strings.HasPrefix(strings.ToLower(h), "access-control") { - r.Header.Del(h) - } - } - return nil - }, - Transport: cachingTransport, - } - - clientSideMux.contextByKey[envConfig.EnvID] = &clientSideContext{ - EnvContext: clientContext, - proxy: proxy, - allowedOrigins: allowedOrigins, - } - } + return nil, err } r := Relay{ - sdkClientMux: clientMux{clientContextByKey: clients}, - mobileClientMux: clientMux{clientContextByKey: mobileClients}, - clientSideMux: clientSideMux, - metricsManager: metricsManager, - config: c, - loggers: loggers, + core: core, + config: c, + loggers: loggers, } if c.Main.ExitAlways { loggers.Info("Running in one-shot mode - will exit immediately after initializing environments") // Just wait until all clients have either started or failed, then exit without bothering // to set up HTTP handlers. - numFinished := 0 - failed := false - for numFinished < len(c.Environment) { - ctx := <-clientReadyCh - numFinished++ - if ctx.GetInitError() != nil { - failed = true - } - } - var err error - if failed { - err = errors.New("one or more environments failed to initialize") - } + err := r.core.WaitForAllClients(0) return &r, err } - isDebugLoggingEnabled := c.Main.LogLevel.GetOrElse(ldlog.Info) <= ldlog.Debug - r.Handler = r.makeHandler(isDebugLoggingEnabled) + + r.Handler = core.MakeRouter() return &r, nil } @@ -214,119 +67,6 @@ func NewRelay(c config.Config, loggers ldlog.Loggers, clientFactory sdkconfig.Cl // // Currently this includes only the metrics components; it does not close SDK clients. func (r *Relay) Close() error { - r.metricsManager.Close() + r.core.Close() return nil } - -func (r *Relay) makeHandler(withRequestLogging bool) http.Handler { - router := mux.NewRouter() - router.Use(logging.GlobalContextLoggersMiddleware(r.loggers)) - if withRequestLogging { - router.Use(logging.RequestLoggerMiddleware(r.loggers)) - } - router.HandleFunc("/status", r.sdkClientMux.getStatus).Methods("GET") - - // Client-side evaluation - clientSideMiddlewareStack := chainMiddleware( - corsMiddleware, - r.clientSideMux.selectClientByUrlParam, - requestCountMiddleware(metrics.BrowserRequests)) - - goalsRouter := router.PathPrefix("/sdk/goals").Subrouter() - goalsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(goalsRouter)) - goalsRouter.HandleFunc("/{envId}", r.clientSideMux.getGoals).Methods("GET", "OPTIONS") - - clientSideSdkEvalRouter := router.PathPrefix("/sdk/eval/{envId}/").Subrouter() - clientSideSdkEvalRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideSdkEvalRouter)) - clientSideSdkEvalRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlagsValueOnly(jsClientSdk)).Methods("GET", "OPTIONS") - clientSideSdkEvalRouter.HandleFunc("/user", evaluateAllFeatureFlagsValueOnly(jsClientSdk)).Methods("REPORT", "OPTIONS") - - clientSideSdkEvalXRouter := router.PathPrefix("/sdk/evalx/{envId}/").Subrouter() - clientSideSdkEvalXRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideSdkEvalXRouter)) - clientSideSdkEvalXRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlags(jsClientSdk)).Methods("GET", "OPTIONS") - clientSideSdkEvalXRouter.HandleFunc("/user", evaluateAllFeatureFlags(jsClientSdk)).Methods("REPORT", "OPTIONS") - - serverSideMiddlewareStack := chainMiddleware( - r.sdkClientMux.selectClientByAuthorizationKey(serverSdk), - requestCountMiddleware(metrics.ServerRequests)) - - serverSideSdkRouter := router.PathPrefix("/sdk/").Subrouter() - // (?)TODO: there is a bug in gorilla mux (see see https://github.com/gorilla/mux/pull/378) that means the middleware below - // because it will not be run if it matches any earlier prefix. Until it is fixed, we have to apply the middleware explicitly - // serverSideSdkRouter.Use(serverSideMiddlewareStack) - - serverSideEvalRouter := serverSideSdkRouter.PathPrefix("/eval/").Subrouter() - serverSideEvalRouter.Handle("/users/{user}", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlagsValueOnly(serverSdk)))).Methods("GET") - serverSideEvalRouter.Handle("/user", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlagsValueOnly(serverSdk)))).Methods("REPORT") - - serverSideEvalXRouter := serverSideSdkRouter.PathPrefix("/evalx/").Subrouter() - serverSideEvalXRouter.Handle("/users/{user}", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlags(serverSdk)))).Methods("GET") - serverSideEvalXRouter.Handle("/user", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlags(serverSdk)))).Methods("REPORT") - - // PHP SDK endpoints - serverSideSdkRouter.Handle("/flags", serverSideMiddlewareStack(http.HandlerFunc(pollAllFlagsHandler))).Methods("GET") - serverSideSdkRouter.Handle("/flags/{key}", serverSideMiddlewareStack(http.HandlerFunc(pollFlagHandler))).Methods("GET") - serverSideSdkRouter.Handle("/segments/{key}", serverSideMiddlewareStack(http.HandlerFunc(pollSegmentHandler))).Methods("GET") - - // Mobile evaluation - mobileMiddlewareStack := chainMiddleware( - r.mobileClientMux.selectClientByAuthorizationKey(mobileSdk), - requestCountMiddleware(metrics.MobileRequests)) - - msdkRouter := router.PathPrefix("/msdk/").Subrouter() - msdkRouter.Use(mobileMiddlewareStack) - - msdkEvalRouter := msdkRouter.PathPrefix("/eval/").Subrouter() - msdkEvalRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlagsValueOnly(mobileSdk)).Methods("GET") - msdkEvalRouter.HandleFunc("/user", evaluateAllFeatureFlagsValueOnly(mobileSdk)).Methods("REPORT") - - msdkEvalXRouter := msdkRouter.PathPrefix("/evalx/").Subrouter() - msdkEvalXRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlags(mobileSdk)).Methods("GET") - msdkEvalXRouter.HandleFunc("/user", evaluateAllFeatureFlags(mobileSdk)).Methods("REPORT") - - mobileStreamRouter := router.PathPrefix("/meval").Subrouter() - mobileStreamRouter.Use(mobileMiddlewareStack, streamingMiddleware) - mobileStreamRouter.Handle("", countMobileConns(pingStreamHandlerWithUser(mobileSdk))).Methods("REPORT") - mobileStreamRouter.Handle("/{user}", countMobileConns(pingStreamHandlerWithUser(mobileSdk))).Methods("GET") - - router.Handle("/mping", r.mobileClientMux.selectClientByAuthorizationKey(mobileSdk)( - countMobileConns(streamingMiddleware(pingStreamHandler())))).Methods("GET") - - clientSidePingRouter := router.PathPrefix("/ping/{envId}").Subrouter() - clientSidePingRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSidePingRouter), streamingMiddleware) - clientSidePingRouter.Handle("", countBrowserConns(pingStreamHandler())).Methods("GET", "OPTIONS") - - clientSideStreamEvalRouter := router.PathPrefix("/eval/{envId}").Subrouter() - clientSideStreamEvalRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideStreamEvalRouter), streamingMiddleware) - // For now we implement eval as simply ping - clientSideStreamEvalRouter.Handle("/{user}", countBrowserConns(pingStreamHandlerWithUser(jsClientSdk))).Methods("GET", "OPTIONS") - clientSideStreamEvalRouter.Handle("", countBrowserConns(pingStreamHandlerWithUser(jsClientSdk))).Methods("REPORT", "OPTIONS") - - mobileEventsRouter := router.PathPrefix("/mobile").Subrouter() - mobileEventsRouter.Use(mobileMiddlewareStack) - mobileEventsRouter.Handle("/events/bulk", bulkEventHandler(events.MobileSDKEventsEndpoint)).Methods("POST") - mobileEventsRouter.Handle("/events", bulkEventHandler(events.MobileSDKEventsEndpoint)).Methods("POST") - mobileEventsRouter.Handle("", bulkEventHandler(events.MobileSDKEventsEndpoint)).Methods("POST") - mobileEventsRouter.Handle("/events/diagnostic", bulkEventHandler(events.MobileSDKDiagnosticEventsEndpoint)).Methods("POST") - - clientSideBulkEventsRouter := router.PathPrefix("/events/bulk/{envId}").Subrouter() - clientSideBulkEventsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideBulkEventsRouter)) - clientSideBulkEventsRouter.Handle("", bulkEventHandler(events.JavaScriptSDKEventsEndpoint)).Methods("POST", "OPTIONS") - - clientSideDiagnosticEventsRouter := router.PathPrefix("/events/diagnostic/{envId}").Subrouter() - clientSideDiagnosticEventsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideBulkEventsRouter)) - clientSideDiagnosticEventsRouter.Handle("", bulkEventHandler(events.JavaScriptSDKDiagnosticEventsEndpoint)).Methods("POST", "OPTIONS") - - clientSideImageEventsRouter := router.PathPrefix("/a/{envId}.gif").Subrouter() - clientSideImageEventsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideImageEventsRouter)) - clientSideImageEventsRouter.HandleFunc("", getEventsImage).Methods("GET", "OPTIONS") - - serverSideRouter := router.PathPrefix("").Subrouter() - serverSideRouter.Use(serverSideMiddlewareStack) - serverSideRouter.Handle("/bulk", bulkEventHandler(events.ServerSDKEventsEndpoint)).Methods("POST") - serverSideRouter.Handle("/diagnostic", bulkEventHandler(events.ServerSDKDiagnosticEventsEndpoint)).Methods("POST") - serverSideRouter.Handle("/all", countServerConns(streamingMiddleware(allStreamHandler()))).Methods("GET") - serverSideRouter.Handle("/flags", countServerConns(streamingMiddleware(flagsStreamHandler()))).Methods("GET") - - return router -} diff --git a/relay_core.go b/relay_core.go new file mode 100644 index 00000000..f6650a89 --- /dev/null +++ b/relay_core.go @@ -0,0 +1,264 @@ +package relay + +import ( + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "sync" + "time" + + "github.com/gregjones/httpcache" + + "github.com/launchdarkly/eventsource" + "github.com/launchdarkly/ld-relay/v6/config" + "github.com/launchdarkly/ld-relay/v6/internal/metrics" + "github.com/launchdarkly/ld-relay/v6/internal/relayenv" + "github.com/launchdarkly/ld-relay/v6/sdkconfig" + "gopkg.in/launchdarkly/go-sdk-common.v2/ldlog" +) + +type RelayEnvironments interface { //nolint:golint // yes, we know the package name is also "relay" + GetEnvironment(config.SDKCredential) relayenv.EnvContext + GetAllEnvironments() map[config.SDKKey]relayenv.EnvContext +} + +type RelayCore struct { //nolint:golint // yes, we know the package name is also "relay" + allEnvironments map[config.SDKKey]relayenv.EnvContext + envsByMobileKey map[config.MobileKey]relayenv.EnvContext + envsByEnvID map[config.EnvironmentID]*clientSideContext + metricsManager *metrics.Manager + clientFactory sdkconfig.ClientFactoryFunc + allPublisher *eventsource.Server + flagsPublisher *eventsource.Server + pingPublisher *eventsource.Server + clientInitCh chan relayenv.EnvContext + config config.Config + baseURL url.URL + loggers ldlog.Loggers + lock sync.RWMutex +} + +func NewRelayCore( + c config.Config, + loggers ldlog.Loggers, + clientFactory sdkconfig.ClientFactoryFunc, +) (*RelayCore, error) { + if err := config.ValidateConfig(&c, loggers); err != nil { // in case a not-yet-validated Config was passed to NewRelay + return nil, err + } + + if c.Main.LogLevel.IsDefined() { + loggers.SetMinLevel(c.Main.LogLevel.GetOrElse(ldlog.Info)) + } + + metricsManager, err := metrics.NewManager(c.MetricsConfig, 0, loggers) + if err != nil { + return nil, fmt.Errorf("unable to create metrics manager: %s", err) + } + + clientInitCh := make(chan relayenv.EnvContext, len(c.Environment)) + + r := RelayCore{ + allEnvironments: make(map[config.SDKKey]relayenv.EnvContext), + envsByMobileKey: make(map[config.MobileKey]relayenv.EnvContext), + envsByEnvID: make(map[config.EnvironmentID]*clientSideContext), + metricsManager: metricsManager, + clientFactory: clientFactory, + clientInitCh: clientInitCh, + config: c, + loggers: loggers, + } + + makeSSEServer := func() *eventsource.Server { + s := eventsource.NewServer() + s.Gzip = false + s.AllowCORS = true + s.ReplayAll = true + s.MaxConnTime = c.Main.MaxClientConnectionTime.GetOrElse(0) + return s + } + r.allPublisher = makeSSEServer() + r.flagsPublisher = makeSSEServer() + r.pingPublisher = makeSSEServer() + + if len(c.Environment) == 0 { + return nil, fmt.Errorf("you must specify at least one environment in your configuration") + } + + if c.Main.BaseURI.IsDefined() { + r.baseURL = *c.Main.BaseURI.Get() + } else { + u, err := url.Parse(config.DefaultBaseURI) + if err != nil { + return nil, errors.New("unexpected error: default base URI is invalid") + } + r.baseURL = *u + } + + for envName, envConfig := range c.Environment { + if envConfig == nil { + loggers.Warnf("environment config was nil for environment %q; ignoring", envName) + continue + } + err := r.AddEnvironment(envName, *envConfig) + if err != nil { + for _, env := range r.allEnvironments { + _ = env.Close() + } + return nil, err + } + } + + return &r, nil +} + +func (r *RelayCore) AddEnvironment(envName string, envConfig config.EnvConfig) error { + r.lock.Lock() + defer r.lock.Unlock() + + dataStoreFactory, err := sdkconfig.ConfigureDataStore(r.config, envConfig, r.loggers) + if err != nil { + return err + } + + clientContext, err := relayenv.NewEnvContext( + envName, + envConfig, + r.config, + r.clientFactory, + dataStoreFactory, + r.allPublisher, + r.flagsPublisher, + r.pingPublisher, + r.metricsManager, + r.loggers, + r.clientInitCh, + ) + if err != nil { + return fmt.Errorf(`unable to create client context for "%s": %s`, envName, err) + } + r.allEnvironments[envConfig.SDKKey] = clientContext + if envConfig.MobileKey != "" { + r.envsByMobileKey[envConfig.MobileKey] = clientContext + } + + if envConfig.EnvID != "" { + allowedOrigins := envConfig.AllowedOrigin.Values() + cachingTransport := httpcache.NewMemoryCacheTransport() + if envConfig.InsecureSkipVerify { + tlsConfig := &tls.Config{InsecureSkipVerify: envConfig.InsecureSkipVerify} // nolint:gas // allow this because the user has to explicitly enable it + defaultTransport := http.DefaultTransport.(*http.Transport) + transport := &http.Transport{ // we can't just copy defaultTransport all at once because it has a Mutex + Proxy: defaultTransport.Proxy, + DialContext: defaultTransport.DialContext, + ForceAttemptHTTP2: defaultTransport.ForceAttemptHTTP2, + MaxIdleConns: defaultTransport.MaxIdleConns, + IdleConnTimeout: defaultTransport.IdleConnTimeout, + TLSClientConfig: tlsConfig, + TLSHandshakeTimeout: defaultTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, + } + cachingTransport.Transport = transport + } + + proxy := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + url := req.URL + url.Scheme = r.baseURL.Scheme + url.Host = r.baseURL.Host + req.Host = r.baseURL.Hostname() + }, + ModifyResponse: func(resp *http.Response) error { + // Leave access control to our own cors middleware + for h := range resp.Header { + if strings.HasPrefix(strings.ToLower(h), "access-control") { + resp.Header.Del(h) + } + } + return nil + }, + Transport: cachingTransport, + } + + r.envsByEnvID[envConfig.EnvID] = &clientSideContext{ + EnvContext: clientContext, + proxy: proxy, + allowedOrigins: allowedOrigins, + } + } + + return nil +} + +func (r *RelayCore) GetEnvironment(credential config.SDKCredential) relayenv.EnvContext { + r.lock.RLock() + defer r.lock.RUnlock() + + switch c := credential.(type) { + case config.SDKKey: + return r.allEnvironments[c] + case config.MobileKey: + return r.envsByMobileKey[c] + case config.EnvironmentID: + return r.envsByEnvID[c] + default: + return nil + } +} + +func (r *RelayCore) GetAllEnvironments() map[config.SDKKey]relayenv.EnvContext { + r.lock.RLock() + defer r.lock.RUnlock() + + ret := make(map[config.SDKKey]relayenv.EnvContext, len(r.allEnvironments)) + for k, v := range r.allEnvironments { + ret[k] = v + } + return ret +} + +func (r *RelayCore) WaitForAllClients(timeout time.Duration) error { + numEnvironments := len(r.allEnvironments) + numFinished := 0 + + var timeoutCh <-chan time.Time + if timeout > 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + timeoutCh = timer.C + } + + resultCh := make(chan bool, 1) + go func() { + failed := false + for numFinished < numEnvironments { + ctx := <-r.clientInitCh + numFinished++ + if ctx.GetInitError() != nil { + failed = true + } + } + resultCh <- failed + }() + + select { + case failed := <-resultCh: + if failed { + return errors.New("one or more environments failed to initialize") + } + return nil + case <-timeoutCh: + return errors.New("timed out waiting for environments to initialize") + } +} + +func (r *RelayCore) Close() { + r.metricsManager.Close() + for _, env := range r.allEnvironments { + _ = env.Close() + } +} diff --git a/relay_core_routes.go b/relay_core_routes.go new file mode 100644 index 00000000..839bcf2a --- /dev/null +++ b/relay_core_routes.go @@ -0,0 +1,135 @@ +package relay + +import ( + "net/http" + + "github.com/gorilla/mux" + + "gopkg.in/launchdarkly/go-sdk-common.v2/ldlog" + + "github.com/launchdarkly/ld-relay/v6/config" + "github.com/launchdarkly/ld-relay/v6/internal/events" + "github.com/launchdarkly/ld-relay/v6/internal/logging" + "github.com/launchdarkly/ld-relay/v6/internal/metrics" +) + +func (r *RelayCore) MakeRouter() *mux.Router { + clientSideMux := clientSideMux{contextByKey: map[config.SDKCredential]*clientSideContext{}} + for envID, csc := range r.envsByEnvID { + clientSideMux.contextByKey[envID] = csc + } + + router := mux.NewRouter() + router.Use(logging.GlobalContextLoggersMiddleware(r.loggers)) + if r.loggers.GetMinLevel() == ldlog.Debug { + router.Use(logging.RequestLoggerMiddleware(r.loggers)) + } + router.Handle("/status", statusHandler(r)).Methods("GET") + + sdkKeySelector := selectEnvironmentByAuthorizationKey(serverSdk, r) + mobileKeySelector := selectEnvironmentByAuthorizationKey(mobileSdk, r) + + // Client-side evaluation + clientSideMiddlewareStack := chainMiddleware( + corsMiddleware, + clientSideMux.selectClientByUrlParam, + requestCountMiddleware(metrics.BrowserRequests)) + + goalsRouter := router.PathPrefix("/sdk/goals").Subrouter() + goalsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(goalsRouter)) + goalsRouter.HandleFunc("/{envId}", clientSideMux.getGoals).Methods("GET", "OPTIONS") + + clientSideSdkEvalRouter := router.PathPrefix("/sdk/eval/{envId}/").Subrouter() + clientSideSdkEvalRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideSdkEvalRouter)) + clientSideSdkEvalRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlagsValueOnly(jsClientSdk)).Methods("GET", "OPTIONS") + clientSideSdkEvalRouter.HandleFunc("/user", evaluateAllFeatureFlagsValueOnly(jsClientSdk)).Methods("REPORT", "OPTIONS") + + clientSideSdkEvalXRouter := router.PathPrefix("/sdk/evalx/{envId}/").Subrouter() + clientSideSdkEvalXRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideSdkEvalXRouter)) + clientSideSdkEvalXRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlags(jsClientSdk)).Methods("GET", "OPTIONS") + clientSideSdkEvalXRouter.HandleFunc("/user", evaluateAllFeatureFlags(jsClientSdk)).Methods("REPORT", "OPTIONS") + + serverSideMiddlewareStack := chainMiddleware( + sdkKeySelector, + requestCountMiddleware(metrics.ServerRequests)) + + serverSideSdkRouter := router.PathPrefix("/sdk/").Subrouter() + // (?)TODO: there is a bug in gorilla mux (see see https://github.com/gorilla/mux/pull/378) that means the middleware below + // because it will not be run if it matches any earlier prefix. Until it is fixed, we have to apply the middleware explicitly + // serverSideSdkRouter.Use(serverSideMiddlewareStack) + + serverSideEvalRouter := serverSideSdkRouter.PathPrefix("/eval/").Subrouter() + serverSideEvalRouter.Handle("/users/{user}", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlagsValueOnly(serverSdk)))).Methods("GET") + serverSideEvalRouter.Handle("/user", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlagsValueOnly(serverSdk)))).Methods("REPORT") + + serverSideEvalXRouter := serverSideSdkRouter.PathPrefix("/evalx/").Subrouter() + serverSideEvalXRouter.Handle("/users/{user}", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlags(serverSdk)))).Methods("GET") + serverSideEvalXRouter.Handle("/user", serverSideMiddlewareStack(http.HandlerFunc(evaluateAllFeatureFlags(serverSdk)))).Methods("REPORT") + + // PHP SDK endpoints + serverSideSdkRouter.Handle("/flags", serverSideMiddlewareStack(http.HandlerFunc(pollAllFlagsHandler))).Methods("GET") + serverSideSdkRouter.Handle("/flags/{key}", serverSideMiddlewareStack(http.HandlerFunc(pollFlagHandler))).Methods("GET") + serverSideSdkRouter.Handle("/segments/{key}", serverSideMiddlewareStack(http.HandlerFunc(pollSegmentHandler))).Methods("GET") + + // Mobile evaluation + mobileMiddlewareStack := chainMiddleware( + mobileKeySelector, + requestCountMiddleware(metrics.MobileRequests)) + + msdkRouter := router.PathPrefix("/msdk/").Subrouter() + msdkRouter.Use(mobileMiddlewareStack) + + msdkEvalRouter := msdkRouter.PathPrefix("/eval/").Subrouter() + msdkEvalRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlagsValueOnly(mobileSdk)).Methods("GET") + msdkEvalRouter.HandleFunc("/user", evaluateAllFeatureFlagsValueOnly(mobileSdk)).Methods("REPORT") + + msdkEvalXRouter := msdkRouter.PathPrefix("/evalx/").Subrouter() + msdkEvalXRouter.HandleFunc("/users/{user}", evaluateAllFeatureFlags(mobileSdk)).Methods("GET") + msdkEvalXRouter.HandleFunc("/user", evaluateAllFeatureFlags(mobileSdk)).Methods("REPORT") + + mobileStreamRouter := router.PathPrefix("/meval").Subrouter() + mobileStreamRouter.Use(mobileMiddlewareStack, streamingMiddleware) + mobileStreamRouter.Handle("", countMobileConns(pingStreamHandlerWithUser(mobileSdk))).Methods("REPORT") + mobileStreamRouter.Handle("/{user}", countMobileConns(pingStreamHandlerWithUser(mobileSdk))).Methods("GET") + + router.Handle("/mping", mobileKeySelector( + countMobileConns(streamingMiddleware(pingStreamHandler())))).Methods("GET") + + clientSidePingRouter := router.PathPrefix("/ping/{envId}").Subrouter() + clientSidePingRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSidePingRouter), streamingMiddleware) + clientSidePingRouter.Handle("", countBrowserConns(pingStreamHandler())).Methods("GET", "OPTIONS") + + clientSideStreamEvalRouter := router.PathPrefix("/eval/{envId}").Subrouter() + clientSideStreamEvalRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideStreamEvalRouter), streamingMiddleware) + // For now we implement eval as simply ping + clientSideStreamEvalRouter.Handle("/{user}", countBrowserConns(pingStreamHandlerWithUser(jsClientSdk))).Methods("GET", "OPTIONS") + clientSideStreamEvalRouter.Handle("", countBrowserConns(pingStreamHandlerWithUser(jsClientSdk))).Methods("REPORT", "OPTIONS") + + mobileEventsRouter := router.PathPrefix("/mobile").Subrouter() + mobileEventsRouter.Use(mobileMiddlewareStack) + mobileEventsRouter.Handle("/events/bulk", bulkEventHandler(events.MobileSDKEventsEndpoint)).Methods("POST") + mobileEventsRouter.Handle("/events", bulkEventHandler(events.MobileSDKEventsEndpoint)).Methods("POST") + mobileEventsRouter.Handle("", bulkEventHandler(events.MobileSDKEventsEndpoint)).Methods("POST") + mobileEventsRouter.Handle("/events/diagnostic", bulkEventHandler(events.MobileSDKDiagnosticEventsEndpoint)).Methods("POST") + + clientSideBulkEventsRouter := router.PathPrefix("/events/bulk/{envId}").Subrouter() + clientSideBulkEventsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideBulkEventsRouter)) + clientSideBulkEventsRouter.Handle("", bulkEventHandler(events.JavaScriptSDKEventsEndpoint)).Methods("POST", "OPTIONS") + + clientSideDiagnosticEventsRouter := router.PathPrefix("/events/diagnostic/{envId}").Subrouter() + clientSideDiagnosticEventsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideBulkEventsRouter)) + clientSideDiagnosticEventsRouter.Handle("", bulkEventHandler(events.JavaScriptSDKDiagnosticEventsEndpoint)).Methods("POST", "OPTIONS") + + clientSideImageEventsRouter := router.PathPrefix("/a/{envId}.gif").Subrouter() + clientSideImageEventsRouter.Use(clientSideMiddlewareStack, mux.CORSMethodMiddleware(clientSideImageEventsRouter)) + clientSideImageEventsRouter.HandleFunc("", getEventsImage).Methods("GET", "OPTIONS") + + serverSideRouter := router.PathPrefix("").Subrouter() + serverSideRouter.Use(serverSideMiddlewareStack) + serverSideRouter.Handle("/bulk", bulkEventHandler(events.ServerSDKEventsEndpoint)).Methods("POST") + serverSideRouter.Handle("/diagnostic", bulkEventHandler(events.ServerSDKDiagnosticEventsEndpoint)).Methods("POST") + serverSideRouter.Handle("/all", countServerConns(streamingMiddleware(allStreamHandler()))).Methods("GET") + serverSideRouter.Handle("/flags", countServerConns(streamingMiddleware(flagsStreamHandler()))).Methods("GET") + + return router +} diff --git a/relay_core_routes_test.go b/relay_core_routes_test.go new file mode 100644 index 00000000..aecc3ff1 --- /dev/null +++ b/relay_core_routes_test.go @@ -0,0 +1,49 @@ +package relay + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + c "github.com/launchdarkly/ld-relay/v6/config" + "gopkg.in/launchdarkly/go-sdk-common.v2/ldlog" + "gopkg.in/launchdarkly/go-sdk-common.v2/ldlogtest" +) + +func TestRequestLogging(t *testing.T) { + url := "http://localhost/status" // must be a route that exists - not-found paths currently aren't logged + + t.Run("requests are not logged by default", func(t *testing.T) { + config := c.Config{ + Environment: makeEnvConfigs(testEnvMain), + } + mockLog := ldlogtest.NewMockLog() + core, err := NewRelayCore(config, mockLog.Loggers, fakeLDClientFactory(true)) + require.NoError(t, err) + defer core.Close() + + handler := core.MakeRouter() + req, _ := http.NewRequest("GET", url, nil) + _, _ = doRequest(req, handler) + + mockLog.AssertMessageMatch(t, false, ldlog.Debug, "method=GET url="+url) + }) + + t.Run("requests are logged when debug logging is enabled", func(t *testing.T) { + config := c.Config{ + Main: c.MainConfig{LogLevel: c.NewOptLogLevel(ldlog.Debug)}, + Environment: makeEnvConfigs(testEnvMain), + } + mockLog := ldlogtest.NewMockLog() + core, err := NewRelayCore(config, mockLog.Loggers, fakeLDClientFactory(true)) + require.NoError(t, err) + defer core.Close() + + handler := core.MakeRouter() + req, _ := http.NewRequest("GET", url, nil) + _, _ = doRequest(req, handler) + + mockLog.AssertMessageMatch(t, true, ldlog.Debug, "method=GET url="+url) + }) +} diff --git a/relay_core_test.go b/relay_core_test.go new file mode 100644 index 00000000..6468d6e4 --- /dev/null +++ b/relay_core_test.go @@ -0,0 +1,115 @@ +package relay + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + c "github.com/launchdarkly/ld-relay/v6/config" + "github.com/launchdarkly/ld-relay/v6/sdkconfig" + "gopkg.in/launchdarkly/go-sdk-common.v2/ldlog" + ld "gopkg.in/launchdarkly/go-server-sdk.v5" +) + +func TestNewRelayCoreRejectsConfigWithContradictoryProperties(t *testing.T) { + // it is an error to enable TLS but not provide a cert or key + config := c.Config{Main: c.MainConfig{TLSEnabled: true}} + core, err := NewRelayCore(config, ldlog.NewDefaultLoggers(), fakeLDClientFactory(true)) + require.Error(t, err) + assert.Contains(t, err.Error(), "TLS cert") + assert.Nil(t, core) +} + +func TestNewRelayCoreRejectsConfigWithNoEnvironments(t *testing.T) { + config := c.Config{} + core, err := NewRelayCore(config, ldlog.NewDefaultLoggers(), fakeLDClientFactory(true)) + require.Error(t, err) + assert.Contains(t, err.Error(), "you must specify at least one environment") + assert.Nil(t, core) +} + +func TestRelayCoreGetEnvironment(t *testing.T) { + config := c.Config{ + Environment: makeEnvConfigs(testEnvMain, testEnvMobile, testEnvClientSide), + } + core, err := NewRelayCore(config, ldlog.NewDefaultLoggers(), fakeLDClientFactory(true)) + require.NoError(t, err) + defer core.Close() + + if assert.NotNil(t, core.GetEnvironment(testEnvMain.config.SDKKey)) { + assert.Equal(t, testEnvMain.name, core.GetEnvironment(testEnvMain.config.SDKKey).GetName()) + } + if assert.NotNil(t, core.GetEnvironment(testEnvMobile.config.SDKKey)) { + assert.Equal(t, testEnvMobile.name, core.GetEnvironment(testEnvMobile.config.SDKKey).GetName()) + } + if assert.NotNil(t, core.GetEnvironment(testEnvClientSide.config.SDKKey)) { + assert.Equal(t, testEnvClientSide.name, core.GetEnvironment(testEnvClientSide.config.SDKKey).GetName()) + } + + if assert.NotNil(t, core.GetEnvironment(testEnvMobile.config.MobileKey)) { + assert.Equal(t, testEnvMobile.name, core.GetEnvironment(testEnvMobile.config.MobileKey).GetName()) + } + + if assert.NotNil(t, core.GetEnvironment(testEnvClientSide.config.EnvID)) { + assert.Equal(t, testEnvClientSide.name, core.GetEnvironment(testEnvClientSide.config.EnvID).GetName()) + } + + assert.Nil(t, core.GetEnvironment(undefinedSDKKey)) + + assert.Nil(t, core.GetEnvironment(unsupportedSDKCredential{})) +} + +func TestRelayCoreGetAllEnvironments(t *testing.T) { + config := c.Config{ + Environment: makeEnvConfigs(testEnvMain, testEnvMobile, testEnvClientSide), + } + core, err := NewRelayCore(config, ldlog.NewDefaultLoggers(), fakeLDClientFactory(true)) + require.NoError(t, err) + defer core.Close() + + envs := core.GetAllEnvironments() + assert.Len(t, envs, 3) + if assert.NotNil(t, envs[testEnvMain.config.SDKKey]) { + assert.Equal(t, testEnvMain.name, envs[testEnvMain.config.SDKKey].GetName()) + } + if assert.NotNil(t, envs[testEnvMobile.config.SDKKey]) { + assert.Equal(t, testEnvMobile.name, envs[testEnvMobile.config.SDKKey].GetName()) + } + if assert.NotNil(t, envs[testEnvClientSide.config.SDKKey]) { + assert.Equal(t, testEnvClientSide.name, envs[testEnvClientSide.config.SDKKey].GetName()) + } +} + +func TestRelayCoreWaitForAllEnvironments(t *testing.T) { + config := c.Config{ + Environment: makeEnvConfigs(testEnvMain, testEnvMobile), + } + + t.Run("returns nil if all environments initialize successfully", func(t *testing.T) { + core, err := NewRelayCore(config, ldlog.NewDefaultLoggers(), fakeLDClientFactory(true)) + require.NoError(t, err) + defer core.Close() + + err = core.WaitForAllClients(time.Second) + assert.NoError(t, err) + }) + + t.Run("returns error if any environment does not initialize successfully", func(t *testing.T) { + oneEnvFails := func(sdkKey c.SDKKey, config ld.Config) (sdkconfig.LDClientContext, error) { + shouldFail := sdkKey == testEnvMobile.config.SDKKey + if shouldFail { + return clientFactoryThatFails(errors.New("sorry"))(sdkKey, config) + } + return fakeLDClientFactory(true)(sdkKey, config) + } + core, err := NewRelayCore(config, ldlog.NewDefaultLoggers(), oneEnvFails) + require.NoError(t, err) + defer core.Close() + + err = core.WaitForAllClients(time.Second) + assert.Error(t, err) + }) +} diff --git a/endpoints.go b/relay_endpoints.go similarity index 100% rename from endpoints.go rename to relay_endpoints.go diff --git a/relay_endpoints_status.go b/relay_endpoints_status.go new file mode 100644 index 00000000..6ac47ad4 --- /dev/null +++ b/relay_endpoints_status.go @@ -0,0 +1,63 @@ +package relay + +import ( + "encoding/json" + "net/http" + + "github.com/launchdarkly/ld-relay/v6/internal/version" + ld "gopkg.in/launchdarkly/go-server-sdk.v5" +) + +type environmentStatus struct { + SdkKey string `json:"sdkKey"` + EnvId string `json:"envId,omitempty"` + MobileKey string `json:"mobileKey,omitempty"` + Status string `json:"status"` +} + +func statusHandler(core *RelayCore) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + envs := make(map[string]environmentStatus) + + healthy := true + for _, clientCtx := range core.GetAllEnvironments() { + var status environmentStatus + creds := clientCtx.GetCredentials() + status.SdkKey = obscureKey(creds.SDKKey) + if mobileKey, ok := creds.MobileKey.Get(); ok { + status.MobileKey = obscureKey(mobileKey) + } + status.EnvId = creds.EnvironmentID.StringValue() + client := clientCtx.GetClient() + if client == nil || !client.Initialized() { + status.Status = "disconnected" + healthy = false + } else { + status.Status = "connected" + } + envs[clientCtx.GetName()] = status + } + + resp := struct { + Environments map[string]environmentStatus `json:"environments"` + Status string `json:"status"` + Version string `json:"version"` + ClientVersion string `json:"clientVersion"` + }{ + Environments: envs, + Version: version.Version, + ClientVersion: ld.Version, + } + + if healthy { + resp.Status = "healthy" + } else { + resp.Status = "degraded" + } + + data, _ := json.Marshal(resp) + + w.Write(data) + }) +} diff --git a/sdk_kinds.go b/sdk_kinds.go index a619a82f..8b8aef88 100644 --- a/sdk_kinds.go +++ b/sdk_kinds.go @@ -3,6 +3,7 @@ package relay import ( "errors" "net/http" + "strings" "github.com/gorilla/mux" @@ -43,12 +44,11 @@ func (s sdkKind) getSDKCredential(req *http.Request) (config.SDKCredential, erro func fetchAuthToken(req *http.Request) (string, error) { authHdr := req.Header.Get("Authorization") - match := uuidHeaderPattern.FindStringSubmatch(authHdr) - - // successfully matched UUID from header - if len(match) == 2 { - return match[1], nil + if strings.HasPrefix(authHdr, "api_key ") { + authHdr = strings.TrimSpace(strings.TrimPrefix(authHdr, "api_key ")) } - - return "", errors.New("no valid token found") + if authHdr == "" || strings.Contains(authHdr, " ") { + return "", errors.New("no valid token found") + } + return authHdr, nil } diff --git a/testutils_components_test.go b/testutils_components_test.go index 7d731c9a..4ef77737 100644 --- a/testutils_components_test.go +++ b/testutils_components_test.go @@ -19,6 +19,22 @@ import ( var emptyStore = sharedtest.NewInMemoryStore() var emptyStoreAdapter = store.NewSSERelayDataStoreAdapterWithExistingStore(emptyStore) +type testEnvironments map[config.SDKCredential]relayenv.EnvContext + +func (t testEnvironments) GetEnvironment(c config.SDKCredential) relayenv.EnvContext { + return t[c] +} + +func (t testEnvironments) GetAllEnvironments() map[config.SDKKey]relayenv.EnvContext { + ret := make(map[config.SDKKey]relayenv.EnvContext) + for k, v := range t { + if sk, ok := k.(config.SDKKey); ok { + ret[sk] = v + } + } + return ret +} + func clientFactoryThatFails(err error) sdkconfig.ClientFactoryFunc { return func(sdkKey config.SDKKey, config ld.Config) (sdkconfig.LDClientContext, error) { return nil, err diff --git a/testutils_values_test.go b/testutils_values_test.go index 136cc8b1..28153f6b 100644 --- a/testutils_values_test.go +++ b/testutils_values_test.go @@ -28,6 +28,10 @@ type testFlag struct { isExperiment bool } +type unsupportedSDKCredential struct{} // implements config.SDKCredential + +func (k unsupportedSDKCredential) GetAuthorizationHeaderValue() string { return "" } + // Returns a key matching the UUID header pattern func key() config.MobileKey { return "mob-ffffffff-ffff-4fff-afff-ffffffffffff" @@ -43,10 +47,9 @@ const ( undefinedMobileKey = config.MobileKey("mob-99999999-9999-4999-8999-999999999999") undefinedEnvID = config.EnvironmentID("999999999999999999999999") - // The "malformed" values do not pass the basic regex match for their types. - malformedSDKKey = config.SDKKey("sdk-no") - malformedMobileKey = config.MobileKey("mob-no") - malformedEnvId = config.EnvironmentID("env-no") + // The "malformed" values contain an unsupported authorization scheme. + malformedSDKKey = config.SDKKey("fake_key sdk-99999999-9999-4999-8999-999999999999") + malformedMobileKey = config.MobileKey("fake_key mob-99999999-9999-4999-8999-999999999999") ) var testEnvMain = testEnv{