From 11d00a52078dc713f9b01abf78c306df45f10835 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 29 Jul 2023 13:09:59 +0100 Subject: [PATCH 01/24] Move Router code out of the main module. It's currently very hard to unit test the methods of the Router type because it's so tightly coupled with program initialisation stuff in the main module. Moving it into its own module is the first step towards fixing this. --- lib/logcompat.go | 21 +++++++++++++++++ metrics.go => lib/metrics.go | 4 ++-- router.go => lib/router.go | 2 +- router_api.go => lib/router_api.go | 4 ++-- router_test.go => lib/router_test.go | 2 +- main.go | 35 +++++++++------------------- 6 files changed, 38 insertions(+), 30 deletions(-) create mode 100644 lib/logcompat.go rename metrics.go => lib/metrics.go (96%) rename router.go => lib/router.go (99%) rename router_api.go => lib/router_api.go (96%) rename router_test.go => lib/router_test.go (99%) diff --git a/lib/logcompat.go b/lib/logcompat.go new file mode 100644 index 00000000..e9a9c2f9 --- /dev/null +++ b/lib/logcompat.go @@ -0,0 +1,21 @@ +package router + +// TODO: remove this file and use rs/zerolog throughout. + +import "log" + +var EnableDebugOutput bool + +func logWarn(msg ...interface{}) { + log.Println(msg...) +} + +func logInfo(msg ...interface{}) { + log.Println(msg...) +} + +func logDebug(msg ...interface{}) { + if EnableDebugOutput { + log.Println(msg...) + } +} diff --git a/metrics.go b/lib/metrics.go similarity index 96% rename from metrics.go rename to lib/metrics.go index 653da981..50164a25 100644 --- a/metrics.go +++ b/lib/metrics.go @@ -1,4 +1,4 @@ -package main +package router import ( "github.com/prometheus/client_golang/prometheus" @@ -35,7 +35,7 @@ var ( ) ) -func initMetrics() { +func InitMetrics() { prometheus.MustRegister(internalServerErrorCountMetric) prometheus.MustRegister(routeReloadCountMetric) diff --git a/router.go b/lib/router.go similarity index 99% rename from router.go rename to lib/router.go index bda7326d..d95abe39 100644 --- a/router.go +++ b/lib/router.go @@ -1,4 +1,4 @@ -package main +package router import ( "fmt" diff --git a/router_api.go b/lib/router_api.go similarity index 96% rename from router_api.go rename to lib/router_api.go index 8652a3f4..63f55543 100644 --- a/router_api.go +++ b/lib/router_api.go @@ -1,4 +1,4 @@ -package main +package router import ( "encoding/json" @@ -9,7 +9,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -func newAPIHandler(rout *Router) (api http.Handler, err error) { +func NewAPIHandler(rout *Router) (api http.Handler, err error) { mux := http.NewServeMux() mux.HandleFunc("/reload", func(w http.ResponseWriter, r *http.Request) { diff --git a/router_test.go b/lib/router_test.go similarity index 99% rename from router_test.go rename to lib/router_test.go index eb69c7cf..5a11ce88 100644 --- a/router_test.go +++ b/lib/router_test.go @@ -1,4 +1,4 @@ -package main +package router import ( "errors" diff --git a/main.go b/main.go index f14d4d8a..6d1b34e7 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "time" "github.com/alphagov/router/handlers" + router "github.com/alphagov/router/lib" ) var ( @@ -20,7 +21,6 @@ var ( mongoPollInterval = getenvDefault("ROUTER_MONGO_POLL_INTERVAL", "2s") errorLogFile = getenvDefault("ROUTER_ERROR_LOG", "STDERR") tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - enableDebugOutput = os.Getenv("DEBUG") != "" backendConnectTimeout = getenvDefault("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") backendHeaderTimeout = getenvDefault("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") frontendReadTimeout = getenvDefault("ROUTER_FRONTEND_READ_TIMEOUT", "60s") @@ -62,20 +62,6 @@ func getenvDefault(key string, defaultVal string) string { return val } -func logWarn(msg ...interface{}) { - log.Println(msg...) -} - -func logInfo(msg ...interface{}) { - log.Println(msg...) -} - -func logDebug(msg ...interface{}) { - if enableDebugOutput { - log.Println(msg...) - } -} - func listenAndServeOrFatal(addr string, handler http.Handler, rTimeout time.Duration, wTimeout time.Duration) { srv := &http.Server{ Addr: addr, @@ -111,31 +97,32 @@ func main() { beHeaderTimeout := parseDurationOrFatal(backendHeaderTimeout) mgoPollInterval := parseDurationOrFatal(mongoPollInterval) - initMetrics() + router.EnableDebugOutput = os.Getenv("DEBUG") != "" + router.InitMetrics() - logInfo("router: using frontend read timeout:", feReadTimeout) - logInfo("router: using frontend write timeout:", feWriteTimeout) - logInfo(fmt.Sprintf("router: using GOMAXPROCS value of %d", runtime.GOMAXPROCS(0))) + log.Printf("using frontend read timeout: %v", feReadTimeout) + log.Printf("using frontend write timeout: %v", feWriteTimeout) + log.Printf("using GOMAXPROCS value of %d", runtime.GOMAXPROCS(0)) if tlsSkipVerify { handlers.TLSSkipVerify = true - logWarn("router: Skipping verification of TLS certificates. " + + log.Printf("skipping verification of TLS certificates; " + "Do not use this option in a production environment.") } - rout, err := NewRouter(mongoURL, mongoDbName, mgoPollInterval, beConnectTimeout, beHeaderTimeout, errorLogFile) + rout, err := router.NewRouter(mongoURL, mongoDbName, mgoPollInterval, beConnectTimeout, beHeaderTimeout, errorLogFile) if err != nil { log.Fatal(err) } go rout.SelfUpdateRoutes() go listenAndServeOrFatal(pubAddr, rout, feReadTimeout, feWriteTimeout) - logInfo(fmt.Sprintf("router: listening for requests on %v", pubAddr)) + log.Printf("router: listening for requests on %v", pubAddr) - api, err := newAPIHandler(rout) + api, err := router.NewAPIHandler(rout) if err != nil { log.Fatal(err) } - logInfo(fmt.Sprintf("router: listening for API requests on %v", apiAddr)) + log.Printf("router: listening for API requests on %v", apiAddr) listenAndServeOrFatal(apiAddr, api, feReadTimeout, feWriteTimeout) } From f9eff5a045cf627c3bbbe5b4419d3d83b2a38806 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 29 Jul 2023 22:14:41 +0100 Subject: [PATCH 02/24] Extract the mongo startup script from the Makefile. It's much easier to read and debug as a standalone shell script, plus we get shellcheck coverage. Hopefully this fixes the last of the mongo startup flakiness. --- Makefile | 35 ++++++++------------------- mongo.sh | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 25 deletions(-) create mode 100755 mongo.sh diff --git a/Makefile b/Makefile index 29c106e2..bb4f75f0 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ -.PHONY: all build test unit_tests integration_tests clean start_mongo clean_mongo clean_mongo_again +.PHONY: all clean build lint test unit_tests integration_tests start_mongo stop_mongo +.NOTPARALLEL: BINARY ?= router -SHELL := /bin/bash +SHELL := /bin/dash ifdef RELEASE_VERSION VERSION := $(RELEASE_VERSION) @@ -20,34 +21,18 @@ build: lint: golangci-lint run -test: start_mongo unit_tests integration_tests clean_mongo_again +test: lint unit_tests integration_tests unit_tests: build go test -race $$(go list ./... | grep -v integration_tests) -integration_tests: start_mongo build +integration_tests: build start_mongo ROUTER_PUBADDR=localhost:8080 \ ROUTER_APIADDR=localhost:8081 \ go test -race -v ./integration_tests -start_mongo: clean_mongo - @if ! docker run --rm --name router-mongo -dp 27017:27017 mongo:2.4 --replSet rs0 --quiet; then \ - echo 'Failed to start mongo; if using Docker Desktop, try:' ; \ - echo ' - disabling Settings -> Features in development -> Use containerd' ; \ - echo ' - enabling Settings -> Features in development -> Use Rosetta' ; \ - exit 1 ; \ - fi - @echo -n Waiting for mongo - @for n in {1..30}; do \ - if docker exec router-mongo mongo --quiet --eval 'rs.initiate()' >/dev/null 2>&1; then \ - sleep 1; \ - echo ; \ - break ; \ - fi ; \ - echo -n . ; \ - sleep 1 ; \ - done ; \ - -clean_mongo clean_mongo_again: - docker rm -f router-mongo >/dev/null 2>&1 || true - @sleep 1 # Docker doesn't queue commands so it races with itself :( +start_mongo: + ./mongo.sh start + +stop_mongo: + ./mongo.sh stop diff --git a/mongo.sh b/mongo.sh new file mode 100755 index 00000000..3fb51e21 --- /dev/null +++ b/mongo.sh @@ -0,0 +1,73 @@ +#!/bin/dash +set -eu + +usage() { + echo "$0 restart|start|stop" + exit 64 +} + +failure_hints() { + echo ' + Failed to start mongo. If using Docker Desktop: + - Go into Settings -> Features in development + - untick "Use containerd" + - tick "Use Rosetta"' + exit 1 +} + +docker_run() { + docker run --name router-mongo -dp 27017:27017 mongo:2.4 --replSet rs0 --quiet +} + +init_replicaset() { + docker exec router-mongo mongo --quiet --eval 'rs.initiate();' >/dev/null 2>&1 +} + +healthy() { + docker exec router-mongo mongo --quiet --eval \ + 'if (rs.status().members[0].health==1) print("healthy");' \ + 2>&1 | grep healthy >/dev/null +} + +# usage: retry_or_fatal description command-to-try +retry_or_fatal() { + n=10 + echo -n "Waiting up to $n s for $1"; shift + while [ "$n" -ge 0 ]; do + if "$@"; then + echo " done" + return + fi + sleep 1 && echo -n . + done + echo "gave up" + exit 1 +} + +stop() { + if ! docker stop router-mongo >/dev/null 2>&1; then + echo "router-mongo not running" + return + fi + echo -n Waiting for router-mongo container to exit. + docker wait router-mongo >/dev/null || true + docker rm -f router-mongo >/dev/null 2>&1 || true + echo " done" +} + +start() { + if healthy; then + echo router-mongo already running. + return + fi + stop + docker_run || failure_hints + retry_or_fatal "for successful rs.initiate()" init_replicaset + retry_or_fatal "for healthy rs.status()" healthy +} + +case $1 in + start) $1;; + stop) $1;; + *) usage +esac From 6ec74654145df5b540c1e177c277ff221332e847 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 29 Jul 2023 16:33:00 +0100 Subject: [PATCH 03/24] Clean up module-level variables in main.go. --- main.go | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/main.go b/main.go index 6d1b34e7..5858a47a 100644 --- a/main.go +++ b/main.go @@ -13,20 +13,6 @@ import ( router "github.com/alphagov/router/lib" ) -var ( - pubAddr = getenvDefault("ROUTER_PUBADDR", ":8080") - apiAddr = getenvDefault("ROUTER_APIADDR", ":8081") - mongoURL = getenvDefault("ROUTER_MONGO_URL", "127.0.0.1") - mongoDbName = getenvDefault("ROUTER_MONGO_DB", "router") - mongoPollInterval = getenvDefault("ROUTER_MONGO_POLL_INTERVAL", "2s") - errorLogFile = getenvDefault("ROUTER_ERROR_LOG", "STDERR") - tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - backendConnectTimeout = getenvDefault("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") - backendHeaderTimeout = getenvDefault("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") - frontendReadTimeout = getenvDefault("ROUTER_FRONTEND_READ_TIMEOUT", "60s") - frontendWriteTimeout = getenvDefault("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") -) - func usage() { helpstring := ` GOV.UK Router %s @@ -42,7 +28,7 @@ ROUTER_MONGO_POLL_INTERVAL=2s Interval to poll mongo for route changes ROUTER_ERROR_LOG=STDERR File to log errors to (in JSON format) DEBUG= Whether to enable debug output - set to anything to enable -Timeouts: (values must be parseable by https://pkg.go.dev/time#ParseDuration +Timeouts: (values must be parseable by https://pkg.go.dev/time#ParseDuration) ROUTER_BACKEND_CONNECT_TIMEOUT=1s Connect timeout when connecting to backends ROUTER_BACKEND_HEADER_TIMEOUT=15s Timeout for backend response headers to be returned @@ -50,7 +36,8 @@ ROUTER_FRONTEND_READ_TIMEOUT=60s See https://cs.opensource.google/go/go/+/mast ROUTER_FRONTEND_WRITE_TIMEOUT=60s See https://cs.opensource.google/go/go/+/master:src/net/http/server.go?q=symbol:WriteTimeout ` fmt.Fprintf(os.Stderr, helpstring, versionInfo(), os.Args[0]) - os.Exit(2) + const ErrUsage = 64 + os.Exit(ErrUsage) } func getenvDefault(key string, defaultVal string) string { @@ -83,6 +70,21 @@ func parseDurationOrFatal(s string) (d time.Duration) { } func main() { + router.EnableDebugOutput = os.Getenv("DEBUG") != "" + var ( + pubAddr = getenvDefault("ROUTER_PUBADDR", ":8080") + apiAddr = getenvDefault("ROUTER_APIADDR", ":8081") + mongoURL = getenvDefault("ROUTER_MONGO_URL", "127.0.0.1") + mongoDbName = getenvDefault("ROUTER_MONGO_DB", "router") + mongoPollInterval = getenvDefault("ROUTER_MONGO_POLL_INTERVAL", "2s") + errorLogFile = getenvDefault("ROUTER_ERROR_LOG", "STDERR") + tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" + backendConnectTimeout = getenvDefault("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") + backendHeaderTimeout = getenvDefault("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") + frontendReadTimeout = getenvDefault("ROUTER_FRONTEND_READ_TIMEOUT", "60s") + frontendWriteTimeout = getenvDefault("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + ) + returnVersion := flag.Bool("version", false, "") flag.Usage = usage flag.Parse() @@ -97,7 +99,6 @@ func main() { beHeaderTimeout := parseDurationOrFatal(backendHeaderTimeout) mgoPollInterval := parseDurationOrFatal(mongoPollInterval) - router.EnableDebugOutput = os.Getenv("DEBUG") != "" router.InitMetrics() log.Printf("using frontend read timeout: %v", feReadTimeout) From 586023d9bb988236ca38f56b11629f61434405e6 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 29 Jul 2023 16:36:05 +0100 Subject: [PATCH 04/24] Fix poorly-named debug env vars. --- README.md | 8 ++++---- integration_tests/router_support.go | 2 +- main.go | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 8ceea467..c11eb43e 100644 --- a/README.md +++ b/README.md @@ -59,17 +59,17 @@ make lint ### Debug output -To see debug messages when running tests, set both the `DEBUG` and -`DEBUG_ROUTER` environment variables. +To see debug messages when running tests, set both the `ROUTER_DEBUG` and +`ROUTER_DEBUG_TESTS` environment variables: ```sh -export DEBUG=1 DEBUG_ROUTER=1 +export ROUTER_DEBUG=1 ROUTER_DEBUG_TESTS=1 ``` or equivalently for a single run: ```sh -DEBUG=1 DEBUG_ROUTER=1 make test +ROUTER_DEBUG=1 ROUTER_DEBUG_TESTS=1 make test ``` ### Update the dependencies diff --git a/integration_tests/router_support.go b/integration_tests/router_support.go index 74e3c36f..2daf51f1 100644 --- a/integration_tests/router_support.go +++ b/integration_tests/router_support.go @@ -75,7 +75,7 @@ func startRouter(port, apiPort int, optionalExtraEnv ...envMap) error { } cmd.Env = env.ToEnv() - if os.Getenv("DEBUG_ROUTER") != "" { + if os.Getenv("ROUTER_DEBUG_TESTS") != "" { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr } diff --git a/main.go b/main.go index 5858a47a..f8c4eeb1 100644 --- a/main.go +++ b/main.go @@ -26,7 +26,7 @@ ROUTER_MONGO_URL=127.0.0.1 Address of mongo cluster (e.g. 'mongo1,mongo2,m ROUTER_MONGO_DB=router Name of mongo database to use ROUTER_MONGO_POLL_INTERVAL=2s Interval to poll mongo for route changes ROUTER_ERROR_LOG=STDERR File to log errors to (in JSON format) -DEBUG= Whether to enable debug output - set to anything to enable +ROUTER_DEBUG= Enable debug output if non-empty Timeouts: (values must be parseable by https://pkg.go.dev/time#ParseDuration) @@ -70,7 +70,7 @@ func parseDurationOrFatal(s string) (d time.Duration) { } func main() { - router.EnableDebugOutput = os.Getenv("DEBUG") != "" + router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" var ( pubAddr = getenvDefault("ROUTER_PUBADDR", ":8080") apiAddr = getenvDefault("ROUTER_APIADDR", ":8081") From bf04cfe642ef3dab690d34bb657e959bcf904ac9 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 29 Jul 2023 22:04:42 +0100 Subject: [PATCH 05/24] Move version.go out of main and use BuildInfo. Move version.go from main to the router module (lib/) so that it can (later) be accessed from Router itself. Make it use standard debug.BuildInfo that the Go toolchain produced automatically, so we no longer have to mess about with shell scripts and linker flags. This eliminates some unnecessary differences between the Dockerfile and Makefile builds. --- .dockerignore | 14 ++++++++------ .gitignore | 2 +- Dockerfile | 18 ++++++++++++++++-- Makefile | 28 +++++++++++----------------- lib/version.go | 42 ++++++++++++++++++++++++++++++++++++++++++ main.go | 4 ++-- version.go | 13 ------------- 7 files changed, 80 insertions(+), 41 deletions(-) create mode 100644 lib/version.go delete mode 100644 version.go diff --git a/.dockerignore b/.dockerignore index 1c8deaa6..c9ee771f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,8 @@ -.git -.gitignore -Dockerfile -README.md -docs -integration_tests +# +# Only files that are untracked by Git should be added here. +# +# The builder container needs to see a pristine checkout, otherwise +# vcs.modified in the BuildInfo will always be true, i.e. the build will always +# be marked as "dirty". +# +/router diff --git a/.gitignore b/.gitignore index aeed59a5..8421b565 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -router +/router __build diff --git a/Dockerfile b/Dockerfile index 6929371d..a2a9728d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,22 @@ -FROM golang:1.20-alpine AS builder +ARG go_registry="" +ARG go_version=1.20 +ARG go_tag_suffix=-alpine + +FROM ${go_registry}golang:${go_version}${go_tag_suffix} AS builder ARG TARGETARCH TARGETOS +ARG GOARCH=$TARGETARCH GOOS=$TARGETOS +ARG CGO_ENABLED=0 +ARG GOFLAGS="-trimpath" +ARG go_ldflags="-s -w" +# Go needs git for `-buildvcs`, but the alpine version lacks git :( It's still +# way cheaper to `apk add git` than to pull the Debian-based golang image. +# hadolint ignore=DL3018 +RUN apk add --no-cache git WORKDIR /src COPY . ./ -RUN CGO_ENABLED=0 GOARCH=$TARGETARCH GOOS=$TARGETOS go build -trimpath -ldflags="-s -w" +RUN go build -ldflags="$go_ldflags" && \ + ./router -version && \ + go version -m ./router FROM scratch COPY --from=builder /src/router /bin/router diff --git a/Makefile b/Makefile index bb4f75f0..e8bbb768 100644 --- a/Makefile +++ b/Makefile @@ -1,35 +1,29 @@ -.PHONY: all clean build lint test unit_tests integration_tests start_mongo stop_mongo +.PHONY: all clean build test lint unit_tests integration_tests start_mongo stop_mongo .NOTPARALLEL: -BINARY ?= router +TARGET_MODULE := router +GO_BUILD_ENV := CGO_ENABLED=0 SHELL := /bin/dash -ifdef RELEASE_VERSION -VERSION := $(RELEASE_VERSION) -else -VERSION := $(shell git describe --always | tr -d '\n'; test -z "`git status --porcelain`" || echo '-dirty') -endif - -all: build test +all: build clean: - rm -f $(BINARY) + rm -f $(TARGET_MODULE) build: - go build -ldflags "-X main.version=$(VERSION)" -o $(BINARY) + env $(GO_BUILD_ENV) go build + ./$(TARGET_MODULE) -version + +test: lint unit_tests integration_tests lint: golangci-lint run -test: lint unit_tests integration_tests - -unit_tests: build +unit_tests: go test -race $$(go list ./... | grep -v integration_tests) integration_tests: build start_mongo - ROUTER_PUBADDR=localhost:8080 \ - ROUTER_APIADDR=localhost:8081 \ - go test -race -v ./integration_tests + go test -race -v ./integration_tests start_mongo: ./mongo.sh start diff --git a/lib/version.go b/lib/version.go new file mode 100644 index 00000000..479dee03 --- /dev/null +++ b/lib/version.go @@ -0,0 +1,42 @@ +package router + +import ( + "fmt" + "runtime/debug" +) + +// VersionInfo returns human-readable version information in a format suitable +// for concatenation with other messages. +func VersionInfo() (v string) { + v = "(version info unavailable)" + + bi, ok := debug.ReadBuildInfo() + if !ok { + return + } + + rev, commitTime, dirty := buildSettings(bi.Settings) + if rev == "" { + return + } + + commitTimeOrDirty := "dirty" + if dirty == "false" { + commitTimeOrDirty = commitTime + } + return fmt.Sprintf("built from commit %.8s (%s) using %s", rev, commitTimeOrDirty, bi.GoVersion) +} + +func buildSettings(bs []debug.BuildSetting) (rev, commitTime, dirty string) { + for _, b := range bs { + switch b.Key { + case "vcs.modified": + dirty = b.Value + case "vcs.revision": + rev = b.Value + case "vcs.time": + commitTime = b.Value + } + } + return +} diff --git a/main.go b/main.go index f8c4eeb1..1f731124 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,7 @@ ROUTER_BACKEND_HEADER_TIMEOUT=15s Timeout for backend response headers to be re ROUTER_FRONTEND_READ_TIMEOUT=60s See https://cs.opensource.google/go/go/+/master:src/net/http/server.go?q=symbol:ReadTimeout ROUTER_FRONTEND_WRITE_TIMEOUT=60s See https://cs.opensource.google/go/go/+/master:src/net/http/server.go?q=symbol:WriteTimeout ` - fmt.Fprintf(os.Stderr, helpstring, versionInfo(), os.Args[0]) + fmt.Fprintf(os.Stderr, helpstring, router.VersionInfo(), os.Args[0]) const ErrUsage = 64 os.Exit(ErrUsage) } @@ -89,7 +89,7 @@ func main() { flag.Usage = usage flag.Parse() if *returnVersion { - fmt.Printf("GOV.UK Router %s\n", versionInfo()) + fmt.Printf("GOV.UK Router %s\n", router.VersionInfo()) os.Exit(0) } diff --git a/version.go b/version.go deleted file mode 100644 index 1a7eefc3..00000000 --- a/version.go +++ /dev/null @@ -1,13 +0,0 @@ -package main - -import ( - "fmt" - "runtime" -) - -// populated by -ldflags from Makefile -var version = "unknown" - -func versionInfo() string { - return fmt.Sprintf("build: %s (compiler: %s)", version, runtime.Version()) -} From eb84ce61ae32aa5f0ccbe3486988013ab11699f9 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Mon, 31 Jul 2023 17:28:04 +0100 Subject: [PATCH 06/24] Always print the version info on startup. It's useful to see the version string in the logs on startup. --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 1f731124..e3574000 100644 --- a/main.go +++ b/main.go @@ -88,8 +88,9 @@ func main() { returnVersion := flag.Bool("version", false, "") flag.Usage = usage flag.Parse() + + fmt.Printf("GOV.UK Router %s\n", router.VersionInfo()) if *returnVersion { - fmt.Printf("GOV.UK Router %s\n", router.VersionInfo()) os.Exit(0) } From f2fd2f8303da58f40cc3eed068af969e9065bf19 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Mon, 31 Jul 2023 00:18:47 +0100 Subject: [PATCH 07/24] Add a make target for updating go module deps. --- Makefile | 5 ++++- README.md | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index e8bbb768..949be474 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all clean build test lint unit_tests integration_tests start_mongo stop_mongo +.PHONY: all clean build test lint unit_tests integration_tests start_mongo stop_mongo update_deps .NOTPARALLEL: TARGET_MODULE := router @@ -30,3 +30,6 @@ start_mongo: stop_mongo: ./mongo.sh stop + +update_deps: + go get -t -u ./... && go mod tidy && go mod vendor diff --git a/README.md b/README.md index c11eb43e..35b36d0c 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ This project uses [Go Modules](https://github.com/golang/go/wiki/Modules) to ven 1. Update all the dependencies, including test dependencies, in your working copy: ```sh - go get -t -u ./... && go mod tidy && go mod vendor + make update_deps ``` 1. Check for any errors and commit. From eab07923838a3c492f20ceb22c64c188cbaf0760 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 5 Aug 2023 09:16:50 +0100 Subject: [PATCH 08/24] Don't assume the default Prometheus registry. Registering metrics in a global namespace is neither test-friendly nor concurrency-safe. There's also the potential for name clashes to cause panics on startup, which can even happen implicitly on import because we've been calling prometheus.MustRegister() from module init functions. Injecting a prometheus.Registerer will allow us to write proper integration tests (and be able to run them in parallel), rather than almost everything having to be an end-to-end test that forks and execs an external Router binary. Also don't export Prometheus objects from modules. Nothing outside the module should depend on implementation details of the module's metrics. The public interface for reading metrics within the same process is prometheus.Gatherer. For everything else, there's `/metrics`. --- handlers/backend_handler.go | 4 ++-- handlers/backend_handler_test.go | 16 +++++++--------- handlers/handlers.go | 5 ++++- handlers/handlers_suite_test.go | 2 +- handlers/metrics.go | 17 +++++++++-------- handlers/redirect_handler.go | 4 ++-- handlers/redirect_handler_test.go | 18 ++++++++---------- lib/metrics.go | 14 +++++++------- lib/router.go | 6 ++++-- main.go | 2 -- triemux/metrics.go | 12 +++++++----- triemux/mux.go | 6 +++--- triemux/mux_test.go | 4 ++-- triemux/triemux.go | 5 ++++- 14 files changed, 60 insertions(+), 55 deletions(-) diff --git a/handlers/backend_handler.go b/handlers/backend_handler.go index e4ca9a7c..fce644fe 100644 --- a/handlers/backend_handler.go +++ b/handlers/backend_handler.go @@ -134,7 +134,7 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e var responseCode int var startTime = time.Now() - BackendHandlerRequestCountMetric.With(prometheus.Labels{ + backendHandlerRequestCountMetric.With(prometheus.Labels{ "backend_id": bt.backendID, "request_method": req.Method, }).Inc() @@ -142,7 +142,7 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e defer func() { durationSeconds := time.Since(startTime).Seconds() - BackendHandlerResponseDurationSecondsMetric.With(prometheus.Labels{ + backendHandlerResponseDurationSecondsMetric.With(prometheus.Labels{ "backend_id": bt.backendID, "request_method": req.Method, "response_code": fmt.Sprintf("%d", responseCode), diff --git a/handlers/backend_handler_test.go b/handlers/backend_handler_test.go index 2365e275..3618adc3 100644 --- a/handlers/backend_handler_test.go +++ b/handlers/backend_handler_test.go @@ -1,4 +1,4 @@ -package handlers_test +package handlers import ( "io" @@ -15,7 +15,6 @@ import ( promtest "github.com/prometheus/client_golang/prometheus/testutil" prommodel "github.com/prometheus/client_model/go" - "github.com/alphagov/router/handlers" log "github.com/alphagov/router/logger" ) @@ -51,7 +50,7 @@ var _ = Describe("Backend handler", func() { Context("when the backend times out", func() { BeforeEach(func() { - router = handlers.NewBackendHandler( + router = NewBackendHandler( "backend-timeout", backendURL, timeout, timeout, @@ -80,7 +79,7 @@ var _ = Describe("Backend handler", func() { Context("when the backend handles the connection", func() { BeforeEach(func() { - router = handlers.NewBackendHandler( + router = NewBackendHandler( "backend-handle", backendURL, timeout, timeout, @@ -141,15 +140,14 @@ var _ = Describe("Backend handler", func() { Context("metrics", func() { var ( - beforeRequestCountMetric float64 - + beforeRequestCountMetric float64 beforeResponseCountMetric float64 beforeResponseDurationSecondsMetric float64 ) measureRequestCount := func() float64 { return promtest.ToFloat64( - handlers.BackendHandlerRequestCountMetric.With(prometheus.Labels{ + backendHandlerRequestCountMetric.With(prometheus.Labels{ "backend_id": "backend-metrics", "request_method": http.MethodGet, }), @@ -160,7 +158,7 @@ var _ = Describe("Backend handler", func() { var err error metricChan := make(chan prometheus.Metric, 1024) - handlers.BackendHandlerResponseDurationSecondsMetric.Collect(metricChan) + backendHandlerResponseDurationSecondsMetric.Collect(metricChan) close(metricChan) for m := range metricChan { metric := new(prommodel.Metric) @@ -197,7 +195,7 @@ var _ = Describe("Backend handler", func() { } BeforeEach(func() { - router = handlers.NewBackendHandler( + router = NewBackendHandler( "backend-metrics", backendURL, timeout, timeout, diff --git a/handlers/handlers.go b/handlers/handlers.go index df249dda..7dc36f99 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -1,5 +1,8 @@ package handlers +import "github.com/prometheus/client_golang/prometheus" + +// TODO: don't use init for this. func init() { - initMetrics() + registerMetrics(prometheus.DefaultRegisterer) } diff --git a/handlers/handlers_suite_test.go b/handlers/handlers_suite_test.go index f3789b51..1ed17f51 100644 --- a/handlers/handlers_suite_test.go +++ b/handlers/handlers_suite_test.go @@ -1,4 +1,4 @@ -package handlers_test +package handlers import ( "testing" diff --git a/handlers/metrics.go b/handlers/metrics.go index 25983a98..cd653872 100644 --- a/handlers/metrics.go +++ b/handlers/metrics.go @@ -5,7 +5,7 @@ import ( ) var ( - RedirectHandlerRedirectCountMetric = prometheus.NewCounterVec( + redirectHandlerRedirectCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_redirect_handler_redirect_total", Help: "Number of redirects handled by router redirect handlers", @@ -16,7 +16,7 @@ var ( }, ) - BackendHandlerRequestCountMetric = prometheus.NewCounterVec( + backendHandlerRequestCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_backend_handler_request_total", Help: "Number of requests handled by router backend handlers", @@ -27,7 +27,7 @@ var ( }, ) - BackendHandlerResponseDurationSecondsMetric = prometheus.NewHistogramVec( + backendHandlerResponseDurationSecondsMetric = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Name: "router_backend_handler_response_duration_seconds", Help: "Histogram of response durations by router backend handlers", @@ -40,9 +40,10 @@ var ( ) ) -func initMetrics() { - prometheus.MustRegister(RedirectHandlerRedirectCountMetric) - - prometheus.MustRegister(BackendHandlerRequestCountMetric) - prometheus.MustRegister(BackendHandlerResponseDurationSecondsMetric) +func registerMetrics(r prometheus.Registerer) { + r.MustRegister( + backendHandlerRequestCountMetric, + backendHandlerResponseDurationSecondsMetric, + redirectHandlerRedirectCountMetric, + ) } diff --git a/handlers/redirect_handler.go b/handlers/redirect_handler.go index 80529f97..1bbdf259 100644 --- a/handlers/redirect_handler.go +++ b/handlers/redirect_handler.go @@ -61,7 +61,7 @@ func (handler *redirectHandler) ServeHTTP(writer http.ResponseWriter, request *h target := addGAQueryParam(handler.url, request) http.Redirect(writer, request, target, handler.code) - RedirectHandlerRedirectCountMetric.With(prometheus.Labels{ + redirectHandlerRedirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": redirectHandlerType, }).Inc() @@ -82,7 +82,7 @@ func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWrit addCacheHeaders(writer) http.Redirect(writer, request, target, handler.code) - RedirectHandlerRedirectCountMetric.With(prometheus.Labels{ + redirectHandlerRedirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": pathPreservingRedirectHandlerType, }).Inc() diff --git a/handlers/redirect_handler_test.go b/handlers/redirect_handler_test.go index f6b48412..51c93239 100644 --- a/handlers/redirect_handler_test.go +++ b/handlers/redirect_handler_test.go @@ -1,4 +1,4 @@ -package handlers_test +package handlers import ( "fmt" @@ -11,8 +11,6 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" - - "github.com/alphagov/router/handlers" ) var _ = Describe("A redirect handler", func() { @@ -29,7 +27,7 @@ var _ = Describe("A redirect handler", func() { for _, temporary := range []bool{true, false} { Context(fmt.Sprintf("where preserve=%t, temporary=%t", preserve, temporary), func() { BeforeEach(func() { - handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) }) @@ -53,7 +51,7 @@ var _ = Describe("A redirect handler", func() { Context("where preserve=true", func() { BeforeEach(func() { - handler = handlers.NewRedirectHandler("/source", "/target", true, false) + handler = NewRedirectHandler("/source", "/target", true, false) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) }) @@ -64,7 +62,7 @@ var _ = Describe("A redirect handler", func() { Context("where preserve=false", func() { BeforeEach(func() { - handler = handlers.NewRedirectHandler("/source", "/target", false, false) + handler = NewRedirectHandler("/source", "/target", false, false) }) It("returns only the configured path in the location header", func() { @@ -86,7 +84,7 @@ var _ = Describe("A redirect handler", func() { Entry(nil, true, false, http.StatusMovedPermanently), Entry(nil, true, true, http.StatusFound), func(preserve, temporary bool, expectedStatus int) { - handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) Expect(rr.Result().StatusCode).To(Equal(expectedStatus)) }) @@ -99,12 +97,12 @@ var _ = Describe("A redirect handler", func() { Entry(nil, true, true, "302", "path-preserving-redirect-handler"), func(preserve, temporary bool, codeLabel, typeLabel string) { lbls := prometheus.Labels{"redirect_code": codeLabel, "redirect_type": typeLabel} - before := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(lbls)) + before := promtest.ToFloat64(redirectHandlerRedirectCountMetric.With(lbls)) - handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) - after := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(lbls)) + after := promtest.ToFloat64(redirectHandlerRedirectCountMetric.With(lbls)) Expect(after - before).To(BeNumerically("~", 1.0)) }, ) diff --git a/lib/metrics.go b/lib/metrics.go index 50164a25..a4e52826 100644 --- a/lib/metrics.go +++ b/lib/metrics.go @@ -35,11 +35,11 @@ var ( ) ) -func InitMetrics() { - prometheus.MustRegister(internalServerErrorCountMetric) - - prometheus.MustRegister(routeReloadCountMetric) - prometheus.MustRegister(routeReloadErrorCountMetric) - - prometheus.MustRegister(routesCountMetric) +func registerMetrics(r prometheus.Registerer) { + r.MustRegister( + internalServerErrorCountMetric, + routeReloadCountMetric, + routeReloadErrorCountMetric, + routesCountMetric, + ) } diff --git a/lib/router.go b/lib/router.go index d95abe39..651c0b4a 100644 --- a/lib/router.go +++ b/lib/router.go @@ -80,14 +80,16 @@ func NewRouter(mongoURL, mongoDbName string, mongoPollInterval, beConnTimeout, b if err != nil { return nil, err } + logInfo("router: logging errors as JSON to", logFileName) + + // TODO: avoid using the global registry. + registerMetrics(prometheus.DefaultRegisterer) mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1) if err != nil { return nil, err } - logInfo("router: logging errors as JSON to", logFileName) - reloadChan := make(chan bool, 1) rt = &Router{ mux: triemux.NewMux(), diff --git a/main.go b/main.go index e3574000..739f0b55 100644 --- a/main.go +++ b/main.go @@ -100,8 +100,6 @@ func main() { beHeaderTimeout := parseDurationOrFatal(backendHeaderTimeout) mgoPollInterval := parseDurationOrFatal(mongoPollInterval) - router.InitMetrics() - log.Printf("using frontend read timeout: %v", feReadTimeout) log.Printf("using frontend write timeout: %v", feWriteTimeout) log.Printf("using GOMAXPROCS value of %d", runtime.GOMAXPROCS(0)) diff --git a/triemux/metrics.go b/triemux/metrics.go index 8a2846e4..e2181be9 100644 --- a/triemux/metrics.go +++ b/triemux/metrics.go @@ -5,14 +5,14 @@ import ( ) var ( - EntryNotFoundCountMetric = prometheus.NewCounter( + entryNotFoundCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_triemux_entry_not_found_total", Help: "Number of triemux lookups for which an entry was not found", }, ) - InternalServiceUnavailableCountMetric = prometheus.NewCounterVec( + internalServiceUnavailableCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_service_unavailable_error_total", Help: "Number of 503 Service Unavailable errors served by router", @@ -21,7 +21,9 @@ var ( ) ) -func initMetrics() { - prometheus.MustRegister(EntryNotFoundCountMetric) - prometheus.MustRegister(InternalServiceUnavailableCountMetric) +func registerMetrics(r prometheus.Registerer) { + r.MustRegister( + entryNotFoundCountMetric, + internalServiceUnavailableCountMetric, + ) } diff --git a/triemux/mux.go b/triemux/mux.go index 6f5870c1..51418a69 100644 --- a/triemux/mux.go +++ b/triemux/mux.go @@ -48,7 +48,7 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !isParent { tempChild = "0" } - InternalServiceUnavailableCountMetric.With(prometheus.Labels{ + internalServiceUnavailableCountMetric.With(prometheus.Labels{ "temporary_child": tempChild, }).Inc() return @@ -75,14 +75,14 @@ func (mux *Mux) lookup(path string) (handler http.Handler, ok bool) { val, ok = mux.prefixTrie.GetLongestPrefix(pathSegments) } if !ok { - EntryNotFoundCountMetric.Inc() + entryNotFoundCountMetric.Inc() return nil, false } entry, ok := val.(muxEntry) if !ok { log.Printf("lookup: got value (%v) from trie that wasn't a muxEntry!", val) - EntryNotFoundCountMetric.Inc() + entryNotFoundCountMetric.Inc() return nil, false } diff --git a/triemux/mux_test.go b/triemux/mux_test.go index 5d2c9367..6452b9be 100644 --- a/triemux/mux_test.go +++ b/triemux/mux_test.go @@ -162,13 +162,13 @@ var lookupExamples = []LookupExample{ } func TestLookup(t *testing.T) { - beforeCount := promtest.ToFloat64(EntryNotFoundCountMetric) + beforeCount := promtest.ToFloat64(entryNotFoundCountMetric) for _, ex := range lookupExamples { testLookup(t, ex) } - afterCount := promtest.ToFloat64(EntryNotFoundCountMetric) + afterCount := promtest.ToFloat64(entryNotFoundCountMetric) notFoundCount := afterCount - beforeCount var expectedNotFoundCount int diff --git a/triemux/triemux.go b/triemux/triemux.go index fb15e185..3bea1799 100644 --- a/triemux/triemux.go +++ b/triemux/triemux.go @@ -1,5 +1,8 @@ package triemux +import "github.com/prometheus/client_golang/prometheus" + +// TODO: don't use init for this. func init() { - initMetrics() + registerMetrics(prometheus.DefaultRegisterer) } From a3818f25664cbbf42a97a01246a8f7375847663b Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 5 Aug 2023 17:07:28 +0100 Subject: [PATCH 09/24] Make metric var names less excessively verbose. --- handlers/backend_handler.go | 4 ++-- handlers/backend_handler_test.go | 4 ++-- handlers/metrics.go | 12 ++++++------ handlers/redirect_handler.go | 4 ++-- handlers/redirect_handler_test.go | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/handlers/backend_handler.go b/handlers/backend_handler.go index fce644fe..7df7663b 100644 --- a/handlers/backend_handler.go +++ b/handlers/backend_handler.go @@ -134,7 +134,7 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e var responseCode int var startTime = time.Now() - backendHandlerRequestCountMetric.With(prometheus.Labels{ + backendRequestCountMetric.With(prometheus.Labels{ "backend_id": bt.backendID, "request_method": req.Method, }).Inc() @@ -142,7 +142,7 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e defer func() { durationSeconds := time.Since(startTime).Seconds() - backendHandlerResponseDurationSecondsMetric.With(prometheus.Labels{ + backendResponseDurationSecondsMetric.With(prometheus.Labels{ "backend_id": bt.backendID, "request_method": req.Method, "response_code": fmt.Sprintf("%d", responseCode), diff --git a/handlers/backend_handler_test.go b/handlers/backend_handler_test.go index 3618adc3..1fc5f45b 100644 --- a/handlers/backend_handler_test.go +++ b/handlers/backend_handler_test.go @@ -147,7 +147,7 @@ var _ = Describe("Backend handler", func() { measureRequestCount := func() float64 { return promtest.ToFloat64( - backendHandlerRequestCountMetric.With(prometheus.Labels{ + backendRequestCountMetric.With(prometheus.Labels{ "backend_id": "backend-metrics", "request_method": http.MethodGet, }), @@ -158,7 +158,7 @@ var _ = Describe("Backend handler", func() { var err error metricChan := make(chan prometheus.Metric, 1024) - backendHandlerResponseDurationSecondsMetric.Collect(metricChan) + backendResponseDurationSecondsMetric.Collect(metricChan) close(metricChan) for m := range metricChan { metric := new(prommodel.Metric) diff --git a/handlers/metrics.go b/handlers/metrics.go index cd653872..9e3c1bb7 100644 --- a/handlers/metrics.go +++ b/handlers/metrics.go @@ -5,7 +5,7 @@ import ( ) var ( - redirectHandlerRedirectCountMetric = prometheus.NewCounterVec( + redirectCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_redirect_handler_redirect_total", Help: "Number of redirects handled by router redirect handlers", @@ -16,7 +16,7 @@ var ( }, ) - backendHandlerRequestCountMetric = prometheus.NewCounterVec( + backendRequestCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_backend_handler_request_total", Help: "Number of requests handled by router backend handlers", @@ -27,7 +27,7 @@ var ( }, ) - backendHandlerResponseDurationSecondsMetric = prometheus.NewHistogramVec( + backendResponseDurationSecondsMetric = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Name: "router_backend_handler_response_duration_seconds", Help: "Histogram of response durations by router backend handlers", @@ -42,8 +42,8 @@ var ( func registerMetrics(r prometheus.Registerer) { r.MustRegister( - backendHandlerRequestCountMetric, - backendHandlerResponseDurationSecondsMetric, - redirectHandlerRedirectCountMetric, + backendRequestCountMetric, + backendResponseDurationSecondsMetric, + redirectCountMetric, ) } diff --git a/handlers/redirect_handler.go b/handlers/redirect_handler.go index 1bbdf259..8fcdd7c7 100644 --- a/handlers/redirect_handler.go +++ b/handlers/redirect_handler.go @@ -61,7 +61,7 @@ func (handler *redirectHandler) ServeHTTP(writer http.ResponseWriter, request *h target := addGAQueryParam(handler.url, request) http.Redirect(writer, request, target, handler.code) - redirectHandlerRedirectCountMetric.With(prometheus.Labels{ + redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": redirectHandlerType, }).Inc() @@ -82,7 +82,7 @@ func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWrit addCacheHeaders(writer) http.Redirect(writer, request, target, handler.code) - redirectHandlerRedirectCountMetric.With(prometheus.Labels{ + redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": pathPreservingRedirectHandlerType, }).Inc() diff --git a/handlers/redirect_handler_test.go b/handlers/redirect_handler_test.go index 51c93239..7a6c6dc1 100644 --- a/handlers/redirect_handler_test.go +++ b/handlers/redirect_handler_test.go @@ -97,12 +97,12 @@ var _ = Describe("A redirect handler", func() { Entry(nil, true, true, "302", "path-preserving-redirect-handler"), func(preserve, temporary bool, codeLabel, typeLabel string) { lbls := prometheus.Labels{"redirect_code": codeLabel, "redirect_type": typeLabel} - before := promtest.ToFloat64(redirectHandlerRedirectCountMetric.With(lbls)) + before := promtest.ToFloat64(redirectCountMetric.With(lbls)) handler = NewRedirectHandler("/source", "/target", preserve, temporary) handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) - after := promtest.ToFloat64(redirectHandlerRedirectCountMetric.With(lbls)) + after := promtest.ToFloat64(redirectCountMetric.With(lbls)) Expect(after - before).To(BeNumerically("~", 1.0)) }, ) From cd028461f870b298995e23e8203b329924bb6c67 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 5 Aug 2023 16:45:31 +0100 Subject: [PATCH 10/24] Remove a Tablecloth reference missed in acfdcc7. --- triemux/metrics.go | 3 +-- triemux/mux.go | 13 ++----------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/triemux/metrics.go b/triemux/metrics.go index e2181be9..86636bd8 100644 --- a/triemux/metrics.go +++ b/triemux/metrics.go @@ -12,12 +12,11 @@ var ( }, ) - internalServiceUnavailableCountMetric = prometheus.NewCounterVec( + internalServiceUnavailableCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_service_unavailable_error_total", Help: "Number of 503 Service Unavailable errors served by router", }, - []string{"temporary_child"}, ) ) diff --git a/triemux/mux.go b/triemux/mux.go index 51418a69..da2f9295 100644 --- a/triemux/mux.go +++ b/triemux/mux.go @@ -6,14 +6,11 @@ package triemux import ( "log" "net/http" - "os" "strings" "sync" "github.com/alphagov/router/logger" "github.com/alphagov/router/trie" - - "github.com/prometheus/client_golang/prometheus" ) type Mux struct { @@ -41,16 +38,10 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if mux.count == 0 { w.WriteHeader(http.StatusServiceUnavailable) logger.NotifySentry(logger.ReportableError{ - Error: logger.RecoveredError{ErrorMessage: "Route table is empty!"}, + Error: logger.RecoveredError{ErrorMessage: "route table is empty"}, Request: r, }) - tempChild, isParent := os.LookupEnv("TEMPORARY_CHILD") - if !isParent { - tempChild = "0" - } - internalServiceUnavailableCountMetric.With(prometheus.Labels{ - "temporary_child": tempChild, - }).Inc() + internalServiceUnavailableCountMetric.Inc() return } From 98e6e0d927e64a2641e9480ff6a952732ed81216 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 5 Aug 2023 17:34:15 +0100 Subject: [PATCH 11/24] Clarify metric descriptions. --- handlers/metrics.go | 6 +++--- lib/metrics.go | 2 +- triemux/metrics.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/handlers/metrics.go b/handlers/metrics.go index 9e3c1bb7..77b11565 100644 --- a/handlers/metrics.go +++ b/handlers/metrics.go @@ -8,7 +8,7 @@ var ( redirectCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_redirect_handler_redirect_total", - Help: "Number of redirects handled by router redirect handlers", + Help: "Number of redirects served by redirect handlers", }, []string{ "redirect_code", @@ -19,7 +19,7 @@ var ( backendRequestCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_backend_handler_request_total", - Help: "Number of requests handled by router backend handlers", + Help: "Number of requests served by backend handlers", }, []string{ "backend_id", @@ -30,7 +30,7 @@ var ( backendResponseDurationSecondsMetric = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Name: "router_backend_handler_response_duration_seconds", - Help: "Histogram of response durations by router backend handlers", + Help: "Histogram of response durations by backend", }, []string{ "backend_id", diff --git a/lib/metrics.go b/lib/metrics.go index a4e52826..522ad0f4 100644 --- a/lib/metrics.go +++ b/lib/metrics.go @@ -8,7 +8,7 @@ var ( internalServerErrorCountMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "router_internal_server_error_total", - Help: "Number of internal server errors encountered by router", + Help: "Number of 500 Internal Server Error responses originating from Router", }, []string{"host"}, ) diff --git a/triemux/metrics.go b/triemux/metrics.go index 86636bd8..c837e339 100644 --- a/triemux/metrics.go +++ b/triemux/metrics.go @@ -8,14 +8,14 @@ var ( entryNotFoundCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_triemux_entry_not_found_total", - Help: "Number of triemux lookups for which an entry was not found", + Help: "Number of route lookups for which no route was found", }, ) internalServiceUnavailableCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_service_unavailable_error_total", - Help: "Number of 503 Service Unavailable errors served by router", + Help: "Number of 503 Service Unavailable errors originating from Router", }, ) ) From 68cf596a8009d08c19bf81dda95325d6de9c7cf0 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 5 Aug 2023 23:42:41 +0100 Subject: [PATCH 12/24] Pass NewRouter a struct instead of too many params. NewRouter()'s parameter list was already getting out of hand. Copying it straight into struct Router for now just for the sake of convenience. Probably change that soon. --- lib/router.go | 72 +++++++++++++++++++++++++++++---------------------- main.go | 72 +++++++++++++++++++++++++++------------------------ 2 files changed, 79 insertions(+), 65 deletions(-) diff --git a/lib/router.go b/lib/router.go index 651c0b4a..65b01f55 100644 --- a/lib/router.go +++ b/lib/router.go @@ -29,17 +29,30 @@ const ( // Router is a wrapper around an HTTP multiplexer (trie.Mux) which retrieves its // routes from a passed mongo database. +// +// TODO: decouple Router from its database backend. Router should not know +// anything about the database backend. Its representation of the route table +// should be independent of the underlying DBMS. Route should define an +// abstract interface for some other module to be able to bulk-load and +// incrementally update routes. Since Router should not care where its routes +// come from, Route and Backend should not contain bson fields. +// MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { - mux *triemux.Mux - lock sync.RWMutex - mongoURL string - mongoDbName string - mongoPollInterval time.Duration - backendConnectTimeout time.Duration - backendHeaderTimeout time.Duration - mongoReadToOptime bson.MongoTimestamp - logger logger.Logger - ReloadChan chan bool + mux *triemux.Mux + lock sync.RWMutex + mongoReadToOptime bson.MongoTimestamp + logger logger.Logger + opts Options + ReloadChan chan bool +} + +type Options struct { + MongoURL string + MongoDbName string + MongoPollInterval time.Duration + BackendConnTimeout time.Duration + BackendHeaderTimeout time.Duration + LogFileName string } type Backend struct { @@ -71,16 +84,16 @@ type Route struct { // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. -func NewRouter(mongoURL, mongoDbName string, mongoPollInterval, beConnTimeout, beHeaderTimeout time.Duration, logFileName string) (rt *Router, err error) { - logInfo("router: using mongo poll interval:", mongoPollInterval) - logInfo("router: using backend connect timeout:", beConnTimeout) - logInfo("router: using backend header timeout:", beHeaderTimeout) +func NewRouter(o Options) (rt *Router, err error) { + logInfo("router: using mongo poll interval:", o.MongoPollInterval) + logInfo("router: using backend connect timeout:", o.BackendConnTimeout) + logInfo("router: using backend header timeout:", o.BackendHeaderTimeout) - l, err := logger.New(logFileName) + l, err := logger.New(o.LogFileName) if err != nil { return nil, err } - logInfo("router: logging errors as JSON to", logFileName) + logInfo("router: logging errors as JSON to", o.LogFileName) // TODO: avoid using the global registry. registerMetrics(prometheus.DefaultRegisterer) @@ -92,15 +105,11 @@ func NewRouter(mongoURL, mongoDbName string, mongoPollInterval, beConnTimeout, b reloadChan := make(chan bool, 1) rt = &Router{ - mux: triemux.NewMux(), - mongoURL: mongoURL, - mongoPollInterval: mongoPollInterval, - mongoDbName: mongoDbName, - backendConnectTimeout: beConnTimeout, - backendHeaderTimeout: beHeaderTimeout, - mongoReadToOptime: mongoReadToOptime, - logger: l, - ReloadChan: reloadChan, + mux: triemux.NewMux(), + mongoReadToOptime: mongoReadToOptime, + logger: l, + opts: o, + ReloadChan: reloadChan, } go rt.pollAndReload() @@ -137,9 +146,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (rt *Router) SelfUpdateRoutes() { - logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.mongoPollInterval)) + logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval)) - tick := time.Tick(rt.mongoPollInterval) + tick := time.Tick(rt.opts.MongoPollInterval) for range tick { logDebug("router: polling MongoDB for changes") @@ -159,9 +168,9 @@ func (rt *Router) pollAndReload() { } }() - logDebug("mgo: connecting to", rt.mongoURL) + logDebug("mgo: connecting to", rt.opts.MongoURL) - sess, err := mgo.Dial(rt.mongoURL) + sess, err := mgo.Dial(rt.opts.MongoURL) if err != nil { logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err)) return @@ -184,7 +193,7 @@ func (rt *Router) pollAndReload() { if rt.shouldReload(currentMongoInstance) { logDebug("router: updates found") - rt.reloadRoutes(sess.DB(rt.mongoDbName), currentMongoInstance.Optime) + rt.reloadRoutes(sess.DB(rt.opts.MongoDbName), currentMongoInstance.Optime) } else { logDebug("router: no updates found") } @@ -290,7 +299,8 @@ func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Hand backends[backend.BackendID] = handlers.NewBackendHandler( backend.BackendID, backendURL, - rt.backendConnectTimeout, rt.backendHeaderTimeout, + rt.opts.BackendConnTimeout, + rt.opts.BackendHeaderTimeout, rt.logger, ) } diff --git a/main.go b/main.go index 739f0b55..8e636dd4 100644 --- a/main.go +++ b/main.go @@ -40,13 +40,24 @@ ROUTER_FRONTEND_WRITE_TIMEOUT=60s See https://cs.opensource.google/go/go/+/mast os.Exit(ErrUsage) } -func getenvDefault(key string, defaultVal string) string { - val := os.Getenv(key) - if val == "" { - val = defaultVal +func getenv(key string, defaultVal string) string { + if s := os.Getenv(key); s != "" { + return s } + return defaultVal +} - return val +func getenvDuration(key string, defaultVal string) time.Duration { + s := getenv(key, defaultVal) + return mustParseDuration(s) +} + +func mustParseDuration(s string) (d time.Duration) { + d, err := time.ParseDuration(s) + if err != nil { + log.Fatal(err) + } + return } func listenAndServeOrFatal(addr string, handler http.Handler, rTimeout time.Duration, wTimeout time.Duration) { @@ -61,30 +72,7 @@ func listenAndServeOrFatal(addr string, handler http.Handler, rTimeout time.Dura } } -func parseDurationOrFatal(s string) (d time.Duration) { - d, err := time.ParseDuration(s) - if err != nil { - log.Fatal(err) - } - return -} - func main() { - router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" - var ( - pubAddr = getenvDefault("ROUTER_PUBADDR", ":8080") - apiAddr = getenvDefault("ROUTER_APIADDR", ":8081") - mongoURL = getenvDefault("ROUTER_MONGO_URL", "127.0.0.1") - mongoDbName = getenvDefault("ROUTER_MONGO_DB", "router") - mongoPollInterval = getenvDefault("ROUTER_MONGO_POLL_INTERVAL", "2s") - errorLogFile = getenvDefault("ROUTER_ERROR_LOG", "STDERR") - tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - backendConnectTimeout = getenvDefault("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") - backendHeaderTimeout = getenvDefault("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") - frontendReadTimeout = getenvDefault("ROUTER_FRONTEND_READ_TIMEOUT", "60s") - frontendWriteTimeout = getenvDefault("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") - ) - returnVersion := flag.Bool("version", false, "") flag.Usage = usage flag.Parse() @@ -94,11 +82,20 @@ func main() { os.Exit(0) } - feReadTimeout := parseDurationOrFatal(frontendReadTimeout) - feWriteTimeout := parseDurationOrFatal(frontendWriteTimeout) - beConnectTimeout := parseDurationOrFatal(backendConnectTimeout) - beHeaderTimeout := parseDurationOrFatal(backendHeaderTimeout) - mgoPollInterval := parseDurationOrFatal(mongoPollInterval) + router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" + var ( + pubAddr = getenv("ROUTER_PUBADDR", ":8080") + apiAddr = getenv("ROUTER_APIADDR", ":8081") + mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1") + mongoDbName = getenv("ROUTER_MONGO_DB", "router") + mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s") + errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") + tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" + beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") + beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") + feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") + feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + ) log.Printf("using frontend read timeout: %v", feReadTimeout) log.Printf("using frontend write timeout: %v", feWriteTimeout) @@ -110,7 +107,14 @@ func main() { "Do not use this option in a production environment.") } - rout, err := router.NewRouter(mongoURL, mongoDbName, mgoPollInterval, beConnectTimeout, beHeaderTimeout, errorLogFile) + rout, err := router.NewRouter(router.Options{ + MongoURL: mongoURL, + MongoDbName: mongoDbName, + MongoPollInterval: mongoPollInterval, + BackendConnTimeout: beConnTimeout, + BackendHeaderTimeout: beHeaderTimeout, + LogFileName: errorLogFile, + }) if err != nil { log.Fatal(err) } From d444d646ea8f410cbf74dd7a4dea972d4aa82d24 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sun, 6 Aug 2023 18:51:25 +0100 Subject: [PATCH 13/24] Enable stylecheck linter and fix errors. --- .golangci.yml | 6 ++++++ lib/router.go | 4 ++-- logger/logger.go | 2 +- main.go | 4 ++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 0060b70f..916bb7f4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,9 +20,15 @@ linters: - promlinter - reassign - revive + - stylecheck - tenv - tparallel - unconvert - usestdlibvars - wastedassign - zerologlint +linters-settings: + stylecheck: + dot-import-whitelist: + - github.com/onsi/ginkgo/v2 + - github.com/onsi/gomega diff --git a/lib/router.go b/lib/router.go index 65b01f55..54939dd2 100644 --- a/lib/router.go +++ b/lib/router.go @@ -48,7 +48,7 @@ type Router struct { type Options struct { MongoURL string - MongoDbName string + MongoDBName string MongoPollInterval time.Duration BackendConnTimeout time.Duration BackendHeaderTimeout time.Duration @@ -193,7 +193,7 @@ func (rt *Router) pollAndReload() { if rt.shouldReload(currentMongoInstance) { logDebug("router: updates found") - rt.reloadRoutes(sess.DB(rt.opts.MongoDbName), currentMongoInstance.Optime) + rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime) } else { logDebug("router: no updates found") } diff --git a/logger/logger.go b/logger/logger.go index 7bfb7790..74900405 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -58,7 +58,7 @@ func openWriter(output interface{}) (w io.Writer, err error) { } } default: - return nil, fmt.Errorf("Invalid output type %T(%v)", output, output) + return nil, fmt.Errorf("invalid output type %T(%v)", output, output) } return } diff --git a/main.go b/main.go index 8e636dd4..eb27205d 100644 --- a/main.go +++ b/main.go @@ -87,7 +87,7 @@ func main() { pubAddr = getenv("ROUTER_PUBADDR", ":8080") apiAddr = getenv("ROUTER_APIADDR", ":8081") mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1") - mongoDbName = getenv("ROUTER_MONGO_DB", "router") + mongoDBName = getenv("ROUTER_MONGO_DB", "router") mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s") errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" @@ -109,7 +109,7 @@ func main() { rout, err := router.NewRouter(router.Options{ MongoURL: mongoURL, - MongoDbName: mongoDbName, + MongoDBName: mongoDBName, MongoPollInterval: mongoPollInterval, BackendConnTimeout: beConnTimeout, BackendHeaderTimeout: beHeaderTimeout, From e1d62af66301aa209fd6b2faa9a6bb355dc07b97 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sun, 6 Aug 2023 19:33:14 +0100 Subject: [PATCH 14/24] Avoid automatically registering metrics globally. Only the main module now touches the default Prometheus registry. We want to be able to refactor some of our unnecessarily-end-to-end tests into integration tests (in the true sense) and unit tests. To do that, we can't have components messing with global state. This also improves separation of concerns generally. For example, the triemux and handlers modules can no longer cause panics simply by being imported if there's a metric name clash. --- handlers/handlers.go | 8 -------- handlers/metrics.go | 2 +- lib/metrics.go | 4 ++++ lib/router.go | 10 +++++++--- main.go | 3 +++ triemux/metrics.go | 2 +- triemux/triemux.go | 8 -------- 7 files changed, 16 insertions(+), 21 deletions(-) delete mode 100644 handlers/handlers.go delete mode 100644 triemux/triemux.go diff --git a/handlers/handlers.go b/handlers/handlers.go deleted file mode 100644 index 7dc36f99..00000000 --- a/handlers/handlers.go +++ /dev/null @@ -1,8 +0,0 @@ -package handlers - -import "github.com/prometheus/client_golang/prometheus" - -// TODO: don't use init for this. -func init() { - registerMetrics(prometheus.DefaultRegisterer) -} diff --git a/handlers/metrics.go b/handlers/metrics.go index 77b11565..a2f3331a 100644 --- a/handlers/metrics.go +++ b/handlers/metrics.go @@ -40,7 +40,7 @@ var ( ) ) -func registerMetrics(r prometheus.Registerer) { +func RegisterMetrics(r prometheus.Registerer) { r.MustRegister( backendRequestCountMetric, backendResponseDurationSecondsMetric, diff --git a/lib/metrics.go b/lib/metrics.go index 522ad0f4..36b0addc 100644 --- a/lib/metrics.go +++ b/lib/metrics.go @@ -1,6 +1,8 @@ package router import ( + "github.com/alphagov/router/handlers" + "github.com/alphagov/router/triemux" "github.com/prometheus/client_golang/prometheus" ) @@ -42,4 +44,6 @@ func registerMetrics(r prometheus.Registerer) { routeReloadErrorCountMetric, routesCountMetric, ) + handlers.RegisterMetrics(r) + triemux.RegisterMetrics(r) } diff --git a/lib/router.go b/lib/router.go index 54939dd2..e41003bd 100644 --- a/lib/router.go +++ b/lib/router.go @@ -82,6 +82,13 @@ type Route struct { Disabled bool `bson:"disabled"` } +// RegisterMetrics registers Prometheus metrics from the router module and the +// modules that it directly depends on. To use the default (global) registry, +// pass prometheus.DefaultRegisterer. +func RegisterMetrics(r prometheus.Registerer) { + registerMetrics(r) +} + // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. func NewRouter(o Options) (rt *Router, err error) { @@ -95,9 +102,6 @@ func NewRouter(o Options) (rt *Router, err error) { } logInfo("router: logging errors as JSON to", o.LogFileName) - // TODO: avoid using the global registry. - registerMetrics(prometheus.DefaultRegisterer) - mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1) if err != nil { return nil, err diff --git a/main.go b/main.go index eb27205d..f2d878e3 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "github.com/alphagov/router/handlers" router "github.com/alphagov/router/lib" + "github.com/prometheus/client_golang/prometheus" ) func usage() { @@ -107,6 +108,8 @@ func main() { "Do not use this option in a production environment.") } + router.RegisterMetrics(prometheus.DefaultRegisterer) + rout, err := router.NewRouter(router.Options{ MongoURL: mongoURL, MongoDBName: mongoDBName, diff --git a/triemux/metrics.go b/triemux/metrics.go index c837e339..187e7177 100644 --- a/triemux/metrics.go +++ b/triemux/metrics.go @@ -20,7 +20,7 @@ var ( ) ) -func registerMetrics(r prometheus.Registerer) { +func RegisterMetrics(r prometheus.Registerer) { r.MustRegister( entryNotFoundCountMetric, internalServiceUnavailableCountMetric, diff --git a/triemux/triemux.go b/triemux/triemux.go deleted file mode 100644 index 3bea1799..00000000 --- a/triemux/triemux.go +++ /dev/null @@ -1,8 +0,0 @@ -package triemux - -import "github.com/prometheus/client_golang/prometheus" - -// TODO: don't use init for this. -func init() { - registerMetrics(prometheus.DefaultRegisterer) -} From 4a1900fdfff7255c353724f91201158024ed5ce8 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 17:45:04 +0100 Subject: [PATCH 15/24] Remove unnecessary envMap type from test helpers. It's way simpler to use cmd.Env. --- integration_tests/envmap.go | 24 ------------------------ integration_tests/integration_test.go | 2 +- integration_tests/proxy_function_test.go | 6 +++--- integration_tests/route_loading_test.go | 5 +++-- integration_tests/router_support.go | 19 ++++++------------- 5 files changed, 13 insertions(+), 43 deletions(-) delete mode 100644 integration_tests/envmap.go diff --git a/integration_tests/envmap.go b/integration_tests/envmap.go deleted file mode 100644 index 42d49fb0..00000000 --- a/integration_tests/envmap.go +++ /dev/null @@ -1,24 +0,0 @@ -package integration - -import ( - "strings" -) - -type envMap map[string]string - -func newEnvMap(env []string) (em envMap) { - em = make(map[string]string, len(env)) - for _, item := range env { - parts := strings.SplitN(item, "=", 2) - em[parts[0]] = parts[1] - } - return -} - -func (em envMap) ToEnv() (env []string) { - env = make([]string, 0, len(em)) - for k, v := range em { - env = append(env, k+"="+v) - } - return -} diff --git a/integration_tests/integration_test.go b/integration_tests/integration_test.go index d92a094b..853d311a 100644 --- a/integration_tests/integration_test.go +++ b/integration_tests/integration_test.go @@ -20,7 +20,7 @@ var _ = BeforeSuite(func() { if err != nil { Fail(err.Error()) } - err = startRouter(3169, 3168) + err = startRouter(3169, 3168, nil) if err != nil { Fail(err.Error()) } diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index bd4500f3..bf30c1da 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -42,7 +42,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should log and return a 504 if the connection times out in the configured time", func() { - err := startRouter(3167, 3166, envMap{"ROUTER_BACKEND_CONNECT_TIMEOUT": "0.3s"}) + err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_CONNECT_TIMEOUT=0.3s"}) Expect(err).NotTo(HaveOccurred()) defer stopRouter(3167) @@ -80,7 +80,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { ) BeforeEach(func() { - err := startRouter(3167, 3166, envMap{"ROUTER_BACKEND_HEADER_TIMEOUT": "0.3s"}) + err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s"}) Expect(err).NotTo(HaveOccurred()) tarpit1 = startTarpitBackend(time.Second) tarpit2 = startTarpitBackend(100*time.Millisecond, 500*time.Millisecond) @@ -387,7 +387,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { var recorder *ghttp.Server BeforeEach(func() { - err := startRouter(3167, 3166, envMap{"ROUTER_TLS_SKIP_VERIFY": "1"}) + err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"}) Expect(err).NotTo(HaveOccurred()) recorder = startRecordingTLSBackend() addBackend("backend", recorder.URL()) diff --git a/integration_tests/route_loading_test.go b/integration_tests/route_loading_test.go index 38f491a8..6408b79c 100644 --- a/integration_tests/route_loading_test.go +++ b/integration_tests/route_loading_test.go @@ -1,6 +1,7 @@ package integration import ( + "fmt" "net/http/httptest" . "github.com/onsi/ginkgo/v2" @@ -82,7 +83,7 @@ var _ = Describe("loading routes from the db", func() { addBackend("backend-3", blackHole) stopRouter(3169) - err := startRouter(3169, 3168, envMap{"BACKEND_URL_backend-3": backend3.URL}) + err := startRouter(3169, 3168, []string{fmt.Sprintf("BACKEND_URL_backend-3=%s", backend3.URL)}) Expect(err).NotTo(HaveOccurred()) addRoute("/oof", NewBackendRoute("backend-3")) @@ -91,7 +92,7 @@ var _ = Describe("loading routes from the db", func() { AfterEach(func() { stopRouter(3169) - err := startRouter(3169, 3168) + err := startRouter(3169, 3168, nil) Expect(err).NotTo(HaveOccurred()) backend3.Close() }) diff --git a/integration_tests/router_support.go b/integration_tests/router_support.go index 2daf51f1..af15f04c 100644 --- a/integration_tests/router_support.go +++ b/integration_tests/router_support.go @@ -52,7 +52,7 @@ func reloadRoutes(optionalPort ...int) { var runningRouters = make(map[int]*exec.Cmd) -func startRouter(port, apiPort int, optionalExtraEnv ...envMap) error { +func startRouter(port, apiPort int, extraEnv []string) error { pubaddr := fmt.Sprintf(":%d", port) apiaddr := fmt.Sprintf(":%d", apiPort) @@ -62,18 +62,11 @@ func startRouter(port, apiPort int, optionalExtraEnv ...envMap) error { } cmd := exec.Command(bin) - env := newEnvMap(os.Environ()) - env["ROUTER_PUBADDR"] = pubaddr - env["ROUTER_APIADDR"] = apiaddr - env["ROUTER_MONGO_DB"] = "router_test" - env["ROUTER_MONGO_POLL_INTERVAL"] = "2s" - env["ROUTER_ERROR_LOG"] = tempLogfile.Name() - if len(optionalExtraEnv) > 0 { - for k, v := range optionalExtraEnv[0] { - env[k] = v - } - } - cmd.Env = env.ToEnv() + cmd.Env = append(cmd.Environ(), "ROUTER_MONGO_DB=router_test") + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_PUBADDR=%s", pubaddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_APIADDR=%s", apiaddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_ERROR_LOG=%s", tempLogfile.Name())) + cmd.Env = append(cmd.Env, extraEnv...) if os.Getenv("ROUTER_DEBUG_TESTS") != "" { cmd.Stdout = os.Stdout From 1b264cc8a2076bba930207e35b146c58b02398d0 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 17:48:32 +0100 Subject: [PATCH 16/24] Bind servers under test to local loopback interface. This eliminates the annoying "wants to accept incoming connections" popups on macOS. --- integration_tests/router_support.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/integration_tests/router_support.go b/integration_tests/router_support.go index af15f04c..a963f133 100644 --- a/integration_tests/router_support.go +++ b/integration_tests/router_support.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/exec" + "strconv" "syscall" "time" @@ -53,8 +54,9 @@ func reloadRoutes(optionalPort ...int) { var runningRouters = make(map[int]*exec.Cmd) func startRouter(port, apiPort int, extraEnv []string) error { - pubaddr := fmt.Sprintf(":%d", port) - apiaddr := fmt.Sprintf(":%d", apiPort) + host := "localhost" + pubAddr := net.JoinHostPort(host, strconv.Itoa(port)) + apiAddr := net.JoinHostPort(host, strconv.Itoa(apiPort)) bin := os.Getenv("BINARY") if bin == "" { @@ -63,8 +65,8 @@ func startRouter(port, apiPort int, extraEnv []string) error { cmd := exec.Command(bin) cmd.Env = append(cmd.Environ(), "ROUTER_MONGO_DB=router_test") - cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_PUBADDR=%s", pubaddr)) - cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_APIADDR=%s", apiaddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_PUBADDR=%s", pubAddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_APIADDR=%s", apiAddr)) cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_ERROR_LOG=%s", tempLogfile.Name())) cmd.Env = append(cmd.Env, extraEnv...) @@ -78,7 +80,7 @@ func startRouter(port, apiPort int, extraEnv []string) error { return err } - waitForServerUp(pubaddr) + waitForServerUp(pubAddr) runningRouters[port] = cmd return nil From 22d8a86d0f917db481cf7dfdc9a39d71808a91ee Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 19:41:59 +0100 Subject: [PATCH 17/24] Get rid of varargs in test helpers. Unnecessary abstraction in test helpers can easily undermine the value of the tests. Get rid of unnecessary logic and just be explicit. Most of this was generated with some sed-foo. --- integration_tests/disabled_routes_test.go | 6 +- integration_tests/error_handling_test.go | 12 +-- integration_tests/gone_test.go | 10 +- integration_tests/http_request_helpers.go | 8 +- integration_tests/integration_test.go | 4 +- integration_tests/metrics_test.go | 2 +- integration_tests/performance_test.go | 16 ++-- integration_tests/proxy_function_test.go | 58 ++++++------ integration_tests/redirect_test.go | 70 +++++++------- integration_tests/reload_api_test.go | 34 +++---- integration_tests/route_loading_test.go | 30 +++--- integration_tests/route_selection_test.go | 110 +++++++++++----------- integration_tests/router_support.go | 22 ++--- 13 files changed, 187 insertions(+), 195 deletions(-) diff --git a/integration_tests/disabled_routes_test.go b/integration_tests/disabled_routes_test.go index 714df121..7c1ae267 100644 --- a/integration_tests/disabled_routes_test.go +++ b/integration_tests/disabled_routes_test.go @@ -11,16 +11,16 @@ var _ = Describe("marking routes as disabled", func() { BeforeEach(func() { addRoute("/unavailable", Route{Handler: "gone", Disabled: true}) addRoute("/something-live", NewRedirectRoute("/somewhere-else")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should return a 503 to the client", func() { - resp := routerRequest("/unavailable") + resp := routerRequest(routerPort, "/unavailable") Expect(resp.StatusCode).To(Equal(503)) }) It("should continue to route other requests", func() { - resp := routerRequest("/something-live") + resp := routerRequest(routerPort, "/something-live") Expect(resp.StatusCode).To(Equal(301)) Expect(resp.Header.Get("Location")).To(Equal("/somewhere-else")) }) diff --git a/integration_tests/error_handling_test.go b/integration_tests/error_handling_test.go index 2241bf19..1c8cf534 100644 --- a/integration_tests/error_handling_test.go +++ b/integration_tests/error_handling_test.go @@ -11,14 +11,14 @@ var _ = Describe("error handling", func() { Describe("handling an empty routing table", func() { BeforeEach(func() { - reloadRoutes() + reloadRoutes(apiPort) }) It("should return a 503 error to the client", func() { - resp := routerRequest("/") + resp := routerRequest(routerPort, "/") Expect(resp.StatusCode).To(Equal(503)) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(503)) }) }) @@ -26,16 +26,16 @@ var _ = Describe("error handling", func() { Describe("handling a panic", func() { BeforeEach(func() { addRoute("/boom", Route{Handler: "boom"}) - reloadRoutes() + reloadRoutes(apiPort) }) It("should return a 500 error to the client", func() { - resp := routerRequest("/boom") + resp := routerRequest(routerPort, "/boom") Expect(resp.StatusCode).To(Equal(500)) }) It("should log the fact", func() { - routerRequest("/boom") + routerRequest(routerPort, "/boom") logDetails := lastRouterErrorLogEntry() Expect(logDetails.Fields).To(Equal(map[string]interface{}{ diff --git a/integration_tests/gone_test.go b/integration_tests/gone_test.go index 18479bfe..990606cd 100644 --- a/integration_tests/gone_test.go +++ b/integration_tests/gone_test.go @@ -10,25 +10,25 @@ var _ = Describe("Gone routes", func() { BeforeEach(func() { addRoute("/foo", NewGoneRoute()) addRoute("/bar", NewGoneRoute("prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should support an exact gone route", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(410)) Expect(readBody(resp)).To(Equal("410 Gone\n")) - resp = routerRequest("/foo/bar") + resp = routerRequest(routerPort, "/foo/bar") Expect(resp.StatusCode).To(Equal(404)) Expect(readBody(resp)).To(Equal("404 page not found\n")) }) It("should support a prefix gone route", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(410)) Expect(readBody(resp)).To(Equal("410 Gone\n")) - resp = routerRequest("/bar/baz") + resp = routerRequest(routerPort, "/bar/baz") Expect(resp.StatusCode).To(Equal(410)) Expect(readBody(resp)).To(Equal("410 Gone\n")) }) diff --git a/integration_tests/http_request_helpers.go b/integration_tests/http_request_helpers.go index a1eead92..9822f2fa 100644 --- a/integration_tests/http_request_helpers.go +++ b/integration_tests/http_request_helpers.go @@ -15,12 +15,12 @@ import ( // revive:enable:dot-imports ) -func routerRequest(path string, optionalPort ...int) *http.Response { - return doRequest(newRequest("GET", routerURL(path, optionalPort...))) +func routerRequest(port int, path string) *http.Response { + return doRequest(newRequest("GET", routerURL(port, path))) } -func routerRequestWithHeaders(path string, headers map[string]string, optionalPort ...int) *http.Response { - return doRequest(newRequestWithHeaders("GET", routerURL(path, optionalPort...), headers)) +func routerRequestWithHeaders(port int, path string, headers map[string]string) *http.Response { + return doRequest(newRequestWithHeaders("GET", routerURL(port, path), headers)) } func newRequest(method, url string) *http.Request { diff --git a/integration_tests/integration_test.go b/integration_tests/integration_test.go index 853d311a..a5fcbd34 100644 --- a/integration_tests/integration_test.go +++ b/integration_tests/integration_test.go @@ -20,7 +20,7 @@ var _ = BeforeSuite(func() { if err != nil { Fail(err.Error()) } - err = startRouter(3169, 3168, nil) + err = startRouter(routerPort, apiPort, nil) if err != nil { Fail(err.Error()) } @@ -35,6 +35,6 @@ var _ = BeforeEach(func() { }) var _ = AfterSuite(func() { - stopRouter(3169) + stopRouter(routerPort) cleanupTempLogfile() }) diff --git a/integration_tests/metrics_test.go b/integration_tests/metrics_test.go index 7b4bfdc6..ba4812e9 100644 --- a/integration_tests/metrics_test.go +++ b/integration_tests/metrics_test.go @@ -10,7 +10,7 @@ var _ = Describe("/metrics API endpoint", func() { var responseBody string BeforeEach(func() { - resp := doRequest(newRequest("GET", routerAPIURL("/metrics"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/metrics"))) Expect(resp.StatusCode).To(Equal(200)) responseBody = readBody(resp) }) diff --git a/integration_tests/performance_test.go b/integration_tests/performance_test.go index f01c4248..08be778e 100644 --- a/integration_tests/performance_test.go +++ b/integration_tests/performance_test.go @@ -26,7 +26,7 @@ var _ = Describe("Performance", func() { addBackend("backend-2", backend2.URL) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -48,7 +48,7 @@ var _ = Describe("Performance", func() { case <-stopCh: return case <-ticker.C: - reloadRoutes() + reloadRoutes(apiPort) } }() @@ -62,9 +62,9 @@ var _ = Describe("Performance", func() { defer slowBackend.Close() addBackend("backend-slow", slowBackend.URL) addRoute("/slow", NewBackendRoute("backend-slow")) - reloadRoutes() + reloadRoutes(apiPort) - _, gen := generateLoad([]string{routerURL("/slow")}, 50) + _, gen := generateLoad([]string{routerURL(routerPort, "/slow")}, 50) defer gen.Stop() assertPerformantRouter(backend1, backend2, 50) @@ -75,9 +75,9 @@ var _ = Describe("Performance", func() { It("Router should not cause errors or much latency", func() { addBackend("backend-down", "http://127.0.0.1:3162/") addRoute("/down", NewBackendRoute("backend-down")) - reloadRoutes() + reloadRoutes(apiPort) - _, gen := generateLoad([]string{routerURL("/down")}, 50) + _, gen := generateLoad([]string{routerURL(routerPort, "/down")}, 50) defer gen.Stop() assertPerformantRouter(backend1, backend2, 50) @@ -102,7 +102,7 @@ var _ = Describe("Performance", func() { addBackend("backend-2", backend2.URL) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -117,7 +117,7 @@ var _ = Describe("Performance", func() { func assertPerformantRouter(backend1, backend2 *httptest.Server, rps int) { directResultsCh, _ := generateLoad([]string{backend1.URL + "/one", backend2.URL + "/two"}, rps) - routerResultsCh, _ := generateLoad([]string{routerURL("/one"), routerURL("/two")}, rps) + routerResultsCh, _ := generateLoad([]string{routerURL(routerPort, "/one"), routerURL(routerPort, "/two")}, rps) directResults := <-directResultsCh routerResults := <-routerResultsCh diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index bf30c1da..bb3fc002 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -20,9 +20,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should return a 502 if the connection to the backend is refused", func() { addBackend("not-running", "http://127.0.0.1:3164/") addRoute("/not-running", NewBackendRoute("not-running")) - reloadRoutes() + reloadRoutes(apiPort) - req, err := http.NewRequest(http.MethodGet, routerURL("/not-running"), nil) + req, err := http.NewRequest(http.MethodGet, routerURL(routerPort, "/not-running"), nil) Expect(err).NotTo(HaveOccurred()) req.Header.Set("X-Varnish", "12345678") @@ -50,7 +50,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { addRoute("/should-time-out", NewBackendRoute("black-hole")) reloadRoutes(3166) - req, err := http.NewRequest(http.MethodGet, routerURL("/should-time-out", 3167), nil) + req, err := http.NewRequest(http.MethodGet, routerURL(3167, "/should-time-out"), nil) Expect(err).NotTo(HaveOccurred()) req.Header.Set("X-Varnish", "12345678") @@ -98,7 +98,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should log and return a 504 if a backend takes longer than the configured response timeout to start returning a response", func() { - req := newRequest(http.MethodGet, routerURL("/tarpit1", 3167)) + req := newRequest(http.MethodGet, routerURL(3167, "/tarpit1")) req.Header.Set("X-Varnish", "12341112") resp := doRequest(req) Expect(resp.StatusCode).To(Equal(504)) @@ -117,7 +117,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should still return the response if the body takes longer than the header timeout", func() { - resp := routerRequest("/tarpit2", 3167) + resp := routerRequest(3167, "/tarpit2") Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("Tarpit\n")) }) @@ -135,7 +135,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorderURL, _ = url.Parse(recorder.URL()) addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -143,7 +143,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should pass through most http headers to the backend", func() { - resp := routerRequestWithHeaders("/foo", map[string]string{ + resp := routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Foo": "bar", "User-Agent": "Router test suite 2.7182", }) @@ -156,7 +156,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should set the Host header to the backend hostname", func() { - resp := routerRequestWithHeaders("/foo", map[string]string{ + resp := routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Host": "www.example.com", }) Expect(resp.StatusCode).To(Equal(200)) @@ -168,7 +168,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should not add a default User-Agent if there isn't one in the request", func() { // Most http libraries add a default User-Agent header. - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -178,14 +178,14 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should add the client IP to X-Forwardrd-For", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] Expect(beReq.Header.Get("X-Forwarded-For")).To(Equal("127.0.0.1")) - resp = routerRequestWithHeaders("/foo", map[string]string{ + resp = routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "X-Forwarded-For": "10.9.8.7", }) Expect(resp.StatusCode).To(Equal(200)) @@ -199,14 +199,14 @@ var _ = Describe("Functioning as a reverse proxy", func() { // See https://tools.ietf.org/html/rfc2616#section-14.45 It("should add itself to the Via request header for an HTTP/1.1 request", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] Expect(beReq.Header.Get("Via")).To(Equal("1.1 router")) - resp = routerRequestWithHeaders("/foo", map[string]string{ + resp = routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Via": "1.0 fred, 1.1 barney", }) Expect(resp.StatusCode).To(Equal(200)) @@ -217,7 +217,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should add itself to the Via request header for an HTTP/1.0 request", func() { - req := newRequest(http.MethodGet, routerURL("/foo")) + req := newRequest(http.MethodGet, routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) Expect(resp.StatusCode).To(Equal(200)) @@ -225,7 +225,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { beReq := recorder.ReceivedRequests()[0] Expect(beReq.Header.Get("Via")).To(Equal("1.0 router")) - req = newRequestWithHeaders("GET", routerURL("/foo"), map[string]string{ + req = newRequestWithHeaders("GET", routerURL(routerPort, "/foo"), map[string]string{ "Via": "1.0 fred, 1.1 barney", }) resp = doHTTP10Request(req) @@ -237,14 +237,14 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should add itself to the Via response heaver", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Header.Get("Via")).To(Equal("1.1 router")) recorder.AppendHandlers(ghttp.RespondWith(200, "body", http.Header{ "Via": []string{"1.0 fred, 1.1 barney"}, })) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Header.Get("Via")).To(Equal("1.0 fred, 1.1 barney, 1.1 router")) }) @@ -260,7 +260,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -273,11 +273,11 @@ var _ = Describe("Functioning as a reverse proxy", func() { ghttp.VerifyRequest("DELETE", "/foo/bar/baz.json"), ) - req := newRequest("POST", routerURL("/foo")) + req := newRequest("POST", routerURL(routerPort, "/foo")) resp := doRequest(req) Expect(resp.StatusCode).To(Equal(200)) - req = newRequest("DELETE", routerURL("/foo/bar/baz.json")) + req = newRequest("DELETE", routerURL(routerPort, "/foo/bar/baz.json")) resp = doRequest(req) Expect(resp.StatusCode).To(Equal(200)) @@ -288,7 +288,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder.AppendHandlers( ghttp.VerifyRequest("GET", "/foo/bar", "baz=qux"), ) - resp := routerRequest("/foo/bar?baz=qux") + resp := routerRequest(routerPort, "/foo/bar?baz=qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -302,7 +302,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(string(body)).To(Equal("I am the request body. Woohoo!")) }) - req := newRequest("POST", routerURL("/foo")) + req := newRequest("POST", routerURL(routerPort, "/foo")) req.Body = io.NopCloser(strings.NewReader("I am the request body. Woohoo!")) resp := doRequest(req) Expect(resp.StatusCode).To(Equal(200)) @@ -320,7 +320,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()+"/something") addRoute("/foo/bar", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -328,7 +328,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should merge the 2 paths", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -337,7 +337,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should preserve the request query string", func() { - resp := routerRequest("/foo/bar?baz=qux") + resp := routerRequest(routerPort, "/foo/bar?baz=qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -355,7 +355,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { @@ -363,7 +363,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should work with incoming HTTP/1.1 requests", func() { - req := newRequest("GET", routerURL("/foo")) + req := newRequest("GET", routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) Expect(resp.StatusCode).To(Equal(200)) @@ -373,7 +373,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should proxy to the backend as HTTP/1.1 requests", func() { - req := newRequest("GET", routerURL("/foo")) + req := newRequest("GET", routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) Expect(resp.StatusCode).To(Equal(200)) @@ -401,7 +401,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should correctly reverse proxy to a HTTPS backend", func() { - req := newRequest("GET", routerURL("/foo", 3167)) + req := newRequest("GET", routerURL(3167, "/foo")) resp := doRequest(req) Expect(resp.StatusCode).To(Equal(200)) diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 92e935f3..91474c21 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -16,46 +16,46 @@ var _ = Describe("Redirection", func() { addRoute("/query-temp", NewRedirectRoute("/bar?query=true", "exact")) addRoute("/fragment", NewRedirectRoute("/bar#section", "exact")) addRoute("/preserve-query", NewRedirectRoute("/qux", "exact", "permanent", "preserve")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should redirect permanently by default", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(301)) }) It("should redirect temporarily when asked to", func() { - resp := routerRequest("/foo-temp") + resp := routerRequest(routerPort, "/foo-temp") Expect(resp.StatusCode).To(Equal(302)) }) It("should contain the redirect location", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should not preserve the query string for the source by default", func() { - resp := routerRequest("/foo?baz=qux") + resp := routerRequest(routerPort, "/foo?baz=qux") Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should preserve the query string for the source if specified", func() { - resp := routerRequest("/preserve-query?foo=bar") + resp := routerRequest(routerPort, "/preserve-query?foo=bar") Expect(resp.Header.Get("Location")).To(Equal("/qux?foo=bar")) }) It("should preserve the query string for the target", func() { - resp := routerRequest("/query-temp") + resp := routerRequest(routerPort, "/query-temp") Expect(resp.Header.Get("Location")).To(Equal("/bar?query=true")) }) It("should preserve the fragment for the target", func() { - resp := routerRequest("/fragment") + resp := routerRequest(routerPort, "/fragment") Expect(resp.Header.Get("Location")).To(Equal("/bar#section")) }) It("should contain cache headers of 30 mins", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Cache-Control")).To(Equal("max-age=1800, public")) Expect( @@ -73,43 +73,43 @@ var _ = Describe("Redirection", func() { addRoute("/foo", NewRedirectRoute("/bar", "prefix")) addRoute("/foo-temp", NewRedirectRoute("/bar-temp", "prefix", "temporary")) addRoute("/qux", NewRedirectRoute("/baz", "prefix", "temporary", "ignore")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should redirect permanently to the destination", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(301)) Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should redirect temporarily to the destination when asked to", func() { - resp := routerRequest("/foo-temp") + resp := routerRequest(routerPort, "/foo-temp") Expect(resp.StatusCode).To(Equal(302)) Expect(resp.Header.Get("Location")).To(Equal("/bar-temp")) }) It("should preserve extra path sections when redirecting by default", func() { - resp := routerRequest("/foo/baz") + resp := routerRequest(routerPort, "/foo/baz") Expect(resp.Header.Get("Location")).To(Equal("/bar/baz")) }) It("should ignore extra path sections when redirecting if specified", func() { - resp := routerRequest("/qux/quux") + resp := routerRequest(routerPort, "/qux/quux") Expect(resp.Header.Get("Location")).To(Equal("/baz")) }) It("should preserve the query string when redirecting by default", func() { - resp := routerRequest("/foo?baz=qux") + resp := routerRequest(routerPort, "/foo?baz=qux") Expect(resp.Header.Get("Location")).To(Equal("/bar?baz=qux")) }) It("should not preserve the query string when redirecting if specified", func() { - resp := routerRequest("/qux/quux?foo=bar") + resp := routerRequest(routerPort, "/qux/quux?foo=bar") Expect(resp.Header.Get("Location")).To(Equal("/baz")) }) It("should contain cache headers of 30 mins", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Cache-Control")).To(Equal("max-age=1800, public")) Expect( @@ -123,9 +123,9 @@ var _ = Describe("Redirection", func() { It("should handle path-preserving redirects with special characters", func() { addRoute("/foo%20bar", NewRedirectRoute("/bar%20baz", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/foo bar/something") + resp := routerRequest(routerPort, "/foo bar/something") Expect(resp.StatusCode).To(Equal(301)) Expect(resp.Header.Get("Location")).To(Equal("/bar%20baz/something")) }) @@ -137,44 +137,44 @@ var _ = Describe("Redirection", func() { addRoute("/baz", NewRedirectRoute("http://foo.example.com/baz", "exact", "permanent", "preserve")) addRoute("/bar", NewRedirectRoute("http://bar.example.com/bar", "prefix")) addRoute("/qux", NewRedirectRoute("http://bar.example.com/qux", "prefix", "permanent", "ignore")) - reloadRoutes() + reloadRoutes(apiPort) }) Describe("exact redirect", func() { It("should redirect to the external URL", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(resp.Header.Get("Location")).To(Equal("http://foo.example.com/foo")) }) It("should not preserve the query string by default", func() { - resp := routerRequest("/foo?foo=qux") + resp := routerRequest(routerPort, "/foo?foo=qux") Expect(resp.Header.Get("Location")).To(Equal("http://foo.example.com/foo")) }) It("should preserve the query string if specified", func() { - resp := routerRequest("/baz?foo=qux") + resp := routerRequest(routerPort, "/baz?foo=qux") Expect(resp.Header.Get("Location")).To(Equal("http://foo.example.com/baz?foo=qux")) }) }) Describe("prefix redirect", func() { It("should redirect to the external URL", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/bar")) }) It("should preserve extra path sections when redirecting by default", func() { - resp := routerRequest("/bar/baz") + resp := routerRequest(routerPort, "/bar/baz") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/bar/baz")) }) It("should ignore extra path sections when redirecting if specified", func() { - resp := routerRequest("/qux/baz") + resp := routerRequest(routerPort, "/qux/baz") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/qux")) }) It("should preserve the query string when redirecting", func() { - resp := routerRequest("/bar?baz=qux") + resp := routerRequest(routerPort, "/bar?baz=qux") Expect(resp.Header.Get("Location")).To(Equal("http://bar.example.com/bar?baz=qux")) }) }) @@ -188,41 +188,41 @@ var _ = Describe("Redirection", func() { addRoute("/pay-tax", NewRedirectRoute("https://tax.service.gov.uk/pay", "exact", "permanent", "ignore")) addRoute("/biz-bank", NewRedirectRoute("https://british-business-bank.co.uk", "prefix", "permanent", "ignore")) addRoute("/query-paramed", NewRedirectRoute("https://param.servicegov.uk?included-param=true", "exact", "permanent", "ignore")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should only preserve the _ga parameter when redirecting to service URLs that want to ignore query params", func() { - resp := routerRequest("/foo?_ga=identifier&blah=xyz") + resp := routerRequest(routerPort, "/foo?_ga=identifier&blah=xyz") Expect(resp.Header.Get("Location")).To(Equal("https://hmrc.service.gov.uk/pay?_ga=identifier")) }) It("should retain all params when redirecting to a route that wants them", func() { - resp := routerRequest("/bar?wanted=param&_ga=xyz&blah=xyz") + resp := routerRequest(routerPort, "/bar?wanted=param&_ga=xyz&blah=xyz") Expect(resp.Header.Get("Location")).To(Equal("https://bar.service.gov.uk/bar?wanted=param&_ga=xyz&blah=xyz")) }) It("should preserve the _ga parameter when redirecting to gov.uk URLs", func() { - resp := routerRequest("/baz?_ga=identifier") + resp := routerRequest(routerPort, "/baz?_ga=identifier") Expect(resp.Header.Get("Location")).To(Equal("https://gov.uk/baz-luhrmann?_ga=identifier")) }) It("should preserve the _ga parameter when redirecting to service.gov.uk URLs", func() { - resp := routerRequest("/pay-tax?_ga=12345") + resp := routerRequest(routerPort, "/pay-tax?_ga=12345") Expect(resp.Header.Get("Location")).To(Equal("https://tax.service.gov.uk/pay?_ga=12345")) }) It("should preserve only the first _ga parameter", func() { - resp := routerRequest("/pay-tax/?_ga=12345&_ga=6789") + resp := routerRequest(routerPort, "/pay-tax/?_ga=12345&_ga=6789") Expect(resp.Header.Get("Location")).To(Equal("https://tax.service.gov.uk/pay?_ga=12345")) }) It("should preserve the _ga param when redirecting to british business bank", func() { - resp := routerRequest("/biz-bank?unwanted=param&_ga=12345") + resp := routerRequest(routerPort, "/biz-bank?unwanted=param&_ga=12345") Expect(resp.Header.Get("Location")).To(Equal("https://british-business-bank.co.uk?_ga=12345")) }) It("should preserve the _ga param and any existing query string that the target URL has", func() { - resp := routerRequest("/query-paramed?unwanted_param=blah&_ga=12345") + resp := routerRequest(routerPort, "/query-paramed?unwanted_param=blah&_ga=12345") // https://param.servicegov.uk?included-param=true?unwanted_param=blah&_ga=12345 Expect(resp.Header.Get("Location")).To(Equal("https://param.servicegov.uk?_ga=12345&included-param=true")) }) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index 47dd326b..e1cf5981 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -11,23 +11,23 @@ var _ = Describe("reload API endpoint", func() { Describe("request handling", func() { It("should return 202 for POST /reload", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/reload"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) Expect(resp.StatusCode).To(Equal(202)) Expect(readBody(resp)).To(Equal("Reload queued")) }) It("should return 404 for POST /foo", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/foo"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/foo"))) Expect(resp.StatusCode).To(Equal(404)) }) It("should return 404 for POST /reload/foo", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/reload/foo"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/reload/foo"))) Expect(resp.StatusCode).To(Equal(404)) }) It("should return 405 for GET /reload", func() { - resp := doRequest(newRequest("GET", routerAPIURL("/reload"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/reload"))) Expect(resp.StatusCode).To(Equal(405)) Expect(resp.Header.Get("Allow")).To(Equal("POST")) }) @@ -36,34 +36,34 @@ var _ = Describe("reload API endpoint", func() { addRoute("/foo", NewRedirectRoute("/qux", "prefix")) start := time.Now() - doRequest(newRequest("POST", routerAPIURL("/reload"))) + doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) end := time.Now() duration := end.Sub(start) Expect(duration.Nanoseconds()).To(BeNumerically("<", 5000000)) addRoute("/bar", NewRedirectRoute("/qux", "prefix")) - doRequest(newRequest("POST", routerAPIURL("/reload"))) + doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) Eventually(func() int { - return routerRequest("/foo").StatusCode + return routerRequest(routerPort, "/foo").StatusCode }, time.Second*1).Should(Equal(301)) Eventually(func() int { - return routerRequest("/bar").StatusCode + return routerRequest(routerPort, "/bar").StatusCode }, time.Second*1).Should(Equal(301)) }) }) Describe("healthcheck", func() { It("should return 200 and sting 'OK' on /healthcheck", func() { - resp := doRequest(newRequest("GET", routerAPIURL("/healthcheck"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/healthcheck"))) Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("OK")) }) It("should return 405 for other verbs", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/healthcheck"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/healthcheck"))) Expect(resp.StatusCode).To(Equal(405)) Expect(resp.Header.Get("Allow")).To(Equal("GET")) }) @@ -78,8 +78,8 @@ var _ = Describe("reload API endpoint", func() { addRoute("/foo", NewRedirectRoute("/bar", "prefix")) addRoute("/baz", NewRedirectRoute("/qux", "prefix")) addRoute("/foo", NewRedirectRoute("/bar/baz")) - reloadRoutes() - resp := doRequest(newRequest("GET", routerAPIURL("/stats"))) + reloadRoutes(apiPort) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/stats"))) Expect(resp.StatusCode).To(Equal(200)) readJSONBody(resp, &data) }) @@ -93,9 +93,9 @@ var _ = Describe("reload API endpoint", func() { var data map[string]map[string]interface{} BeforeEach(func() { - reloadRoutes() + reloadRoutes(apiPort) - resp := doRequest(newRequest("GET", routerAPIURL("/stats"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/stats"))) Expect(resp.StatusCode).To(Equal(200)) readJSONBody(resp, &data) }) @@ -106,7 +106,7 @@ var _ = Describe("reload API endpoint", func() { }) It("should return 405 for other verbs", func() { - resp := doRequest(newRequest("POST", routerAPIURL("/stats"))) + resp := doRequest(newRequest("POST", routerURL(apiPort, "/stats"))) Expect(resp.StatusCode).To(Equal(405)) Expect(resp.Header.Get("Allow")).To(Equal("GET")) }) @@ -117,9 +117,9 @@ var _ = Describe("reload API endpoint", func() { addRoute("/foo", NewRedirectRoute("/bar", "prefix")) addRoute("/baz", NewRedirectRoute("/qux", "prefix")) addRoute("/foo", NewRedirectRoute("/bar/baz")) - reloadRoutes() + reloadRoutes(apiPort) - resp := doRequest(newRequest("GET", routerAPIURL("/memory-stats"))) + resp := doRequest(newRequest("GET", routerURL(apiPort, "/memory-stats"))) Expect(resp.StatusCode).To(Equal(200)) var data map[string]interface{} diff --git a/integration_tests/route_loading_test.go b/integration_tests/route_loading_test.go index 6408b79c..827c1deb 100644 --- a/integration_tests/route_loading_test.go +++ b/integration_tests/route_loading_test.go @@ -31,19 +31,19 @@ var _ = Describe("loading routes from the db", func() { addRoute("/foo", NewBackendRoute("backend-1")) addRoute("/bar", Route{Handler: "fooey"}) addRoute("/baz", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should skip the invalid route", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should continue to load other routes", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 2")) }) }) @@ -54,22 +54,22 @@ var _ = Describe("loading routes from the db", func() { addRoute("/bar", NewBackendRoute("backend-non-existent")) addRoute("/baz", NewBackendRoute("backend-2")) addRoute("/qux", NewBackendRoute("backend-1")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should skip the invalid route", func() { - resp := routerRequest("/bar") + resp := routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should continue to load other routes", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/qux") + resp = routerRequest(routerPort, "/qux") Expect(readBody(resp)).To(Equal("backend 1")) }) }) @@ -82,23 +82,23 @@ var _ = Describe("loading routes from the db", func() { backend3 = startSimpleBackend("backend 3") addBackend("backend-3", blackHole) - stopRouter(3169) - err := startRouter(3169, 3168, []string{fmt.Sprintf("BACKEND_URL_backend-3=%s", backend3.URL)}) + stopRouter(routerPort) + err := startRouter(routerPort, apiPort, []string{fmt.Sprintf("BACKEND_URL_backend-3=%s", backend3.URL)}) Expect(err).NotTo(HaveOccurred()) addRoute("/oof", NewBackendRoute("backend-3")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { - stopRouter(3169) - err := startRouter(3169, 3168, nil) + stopRouter(routerPort) + err := startRouter(routerPort, apiPort, nil) Expect(err).NotTo(HaveOccurred()) backend3.Close() }) It("should send requests to the backend_url provided in the env var", func() { - resp := routerRequest("/oof") + resp := routerRequest(routerPort, "/oof") Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("backend 3")) }) diff --git a/integration_tests/route_selection_test.go b/integration_tests/route_selection_test.go index c3c089ef..4b60e44b 100644 --- a/integration_tests/route_selection_test.go +++ b/integration_tests/route_selection_test.go @@ -24,7 +24,7 @@ var _ = Describe("Route selection", func() { addRoute("/foo", NewBackendRoute("backend-1")) addRoute("/bar", NewBackendRoute("backend-2")) addRoute("/baz", NewBackendRoute("backend-1")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -32,29 +32,29 @@ var _ = Describe("Route selection", func() { }) It("should route a matching request to the corresponding backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 1")) }) It("should 404 for children of the exact route", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should 404 for non-matching requests", func() { - resp := routerRequest("/wibble") + resp := routerRequest(routerPort, "/wibble") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/") + resp = routerRequest(routerPort, "/") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/foo.json") + resp = routerRequest(routerPort, "/foo.json") Expect(resp.StatusCode).To(Equal(404)) }) }) @@ -73,7 +73,7 @@ var _ = Describe("Route selection", func() { addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/bar", NewBackendRoute("backend-2", "prefix")) addRoute("/baz", NewBackendRoute("backend-1", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -81,35 +81,35 @@ var _ = Describe("Route selection", func() { }) It("should route requests for the prefix to the backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/baz") + resp = routerRequest(routerPort, "/baz") Expect(readBody(resp)).To(Equal("backend 1")) }) It("should route requests for the children of the prefix to the backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("backend 1")) - resp = routerRequest("/bar/foo.json") + resp = routerRequest(routerPort, "/bar/foo.json") Expect(readBody(resp)).To(Equal("backend 2")) - resp = routerRequest("/baz/fooey/kablooie") + resp = routerRequest(routerPort, "/baz/fooey/kablooie") Expect(readBody(resp)).To(Equal("backend 1")) }) It("should 404 for non-matching requests", func() { - resp := routerRequest("/wibble") + resp := routerRequest(routerPort, "/wibble") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/") + resp = routerRequest(routerPort, "/") Expect(resp.StatusCode).To(Equal(404)) - resp = routerRequest("/foo.json") + resp = routerRequest(routerPort, "/foo.json") Expect(resp.StatusCode).To(Equal(404)) }) }) @@ -126,7 +126,7 @@ var _ = Describe("Route selection", func() { addBackend("outer-backend", outer.URL) addBackend("inner-backend", inner.URL) addRoute("/foo", NewBackendRoute("outer-backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { outer.Close() @@ -136,21 +136,21 @@ var _ = Describe("Route selection", func() { Describe("with an exact child", func() { BeforeEach(func() { addRoute("/foo/bar", NewBackendRoute("inner-backend")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should route the prefix to the outer backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the exact child to the inner backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("inner")) }) It("should route the children of the exact child to the outer backend", func() { - resp := routerRequest("/foo/bar/baz") + resp := routerRequest(routerPort, "/foo/bar/baz") Expect(readBody(resp)).To(Equal("outer")) }) }) @@ -158,29 +158,29 @@ var _ = Describe("Route selection", func() { Describe("with a prefix child", func() { BeforeEach(func() { addRoute("/foo/bar", NewBackendRoute("inner-backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) It("should route the outer prefix to the outer backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the inner prefix to the inner backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("inner")) }) It("should route the children of the inner prefix to the inner backend", func() { - resp := routerRequest("/foo/bar/baz") + resp := routerRequest(routerPort, "/foo/bar/baz") Expect(readBody(resp)).To(Equal("inner")) }) It("should route other children of the outer prefix to the outer backend", func() { - resp := routerRequest("/foo/baz") + resp := routerRequest(routerPort, "/foo/baz") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/bar.json") + resp = routerRequest(routerPort, "/foo/bar.json") Expect(readBody(resp)).To(Equal("outer")) }) }) @@ -194,43 +194,43 @@ var _ = Describe("Route selection", func() { addBackend("innerer-backend", innerer.URL) addRoute("/foo/bar", NewBackendRoute("inner-backend")) addRoute("/foo/bar/baz", NewBackendRoute("innerer-backend", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { innerer.Close() }) It("should route the outer prefix to the outer backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/baz") + resp = routerRequest(routerPort, "/foo/baz") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/bar.json") + resp = routerRequest(routerPort, "/foo/bar.json") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the exact route to the inner backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("inner")) }) It("should route other children of the exact route to the outer backend", func() { - resp := routerRequest("/foo/bar/wibble") + resp := routerRequest(routerPort, "/foo/bar/wibble") Expect(readBody(resp)).To(Equal("outer")) - resp = routerRequest("/foo/bar/baz.json") + resp = routerRequest(routerPort, "/foo/bar/baz.json") Expect(readBody(resp)).To(Equal("outer")) }) It("should route the inner prefix route to the innerer backend", func() { - resp := routerRequest("/foo/bar/baz") + resp := routerRequest(routerPort, "/foo/bar/baz") Expect(readBody(resp)).To(Equal("innerer")) }) It("should route children of the inner prefix route to the innerer backend", func() { - resp := routerRequest("/foo/bar/baz/wibble") + resp := routerRequest(routerPort, "/foo/bar/baz/wibble") Expect(readBody(resp)).To(Equal("innerer")) }) }) @@ -249,7 +249,7 @@ var _ = Describe("Route selection", func() { addBackend("backend-2", backend2.URL) addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/foo", NewBackendRoute("backend-2")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { backend1.Close() @@ -257,12 +257,12 @@ var _ = Describe("Route selection", func() { }) It("should route the exact route to the exact backend", func() { - resp := routerRequest("/foo") + resp := routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("backend 2")) }) It("should route children of the route to the prefix backend", func() { - resp := routerRequest("/foo/bar") + resp := routerRequest(routerPort, "/foo/bar") Expect(readBody(resp)).To(Equal("backend 1")) }) }) @@ -287,29 +287,29 @@ var _ = Describe("Route selection", func() { It("should handle an exact route at the root level", func() { addRoute("/", NewBackendRoute("root")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/") + resp := routerRequest(routerPort, "/") Expect(readBody(resp)).To(Equal("root backend")) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("other backend")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(resp.StatusCode).To(Equal(404)) }) It("should handle a prefix route at the root level", func() { addRoute("/", NewBackendRoute("root", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/") + resp := routerRequest(routerPort, "/") Expect(readBody(resp)).To(Equal("root backend")) - resp = routerRequest("/foo") + resp = routerRequest(routerPort, "/foo") Expect(readBody(resp)).To(Equal("other backend")) - resp = routerRequest("/bar") + resp = routerRequest(routerPort, "/bar") Expect(readBody(resp)).To(Equal("root backend")) }) }) @@ -327,7 +327,7 @@ var _ = Describe("Route selection", func() { addBackend("other", recorder.URL()) addRoute("/", NewBackendRoute("root", "prefix")) addRoute("/foo/bar", NewBackendRoute("other", "prefix")) - reloadRoutes() + reloadRoutes(apiPort) }) AfterEach(func() { root.Close() @@ -335,19 +335,19 @@ var _ = Describe("Route selection", func() { }) It("should not be redirected by our simple test backend", func() { - resp := routerRequest("//") + resp := routerRequest(routerPort, "//") Expect(readBody(resp)).To(Equal("fallthrough")) }) It("should not be redirected by our recorder backend", func() { - resp := routerRequest("/foo/bar/baz//qux") + resp := routerRequest(routerPort, "/foo/bar/baz//qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].URL.Path).To(Equal("/foo/bar/baz//qux")) }) It("should collapse double slashes when looking up route, but pass request as-is", func() { - resp := routerRequest("/foo//bar") + resp := routerRequest(routerPort, "/foo//bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].URL.Path).To(Equal("/foo//bar")) @@ -367,9 +367,9 @@ var _ = Describe("Route selection", func() { It("should handle spaces (%20) in paths", func() { addRoute("/foo%20bar", NewBackendRoute("backend")) - reloadRoutes() + reloadRoutes(apiPort) - resp := routerRequest("/foo bar") + resp := routerRequest(routerPort, "/foo bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].RequestURI).To(Equal("/foo%20bar")) diff --git a/integration_tests/router_support.go b/integration_tests/router_support.go index a963f133..d03d526f 100644 --- a/integration_tests/router_support.go +++ b/integration_tests/router_support.go @@ -16,24 +16,16 @@ import ( // revive:enable:dot-imports ) -func routerURL(path string, optionalPort ...int) string { - port := 3169 - if len(optionalPort) > 0 { - port = optionalPort[0] - } - return fmt.Sprintf("http://127.0.0.1:%d%s", port, path) -} +const ( + routerPort = 3169 + apiPort = 3168 +) -func routerAPIURL(path string) string { - return routerURL(path, 3168) +func routerURL(port int, path string) string { + return fmt.Sprintf("http://127.0.0.1:%d%s", port, path) } -func reloadRoutes(optionalPort ...int) { - port := 3168 - if len(optionalPort) > 0 { - port = optionalPort[0] - } - +func reloadRoutes(port int) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, From 46e13a6882f9551b4f0f12ff68577a68953d5fb0 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 22:02:08 +0100 Subject: [PATCH 18/24] Fix a couple of wonky test names. --- integration_tests/reload_api_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index e1cf5981..bf9184ca 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -56,13 +56,13 @@ var _ = Describe("reload API endpoint", func() { }) Describe("healthcheck", func() { - It("should return 200 and sting 'OK' on /healthcheck", func() { + It("should return HTTP 200 OK on GET", func() { resp := doRequest(newRequest("GET", routerURL(apiPort, "/healthcheck"))) Expect(resp.StatusCode).To(Equal(200)) Expect(readBody(resp)).To(Equal("OK")) }) - It("should return 405 for other verbs", func() { + It("should return HTTP 405 Method Not Allowed on POST", func() { resp := doRequest(newRequest("POST", routerURL(apiPort, "/healthcheck"))) Expect(resp.StatusCode).To(Equal(405)) Expect(resp.Header.Get("Allow")).To(Equal("GET")) From 341963e0c67e6067c1163847d3f9ffbd91bec762 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 22:27:40 +0100 Subject: [PATCH 19/24] Remove an unrelated test assertion. In a test that's supposed to check whether reloading the route table works, asserting that the async reload request took < 5 ms is neither helpful nor tasty, unlike this other [recipe for flakes](https://bakewithsarah.com/chocolate-fudge-flake-cake/). Also don't request the reload twice; it's unnecessary and makes the test slightly less useful, e.g. it would fail to catch a bug where the reload fails the first time but succeeds on retry. --- integration_tests/reload_api_test.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index bf9184ca..ea50f5a6 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -34,14 +34,6 @@ var _ = Describe("reload API endpoint", func() { It("eventually reloads the routes", func() { addRoute("/foo", NewRedirectRoute("/qux", "prefix")) - - start := time.Now() - doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) - end := time.Now() - duration := end.Sub(start) - - Expect(duration.Nanoseconds()).To(BeNumerically("<", 5000000)) - addRoute("/bar", NewRedirectRoute("/qux", "prefix")) doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) From 794070711651da47498aabed2109c9472a486fab Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 22:43:24 +0100 Subject: [PATCH 20/24] More deflaking: allow 1.5 poll intervals for reload. The default poll interval is 2 s, so wait 3. This doesn't slow down the test; gomega.Eventually polls every 10 ms by default. --- integration_tests/reload_api_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index ea50f5a6..36c8f2cd 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -39,11 +39,11 @@ var _ = Describe("reload API endpoint", func() { Eventually(func() int { return routerRequest(routerPort, "/foo").StatusCode - }, time.Second*1).Should(Equal(301)) + }, time.Second*3).Should(Equal(301)) Eventually(func() int { return routerRequest(routerPort, "/bar").StatusCode - }, time.Second*1).Should(Equal(301)) + }, time.Second*3).Should(Equal(301)) }) }) From 73b9cd9b27f13eecc7b0cc24e410f357705333d6 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 23:16:10 +0100 Subject: [PATCH 21/24] Remove /stats API endpoint. `/stats` is disused and useless now that we have a Prometheus gauge for the route count. It didn't use to return anything else. --- integration_tests/reload_api_test.go | 43 ---------------------------- lib/router.go | 10 ------- lib/router_api.go | 25 ---------------- 3 files changed, 78 deletions(-) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index 36c8f2cd..0b52cd17 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -61,49 +61,6 @@ var _ = Describe("reload API endpoint", func() { }) }) - Describe("route stats", func() { - - Context("with some routes loaded", func() { - var data map[string]map[string]interface{} - - BeforeEach(func() { - addRoute("/foo", NewRedirectRoute("/bar", "prefix")) - addRoute("/baz", NewRedirectRoute("/qux", "prefix")) - addRoute("/foo", NewRedirectRoute("/bar/baz")) - reloadRoutes(apiPort) - resp := doRequest(newRequest("GET", routerURL(apiPort, "/stats"))) - Expect(resp.StatusCode).To(Equal(200)) - readJSONBody(resp, &data) - }) - - It("should return the number of routes loaded", func() { - Expect(data["routes"]["count"]).To(BeEquivalentTo(3)) - }) - }) - - Context("with no routes", func() { - var data map[string]map[string]interface{} - - BeforeEach(func() { - reloadRoutes(apiPort) - - resp := doRequest(newRequest("GET", routerURL(apiPort, "/stats"))) - Expect(resp.StatusCode).To(Equal(200)) - readJSONBody(resp, &data) - }) - - It("should return the number of routes loaded", func() { - Expect(data["routes"]["count"]).To(BeEquivalentTo(0)) - }) - }) - - It("should return 405 for other verbs", func() { - resp := doRequest(newRequest("POST", routerURL(apiPort, "/stats"))) - Expect(resp.StatusCode).To(Equal(405)) - Expect(resp.Header.Get("Allow")).To(Equal("GET")) - }) - }) - Describe("memory stats", func() { It("should return memory statistics", func() { addRoute("/foo", NewRedirectRoute("/bar", "prefix")) diff --git a/lib/router.go b/lib/router.go index e41003bd..6144cf41 100644 --- a/lib/router.go +++ b/lib/router.go @@ -393,16 +393,6 @@ func (be *Backend) ParseURL() (*url.URL, error) { return url.Parse(backendURL) } -func (rt *Router) RouteStats() (stats map[string]interface{}) { - rt.lock.RLock() - mux := rt.mux - rt.lock.RUnlock() - - stats = make(map[string]interface{}) - stats["count"] = mux.RouteCount() - return -} - func shouldPreserveSegments(route *Route) bool { switch route.RouteType { case RouteTypeExact: diff --git a/lib/router_api.go b/lib/router_api.go index 63f55543..ab0a4fbe 100644 --- a/lib/router_api.go +++ b/lib/router_api.go @@ -2,7 +2,6 @@ package router import ( "encoding/json" - "fmt" "net/http" "runtime" @@ -47,30 +46,6 @@ func NewAPIHandler(rout *Router) (api http.Handler, err error) { } }) - mux.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - w.Header().Set("Allow", http.MethodGet) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - stats := make(map[string]map[string]interface{}) - stats["routes"] = rout.RouteStats() - - jsonData, err := json.MarshalIndent(stats, "", " ") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - _, err = fmt.Fprintln(w, string(jsonData)) - if err != nil { - logWarn(err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - mux.HandleFunc("/memory-stats", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { w.Header().Set("Allow", http.MethodGet) From d3ad82cf6494938e90cc5a0f498b5c08cd6d85a6 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 19 Aug 2023 23:22:07 +0100 Subject: [PATCH 22/24] Fix dirty read of rt.mux when fetching route count. --- lib/router.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/router.go b/lib/router.go index 6144cf41..03e56545 100644 --- a/lib/router.go +++ b/lib/router.go @@ -235,14 +235,14 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta backends := rt.loadBackends(db.C("backends")) loadRoutes(db.C("routes"), newmux, backends) + routeCount := newmux.RouteCount() rt.lock.Lock() rt.mux = newmux rt.lock.Unlock() - logInfo(fmt.Sprintf("router: reloaded %d routes", rt.mux.RouteCount())) - - routesCountMetric.Set(float64(rt.mux.RouteCount())) + logInfo(fmt.Sprintf("router: reloaded %d routes", routeCount)) + routesCountMetric.Set(float64(routeCount)) } func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) { From 832ef3e810ff65b177f5f64af475afeea63c494d Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sun, 20 Aug 2023 12:36:48 +0100 Subject: [PATCH 23/24] Fix typo in test name. --- integration_tests/proxy_function_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index bb3fc002..d7c93fc5 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -177,7 +177,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(ok).To(BeFalse()) }) - It("should add the client IP to X-Forwardrd-For", func() { + It("should add the client IP to X-Forwarded-For", func() { resp := routerRequest(routerPort, "/foo") Expect(resp.StatusCode).To(Equal(200)) From 46d294206cf46b6ba9d2f7750178b8809744338f Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sun, 20 Aug 2023 16:57:59 +0100 Subject: [PATCH 24/24] Fix bad syntax in examples. --- trie/trie.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trie/trie.go b/trie/trie.go index 10fdc9b0..11eda389 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -23,7 +23,7 @@ func NewTrie() *Trie { // and returns the object if the path exists in the Trie, or nil and a status of // false. Example: // -// if res, ok := trie.Get([]string{"foo", "bar"}), ok { +// if res, ok := trie.Get([]string{"foo", "bar"}); ok { // fmt.Println("Value at /foo/bar was", res) // } func (t *Trie) Get(path []string) (entry interface{}, ok bool) { @@ -50,7 +50,7 @@ func (t *Trie) Get(path []string) (entry interface{}, ok bool) { // longest matching prefix is returned. If nothing matches at all, nil and a // status of false is returned. Example: // -// if res, ok := trie.GetLongestPrefix([]string{"foo", "bar"}), ok { +// if res, ok := trie.GetLongestPrefix([]string{"foo", "bar"}); ok { // fmt.Println("Value at /foo/bar was", res) // } func (t *Trie) GetLongestPrefix(path []string) (entry interface{}, ok bool) {