diff --git a/go.mod b/go.mod index adfefaa7..53ebffe4 100644 --- a/go.mod +++ b/go.mod @@ -10,12 +10,16 @@ require ( ) require ( + github.com/andybalholm/brotli v1.1.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.55.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0f972aaf..26ec34a1 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,12 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -26,6 +30,10 @@ github.com/urfave/cli-docs/v3 v3.0.0-alpha5 h1:H1oWnR2/GN0dNm2PVylws+GxSOD6YOwW/ github.com/urfave/cli-docs/v3 v3.0.0-alpha5/go.mod h1:AIqom6Q60U4tiqHp41i7+/AB2XHgi1WvQ7jOFlccmZ4= github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo= github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8= +github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= diff --git a/internal/cli/perftest/command.go b/internal/cli/perftest/command.go index 4371d631..9265a0fa 100644 --- a/internal/cli/perftest/command.go +++ b/internal/cli/perftest/command.go @@ -55,7 +55,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit Aliases: []string{"perf", "test"}, Hidden: true, Usage: "Simple performance (load) test for the HTTP server", - Action: func(ctx context.Context, c *cli.Command) error { + Action: func(ctx context.Context, c *cli.Command) error { // TODO: use fasthttp.Client var ( perfCtx, cancel = context.WithTimeout(ctx, c.Duration(durationFlag.Name)) startedAt = time.Now() diff --git a/internal/cli/serve/command.go b/internal/cli/serve/command.go index 18a3c130..9d01d602 100644 --- a/internal/cli/serve/command.go +++ b/internal/cli/serve/command.go @@ -290,7 +290,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit,gocy // Run current command. func (cmd *command) Run(ctx context.Context, log *logger.Logger, cfg *config.Config) error { //nolint:funlen - var srv = appHttp.NewServer(ctx, log) + var srv = appHttp.NewServer(log) if err := srv.Register(cfg); err != nil { return err diff --git a/internal/http/handlers/error_page/code.go b/internal/http/handlers/error_page/code.go index 8dfe7ecb..2a5745dc 100644 --- a/internal/http/handlers/error_page/code.go +++ b/internal/http/handlers/error_page/code.go @@ -1,10 +1,11 @@ package error_page import ( - "net/http" "path/filepath" "strconv" "strings" + + "github.com/valyala/fasthttp" ) // extractCodeFromURL extracts the error code from the given URL. @@ -37,11 +38,15 @@ func extractCodeFromURL(url string) (uint16, bool) { func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn // extractCodeFromHeaders extracts the error code from the given headers. -func extractCodeFromHeaders(headers http.Header) (uint16, bool) { +func extractCodeFromHeaders(headers *fasthttp.RequestHeader) (uint16, bool) { + if headers == nil { + return 0, false + } + // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ // HTTP status code returned by the request - if value := headers.Get("X-Code"); len(value) > 0 && len(value) <= 3 { - if code, err := strconv.ParseUint(value, 10, 16); err == nil && code > 0 && code < 999 { + if value := headers.Peek("X-Code"); len(value) > 0 && len(value) <= 3 { + if code, err := strconv.ParseUint(string(value), 10, 16); err == nil && code > 0 && code < 999 { return uint16(code), true } } @@ -50,7 +55,7 @@ func extractCodeFromHeaders(headers http.Header) (uint16, bool) { } // HeadersContainCode checks if the given headers contain an error code. -func HeadersContainCode(headers http.Header) (ok bool) { +func HeadersContainCode(headers *fasthttp.RequestHeader) (ok bool) { _, ok = extractCodeFromHeaders(headers) return diff --git a/internal/http/handlers/error_page/code_test.go b/internal/http/handlers/error_page/code_test.go index 0413d3cc..3f4c6918 100644 --- a/internal/http/handlers/error_page/code_test.go +++ b/internal/http/handlers/error_page/code_test.go @@ -1,10 +1,10 @@ package error_page_test import ( - "net/http" "testing" "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" "gh.tarampamp.am/error-pages/internal/http/handlers/error_page" ) @@ -36,18 +36,26 @@ func TestURLContainsCode(t *testing.T) { func TestHeadersContainCode(t *testing.T) { t.Parallel() + var mkHeaders = func(key, value string) *fasthttp.RequestHeader { + var out = new(fasthttp.RequestHeader) + + out.Set(key, value) + + return out + } + for name, _tt := range map[string]struct { - giveHeaders http.Header + giveHeaders *fasthttp.RequestHeader wantOk bool }{ - "with code": {giveHeaders: http.Header{"X-Code": {"404"}}, wantOk: true}, + "with code": {giveHeaders: mkHeaders("X-Code", "404"), wantOk: true}, "empty": {giveHeaders: nil}, - "no code": {giveHeaders: http.Header{"X-Code": {""}}}, - "wrong": {giveHeaders: http.Header{"X-Code": {"foo"}}}, - "too big": {giveHeaders: http.Header{"X-Code": {"1000"}}}, - "too small": {giveHeaders: http.Header{"X-Code": {"0"}}}, - "negative": {giveHeaders: http.Header{"X-Code": {"-1"}}}, + "no code": {giveHeaders: mkHeaders("X-Code", "")}, + "wrong": {giveHeaders: mkHeaders("X-Code", "foo")}, + "too big": {giveHeaders: mkHeaders("X-Code", "1000")}, + "too small": {giveHeaders: mkHeaders("X-Code", "0")}, + "negative": {giveHeaders: mkHeaders("X-Code", "-1")}, } { tt := _tt diff --git a/internal/http/handlers/error_page/format.go b/internal/http/handlers/error_page/format.go index 909f90b3..3be430bf 100644 --- a/internal/http/handlers/error_page/format.go +++ b/internal/http/handlers/error_page/format.go @@ -2,10 +2,11 @@ package error_page import ( "math" - "net/http" "slices" "strconv" "strings" + + "github.com/valyala/fasthttp" ) type preferredFormat = byte @@ -21,10 +22,10 @@ const ( // detectPreferredFormatForClient detects the preferred format for the client based on the headers. // It supports the following headers: Content-Type, Accept, X-Format. // If the headers are not set or the format is not recognized, it returns unknownFormat. -func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nolint:funlen,gocognit +func detectPreferredFormatForClient(headers *fasthttp.RequestHeader) preferredFormat { //nolint:funlen,gocognit var contentType, accept string - if contentTypeHeader := strings.TrimSpace(headers.Get("Content-Type")); contentTypeHeader != "" { //nolint:nestif + if contentTypeHeader := strings.TrimSpace(string(headers.Peek("Content-Type"))); contentTypeHeader != "" { //nolint:nestif,lll // https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type // text/html; charset=utf-8 // multipart/form-data; boundary=something @@ -38,11 +39,11 @@ func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nol // take the whole value contentType = contentTypeHeader } - } else if xFormatHeader := strings.TrimSpace(headers.Get("X-Format")); xFormatHeader != "" { + } else if xFormatHeader := strings.TrimSpace(string(headers.Peek("X-Format"))); xFormatHeader != "" { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ // Value of the `Accept` header sent by the client accept = xFormatHeader - } else if acceptHeader := strings.TrimSpace(headers.Get("Accept")); acceptHeader != "" { + } else if acceptHeader := strings.TrimSpace(string(headers.Peek("Accept"))); acceptHeader != "" { // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept // text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8 // text/html diff --git a/internal/http/handlers/error_page/format_test.go b/internal/http/handlers/error_page/format_test.go index 06fa20d1..c8b31fcb 100644 --- a/internal/http/handlers/error_page/format_test.go +++ b/internal/http/handlers/error_page/format_test.go @@ -1,80 +1,80 @@ package error_page import ( - "net/http" "testing" "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" ) func Test_detectPreferredFormatForClient(t *testing.T) { t.Parallel() for name, _tt := range map[string]struct { - giveHeaders http.Header + giveHeaders map[string][]string wantFormat preferredFormat }{ "content type json": { - giveHeaders: http.Header{"Content-Type": {"application/jSoN"}}, + giveHeaders: map[string][]string{"Content-Type": {"application/jSoN"}}, wantFormat: jsonFormat, }, "content type xml": { - giveHeaders: http.Header{"Content-Type": {"application/xml; charset=UTF-8"}}, + giveHeaders: map[string][]string{"Content-Type": {"application/xml; charset=UTF-8"}}, wantFormat: xmlFormat, }, "content type html": { - giveHeaders: http.Header{"Content-Type": {"text/hTmL; charset=utf-8"}}, + giveHeaders: map[string][]string{"Content-Type": {"text/hTmL; charset=utf-8"}}, wantFormat: htmlFormat, }, "content type plain": { - giveHeaders: http.Header{"Content-Type": {"text/plaIN"}}, + giveHeaders: map[string][]string{"Content-Type": {"text/plaIN"}}, wantFormat: plainTextFormat, }, "accept json": { - giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}}, wantFormat: jsonFormat, }, "accept xml, depends on weight": { - giveHeaders: http.Header{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}}, wantFormat: xmlFormat, }, "accept json, depends on weight": { - giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}}, wantFormat: jsonFormat, }, "accept xml": { - giveHeaders: http.Header{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, wantFormat: xmlFormat, }, "accept html": { - giveHeaders: http.Header{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}}, wantFormat: htmlFormat, }, "accept plain": { - giveHeaders: http.Header{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}}, wantFormat: plainTextFormat, }, "accept json, weighted values only": { - giveHeaders: http.Header{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}}, wantFormat: jsonFormat, }, "x-format json, depends on weight": { - giveHeaders: http.Header{"X-Format": {"application/jsoN,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"X-Format": {"application/jsoN,*/*;q=0.8"}}, wantFormat: jsonFormat, }, "x-format xml": { - giveHeaders: http.Header{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, + giveHeaders: map[string][]string{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, wantFormat: xmlFormat, }, "content type has priority over accept": { - giveHeaders: http.Header{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}}, + giveHeaders: map[string][]string{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}}, wantFormat: plainTextFormat, }, "accept has priority over x-format": { - giveHeaders: http.Header{"Accept": {"application/xml"}, "X-Format": {"text/plain"}}, + giveHeaders: map[string][]string{"Accept": {"application/xml"}, "X-Format": {"text/plain"}}, wantFormat: plainTextFormat, }, @@ -82,25 +82,33 @@ func Test_detectPreferredFormatForClient(t *testing.T) { giveHeaders: nil, }, "empty content type": { - giveHeaders: http.Header{"Content-Type": {" "}}, + giveHeaders: map[string][]string{"Content-Type": {" "}}, }, "wrong content type": { - giveHeaders: http.Header{"Content-Type": {"multipart/form-data; boundary=something"}}, + giveHeaders: map[string][]string{"Content-Type": {"multipart/form-data; boundary=something"}}, }, "wrong accept": { - giveHeaders: http.Header{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}}, + giveHeaders: map[string][]string{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}}, }, "none on invalid input": { - giveHeaders: http.Header{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}}, + giveHeaders: map[string][]string{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}}, }, "completely unknown": { - giveHeaders: http.Header{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}}, + giveHeaders: map[string][]string{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}}, }, } { tt := _tt t.Run(name, func(t *testing.T) { - assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(tt.giveHeaders)) + var headers = new(fasthttp.RequestHeader) + + for key, values := range tt.giveHeaders { + for _, value := range values { + headers.Add(key, value) + } + } + + assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(headers)) }) } } diff --git a/internal/http/handlers/error_page/handler.go b/internal/http/handlers/error_page/handler.go index d70871d2..74d46ea7 100644 --- a/internal/http/handlers/error_page/handler.go +++ b/internal/http/handlers/error_page/handler.go @@ -7,21 +7,24 @@ import ( "sync/atomic" "time" + "github.com/valyala/fasthttp" + "gh.tarampamp.am/error-pages/internal/config" "gh.tarampamp.am/error-pages/internal/logger" "gh.tarampamp.am/error-pages/internal/template" ) // New creates a new handler that returns an error page with the specified status code and format. -func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,gocognit,gocyclo - const contentTypeHeader = "Content-Type" - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var code uint16 +func New(cfg *config.Config, log *logger.Logger) fasthttp.RequestHandler { //nolint:funlen,gocognit,gocyclo + return func(ctx *fasthttp.RequestCtx) { + var ( + reqHeaders = &ctx.Request.Header + code uint16 + ) - if fromUrl, okUrl := extractCodeFromURL(r.URL.Path); okUrl { + if fromUrl, okUrl := extractCodeFromURL(string(ctx.RequestURI())); okUrl { code = fromUrl - } else if fromHeader, okHeaders := extractCodeFromHeaders(r.Header); okHeaders { + } else if fromHeader, okHeaders := extractCodeFromHeaders(reqHeaders); okHeaders { code = fromHeader } else { code = cfg.DefaultCodeToRender @@ -35,23 +38,23 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen, httpCode = http.StatusOK } - var format = detectPreferredFormatForClient(r.Header) + var format = detectPreferredFormatForClient(reqHeaders) { // deal with the headers switch format { case jsonFormat: - w.Header().Set(contentTypeHeader, "application/json; charset=utf-8") + ctx.SetContentType("application/json; charset=utf-8") case xmlFormat: - w.Header().Set(contentTypeHeader, "application/xml; charset=utf-8") + ctx.SetContentType("application/xml; charset=utf-8") case htmlFormat: - w.Header().Set(contentTypeHeader, "text/html; charset=utf-8") + ctx.SetContentType("text/html; charset=utf-8") default: - w.Header().Set(contentTypeHeader, "text/plain; charset=utf-8") // plainTextFormat as default + ctx.SetContentType("text/plain; charset=utf-8") // plainTextFormat as default } // https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag // disallow indexing of the error pages - w.Header().Set("X-Robots-Tag", "noindex") + ctx.Response.Header.Set("X-Robots-Tag", "noindex") switch code { case http.StatusRequestTimeout, http.StatusTooEarly, http.StatusTooManyRequests, @@ -59,18 +62,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen, http.StatusGatewayTimeout: // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After // tell the client (search crawler) to retry the request after 120 seconds - w.Header().Set("Retry-After", "120") + ctx.Response.Header.Set("Retry-After", "120") } // proxy the headers from the incoming request to the error page response if they are defined in the config for _, proxyHeader := range cfg.ProxyHeaders { - if value := r.Header.Get(proxyHeader); value != "" { - w.Header().Set(proxyHeader, value) + if value := reqHeaders.Peek(proxyHeader); len(value) > 0 { + ctx.Response.Header.SetBytesV(proxyHeader, value) } } } - w.WriteHeader(httpCode) + ctx.SetStatusCode(httpCode) // prepare the template properties for rendering var tplProps = template.Props{ @@ -81,14 +84,14 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen, //nolint:lll if cfg.ShowDetails { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ - tplProps.OriginalURI = r.Header.Get("X-Original-URI") // (ingress-nginx) URI that caused the error - tplProps.Namespace = r.Header.Get("X-Namespace") // (ingress-nginx) namespace where the backend Service is located - tplProps.IngressName = r.Header.Get("X-Ingress-Name") // (ingress-nginx) name of the Ingress where the backend is defined - tplProps.ServiceName = r.Header.Get("X-Service-Name") // (ingress-nginx) name of the Service backing the backend - tplProps.ServicePort = r.Header.Get("X-Service-Port") // (ingress-nginx) port number of the Service backing the backend - tplProps.RequestID = r.Header.Get("X-Request-Id") // (ingress-nginx) unique ID that identifies the request - same as for backend service - tplProps.ForwardedFor = r.Header.Get("X-Forwarded-For") // the value of the `X-Forwarded-For` header - tplProps.Host = r.Host // the value of the `Host` header + tplProps.OriginalURI = string(reqHeaders.Peek("X-Original-URI")) // (ingress-nginx) URI that caused the error + tplProps.Namespace = string(reqHeaders.Peek("X-Namespace")) // (ingress-nginx) namespace where the backend Service is located + tplProps.IngressName = string(reqHeaders.Peek("X-Ingress-Name")) // (ingress-nginx) name of the Ingress where the backend is defined + tplProps.ServiceName = string(reqHeaders.Peek("X-Service-Name")) // (ingress-nginx) name of the Service backing the backend + tplProps.ServicePort = string(reqHeaders.Peek("X-Service-Port")) // (ingress-nginx) port number of the Service backing the backend + tplProps.RequestID = string(reqHeaders.Peek("X-Request-Id")) // (ingress-nginx) unique ID that identifies the request - same as for backend service + tplProps.ForwardedFor = string(reqHeaders.Peek("X-Forwarded-For")) // the value of the `X-Forwarded-For` header + tplProps.Host = string(reqHeaders.Peek("Host")) // the value of the `Host` header } // try to find the code message and description in the config and if not - use the standard status text or fallback @@ -105,18 +108,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen, case format == jsonFormat && cfg.Formats.JSON != "": if content, err := template.Render(cfg.Formats.JSON, tplProps); err != nil { j, _ := json.Marshal(fmt.Sprintf("Failed to render the JSON template: %s", err.Error())) - write(w, log, j) + write(ctx, log, j) } else { - write(w, log, content) + write(ctx, log, content) } case format == xmlFormat && cfg.Formats.XML != "": if content, err := template.Render(cfg.Formats.XML, tplProps); err != nil { - write(w, log, fmt.Sprintf( + write(ctx, log, fmt.Sprintf( "\nFailed to render the XML template: %s", err.Error(), )) } else { - write(w, log, content) + write(ctx, log, content) } case format == htmlFormat: @@ -125,16 +128,16 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen, if tpl, found := cfg.Templates.Get(templateName); found { if content, err := template.Render(tpl, tplProps); err != nil { // TODO: add GZIP compression for the HTML content support - write(w, log, fmt.Sprintf( + write(ctx, log, fmt.Sprintf( "\nFailed to render the HTML template %s: %s", templateName, err.Error(), )) } else { - write(w, log, content) + write(ctx, log, content) } } else { - write(w, log, fmt.Sprintf( + write(ctx, log, fmt.Sprintf( "\nTemplate %s not found and cannot be used", templateName, )) } @@ -142,18 +145,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen, default: // plainTextFormat as default if cfg.Formats.PlainText != "" { if content, err := template.Render(cfg.Formats.PlainText, tplProps); err != nil { - write(w, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error())) + write(ctx, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error())) } else { - write(w, log, content) + write(ctx, log, content) } } else { - write(w, log, `The requested content format is not supported. + write(ctx, log, `The requested content format is not supported. Please create an issue on the project's GitHub page to request support for this format. Supported formats: JSON, XML, HTML, Plain Text`) } } - }) + } } var ( @@ -204,7 +207,7 @@ func templateToUse(cfg *config.Config) string { } // write the content to the response writer and log the error if any. -func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content T) { +func write[T string | []byte](ctx *fasthttp.RequestCtx, log *logger.Logger, content T) { var data []byte if s, ok := any(content).(string); ok { @@ -213,7 +216,7 @@ func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content data = any(content).([]byte) } - if _, err := w.Write(data); err != nil && log != nil { + if _, err := ctx.Write(data); err != nil && log != nil { log.Error("failed to write the response body", logger.String("content", string(data)), logger.Error(err), diff --git a/internal/http/handlers/error_page/handler_test.go b/internal/http/handlers/error_page/handler_test.go index e6e608f3..2bbbbad7 100644 --- a/internal/http/handlers/error_page/handler_test.go +++ b/internal/http/handlers/error_page/handler_test.go @@ -2,13 +2,14 @@ package error_page_test import ( "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gh.tarampamp.am/error-pages/internal/config" "gh.tarampamp.am/error-pages/internal/http/handlers/error_page" + "gh.tarampamp.am/error-pages/internal/http/httptest" "gh.tarampamp.am/error-pages/internal/logger" ) @@ -26,7 +27,7 @@ func TestHandler(t *testing.T) { }{ "common, plain text": { giveConfig: func() *config.Config { cfg := config.New(); return &cfg }, - giveUrl: "/", + giveUrl: "http://testing/", giveHeaders: map[string]string{"Content-Type": "text/plain"}, wantStatusCode: http.StatusOK, @@ -41,7 +42,7 @@ func TestHandler(t *testing.T) { return &cfg }, - giveUrl: "/", + giveUrl: "http://testing/", giveHeaders: map[string]string{"X-Format": "text/html", "X-Code": "407"}, wantStatusCode: http.StatusOK, @@ -60,7 +61,7 @@ func TestHandler(t *testing.T) { return &cfg }, - giveUrl: "/503.html", + giveUrl: "http://testing/503.html", giveHeaders: map[string]string{"Accept": "application/json", "X-FooBar": "baz"}, wantStatusCode: http.StatusServiceUnavailable, @@ -78,7 +79,7 @@ func TestHandler(t *testing.T) { return &cfg }, - giveUrl: "/500", + giveUrl: "http://testing/500", giveHeaders: map[string]string{"Accept": "application/xml", "X-FooBar": "baz"}, wantStatusCode: http.StatusOK, @@ -96,7 +97,7 @@ func TestHandler(t *testing.T) { return &cfg }, - giveUrl: "/503", + giveUrl: "http://example.com/503", giveHeaders: map[string]string{ "Accept": "application/json", "X-Original-URI": "/foo/bar", @@ -106,7 +107,6 @@ func TestHandler(t *testing.T) { "X-Service-Port": "666", "X-Request-ID": "req-id-777", "X-Forwarded-For": "123.123.123.123:12312", - "Host": "example.com", }, wantStatusCode: http.StatusOK, @@ -133,7 +133,7 @@ func TestHandler(t *testing.T) { return &cfg }, - giveUrl: "/100", + giveUrl: "http://testing/100", giveHeaders: map[string]string{"Accept": "application/json"}, wantStatusCode: http.StatusOK, @@ -148,7 +148,7 @@ func TestHandler(t *testing.T) { return &cfg }, - giveUrl: "/1", + giveUrl: "http://testing/1", giveHeaders: map[string]string{"Accept": "application/json"}, wantStatusCode: http.StatusOK, @@ -159,31 +159,30 @@ func TestHandler(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - var ( - req = httptest.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody) - handler = error_page.New(tt.giveConfig(), logger.NewNop()) - rr = httptest.NewRecorder() - ) + var handler = error_page.New(tt.giveConfig(), logger.NewNop()) + + req, reqErr := http.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody) + require.NoError(t, reqErr) for k, v := range tt.giveHeaders { req.Header.Set(k, v) } - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Code, tt.wantStatusCode) + httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) { + assert.Equal(t, tt.wantStatusCode, status) - for hName, hWant := range tt.wantHeaders { - for hGot := range rr.Header() { - if hGot == hName { - assert.Contains(t, hWant, rr.Header().Get(hGot)) + for hName, hWant := range tt.wantHeaders { + for hGot := range headers { + if hGot == hName { + assert.Contains(t, hWant, headers.Get(hGot)) + } } } - } - for _, wantBodyInclude := range tt.wantBodyIncludes { - assert.Contains(t, rr.Body.String(), wantBodyInclude) - } + for _, wantBodyInclude := range tt.wantBodyIncludes { + assert.Contains(t, body, wantBodyInclude) + } + }) }) } } @@ -207,19 +206,17 @@ func TestRotationModeOnEachRequest(t *testing.T) { ) for range 300 { - var ( - req = httptest.NewRequest(http.MethodGet, "/", http.NoBody) - rr = httptest.NewRecorder() - ) + req, reqErr := http.NewRequest(http.MethodGet, "http://testing/", http.NoBody) + require.NoError(t, reqErr) req.Header.Set("Accept", "text/html") - handler.ServeHTTP(rr, req) - - if lastResponseBody != rr.Body.String() { - changedTimes++ - lastResponseBody = rr.Body.String() - } + httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) { + if lastResponseBody != body { + changedTimes++ + lastResponseBody = body + } + }) } assert.True(t, changedTimes > 30, "the template should be changed at least 30 times") diff --git a/internal/http/handlers/live/handler.go b/internal/http/handlers/live/handler.go index acf41644..a2f2dd33 100644 --- a/internal/http/handlers/live/handler.go +++ b/internal/http/handlers/live/handler.go @@ -2,24 +2,29 @@ package live import ( "net/http" + + "github.com/valyala/fasthttp" ) // New creates a new handler that returns "OK" for GET and HEAD requests. -func New() http.Handler { - var body = []byte("OK\n") +func New() fasthttp.RequestHandler { + var ( + body = []byte("OK\n") + notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n" + ) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(body) + return func(ctx *fasthttp.RequestCtx) { + switch string(ctx.Method()) { + case fasthttp.MethodGet: + ctx.SetContentType("text/plain; charset=utf-8") + ctx.SetStatusCode(http.StatusOK) + _, _ = ctx.Write(body) - case http.MethodHead: - w.WriteHeader(http.StatusOK) + case fasthttp.MethodHead: + ctx.SetStatusCode(http.StatusOK) default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + ctx.Error(notAllowed, http.StatusMethodNotAllowed) } - }) + } } diff --git a/internal/http/handlers/live/handler_test.go b/internal/http/handlers/live/handler_test.go index ce04b065..8cfc47f9 100644 --- a/internal/http/handlers/live/handler_test.go +++ b/internal/http/handlers/live/handler_test.go @@ -2,43 +2,37 @@ package live_test import ( "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "gh.tarampamp.am/error-pages/internal/http/handlers/live" + "gh.tarampamp.am/error-pages/internal/http/httptest" ) func TestServeHTTP(t *testing.T) { t.Parallel() - var handler = live.New() + var ( + handler = live.New() + url = "http://testing" + body = http.NoBody + ) t.Run("get", func(t *testing.T) { - var ( - req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8") - assert.Equal(t, rr.Code, http.StatusOK) - assert.Equal(t, "OK\n", rr.Body.String()) + httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type")) + assert.Equal(t, "OK\n", body) + }) }) t.Run("head", func(t *testing.T) { - var ( - req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Code, http.StatusOK) - assert.Empty(t, rr.Header().Get("Content-Type")) - assert.Empty(t, rr.Body.Bytes()) + httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Empty(t, headers.Get("Content-Type")) + assert.Empty(t, body) + }) }) t.Run("method not allowed", func(t *testing.T) { @@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) { http.MethodPost, http.MethodPut, } { - var ( - req = httptest.NewRequest(method, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8") - assert.Equal(t, rr.Code, http.StatusMethodNotAllowed) - assert.Equal(t, "Method Not Allowed\n", rr.Body.String()) + httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusMethodNotAllowed, status) + assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type")) + assert.Equal(t, "Method Not Allowed\n", body) + }) } }) } diff --git a/internal/http/handlers/static/handler.go b/internal/http/handlers/static/handler.go index 5d572281..815a8f4c 100644 --- a/internal/http/handlers/static/handler.go +++ b/internal/http/handlers/static/handler.go @@ -3,25 +3,29 @@ package static import ( _ "embed" "net/http" + + "github.com/valyala/fasthttp" ) //go:embed favicon.ico var Favicon []byte // New creates a new handler that returns the provided content for GET and HEAD requests. -func New(content []byte) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - w.Header().Set("Content-Type", http.DetectContentType(content)) - w.WriteHeader(http.StatusOK) - _, _ = w.Write(content) +func New(content []byte) fasthttp.RequestHandler { + var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n" + + return func(ctx *fasthttp.RequestCtx) { + switch string(ctx.Method()) { + case fasthttp.MethodGet: + ctx.SetContentType(http.DetectContentType(content)) + ctx.SetStatusCode(http.StatusOK) + _, _ = ctx.Write(content) - case http.MethodHead: - w.WriteHeader(http.StatusOK) + case fasthttp.MethodHead: + ctx.SetStatusCode(http.StatusOK) default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + ctx.Error(notAllowed, http.StatusMethodNotAllowed) } - }) + } } diff --git a/internal/http/handlers/static/handler_test.go b/internal/http/handlers/static/handler_test.go index e1fde0ba..bbdffcd7 100644 --- a/internal/http/handlers/static/handler_test.go +++ b/internal/http/handlers/static/handler_test.go @@ -2,43 +2,37 @@ package static_test import ( "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "gh.tarampamp.am/error-pages/internal/http/handlers/static" + "gh.tarampamp.am/error-pages/internal/http/httptest" ) func TestServeHTTP(t *testing.T) { t.Parallel() - var handler = static.New([]byte{1, 2, 3}) + var ( + handler = static.New([]byte{1, 2, 3}) + url = "http://testing" + body = http.NoBody + ) t.Run("get", func(t *testing.T) { - var ( - req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "application/octet-stream") - assert.Equal(t, rr.Code, http.StatusOK) - assert.Equal(t, rr.Body.Bytes(), []byte{1, 2, 3}) + httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Equal(t, "application/octet-stream", headers.Get("Content-Type")) + assert.Equal(t, []byte{1, 2, 3}, []byte(body)) + }) }) t.Run("head", func(t *testing.T) { - var ( - req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Code, http.StatusOK) - assert.Empty(t, rr.Header().Get("Content-Type")) - assert.Empty(t, rr.Body.Bytes()) + httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Empty(t, headers.Get("Content-Type")) + assert.Empty(t, body) + }) }) t.Run("method not allowed", func(t *testing.T) { @@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) { http.MethodPost, http.MethodPut, } { - var ( - req = httptest.NewRequest(method, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8") - assert.Equal(t, rr.Code, http.StatusMethodNotAllowed) - assert.Equal(t, "Method Not Allowed\n", rr.Body.String()) + httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusMethodNotAllowed, status) + assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type")) + assert.Equal(t, "Method Not Allowed\n", body) + }) } }) } @@ -65,16 +54,15 @@ func TestServeHTTP(t *testing.T) { func TestServeHTTP_Favicon(t *testing.T) { t.Parallel() - var ( - handler = static.New(static.Favicon) - - req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) - rr = httptest.NewRecorder() + httptest.HandleFast(t, + static.New(static.Favicon), + http.MethodGet, + "http://testing", + http.NoBody, + func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Equal(t, "image/x-icon", headers.Get("Content-Type")) + assert.Equal(t, static.Favicon, []byte(body)) + }, ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "image/x-icon") - assert.Equal(t, rr.Code, http.StatusOK) - assert.Equal(t, rr.Body.Bytes(), static.Favicon) } diff --git a/internal/http/handlers/version/handler.go b/internal/http/handlers/version/handler.go index 6e4d92da..ecd66a0f 100644 --- a/internal/http/handlers/version/handler.go +++ b/internal/http/handlers/version/handler.go @@ -4,28 +4,32 @@ import ( "encoding/json" "net/http" "strings" + + "github.com/valyala/fasthttp" ) // New creates a handler that returns the version of the service in JSON format. -func New(ver string) http.Handler { +func New(ver string) fasthttp.RequestHandler { var body, _ = json.Marshal(struct { //nolint:errchkjson Version string `json:"version"` }{ Version: strings.TrimSpace(ver), }) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(body) + var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n" + + return func(ctx *fasthttp.RequestCtx) { + switch string(ctx.Method()) { + case fasthttp.MethodGet: + ctx.SetContentType("application/json; charset=utf-8") + ctx.SetStatusCode(http.StatusOK) + _, _ = ctx.Write(body) - case http.MethodHead: - w.WriteHeader(http.StatusOK) + case fasthttp.MethodHead: + ctx.SetStatusCode(http.StatusOK) default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + ctx.Error(notAllowed, http.StatusMethodNotAllowed) } - }) + } } diff --git a/internal/http/handlers/version/handler_test.go b/internal/http/handlers/version/handler_test.go index d8dea46b..0b50f269 100644 --- a/internal/http/handlers/version/handler_test.go +++ b/internal/http/handlers/version/handler_test.go @@ -2,43 +2,37 @@ package version_test import ( "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "gh.tarampamp.am/error-pages/internal/http/handlers/version" + "gh.tarampamp.am/error-pages/internal/http/httptest" ) func TestServeHTTP(t *testing.T) { t.Parallel() - var handler = version.New("\t\n foo@bar ") + var ( + handler = version.New("\t\n foo@bar ") + url = "http://testing" + body = http.NoBody + ) t.Run("get", func(t *testing.T) { - var ( - req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "application/json; charset=utf-8") - assert.Equal(t, rr.Code, http.StatusOK) - assert.Equal(t, rr.Body.String(), `{"version":"foo@bar"}`) + httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Equal(t, "application/json; charset=utf-8", headers.Get("Content-Type")) + assert.Equal(t, `{"version":"foo@bar"}`, body) + }) }) t.Run("head", func(t *testing.T) { - var ( - req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Code, http.StatusOK) - assert.Empty(t, rr.Header().Get("Content-Type")) - assert.Empty(t, rr.Body.Bytes()) + httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusOK, status) + assert.Empty(t, headers.Get("Content-Type")) + assert.Empty(t, body) + }) }) t.Run("method not allowed", func(t *testing.T) { @@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) { http.MethodPost, http.MethodPut, } { - var ( - req = httptest.NewRequest(method, "http://testing", http.NoBody) - rr = httptest.NewRecorder() - ) - - handler.ServeHTTP(rr, req) - - assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8") - assert.Equal(t, rr.Code, http.StatusMethodNotAllowed) - assert.Equal(t, "Method Not Allowed\n", rr.Body.String()) + httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) { + assert.Equal(t, http.StatusMethodNotAllowed, status) + assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type")) + assert.Equal(t, "Method Not Allowed\n", body) + }) } }) } diff --git a/internal/http/httptest/httptest.go b/internal/http/httptest/httptest.go index ae806125..32026629 100644 --- a/internal/http/httptest/httptest.go +++ b/internal/http/httptest/httptest.go @@ -1 +1,69 @@ +// Package httptest provides utilities for (fast-)HTTP testing. package httptest + +import ( + "context" + "io" + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" +) + +// HandleFastRequest serves http request using provided fasthttp handler and HTTP request. +func HandleFastRequest( + t *testing.T, + handler fasthttp.RequestHandler, + req *http.Request, + check func(status int, body string, _ http.Header), +) { + t.Helper() + + // create in-memory listener + var ln = fasthttputil.NewInmemoryListener() + defer func() { require.NoError(t, ln.Close()) }() + + // start fasthttp server + go func() { require.NoError(t, fasthttp.Serve(ln, handler)) }() + + // send http request + resp, respErr := (&http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return ln.Dial() }, + }, + }).Do(req) + require.NoError(t, respErr) + + // close response body after the test + defer func() { assert.NoError(t, resp.Body.Close()) }() + + // read response body + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // check the response + check(resp.StatusCode, string(respBody), resp.Header) +} + +// HandleFast serves http request using provided fasthttp handler. +func HandleFast( + t *testing.T, + handler fasthttp.RequestHandler, + method string, + url string, + body io.Reader, + check func(status int, body string, _ http.Header), +) { + t.Helper() + + // create http request + req, reqErr := http.NewRequest(method, url, body) + require.NoError(t, reqErr) + + // serve http request + HandleFastRequest(t, handler, req, check) +} diff --git a/internal/http/middleware/logreq/middleware.go b/internal/http/middleware/logreq/middleware.go index 467f3e45..9d2caa25 100644 --- a/internal/http/middleware/logreq/middleware.go +++ b/internal/http/middleware/logreq/middleware.go @@ -1,20 +1,24 @@ package logreq import ( - "net/http" "time" + "github.com/valyala/fasthttp" + "gh.tarampamp.am/error-pages/internal/logger" ) -// New creates a middleware for [http.ServeMux] that logs every incoming request. +// New creates a middleware that logs every incoming request. // // The skipper function should return true if the request should be skipped. It's ok to pass nil. -func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if skipper != nil && skipper(r) { - next.ServeHTTP(w, r) +func New( + log *logger.Logger, + skipper func(*fasthttp.RequestCtx) bool, +) func(fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + if skipper != nil && skipper(ctx) { + next(ctx) return } @@ -23,27 +27,35 @@ func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler defer func() { var fields = []logger.Attr{ - logger.String("useragent", r.UserAgent()), - logger.String("method", r.Method), - logger.String("url", r.URL.String()), - logger.String("referer", r.Referer()), - logger.String("content type", w.Header().Get("Content-Type")), - logger.String("remote addr", r.RemoteAddr), - logger.String("method", r.Method), + logger.Int("status code", ctx.Response.StatusCode()), + logger.String("useragent", string(ctx.UserAgent())), + logger.String("method", string(ctx.Method())), + logger.String("url", string(ctx.RequestURI())), + logger.String("referer", string(ctx.Referer())), + logger.String("content type", string(ctx.Response.Header.ContentType())), + logger.String("remote addr", ctx.RemoteAddr().String()), logger.Duration("duration", time.Since(now).Round(time.Microsecond)), } if log.Level() <= logger.DebugLevel { + var ( + reqHeaders = make(map[string]string) + respHeaders = make(map[string]string) + ) + + ctx.Request.Header.VisitAll(func(key, value []byte) { reqHeaders[string(key)] = string(value) }) + ctx.Response.Header.VisitAll(func(key, value []byte) { respHeaders[string(key)] = string(value) }) + fields = append(fields, - logger.Any("request headers", r.Header.Clone()), - logger.Any("response headers", w.Header().Clone()), + logger.Any("request headers", reqHeaders), + logger.Any("response headers", respHeaders), ) } log.Info("HTTP request processed", fields...) }() - next.ServeHTTP(w, r) - }) + next(ctx) + } } } diff --git a/internal/http/middleware/logreq/middleware_test.go b/internal/http/middleware/logreq/middleware_test.go index 0a93d317..0bef30cc 100644 --- a/internal/http/middleware/logreq/middleware_test.go +++ b/internal/http/middleware/logreq/middleware_test.go @@ -3,11 +3,12 @@ package logreq_test import ( "bytes" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" + "gh.tarampamp.am/error-pages/internal/http/httptest" "gh.tarampamp.am/error-pages/internal/http/middleware/logreq" "gh.tarampamp.am/error-pages/internal/logger" ) @@ -19,18 +20,19 @@ func TestNew(t *testing.T) { buf bytes.Buffer log, _ = logger.New(logger.DebugLevel, logger.JSONFormat, &buf) - mw = logreq.New(log, nil) - rr = httptest.NewRecorder() - req = httptest.NewRequest(http.MethodPut, "/foo/bar", http.NoBody) + mw = logreq.New(log, nil) + req, _ = http.NewRequest(http.MethodPut, "http://testing/foo/bar", http.NoBody) ) req.Header.Set("User-Agent", "test") req.Header.Set("Referer", "https://example.com") req.Header.Set("Content-Type", "application/json") - mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })).ServeHTTP(rr, req) + httptest.HandleFastRequest(t, + mw(func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(http.StatusOK) }), + req, + func(status int, body string, _ http.Header) { assert.Equal(t, http.StatusOK, status) }, + ) var logRecord = buf.String() diff --git a/internal/http/server.go b/internal/http/server.go index 1ebd0282..82b4aa04 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -2,12 +2,15 @@ package http import ( "context" + "errors" + "fmt" "net" "net/http" - "strconv" "strings" "time" + "github.com/valyala/fasthttp" + "gh.tarampamp.am/error-pages/internal/appmeta" "gh.tarampamp.am/error-pages/internal/config" ep "gh.tarampamp.am/error-pages/internal/http/handlers/error_page" @@ -21,11 +24,11 @@ import ( // Server is an HTTP server for serving error pages. type Server struct { log *logger.Logger - server *http.Server + server *fasthttp.Server } // NewServer creates a new HTTP server. -func NewServer(baseCtx context.Context, log *logger.Logger) Server { +func NewServer(log *logger.Logger) Server { const ( readTimeout = 30 * time.Second writeTimeout = readTimeout + 10*time.Second // should be bigger than the read timeout @@ -34,13 +37,14 @@ func NewServer(baseCtx context.Context, log *logger.Logger) Server { return Server{ log: log, - server: &http.Server{ - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - ReadHeaderTimeout: readTimeout, - MaxHeaderBytes: maxHeaderBytes, - ErrorLog: logger.NewStdLog(log), - BaseContext: func(net.Listener) context.Context { return baseCtx }, + server: &fasthttp.Server{ + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + ReadBufferSize: maxHeaderBytes, + DisablePreParseMultipartForm: true, + NoDefaultServerHeader: true, + CloseOnShutdown: true, + Logger: logger.NewStdLog(log), }, } } @@ -52,60 +56,78 @@ func (s *Server) Register(cfg *config.Config) error { versionHandler = version.New(appmeta.Version()) errorPagesHandler = ep.New(cfg, s.log) faviconHandler = static.New(static.Favicon) + + notFound = http.StatusText(http.StatusNotFound) + "\n" + notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n" ) - s.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var url, method = r.URL.Path, r.Method + s.server.Handler = func(ctx *fasthttp.RequestCtx) { + var url, method = string(ctx.RequestURI()), string(ctx.Method()) switch { // live endpoints case url == "/health/live" || url == "/health" || url == "/healthz" || url == "/live": - liveHandler.ServeHTTP(w, r) + liveHandler(ctx) // version endpoint case url == "/version": - versionHandler.ServeHTTP(w, r) + versionHandler(ctx) // favicon.ico endpoint case url == "/favicon.ico": - faviconHandler.ServeHTTP(w, r) + faviconHandler(ctx) // error pages endpoints: // - / // - /{code}.html // - /{code}.htm // - /{code} - case method == http.MethodGet && (url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(r.Header)): - errorPagesHandler.ServeHTTP(w, r) + case method == fasthttp.MethodGet && + (url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(&ctx.Request.Header)): + errorPagesHandler(ctx) // wrong requests handling default: switch { - case method == http.MethodHead: - w.WriteHeader(http.StatusNotFound) - case method == http.MethodGet: - http.NotFound(w, r) + case method == fasthttp.MethodHead: + ctx.Error(notAllowed, fasthttp.StatusNotFound) + case method == fasthttp.MethodGet: + ctx.Error(notFound, fasthttp.StatusNotFound) default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + ctx.Error(notAllowed, fasthttp.StatusMethodNotAllowed) } } - }) + } // apply middleware - s.server.Handler = logreq.New(s.log, func(r *http.Request) bool { + s.server.Handler = logreq.New(s.log, func(ctx *fasthttp.RequestCtx) bool { // skip logging healthcheck and .ico (favicon) requests - return strings.Contains(strings.ToLower(r.UserAgent()), "healthcheck") || - strings.HasSuffix(r.URL.Path, ".ico") + return strings.Contains(strings.ToLower(string(ctx.UserAgent())), "healthcheck") || + strings.HasSuffix(string(ctx.Path()), ".ico") })(s.server.Handler) return nil } // Start server. -func (s *Server) Start(ip string, port uint16) error { - s.server.Addr = ip + ":" + strconv.Itoa(int(port)) +func (s *Server) Start(ip string, port uint16) (err error) { + if net.ParseIP(ip) == nil { + return errors.New("invalid IP address") + } + + var ln net.Listener + + if strings.Count(ip, ":") >= 2 { //nolint:mnd // ipv6 + if ln, err = net.Listen("tcp6", fmt.Sprintf("[%s]:%d", ip, port)); err != nil { + return err + } + } else { // ipv4 + if ln, err = net.Listen("tcp4", fmt.Sprintf("%s:%d", ip, port)); err != nil { + return err + } + } - return s.server.ListenAndServe() + return s.server.Serve(ln) } // Stop server gracefully. @@ -113,5 +135,5 @@ func (s *Server) Stop(timeout time.Duration) error { var ctx, cancel = context.WithTimeout(context.Background(), timeout) defer cancel() - return s.server.Shutdown(ctx) + return s.server.ShutdownWithContext(ctx) } diff --git a/internal/http/server_test.go b/internal/http/server_test.go index 4218311a..d66a75f8 100644 --- a/internal/http/server_test.go +++ b/internal/http/server_test.go @@ -1,7 +1,6 @@ package http_test import ( - "context" "errors" "fmt" "io" @@ -21,7 +20,7 @@ import ( // TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers. func TestRouting(t *testing.T) { var ( - srv = appHttp.NewServer(context.Background(), logger.NewNop()) + srv = appHttp.NewServer(logger.NewNop()) cfg = config.New() ) @@ -296,7 +295,7 @@ func TestRouting(t *testing.T) { assert.Equal(t, http.StatusNotFound, status) assert.Empty(t, body) - assert.Empty(t, headers.Get("Content-Type")) + assert.Contains(t, headers.Get("Content-Type"), "text/plain") } })