diff --git a/middleware.go b/middleware.go index f208dd2..eb5e709 100644 --- a/middleware.go +++ b/middleware.go @@ -42,12 +42,20 @@ func Logger(next http.Handler) http.Handler { } // Readiness - middleware for the readiness probe -func Readiness(isReady *atomic.Value) http.HandlerFunc { - return func(w http.ResponseWriter, _ *http.Request) { - if isReady == nil || !isReady.Load().(bool) { - http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) - return - } - w.WriteHeader(http.StatusOK) +func Readiness(endpoint string, isReady *atomic.Value) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) { + if isReady == nil || !isReady.Load().(bool) { + ErrorResponse(w, r, http.StatusServiceUnavailable, nil, "") + return + } + + OkResponse(w) + return + } + + h.ServeHTTP(w, r) + }) } } diff --git a/middleware_test.go b/middleware_test.go index 7a5cc00..c9cabf7 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -32,7 +32,9 @@ func TestReadiness(t *testing.T) { isReady := &atomic.Value{} isReady.Store(false) - ts := httptest.NewServer(Readiness(isReady)) + ts := httptest.NewServer(Readiness("/", isReady)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + JsonResponse(w, []byte("OK")) + }))) defer ts.Close() resp, err := http.Get(ts.URL) diff --git a/server.go b/server.go index 9f7c4bc..9e3f7b1 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package rest import ( "context" "fmt" + "github.com/go-chi/chi/v5" "log" "net/http" "sync/atomic" @@ -27,6 +28,12 @@ func handler(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(".")) } +func ready(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(".")) +} + // Run - will initialize server and run it on provided port func (s *Server) Run(router http.Handler) error { if s.Port == 0 { @@ -48,10 +55,10 @@ func (s *Server) Run(router http.Handler) error { } if router == nil { - mux := http.NewServeMux() + mux := chi.NewRouter() + mux.Use(Readiness("/readiness", s.IsReady)) mux.HandleFunc("/ping", handler) mux.HandleFunc("/liveness", handler) - mux.HandleFunc("/readiness", Readiness(s.IsReady)) router = mux }