diff --git a/.dockerignore b/.dockerignore index 1c8deaa6..c9ee771f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,8 @@ -.git -.gitignore -Dockerfile -README.md -docs -integration_tests +# +# Only files that are untracked by Git should be added here. +# +# The builder container needs to see a pristine checkout, otherwise +# vcs.modified in the BuildInfo will always be true, i.e. the build will always +# be marked as "dirty". +# +/router diff --git a/.gitignore b/.gitignore index aeed59a5..8421b565 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -router +/router __build diff --git a/.golangci.yml b/.golangci.yml index 0060b70f..916bb7f4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,9 +20,15 @@ linters: - promlinter - reassign - revive + - stylecheck - tenv - tparallel - unconvert - usestdlibvars - wastedassign - zerologlint +linters-settings: + stylecheck: + dot-import-whitelist: + - github.com/onsi/ginkgo/v2 + - github.com/onsi/gomega diff --git a/Dockerfile b/Dockerfile index 6929371d..a2a9728d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,22 @@ -FROM golang:1.20-alpine AS builder +ARG go_registry="" +ARG go_version=1.20 +ARG go_tag_suffix=-alpine + +FROM ${go_registry}golang:${go_version}${go_tag_suffix} AS builder ARG TARGETARCH TARGETOS +ARG GOARCH=$TARGETARCH GOOS=$TARGETOS +ARG CGO_ENABLED=0 +ARG GOFLAGS="-trimpath" +ARG go_ldflags="-s -w" +# Go needs git for `-buildvcs`, but the alpine version lacks git :( It's still +# way cheaper to `apk add git` than to pull the Debian-based golang image. +# hadolint ignore=DL3018 +RUN apk add --no-cache git WORKDIR /src COPY . ./ -RUN CGO_ENABLED=0 GOARCH=$TARGETARCH GOOS=$TARGETOS go build -trimpath -ldflags="-s -w" +RUN go build -ldflags="$go_ldflags" && \ + ./router -version && \ + go version -m ./router FROM scratch COPY --from=builder /src/router /bin/router diff --git a/Makefile b/Makefile index 29c106e2..949be474 100644 --- a/Makefile +++ b/Makefile @@ -1,53 +1,35 @@ -.PHONY: all build test unit_tests integration_tests clean start_mongo clean_mongo clean_mongo_again +.PHONY: all clean build test lint unit_tests integration_tests start_mongo stop_mongo update_deps +.NOTPARALLEL: -BINARY ?= router -SHELL := /bin/bash +TARGET_MODULE := router +GO_BUILD_ENV := CGO_ENABLED=0 +SHELL := /bin/dash -ifdef RELEASE_VERSION -VERSION := $(RELEASE_VERSION) -else -VERSION := $(shell git describe --always | tr -d '\n'; test -z "`git status --porcelain`" || echo '-dirty') -endif - -all: build test +all: build clean: - rm -f $(BINARY) + rm -f $(TARGET_MODULE) build: - go build -ldflags "-X main.version=$(VERSION)" -o $(BINARY) + env $(GO_BUILD_ENV) go build + ./$(TARGET_MODULE) -version + +test: lint unit_tests integration_tests lint: golangci-lint run -test: start_mongo unit_tests integration_tests clean_mongo_again - -unit_tests: build +unit_tests: go test -race $$(go list ./... | grep -v integration_tests) -integration_tests: start_mongo build - ROUTER_PUBADDR=localhost:8080 \ - ROUTER_APIADDR=localhost:8081 \ - go test -race -v ./integration_tests - -start_mongo: clean_mongo - @if ! docker run --rm --name router-mongo -dp 27017:27017 mongo:2.4 --replSet rs0 --quiet; then \ - echo 'Failed to start mongo; if using Docker Desktop, try:' ; \ - echo ' - disabling Settings -> Features in development -> Use containerd' ; \ - echo ' - enabling Settings -> Features in development -> Use Rosetta' ; \ - exit 1 ; \ - fi - @echo -n Waiting for mongo - @for n in {1..30}; do \ - if docker exec router-mongo mongo --quiet --eval 'rs.initiate()' >/dev/null 2>&1; then \ - sleep 1; \ - echo ; \ - break ; \ - fi ; \ - echo -n . ; \ - sleep 1 ; \ - done ; \ - -clean_mongo clean_mongo_again: - docker rm -f router-mongo >/dev/null 2>&1 || true - @sleep 1 # Docker doesn't queue commands so it races with itself :( +integration_tests: build start_mongo + go test -race -v ./integration_tests + +start_mongo: + ./mongo.sh start + +stop_mongo: + ./mongo.sh stop + +update_deps: + go get -t -u ./... && go mod tidy && go mod vendor diff --git a/README.md b/README.md index 8ceea467..35b36d0c 100644 --- a/README.md +++ b/README.md @@ -59,17 +59,17 @@ make lint ### Debug output -To see debug messages when running tests, set both the `DEBUG` and -`DEBUG_ROUTER` environment variables. +To see debug messages when running tests, set both the `ROUTER_DEBUG` and +`ROUTER_DEBUG_TESTS` environment variables: ```sh -export DEBUG=1 DEBUG_ROUTER=1 +export ROUTER_DEBUG=1 ROUTER_DEBUG_TESTS=1 ``` or equivalently for a single run: ```sh -DEBUG=1 DEBUG_ROUTER=1 make test +ROUTER_DEBUG=1 ROUTER_DEBUG_TESTS=1 make test ``` ### Update the dependencies @@ -79,7 +79,7 @@ This project uses [Go Modules](https://github.com/golang/go/wiki/Modules) to ven 1. Update all the dependencies, including test dependencies, in your working copy: ```sh - go get -t -u ./... && go mod tidy && go mod vendor + make update_deps ``` 1. Check for any errors and commit. diff --git a/handlers/backend_handler.go b/handlers/backend_handler.go index e4ca9a7c..7df7663b 100644 --- a/handlers/backend_handler.go +++ b/handlers/backend_handler.go @@ -134,7 +134,7 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e var responseCode int var startTime = time.Now() - BackendHandlerRequestCountMetric.With(prometheus.Labels{ + backendRequestCountMetric.With(prometheus.Labels{ "backend_id": bt.backendID, "request_method": req.Method, }).Inc() @@ -142,7 +142,7 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e defer func() { durationSeconds := time.Since(startTime).Seconds() - BackendHandlerResponseDurationSecondsMetric.With(prometheus.Labels{ + backendResponseDurationSecondsMetric.With(prometheus.Labels{ "backend_id": bt.backendID, "request_method": req.Method, "response_code": fmt.Sprintf("%d", responseCode), diff --git a/handlers/backend_handler_test.go b/handlers/backend_handler_test.go index 2365e275..1fc5f45b 100644 --- a/handlers/backend_handler_test.go +++ b/handlers/backend_handler_test.go @@ -1,4 +1,4 @@ -package handlers_test +package handlers import ( "io" @@ -15,7 +15,6 @@ import ( promtest "github.com/prometheus/client_golang/prometheus/testutil" prommodel "github.com/prometheus/client_model/go" - "github.com/alphagov/router/handlers" log "github.com/alphagov/router/logger" ) @@ -51,7 +50,7 @@ var _ = Describe("Backend handler", func() { Context("when the backend times out", func() { BeforeEach(func() { - router = handlers.NewBackendHandler( + router = NewBackendHandler( "backend-timeout", backendURL, timeout, timeout, @@ -80,7 +79,7 @@ var _ = Describe("Backend handler", func() { Context("when the backend handles the connection", func() { BeforeEach(func() { - router = handlers.NewBackendHandler( + router = NewBackendHandler( "backend-handle", backendURL, timeout, timeout, @@ -141,15 +140,14 @@ var _ = Describe("Backend handler", func() { Context("metrics", func() { var ( - beforeRequestCountMetric float64 - + beforeRequestCountMetric float64 beforeResponseCountMetric float64 beforeResponseDurationSecondsMetric float64 ) measureRequestCount := func() float64 { return promtest.ToFloat64( - handlers.BackendHandlerRequestCountMetric.With(prometheus.Labels{ + backendRequestCountMetric.With(prometheus.Labels{ "backend_id": "backend-metrics", "request_method": http.MethodGet, }), @@ -160,7 +158,7 @@ var _ = Describe("Backend handler", func() { var err error metricChan := make(chan prometheus.Metric, 1024) - handlers.BackendHandlerResponseDurationSecondsMetric.Collect(metricChan) + backendResponseDurationSecondsMetric.Collect(metricChan) close(metricChan) for m := range metricChan { metric := new(prommodel.Metric) @@ -197,7 +195,7 @@ var _ = Describe("Backend handler", func() { } BeforeEach(func() { - router = handlers.NewBackendHandler( + router = NewBackendHandler( "backend-metrics", backendURL, timeout, timeout, diff --git a/handlers/handlers.go b/handlers/handlers.go deleted file mode 100644 index df249dda..00000000 --- a/handlers/handlers.go +++ /dev/null @@ -1,5 +0,0 @@ -package handlers - -func init() { - initMetrics() -} diff --git a/handlers/handlers_suite_test.go b/handlers/handlers_suite_test.go index f3789b51..1ed17f51 100644 --- a/handlers/handlers_suite_test.go +++ b/handlers/handlers_suite_test.go @@ -1,4 +1,4 @@ -package handlers_test +package handlers import ( "testing" diff --git a/handlers/metrics.go b/handlers/metrics.go index 25983a98..a2f3331a 100644 --- a/handlers/metrics.go +++ b/handlers/metrics.go @@ -5,10 +5,10 @@ import ( ) var ( - RedirectHandlerRedirectCountMetric = prometheus.NewCounterVec( + redirectCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_redirect_handler_redirect_total", - Help: "Number of redirects handled by router redirect handlers", + Help: "Number of redirects served by redirect handlers", }, []string{ "redirect_code", @@ -16,10 +16,10 @@ var ( }, ) - BackendHandlerRequestCountMetric = prometheus.NewCounterVec( + backendRequestCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_backend_handler_request_total", - Help: "Number of requests handled by router backend handlers", + Help: "Number of requests served by backend handlers", }, []string{ "backend_id", @@ -27,10 +27,10 @@ var ( }, ) - BackendHandlerResponseDurationSecondsMetric = prometheus.NewHistogramVec( + backendResponseDurationSecondsMetric = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Name: "router_backend_handler_response_duration_seconds", - Help: "Histogram of response durations by router backend handlers", + Help: "Histogram of response durations by backend", }, []string{ "backend_id", @@ -40,9 +40,10 @@ var ( ) ) -func initMetrics() { - prometheus.MustRegister(RedirectHandlerRedirectCountMetric) - - prometheus.MustRegister(BackendHandlerRequestCountMetric) - prometheus.MustRegister(BackendHandlerResponseDurationSecondsMetric) +func RegisterMetrics(r prometheus.Registerer) { + r.MustRegister( + backendRequestCountMetric, + backendResponseDurationSecondsMetric, + redirectCountMetric, + ) } diff --git a/handlers/redirect_handler.go b/handlers/redirect_handler.go index 80529f97..8fcdd7c7 100644 --- a/handlers/redirect_handler.go +++ b/handlers/redirect_handler.go @@ -61,7 +61,7 @@ func (handler *redirectHandler) ServeHTTP(writer http.ResponseWriter, request *h target := addGAQueryParam(handler.url, request) http.Redirect(writer, request, target, handler.code) - RedirectHandlerRedirectCountMetric.With(prometheus.Labels{ + redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": redirectHandlerType, }).Inc() @@ -82,7 +82,7 @@ func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWrit addCacheHeaders(writer) http.Redirect(writer, request, target, handler.code) - RedirectHandlerRedirectCountMetric.With(prometheus.Labels{ + redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": pathPreservingRedirectHandlerType, }).Inc() diff --git a/handlers/redirect_handler_test.go b/handlers/redirect_handler_test.go index f6b48412..7a6c6dc1 100644 --- a/handlers/redirect_handler_test.go +++ b/handlers/redirect_handler_test.go @@ -1,4 +1,4 @@ -package handlers_test +package handlers import ( "fmt" @@ -11,8 +11,6 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" - - "github.com/alphagov/router/handlers" ) var _ = Describe("A redirect handler", func() { @@ -29,7 +27,7 @@ var _ = Describe("A redirect handler", func() { for _, temporary := range []bool{true, false} { Context(fmt.Sprintf("where preserve=%t, temporary=%t", preserve, temporary), func() { BeforeEach(func() { - handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) }) @@ -53,7 +51,7 @@ var _ = Describe("A redirect handler", func() { Context("where preserve=true", func() { BeforeEach(func() { - handler = handlers.NewRedirectHandler("/source", "/target", true, false) + handler = NewRedirectHandler("/source", "/target", true, false) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) }) @@ -64,7 +62,7 @@ var _ = Describe("A redirect handler", func() { Context("where preserve=false", func() { BeforeEach(func() { - handler = handlers.NewRedirectHandler("/source", "/target", false, false) + handler = NewRedirectHandler("/source", "/target", false, false) }) It("returns only the configured path in the location header", func() { @@ -86,7 +84,7 @@ var _ = Describe("A redirect handler", func() { Entry(nil, true, false, http.StatusMovedPermanently), Entry(nil, true, true, http.StatusFound), func(preserve, temporary bool, expectedStatus int) { - handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) Expect(rr.Result().StatusCode).To(Equal(expectedStatus)) }) @@ -99,12 +97,12 @@ var _ = Describe("A redirect handler", func() { Entry(nil, true, true, "302", "path-preserving-redirect-handler"), func(preserve, temporary bool, codeLabel, typeLabel string) { lbls := prometheus.Labels{"redirect_code": codeLabel, "redirect_type": typeLabel} - before := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(lbls)) + before := promtest.ToFloat64(redirectCountMetric.With(lbls)) - handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) - after := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(lbls)) + after := promtest.ToFloat64(redirectCountMetric.With(lbls)) Expect(after - before).To(BeNumerically("~", 1.0)) }, ) diff --git a/integration_tests/disabled_routes_test.go b/integration_tests/disabled_routes_test.go index 714df121..7c1ae267 100644 --- a/integration_tests/disabled_routes_test.go +++ b/integration_tests/disabled_routes_test.go @@ -11,16 +11,16 @@ var _ = Describe("marking routes as disabled", func() { BeforeEach(func() { addRoute("/unavailable", Route{Handler: "gone", Disabled: true}) addRoute("/something-live", NewRedirectRoute("/somewhere-else")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should return a 503 to the client", func() { - resp := routerRequest("/unavailable") + resp := routerRequest(routerPort, "/unavailable") Expect(resp.StatusCode).To(Equal(503)) }) It("should continue to route other requests", func() { - resp := routerRequest("/something-live") + resp := routerRequest(routerPort, "/something-live") Expect(resp.StatusCode).To(Equal(301)) Expect(resp.Header.Get("Location")).To(Equal("/somewhere-else")) }) diff --git a/integration_tests/envmap.go b/integration_tests/envmap.go deleted file mode 100644 index 42d49fb0..00000000 --- a/integration_tests/envmap.go +++ /dev/null @@ -1,24 +0,0 @@ -package integration - -import ( - "strings" -) - -type envMap map[string]string - -func newEnvMap(env []string) (em envMap) { - em = make(map[string]string, len(env)) - for _, item := range env { - parts := strings.SplitN(item, "=", 2) - em[parts[0]] = parts[1] - } - return -} - -func (em envMap) ToEnv() (env []string) { - env = make([]string, 0, len(em)) - for k, v := range em { - env = append(env, k+"="+v) - } - return -} diff --git a/integration_tests/error_handling_test.go b/integration_tests/error_handling_test.go index 2241bf19..1c8cf534 100644 --- a/integration_tests/error_handling_test.go +++ b/integration_tests/error_handling_test.go @@ -11,14 +11,14 @@ var _ = Describe("error handling", func() { Describe("handling an empty routing table", func() { BeforeEach(func() { - reloadRoutes() + reloadRoutes(apiPort) }) It("should return a 503 error to the client", func() { - resp := routerRequest("/") + resp := routerRequest(routerPort, "/") Expect(resp.StatusCode).To(Equal(503)) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(503)) }) }) @@ -26,16 +26,16 @@ var _ = Describe("error handling", func() { Describe("handling a panic", func() { BeforeEach(func() { addRoute("/boom", Route{Handler: "boom"}) - reloadRoutes() + reloadRoutes(apiPort) }) It("should return a 500 error to the client", func() { - resp := routerRequest("/boom") + resp := routerRequest(routerPort, "/boom") Expect(resp.StatusCode).To(Equal(500)) }) It("should log the fact", func() { - routerRequest("/boom") + routerRequest(routerPort, "/boom") logDetails := lastRouterErrorLogEntry() Expect(logDetails.Fields).To(Equal(map[string]interface{}{ diff --git a/integration_tests/gone_test.go b/integration_tests/gone_test.go index 18479bfe..990606cd 100644 --- a/integration_tests/gone_test.go +++ b/integration_tests/gone_test.go @@ -10,25 +10,25 @@ var _ = Describe("Gone routes", func() { BeforeEach(func() { addRoute("/foo", NewGoneRoute()) addRoute("/bar", NewGoneRoute("prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should support an exact gone route", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(410)) Expect(readBody(resp)).To(Equal("410 Gone\n")) - resp = routerRequest("/foo/bar") + resp = routerRequest(routerPort, "/foo/bar") Expect(resp.StatusCode).To(Equal(404)) Expect(readBody(resp)).To(Equal("404 page not found\n")) }) It("should support a prefix gone route", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(410)) Expect(readBody(resp)).To(Equal("410 Gone\n")) - resp = routerRequest("/bar/baz") + resp = routerRequest(routerPort, "/bar/baz") Expect(resp.StatusCode).To(Equal(410)) Expect(readBody(resp)).To(Equal("410 Gone\n")) }) diff --git a/integration_tests/http_request_helpers.go b/integration_tests/http_request_helpers.go index a1eead92..9822f2fa 100644 --- a/integration_tests/http_request_helpers.go +++ b/integration_tests/http_request_helpers.go @@ -15,12 +15,12 @@ import ( // revive:enable:dot-imports ) -func routerRequest(path string, optionalPort ...int) *http.Response { - return doRequest(newRequest("GET", routerURL(path, optionalPort...))) +func routerRequest(port int, path string) *http.Response { + return doRequest(newRequest("GET", routerURL(port, path))) } -func routerRequestWithHeaders(path string, headers map[string]string, optionalPort ...int) *http.Response { - return doRequest(newRequestWithHeaders("GET", routerURL(path, optionalPort...), headers)) +func routerRequestWithHeaders(port int, path string, headers map[string]string) *http.Response { + return doRequest(newRequestWithHeaders("GET", routerURL(port, path), headers)) } func newRequest(method, url string) *http.Request { diff --git a/integration_tests/integration_test.go b/integration_tests/integration_test.go index d92a094b..a5fcbd34 100644 --- a/integration_tests/integration_test.go +++ b/integration_tests/integration_test.go @@ -20,7 +20,7 @@ var _ = BeforeSuite(func() { if err != nil { Fail(err.Error()) } - err = startRouter(3169, 3168) + err = startRouter(routerPort, apiPort, nil) if err != nil { Fail(err.Error()) } @@ -35,6 +35,6 @@ var _ = BeforeEach(func() { }) var _ = AfterSuite(func() { - stopRouter(3169) + stopRouter(routerPort) cleanupTempLogfile() }) diff --git a/integration_tests/metrics_test.go b/integration_tests/metrics_test.go index 7b4bfdc6..ba4812e9 100644 --- a/integration_tests/metrics_test.go +++ b/integration_tests/metrics_test.go @@ -10,7 +10,7 @@ var _ = Describe("/metrics API endpoint", func() { var responseBody string BeforeEach(func() { - resp := doRequest(newRequest("GET", routerAPIURL("/metrics"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/metrics"))) Expect(resp.StatusCode).To(Equal(200)) responseBody = readBody(resp) }) diff --git a/integration_tests/performance_test.go b/integration_tests/performance_test.go index f01c4248..08be778e 100644 --- a/integration_tests/performance_test.go +++ b/integration_tests/performance_test.go @@ -26,7 +26,7 @@ var _ = Describe("Performance", func() { addBackend("backend-2", backend2.URL) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -48,7 +48,7 @@ var _ = Describe("Performance", func() { case <-stopCh: return case <-ticker.C: - reloadRoutes() + reloadRoutes(apiPort) } }() @@ -62,9 +62,9 @@ var _ = Describe("Performance", func() { defer slowBackend.Close() addBackend("backend-slow", slowBackend.URL) addRoute("/slow", NewBackendRoute("backend-slow")) - reloadRoutes() + reloadRoutes(apiPort) - _, gen := generateLoad([]string{routerURL("/slow")}, 50) + _, gen := generateLoad([]string{routerURL(routerPort, "/slow")}, 50) defer gen.Stop() assertPerformantRouter(backend1, backend2, 50) @@ -75,9 +75,9 @@ var _ = Describe("Performance", func() { It("Router should not cause errors or much latency", func() { addBackend("backend-down", "http://127.0.0.1:3162/") addRoute("/down", NewBackendRoute("backend-down")) - reloadRoutes() + reloadRoutes(apiPort) - _, gen := generateLoad([]string{routerURL("/down")}, 50) + _, gen := generateLoad([]string{routerURL(routerPort, "/down")}, 50) defer gen.Stop() assertPerformantRouter(backend1, backend2, 50) @@ -102,7 +102,7 @@ var _ = Describe("Performance", func() { addBackend("backend-2", backend2.URL) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -117,7 +117,7 @@ var _ = Describe("Performance", func() { func assertPerformantRouter(backend1, backend2 *httptest.Server, rps int) { directResultsCh, _ := generateLoad([]string{backend1.URL + "/one", backend2.URL + "/two"}, rps) - routerResultsCh, _ := generateLoad([]string{routerURL("/one"), routerURL("/two")}, rps) + routerResultsCh, _ := generateLoad([]string{routerURL(routerPort, "/one"), routerURL(routerPort, "/two")}, rps) directResults := <-directResultsCh routerResults := <-routerResultsCh diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index bd4500f3..d7c93fc5 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -20,9 +20,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should return a 502 if the connection to the backend is refused", func() { addBackend("not-running", "http://127.0.0.1:3164/") addRoute("/not-running", NewBackendRoute("not-running")) - reloadRoutes() + reloadRoutes(apiPort) - req, err := http.NewRequest(http.MethodGet, routerURL("/not-running"), nil) + req, err := http.NewRequest(http.MethodGet, routerURL(routerPort, "/not-running"), nil) Expect(err).NotTo(HaveOccurred()) req.Header.Set("X-Varnish", "12345678") @@ -42,7 +42,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should log and return a 504 if the connection times out in the configured time", func() { - err := startRouter(3167, 3166, envMap{"ROUTER_BACKEND_CONNECT_TIMEOUT": "0.3s"}) + err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_CONNECT_TIMEOUT=0.3s"}) Expect(err).NotTo(HaveOccurred()) defer stopRouter(3167) @@ -50,7 +50,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { addRoute("/should-time-out", NewBackendRoute("black-hole")) reloadRoutes(3166) - req, err := http.NewRequest(http.MethodGet, routerURL("/should-time-out", 3167), nil) + req, err := http.NewRequest(http.MethodGet, routerURL(3167, "/should-time-out"), nil) Expect(err).NotTo(HaveOccurred()) req.Header.Set("X-Varnish", "12345678") @@ -80,7 +80,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { ) BeforeEach(func() { - err := startRouter(3167, 3166, envMap{"ROUTER_BACKEND_HEADER_TIMEOUT": "0.3s"}) + err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s"}) Expect(err).NotTo(HaveOccurred()) tarpit1 = startTarpitBackend(time.Second) tarpit2 = startTarpitBackend(100*time.Millisecond, 500*time.Millisecond) @@ -98,7 +98,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should log and return a 504 if a backend takes longer than the configured response timeout to start returning a response", func() { - req := newRequest(http.MethodGet, routerURL("/tarpit1", 3167)) + req := newRequest(http.MethodGet, routerURL(3167, "/tarpit1")) req.Header.Set("X-Varnish", "12341112") resp := doRequest(req) Expect(resp.StatusCode).To(Equal(504)) @@ -117,7 +117,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should still return the response if the body takes longer than the header timeout", func() { - resp := routerRequest("/tarpit2", 3167) + resp := routerRequest(3167, "/tarpit2") Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("Tarpit\n")) }) @@ -135,7 +135,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorderURL, _ = url.Parse(recorder.URL()) addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -143,7 +143,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should pass through most http headers to the backend", func() { - resp := routerRequestWithHeaders("/foo", map[string]string{ + resp := routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Foo": "bar", "User-Agent": "Router test suite 2.7182", }) @@ -156,7 +156,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should set the Host header to the backend hostname", func() { - resp := routerRequestWithHeaders("/foo", map[string]string{ + resp := routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Host": "www.example.com", }) Expect(resp.StatusCode).To(Equal(200)) @@ -168,7 +168,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should not add a default User-Agent if there isn't one in the request", func() { // Most http libraries add a default User-Agent header. - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -177,15 +177,15 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(ok).To(BeFalse()) }) - It("should add the client IP to X-Forwardrd-For", func() { - resp := routerRequest("/foo") + It("should add the client IP to X-Forwarded-For", func() { + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] Expect(beReq.Header.Get("X-Forwarded-For")).To(Equal("127.0.0.1")) - resp = routerRequestWithHeaders("/foo", map[string]string{ + resp = routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "X-Forwarded-For": "10.9.8.7", }) Expect(resp.StatusCode).To(Equal(200)) @@ -199,14 +199,14 @@ var _ = Describe("Functioning as a reverse proxy", func() { // See https://tools.ietf.org/html/rfc2616#section-14.45 It("should add itself to the Via request header for an HTTP/1.1 request", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] Expect(beReq.Header.Get("Via")).To(Equal("1.1 router")) - resp = routerRequestWithHeaders("/foo", map[string]string{ + resp = routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Via": "1.0 fred, 1.1 barney", }) Expect(resp.StatusCode).To(Equal(200)) @@ -217,7 +217,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should add itself to the Via request header for an HTTP/1.0 request", func() { - req := newRequest(http.MethodGet, routerURL("/foo")) + req := newRequest(http.MethodGet, routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) Expect(resp.StatusCode).To(Equal(200)) @@ -225,7 +225,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { beReq := recorder.ReceivedRequests()[0] Expect(beReq.Header.Get("Via")).To(Equal("1.0 router")) - req = newRequestWithHeaders("GET", routerURL("/foo"), map[string]string{ + req = newRequestWithHeaders("GET", routerURL(routerPort, "/foo"), map[string]string{ "Via": "1.0 fred, 1.1 barney", }) resp = doHTTP10Request(req) @@ -237,14 +237,14 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should add itself to the Via response heaver", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Header.Get("Via")).To(Equal("1.1 router")) recorder.AppendHandlers(ghttp.RespondWith(200, "body", http.Header{ "Via": []string{"1.0 fred, 1.1 barney"}, })) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Header.Get("Via")).To(Equal("1.0 fred, 1.1 barney, 1.1 router")) }) @@ -260,7 +260,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -273,11 +273,11 @@ var _ = Describe("Functioning as a reverse proxy", func() { ghttp.VerifyRequest("DELETE", "/foo/bar/baz.json"), ) - req := newRequest("POST", routerURL("/foo")) + req := newRequest("POST", routerURL(routerPort, "/foo")) resp := doRequest(req) Expect(resp.StatusCode).To(Equal(200)) - req = newRequest("DELETE", routerURL("/foo/bar/baz.json")) + req = newRequest("DELETE", routerURL(routerPort, "/foo/bar/baz.json")) resp = doRequest(req) Expect(resp.StatusCode).To(Equal(200)) @@ -288,7 +288,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder.AppendHandlers( ghttp.VerifyRequest("GET", "/foo/bar", "baz=qux"), ) - resp := routerRequest("/foo/bar?baz=qux") + resp := routerRequest(routerPort, "/foo/bar?baz=qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -302,7 +302,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(string(body)).To(Equal("I am the request body. Woohoo!")) }) - req := newRequest("POST", routerURL("/foo")) + req := newRequest("POST", routerURL(routerPort, "/foo")) req.Body = io.NopCloser(strings.NewReader("I am the request body. Woohoo!")) resp := doRequest(req) Expect(resp.StatusCode).To(Equal(200)) @@ -320,7 +320,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()+"/something") addRoute("/foo/bar", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -328,7 +328,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should merge the 2 paths", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -337,7 +337,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should preserve the request query string", func() { - resp := routerRequest("/foo/bar?baz=qux") + resp := routerRequest(routerPort, "/foo/bar?baz=qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -355,7 +355,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -363,7 +363,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should work with incoming HTTP/1.1 requests", func() { - req := newRequest("GET", routerURL("/foo")) + req := newRequest("GET", routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) Expect(resp.StatusCode).To(Equal(200)) @@ -373,7 +373,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should proxy to the backend as HTTP/1.1 requests", func() { - req := newRequest("GET", routerURL("/foo")) + req := newRequest("GET", routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) Expect(resp.StatusCode).To(Equal(200)) @@ -387,7 +387,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { var recorder *ghttp.Server BeforeEach(func() { - err := startRouter(3167, 3166, envMap{"ROUTER_TLS_SKIP_VERIFY": "1"}) + err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"}) Expect(err).NotTo(HaveOccurred()) recorder = startRecordingTLSBackend() addBackend("backend", recorder.URL()) @@ -401,7 +401,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should correctly reverse proxy to a HTTPS backend", func() { - req := newRequest("GET", routerURL("/foo", 3167)) + req := newRequest("GET", routerURL(3167, "/foo")) resp := doRequest(req) Expect(resp.StatusCode).To(Equal(200)) diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 92e935f3..91474c21 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -16,46 +16,46 @@ var _ = Describe("Redirection", func() { addRoute("/query-temp", NewRedirectRoute("/bar?query=true", "exact")) addRoute("/fragment", NewRedirectRoute("/bar#section", "exact")) addRoute("/preserve-query", NewRedirectRoute("/qux", "exact", "permanent", "preserve")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should redirect permanently by default", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(301)) }) It("should redirect temporarily when asked to", func() { - resp := routerRequest("/foo-temp") + resp := routerRequest(routerPort, "/foo-temp") Expect(resp.StatusCode).To(Equal(302)) }) It("should contain the redirect location", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should not preserve the query string for the source by default", func() { - resp := routerRequest("/foo?baz=qux") + resp := routerRequest(routerPort, "/foo?baz=qux") Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should preserve the query string for the source if specified", func() { - resp := routerRequest("/preserve-query?foo=bar") + resp := routerRequest(routerPort, "/preserve-query?foo=bar") Expect(resp.Header.Get("Location")).To(Equal("/qux?foo=bar")) }) It("should preserve the query string for the target", func() { - resp := routerRequest("/query-temp") + resp := routerRequest(routerPort, "/query-temp") Expect(resp.Header.Get("Location")).To(Equal("/bar?query=true")) }) It("should preserve the fragment for the target", func() { - resp := routerRequest("/fragment") + resp := routerRequest(routerPort, "/fragment") Expect(resp.Header.Get("Location")).To(Equal("/bar#section")) }) It("should contain cache headers of 30 mins", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Cache-Control")).To(Equal("max-age=1800, public")) Expect( @@ -73,43 +73,43 @@ var _ = Describe("Redirection", func() { addRoute("/foo", NewRedirectRoute("/bar", "prefix")) addRoute("/foo-temp", NewRedirectRoute("/bar-temp", "prefix", "temporary")) addRoute("/qux", NewRedirectRoute("/baz", "prefix", "temporary", "ignore")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should redirect permanently to the destination", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(301)) Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should redirect temporarily to the destination when asked to", func() { - resp := routerRequest("/foo-temp") + resp := routerRequest(routerPort, "/foo-temp") Expect(resp.StatusCode).To(Equal(302)) Expect(resp.Header.Get("Location")).To(Equal("/bar-temp")) }) It("should preserve extra path sections when redirecting by default", func() { - resp := routerRequest("/foo/baz") + resp := routerRequest(routerPort, "/foo/baz") Expect(resp.Header.Get("Location")).To(Equal("/bar/baz")) }) It("should ignore extra path sections when redirecting if specified", func() { - resp := routerRequest("/qux/quux") + resp := routerRequest(routerPort, "/qux/quux") Expect(resp.Header.Get("Location")).To(Equal("/baz")) }) It("should preserve the query string when redirecting by default", func() { - resp := routerRequest("/foo?baz=qux") + resp := routerRequest(routerPort, "/foo?baz=qux") Expect(resp.Header.Get("Location")).To(Equal("/bar?baz=qux")) }) It("should not preserve the query string when redirecting if specified", func() { - resp := routerRequest("/qux/quux?foo=bar") + resp := routerRequest(routerPort, "/qux/quux?foo=bar") Expect(resp.Header.Get("Location")).To(Equal("/baz")) }) It("should contain cache headers of 30 mins", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Cache-Control")).To(Equal("max-age=1800, public")) Expect( @@ -123,9 +123,9 @@ var _ = Describe("Redirection", func() { It("should handle path-preserving redirects with special characters", func() { addRoute("/foo%20bar", NewRedirectRoute("/bar%20baz", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/foo bar/something") + resp := routerRequest(routerPort, "/foo bar/something") Expect(resp.StatusCode).To(Equal(301)) Expect(resp.Header.Get("Location")).To(Equal("/bar%20baz/something")) }) @@ -137,44 +137,44 @@ var _ = Describe("Redirection", func() { addRoute("/baz", NewRedirectRoute("http://foo.example.com/baz", "exact", "permanent", "preserve")) addRoute("/bar", NewRedirectRoute("http://bar.example.com/bar", "prefix")) addRoute("/qux", NewRedirectRoute("http://bar.example.com/qux", "prefix", "permanent", "ignore")) - reloadRoutes() + reloadRoutes(apiPort) }) Describe("exact redirect", func() { It("should redirect to the external URL", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Location")).To(Equal("http://foo.example.com/foo")) }) It("should not preserve the query string by default", func() { - resp := routerRequest("/foo?foo=qux") + resp := routerRequest(routerPort, "/foo?foo=qux") Expect(resp.Header.Get("Location")).To(Equal("http://foo.example.com/foo")) }) It("should preserve the query string if specified", func() { - resp := routerRequest("/baz?foo=qux") + resp := routerRequest(routerPort, "/baz?foo=qux") Expect(resp.Header.Get("Location")).To(Equal("http://foo.example.com/baz?foo=qux")) }) }) Describe("prefix redirect", func() { It("should redirect to the external URL", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/bar")) }) It("should preserve extra path sections when redirecting by default", func() { - resp := routerRequest("/bar/baz") + resp := routerRequest(routerPort, "/bar/baz") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/bar/baz")) }) It("should ignore extra path sections when redirecting if specified", func() { - resp := routerRequest("/qux/baz") + resp := routerRequest(routerPort, "/qux/baz") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/qux")) }) It("should preserve the query string when redirecting", func() { - resp := routerRequest("/bar?baz=qux") + resp := routerRequest(routerPort, "/bar?baz=qux") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/bar?baz=qux")) }) }) @@ -188,41 +188,41 @@ var _ = Describe("Redirection", func() { addRoute("/pay-tax", NewRedirectRoute("https://tax.service.gov.uk/pay", "exact", "permanent", "ignore")) addRoute("/biz-bank", NewRedirectRoute("https://british-business-bank.co.uk", "prefix", "permanent", "ignore")) addRoute("/query-paramed", NewRedirectRoute("https://param.servicegov.uk?included-param=true", "exact", "permanent", "ignore")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should only preserve the _ga parameter when redirecting to service URLs that want to ignore query params", func() { - resp := routerRequest("/foo?_ga=identifier&blah=xyz") + resp := routerRequest(routerPort, "/foo?_ga=identifier&blah=xyz") Expect(resp.Header.Get("Location")).To(Equal("https://hmrc.service.gov.uk/pay?_ga=identifier")) }) It("should retain all params when redirecting to a route that wants them", func() { - resp := routerRequest("/bar?wanted=param&_ga=xyz&blah=xyz") + resp := routerRequest(routerPort, "/bar?wanted=param&_ga=xyz&blah=xyz") Expect(resp.Header.Get("Location")).To(Equal("https://bar.service.gov.uk/bar?wanted=param&_ga=xyz&blah=xyz")) }) It("should preserve the _ga parameter when redirecting to gov.uk URLs", func() { - resp := routerRequest("/baz?_ga=identifier") + resp := routerRequest(routerPort, "/baz?_ga=identifier") Expect(resp.Header.Get("Location")).To(Equal("https://gov.uk/baz-luhrmann?_ga=identifier")) }) It("should preserve the _ga parameter when redirecting to service.gov.uk URLs", func() { - resp := routerRequest("/pay-tax?_ga=12345") + resp := routerRequest(routerPort, "/pay-tax?_ga=12345") Expect(resp.Header.Get("Location")).To(Equal("https://tax.service.gov.uk/pay?_ga=12345")) }) It("should preserve only the first _ga parameter", func() { - resp := routerRequest("/pay-tax/?_ga=12345&_ga=6789") + resp := routerRequest(routerPort, "/pay-tax/?_ga=12345&_ga=6789") Expect(resp.Header.Get("Location")).To(Equal("https://tax.service.gov.uk/pay?_ga=12345")) }) It("should preserve the _ga param when redirecting to british business bank", func() { - resp := routerRequest("/biz-bank?unwanted=param&_ga=12345") + resp := routerRequest(routerPort, "/biz-bank?unwanted=param&_ga=12345") Expect(resp.Header.Get("Location")).To(Equal("https://british-business-bank.co.uk?_ga=12345")) }) It("should preserve the _ga param and any existing query string that the target URL has", func() { - resp := routerRequest("/query-paramed?unwanted_param=blah&_ga=12345") + resp := routerRequest(routerPort, "/query-paramed?unwanted_param=blah&_ga=12345") // https://param.servicegov.uk?included-param=true?unwanted_param=blah&_ga=12345 Expect(resp.Header.Get("Location")).To(Equal("https://param.servicegov.uk?_ga=12345&included-param=true")) }) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index 47dd326b..0b52cd17 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -11,102 +11,51 @@ var _ = Describe("reload API endpoint", func() { Describe("request handling", func() { It("should return 202 for POST /reload", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/reload"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) Expect(resp.StatusCode).To(Equal(202)) Expect(readBody(resp)).To(Equal("Reload queued")) }) It("should return 404 for POST /foo", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/foo"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/foo"))) Expect(resp.StatusCode).To(Equal(404)) }) It("should return 404 for POST /reload/foo", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/reload/foo"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/reload/foo"))) Expect(resp.StatusCode).To(Equal(404)) }) It("should return 405 for GET /reload", func() { - resp := doRequest(newRequest("GET", routerAPIURL("/reload"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/reload"))) Expect(resp.StatusCode).To(Equal(405)) Expect(resp.Header.Get("Allow")).To(Equal("POST")) }) It("eventually reloads the routes", func() { addRoute("/foo", NewRedirectRoute("/qux", "prefix")) - - start := time.Now() - doRequest(newRequest("POST", routerAPIURL("/reload"))) - end := time.Now() - duration := end.Sub(start) - - Expect(duration.Nanoseconds()).To(BeNumerically("<", 5000000)) - addRoute("/bar", NewRedirectRoute("/qux", "prefix")) - doRequest(newRequest("POST", routerAPIURL("/reload"))) + doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) Eventually(func() int { - return routerRequest("/foo").StatusCode - }, time.Second*1).Should(Equal(301)) + return routerRequest(routerPort, "/foo").StatusCode + }, time.Second*3).Should(Equal(301)) Eventually(func() int { - return routerRequest("/bar").StatusCode - }, time.Second*1).Should(Equal(301)) + return routerRequest(routerPort, "/bar").StatusCode + }, time.Second*3).Should(Equal(301)) }) }) Describe("healthcheck", func() { - It("should return 200 and sting 'OK' on /healthcheck", func() { - resp := doRequest(newRequest("GET", routerAPIURL("/healthcheck"))) + It("should return HTTP 200 OK on GET", func() { + resp := doRequest(newRequest("GET", routerURL(apiPort, "/healthcheck"))) Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("OK")) }) - It("should return 405 for other verbs", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/healthcheck"))) - Expect(resp.StatusCode).To(Equal(405)) - Expect(resp.Header.Get("Allow")).To(Equal("GET")) - }) - }) - - Describe("route stats", func() { - - Context("with some routes loaded", func() { - var data map[string]map[string]interface{} - - BeforeEach(func() { - addRoute("/foo", NewRedirectRoute("/bar", "prefix")) - addRoute("/baz", NewRedirectRoute("/qux", "prefix")) - addRoute("/foo", NewRedirectRoute("/bar/baz")) - reloadRoutes() - resp := doRequest(newRequest("GET", routerAPIURL("/stats"))) - Expect(resp.StatusCode).To(Equal(200)) - readJSONBody(resp, &data) - }) - - It("should return the number of routes loaded", func() { - Expect(data["routes"]["count"]).To(BeEquivalentTo(3)) - }) - }) - - Context("with no routes", func() { - var data map[string]map[string]interface{} - - BeforeEach(func() { - reloadRoutes() - - resp := doRequest(newRequest("GET", routerAPIURL("/stats"))) - Expect(resp.StatusCode).To(Equal(200)) - readJSONBody(resp, &data) - }) - - It("should return the number of routes loaded", func() { - Expect(data["routes"]["count"]).To(BeEquivalentTo(0)) - }) - }) - - It("should return 405 for other verbs", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/stats"))) + It("should return HTTP 405 Method Not Allowed on POST", func() { + resp := doRequest(newRequest("POST", routerURL(apiPort, "/healthcheck"))) Expect(resp.StatusCode).To(Equal(405)) Expect(resp.Header.Get("Allow")).To(Equal("GET")) }) @@ -117,9 +66,9 @@ var _ = Describe("reload API endpoint", func() { addRoute("/foo", NewRedirectRoute("/bar", "prefix")) addRoute("/baz", NewRedirectRoute("/qux", "prefix")) addRoute("/foo", NewRedirectRoute("/bar/baz")) - reloadRoutes() + reloadRoutes(apiPort) - resp := doRequest(newRequest("GET", routerAPIURL("/memory-stats"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/memory-stats"))) Expect(resp.StatusCode).To(Equal(200)) var data map[string]interface{} diff --git a/integration_tests/route_loading_test.go b/integration_tests/route_loading_test.go index 38f491a8..827c1deb 100644 --- a/integration_tests/route_loading_test.go +++ b/integration_tests/route_loading_test.go @@ -1,6 +1,7 @@ package integration import ( + "fmt" "net/http/httptest" . "github.com/onsi/ginkgo/v2" @@ -30,19 +31,19 @@ var _ = Describe("loading routes from the db", func() { addRoute("/foo", NewBackendRoute("backend-1")) addRoute("/bar", Route{Handler: "fooey"}) addRoute("/baz", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should skip the invalid route", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should continue to load other routes", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 2")) }) }) @@ -53,22 +54,22 @@ var _ = Describe("loading routes from the db", func() { addRoute("/bar", NewBackendRoute("backend-non-existent")) addRoute("/baz", NewBackendRoute("backend-2")) addRoute("/qux", NewBackendRoute("backend-1")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should skip the invalid route", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should continue to load other routes", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/qux") + resp = routerRequest(routerPort, "/qux") Expect(readBody(resp)).To(Equal("backend 1")) }) }) @@ -81,23 +82,23 @@ var _ = Describe("loading routes from the db", func() { backend3 = startSimpleBackend("backend 3") addBackend("backend-3", blackHole) - stopRouter(3169) - err := startRouter(3169, 3168, envMap{"BACKEND_URL_backend-3": backend3.URL}) + stopRouter(routerPort) + err := startRouter(routerPort, apiPort, []string{fmt.Sprintf("BACKEND_URL_backend-3=%s", backend3.URL)}) Expect(err).NotTo(HaveOccurred()) addRoute("/oof", NewBackendRoute("backend-3")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { - stopRouter(3169) - err := startRouter(3169, 3168) + stopRouter(routerPort) + err := startRouter(routerPort, apiPort, nil) Expect(err).NotTo(HaveOccurred()) backend3.Close() }) It("should send requests to the backend_url provided in the env var", func() { - resp := routerRequest("/oof") + resp := routerRequest(routerPort, "/oof") Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("backend 3")) }) diff --git a/integration_tests/route_selection_test.go b/integration_tests/route_selection_test.go index c3c089ef..4b60e44b 100644 --- a/integration_tests/route_selection_test.go +++ b/integration_tests/route_selection_test.go @@ -24,7 +24,7 @@ var _ = Describe("Route selection", func() { addRoute("/foo", NewBackendRoute("backend-1")) addRoute("/bar", NewBackendRoute("backend-2")) addRoute("/baz", NewBackendRoute("backend-1")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -32,29 +32,29 @@ var _ = Describe("Route selection", func() { }) It("should route a matching request to the corresponding backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 1")) }) It("should 404 for children of the exact route", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should 404 for non-matching requests", func() { - resp := routerRequest("/wibble") + resp := routerRequest(routerPort, "/wibble") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/") + resp = routerRequest(routerPort, "/") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/foo.json") + resp = routerRequest(routerPort, "/foo.json") Expect(resp.StatusCode).To(Equal(404)) }) }) @@ -73,7 +73,7 @@ var _ = Describe("Route selection", func() { addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/bar", NewBackendRoute("backend-2", "prefix")) addRoute("/baz", NewBackendRoute("backend-1", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -81,35 +81,35 @@ var _ = Describe("Route selection", func() { }) It("should route requests for the prefix to the backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 1")) }) It("should route requests for the children of the prefix to the backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/bar/foo.json") + resp = routerRequest(routerPort, "/bar/foo.json") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/baz/fooey/kablooie") + resp = routerRequest(routerPort, "/baz/fooey/kablooie") Expect(readBody(resp)).To(Equal("backend 1")) }) It("should 404 for non-matching requests", func() { - resp := routerRequest("/wibble") + resp := routerRequest(routerPort, "/wibble") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/") + resp = routerRequest(routerPort, "/") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/foo.json") + resp = routerRequest(routerPort, "/foo.json") Expect(resp.StatusCode).To(Equal(404)) }) }) @@ -126,7 +126,7 @@ var _ = Describe("Route selection", func() { addBackend("outer-backend", outer.URL) addBackend("inner-backend", inner.URL) addRoute("/foo", NewBackendRoute("outer-backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { outer.Close() @@ -136,21 +136,21 @@ var _ = Describe("Route selection", func() { Describe("with an exact child", func() { BeforeEach(func() { addRoute("/foo/bar", NewBackendRoute("inner-backend")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should route the prefix to the outer backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the exact child to the inner backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("inner")) }) It("should route the children of the exact child to the outer backend", func() { - resp := routerRequest("/foo/bar/baz") + resp := routerRequest(routerPort, "/foo/bar/baz") Expect(readBody(resp)).To(Equal("outer")) }) }) @@ -158,29 +158,29 @@ var _ = Describe("Route selection", func() { Describe("with a prefix child", func() { BeforeEach(func() { addRoute("/foo/bar", NewBackendRoute("inner-backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should route the outer prefix to the outer backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the inner prefix to the inner backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("inner")) }) It("should route the children of the inner prefix to the inner backend", func() { - resp := routerRequest("/foo/bar/baz") + resp := routerRequest(routerPort, "/foo/bar/baz") Expect(readBody(resp)).To(Equal("inner")) }) It("should route other children of the outer prefix to the outer backend", func() { - resp := routerRequest("/foo/baz") + resp := routerRequest(routerPort, "/foo/baz") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/bar.json") + resp = routerRequest(routerPort, "/foo/bar.json") Expect(readBody(resp)).To(Equal("outer")) }) }) @@ -194,43 +194,43 @@ var _ = Describe("Route selection", func() { addBackend("innerer-backend", innerer.URL) addRoute("/foo/bar", NewBackendRoute("inner-backend")) addRoute("/foo/bar/baz", NewBackendRoute("innerer-backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { innerer.Close() }) It("should route the outer prefix to the outer backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/baz") + resp = routerRequest(routerPort, "/foo/baz") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/bar.json") + resp = routerRequest(routerPort, "/foo/bar.json") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the exact route to the inner backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("inner")) }) It("should route other children of the exact route to the outer backend", func() { - resp := routerRequest("/foo/bar/wibble") + resp := routerRequest(routerPort, "/foo/bar/wibble") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/bar/baz.json") + resp = routerRequest(routerPort, "/foo/bar/baz.json") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the inner prefix route to the innerer backend", func() { - resp := routerRequest("/foo/bar/baz") + resp := routerRequest(routerPort, "/foo/bar/baz") Expect(readBody(resp)).To(Equal("innerer")) }) It("should route children of the inner prefix route to the innerer backend", func() { - resp := routerRequest("/foo/bar/baz/wibble") + resp := routerRequest(routerPort, "/foo/bar/baz/wibble") Expect(readBody(resp)).To(Equal("innerer")) }) }) @@ -249,7 +249,7 @@ var _ = Describe("Route selection", func() { addBackend("backend-2", backend2.URL) addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/foo", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -257,12 +257,12 @@ var _ = Describe("Route selection", func() { }) It("should route the exact route to the exact backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 2")) }) It("should route children of the route to the prefix backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("backend 1")) }) }) @@ -287,29 +287,29 @@ var _ = Describe("Route selection", func() { It("should handle an exact route at the root level", func() { addRoute("/", NewBackendRoute("root")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/") + resp := routerRequest(routerPort, "/") Expect(readBody(resp)).To(Equal("root backend")) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("other backend")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should handle a prefix route at the root level", func() { addRoute("/", NewBackendRoute("root", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/") + resp := routerRequest(routerPort, "/") Expect(readBody(resp)).To(Equal("root backend")) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("other backend")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(readBody(resp)).To(Equal("root backend")) }) }) @@ -327,7 +327,7 @@ var _ = Describe("Route selection", func() { addBackend("other", recorder.URL()) addRoute("/", NewBackendRoute("root", "prefix")) addRoute("/foo/bar", NewBackendRoute("other", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { root.Close() @@ -335,19 +335,19 @@ var _ = Describe("Route selection", func() { }) It("should not be redirected by our simple test backend", func() { - resp := routerRequest("//") + resp := routerRequest(routerPort, "//") Expect(readBody(resp)).To(Equal("fallthrough")) }) It("should not be redirected by our recorder backend", func() { - resp := routerRequest("/foo/bar/baz//qux") + resp := routerRequest(routerPort, "/foo/bar/baz//qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].URL.Path).To(Equal("/foo/bar/baz//qux")) }) It("should collapse double slashes when looking up route, but pass request as-is", func() { - resp := routerRequest("/foo//bar") + resp := routerRequest(routerPort, "/foo//bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].URL.Path).To(Equal("/foo//bar")) @@ -367,9 +367,9 @@ var _ = Describe("Route selection", func() { It("should handle spaces (%20) in paths", func() { addRoute("/foo%20bar", NewBackendRoute("backend")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/foo bar") + resp := routerRequest(routerPort, "/foo bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].RequestURI).To(Equal("/foo%20bar")) diff --git a/integration_tests/router_support.go b/integration_tests/router_support.go index 74e3c36f..d03d526f 100644 --- a/integration_tests/router_support.go +++ b/integration_tests/router_support.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/exec" + "strconv" "syscall" "time" @@ -15,24 +16,16 @@ import ( // revive:enable:dot-imports ) -func routerURL(path string, optionalPort ...int) string { - port := 3169 - if len(optionalPort) > 0 { - port = optionalPort[0] - } - return fmt.Sprintf("http://127.0.0.1:%d%s", port, path) -} +const ( + routerPort = 3169 + apiPort = 3168 +) -func routerAPIURL(path string) string { - return routerURL(path, 3168) +func routerURL(port int, path string) string { + return fmt.Sprintf("http://127.0.0.1:%d%s", port, path) } -func reloadRoutes(optionalPort ...int) { - port := 3168 - if len(optionalPort) > 0 { - port = optionalPort[0] - } - +func reloadRoutes(port int) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, @@ -52,9 +45,10 @@ func reloadRoutes(optionalPort ...int) { var runningRouters = make(map[int]*exec.Cmd) -func startRouter(port, apiPort int, optionalExtraEnv ...envMap) error { - pubaddr := fmt.Sprintf(":%d", port) - apiaddr := fmt.Sprintf(":%d", apiPort) +func startRouter(port, apiPort int, extraEnv []string) error { + host := "localhost" + pubAddr := net.JoinHostPort(host, strconv.Itoa(port)) + apiAddr := net.JoinHostPort(host, strconv.Itoa(apiPort)) bin := os.Getenv("BINARY") if bin == "" { @@ -62,20 +56,13 @@ func startRouter(port, apiPort int, optionalExtraEnv ...envMap) error { } cmd := exec.Command(bin) - env := newEnvMap(os.Environ()) - env["ROUTER_PUBADDR"] = pubaddr - env["ROUTER_APIADDR"] = apiaddr - env["ROUTER_MONGO_DB"] = "router_test" - env["ROUTER_MONGO_POLL_INTERVAL"] = "2s" - env["ROUTER_ERROR_LOG"] = tempLogfile.Name() - if len(optionalExtraEnv) > 0 { - for k, v := range optionalExtraEnv[0] { - env[k] = v - } - } - cmd.Env = env.ToEnv() + cmd.Env = append(cmd.Environ(), "ROUTER_MONGO_DB=router_test") + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_PUBADDR=%s", pubAddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_APIADDR=%s", apiAddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_ERROR_LOG=%s", tempLogfile.Name())) + cmd.Env = append(cmd.Env, extraEnv...) - if os.Getenv("DEBUG_ROUTER") != "" { + if os.Getenv("ROUTER_DEBUG_TESTS") != "" { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr } @@ -85,7 +72,7 @@ func startRouter(port, apiPort int, optionalExtraEnv ...envMap) error { return err } - waitForServerUp(pubaddr) + waitForServerUp(pubAddr) runningRouters[port] = cmd return nil diff --git a/lib/logcompat.go b/lib/logcompat.go new file mode 100644 index 00000000..e9a9c2f9 --- /dev/null +++ b/lib/logcompat.go @@ -0,0 +1,21 @@ +package router + +// TODO: remove this file and use rs/zerolog throughout. + +import "log" + +var EnableDebugOutput bool + +func logWarn(msg ...interface{}) { + log.Println(msg...) +} + +func logInfo(msg ...interface{}) { + log.Println(msg...) +} + +func logDebug(msg ...interface{}) { + if EnableDebugOutput { + log.Println(msg...) + } +} diff --git a/metrics.go b/lib/metrics.go similarity index 65% rename from metrics.go rename to lib/metrics.go index 653da981..36b0addc 100644 --- a/metrics.go +++ b/lib/metrics.go @@ -1,6 +1,8 @@ -package main +package router import ( + "github.com/alphagov/router/handlers" + "github.com/alphagov/router/triemux" "github.com/prometheus/client_golang/prometheus" ) @@ -8,7 +10,7 @@ var ( internalServerErrorCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_internal_server_error_total", - Help: "Number of internal server errors encountered by router", + Help: "Number of 500 Internal Server Error responses originating from Router", }, []string{"host"}, ) @@ -35,11 +37,13 @@ var ( ) ) -func initMetrics() { - prometheus.MustRegister(internalServerErrorCountMetric) - - prometheus.MustRegister(routeReloadCountMetric) - prometheus.MustRegister(routeReloadErrorCountMetric) - - prometheus.MustRegister(routesCountMetric) +func registerMetrics(r prometheus.Registerer) { + r.MustRegister( + internalServerErrorCountMetric, + routeReloadCountMetric, + routeReloadErrorCountMetric, + routesCountMetric, + ) + handlers.RegisterMetrics(r) + triemux.RegisterMetrics(r) } diff --git a/router.go b/lib/router.go similarity index 82% rename from router.go rename to lib/router.go index bda7326d..03e56545 100644 --- a/router.go +++ b/lib/router.go @@ -1,4 +1,4 @@ -package main +package router import ( "fmt" @@ -29,17 +29,30 @@ const ( // Router is a wrapper around an HTTP multiplexer (trie.Mux) which retrieves its // routes from a passed mongo database. +// +// TODO: decouple Router from its database backend. Router should not know +// anything about the database backend. Its representation of the route table +// should be independent of the underlying DBMS. Route should define an +// abstract interface for some other module to be able to bulk-load and +// incrementally update routes. Since Router should not care where its routes +// come from, Route and Backend should not contain bson fields. +// MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { - mux *triemux.Mux - lock sync.RWMutex - mongoURL string - mongoDbName string - mongoPollInterval time.Duration - backendConnectTimeout time.Duration - backendHeaderTimeout time.Duration - mongoReadToOptime bson.MongoTimestamp - logger logger.Logger - ReloadChan chan bool + mux *triemux.Mux + lock sync.RWMutex + mongoReadToOptime bson.MongoTimestamp + logger logger.Logger + opts Options + ReloadChan chan bool +} + +type Options struct { + MongoURL string + MongoDBName string + MongoPollInterval time.Duration + BackendConnTimeout time.Duration + BackendHeaderTimeout time.Duration + LogFileName string } type Backend struct { @@ -69,36 +82,38 @@ type Route struct { Disabled bool `bson:"disabled"` } +// RegisterMetrics registers Prometheus metrics from the router module and the +// modules that it directly depends on. To use the default (global) registry, +// pass prometheus.DefaultRegisterer. +func RegisterMetrics(r prometheus.Registerer) { + registerMetrics(r) +} + // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. -func NewRouter(mongoURL, mongoDbName string, mongoPollInterval, beConnTimeout, beHeaderTimeout time.Duration, logFileName string) (rt *Router, err error) { - logInfo("router: using mongo poll interval:", mongoPollInterval) - logInfo("router: using backend connect timeout:", beConnTimeout) - logInfo("router: using backend header timeout:", beHeaderTimeout) +func NewRouter(o Options) (rt *Router, err error) { + logInfo("router: using mongo poll interval:", o.MongoPollInterval) + logInfo("router: using backend connect timeout:", o.BackendConnTimeout) + logInfo("router: using backend header timeout:", o.BackendHeaderTimeout) - l, err := logger.New(logFileName) + l, err := logger.New(o.LogFileName) if err != nil { return nil, err } + logInfo("router: logging errors as JSON to", o.LogFileName) mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1) if err != nil { return nil, err } - logInfo("router: logging errors as JSON to", logFileName) - reloadChan := make(chan bool, 1) rt = &Router{ - mux: triemux.NewMux(), - mongoURL: mongoURL, - mongoPollInterval: mongoPollInterval, - mongoDbName: mongoDbName, - backendConnectTimeout: beConnTimeout, - backendHeaderTimeout: beHeaderTimeout, - mongoReadToOptime: mongoReadToOptime, - logger: l, - ReloadChan: reloadChan, + mux: triemux.NewMux(), + mongoReadToOptime: mongoReadToOptime, + logger: l, + opts: o, + ReloadChan: reloadChan, } go rt.pollAndReload() @@ -135,9 +150,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (rt *Router) SelfUpdateRoutes() { - logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.mongoPollInterval)) + logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval)) - tick := time.Tick(rt.mongoPollInterval) + tick := time.Tick(rt.opts.MongoPollInterval) for range tick { logDebug("router: polling MongoDB for changes") @@ -157,9 +172,9 @@ func (rt *Router) pollAndReload() { } }() - logDebug("mgo: connecting to", rt.mongoURL) + logDebug("mgo: connecting to", rt.opts.MongoURL) - sess, err := mgo.Dial(rt.mongoURL) + sess, err := mgo.Dial(rt.opts.MongoURL) if err != nil { logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err)) return @@ -182,7 +197,7 @@ func (rt *Router) pollAndReload() { if rt.shouldReload(currentMongoInstance) { logDebug("router: updates found") - rt.reloadRoutes(sess.DB(rt.mongoDbName), currentMongoInstance.Optime) + rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime) } else { logDebug("router: no updates found") } @@ -220,14 +235,14 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta backends := rt.loadBackends(db.C("backends")) loadRoutes(db.C("routes"), newmux, backends) + routeCount := newmux.RouteCount() rt.lock.Lock() rt.mux = newmux rt.lock.Unlock() - logInfo(fmt.Sprintf("router: reloaded %d routes", rt.mux.RouteCount())) - - routesCountMetric.Set(float64(rt.mux.RouteCount())) + logInfo(fmt.Sprintf("router: reloaded %d routes", routeCount)) + routesCountMetric.Set(float64(routeCount)) } func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) { @@ -288,7 +303,8 @@ func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Hand backends[backend.BackendID] = handlers.NewBackendHandler( backend.BackendID, backendURL, - rt.backendConnectTimeout, rt.backendHeaderTimeout, + rt.opts.BackendConnTimeout, + rt.opts.BackendHeaderTimeout, rt.logger, ) } @@ -377,16 +393,6 @@ func (be *Backend) ParseURL() (*url.URL, error) { return url.Parse(backendURL) } -func (rt *Router) RouteStats() (stats map[string]interface{}) { - rt.lock.RLock() - mux := rt.mux - rt.lock.RUnlock() - - stats = make(map[string]interface{}) - stats["count"] = mux.RouteCount() - return -} - func shouldPreserveSegments(route *Route) bool { switch route.RouteType { case RouteTypeExact: diff --git a/router_api.go b/lib/router_api.go similarity index 71% rename from router_api.go rename to lib/router_api.go index 8652a3f4..ab0a4fbe 100644 --- a/router_api.go +++ b/lib/router_api.go @@ -1,15 +1,14 @@ -package main +package router import ( "encoding/json" - "fmt" "net/http" "runtime" "github.com/prometheus/client_golang/prometheus/promhttp" ) -func newAPIHandler(rout *Router) (api http.Handler, err error) { +func NewAPIHandler(rout *Router) (api http.Handler, err error) { mux := http.NewServeMux() mux.HandleFunc("/reload", func(w http.ResponseWriter, r *http.Request) { @@ -47,30 +46,6 @@ func newAPIHandler(rout *Router) (api http.Handler, err error) { } }) - mux.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - w.Header().Set("Allow", http.MethodGet) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - stats := make(map[string]map[string]interface{}) - stats["routes"] = rout.RouteStats() - - jsonData, err := json.MarshalIndent(stats, "", " ") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - _, err = fmt.Fprintln(w, string(jsonData)) - if err != nil { - logWarn(err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - mux.HandleFunc("/memory-stats", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { w.Header().Set("Allow", http.MethodGet) diff --git a/router_test.go b/lib/router_test.go similarity index 99% rename from router_test.go rename to lib/router_test.go index eb69c7cf..5a11ce88 100644 --- a/router_test.go +++ b/lib/router_test.go @@ -1,4 +1,4 @@ -package main +package router import ( "errors" diff --git a/lib/version.go b/lib/version.go new file mode 100644 index 00000000..479dee03 --- /dev/null +++ b/lib/version.go @@ -0,0 +1,42 @@ +package router + +import ( + "fmt" + "runtime/debug" +) + +// VersionInfo returns human-readable version information in a format suitable +// for concatenation with other messages. +func VersionInfo() (v string) { + v = "(version info unavailable)" + + bi, ok := debug.ReadBuildInfo() + if !ok { + return + } + + rev, commitTime, dirty := buildSettings(bi.Settings) + if rev == "" { + return + } + + commitTimeOrDirty := "dirty" + if dirty == "false" { + commitTimeOrDirty = commitTime + } + return fmt.Sprintf("built from commit %.8s (%s) using %s", rev, commitTimeOrDirty, bi.GoVersion) +} + +func buildSettings(bs []debug.BuildSetting) (rev, commitTime, dirty string) { + for _, b := range bs { + switch b.Key { + case "vcs.modified": + dirty = b.Value + case "vcs.revision": + rev = b.Value + case "vcs.time": + commitTime = b.Value + } + } + return +} diff --git a/logger/logger.go b/logger/logger.go index 7bfb7790..74900405 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -58,7 +58,7 @@ func openWriter(output interface{}) (w io.Writer, err error) { } } default: - return nil, fmt.Errorf("Invalid output type %T(%v)", output, output) + return nil, fmt.Errorf("invalid output type %T(%v)", output, output) } return } diff --git a/main.go b/main.go index f14d4d8a..f2d878e3 100644 --- a/main.go +++ b/main.go @@ -10,21 +10,8 @@ import ( "time" "github.com/alphagov/router/handlers" -) - -var ( - pubAddr = getenvDefault("ROUTER_PUBADDR", ":8080") - apiAddr = getenvDefault("ROUTER_APIADDR", ":8081") - mongoURL = getenvDefault("ROUTER_MONGO_URL", "127.0.0.1") - mongoDbName = getenvDefault("ROUTER_MONGO_DB", "router") - mongoPollInterval = getenvDefault("ROUTER_MONGO_POLL_INTERVAL", "2s") - errorLogFile = getenvDefault("ROUTER_ERROR_LOG", "STDERR") - tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - enableDebugOutput = os.Getenv("DEBUG") != "" - backendConnectTimeout = getenvDefault("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") - backendHeaderTimeout = getenvDefault("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") - frontendReadTimeout = getenvDefault("ROUTER_FRONTEND_READ_TIMEOUT", "60s") - frontendWriteTimeout = getenvDefault("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + router "github.com/alphagov/router/lib" + "github.com/prometheus/client_golang/prometheus" ) func usage() { @@ -40,40 +27,38 @@ ROUTER_MONGO_URL=127.0.0.1 Address of mongo cluster (e.g. 'mongo1,mongo2,m ROUTER_MONGO_DB=router Name of mongo database to use ROUTER_MONGO_POLL_INTERVAL=2s Interval to poll mongo for route changes ROUTER_ERROR_LOG=STDERR File to log errors to (in JSON format) -DEBUG= Whether to enable debug output - set to anything to enable +ROUTER_DEBUG= Enable debug output if non-empty -Timeouts: (values must be parseable by https://pkg.go.dev/time#ParseDuration +Timeouts: (values must be parseable by https://pkg.go.dev/time#ParseDuration) ROUTER_BACKEND_CONNECT_TIMEOUT=1s Connect timeout when connecting to backends ROUTER_BACKEND_HEADER_TIMEOUT=15s Timeout for backend response headers to be returned ROUTER_FRONTEND_READ_TIMEOUT=60s See https://cs.opensource.google/go/go/+/master:src/net/http/server.go?q=symbol:ReadTimeout ROUTER_FRONTEND_WRITE_TIMEOUT=60s See https://cs.opensource.google/go/go/+/master:src/net/http/server.go?q=symbol:WriteTimeout ` - fmt.Fprintf(os.Stderr, helpstring, versionInfo(), os.Args[0]) - os.Exit(2) + fmt.Fprintf(os.Stderr, helpstring, router.VersionInfo(), os.Args[0]) + const ErrUsage = 64 + os.Exit(ErrUsage) } -func getenvDefault(key string, defaultVal string) string { - val := os.Getenv(key) - if val == "" { - val = defaultVal +func getenv(key string, defaultVal string) string { + if s := os.Getenv(key); s != "" { + return s } - - return val -} - -func logWarn(msg ...interface{}) { - log.Println(msg...) + return defaultVal } -func logInfo(msg ...interface{}) { - log.Println(msg...) +func getenvDuration(key string, defaultVal string) time.Duration { + s := getenv(key, defaultVal) + return mustParseDuration(s) } -func logDebug(msg ...interface{}) { - if enableDebugOutput { - log.Println(msg...) +func mustParseDuration(s string) (d time.Duration) { + d, err := time.ParseDuration(s) + if err != nil { + log.Fatal(err) } + return } func listenAndServeOrFatal(addr string, handler http.Handler, rTimeout time.Duration, wTimeout time.Duration) { @@ -88,54 +73,63 @@ func listenAndServeOrFatal(addr string, handler http.Handler, rTimeout time.Dura } } -func parseDurationOrFatal(s string) (d time.Duration) { - d, err := time.ParseDuration(s) - if err != nil { - log.Fatal(err) - } - return -} - func main() { returnVersion := flag.Bool("version", false, "") flag.Usage = usage flag.Parse() + + fmt.Printf("GOV.UK Router %s\n", router.VersionInfo()) if *returnVersion { - fmt.Printf("GOV.UK Router %s\n", versionInfo()) os.Exit(0) } - feReadTimeout := parseDurationOrFatal(frontendReadTimeout) - feWriteTimeout := parseDurationOrFatal(frontendWriteTimeout) - beConnectTimeout := parseDurationOrFatal(backendConnectTimeout) - beHeaderTimeout := parseDurationOrFatal(backendHeaderTimeout) - mgoPollInterval := parseDurationOrFatal(mongoPollInterval) - - initMetrics() - - logInfo("router: using frontend read timeout:", feReadTimeout) - logInfo("router: using frontend write timeout:", feWriteTimeout) - logInfo(fmt.Sprintf("router: using GOMAXPROCS value of %d", runtime.GOMAXPROCS(0))) + router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" + var ( + pubAddr = getenv("ROUTER_PUBADDR", ":8080") + apiAddr = getenv("ROUTER_APIADDR", ":8081") + mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1") + mongoDBName = getenv("ROUTER_MONGO_DB", "router") + mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s") + errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") + tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" + beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") + beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") + feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") + feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + ) + + log.Printf("using frontend read timeout: %v", feReadTimeout) + log.Printf("using frontend write timeout: %v", feWriteTimeout) + log.Printf("using GOMAXPROCS value of %d", runtime.GOMAXPROCS(0)) if tlsSkipVerify { handlers.TLSSkipVerify = true - logWarn("router: Skipping verification of TLS certificates. " + + log.Printf("skipping verification of TLS certificates; " + "Do not use this option in a production environment.") } - rout, err := NewRouter(mongoURL, mongoDbName, mgoPollInterval, beConnectTimeout, beHeaderTimeout, errorLogFile) + router.RegisterMetrics(prometheus.DefaultRegisterer) + + rout, err := router.NewRouter(router.Options{ + MongoURL: mongoURL, + MongoDBName: mongoDBName, + MongoPollInterval: mongoPollInterval, + BackendConnTimeout: beConnTimeout, + BackendHeaderTimeout: beHeaderTimeout, + LogFileName: errorLogFile, + }) if err != nil { log.Fatal(err) } go rout.SelfUpdateRoutes() go listenAndServeOrFatal(pubAddr, rout, feReadTimeout, feWriteTimeout) - logInfo(fmt.Sprintf("router: listening for requests on %v", pubAddr)) + log.Printf("router: listening for requests on %v", pubAddr) - api, err := newAPIHandler(rout) + api, err := router.NewAPIHandler(rout) if err != nil { log.Fatal(err) } - logInfo(fmt.Sprintf("router: listening for API requests on %v", apiAddr)) + log.Printf("router: listening for API requests on %v", apiAddr) listenAndServeOrFatal(apiAddr, api, feReadTimeout, feWriteTimeout) } diff --git a/mongo.sh b/mongo.sh new file mode 100755 index 00000000..3fb51e21 --- /dev/null +++ b/mongo.sh @@ -0,0 +1,73 @@ +#!/bin/dash +set -eu + +usage() { + echo "$0 restart|start|stop" + exit 64 +} + +failure_hints() { + echo ' + Failed to start mongo. If using Docker Desktop: + - Go into Settings -> Features in development + - untick "Use containerd" + - tick "Use Rosetta"' + exit 1 +} + +docker_run() { + docker run --name router-mongo -dp 27017:27017 mongo:2.4 --replSet rs0 --quiet +} + +init_replicaset() { + docker exec router-mongo mongo --quiet --eval 'rs.initiate();' >/dev/null 2>&1 +} + +healthy() { + docker exec router-mongo mongo --quiet --eval \ + 'if (rs.status().members[0].health==1) print("healthy");' \ + 2>&1 | grep healthy >/dev/null +} + +# usage: retry_or_fatal description command-to-try +retry_or_fatal() { + n=10 + echo -n "Waiting up to $n s for $1"; shift + while [ "$n" -ge 0 ]; do + if "$@"; then + echo " done" + return + fi + sleep 1 && echo -n . + done + echo "gave up" + exit 1 +} + +stop() { + if ! docker stop router-mongo >/dev/null 2>&1; then + echo "router-mongo not running" + return + fi + echo -n Waiting for router-mongo container to exit. + docker wait router-mongo >/dev/null || true + docker rm -f router-mongo >/dev/null 2>&1 || true + echo " done" +} + +start() { + if healthy; then + echo router-mongo already running. + return + fi + stop + docker_run || failure_hints + retry_or_fatal "for successful rs.initiate()" init_replicaset + retry_or_fatal "for healthy rs.status()" healthy +} + +case $1 in + start) $1;; + stop) $1;; + *) usage +esac diff --git a/trie/trie.go b/trie/trie.go index 10fdc9b0..11eda389 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -23,7 +23,7 @@ func NewTrie() *Trie { // and returns the object if the path exists in the Trie, or nil and a status of // false. Example: // -// if res, ok := trie.Get([]string{"foo", "bar"}), ok { +// if res, ok := trie.Get([]string{"foo", "bar"}); ok { // fmt.Println("Value at /foo/bar was", res) // } func (t *Trie) Get(path []string) (entry interface{}, ok bool) { @@ -50,7 +50,7 @@ func (t *Trie) Get(path []string) (entry interface{}, ok bool) { // longest matching prefix is returned. If nothing matches at all, nil and a // status of false is returned. Example: // -// if res, ok := trie.GetLongestPrefix([]string{"foo", "bar"}), ok { +// if res, ok := trie.GetLongestPrefix([]string{"foo", "bar"}); ok { // fmt.Println("Value at /foo/bar was", res) // } func (t *Trie) GetLongestPrefix(path []string) (entry interface{}, ok bool) { diff --git a/triemux/metrics.go b/triemux/metrics.go index 8a2846e4..187e7177 100644 --- a/triemux/metrics.go +++ b/triemux/metrics.go @@ -5,23 +5,24 @@ import ( ) var ( - EntryNotFoundCountMetric = prometheus.NewCounter( + entryNotFoundCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_triemux_entry_not_found_total", - Help: "Number of triemux lookups for which an entry was not found", + Help: "Number of route lookups for which no route was found", }, ) - InternalServiceUnavailableCountMetric = prometheus.NewCounterVec( + internalServiceUnavailableCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_service_unavailable_error_total", - Help: "Number of 503 Service Unavailable errors served by router", + Help: "Number of 503 Service Unavailable errors originating from Router", }, - []string{"temporary_child"}, ) ) -func initMetrics() { - prometheus.MustRegister(EntryNotFoundCountMetric) - prometheus.MustRegister(InternalServiceUnavailableCountMetric) +func RegisterMetrics(r prometheus.Registerer) { + r.MustRegister( + entryNotFoundCountMetric, + internalServiceUnavailableCountMetric, + ) } diff --git a/triemux/mux.go b/triemux/mux.go index 6f5870c1..da2f9295 100644 --- a/triemux/mux.go +++ b/triemux/mux.go @@ -6,14 +6,11 @@ package triemux import ( "log" "net/http" - "os" "strings" "sync" "github.com/alphagov/router/logger" "github.com/alphagov/router/trie" - - "github.com/prometheus/client_golang/prometheus" ) type Mux struct { @@ -41,16 +38,10 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if mux.count == 0 { w.WriteHeader(http.StatusServiceUnavailable) logger.NotifySentry(logger.ReportableError{ - Error: logger.RecoveredError{ErrorMessage: "Route table is empty!"}, + Error: logger.RecoveredError{ErrorMessage: "route table is empty"}, Request: r, }) - tempChild, isParent := os.LookupEnv("TEMPORARY_CHILD") - if !isParent { - tempChild = "0" - } - InternalServiceUnavailableCountMetric.With(prometheus.Labels{ - "temporary_child": tempChild, - }).Inc() + internalServiceUnavailableCountMetric.Inc() return } @@ -75,14 +66,14 @@ func (mux *Mux) lookup(path string) (handler http.Handler, ok bool) { val, ok = mux.prefixTrie.GetLongestPrefix(pathSegments) } if !ok { - EntryNotFoundCountMetric.Inc() + entryNotFoundCountMetric.Inc() return nil, false } entry, ok := val.(muxEntry) if !ok { log.Printf("lookup: got value (%v) from trie that wasn't a muxEntry!", val) - EntryNotFoundCountMetric.Inc() + entryNotFoundCountMetric.Inc() return nil, false } diff --git a/triemux/mux_test.go b/triemux/mux_test.go index 5d2c9367..6452b9be 100644 --- a/triemux/mux_test.go +++ b/triemux/mux_test.go @@ -162,13 +162,13 @@ var lookupExamples = []LookupExample{ } func TestLookup(t *testing.T) { - beforeCount := promtest.ToFloat64(EntryNotFoundCountMetric) + beforeCount := promtest.ToFloat64(entryNotFoundCountMetric) for _, ex := range lookupExamples { testLookup(t, ex) } - afterCount := promtest.ToFloat64(EntryNotFoundCountMetric) + afterCount := promtest.ToFloat64(entryNotFoundCountMetric) notFoundCount := afterCount - beforeCount var expectedNotFoundCount int diff --git a/triemux/triemux.go b/triemux/triemux.go deleted file mode 100644 index fb15e185..00000000 --- a/triemux/triemux.go +++ /dev/null @@ -1,5 +0,0 @@ -package triemux - -func init() { - initMetrics() -} diff --git a/version.go b/version.go deleted file mode 100644 index 1a7eefc3..00000000 --- a/version.go +++ /dev/null @@ -1,13 +0,0 @@ -package main - -import ( - "fmt" - "runtime" -) - -// populated by -ldflags from Makefile -var version = "unknown" - -func versionInfo() string { - return fmt.Sprintf("build: %s (compiler: %s)", version, runtime.Version()) -}