diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3c833fe7..cd7930bf 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,12 +11,12 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: - go-version: 1.19 + go-version: ">=1.20.4" - name: Build run: go build -v ./... diff --git a/UPSTREAM b/UPSTREAM index a500091a..421c8312 100644 --- a/UPSTREAM +++ b/UPSTREAM @@ -1 +1 @@ -go1.19.6 +go1.20.4 diff --git a/cgi/child.go b/cgi/child.go index 1e7f5a15..98211b7f 100644 --- a/cgi/child.go +++ b/cgi/child.go @@ -83,10 +83,12 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers for k, v := range params { - if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" { + if k == "HTTP_HOST" { continue } - r.Header.Add(strings.ReplaceAll(k[5:], "_", "-"), v) + if after, found := strings.CutPrefix(k, "HTTP_"); found { + r.Header.Add(strings.ReplaceAll(after, "_", "-"), v) + } } uriStr := params["REQUEST_URI"] diff --git a/cgi/host.go b/cgi/host.go index 9058c55e..7a51bf81 100644 --- a/cgi/host.go +++ b/cgi/host.go @@ -138,7 +138,6 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env := []string{ "SERVER_SOFTWARE=go", - "SERVER_NAME=" + req.Host, "SERVER_PROTOCOL=HTTP/1.1", "HTTP_HOST=" + req.Host, "GATEWAY_INTERFACE=CGI/1.1", @@ -158,6 +157,12 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, "REMOTE_ADDR="+req.RemoteAddr, "REMOTE_HOST="+req.RemoteAddr) } + if hostDomain, _, err := net.SplitHostPort(req.Host); err == nil { + env = append(env, "SERVER_NAME="+hostDomain) + } else { + env = append(env, "SERVER_NAME="+req.Host) + } + if req.TLS != nil { env = append(env, "HTTPS=on") } diff --git a/cgi/host_test.go b/cgi/host_test.go index e546e86a..1fa6ab22 100644 --- a/cgi/host_test.go +++ b/cgi/host_test.go @@ -8,7 +8,6 @@ package cgi import ( "bufio" - "bytes" "fmt" "io" "net" @@ -115,7 +114,7 @@ func TestCGIBasicGet(t *testing.T) { "param-a": "b", "param-foo": "bar", "env-GATEWAY_INTERFACE": "CGI/1.1", - "env-HTTP_HOST": "example.com", + "env-HTTP_HOST": "example.com:80", "env-PATH_INFO": "", "env-QUERY_STRING": "foo=bar&a=b", "env-REMOTE_ADDR": "1.2.3.4", @@ -129,7 +128,7 @@ func TestCGIBasicGet(t *testing.T) { "env-SERVER_PORT": "80", "env-SERVER_SOFTWARE": "go", } - replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com:80\n\n", expectedMap) if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected { t.Errorf("got a Content-Type of %q; expected %q", got, expected) @@ -541,7 +540,7 @@ func TestEnvOverride(t *testing.T) { func TestHandlerStderr(t *testing.T) { check(t) - var stderr bytes.Buffer + var stderr strings.Builder h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", diff --git a/client.go b/client.go index 0abcd7f7..ed536969 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "time" "github.com/ooni/oohttp/internal/ascii" @@ -362,7 +363,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi initialReqCancel := req.Cancel // the user's original Request.Cancel, if any var cancelCtx func() - if oldCtx := req.Context(); timeBeforeContextDeadline(deadline, oldCtx) { + if timeBeforeContextDeadline(deadline, oldCtx) { req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) } @@ -392,7 +393,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi } timer := time.NewTimer(time.Until(deadline)) - var timedOut atomicBool + var timedOut atomic.Bool go func() { select { @@ -400,14 +401,14 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi doCancel() timer.Stop() case <-timer.C: - timedOut.setTrue() + timedOut.Store(true) doCancel() case <-stopTimerCh: timer.Stop() } }() - return stopTimer, timedOut.isSet + return stopTimer, timedOut.Load } // See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt @@ -499,7 +500,7 @@ func (c *Client) checkRedirect(req *Request, via []*Request) error { } // redirectBehavior describes what should happen when the -// client encounters a 3xx status code from the server +// client encounters a 3xx status code from the server. func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect, includeBody bool) { switch resp.StatusCode { case 301, 302, 303: diff --git a/client_test.go b/client_test.go index 4a55f2cd..883c79ea 100644 --- a/client_test.go +++ b/client_test.go @@ -66,11 +66,9 @@ func (w chanWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func TestClient(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(robotsTxtHandler) - defer ts.Close() +func TestClient(t *testing.T) { run(t, testClient) } +func testClient(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, robotsTxtHandler).ts c := ts.Client() r, err := c.Get(ts.URL) @@ -86,14 +84,9 @@ func TestClient(t *testing.T) { } } -func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) } -func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) } - -func testClientHead(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, robotsTxtHandler) - defer cst.close() - +func TestClientHead(t *testing.T) { run(t, testClientHead) } +func testClientHead(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, robotsTxtHandler) r, err := cst.c.Head(cst.ts.URL) if err != nil { t.Fatal(err) @@ -199,11 +192,10 @@ func TestPostFormRequestFormat(t *testing.T) { } } -func TestClientRedirects(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirects(t *testing.T) { run(t, testClientRedirects) } +func testClientRedirects(t *testing.T, mode testMode) { var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) // Test Referer header. (7 is arbitrary position to test at) if n == 7 { @@ -216,8 +208,7 @@ func TestClientRedirects(t *testing.T) { return } fmt.Fprintf(w, "n=%d", n) - })) - defer ts.Close() + })).ts c := ts.Client() _, err := c.Get(ts.URL) @@ -298,13 +289,11 @@ func TestClientRedirects(t *testing.T) { } // Tests that Client redirects' contexts are derived from the original request's context. -func TestClientRedirectContext(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientRedirectsContext(t *testing.T) { run(t, testClientRedirectsContext) } +func testClientRedirectsContext(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Redirect(w, r, "/", StatusTemporaryRedirect) - })) - defer ts.Close() + })).ts ctx, cancel := context.WithCancel(context.Background()) c := ts.Client() @@ -372,7 +361,9 @@ func TestPostRedirects(t *testing.T) { `POST /?code=404 "c404"`, } want := strings.Join(wantSegments, "\n") - testRedirectsByMethod(t, "POST", postRedirectTests, want) + run(t, func(t *testing.T, mode testMode) { + testRedirectsByMethod(t, mode, "POST", postRedirectTests, want) + }) } func TestDeleteRedirects(t *testing.T) { @@ -409,17 +400,18 @@ func TestDeleteRedirects(t *testing.T) { `DELETE /?code=404 "c404"`, } want := strings.Join(wantSegments, "\n") - testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want) + run(t, func(t *testing.T, mode testMode) { + testRedirectsByMethod(t, mode, "DELETE", deleteRedirectTests, want) + }) } -func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) { - defer afterTest(t) +func testRedirectsByMethod(t *testing.T, mode testMode, method string, table []redirectTest, want string) { var log struct { sync.Mutex bytes.Buffer } var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() slurp, _ := io.ReadAll(r.Body) fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp) @@ -444,8 +436,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa } w.WriteHeader(code) } - })) - defer ts.Close() + })).ts c := ts.Client() for _, tt := range table { @@ -490,12 +481,11 @@ func removeCommonLines(a, b string) (asuffix, bsuffix string, commonLines int) { } } -func TestClientRedirectUseResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirectUseResponse(t *testing.T) { run(t, testClientRedirectUseResponse) } +func testClientRedirectUseResponse(t *testing.T, mode testMode) { const body = "Hello, world." var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/other") { io.WriteString(w, "wrong body") } else { @@ -503,8 +493,7 @@ func TestClientRedirectUseResponse(t *testing.T) { w.WriteHeader(StatusFound) io.WriteString(w, body) } - })) - defer ts.Close() + })).ts c := ts.Client() c.CheckRedirect = func(req *Request, via []*Request) error { @@ -532,18 +521,16 @@ func TestClientRedirectUseResponse(t *testing.T) { // Issues 17773 and 49281: don't follow a 3xx if the response doesn't // have a Location header. -func TestClientRedirectNoLocation(t *testing.T) { +func TestClientRedirectNoLocation(t *testing.T) { run(t, testClientRedirectNoLocation) } +func testClientRedirectNoLocation(t *testing.T, mode testMode) { for _, code := range []int{301, 308} { t.Run(fmt.Sprint(code), func(t *testing.T) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.WriteHeader(code) })) - defer ts.Close() - c := ts.Client() - res, err := c.Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -559,15 +546,13 @@ func TestClientRedirectNoLocation(t *testing.T) { } // Don't follow a 307/308 if we can't resent the request body. -func TestClientRedirect308NoGetBody(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirect308NoGetBody(t *testing.T) { run(t, testClientRedirect308NoGetBody) } +func testClientRedirect308NoGetBody(t *testing.T, mode testMode) { const fakeURL = "https://localhost:1234/" // won't be hit - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Location", fakeURL) w.WriteHeader(308) - })) - defer ts.Close() + })).ts req, err := NewRequest("POST", ts.URL, strings.NewReader("some body")) if err != nil { t.Fatal(err) @@ -658,12 +643,10 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { return j.perURL[u.Host] } -func TestRedirectCookiesJar(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestRedirectCookiesJar(t *testing.T) { run(t, testRedirectCookiesJar) } +func testRedirectCookiesJar(t *testing.T, mode testMode) { var ts *httptest.Server - ts = httptest.NewServer(echoCookiesRedirectHandler) - defer ts.Close() + ts = newClientServerTest(t, mode, echoCookiesRedirectHandler).ts c := ts.Client() c.Jar = new(TestJar) u, _ := url.Parse(ts.URL) @@ -695,9 +678,9 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } } -func TestJarCalls(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestJarCalls(t *testing.T) { run(t, testJarCalls, []testMode{http1Mode}) } +func testJarCalls(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { pathSuffix := r.RequestURI[1:] if r.RequestURI == "/nosetcookie" { return // don't set cookies for this path @@ -706,8 +689,7 @@ func TestJarCalls(t *testing.T) { if r.RequestURI == "/" { Redirect(w, r, "http://secondhost.fake/secondpath", 302) } - })) - defer ts.Close() + })).ts jar := new(RecordingJar) c := ts.Client() c.Jar = jar @@ -756,20 +738,16 @@ func (j *RecordingJar) logf(format string, args ...any) { fmt.Fprintf(&j.log, format, args...) } -func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) } -func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) } - -func testStreamingGet(t *testing.T, h2 bool) { - defer afterTest(t) +func TestStreamingGet(t *testing.T) { run(t, testStreamingGet) } +func testStreamingGet(t *testing.T, mode testMode) { say := make(chan string) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() for str := range say { w.Write([]byte(str)) w.(Flusher).Flush() } })) - defer cst.close() c := cst.c res, err := c.Get(cst.ts.URL) @@ -810,11 +788,10 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. -func TestClientWrites(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() +func TestClientWrites(t *testing.T) { run(t, testClientWrites, []testMode{http1Mode}) } +func testClientWrites(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts writes := 0 dialer := func(netz string, addr string) (net.Conn, error) { @@ -846,11 +823,12 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode}) +} +func testClientInsecureTransport(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) + })).ts errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) defer ts.Close() @@ -897,15 +875,15 @@ func TestClientErrorWithRequestURI(t *testing.T) { } func TestClientWithCorrectTLSServerName(t *testing.T) { - defer afterTest(t) - + run(t, testClientWithCorrectTLSServerName, []testMode{https1Mode, http2Mode}) +} +func testClientWithCorrectTLSServerName(t *testing.T, mode testMode) { const serverName = "example.com" - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS.ServerName != serverName { t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName) } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).TLSClientConfig.ServerName = serverName @@ -915,9 +893,10 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { } func TestClientWithIncorrectTLSServerName(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode}) +} +func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) @@ -950,11 +929,12 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { // // The httptest.Server has a cert with "example.com" as its name. func TestTransportUsesTLSConfigServerName(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportUsesTLSConfigServerName, []testMode{https1Mode, http2Mode}) +} +func testTransportUsesTLSConfigServerName(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -970,11 +950,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { } func TestResponseSetsTLSConnectionState(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testResponseSetsTLSConnectionState, []testMode{https1Mode}) +} +func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1000,10 +981,11 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { // to determine that the server is speaking HTTP. // See golang.org/issue/11111. func TestHTTPSClientDetectsHTTPServer(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testHTTPSClientDetectsHTTPServer, []testMode{http1Mode}) +} +func testHTTPSClientDetectsHTTPServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts ts.Config.ErrorLog = quietLog - defer ts.Close() _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") { @@ -1012,22 +994,13 @@ func TestHTTPSClientDetectsHTTPServer(t *testing.T) { } // Verify Response.ContentLength is populated. https://golang.org/issue/4126 -func TestClientHeadContentLength_h1(t *testing.T) { - testClientHeadContentLength(t, h1Mode) -} - -func TestClientHeadContentLength_h2(t *testing.T) { - testClientHeadContentLength(t, h2Mode) -} - -func testClientHeadContentLength(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientHeadContentLength(t *testing.T) { run(t, testClientHeadContentLength) } +func testClientHeadContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) } })) - defer cst.close() tests := []struct { suffix string want int64 @@ -1055,11 +1028,10 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { } } -func TestEmptyPasswordAuth(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestEmptyPasswordAuth(t *testing.T) { run(t, testEmptyPasswordAuth) } +func testEmptyPasswordAuth(t *testing.T, mode testMode) { gopher := "gopher" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Basic ") { encoded := auth[6:] @@ -1075,7 +1047,7 @@ func TestEmptyPasswordAuth(t *testing.T) { } else { t.Errorf("Invalid auth %q", auth) } - })) + })).ts defer ts.Close() req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -1204,19 +1176,14 @@ func TestStripPasswordFromError(t *testing.T) { } } -func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } -func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } - -func testClientTimeout(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestClientTimeout(t *testing.T) { run(t, testClientTimeout) } +func testClientTimeout(t *testing.T, mode testMode) { var ( mu sync.Mutex nonce string // a unique per-request string sawSlowNonce bool // true if the handler saw /slow?nonce= ) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _ = r.ParseForm() if r.URL.Path == "/" { Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound) @@ -1237,7 +1204,6 @@ func testClientTimeout(t *testing.T, h2 bool) { return } })) - defer cst.close() // Try to trigger a timeout after reading part of the response body. // The initial timeout is emprically usually long enough on a decently fast @@ -1301,18 +1267,13 @@ func testClientTimeout(t *testing.T, h2 bool) { } } -func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } -func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) } - // Client.Timeout firing before getting to the body -func testClientTimeout_Headers(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestClientTimeout_Headers(t *testing.T) { run(t, testClientTimeout_Headers) } +func testClientTimeout_Headers(t *testing.T, mode testMode) { donec := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec }), optQuietLog) - defer cst.close() // Note that we use a channel send here and not a close. // The race detector doesn't know that we're waiting for a timeout // and thinks that the waitgroup inside httptest.Server is added to concurrently @@ -1345,18 +1306,15 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { // Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be // returned. -func TestClientTimeoutCancel(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestClientTimeoutCancel(t *testing.T) { run(t, testClientTimeoutCancel) } +func testClientTimeoutCancel(t *testing.T, mode testMode) { testDone := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() <-testDone })) - defer cst.close() defer close(testDone) cst.c.Timeout = 1 * time.Hour @@ -1373,18 +1331,12 @@ func TestClientTimeoutCancel(t *testing.T) { } } -func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) } -func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) } - // Issue 49366: if Client.Timeout is set but not hit, no error should be returned. -func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientTimeoutDoesNotExpire(t *testing.T) { run(t, testClientTimeoutDoesNotExpire) } +func testClientTimeoutDoesNotExpire(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("body")) })) - defer cst.close() cst.c.Timeout = 1 * time.Hour req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -1400,19 +1352,15 @@ func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { } } -func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } -func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } -func testClientRedirectEatsBody(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestClientRedirectEatsBody_h1(t *testing.T) { run(t, testClientRedirectEatsBody) } +func testClientRedirectEatsBody(t *testing.T, mode testMode) { saw := make(chan string, 2) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { saw <- r.RemoteAddr if r.URL.Path == "/" { Redirect(w, r, "/foo", StatusFound) // which includes a body } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1512,13 +1460,14 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { } // Issue 4800: copy (some) headers when Client follows a redirect. -func TestClientCopyHeadersOnRedirect(t *testing.T) { +func TestClientCopyHeadersOnRedirect(t *testing.T) { run(t, testClientCopyHeadersOnRedirect) } +func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) { const ( ua = "some-agent/1.2" xfoo = "foo-val" ) var ts2URL string - ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts1 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { want := Header{ "User-Agent": []string{ua}, "X-Foo": []string{xfoo}, @@ -1533,12 +1482,10 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } else { w.Header().Set("Result", "ok") } - })) - defer ts1.Close() - ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts + ts2 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Redirect(w, r, ts1.URL, StatusFound) - })) - defer ts2.Close() + })).ts ts2URL = ts2.URL c := ts1.Client() @@ -1573,22 +1520,24 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } // Issue 22233: copy host when Client follows a relative redirect. -func TestClientCopyHostOnRedirect(t *testing.T) { +func TestClientCopyHostOnRedirect(t *testing.T) { run(t, testClientCopyHostOnRedirect) } +func testClientCopyHostOnRedirect(t *testing.T, mode testMode) { // Virtual hostname: should not receive any request. - virtual := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + virtual := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Virtual host received request %v", r.URL) w.WriteHeader(403) io.WriteString(w, "should not see this response") - })) + })).ts defer virtual.Close() virtualHost := strings.TrimPrefix(virtual.URL, "http://") + virtualHost = strings.TrimPrefix(virtualHost, "https://") t.Logf("Virtual host is %v", virtualHost) // Actual hostname: should not receive any request. const wantBody = "response body" var tsURL string var tsHost string - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": // Relative redirect. @@ -1620,10 +1569,10 @@ func TestClientCopyHostOnRedirect(t *testing.T) { t.Errorf("Serving unexpected path %q", r.URL.Path) w.WriteHeader(404) } - })) - defer ts.Close() + })).ts tsURL = ts.URL tsHost = strings.TrimPrefix(ts.URL, "http://") + tsHost = strings.TrimPrefix(tsHost, "https://") t.Logf("Server host is %v", tsHost) c := ts.Client() @@ -1643,7 +1592,8 @@ func TestClientCopyHostOnRedirect(t *testing.T) { } // Issue 17494: cookies should be altered when Client follows redirects. -func TestClientAltersCookiesOnRedirect(t *testing.T) { +func TestClientAltersCookiesOnRedirect(t *testing.T) { run(t, testClientAltersCookiesOnRedirect) } +func testClientAltersCookiesOnRedirect(t *testing.T, mode testMode) { cookieMap := func(cs []*Cookie) map[string][]string { m := make(map[string][]string) for _, c := range cs { @@ -1652,7 +1602,7 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { return m } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var want map[string][]string got := cookieMap(r.Cookies()) @@ -1707,8 +1657,7 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want) } - })) - defer ts.Close() + })).ts jar, _ := cookiejar.New(nil) c := ts.Client() @@ -1780,10 +1729,8 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) { } } -func TestClientRedirectTypes(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestClientRedirectTypes(t *testing.T) { run(t, testClientRedirectTypes) } +func testClientRedirectTypes(t *testing.T, mode testMode) { tests := [...]struct { method string serverStatus int @@ -1828,11 +1775,10 @@ func TestClientRedirectTypes(t *testing.T) { handlerc := make(chan HandlerFunc, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { h := <-handlerc h(rw, req) - })) - defer ts.Close() + })).ts c := ts.Client() for i, tt := range tests { @@ -1888,18 +1834,16 @@ func (b issue18239Body) Close() error { // Issue 18239: make sure the Transport doesn't retry requests with bodies // if Request.GetBody is not defined. -func TestTransportBodyReadError(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportBodyReadError(t *testing.T) { run(t, testTransportBodyReadError) } +func testTransportBodyReadError(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/ping" { return } buf := make([]byte, 1) n, err := r.Body.Read(buf) w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1983,22 +1927,13 @@ func TestClientPropagatesTimeoutToContext(t *testing.T) { c.Get("https://example.tld/") } -func TestClientDoCanceledVsTimeout_h1(t *testing.T) { - testClientDoCanceledVsTimeout(t, h1Mode) -} - -func TestClientDoCanceledVsTimeout_h2(t *testing.T) { - testClientDoCanceledVsTimeout(t, h2Mode) -} - // Issue 33545: lock-in the behavior promised by Client.Do's // docs about request cancellation vs timing out. -func testClientDoCanceledVsTimeout(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientDoCanceledVsTimeout(t *testing.T) { run(t, testClientDoCanceledVsTimeout) } +func testClientDoCanceledVsTimeout(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) })) - defer cst.close() cases := []string{"timeout", "canceled"} @@ -2074,13 +2009,11 @@ func TestClientPopulatesNilResponseBody(t *testing.T) { } // Issue 40382: Client calls Close multiple times on Request.Body. -func TestClientCallsCloseOnlyOnce(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientCallsCloseOnlyOnce(t *testing.T) { run(t, testClientCallsCloseOnlyOnce) } +func testClientCallsCloseOnlyOnce(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) - defer cst.close() // Issue occurred non-deterministically: needed to occur after a successful // write (into TCP buffer) but before end of body. @@ -2130,17 +2063,15 @@ func (b *issue40382Body) Close() error { return nil } -func TestProbeZeroLengthBody(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestProbeZeroLengthBody(t *testing.T) { run(t, testProbeZeroLengthBody) } +func testProbeZeroLengthBody(t *testing.T, mode testMode) { reqc := make(chan struct{}) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(reqc) if _, err := io.Copy(w, r.Body); err != nil { t.Errorf("error copying request body: %v", err) } })) - defer cst.close() bodyr, bodyw := io.Pipe() var gotBody string diff --git a/clientserver_test.go b/clientserver_test.go index 33490e0e..54a00a5c 100644 --- a/clientserver_test.go +++ b/clientserver_test.go @@ -36,8 +36,65 @@ import ( "github.com/ooni/oohttp/httputil" ) +type testMode string + +const ( + http1Mode = testMode("h1") // HTTP/1.1 + https1Mode = testMode("https1") // HTTPS/1.1 + http2Mode = testMode("h2") // HTTP/2 +) + +type testNotParallelOpt struct{} + +var ( + testNotParallel = testNotParallelOpt{} +) + +type TBRun[T any] interface { + testing.TB + Run(string, func(T)) bool +} + +// run runs a client/server test in a variety of test configurations. +// +// Tests execute in HTTP/1.1 and HTTP/2 modes by default. +// To run in a different set of configurations, pass a []testMode option. +// +// Tests call t.Parallel() by default. +// To disable parallel execution, pass the testNotParallel option. +func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { + t.Helper() + modes := []testMode{http1Mode, http2Mode} + parallel := true + for _, opt := range opts { + switch opt := opt.(type) { + case []testMode: + modes = opt + case testNotParallelOpt: + parallel = false + default: + t.Fatalf("unknown option type %T", opt) + } + } + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + for _, mode := range modes { + t.Run(string(mode), func(t T) { + t.Helper() + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + t.Cleanup(func() { + afterTest(t) + }) + f(t, mode) + }) + } +} + type clientServerTest struct { - t *testing.T + t testing.TB h2 bool h Handler ts *httptest.Server @@ -70,11 +127,6 @@ func (t *clientServerTest) scheme() string { return "http" } -const ( - h1Mode = false - h2Mode = true -) - var optQuietLog = func(ts *httptest.Server) { ts.Config.ErrorLog = quietLog } @@ -85,23 +137,33 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { - if h2 { +// newClientServerTest creates and starts an httptest.Server. +// +// The mode parameter selects the implementation to test: +// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use +// the 'run' function, which will start a subtests for each tested mode. +// +// The vararg opts parameter can include functions to configure the +// test server or transport. +// +// func(*httptest.Server) // run before starting the server +// func(*http.Transport) +func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { + if mode == http2Mode { CondSkipHTTP2(t) } cst := &clientServerTest{ t: t, - h2: h2, + h2: mode == http2Mode, h: h, - tr: &Transport{}, } - cst.c = &Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) + var transportFuncs []func(*Transport) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): - opt(cst.tr) + transportFuncs = append(transportFuncs, opt) case func(*httptest.Server): opt(cst.ts) default: @@ -109,60 +171,97 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientS } } - if !h2 { - cst.ts.Start() - return cst + if cst.ts.Config.ErrorLog == nil { + cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) } - ExportHttp2ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - cst.tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, + switch mode { + case http1Mode: + cst.ts.Start() + case https1Mode: + cst.ts.StartTLS() + case http2Mode: + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + default: + t.Fatalf("unknown test mode %v", mode) } - if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { - t.Fatal(err) + cst.c = cst.ts.Client() + cst.tr = cst.c.Transport.(*Transport) + if mode == http2Mode { + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } + } + for _, f := range transportFuncs { + f(cst.tr) } + t.Cleanup(func() { + cst.close() + }) return cst } +type testLogWriter struct { + t testing.TB +} + +func (w testLogWriter) Write(b []byte) (int, error) { + w.t.Logf("server log: %v", strings.TrimSpace(string(b))) + return len(b), nil +} + // Testing the newClientServerTest helper itself. func TestNewClientServerTest(t *testing.T) { + run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testNewClientServerTest(t *testing.T, mode testMode) { var got struct { sync.Mutex - log []string + proto string + hasTLS bool } h := HandlerFunc(func(w ResponseWriter, r *Request) { got.Lock() defer got.Unlock() - got.log = append(got.log, r.Proto) + got.proto = r.Proto + got.hasTLS = r.TLS != nil }) - for _, v := range [2]bool{false, true} { - cst := newClientServerTest(t, v, h) - if _, err := cst.c.Head(cst.ts.URL); err != nil { - t.Fatal(err) - } - cst.close() + cst := newClientServerTest(t, mode, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) + } + var wantProto string + var wantTLS bool + switch mode { + case http1Mode: + wantProto = "HTTP/1.1" + wantTLS = false + case https1Mode: + wantProto = "HTTP/1.1" + wantTLS = true + case http2Mode: + wantProto = "HTTP/2.0" + wantTLS = true } - got.Lock() // no need to unlock - if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { - t.Errorf("got %q; want %q", got.log, want) + if got.proto != wantProto { + t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) + } + if got.hasTLS != wantTLS { + t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) } } -func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } -func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } - -func testChunkedResponseHeaders(t *testing.T, h2 bool) { - defer afterTest(t) +func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } +func testChunkedResponseHeaders(t *testing.T, mode testMode) { log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -173,7 +272,7 @@ func testChunkedResponseHeaders(t *testing.T, h2 bool) { t.Errorf("expected ContentLength of %d; got %d", e, g) } wantTE := []string{"chunked"} - if h2 { + if mode == http2Mode { wantTE = nil } if !reflect.DeepEqual(res.TransferEncoding, wantTE) { @@ -205,9 +304,9 @@ func (tt h12Compare) reqFunc() reqFunc { func (tt h12Compare) run(t *testing.T) { setParallel(t) - cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() - cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst2.close() res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) @@ -460,12 +559,9 @@ func TestH12_AutoGzip_Disabled(t *testing.T) { // Test304Responses verifies that 304s don't declare that they're // chunking in their response headers and aren't allowed to produce // output. -func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } -func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } - -func test304Responses(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func Test304Responses(t *testing.T) { run(t, test304Responses) } +func test304Responses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) if err != ErrBodyNotAllowed { @@ -529,20 +625,17 @@ func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int6 // Tests that closing the Request.Cancel channel also while still // reading the response body. Issue 13159. -func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } -func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } -func testCancelRequestMidBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } +func testCancelRequestMidBody(t *testing.T, mode testMode) { unblock := make(chan bool) didFlush := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello") w.(Flusher).Flush() didFlush <- true <-unblock io.WriteString(w, ", world.") })) - defer cst.close() defer close(unblock) req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -578,12 +671,9 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { } // Tests that clients can send trailers to a server and that the server can read them. -func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } -func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } - -func testTrailersClientToServer(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } +func testTrailersClientToServer(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var decl []string for k := range r.Trailer { decl = append(decl, k) @@ -606,7 +696,6 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { r.Trailer.Get("Client-Trailer-B")) } })) - defer cst.close() var req *Request req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( @@ -633,15 +722,20 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } // Tests that servers send trailers to a client and that the client can read them. -func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } -func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } -func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } -func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } +func TestTrailersServerToClient(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, false) + }) +} +func TestTrailersServerToClientFlush(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, true) + }) +} -func testTrailersServerToClient(t *testing.T, h2, flush bool) { - defer afterTest(t) +func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") w.Header().Add("Trailer", "Server-Trailer-C") @@ -658,7 +752,6 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { w.Header().Set("Server-Trailer-C", "valuec") // skipping B w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -669,7 +762,7 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { "Content-Type": {"text/plain; charset=utf-8"}, } wantLen := -1 - if h2 && !flush { + if mode == http2Mode && !flush { // In HTTP/1.1, any use of trailers forces HTTP/1.1 // chunking and a flush at the first write. That's // unnecessary with HTTP/2's framing, so the server @@ -709,16 +802,12 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { } // Don't allow a Body.Read after Body.Close. Issue 13648. -func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } -func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } - -func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { - defer afterTest(t) +func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } +func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -730,13 +819,11 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { } } -func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } -func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } -func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } +func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { const reqBody = "some request body" const resBody = "some response body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var wg sync.WaitGroup wg.Add(2) didRead := make(chan bool, 1) @@ -755,7 +842,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Write in another goroutine. go func() { defer wg.Done() - if !h2 { + if mode != http2Mode { // our HTTP/1 implementation intentionally // doesn't permit writes during read (mostly // due to it being undefined); if that is ever @@ -766,7 +853,6 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { }() wg.Wait() })) - defer cst.close() req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) req.Header.Add("Expect", "100-continue") // just to complicate things res, err := cst.c.Do(req) @@ -783,15 +869,12 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { } } -func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } -func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } -func testConnectRequest(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } +func testConnectRequest(t *testing.T, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotc <- r })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -841,17 +924,14 @@ func testConnectRequest(t *testing.T, h2 bool) { } } -func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } -func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } -func testTransportUserAgent(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } +func testTransportUserAgent(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%q", r.Header["User-Agent"]) })) - defer cst.close() either := func(a, b string) string { - if h2 { + if mode == http2Mode { return b } return a @@ -902,19 +982,22 @@ func testTransportUserAgent(t *testing.T, h2 bool) { } } -func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } -func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } -func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } -func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } -func testStarRequest(t *testing.T, method string, h2 bool) { - defer afterTest(t) +func TestStarRequestMethod(t *testing.T) { + for _, method := range []string{"FOO", "OPTIONS"} { + t.Run(method, func(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testStarRequest(t, method, mode) + }) + }) + } +} +func testStarRequest(t *testing.T, method string, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("foo", "bar") gotc <- r w.(Flusher).Flush() })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -973,9 +1056,10 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) +} +func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) })) defer cst.close() @@ -1059,20 +1143,19 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { } // tests that Transport doesn't retain a pointer to the provided request. -func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) } -func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) } -func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) } -func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) } -func testTransportGCRequest(t *testing.T, h2, body bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGCRequest(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) + t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) + }) +} +func testTransportGCRequest(t *testing.T, mode testMode, body bool) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } })) - defer cst.close() didGC := make(chan struct{}) (func() { @@ -1104,19 +1187,11 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { } } -func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h1Mode) -} -func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h2Mode) -} -func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } +func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) }), optQuietLog) - defer cst.close() cst.tr.DisableKeepAlives = true tests := []struct { @@ -1162,27 +1237,22 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } -func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } -func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } -func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { - testInterruptWithPanic(t, h1Mode, ErrAbortHandler) -} -func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { - testInterruptWithPanic(t, h2Mode, ErrAbortHandler) +func TestInterruptWithPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) + t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) + t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) + }) } -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { - setParallel(t) +func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { const msg = "hello" - defer afterTest(t) testDone := make(chan struct{}) defer close(testDone) var errorLog lockedBytesBuffer gotHeaders := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() @@ -1194,7 +1264,6 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1275,15 +1344,11 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { } // Issue 14607 -func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } -func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } -func testCloseIdleConnections(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } +func testCloseIdleConnections(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1321,15 +1386,11 @@ func (r testErrorReader) Read(p []byte) (n int, err error) { return 0, io.EOF } -func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } -func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } - -func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } +func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusUnauthorized) })) - defer cst.close() // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. cst.tr.ExpectContinueTimeout = 10 * time.Second @@ -1350,18 +1411,15 @@ func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { } } -func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } -func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } -func testServerUndeclaredTrailers(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } +func testServerUndeclaredTrailers(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.Header().Set("Trailer:Foo", "Baz") w.(Flusher).Flush() w.Header().Add("Trailer:Foo", "Baz2") w.Header().Set("Trailer:Bar", "Quux") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1382,8 +1440,10 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { } func TestBadResponseAfterReadingBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) +} +func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) @@ -1395,7 +1455,6 @@ func TestBadResponseAfterReadingBody(t *testing.T) { defer c.Close() fmt.Fprintln(c, "some bogus crap") })) - defer cst.close() closes := 0 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) @@ -1408,12 +1467,10 @@ func TestBadResponseAfterReadingBody(t *testing.T) { } } -func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } -func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } -func testWriteHeader0(t *testing.T, h2 bool) { - defer afterTest(t) +func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } +func testWriteHeader0(t *testing.T, mode testMode) { gotpanic := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(gotpanic) defer func() { if e := recover(); e != nil { @@ -1432,7 +1489,6 @@ func testWriteHeader0(t *testing.T, h2 bool) { }() w.WriteHeader(0) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1447,15 +1503,17 @@ func testWriteHeader0(t *testing.T, h2 bool) { // Issue 23010: don't be super strict checking WriteHeader's code if // it's not even valid to call WriteHeader then anyway. -func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } -func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } -func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } -func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { - setParallel(t) - defer afterTest(t) - +func TestWriteHeaderNoCodeCheck(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testWriteHeaderAfterWrite(t, mode, false) + }) +} +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { + testWriteHeaderAfterWrite(t, http1Mode, true) +} +func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { var errorLog lockedBytesBuffer - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if hijack { conn, _, _ := w.(Hijacker).Hijack() defer conn.Close() @@ -1471,7 +1529,6 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1486,7 +1543,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } // Also check the stderr output: - if h2 { + if mode == http2Mode { // TODO: also emit this log message for HTTP/2? // We historically haven't, so don't check. return @@ -1502,14 +1559,14 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } func TestBidiStreamReverseProxy(t *testing.T) { - setParallel(t) - defer afterTest(t) - backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) +} +func testBidiStreamReverseProxy(t *testing.T, mode testMode) { + backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if _, err := io.Copy(w, r.Body); err != nil { log.Printf("bidi backend copy: %v", err) } })) - defer backend.close() backURL, err := url.Parse(backend.ts.URL) if err != nil { @@ -1517,10 +1574,9 @@ func TestBidiStreamReverseProxy(t *testing.T) { } rp := httputil.NewSingleHostReverseProxy(backURL) rp.Transport = backend.tr - proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rp.ServeHTTP(w, r) })) - defer proxy.close() bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() @@ -1587,15 +1643,10 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }.run(t) } -func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } -func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } - -func testIdentityTransferEncoding(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } +func testIdentityTransferEncoding(t *testing.T, mode testMode) { const body = "body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotBody, _ := io.ReadAll(r.Body) if got, want := string(gotBody), body; got != want { t.Errorf("got request body = %q; want %q", got, want) @@ -1605,7 +1656,6 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { w.(Flusher).Flush() io.WriteString(w, body) })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) res, err := cst.c.Do(req) if err != nil { @@ -1621,14 +1671,11 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { } } -func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) } -func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) } -func testEarlyHintsRequest(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } +func testEarlyHintsRequest(t *testing.T, mode testMode) { var wg sync.WaitGroup wg.Add(1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { h := w.Header() h.Add("Content-Length", "123") // must be ignored @@ -1643,7 +1690,6 @@ func testEarlyHintsRequest(t *testing.T, h2 bool) { w.Write([]byte("Hello")) })) - defer cst.close() checkLinkHeaders := func(t *testing.T, expected, got []string) { t.Helper() diff --git a/cookie.go b/cookie.go index 42b7823b..ca4a8722 100644 --- a/cookie.go +++ b/cookie.go @@ -74,6 +74,7 @@ func readSetCookies(h Header) []*Cookie { if !ok { continue } + name = textproto.TrimString(name) if !isCookieNameValid(name) { continue } @@ -247,7 +248,7 @@ func (c *Cookie) Valid() error { if !isCookieNameValid(c.Name) { return errors.New("http: invalid Cookie.Name") } - if !validCookieExpires(c.Expires) { + if !c.Expires.IsZero() && !validCookieExpires(c.Expires) { return errors.New("http: invalid Cookie.Expires") } for i := 0; i < len(c.Value); i++ { @@ -273,7 +274,7 @@ func (c *Cookie) Valid() error { // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // -// if filter isn't empty, only cookies of that name are returned +// if filter isn't empty, only cookies of that name are returned. func readCookies(h Header, filter string) []*Cookie { lines := h["Cookie"] if len(lines) == 0 { @@ -292,6 +293,7 @@ func readCookies(h Header, filter string) []*Cookie { continue } name, val, _ := strings.Cut(part, "=") + name = textproto.TrimString(name) if !isCookieNameValid(name) { continue } diff --git a/cookie_test.go b/cookie_test.go index ccc5f980..e5bd46a7 100644 --- a/cookie_test.go +++ b/cookie_test.go @@ -5,7 +5,6 @@ package http import ( - "bytes" "encoding/json" "fmt" "log" @@ -151,7 +150,7 @@ var writeSetCookiesTests = []struct { func TestWriteSetCookies(t *testing.T) { defer log.SetOutput(os.Stderr) - var logbuf bytes.Buffer + var logbuf strings.Builder log.SetOutput(&logbuf) for i, tt := range writeSetCookiesTests { @@ -352,6 +351,12 @@ var readSetCookiesTests = []struct { Header{"Set-Cookie": {`special-8=","`}}, []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}}, }, + // Make sure we can properly read back the Set-Cookie headers + // for names containing spaces: + { + Header{"Set-Cookie": {`special-9 =","`}}, + []*Cookie{{Name: "special-9", Value: ",", Raw: `special-9 =","`}}, + }, // TODO(bradfitz): users have reported seeing this in the // wild, but do browsers handle it? RFC 6265 just says "don't @@ -476,7 +481,7 @@ func TestSetCookieDoubleQuotes(t *testing.T) { func TestCookieSanitizeValue(t *testing.T) { defer log.SetOutput(os.Stderr) - var logbuf bytes.Buffer + var logbuf strings.Builder log.SetOutput(&logbuf) tests := []struct { @@ -508,7 +513,7 @@ func TestCookieSanitizeValue(t *testing.T) { func TestCookieSanitizePath(t *testing.T) { defer log.SetOutput(os.Stderr) - var logbuf bytes.Buffer + var logbuf strings.Builder log.SetOutput(&logbuf) tests := []struct { @@ -536,11 +541,14 @@ func TestCookieValid(t *testing.T) { }{ {nil, false}, {&Cookie{Name: ""}, false}, - {&Cookie{Name: "invalid-expires"}, false}, {&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false}, {&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false}, {&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false}, - {&Cookie{Name: "valid", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true}, + {&Cookie{Name: "invalid-expiry", Value: "", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, false}, + {&Cookie{Name: "valid-empty"}, true}, + {&Cookie{Name: "valid-expires", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true}, + {&Cookie{Name: "valid-max-age", Value: "foo", Path: "/bar", Domain: "example.com", MaxAge: 60}, true}, + {&Cookie{Name: "valid-all-fields", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0), MaxAge: 0}, true}, } for _, tt := range tests { diff --git a/cookiejar/jar.go b/cookiejar/jar.go index 26370a8f..cfee9704 100644 --- a/cookiejar/jar.go +++ b/cookiejar/jar.go @@ -215,8 +215,8 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { if len(s[i].Path) != len(s[j].Path) { return len(s[i].Path) > len(s[j].Path) } - if !s[i].Creation.Equal(s[j].Creation) { - return s[i].Creation.Before(s[j].Creation) + if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 { + return ret < 0 } return s[i].seqNum < s[j].seqNum }) @@ -466,7 +466,7 @@ func (j *Jar) domainAndType(host, domain string) (string, bool, error) { // dot in the domain-attribute before processing the cookie. // // Most browsers don't do that for IP addresses, only curl - // version 7.54) and and IE (version 11) do not reject a + // version 7.54) and IE (version 11) do not reject a // Set-Cookie: a=1; domain=.127.0.0.1 // This leading dot is optional and serves only as hint for // humans to indicate that a cookie with "domain=.bbc.co.uk" diff --git a/export_test.go b/export_test.go index 205ca83f..fb5ab939 100644 --- a/export_test.go +++ b/export_test.go @@ -60,7 +60,7 @@ func init() { } } -func CondSkipHTTP2(t *testing.T) { +func CondSkipHTTP2(t testing.TB) { if omitBundledHTTP2 { t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use") } @@ -72,8 +72,6 @@ var ( ) func SetReadLoopBeforeNextReadHook(f func()) { - testHookMu.Lock() - defer testHookMu.Unlock() unnilTestHook(&f) testHookReadLoopBeforeNextRead = f } diff --git a/fs.go b/fs.go index 0a94de91..c21cef95 100644 --- a/fs.go +++ b/fs.go @@ -257,81 +257,95 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, Error(w, err.Error(), StatusInternalServerError) return } + if size < 0 { + // Should never happen but just to be sure + Error(w, "negative content size computed", StatusInternalServerError) + return + } // handle Content-Range header. sendSize := size var sendContent io.Reader = content - if size >= 0 { - ranges, err := parseRange(rangeReq, size) - if err != nil { - if err == errNoOverlap { - w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) - } + ranges, err := parseRange(rangeReq, size) + switch err { + case nil: + case errNoOverlap: + if size == 0 { + // Some clients add a Range header to all requests to + // limit the size of the response. If the file is empty, + // ignore the range header and respond with a 200 rather + // than a 416. + ranges = nil + break + } + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + fallthrough + default: + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } - if sumRangesSize(ranges) > size { - // The total number of bytes in all the ranges - // is larger than the size of the file by - // itself, so this is probably an attack, or a - // dumb client. Ignore the range request. - ranges = nil - } - switch { - case len(ranges) == 1: - // RFC 7233, Section 4.1: - // "If a single part is being transferred, the server - // generating the 206 response MUST generate a - // Content-Range header field, describing what range - // of the selected representation is enclosed, and a - // payload consisting of the range. - // ... - // A server MUST NOT generate a multipart response to - // a request for a single range, since a client that - // does not request multiple parts might not support - // multipart responses." - ra := ranges[0] - if _, err := content.Seek(ra.start, io.SeekStart); err != nil { - Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) - return - } - sendSize = ra.length - code = StatusPartialContent - w.Header().Set("Content-Range", ra.contentRange(size)) - case len(ranges) > 1: - sendSize = rangesMIMESize(ranges, ctype, size) - code = StatusPartialContent - - pr, pw := io.Pipe() - mw := multipart.NewWriter(pw) - w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) - sendContent = pr - defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. - go func() { - for _, ra := range ranges { - part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) - if err != nil { - pw.CloseWithError(err) - return - } - if _, err := content.Seek(ra.start, io.SeekStart); err != nil { - pw.CloseWithError(err) - return - } - if _, err := io.CopyN(part, content, ra.length); err != nil { - pw.CloseWithError(err) - return - } + sendSize = ra.length + code = StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return } - mw.Close() - pw.Close() - }() - } + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } - w.Header().Set("Accept-Ranges", "bytes") - if w.Header().Get("Content-Encoding") == "" { - w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) - } + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) } w.WriteHeader(code) @@ -434,7 +448,7 @@ func checkIfUnmodifiedSince(r *Request, modtime time.Time) condResult { // The Last-Modified header truncates sub-second precision so // the modtime needs to be truncated too. modtime = modtime.Truncate(time.Second) - if modtime.Before(t) || modtime.Equal(t) { + if ret := modtime.Compare(t); ret <= 0 { return condTrue } return condFalse @@ -485,7 +499,7 @@ func checkIfModifiedSince(r *Request, modtime time.Time) condResult { // The Last-Modified header truncates sub-second precision so // the modtime needs to be truncated too. modtime = modtime.Truncate(time.Second) - if modtime.Before(t) || modtime.Equal(t) { + if ret := modtime.Compare(t); ret <= 0 { return condFalse } return condTrue @@ -645,7 +659,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec defer ff.Close() dd, err := ff.Stat() if err == nil { - name = index d = dd f = ff } @@ -821,6 +834,7 @@ func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) { // FS converts fsys to a FileSystem implementation, // for use with FileServer and NewFileTransport. +// The files provided by fsys must implement io.Seeker. func FS(fsys fs.FS) FileSystem { return ioFS{fsys} } diff --git a/fs_test.go b/fs_test.go index 19a2e3c9..baa0705d 100644 --- a/fs_test.go +++ b/fs_test.go @@ -68,13 +68,11 @@ var ServeFileRangeTests = []struct { {r: "bytes=100-1000", code: StatusRequestedRangeNotSatisfiable}, } -func TestServeFile(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFile(t *testing.T) { run(t, testServeFile) } +func testServeFile(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") - })) - defer ts.Close() + })).ts c := ts.Client() var err error @@ -220,6 +218,27 @@ func TestServeFileDirPanicEmptyPath(t *testing.T) { } } +// Tests that ranges are ignored with serving empty content. (Issue 54794) +func TestServeContentWithEmptyContentIgnoreRanges(t *testing.T) { + for _, r := range []string{ + "bytes=0-128", + "bytes=1-", + } { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Range", r) + ServeContent(rec, req, "nothing", time.Now(), bytes.NewReader(nil)) + res := rec.Result() + if res.StatusCode != 200 { + t.Errorf("code = %v; want 200", res.Status) + } + bodyLen := rec.Body.Len() + if bodyLen != 0 { + t.Errorf("body.Len() = %v; want 0", res.Status) + } + } +} + var fsRedirectTestData = []struct { original, redirect string }{ @@ -228,13 +247,12 @@ var fsRedirectTestData = []struct { {"/test/testdata/file/", "/test/testdata/file"}, } -func TestFSRedirect(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) - defer ts.Close() +func TestFSRedirect(t *testing.T) { run(t, testFSRedirect) } +func testFSRedirect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, StripPrefix("/test", FileServer(Dir(".")))).ts for _, data := range fsRedirectTestData { - res, err := Get(ts.URL + data.original) + res, err := ts.Client().Get(ts.URL + data.original) if err != nil { t.Fatal(err) } @@ -278,8 +296,8 @@ func TestFileServerCleans(t *testing.T) { } } -func TestFileServerEscapesNames(t *testing.T) { - defer afterTest(t) +func TestFileServerEscapesNames(t *testing.T) { run(t, testFileServerEscapesNames) } +func testFileServerEscapesNames(t *testing.T, mode testMode) { const dirListPrefix = "
\n"
 	const dirListSuffix = "\n
\n" tests := []struct { @@ -304,11 +322,10 @@ func TestFileServerEscapesNames(t *testing.T) { fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile } - ts := httptest.NewServer(FileServer(&fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(&fs)).ts for i, test := range tests { url := fmt.Sprintf("%s/%d", ts.URL, i) - res, err := Get(url) + res, err := ts.Client().Get(url) if err != nil { t.Fatalf("test %q: Get: %v", test.name, err) } @@ -327,8 +344,8 @@ func TestFileServerEscapesNames(t *testing.T) { } } -func TestFileServerSortsNames(t *testing.T) { - defer afterTest(t) +func TestFileServerSortsNames(t *testing.T) { run(t, testFileServerSortsNames) } +func testFileServerSortsNames(t *testing.T, mode testMode) { const contents = "I am a fake file" dirMod := time.Unix(123, 0).UTC() fileMod := time.Unix(1000000000, 0).UTC() @@ -351,10 +368,9 @@ func TestFileServerSortsNames(t *testing.T) { }, } - ts := httptest.NewServer(FileServer(&fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(&fs)).ts - res, err := Get(ts.URL) + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } @@ -377,16 +393,15 @@ func mustRemoveAll(dir string) { } } -func TestFileServerImplicitLeadingSlash(t *testing.T) { - defer afterTest(t) +func TestFileServerImplicitLeadingSlash(t *testing.T) { run(t, testFileServerImplicitLeadingSlash) } +func testFileServerImplicitLeadingSlash(t *testing.T, mode testMode) { tempDir := t.TempDir() if err := os.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { t.Fatalf("WriteFile: %v", err) } - ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/bar/", FileServer(Dir(tempDir)))).ts get := func(suffix string) string { - res, err := Get(ts.URL + suffix) + res, err := ts.Client().Get(ts.URL + suffix) if err != nil { t.Fatalf("Get %s: %v", suffix, err) } @@ -455,10 +470,10 @@ func TestEmptyDirOpenCWD(t *testing.T) { test(Dir("./")) } -func TestServeFileContentType(t *testing.T) { - defer afterTest(t) +func TestServeFileContentType(t *testing.T) { run(t, testServeFileContentType) } +func testServeFileContentType(t *testing.T, mode testMode) { const ctype = "icecream/chocolate" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.FormValue("override") { case "1": w.Header().Set("Content-Type", ctype) @@ -467,10 +482,9 @@ func TestServeFileContentType(t *testing.T) { w.Header()["Content-Type"] = []string{} } ServeFile(w, r, "testdata/file") - })) - defer ts.Close() + })).ts get := func(override string, want []string) { - resp, err := Get(ts.URL + "?override=" + override) + resp, err := ts.Client().Get(ts.URL + "?override=" + override) if err != nil { t.Fatal(err) } @@ -484,13 +498,12 @@ func TestServeFileContentType(t *testing.T) { get("2", nil) } -func TestServeFileMimeType(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileMimeType(t *testing.T) { run(t, testServeFileMimeType) } +func testServeFileMimeType(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") - })) - defer ts.Close() - resp, err := Get(ts.URL) + })).ts + resp, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -501,13 +514,12 @@ func TestServeFileMimeType(t *testing.T) { } } -func TestServeFileFromCWD(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileFromCWD(t *testing.T) { run(t, testServeFileFromCWD) } +func testServeFileFromCWD(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") - })) - defer ts.Close() - r, err := Get(ts.URL) + })).ts + r, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -518,14 +530,13 @@ func TestServeFileFromCWD(t *testing.T) { } // Issue 13996 -func TestServeDirWithoutTrailingSlash(t *testing.T) { +func TestServeDirWithoutTrailingSlash(t *testing.T) { run(t, testServeDirWithoutTrailingSlash) } +func testServeDirWithoutTrailingSlash(t *testing.T, mode testMode) { e := "/testdata/" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, ".") - })) - defer ts.Close() - r, err := Get(ts.URL + "/testdata") + })).ts + r, err := ts.Client().Get(ts.URL + "/testdata") if err != nil { t.Fatal(err) } @@ -537,11 +548,9 @@ func TestServeDirWithoutTrailingSlash(t *testing.T) { // Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is // specified. -func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) } -func TestServeFileWithContentEncoding_h2(t *testing.T) { testServeFileWithContentEncoding(t, h2Mode) } -func testServeFileWithContentEncoding(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileWithContentEncoding(t *testing.T) { run(t, testServeFileWithContentEncoding) } +func testServeFileWithContentEncoding(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -554,7 +563,6 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { // Content-Length and test ServeFile only, flush here. w.(Flusher).Flush() })) - defer cst.close() resp, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -567,11 +575,9 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { // Tests that ServeFile does not generate representation metadata when // file has not been modified, as per RFC 7232 section 4.1. -func TestServeFileNotModified_h1(t *testing.T) { testServeFileNotModified(t, h1Mode) } -func TestServeFileNotModified_h2(t *testing.T) { testServeFileNotModified(t, h2Mode) } -func testServeFileNotModified(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileNotModified(t *testing.T) { run(t, testServeFileNotModified) } +func testServeFileNotModified(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Encoding", "foo") w.Header().Set("Etag", `"123"`) @@ -586,7 +592,6 @@ func testServeFileNotModified(t *testing.T, h2 bool) { // Content-Length and test ServeFile only, flush here. w.(Flusher).Flush() })) - defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) @@ -619,9 +624,8 @@ func testServeFileNotModified(t *testing.T, h2 bool) { } } -func TestServeIndexHtml(t *testing.T) { - defer afterTest(t) - +func TestServeIndexHtml(t *testing.T) { run(t, testServeIndexHtml) } +func testServeIndexHtml(t *testing.T, mode testMode) { for i := 0; i < 2; i++ { var h Handler var name string @@ -635,11 +639,10 @@ func TestServeIndexHtml(t *testing.T) { } t.Run(name, func(t *testing.T) { const want = "index.html says hello\n" - ts := httptest.NewServer(h) - defer ts.Close() + ts := newClientServerTest(t, mode, h).ts for _, path := range []string{"/testdata/", "/testdata/index.html"} { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Fatal(err) } @@ -656,14 +659,14 @@ func TestServeIndexHtml(t *testing.T) { } } -func TestServeIndexHtmlFS(t *testing.T) { - defer afterTest(t) +func TestServeIndexHtmlFS(t *testing.T) { run(t, testServeIndexHtmlFS) } +func testServeIndexHtmlFS(t *testing.T, mode testMode) { const want = "index.html says hello\n" - ts := httptest.NewServer(FileServer(Dir("."))) + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts defer ts.Close() for _, path := range []string{"/testdata/", "/testdata/index.html"} { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Fatal(err) } @@ -678,10 +681,9 @@ func TestServeIndexHtmlFS(t *testing.T) { } } -func TestFileServerZeroByte(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(FileServer(Dir("."))) - defer ts.Close() +func TestFileServerZeroByte(t *testing.T) { run(t, testFileServerZeroByte) } +func testFileServerZeroByte(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -703,18 +705,9 @@ func TestFileServerZeroByte(t *testing.T) { } } -func TestFileServerNamesEscape(t *testing.T) { - t.Run("h1", func(t *testing.T) { - testFileServerNamesEscape(t, h1Mode) - }) - t.Run("h2", func(t *testing.T) { - testFileServerNamesEscape(t, h2Mode) - }) -} -func testFileServerNamesEscape(t *testing.T, h2 bool) { - defer afterTest(t) - ts := newClientServerTest(t, h2, FileServer(Dir("testdata"))).ts - defer ts.Close() +func TestFileServerNamesEscape(t *testing.T) { run(t, testFileServerNamesEscape) } +func testFileServerNamesEscape(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts for _, path := range []string{ "/../testdata/file", "/NUL", // don't read from device files on Windows @@ -796,8 +789,8 @@ func (fsys fakeFS) Open(name string) (File, error) { return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil } -func TestDirectoryIfNotModified(t *testing.T) { - defer afterTest(t) +func TestDirectoryIfNotModified(t *testing.T) { run(t, testDirectoryIfNotModified) } +func testDirectoryIfNotModified(t *testing.T, mode testMode) { const indexContents = "I am a fake index.html file" fileMod := time.Unix(1000000000, 0).UTC() fileModStr := fileMod.Format(TimeFormat) @@ -816,10 +809,9 @@ func TestDirectoryIfNotModified(t *testing.T) { "/index.html": indexFile, } - ts := httptest.NewServer(FileServer(fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(fs)).ts - res, err := Get(ts.URL) + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -871,8 +863,8 @@ func mustStat(t *testing.T, fileName string) fs.FileInfo { return fi } -func TestServeContent(t *testing.T) { - defer afterTest(t) +func TestServeContent(t *testing.T) { run(t, testServeContent) } +func testServeContent(t *testing.T, mode testMode) { type serveParam struct { name string modtime time.Time @@ -881,7 +873,7 @@ func TestServeContent(t *testing.T) { etag string } servec := make(chan serveParam, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { p := <-servec if p.etag != "" { w.Header().Set("ETag", p.etag) @@ -890,8 +882,7 @@ func TestServeContent(t *testing.T) { w.Header().Set("Content-Type", p.contentType) } ServeContent(w, r, p.name, p.modtime, p.content) - })) - defer ts.Close() + })).ts type testCase struct { // One of file or content must be set: @@ -1200,8 +1191,8 @@ type issue12991File struct{ File } func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } func (issue12991File) Close() error { return nil } -func TestServeContentErrorMessages(t *testing.T) { - defer afterTest(t) +func TestServeContentErrorMessages(t *testing.T) { run(t, testServeContentErrorMessages) } +func testServeContentErrorMessages(t *testing.T, mode testMode) { fs := fakeFS{ "/500": &fakeFileInfo{ err: errors.New("random error"), @@ -1210,8 +1201,7 @@ func TestServeContentErrorMessages(t *testing.T) { err: &fs.PathError{Err: fs.ErrPermission}, }, } - ts := httptest.NewServer(FileServer(fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(fs)).ts c := ts.Client() for _, code := range []int{403, 404, 500} { res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) @@ -1260,7 +1250,7 @@ func TestLinuxSendfile(t *testing.T) { } defer os.Remove(filepath) - var buf bytes.Buffer + var buf strings.Builder child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) @@ -1329,20 +1319,20 @@ func TestLinuxSendfileChild(*testing.T) { // Issues 18984, 49552: tests that requests for paths beyond files return not-found errors func TestFileServerNotDirError(t *testing.T) { - defer afterTest(t) - t.Run("Dir", func(t *testing.T) { - testFileServerNotDirError(t, func(path string) FileSystem { return Dir(path) }) - }) - t.Run("FS", func(t *testing.T) { - testFileServerNotDirError(t, func(path string) FileSystem { return FS(os.DirFS(path)) }) + run(t, func(t *testing.T, mode testMode) { + t.Run("Dir", func(t *testing.T) { + testFileServerNotDirError(t, mode, func(path string) FileSystem { return Dir(path) }) + }) + t.Run("FS", func(t *testing.T) { + testFileServerNotDirError(t, mode, func(path string) FileSystem { return FS(os.DirFS(path)) }) + }) }) } -func testFileServerNotDirError(t *testing.T, newfs func(string) FileSystem) { - ts := httptest.NewServer(FileServer(newfs("testdata"))) - defer ts.Close() +func testFileServerNotDirError(t *testing.T, mode testMode, newfs func(string) FileSystem) { + ts := newClientServerTest(t, mode, FileServer(newfs("testdata"))).ts - res, err := Get(ts.URL + "/index.html/not-a-file") + res, err := ts.Client().Get(ts.URL + "/index.html/not-a-file") if err != nil { t.Fatal(err) } @@ -1446,19 +1436,11 @@ func Test_scanETag(t *testing.T) { // Issue 40940: Ensure that we only accept non-negative suffix-lengths // in "Range": "bytes=-N", and should reject "bytes=--2". -func TestServeFileRejectsInvalidSuffixLengths_h1(t *testing.T) { - testServeFileRejectsInvalidSuffixLengths(t, h1Mode) +func TestServeFileRejectsInvalidSuffixLengths(t *testing.T) { + run(t, testServeFileRejectsInvalidSuffixLengths, []testMode{http1Mode, https1Mode, http2Mode}) } -func TestServeFileRejectsInvalidSuffixLengths_h2(t *testing.T) { - testServeFileRejectsInvalidSuffixLengths(t, h2Mode) -} - -func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { - defer afterTest(t) - cst := httptest.NewUnstartedServer(FileServer(Dir("testdata"))) - cst.EnableHTTP2 = h2 - cst.StartTLS() - defer cst.Close() +func testServeFileRejectsInvalidSuffixLengths(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts tests := []struct { r string diff --git a/go.mod b/go.mod index bcba652e..b4fc2252 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/ooni/oohttp -go 1.18 +go 1.20 -require golang.org/x/net v0.4.0 +require golang.org/x/net v0.10.0 -require golang.org/x/text v0.5.0 // indirect +require golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum index 0b4a2b6e..a3fe7ead 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ -golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= -golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= diff --git a/h2_bundle.go b/h2_bundle.go index e17de361..529c43cd 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -30,6 +30,7 @@ import ( "errors" "fmt" "io" + "io/fs" "log" "math" mathrand "math/rand" @@ -859,7 +860,6 @@ func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr strin func (c *http2dialCall) dial(ctx context.Context, addr string) { const singleUse = false // shared conn c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) - close(c.done) c.p.mu.Lock() delete(c.p.dialing, addr) @@ -867,6 +867,8 @@ func (c *http2dialCall) dial(ctx context.Context, addr string) { c.p.addConnLocked(addr, c.res) } c.p.mu.Unlock() + + close(c.done) } // addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't @@ -1351,7 +1353,7 @@ const http2frameHeaderLen = 9 var http2padZeros = make([]byte, 255) // zeros for padding // A FrameType is a registered frame type as defined in -// http://http2.github.io/http2-spec/#rfc.section.11.2 +// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2 type http2FrameType uint8 const ( @@ -1474,7 +1476,7 @@ func http2typeFrameParser(t http2FrameType) http2frameParser { // A FrameHeader is the 9 byte header of all HTTP/2 frames. // -// See http://http2.github.io/http2-spec/#FrameHeader +// See https://httpwg.org/specs/rfc7540.html#FrameHeader type http2FrameHeader struct { valid bool // caller can access []byte fields in the Frame @@ -1906,7 +1908,7 @@ func (fr *http2Framer) checkFrameOrder(f http2Frame) error { // A DataFrame conveys arbitrary, variable-length sequences of octets // associated with a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.1 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.1 type http2DataFrame struct { http2FrameHeader data []byte @@ -2029,7 +2031,7 @@ func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad // endpoints communicate, such as preferences and constraints on peer // behavior. // -// See http://http2.github.io/http2-spec/#SETTINGS +// See https://httpwg.org/specs/rfc7540.html#SETTINGS type http2SettingsFrame struct { http2FrameHeader p []byte @@ -2168,7 +2170,7 @@ func (f *http2Framer) WriteSettingsAck() error { // A PingFrame is a mechanism for measuring a minimal round trip time // from the sender, as well as determining whether an idle connection // is still functional. -// See http://http2.github.io/http2-spec/#rfc.section.6.7 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.7 type http2PingFrame struct { http2FrameHeader Data [8]byte @@ -2201,7 +2203,7 @@ func (f *http2Framer) WritePing(ack bool, data [8]byte) error { } // A GoAwayFrame informs the remote peer to stop creating streams on this connection. -// See http://http2.github.io/http2-spec/#rfc.section.6.8 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.8 type http2GoAwayFrame struct { http2FrameHeader LastStreamID uint32 @@ -2265,7 +2267,7 @@ func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError } // A WindowUpdateFrame is used to implement flow control. -// See http://http2.github.io/http2-spec/#rfc.section.6.9 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.9 type http2WindowUpdateFrame struct { http2FrameHeader Increment uint32 // never read with high bit set @@ -2454,7 +2456,7 @@ func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { } // A PriorityFrame specifies the sender-advised priority of a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.3 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.3 type http2PriorityFrame struct { http2FrameHeader http2PriorityParam @@ -2524,7 +2526,7 @@ func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error } // A RSTStreamFrame allows for abnormal termination of a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.4 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4 type http2RSTStreamFrame struct { http2FrameHeader ErrCode http2ErrCode @@ -2556,7 +2558,7 @@ func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { } // A ContinuationFrame is used to continue a sequence of header block fragments. -// See http://http2.github.io/http2-spec/#rfc.section.6.10 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.10 type http2ContinuationFrame struct { http2FrameHeader headerFragBuf []byte @@ -2597,7 +2599,7 @@ func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, header } // A PushPromiseFrame is used to initiate a server stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.6 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.6 type http2PushPromiseFrame struct { http2FrameHeader PromiseID uint32 @@ -3184,7 +3186,14 @@ func http2buildCommonHeaderMaps() { "accept-language", "accept-ranges", "age", + "access-control-allow-credentials", + "access-control-allow-headers", + "access-control-allow-methods", "access-control-allow-origin", + "access-control-expose-headers", + "access-control-max-age", + "access-control-request-headers", + "access-control-request-method", "allow", "authorization", "cache-control", @@ -3210,6 +3219,7 @@ func http2buildCommonHeaderMaps() { "link", "location", "max-forwards", + "origin", "proxy-authenticate", "proxy-authorization", "range", @@ -3225,6 +3235,8 @@ func http2buildCommonHeaderMaps() { "vary", "via", "www-authenticate", + "x-forwarded-for", + "x-forwarded-proto", } http2commonLowerHeader = make(map[string]string, len(common)) http2commonCanonHeader = make(map[string]string, len(common)) @@ -3243,6 +3255,14 @@ func http2lowerHeader(v string) (lower string, ascii bool) { return http2asciiToLower(v) } +func http2canonicalHeader(v string) string { + http2buildCommonHeaderMapsOnce() + if s, ok := http2commonCanonHeader[v]; ok { + return s + } + return CanonicalHeaderKey(v) +} + var ( http2VerboseLogs bool http2logFrameWrites bool @@ -3268,14 +3288,14 @@ const ( http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" // SETTINGS_MAX_FRAME_SIZE default - // http://http2.github.io/http2-spec/#rfc.section.6.5.2 + // https://httpwg.org/specs/rfc7540.html#rfc.section.6.5.2 http2initialMaxFrameSize = 16384 // NextProtoTLS is the NPN/ALPN protocol negotiated during // HTTP/2's TLS setup. http2NextProtoTLS = "h2" - // http://http2.github.io/http2-spec/#SettingValues + // https://httpwg.org/specs/rfc7540.html#SettingValues http2initialHeaderTableSize = 4096 http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size @@ -3324,7 +3344,7 @@ func (st http2streamState) String() string { // Setting is a setting parameter: which setting it is, and its value. type http2Setting struct { // ID is which setting is being set. - // See http://http2.github.io/http2-spec/#SettingValues + // See https://httpwg.org/specs/rfc7540.html#SettingFormat ID http2SettingID // Val is the value. @@ -3356,7 +3376,7 @@ func (s http2Setting) Valid() error { } // A SettingID is an HTTP/2 setting as defined in -// http://http2.github.io/http2-spec/#iana-settings +// https://httpwg.org/specs/rfc7540.html#iana-settings type http2SettingID uint16 const ( @@ -3817,6 +3837,19 @@ type http2Server struct { // the HTTP/2 spec's recommendations. MaxConcurrentStreams uint32 + // MaxDecoderHeaderTableSize optionally specifies the http2 + // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It + // informs the remote endpoint of the maximum size of the header compression + // table used to decode header blocks, in octets. If zero, the default value + // of 4096 is used. + MaxDecoderHeaderTableSize uint32 + + // MaxEncoderHeaderTableSize optionally specifies an upper limit for the + // header compression table used for encoding request headers. Received + // SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero, + // the default value of 4096 is used. + MaxEncoderHeaderTableSize uint32 + // MaxReadFrameSize optionally specifies the largest frame // this server is willing to read. A valid value is between // 16k and 16M, inclusive. If zero or otherwise invalid, a @@ -3862,7 +3895,7 @@ type http2Server struct { } func (s *http2Server) initialConnRecvWindowSize() int32 { - if s.MaxUploadBufferPerConnection > http2initialWindowSize { + if s.MaxUploadBufferPerConnection >= http2initialWindowSize { return s.MaxUploadBufferPerConnection } return 1 << 20 @@ -3889,6 +3922,20 @@ func (s *http2Server) maxConcurrentStreams() uint32 { return http2defaultMaxStreams } +func (s *http2Server) maxDecoderHeaderTableSize() uint32 { + if v := s.MaxDecoderHeaderTableSize; v > 0 { + return v + } + return http2initialHeaderTableSize +} + +func (s *http2Server) maxEncoderHeaderTableSize() uint32 { + if v := s.MaxEncoderHeaderTableSize; v > 0 { + return v + } + return http2initialHeaderTableSize +} + // maxQueuedControlFrames is the maximum number of control frames like // SETTINGS, PING and RST_STREAM that will be queued for writing before // the connection is closed to prevent memory exhaustion attacks. @@ -4034,6 +4081,20 @@ type http2ServeConnOpts struct { // requests. If nil, BaseConfig.Handler is used. If BaseConfig // or BaseConfig.Handler is nil, http.DefaultServeMux is used. Handler Handler + + // UpgradeRequest is an initial request received on a connection + // undergoing an h2c upgrade. The request body must have been + // completely read from the connection before calling ServeConn, + // and the 101 Switching Protocols response written. + UpgradeRequest *Request + + // Settings is the decoded contents of the HTTP2-Settings header + // in an h2c upgrade request. + Settings []byte + + // SawClientPreface is set if the HTTP/2 connection preface + // has already been read from the connection. + SawClientPreface bool } func (o *http2ServeConnOpts) context() context.Context { @@ -4099,9 +4160,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { advMaxStreams: s.maxConcurrentStreams(), initialStreamSendWindowSize: http2initialWindowSize, maxFrameSize: http2initialMaxFrameSize, - headerTableSize: http2initialHeaderTableSize, serveG: http2newGoroutineLock(), pushEnabled: true, + sawClientPreface: opts.SawClientPreface, } s.state.registerConn(sc) @@ -4128,12 +4189,13 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) + sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize()) fr := http2NewFramer(sc.bw, c) if s.CountError != nil { fr.countError = s.CountError } - fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil) fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.SetMaxReadFrameSize(s.maxReadFrameSize()) sc.framer = fr @@ -4184,9 +4246,27 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { } } + if opts.Settings != nil { + fr := &http2SettingsFrame{ + http2FrameHeader: http2FrameHeader{valid: true}, + p: opts.Settings, + } + if err := fr.ForeachSetting(sc.processSetting); err != nil { + sc.rejectConn(http2ErrCodeProtocol, "invalid settings") + return + } + opts.Settings = nil + } + if hook := http2testHookGetServerConn; hook != nil { hook(sc) } + + if opts.UpgradeRequest != nil { + sc.upgradeRequest(opts.UpgradeRequest) + opts.UpgradeRequest = nil + } + sc.serve() } @@ -4231,6 +4311,7 @@ type http2serverConn struct { // Everything following is owned by the serve loop; use serveG.check(): serveG http2goroutineLock // used to verify funcs are on serve() pushEnabled bool + sawClientPreface bool // preface has already been read, used in h2c upgrade sawFirstSettings bool // got the initial SETTINGS frame after the preface needToSendSettingsAck bool unackedSettings int // how many SETTINGS have we sent without ACKs? @@ -4244,7 +4325,6 @@ type http2serverConn struct { streams map[uint32]*http2stream initialStreamSendWindowSize int32 maxFrameSize int32 - headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case canonHeaderKeysSize int // canonHeader keys size in bytes @@ -4308,7 +4388,9 @@ type http2stream struct { resetQueued bool // RST_STREAM queued for write; set by sc.resetStream gotTrailerHeader bool // HEADER frame for trailers was seen wroteHeaders bool // whether we wrote headers (not status 100) + readDeadline *time.Timer // nil if unused writeDeadline *time.Timer // nil if unused + closeErr error // set before cw is closed trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -4554,6 +4636,7 @@ func (sc *http2serverConn) serve() { {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + {http2SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()}, {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, }, }) @@ -4640,6 +4723,8 @@ func (sc *http2serverConn) serve() { } case *http2startPushRequest: sc.startPush(v) + case func(*http2serverConn): + v(sc) default: panic(fmt.Sprintf("unexpected type %T", v)) } @@ -4702,6 +4787,9 @@ var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") // returns errPrefaceTimeout on timeout, or an error if the greeting // is invalid. func (sc *http2serverConn) readPreface() error { + if sc.sawClientPreface { + return nil + } errc := make(chan error, 1) go func() { // Read the client preface @@ -5152,6 +5240,21 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { sc.sawFirstSettings = true } + // Discard frames for streams initiated after the identified last + // stream sent in a GOAWAY, or all frames after sending an error. + // We still need to return connection-level flow control for DATA frames. + // RFC 9113 Section 6.8. + if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) { + + if f, ok := f.(*http2DataFrame); ok { + if sc.inflow.available() < int32(f.Length) { + return sc.countError("data_flow", http2streamError(f.Header().StreamID, http2ErrCodeFlowControl)) + } + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level + } + return nil + } + switch f := f.(type) { case *http2SettingsFrame: return sc.processSettings(f) @@ -5194,9 +5297,6 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { // PROTOCOL_ERROR." return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) } - if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { - return nil - } sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) return nil } @@ -5258,6 +5358,9 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = http2stateClosed + if st.readDeadline != nil { + st.readDeadline.Stop() + } if st.writeDeadline != nil { st.writeDeadline.Stop() } @@ -5283,6 +5386,14 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { p.CloseWithError(err) } + if e, ok := err.(http2StreamError); ok { + if e.Cause != nil { + err = e.Cause + } else { + err = http2errStreamClosed + } + } + st.closeErr = err st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc sc.writeSched.CloseStream(st.id) } @@ -5325,7 +5436,6 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { } switch s.ID { case http2SettingHeaderTableSize: - sc.headerTableSize = s.Val sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) case http2SettingEnablePush: sc.pushEnabled = s.Val != 0 @@ -5379,16 +5489,6 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.serveG.check() id := f.Header().StreamID - if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || id > sc.maxClientStreamID) { - // Discard all DATA frames if the GOAWAY is due to an - // error, or: - // - // Section 6.8: After sending a GOAWAY frame, the sender - // can discard frames for streams initiated by the - // receiver with identifiers higher than the identified - // last stream. - return nil - } data := f.Data() state, st := sc.state(id) @@ -5441,6 +5541,12 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // Sender sending more than they'd declared? if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + if sc.inflow.available() < int32(f.Length) { + return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) + } + sc.inflow.take(int32(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // value of a content-length header field does not equal the sum of the @@ -5525,19 +5631,27 @@ func (st *http2stream) copyTrailersToHandlerRequest() { } } +// onReadTimeout is run on its own goroutine (from time.AfterFunc) +// when the stream's ReadTimeout has fired. +func (st *http2stream) onReadTimeout() { + // Wrap the ErrDeadlineExceeded to avoid callers depending on us + // returning the bare error. + st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) +} + // onWriteTimeout is run on its own goroutine (from time.AfterFunc) // when the stream's WriteTimeout has fired. func (st *http2stream) onWriteTimeout() { - st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) + st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2StreamError{ + StreamID: st.id, + Code: http2ErrCodeInternal, + Cause: os.ErrDeadlineExceeded, + }}) } func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() id := f.StreamID - if sc.inGoAway { - // Ignore. - return nil - } // http://tools.ietf.org/html/rfc7540#section-5.1.1 // Streams initiated by a client MUST use odd-numbered stream // identifiers. [...] An endpoint that receives an unexpected @@ -5640,12 +5754,35 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout != 0 { sc.conn.SetReadDeadline(time.Time{}) + if st.body != nil { + st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) + } } go sc.runHandler(rw, req, handler) return nil } +func (sc *http2serverConn) upgradeRequest(req *Request) { + sc.serveG.check() + id := uint32(1) + sc.maxClientStreamID = id + st := sc.newStream(id, 0, http2stateHalfClosedRemote) + st.reqTrailer = req.Trailer + if st.reqTrailer != nil { + st.trailer = make(Header) + } + rw := sc.newResponseWriter(st, req) + + // Disable any read deadline set by the net/http package + // prior to the upgrade. + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) +} + func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { sc := st.sc sc.serveG.check() @@ -5688,9 +5825,6 @@ func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) } func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { - if sc.inGoAway { - return nil - } if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { return err } @@ -5871,6 +6005,11 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re } req = req.WithContext(st.ctx) + rw := sc.newResponseWriter(st, req) + return rw, req, nil +} + +func (sc *http2serverConn) newResponseWriter(st *http2stream, req *Request) *http2responseWriter { rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw *rws = http2responseWriterState{} // zero all the fields @@ -5879,10 +6018,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re rws.bw.Reset(http2chunkWriter{rws}) rws.stream = st rws.req = req - rws.body = body - - rw := &http2responseWriter{rws: rws} - return rw, req, nil + return &http2responseWriter{rws: rws} } // Run on its own goroutine. @@ -5890,6 +6026,9 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han didPanic := true defer func() { rw.rws.stream.cancelCtx() + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } if didPanic { e := recover() sc.writeFrameFromHandler(http2FrameWriteRequest{ @@ -6001,7 +6140,7 @@ func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { // a larger Read than this. Very unlikely, but we handle it here // rather than elsewhere for now. const maxUint31 = 1<<31 - 1 - for n >= maxUint31 { + for n > maxUint31 { sc.sendWindowUpdate32(st, maxUint31) n -= maxUint31 } @@ -6097,7 +6236,6 @@ type http2responseWriterState struct { // immutable within a request: stream *http2stream req *Request - body *http2requestBody // to close at end of request, if DATA frames didn't conn *http2serverConn // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc @@ -6122,7 +6260,15 @@ type http2responseWriterState struct { type http2chunkWriter struct{ rws *http2responseWriterState } -func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } +func (cw http2chunkWriter) Write(p []byte) (n int, err error) { + n, err = cw.rws.writeChunk(p) + if err == http2errStreamClosed { + // If writing failed because the stream has been closed, + // return the reason it was closed. + err = cw.rws.stream.closeErr + } + return n, err +} func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } @@ -6161,6 +6307,10 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { rws.writeHeader(200) } + if rws.handlerDone { + rws.promoteUndeclaredTrailers() + } + isHeadResp := rws.req.Method == "HEAD" if !rws.sentHeader { rws.sentHeader = true @@ -6232,10 +6382,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { return 0, nil } - if rws.handlerDone { - rws.promoteUndeclaredTrailers() - } - // only send trailers if they have actually been defined by the // server handler. hasNonemptyTrailers := rws.hasNonemptyTrailers() @@ -6316,23 +6462,85 @@ func (rws *http2responseWriterState) promoteUndeclaredTrailers() { } } +func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error { + st := w.rws.stream + if !deadline.IsZero() && deadline.Before(time.Now()) { + // If we're setting a deadline in the past, reset the stream immediately + // so writes after SetWriteDeadline returns will fail. + st.onReadTimeout() + return nil + } + w.rws.conn.sendServeMsg(func(sc *http2serverConn) { + if st.readDeadline != nil { + if !st.readDeadline.Stop() { + // Deadline already exceeded, or stream has been closed. + return + } + } + if deadline.IsZero() { + st.readDeadline = nil + } else if st.readDeadline == nil { + st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout) + } else { + st.readDeadline.Reset(deadline.Sub(time.Now())) + } + }) + return nil +} + +func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error { + st := w.rws.stream + if !deadline.IsZero() && deadline.Before(time.Now()) { + // If we're setting a deadline in the past, reset the stream immediately + // so writes after SetWriteDeadline returns will fail. + st.onWriteTimeout() + return nil + } + w.rws.conn.sendServeMsg(func(sc *http2serverConn) { + if st.writeDeadline != nil { + if !st.writeDeadline.Stop() { + // Deadline already exceeded, or stream has been closed. + return + } + } + if deadline.IsZero() { + st.writeDeadline = nil + } else if st.writeDeadline == nil { + st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout) + } else { + st.writeDeadline.Reset(deadline.Sub(time.Now())) + } + }) + return nil +} + func (w *http2responseWriter) Flush() { + w.FlushError() +} + +func (w *http2responseWriter) FlushError() error { rws := w.rws if rws == nil { panic("Header called after Handler finished") } + var err error if rws.bw.Buffered() > 0 { - if err := rws.bw.Flush(); err != nil { - // Ignore the error. The frame writer already knows. - return - } + err = rws.bw.Flush() } else { // The bufio.Writer won't call chunkWriter.Write // (writeChunk with zero bytes, so we have to do it // ourselves to force the HTTP response header and/or // final DATA frame (with END_STREAM) to be sent. - rws.writeChunk(nil) + _, err = http2chunkWriter{rws}.Write(nil) + if err == nil { + select { + case <-rws.stream.cw: + err = rws.stream.closeErr + default: + } + } } + return err } func (w *http2responseWriter) CloseNotify() <-chan bool { @@ -6817,13 +7025,23 @@ const ( // A Transport internally caches connections to servers. It is safe // for concurrent use by multiple goroutines. type http2Transport struct { - // DialTLS specifies an optional dial function for creating - // TLS connections for requests. + // DialTLSContext specifies an optional dial function with context for + // creating TLS connections for requests. // - // If DialTLS is nil, tls.Dial is used. + // If DialTLSContext and DialTLS is nil, tls.Dial is used. // // If the returned net.Conn has a ConnectionState method like tls.Conn, // it will be used to set http.Response.TLS. + DialTLSContext func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) + + // DialTLS specifies an optional dial function for creating + // TLS connections for requests. + // + // If DialTLSContext and DialTLS is nil, tls.Dial is used. + // + // Deprecated: Use DialTLSContext instead, which allows the transport + // to cancel dials as soon as they are no longer needed. + // If both are set, DialTLSContext takes priority. DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with @@ -6857,6 +7075,28 @@ type http2Transport struct { // to mean no limit. MaxHeaderListSize uint32 + // MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the + // initial settings frame. It is the size in bytes of the largest frame + // payload that the sender is willing to receive. If 0, no setting is + // sent, and the value is provided by the peer, which should be 16384 + // according to the spec: + // https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2. + // Values are bounded in the range 16k to 16M. + MaxReadFrameSize uint32 + + // MaxDecoderHeaderTableSize optionally specifies the http2 + // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It + // informs the remote endpoint of the maximum size of the header compression + // table used to decode header blocks, in octets. If zero, the default value + // of 4096 is used. + MaxDecoderHeaderTableSize uint32 + + // MaxEncoderHeaderTableSize optionally specifies an upper limit for the + // header compression table used for encoding request headers. Received + // SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero, + // the default value of 4096 is used. + MaxEncoderHeaderTableSize uint32 + // StrictMaxConcurrentStreams controls whether the server's // SETTINGS_MAX_CONCURRENT_STREAMS should be respected // globally. If false, new TCP connections are created to the @@ -6910,6 +7150,19 @@ func (t *http2Transport) maxHeaderListSize() uint32 { return t.MaxHeaderListSize } +func (t *http2Transport) maxFrameReadSize() uint32 { + if t.MaxReadFrameSize == 0 { + return 0 // use the default provided by the peer + } + if t.MaxReadFrameSize < http2minMaxFrameSize { + return http2minMaxFrameSize + } + if t.MaxReadFrameSize > http2maxFrameSize { + return http2maxFrameSize + } + return t.MaxReadFrameSize +} + func (t *http2Transport) disableCompression() bool { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } @@ -6998,7 +7251,8 @@ func (t *http2Transport) initConnPool() { // HTTP/2 server. type http2ClientConn struct { t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls + tconn net.Conn // usually *tls.Conn, except specialized impls + tconnClosed bool tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request @@ -7031,10 +7285,11 @@ type http2ClientConn struct { lastActive time.Time lastIdle time.Time // time last idle // Settings from peer: (also guarded by wmu) - maxFrameSize uint32 - maxConcurrentStreams uint32 - peerMaxHeaderListSize uint64 - initialWindowSize uint32 + maxFrameSize uint32 + maxConcurrentStreams uint32 + peerMaxHeaderListSize uint64 + peerMaxHeaderTableSize uint32 + initialWindowSize uint32 // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. // Write to reqHeaderMu to lock it, read from it to unlock. @@ -7084,8 +7339,8 @@ type http2clientStream struct { readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser - reqBodyContentLength int64 // -1 means unknown - reqBodyClosed bool // body has been closed; guarded by cc.mu + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done // owned by writeRequest: sentEndStream bool // sent an END_STREAM flag to the peer @@ -7125,9 +7380,8 @@ func (cs *http2clientStream) abortStreamLocked(err error) { cs.abortErr = err close(cs.abort) }) - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() - cs.reqBodyClosed = true + if cs.reqBody != nil { + cs.closeReqBodyLocked() } // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { @@ -7140,13 +7394,24 @@ func (cs *http2clientStream) abortRequestBodyWrite() { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() - cs.reqBodyClosed = true + if cs.reqBody != nil && cs.reqBodyClosed == nil { + cs.closeReqBodyLocked() cc.cond.Broadcast() } } +func (cs *http2clientStream) closeReqBodyLocked() { + if cs.reqBodyClosed != nil { + return + } + cs.reqBodyClosed = make(chan struct{}) + reqBodyClosed := cs.reqBodyClosed + go func() { + cs.reqBody.Close() + close(reqBodyClosed) + }() +} + type http2stickyErrWriter struct { conn net.Conn timeout time.Duration @@ -7231,6 +7496,15 @@ func http2authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } +var http2retryBackoffHook func(time.Duration) *time.Timer + +func http2backoffNewTimer(d time.Duration) *time.Timer { + if http2retryBackoffHook != nil { + return http2retryBackoffHook(d) + } + return time.NewTimer(d) +} + // RoundTripOpt is like RoundTrip, but takes options. func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -7256,11 +7530,14 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) + d := time.Second * time.Duration(backoff) + timer := http2backoffNewTimer(d) select { - case <-time.After(time.Second * time.Duration(backoff)): + case <-timer.C: t.vlogf("RoundTrip retrying after failure: %v", err) continue case <-req.Context().Done(): + timer.Stop() err = req.Context().Err() } } @@ -7343,7 +7620,7 @@ func (t *http2Transport) dialClientConn(ctx context.Context, addr string, single if err != nil { return nil, err } - tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host)) + tconn, err := t.dialTLS(ctx, "tcp", addr, t.newTLSConfig(host)) if err != nil { return nil, err } @@ -7364,24 +7641,25 @@ func (t *http2Transport) newTLSConfig(host string) *tls.Config { return cfg } -func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { - if t.DialTLS != nil { - return t.DialTLS +func (t *http2Transport) dialTLS(ctx context.Context, network, addr string, tlsCfg *tls.Config) (net.Conn, error) { + if t.DialTLSContext != nil { + return t.DialTLSContext(ctx, network, addr, tlsCfg) + } else if t.DialTLS != nil { + return t.DialTLS(network, addr, tlsCfg) } - return func(network, addr string, cfg *tls.Config) (net.Conn, error) { - tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) - if err != nil { - return nil, err - } - state := tlsCn.ConnectionState() - if p := state.NegotiatedProtocol; p != http2NextProtoTLS { - return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) - } - if !state.NegotiatedProtocolIsMutual { - return nil, errors.New("http2: could not negotiate protocol mutually") - } - return tlsCn, nil + + tlsCn, err := t.dialTLSWithContext(ctx, network, addr, tlsCfg) + if err != nil { + return nil, err + } + state := tlsCn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) + } + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("http2: could not negotiate protocol mutually") } + return tlsCn, nil } // disableKeepAlives reports whether connections should be closed as @@ -7397,6 +7675,20 @@ func (t *http2Transport) expectContinueTimeout() time.Duration { return t.t1.ExpectContinueTimeout } +func (t *http2Transport) maxDecoderHeaderTableSize() uint32 { + if v := t.MaxDecoderHeaderTableSize; v > 0 { + return v + } + return http2initialHeaderTableSize +} + +func (t *http2Transport) maxEncoderHeaderTableSize() uint32 { + if v := t.MaxEncoderHeaderTableSize; v > 0 { + return v + } + return http2initialHeaderTableSize +} + func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { return t.newClientConn(c, t.disableKeepAlives()) } @@ -7437,15 +7729,19 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client }) cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) + if t.maxFrameReadSize() != 0 { + cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize()) + } if t.CountError != nil { cc.fr.countError = t.CountError } - cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + maxHeaderTableSize := t.maxDecoderHeaderTableSize() + cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() - // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on - // henc in response to SETTINGS frames? cc.henc = hpack.NewEncoder(&cc.hbuf) + cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize()) + cc.peerMaxHeaderTableSize = http2initialHeaderTableSize if t.AllowHTTP { cc.nextStreamID = 3 @@ -7460,9 +7756,15 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client {ID: http2SettingEnablePush, Val: 0}, {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, } + if max := t.maxFrameReadSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxFrameSize, Val: max}) + } if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) } + if maxHeaderTableSize != http2initialHeaderTableSize { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingHeaderTableSize, Val: maxHeaderTableSize}) + } cc.bw.Write(http2clientPreface) cc.fr.WriteSettings(initialSettings...) @@ -7661,10 +7963,10 @@ func (cc *http2ClientConn) onIdleTimeout() { cc.closeIfIdle() } -func (cc *http2ClientConn) closeConn() error { +func (cc *http2ClientConn) closeConn() { t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn) defer t.Stop() - return cc.tconn.Close() + cc.tconn.Close() } // A tls.Conn.Close can hang for a long time if the peer is unresponsive. @@ -7730,7 +8032,8 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { http2shutdownEnterWaitStateHook() select { case <-done: - return cc.closeConn() + cc.closeConn() + return nil case <-ctx.Done(): cc.mu.Lock() // Free the goroutine above @@ -7767,7 +8070,7 @@ func (cc *http2ClientConn) sendGoAway() error { // closes the client connection immediately. In-flight requests are interrupted. // err is sent to streams. -func (cc *http2ClientConn) closeForError(err error) error { +func (cc *http2ClientConn) closeForError(err error) { cc.mu.Lock() cc.closed = true for _, cs := range cc.streams { @@ -7775,7 +8078,7 @@ func (cc *http2ClientConn) closeForError(err error) error { } cc.cond.Broadcast() cc.mu.Unlock() - return cc.closeConn() + cc.closeConn() } // Close closes the client connection immediately. @@ -7783,16 +8086,17 @@ func (cc *http2ClientConn) closeForError(err error) error { // In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. func (cc *http2ClientConn) Close() error { err := errors.New("http2: client connection force closed via ClientConn.Close") - return cc.closeForError(err) + cc.closeForError(err) + return nil } // closes the client connection immediately. In-flight requests are interrupted. -func (cc *http2ClientConn) closeForLostPing() error { +func (cc *http2ClientConn) closeForLostPing() { err := errors.New("http2: client connection lost") if f := cc.t.CountError; f != nil { f("conn_close_lost_ping") } - return cc.closeForError(err) + cc.closeForError(err) } // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not @@ -7802,7 +8106,7 @@ var http2errRequestCanceled = errors.New("net/http: request canceled") func http2commaSeparatedTrailers(req *Request) (string, error) { keys := make([]string, 0, len(req.Trailer)) for k := range req.Trailer { - k = CanonicalHeaderKey(k) + k = http2canonicalHeader(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": return "", fmt.Errorf("invalid Trailer key %q", k) @@ -8170,11 +8474,19 @@ func (cs *http2clientStream) cleanupWriteRequest(err error) { // and in multiple cases: server replies <=299 and >299 // while still writing request body cc.mu.Lock() + mustCloseBody := false + if cs.reqBody != nil && cs.reqBodyClosed == nil { + mustCloseBody = true + cs.reqBodyClosed = make(chan struct{}) + } bodyClosed := cs.reqBodyClosed - cs.reqBodyClosed = true cc.mu.Unlock() - if !bodyClosed && cs.reqBody != nil { + if mustCloseBody { cs.reqBody.Close() + close(bodyClosed) + } + if bodyClosed != nil { + <-bodyClosed } if err != nil && cs.sentEndStream { @@ -8331,7 +8643,7 @@ func (cs *http2clientStream) writeRequestBody(req *Request) (err error) { var sawEOF bool for !sawEOF { - n, err := body.Read(buf[:len(buf)]) + n, err := body.Read(buf) if hasContentLen { remainLen -= int64(n) if remainLen == 0 && err == nil { @@ -8354,7 +8666,7 @@ func (cs *http2clientStream) writeRequestBody(req *Request) (err error) { } if err != nil { cc.mu.Lock() - bodyClosed := cs.reqBodyClosed + bodyClosed := cs.reqBodyClosed != nil cc.mu.Unlock() switch { case bodyClosed: @@ -8449,7 +8761,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er if cc.closed { return 0, http2errClientConnClosed } - if cs.reqBodyClosed { + if cs.reqBodyClosed != nil { return 0, http2errStopReqBodyWrite } select { @@ -8634,7 +8946,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { - name, ascii := http2asciiToLower(name) + name, ascii := http2lowerHeader(name) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). @@ -8687,7 +8999,7 @@ func (cc *http2ClientConn) encodeTrailers(trailer Header) ([]byte, error) { } for k, vv := range trailer { - lowKey, ascii := http2asciiToLower(k) + lowKey, ascii := http2lowerHeader(k) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). @@ -8745,7 +9057,7 @@ func (cc *http2ClientConn) forgetStreamID(id uint32) { // wake up RoundTrip if there is a pending request. cc.cond.Broadcast() - closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { if http2VerboseLogs { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) @@ -8821,6 +9133,7 @@ func (rl *http2clientConnReadLoop) cleanup() { err = io.ErrUnexpectedEOF } cc.closed = true + for _, cs := range cc.streams { select { case <-cs.peerClosed: @@ -9019,7 +9332,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http Status: status + " " + StatusText(statusCode), } for _, hf := range regularFields { - key := CanonicalHeaderKey(hf.Name) + key := http2canonicalHeader(hf.Name) if key == "Trailer" { t := res.Trailer if t == nil { @@ -9027,7 +9340,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http res.Trailer = t } http2foreachHeaderElement(hf.Value, func(v string) { - t[CanonicalHeaderKey(v)] = nil + t[http2canonicalHeader(v)] = nil }) } else { vv := header[key] @@ -9132,7 +9445,7 @@ func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *htt trailer := make(Header) for _, hf := range f.RegularFields() { - key := CanonicalHeaderKey(hf.Name) + key := http2canonicalHeader(hf.Name) trailer[key] = append(trailer[key], hf.Value) } cs.trailer = trailer @@ -9414,7 +9727,6 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { if fn := cc.t.CountError; fn != nil { fn("recv_goaway_" + f.ErrCode.stringToken()) } - } cc.setGoAway(f) return nil @@ -9479,8 +9791,10 @@ func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) cc.cond.Broadcast() cc.initialWindowSize = s.Val + case http2SettingHeaderTableSize: + cc.henc.SetMaxDynamicTableSize(s.Val) + cc.peerMaxHeaderTableSize = s.Val default: - // TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably. cc.vlogf("Unhandled Setting: %v", s) } return nil @@ -9707,7 +10021,11 @@ func (gz *http2gzipReader) Read(p []byte) (n int, err error) { } func (gz *http2gzipReader) Close() error { - return gz.body.Close() + if err := gz.body.Close(); err != nil { + return err + } + gz.zerr = fs.ErrClosed + return nil } type http2errorReader struct{ err error } @@ -9771,7 +10089,7 @@ func http2traceGotConn(req *Request, cc *http2ClientConn, reused bool) { cc.mu.Lock() ci.WasIdle = len(cc.streams) == 0 && reused if ci.WasIdle && !cc.lastActive.IsZero() { - ci.IdleTime = time.Now().Sub(cc.lastActive) + ci.IdleTime = time.Since(cc.lastActive) } cc.mu.Unlock() @@ -10784,16 +11102,15 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { var n *http2priorityNode - if id := wr.StreamID(); id == 0 { + if wr.isControl() { n = &ws.root } else { + id := wr.StreamID() n = ws.nodes[id] if n == nil { // id is an idle or closed stream. wr should not be a HEADERS or - // DATA frame. However, wr can be a RST_STREAM. In this case, we - // push wr onto the root, rather than creating a new priorityNode, - // since RST_STREAM is tiny and the stream's priority is unknown - // anyway. See issue #17919. + // DATA frame. In other case, we push wr onto the root, rather + // than creating a new priorityNode. if wr.DataSize() > 0 { panic("add DATA on non-open stream") } diff --git a/h2_error.go b/h2_error.go new file mode 100644 index 00000000..0391d31e --- /dev/null +++ b/h2_error.go @@ -0,0 +1,38 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !nethttpomithttp2 +// +build !nethttpomithttp2 + +package http + +import ( + "reflect" +) + +func (e http2StreamError) As(target any) bool { + dst := reflect.ValueOf(target).Elem() + dstType := dst.Type() + if dstType.Kind() != reflect.Struct { + return false + } + src := reflect.ValueOf(e) + srcType := src.Type() + numField := srcType.NumField() + if dstType.NumField() != numField { + return false + } + for i := 0; i < numField; i++ { + sf := srcType.Field(i) + df := dstType.Field(i) + if sf.Name != df.Name || !sf.Type.ConvertibleTo(df.Type) { + return false + } + } + for i := 0; i < numField; i++ { + df := dst.Field(i) + df.Set(src.Field(i).Convert(df.Type())) + } + return true +} diff --git a/h2_error_test.go b/h2_error_test.go new file mode 100644 index 00000000..0d85e2f3 --- /dev/null +++ b/h2_error_test.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !nethttpomithttp2 +// +build !nethttpomithttp2 + +package http + +import ( + "errors" + "fmt" + "testing" +) + +type externalStreamErrorCode uint32 + +type externalStreamError struct { + StreamID uint32 + Code externalStreamErrorCode + Cause error +} + +func (e externalStreamError) Error() string { + return fmt.Sprintf("ID %v, code %v", e.StreamID, e.Code) +} + +func TestStreamError(t *testing.T) { + var target externalStreamError + streamErr := http2streamError(42, http2ErrCodeProtocol) + ok := errors.As(streamErr, &target) + if !ok { + t.Fatalf("errors.As failed") + } + if target.StreamID != streamErr.StreamID { + t.Errorf("got StreamID %v, expected %v", target.StreamID, streamErr.StreamID) + } + if target.Cause != streamErr.Cause { + t.Errorf("got Cause %v, expected %v", target.Cause, streamErr.Cause) + } + if uint32(target.Code) != uint32(streamErr.Code) { + t.Errorf("got Code %v, expected %v", target.Code, streamErr.Code) + } +} diff --git a/header_test.go b/header_test.go index 3a1b1dd1..0121dcfd 100644 --- a/header_test.go +++ b/header_test.go @@ -7,6 +7,7 @@ package http import ( "bytes" "reflect" + "strings" "testing" "time" ) @@ -103,7 +104,7 @@ var headerWriteTests = []struct { } func TestHeaderWrite(t *testing.T) { - var buf bytes.Buffer + var buf strings.Builder for i, test := range headerWriteTests { test.h.WriteSubset(&buf, test.exclude) if buf.String() != test.expected { diff --git a/httptrace/trace_test.go b/httptrace/trace_test.go index bb57ada8..6efa1f79 100644 --- a/httptrace/trace_test.go +++ b/httptrace/trace_test.go @@ -5,13 +5,13 @@ package httptrace import ( - "bytes" "context" + "strings" "testing" ) func TestWithClientTrace(t *testing.T) { - var buf bytes.Buffer + var buf strings.Builder connectStart := func(b byte) func(network, addr string) { return func(network, addr string) { buf.WriteByte(b) @@ -37,7 +37,7 @@ func TestWithClientTrace(t *testing.T) { } func TestCompose(t *testing.T) { - var buf bytes.Buffer + var buf strings.Builder var testNum int connectStart := func(b byte) func(network, addr string) { diff --git a/httputil/dump.go b/httputil/dump.go index 31b72e88..6b24fc05 100644 --- a/httputil/dump.go +++ b/httputil/dump.go @@ -60,7 +60,7 @@ func (b neverEnding) Read(p []byte) (n int, err error) { return len(p), nil } -// outGoingLength is a copy of the unexported +// outgoingLength is a copy of the unexported // (*http.Request).outgoingLength method. func outgoingLength(req *http.Request) int64 { if req.Body == nil || req.Body == http.NoBody { @@ -259,9 +259,6 @@ func DumpRequest(req *http.Request, body bool) ([]byte, error) { if len(req.TransferEncoding) > 0 { fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) } - if req.Close { - fmt.Fprintf(&b, "Connection: close\r\n") - } err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) if err != nil { diff --git a/httputil/dump_test.go b/httputil/dump_test.go index 711362e9..fad0855c 100644 --- a/httputil/dump_test.go +++ b/httputil/dump_test.go @@ -237,6 +237,19 @@ var dumpTests = []dumpTest{ "Transfer-Encoding: chunked\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, + + // Issue 54616: request with Connection header doesn't result in duplicate header. + { + GetReq: func() *http.Request { + return mustReadRequest("GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n\r\n") + }, + NoBody: true, + WantDump: "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n\r\n", + }, } func TestDumpRequest(t *testing.T) { @@ -511,7 +524,7 @@ func TestDumpRequestOutIssue38352(t *testing.T) { select { case <-out: case <-time.After(timeout): - b := &bytes.Buffer{} + b := &strings.Builder{} fmt.Fprintf(b, "deadlock detected on iteration %d after %s with delay: %v\n", i, timeout, delay) pprof.Lookup("goroutine").WriteTo(b, 1) t.Fatal(b.String()) diff --git a/httputil/example_test.go b/httputil/example_test.go index 4c26d2a4..b3c0ed99 100644 --- a/httputil/example_test.go +++ b/httputil/example_test.go @@ -104,7 +104,12 @@ func ExampleReverseProxy() { if err != nil { log.Fatal(err) } - frontendProxy := httptest.NewServer(httputil.NewSingleHostReverseProxy(rpURL)) + frontendProxy := httptest.NewServer(&httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetXForwarded() + r.SetURL(rpURL) + }, + }) defer frontendProxy.Close() resp, err := http.Get(frontendProxy.URL) diff --git a/httputil/reverseproxy.go b/httputil/reverseproxy.go index 56a8367f..7ba3e0b5 100644 --- a/httputil/reverseproxy.go +++ b/httputil/reverseproxy.go @@ -8,6 +8,7 @@ package httputil import ( "context" + "errors" "fmt" "io" "log" @@ -19,34 +20,138 @@ import ( "sync" "time" - "github.com/ooni/oohttp" + http "github.com/ooni/oohttp" + "github.com/ooni/oohttp/httptrace" "github.com/ooni/oohttp/internal/ascii" "golang.org/x/net/http/httpguts" ) +// A ProxyRequest contains a request to be rewritten by a ReverseProxy. +type ProxyRequest struct { + // In is the request received by the proxy. + // The Rewrite function must not modify In. + In *http.Request + + // Out is the request which will be sent by the proxy. + // The Rewrite function may modify or replace this request. + // Hop-by-hop headers are removed from this request + // before Rewrite is called. + Out *http.Request +} + +// SetURL routes the outbound request to the scheme, host, and base path +// provided in target. If the target's path is "/base" and the incoming +// request was for "/dir", the target request will be for "/base/dir". +// +// SetURL rewrites the outbound Host header to match the target's host. +// To preserve the inbound request's Host header (the default behavior +// of NewSingleHostReverseProxy): +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.SetURL(url) +// r.Out.Host = r.In.Host +// } +func (r *ProxyRequest) SetURL(target *url.URL) { + rewriteRequestURL(r.Out, target) + r.Out.Host = "" +} + +// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and +// X-Forwarded-Proto headers of the outbound request. +// +// - The X-Forwarded-For header is set to the client IP address. +// - The X-Forwarded-Host header is set to the host name requested +// by the client. +// - The X-Forwarded-Proto header is set to "http" or "https", depending +// on whether the inbound request was made on a TLS-enabled connection. +// +// If the outbound request contains an existing X-Forwarded-For header, +// SetXForwarded appends the client IP address to it. To append to the +// inbound request's X-Forwarded-For header (the default behavior of +// ReverseProxy when using a Director function), copy the header +// from the inbound request before calling SetXForwarded: +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] +// r.SetXForwarded() +// } +func (r *ProxyRequest) SetXForwarded() { + clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) + if err == nil { + prior := r.Out.Header["X-Forwarded-For"] + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + r.Out.Header.Set("X-Forwarded-For", clientIP) + } else { + r.Out.Header.Del("X-Forwarded-For") + } + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + if r.In.TLS == nil { + r.Out.Header.Set("X-Forwarded-Proto", "http") + } else { + r.Out.Header.Set("X-Forwarded-Proto", "https") + } +} + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. // -// ReverseProxy by default sets the client IP as the value of the -// X-Forwarded-For header. -// -// If an X-Forwarded-For header already exists, the client IP is -// appended to the existing values. As a special case, if the header -// exists in the Request.Header map but has a nil value (such as when -// set by the Director func), the X-Forwarded-For header is -// not modified. -// -// To prevent IP spoofing, be sure to delete any pre-existing -// X-Forwarded-For header coming from the client or -// an untrusted proxy. +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. type ReverseProxy struct { - // Director must be a function which modifies + // Rewrite must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Rewrite must not access the provided ProxyRequest + // or its contents after returning. + // + // The Forwarded, X-Forwarded, X-Forwarded-Host, + // and X-Forwarded-Proto headers are removed from the + // outbound request before Rewrite is called. See also + // the ProxyRequest.SetXForwarded method. + // + // Unparsable query parameters are removed from the + // outbound request before Rewrite is called. + // The Rewrite function may copy the inbound URL's + // RawQuery to the outbound URL to preserve the original + // parameter string. Note that this can lead to security + // issues if the proxy's interpretation of query parameters + // does not match that of the downstream server. + // + // At most one of Rewrite or Director may be set. + Rewrite func(*ProxyRequest) + + // Director is a function which modifies // the request into a new request to be sent // using Transport. Its response is then copied // back to the original client unmodified. // Director must not access the provided Request // after returning. + // + // By default, the X-Forwarded-For header is set to the + // value of the client IP address. If an X-Forwarded-For + // header already exists, the client IP is appended to the + // existing values. As a special case, if the header + // exists in the Request.Header map but has a nil value + // (such as when set by the Director func), the X-Forwarded-For + // header is not modified. + // + // To prevent IP spoofing, be sure to delete any pre-existing + // X-Forwarded-For header coming from the client or + // an untrusted proxy. + // + // Hop-by-hop headers are removed from the request after + // Director returns, which can remove headers added by + // Director. Use a Rewrite function instead to ensure + // modifications to the request are preserved. + // + // Unparsable query parameters are removed from the outbound + // request if Request.Form is set after Director returns. + // + // At most one of Rewrite or Director may be set. Director func(*http.Request) // The transport used to perform proxy requests. @@ -138,28 +243,41 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // URLs to the scheme, host, and base path provided in target. If the // target's path is "/base" and the incoming request was for "/dir", // the target request will be for /base/dir. +// // NewSingleHostReverseProxy does not rewrite the Host header. -// To rewrite Host headers, use ReverseProxy directly with a custom -// Director policy. +// +// To customize the ReverseProxy behavior beyond what +// NewSingleHostReverseProxy provides, use ReverseProxy directly +// with a Rewrite function. The ProxyRequest SetURL method +// may be used to route the outbound request. (Note that SetURL, +// unlike NewSingleHostReverseProxy, rewrites the Host header +// of the outbound request by default.) +// +// proxy := &ReverseProxy{ +// Rewrite: func(r *ProxyRequest) { +// r.SetURL(target) +// r.Out.Host = r.In.Host // if desired +// } +// } func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - targetQuery := target.RawQuery director := func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery - } - if _, ok := req.Header["User-Agent"]; !ok { - // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") - } + rewriteRequestURL(req, target) } return &ReverseProxy{Director: director} } +func rewriteRequestURL(req *http.Request, target *url.URL) { + targetQuery := target.RawQuery + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { @@ -260,9 +378,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate } - p.Director(outreq) - if outreq.Form != nil { - outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + if (p.Director != nil) == (p.Rewrite != nil) { + p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set")) + return + } + + if p.Director != nil { + p.Director(outreq) + if outreq.Form != nil { + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + } } outreq.Close = false @@ -271,20 +396,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) return } - removeConnectionHeaders(outreq.Header) - - // Remove hop-by-hop headers to the backend. Especially - // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. - for _, h := range hopHeaders { - outreq.Header.Del(h) - } + removeHopByHopHeaders(outreq.Header) // Issue 21096: tell backend applications that care about trailer support // that we support trailers. (We do, but we don't go out of our way to // advertise that unless the incoming client request thought it was worth // mentioning.) Note that we look at req.Header, not outreq.Header, since - // the latter has passed through removeConnectionHeaders. + // the latter has passed through removeHopByHopHeaders. if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { outreq.Header.Set("Te", "trailers") } @@ -296,20 +414,62 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Set("Upgrade", reqUpType) } - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - // If we aren't the first proxy retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - prior, ok := outreq.Header["X-Forwarded-For"] - omit := ok && prior == nil // Issue 38079: nil now means don't populate the header - if len(prior) > 0 { - clientIP = strings.Join(prior, ", ") + ", " + clientIP + if p.Rewrite != nil { + // Strip client-provided forwarding headers. + // The Rewrite func may use SetXForwarded to set new values + // for these or copy the previous values from the inbound request. + outreq.Header.Del("Forwarded") + outreq.Header.Del("X-Forwarded-For") + outreq.Header.Del("X-Forwarded-Host") + outreq.Header.Del("X-Forwarded-Proto") + + // Remove unparsable query parameters from the outbound request. + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + + pr := &ProxyRequest{ + In: req, + Out: outreq, } - if !omit { - outreq.Header.Set("X-Forwarded-For", clientIP) + p.Rewrite(pr) + outreq = pr.Out + } else { + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + if !omit { + outreq.Header.Set("X-Forwarded-For", clientIP) + } } } + if _, ok := outreq.Header["User-Agent"]; !ok { + // If the outbound request doesn't have a User-Agent header set, + // don't send the default Go HTTP client User-Agent. + outreq.Header.Set("User-Agent", "") + } + + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + for k := range h { + delete(h, k) + } + + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + res, err := transport.RoundTrip(outreq) if err != nil { p.getErrorHandler()(rw, outreq, err) @@ -325,11 +485,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - removeConnectionHeaders(res.Header) - - for _, h := range hopHeaders { - res.Header.Del(h) - } + removeHopByHopHeaders(res.Header) if !p.modifyResponse(rw, res, outreq) { return @@ -408,9 +564,9 @@ func shouldPanicOnCopyError(req *http.Request) bool { return false } -// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. -// See RFC 7230, section 6.1 -func removeConnectionHeaders(h http.Header) { +// removeHopByHopHeaders removes hop-by-hop headers. +func removeHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. for _, f := range h["Connection"] { for _, sf := range strings.Split(f, ",") { if sf = textproto.TrimString(sf); sf != "" { @@ -418,6 +574,12 @@ func removeConnectionHeaders(h http.Header) { } } } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } } // flushInterval returns the p.FlushInterval value, conditionally diff --git a/httputil/reverseproxy_test.go b/httputil/reverseproxy_test.go index 4b5dcac6..aeff7117 100644 --- a/httputil/reverseproxy_test.go +++ b/httputil/reverseproxy_test.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "log" + "net/textproto" "net/url" "os" "reflect" @@ -24,8 +25,9 @@ import ( "testing" "time" - "github.com/ooni/oohttp" + http "github.com/ooni/oohttp" "github.com/ooni/oohttp/httptest" + "github.com/ooni/oohttp/httptrace" "github.com/ooni/oohttp/internal/ascii" ) @@ -320,7 +322,6 @@ func TestXForwardedFor(t *testing.T) { defer frontend.Close() getReq, _ := http.NewRequest("GET", frontend.URL, nil) - getReq.Host = "some-name" getReq.Header.Set("Connection", "close") getReq.Header.Set("X-Forwarded-For", prevForwardedFor) getReq.Close = true @@ -370,6 +371,46 @@ func TestXForwardedFor_Omit(t *testing.T) { res.Body.Close() } +func TestReverseProxyRewriteStripsForwarded(t *testing.T) { + headers := []string{ + "Forwarded", + "X-Forwarded-For", + "X-Forwarded-Host", + "X-Forwarded-Proto", + } + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, h := range headers { + if v := r.Header.Get(h); v != "" { + t.Errorf("got %v header: %q", h, v) + } + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + for _, h := range headers { + getReq.Header.Set(h, "x") + } + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + var proxyQueryTests = []struct { baseSuffix string // suffix to add to backend URL reqSuffix string // suffix to add to frontend's request URL @@ -575,46 +616,38 @@ func TestNilBody(t *testing.T) { // Issue 15524 func TestUserAgentHeader(t *testing.T) { - const explicitUA = "explicit UA" + var gotUA string backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/noua" { - if c := r.Header.Get("User-Agent"); c != "" { - t.Errorf("handler got non-empty User-Agent header %q", c) - } - return - } - if c := r.Header.Get("User-Agent"); c != explicitUA { - t.Errorf("handler got unexpected User-Agent header %q", c) - } + gotUA = r.Header.Get("User-Agent") })) defer backend.Close() backendURL, err := url.Parse(backend.URL) if err != nil { t.Fatal(err) } - proxyHandler := NewSingleHostReverseProxy(backendURL) + + proxyHandler := new(ReverseProxy) proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(req *http.Request) { + req.URL = backendURL + } frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - getReq.Header.Set("User-Agent", explicitUA) - getReq.Close = true - res, err := frontendClient.Do(getReq) - if err != nil { - t.Fatalf("Get: %v", err) - } - res.Body.Close() - - getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) - getReq.Header.Set("User-Agent", "") - getReq.Close = true - res, err = frontendClient.Do(getReq) - if err != nil { - t.Fatalf("Get: %v", err) + for _, sentUA := range []string{"explicit UA", ""} { + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("User-Agent", sentUA) + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() + if got, want := gotUA, sentUA; got != want { + t.Errorf("got forwarded User-Agent %q, want %q", got, want) + } } - res.Body.Close() } type bufferPool struct { @@ -1030,13 +1063,14 @@ func TestClonesRequestHeaders(t *testing.T) { } rp.ServeHTTP(httptest.NewRecorder(), req) - if req.Header.Get("From-Director") == "1" { - t.Error("Director header mutation modified caller's request") - } - if req.Header.Get("X-Forwarded-For") != "" { - t.Error("X-Forward-For header mutation modified caller's request") + for _, h := range []string{ + "From-Director", + "X-Forwarded-For", + } { + if req.Header.Get(h) != "" { + t.Errorf("%v header mutation modified caller's request", h) + } } - } type roundTripperFunc func(req *http.Request) (*http.Response, error) @@ -1049,7 +1083,7 @@ func TestModifyResponseClosesBody(t *testing.T) { req, _ := http.NewRequest("GET", "http://foo.tld/", nil) req.RemoteAddr = "1.2.3.4:56789" closeCheck := new(checkCloser) - logBuf := new(bytes.Buffer) + logBuf := new(strings.Builder) outErr := errors.New("ModifyResponse error") rp := &ReverseProxy{ Director: func(req *http.Request) {}, @@ -1489,6 +1523,40 @@ func TestUnannouncedTrailer(t *testing.T) { } +func TestSetURL(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Host)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Reading body: %v", err) + } + + if got, want := string(body), backendURL.Host; got != want { + t.Errorf("backend got Host %q, want %q", got, want) + } +} + func TestSingleJoinSlash(t *testing.T) { tests := []struct { slasha string @@ -1539,6 +1607,111 @@ func TestJoinURLPath(t *testing.T) { } } +func TestReverseProxyRewriteReplacesOut(t *testing.T) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(content)) + })) + defer backend.Close() + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.Out, _ = http.NewRequest("GET", backend.URL, nil) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if got, want := string(body), content; got != want { + t.Errorf("got response %q, want %q", got, want) + } +} + +func Test1xxResponses(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write([]byte("Hello")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} + const ( testWantsCleanQuery = true testWantsRawQuery = false @@ -1569,6 +1742,27 @@ func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { }) } +func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + }, + } + }) +} + +func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + r.Out.URL.RawQuery = r.In.URL.RawQuery + }, + } + }) +} + func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { const content = "response_content" backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/safefilepath/path_other.go b/internal/safefilepath/path_other.go index f93da186..974e7751 100644 --- a/internal/safefilepath/path_other.go +++ b/internal/safefilepath/path_other.go @@ -11,7 +11,7 @@ import "runtime" func fromFS(path string) (string, error) { if runtime.GOOS == "plan9" { if len(path) > 0 && path[0] == '#' { - return path, errInvalidPath + return "", errInvalidPath } } for i := range path { diff --git a/readrequest_test.go b/readrequest_test.go index c64f9c45..5aaf3b9f 100644 --- a/readrequest_test.go +++ b/readrequest_test.go @@ -416,7 +416,7 @@ func TestReadRequest(t *testing.T) { req.Body = nil testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw) diff(t, testName, req, tt.Req) - var bout bytes.Buffer + var bout strings.Builder if rbody != nil { _, err := io.Copy(&bout, rbody) if err != nil { @@ -458,12 +458,10 @@ Content-Length: 5 1234`)}, // golang.org/issue/22464 - {"leading_space_in_header", reqBytes(`HEAD / HTTP/1.1 - Host: foo -Content-Length: 5`)}, - {"leading_tab_in_header", reqBytes(`HEAD / HTTP/1.1 -` + "\t" + `Host: foo -Content-Length: 5`)}, + {"leading_space_in_header", reqBytes(`GET / HTTP/1.1 + Host: foo`)}, + {"leading_tab_in_header", reqBytes(`GET / HTTP/1.1 +` + "\t" + `Host: foo`)}, } func TestReadRequest_Bad(t *testing.T) { diff --git a/request.go b/request.go index 74f3ce7b..2bcc77f4 100644 --- a/request.go +++ b/request.go @@ -49,9 +49,12 @@ type ProtocolError struct { func (pe *ProtocolError) Error() string { return pe.ErrorString } var ( - // ErrNotSupported is returned by the Push method of Pusher - // implementations to indicate that HTTP/2 Push support is not - // available. + // ErrNotSupported indicates that a feature is not supported. + // + // It is returned by ResponseController methods to indicate that + // the handler does not support the method, and by the Push method + // of Pusher implementations to indicate that HTTP/2 Push support + // is not available. ErrNotSupported = &ProtocolError{"feature not supported"} // Deprecated: ErrUnexpectedTrailer is no longer returned by @@ -317,14 +320,14 @@ type Request struct { Response *Response // ctx is either the client or server context. It should only - // be modified via copying the whole Request using WithContext. + // be modified via copying the whole Request using Clone or WithContext. // It is unexported to prevent people from using Context wrong // and mutating the contexts held by callers of the same request. ctx context.Context } // Context returns the request's context. To change the context, use -// WithContext. +// Clone or WithContext. // // The returned context is always non-nil; it defaults to the // background context. @@ -349,9 +352,7 @@ func (r *Request) Context() context.Context { // sending the request, and reading the response headers and body. // // To create a new request with a context, use NewRequestWithContext. -// To change the context of a request, such as an incoming request you -// want to modify before sending back out, use Request.Clone. Between -// those two uses, it's rare to need WithContext. +// To make a deep copy of a request with a new context, use Request.Clone. func (r *Request) WithContext(ctx context.Context) *Request { if ctx == nil { panic("nil context") @@ -418,6 +419,9 @@ var ErrNoCookie = errors.New("http: named cookie not present") // If multiple cookies match the given name, only one cookie will // be returned. func (r *Request) Cookie(name string) (*Cookie, error) { + if name == "" { + return nil, ErrNoCookie + } for _, c := range readCookies(r.Header, name) { return c, nil } @@ -1029,6 +1033,8 @@ func ReadRequest(b *bufio.Reader) (*Request, error) { func readRequest(b *bufio.Reader) (req *Request, err error) { tp := newTextprotoReader(b) + defer putTextprotoReader(tp) + req = new(Request) // First line: GET /index.html HTTP/1.0 @@ -1037,7 +1043,6 @@ func readRequest(b *bufio.Reader) (req *Request, err error) { return nil, err } defer func() { - putTextprotoReader(tp) if err == io.EOF { err = io.ErrUnexpectedEOF } @@ -1168,7 +1173,8 @@ func (l *maxBytesReader) Read(p []byte) (n int, err error) { // If they asked for a 32KB byte read but only 5 bytes are // remaining, no need to read 32KB. 6 bytes will answer the // question of the whether we hit the limit or go past it. - if int64(len(p)) > l.n+1 { + // 0 < len(p) < 2^63 + if int64(len(p))-1 > l.n { p = p[:l.n+1] } n, err = l.r.Read(p) diff --git a/request_test.go b/request_test.go index f199ae72..e89297a1 100644 --- a/request_test.go +++ b/request_test.go @@ -22,7 +22,6 @@ import ( "testing" . "github.com/ooni/oohttp" - "github.com/ooni/oohttp/httptest" ) func TestQuery(t *testing.T) { @@ -290,10 +289,11 @@ Content-Type: text/plain // the payload size and the internal leeway buffer size of 10MiB overflows, that we // correctly return an error. func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { - defer afterTest(t) - + run(t, testMaxInt64ForMultipartFormMaxMemoryOverflow) +} +func testMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T, mode testMode) { payloadSize := 1 << 10 - cst := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { // The combination of: // MaxInt64 + payloadSize + (internal spare of 10MiB) // triggers the overflow. See issue https://golang.org/issue/40430/ @@ -301,8 +301,7 @@ func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { Error(rw, err.Error(), StatusBadRequest) return } - })) - defer cst.Close() + })).ts fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) mf, err := mw.CreateFormFile("file", "myfile.txt") @@ -330,11 +329,9 @@ func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { } } -func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) } -func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) } -func testRedirect(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestRedirect(t *testing.T) { run(t, testRequestRedirect) } +func testRequestRedirect(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": w.Header().Set("Location", "/foo/") @@ -345,7 +342,6 @@ func testRedirect(t *testing.T, h2 bool) { w.WriteHeader(StatusBadRequest) } })) - defer cst.close() var end = regexp.MustCompile("/foo/$") r, err := cst.c.Get(cst.ts.URL) @@ -812,7 +808,7 @@ func TestStarRequest(t *testing.T) { clientReq := *req clientReq.Body = nil - var out bytes.Buffer + var out strings.Builder if err := clientReq.Write(&out); err != nil { t.Fatal(err) } @@ -820,7 +816,7 @@ func TestStarRequest(t *testing.T) { if strings.Contains(out.String(), "chunked") { t.Error("wrote chunked request; want no body") } - back, err := ReadRequest(bufio.NewReader(bytes.NewReader(out.Bytes()))) + back, err := ReadRequest(bufio.NewReader(strings.NewReader(out.String()))) if err != nil { t.Fatal(err) } @@ -832,7 +828,7 @@ func TestStarRequest(t *testing.T) { t.Errorf("Original request doesn't match Request read back.") t.Logf("Original: %#v", req) t.Logf("Original.URL: %#v", req.URL) - t.Logf("Wrote: %s", out.Bytes()) + t.Logf("Wrote: %s", out.String()) t.Logf("Read back (doesn't match Original): %#v", back) } } @@ -979,6 +975,12 @@ func TestMaxBytesReaderDifferentLimits(t *testing.T) { wantN: len(testStr), wantErr: false, }, + 10: { /* Issue 54408 */ + limit: int64(1<<63 - 1), + lenP: len(testStr), + wantN: len(testStr), + wantErr: false, + }, } for i, tt := range tests { rc := MaxBytesReader(nil, io.NopCloser(strings.NewReader(testStr)), tt.limit) @@ -1030,19 +1032,10 @@ func TestRequestCloneTransferEncoding(t *testing.T) { } } -func TestNoPanicOnRoundTripWithBasicAuth_h1(t *testing.T) { - testNoPanicWithBasicAuth(t, h1Mode) -} - -func TestNoPanicOnRoundTripWithBasicAuth_h2(t *testing.T) { - testNoPanicWithBasicAuth(t, h2Mode) -} - // Issue 34878: verify we don't panic when including basic auth (Go 1.13 regression) -func testNoPanicWithBasicAuth(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer cst.close() +func TestNoPanicOnRoundTripWithBasicAuth(t *testing.T) { run(t, testNoPanicWithBasicAuth) } +func testNoPanicWithBasicAuth(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) u, err := url.Parse(cst.ts.URL) if err != nil { @@ -1163,7 +1156,7 @@ func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectCo if fh.Filename != expectFilename { t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) } - var b bytes.Buffer + var b strings.Builder _, err = io.Copy(&b, f) if err != nil { t.Fatal("copying contents:", err) @@ -1174,6 +1167,47 @@ func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectCo return f } +// Issue 53181: verify Request.Cookie return the correct Cookie. +// Return ErrNoCookie instead of the first cookie when name is "". +func TestRequestCookie(t *testing.T) { + for _, tt := range []struct { + name string + value string + expectedErr error + }{ + { + name: "foo", + value: "bar", + expectedErr: nil, + }, + { + name: "", + expectedErr: ErrNoCookie, + }, + } { + req, err := NewRequest("GET", "http://example.com/", nil) + if err != nil { + t.Fatal(err) + } + req.AddCookie(&Cookie{Name: tt.name, Value: tt.value}) + c, err := req.Cookie(tt.name) + if err != tt.expectedErr { + t.Errorf("got %v, want %v", err, tt.expectedErr) + } + + // skip if error occurred. + if err != nil { + continue + } + if c.Value != tt.value { + t.Errorf("got %v, want %v", c.Value, tt.value) + } + if c.Name != tt.name { + t.Errorf("got %s, want %v", tt.name, c.Name) + } + } +} + const ( fileaContents = "This is a test file." filebContents = "Another test file." @@ -1282,11 +1316,6 @@ Host: localhost:8080 `) } -const ( - withTLS = true - noTLS = false -) - func BenchmarkFileAndServer_1KB(b *testing.B) { benchmarkFileAndServer(b, 1<<10) } @@ -1314,16 +1343,12 @@ func benchmarkFileAndServer(b *testing.B, n int64) { b.Fatalf("Failed to copy %d bytes: %v", n, err) } - b.Run("NoTLS", func(b *testing.B) { - runFileAndServerBenchmarks(b, noTLS, f, n) - }) - - b.Run("TLS", func(b *testing.B) { - runFileAndServerBenchmarks(b, withTLS, f, n) - }) + run(b, func(b *testing.B, mode testMode) { + runFileAndServerBenchmarks(b, mode, f, n) + }, []testMode{http1Mode, https1Mode, http2Mode}) } -func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) { +func runFileAndServerBenchmarks(b *testing.B, mode testMode, f *os.File, n int64) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { defer req.Body.Close() nc, err := io.Copy(io.Discard, req.Body) @@ -1336,14 +1361,8 @@ func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int6 } }) - var cst *httptest.Server - if tlsOption == withTLS { - cst = httptest.NewTLSServer(handler) - } else { - cst = httptest.NewServer(handler) - } + cst := newClientServerTest(b, mode, handler).ts - defer cst.Close() b.ResetTimer() for i := 0; i < b.N; i++ { // Perform some setup. diff --git a/requestwrite_test.go b/requestwrite_test.go index bdc1e3c5..380ae9de 100644 --- a/requestwrite_test.go +++ b/requestwrite_test.go @@ -629,7 +629,7 @@ func TestRequestWrite(t *testing.T) { tt.Req.Header = make(Header) } - var braw bytes.Buffer + var braw strings.Builder err := tt.Req.Write(&braw) if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e { t.Errorf("writing #%d, err = %q, want %q", i, g, e) @@ -649,7 +649,7 @@ func TestRequestWrite(t *testing.T) { if tt.WantProxy != "" { setBody() - var praw bytes.Buffer + var praw strings.Builder err = tt.Req.WriteProxy(&praw) if err != nil { t.Errorf("WriteProxy #%d: %s", i, err) @@ -815,7 +815,7 @@ func TestRequestWriteClosesBody(t *testing.T) { if err != nil { t.Fatal(err) } - buf := new(bytes.Buffer) + buf := new(strings.Builder) if err := req.Write(buf); err != nil { t.Error(err) } diff --git a/response_test.go b/response_test.go index 73278365..0deeba1d 100644 --- a/response_test.go +++ b/response_test.go @@ -597,7 +597,7 @@ func TestReadResponse(t *testing.T) { rbody := resp.Body resp.Body = nil diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) - var bout bytes.Buffer + var bout strings.Builder if rbody != nil { _, err = io.Copy(&bout, rbody) if err != nil { @@ -810,7 +810,7 @@ func TestResponseStatusStutter(t *testing.T) { ProtoMajor: 1, ProtoMinor: 3, } - var buf bytes.Buffer + var buf strings.Builder r.Write(&buf) if strings.Contains(buf.String(), "123 123") { t.Errorf("stutter in status: %s", buf.String()) @@ -830,7 +830,7 @@ func TestResponseContentLengthShortBody(t *testing.T) { if res.ContentLength != 123 { t.Fatalf("Content-Length = %d; want 123", res.ContentLength) } - var buf bytes.Buffer + var buf strings.Builder n, err := io.Copy(&buf, res.Body) if n != int64(len(shortBody)) { t.Errorf("Copied %d bytes; want %d, len(%q)", n, len(shortBody), shortBody) @@ -973,19 +973,6 @@ func matchErr(err error, wantErr any) error { return fmt.Errorf("%v; want %v", err, wantErr) } -func TestNeedsSniff(t *testing.T) { - // needsSniff returns true with an empty response. - r := &response{} - if got, want := r.needsSniff(), true; got != want { - t.Errorf("needsSniff = %t; want %t", got, want) - } - // needsSniff returns false when Content-Type = nil. - r.handlerHeader = Header{"Content-Type": nil} - if got, want := r.needsSniff(), false; got != want { - t.Errorf("needsSniff empty Content-Type = %t; want %t", got, want) - } -} - // A response should only write out single Connection: close header. Tests #19499. func TestResponseWritesOnlySingleConnectionClose(t *testing.T) { const connectionCloseHeader = "Connection: close" @@ -1003,7 +990,7 @@ func TestResponseWritesOnlySingleConnectionClose(t *testing.T) { t.Fatalf("ReadResponse failed %v", err) } - var buf2 bytes.Buffer + var buf2 strings.Builder if err = res.Write(&buf2); err != nil { t.Fatalf("Write failed %v", err) } diff --git a/responsecontroller.go b/responsecontroller.go new file mode 100644 index 00000000..018bdc00 --- /dev/null +++ b/responsecontroller.go @@ -0,0 +1,122 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "fmt" + "net" + "time" +) + +// A ResponseController is used by an HTTP handler to control the response. +// +// A ResponseController may not be used after the Handler.ServeHTTP method has returned. +type ResponseController struct { + rw ResponseWriter +} + +// NewResponseController creates a ResponseController for a request. +// +// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method, +// or have an Unwrap method returning the original ResponseWriter. +// +// If the ResponseWriter implements any of the following methods, the ResponseController +// will call them as appropriate: +// +// Flush() +// FlushError() error // alternative Flush returning an error +// Hijack() (net.Conn, *bufio.ReadWriter, error) +// SetReadDeadline(deadline time.Time) error +// SetWriteDeadline(deadline time.Time) error +// +// If the ResponseWriter does not support a method, ResponseController returns +// an error matching ErrNotSupported. +func NewResponseController(rw ResponseWriter) *ResponseController { + return &ResponseController{rw} +} + +type rwUnwrapper interface { + Unwrap() ResponseWriter +} + +// Flush flushes buffered data to the client. +func (c *ResponseController) Flush() error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ FlushError() error }: + return t.FlushError() + case Flusher: + t.Flush() + return nil + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + +// Hijack lets the caller take over the connection. +// See the Hijacker interface for details. +func (c *ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := c.rw + for { + switch t := rw.(type) { + case Hijacker: + return t.Hijack() + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, nil, errNotSupported() + } + } +} + +// SetReadDeadline sets the deadline for reading the entire request, including the body. +// Reads from the request body after the deadline has been exceeded will return an error. +// A zero value means no deadline. +// +// Setting the read deadline after it has been exceeded will not extend it. +func (c *ResponseController) SetReadDeadline(deadline time.Time) error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ SetReadDeadline(time.Time) error }: + return t.SetReadDeadline(deadline) + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + +// SetWriteDeadline sets the deadline for writing the response. +// Writes to the response body after the deadline has been exceeded will not block, +// but may succeed if the data has been buffered. +// A zero value means no deadline. +// +// Setting the write deadline after it has been exceeded will not extend it. +func (c *ResponseController) SetWriteDeadline(deadline time.Time) error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ SetWriteDeadline(time.Time) error }: + return t.SetWriteDeadline(deadline) + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + +// errNotSupported returns an error that Is ErrNotSupported, +// but is not == to it. +func errNotSupported() error { + return fmt.Errorf("%w", ErrNotSupported) +} diff --git a/responsecontroller_test.go b/responsecontroller_test.go new file mode 100644 index 00000000..014737e7 --- /dev/null +++ b/responsecontroller_test.go @@ -0,0 +1,266 @@ +package http_test + +import ( + "errors" + "fmt" + "io" + "os" + "sync" + "testing" + "time" + + . "github.com/ooni/oohttp" +) + +func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) } +func testResponseControllerFlush(t *testing.T, mode testMode) { + continuec := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + w.Write([]byte("one")) + if err := ctl.Flush(); err != nil { + t.Errorf("ctl.Flush() = %v, want nil", err) + return + } + <-continuec + w.Write([]byte("two")) + })) + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() + + buf := make([]byte, 16) + n, err := res.Body.Read(buf) + close(continuec) + if err != nil || string(buf[:n]) != "one" { + t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one") + } + + got, err := io.ReadAll(res.Body) + if err != nil || string(got) != "two" { + t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two") + } +} + +func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) } +func testResponseControllerHijack(t *testing.T, mode testMode) { + const header = "X-Header" + const value = "set" + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + c, _, err := ctl.Hijack() + if mode == http2Mode { + if err == nil { + t.Errorf("ctl.Hijack = nil, want error") + } + w.Header().Set(header, value) + return + } + if err != nil { + t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err) + return + } + fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value) + })) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if got, want := res.Header.Get(header), value; got != want { + t.Errorf("response header %q = %q, want %q", header, got, want) + } +} + +func TestResponseControllerSetPastWriteDeadline(t *testing.T) { + run(t, testResponseControllerSetPastWriteDeadline) +} +func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + w.Write([]byte("one")) + if err := ctl.Flush(); err != nil { + t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err) + } + if err := ctl.SetWriteDeadline(time.Now().Add(-10 * time.Second)); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + + w.Write([]byte("two")) + if err := ctl.Flush(); err == nil { + t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil") + } + // Connection errors are sticky, so resetting the deadline does not permit + // making more progress. We might want to change this in the future, but verify + // the current behavior for now. If we do change this, we'll want to make sure + // to do so only for writing the response body, not headers. + if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + w.Write([]byte("three")) + if err := ctl.Flush(); err == nil { + t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil") + } + })) + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() + b, _ := io.ReadAll(res.Body) + if string(b) != "one" { + t.Errorf("unexpected body: %q", string(b)) + } +} + +func TestResponseControllerSetFutureWriteDeadline(t *testing.T) { + run(t, testResponseControllerSetFutureWriteDeadline) +} +func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) { + errc := make(chan error, 1) + startwritec := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + w.WriteHeader(200) + if err := ctl.Flush(); err != nil { + t.Errorf("ctl.Flush() = %v, want nil", err) + } + <-startwritec // don't set the deadline until the client reads response headers + if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + _, err := io.Copy(w, neverEnding('a')) + errc <- err + })) + + res, err := cst.c.Get(cst.ts.URL) + close(startwritec) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() + _, err = io.Copy(io.Discard, res.Body) + if err == nil { + t.Errorf("client reading from truncated request body: got nil error, want non-nil") + } + err = <-errc // io.Copy error + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err) + } +} + +func TestResponseControllerSetPastReadDeadline(t *testing.T) { + run(t, testResponseControllerSetPastReadDeadline) +} +func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) { + readc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + b := make([]byte, 3) + n, err := io.ReadFull(r.Body, b) + b = b[:n] + if err != nil || string(b) != "one" { + t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one") + return + } + if err := ctl.SetReadDeadline(time.Now()); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + return + } + b, err = io.ReadAll(r.Body) + if err == nil || string(b) != "" { + t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b)) + } + close(readc) + // Connection errors are sticky, so resetting the deadline does not permit + // making more progress. We might want to change this in the future, but verify + // the current behavior for now. + if err := ctl.SetReadDeadline(time.Time{}); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + return + } + b, err = io.ReadAll(r.Body) + if err == nil { + t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b)) + } + })) + + pr, pw := io.Pipe() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + pw.Write([]byte("one")) + <-readc + pw.Write([]byte("two")) + pw.Close() + }() + defer wg.Wait() + res, err := cst.c.Post(cst.ts.URL, "text/foo", pr) + if err == nil { + defer res.Body.Close() + } +} + +func TestResponseControllerSetFutureReadDeadline(t *testing.T) { + run(t, testResponseControllerSetFutureReadDeadline) +} +func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) { + respBody := "response body" + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { + ctl := NewResponseController(w) + if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + } + _, err := io.Copy(io.Discard, req.Body) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err) + } + w.Write([]byte(respBody)) + })) + pr, pw := io.Pipe() + res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if string(got) != respBody || err != nil { + t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody) + } + pw.Close() +} + +type wrapWriter struct { + ResponseWriter +} + +func (w wrapWriter) Unwrap() ResponseWriter { + return w.ResponseWriter +} + +func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) } +func testWrappedResponseController(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w = wrapWriter{w} + ctl := NewResponseController(w) + if err := ctl.Flush(); err != nil { + t.Errorf("ctl.Flush() = %v, want nil", err) + } + if err := ctl.SetReadDeadline(time.Time{}); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + } + if err := ctl.SetWriteDeadline(time.Time{}); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + })) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + io.Copy(io.Discard, res.Body) + defer res.Body.Close() +} diff --git a/responsewrite_test.go b/responsewrite_test.go index 1cc87b94..226ad722 100644 --- a/responsewrite_test.go +++ b/responsewrite_test.go @@ -5,7 +5,6 @@ package http import ( - "bytes" "io" "strings" "testing" @@ -276,7 +275,7 @@ func TestResponseWrite(t *testing.T) { for i := range respWriteTests { tt := &respWriteTests[i] - var braw bytes.Buffer + var braw strings.Builder err := tt.Resp.Write(&braw) if err != nil { t.Errorf("error writing #%d: %s", i, err) diff --git a/serve_test.go b/serve_test.go index 9b8b8491..50c7b15f 100644 --- a/serve_test.go +++ b/serve_test.go @@ -19,6 +19,7 @@ import ( "io" "log" "math/rand" + "mime/multipart" "net" "net/url" "os" @@ -146,7 +147,7 @@ func newHandlerTest(h Handler) handlerTest { func (ht *handlerTest) rawResponse(req string) string { reqb := reqBytes(req) - var output bytes.Buffer + var output strings.Builder conn := &rwTestConn{ Reader: bytes.NewReader(reqb), Writer: &output, @@ -244,15 +245,13 @@ var vtests = []struct { {"http://someHost.com/someDir", "/someDir/"}, } -func TestHostHandlers(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) } +func testHostHandlers(t *testing.T, mode testMode) { mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -485,9 +484,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { // properly sets the query string in the redirect URL. // See Issue 17841. func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { - setParallel(t) - defer afterTest(t) - + run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode}) +} +func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) { writeBackQuery := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.URL.RawQuery) } @@ -500,8 +499,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { fmt.Fprintf(w, "%s:bar", r.URL.RawQuery) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts tests := [...]struct { path string @@ -544,7 +542,6 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { setParallel(t) - defer afterTest(t) mux := NewServeMux() mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/")) @@ -576,9 +573,6 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""}, } - ts := httptest.NewServer(mux) - defer ts.Close() - for i, tt := range tests { req, _ := NewRequest(tt.method, tt.url, nil) w := httptest.NewRecorder() @@ -600,13 +594,10 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { } } -func TestShouldRedirectConcurrency(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) } +func testShouldRedirectConcurrency(t *testing.T, mode testMode) { mux := NewServeMux() - ts := httptest.NewServer(mux) - defer ts.Close() + newClientServerTest(t, mode, mux) mux.HandleFunc("/", func(w ResponseWriter, r *Request) {}) } @@ -654,13 +645,12 @@ func benchmarkServeMux(b *testing.B, runHandler bool) { } } -func TestServerTimeouts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) } +func testServerTimeouts(t *testing.T, mode testMode) { // Try three times, with increasing timeouts. tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} for i, timeout := range tries { - err := testServerTimeouts(timeout) + err := testServerTimeoutsWithTimeout(t, timeout, mode) if err == nil { return } @@ -672,16 +662,15 @@ func TestServerTimeouts(t *testing.T) { t.Fatal("all attempts failed") } -func testServerTimeouts(timeout time.Duration) error { +func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - })) - ts.Config.ReadTimeout = timeout - ts.Config.WriteTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + ts.Config.WriteTimeout = timeout + }).ts // Hit the HTTP server successfully. c := ts.Client() @@ -746,23 +735,90 @@ func testServerTimeouts(timeout time.Duration) error { return nil } +func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) } +func testServerReadTimeout(t *testing.T, mode testMode) { + respBody := "response body" + for timeout := 5 * time.Millisecond; ; timeout *= 2 { + cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { + _, err := io.Copy(io.Discard, req.Body) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err) + } + res.Write([]byte(respBody)) + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = -1 // don't time out while reading headers + ts.Config.ReadTimeout = timeout + }) + pr, pw := io.Pipe() + res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr) + if err != nil { + t.Logf("Get error, retrying: %v", err) + cst.close() + continue + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if string(got) != respBody || err != nil { + t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody) + } + pw.Close() + break + } +} + +func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) } +func testServerWriteTimeout(t *testing.T, mode testMode) { + for timeout := 5 * time.Millisecond; ; timeout *= 2 { + errc := make(chan error, 2) + cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { + errc <- nil + _, err := io.Copy(res, neverEnding('a')) + errc <- err + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = timeout + }) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + // Probably caused by the write timeout expiring before the handler runs. + t.Logf("Get error, retrying: %v", err) + cst.close() + continue + } + defer res.Body.Close() + _, err = io.Copy(io.Discard, res.Body) + if err == nil { + t.Errorf("client reading from truncated request body: got nil error, want non-nil") + } + select { + case <-errc: + err = <-errc // io.Copy error + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err) + } + return + default: + // The write timeout expired before the handler started. + t.Logf("handler didn't run, retrying") + cst.close() + } + } +} + // Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) -func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { +func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) { + run(t, testWriteDeadlineExtendedOnNewRequest) +} +func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}), + func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }, + ).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - t.Fatal(err) - } for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -783,9 +839,6 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { t.Fatalf("http2 Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } time.Sleep(ts.Config.WriteTimeout / 2) } } @@ -808,33 +861,31 @@ func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) { } // Test that the HTTP/2 server RSTs stream on slow write. -func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) { +func TestWriteDeadlineEnforcedPerStream(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) - defer afterTest(t) - tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testWriteDeadlineEnforcedPerStream(t, mode, timeout) + }) + }) } -func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { +func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request times out - })) - ts.Config.WriteTimeout = timeout / 2 - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = timeout / 2 + }).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -842,12 +893,9 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #1: %v", err) + return fmt.Errorf("Get #1: %v", err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } req, err = NewRequest("GET", ts.URL, nil) if err != nil { @@ -856,45 +904,42 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { r, err = c.Do(req) if err == nil { r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } - return fmt.Errorf("http2 Get #2 expected error, got nil") + return fmt.Errorf("Get #2 expected error, got nil") } - expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 - if !strings.Contains(err.Error(), expected) { - return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + if mode == http2Mode { + expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 + if !strings.Contains(err.Error(), expected) { + return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + } } return nil } // Test that the HTTP/2 server does not send RST when WriteDeadline not set. -func TestHTTP2NoWriteDeadline(t *testing.T) { +func TestNoWriteDeadline(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) defer afterTest(t) - tryTimeouts(t, testHTTP2NoWriteDeadline) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testNoWriteDeadline(t, mode, timeout) + }) + }) } -func testHTTP2NoWriteDeadline(timeout time.Duration) error { +func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request timesout - })) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + })).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } for i := 0; i < 2; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -903,12 +948,9 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #%d: %v", i, err) + return fmt.Errorf("Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } } return nil } @@ -916,15 +958,14 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. -func TestOnlyWriteTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) } +func testOnlyWriteTimeout(t *testing.T, mode testMode) { var ( mu sync.RWMutex conn net.Conn ) var afterTimeoutErrc = make(chan error, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { buf := make([]byte, 512<<10) _, err := w.Write(buf) if err != nil { @@ -940,10 +981,9 @@ func TestOnlyWriteTimeout(t *testing.T) { conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) _, err = w.Write(buf) afterTimeoutErrc <- err - })) - ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} + }).ts c := ts.Client() @@ -990,9 +1030,12 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { } // TestIdentityResponse verifies that a handler can unset -func TestIdentityResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) } +func testIdentityResponse(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56019") + } + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -1010,9 +1053,7 @@ func TestIdentityResponse(t *testing.T) { } }) - ts := httptest.NewServer(handler) - defer ts.Close() - + ts := newClientServerTest(t, mode, handler).ts c := ts.Client() // Note: this relies on the assumption (which is true) that @@ -1046,6 +1087,10 @@ func TestIdentityResponse(t *testing.T) { } res.Body.Close() + if mode != http1Mode { + return + } + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1068,9 +1113,7 @@ func TestIdentityResponse(t *testing.T) { func testTCPConnectionCloses(t *testing.T, req string, h Handler) { setParallel(t) - defer afterTest(t) - s := httptest.NewServer(h) - defer s.Close() + s := newClientServerTest(t, http1Mode, h).ts conn, err := net.Dial("tcp", s.Listener.Addr().String()) if err != nil { @@ -1112,9 +1155,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(handler) - defer ts.Close() + ts := newClientServerTest(t, http1Mode, handler).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1190,14 +1231,12 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { } // Issue 15703 -func TestKeepAliveFinalChunkWithEOF(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) } +func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}")) })) - defer cst.close() type data struct { Addr string } @@ -1220,16 +1259,11 @@ func TestKeepAliveFinalChunkWithEOF(t *testing.T) { } } -func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } -func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } - -func testSetsRemoteAddr(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) } +func testSetsRemoteAddr(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1274,17 +1308,18 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "RA:%s", r.RemoteAddr) - })) + run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode}) +} +func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { conns := make(chan net.Conn) - ts.Listener = &blockingRemoteAddrListener{ - Listener: ts.Listener, - conns: conns, - } - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "RA:%s", r.RemoteAddr) + }), func(ts *httptest.Server) { + ts.Listener = &blockingRemoteAddrListener{ + Listener: ts.Listener, + conns: conns, + } + }).ts c := ts.Client() c.Timeout = time.Second @@ -1349,13 +1384,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. -func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } -func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } - -func testHeadResponses(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) } +func testHeadResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("")) if err != nil { t.Errorf("ResponseWriter.Write: %v", err) @@ -1367,7 +1398,6 @@ func testHeadResponses(t *testing.T, h2 bool) { t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) - defer cst.close() res, err := cst.c.Head(cst.ts.URL) if err != nil { t.Error(err) @@ -1391,14 +1421,16 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTimeout(t *testing.T, mode testMode) { errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + func(ts *httptest.Server) { + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) + }, + ).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -1421,19 +1453,18 @@ func TestTLSHandshakeTimeout(t *testing.T) { } } -func TestTLSServer(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) } +func testTLSServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") if r.TLS.HandshakeComplete { w.Header().Set("X-TLS-HandshakeComplete", "true") } } - })) - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + }).ts // Connect an idle TCP connection to this server before we run // our real tests. This idle connection used to block forever @@ -1526,14 +1557,15 @@ func TestServeTLS(t *testing.T) { // Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests. func TestTLSServerRejectHTTPRequests(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode}) +} +func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("unexpected HTTPS request") - })) - var errBuf bytes.Buffer - ts.Config.ErrorLog = log.New(&errBuf, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + var errBuf bytes.Buffer + ts.Config.ErrorLog = log.New(&errBuf, "", 0) + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1725,11 +1757,9 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. -// http2 test: TestServer_Response_Automatic100Continue -func TestServerExpect(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) } +func testServerExpect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only // conditionally want to do. @@ -1739,8 +1769,7 @@ func TestServerExpect(t *testing.T) { } else { w.WriteHeader(StatusUnauthorized) } - })) - defer ts.Close() + })).ts runTest := func(test serverExpectTest) { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -2239,11 +2268,8 @@ func (c cancelableTimeoutContext) Err() error { return nil } -func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } -func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } -func testTimeoutHandler(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) } +func testTimeoutHandler(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2253,8 +2279,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h2, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2300,10 +2325,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // See issues 8209 and 8414. -func TestTimeoutHandlerRace(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) } +func testTimeoutHandlerRace(t *testing.T, mode testMode) { delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { ms, _ := strconv.Atoi(r.URL.Path[1:]) if ms == 0 { @@ -2315,8 +2338,7 @@ func TestTimeoutHandlerRace(t *testing.T) { } }) - ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts c := ts.Client() @@ -2345,16 +2367,13 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. // Both issues involved panics in the implementation of TimeoutHandler. -func TestTimeoutHandlerRaceHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) } +func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) { delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(204) }) - ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts var wg sync.WaitGroup gate := make(chan bool, 50) @@ -2385,9 +2404,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { } // Issue 9162 -func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) } +func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2398,8 +2416,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h1Mode, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2443,15 +2460,17 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { // Issue 14568. func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { + run(t, testTimeoutHandlerStartTimerWhenServing) +} +func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping sleeping test in -short mode") } - defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { w.WriteHeader(StatusNoContent) } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts defer ts.Close() c := ts.Client() @@ -2470,9 +2489,8 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } -func TestTimeoutHandlerContextCanceled(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) } +func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) { writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "text/plain") @@ -2492,7 +2510,7 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() h := NewTestTimeoutHandler(sayHi, ctx) - cst := newClientServerTest(t, h1Mode, h) + cst := newClientServerTest(t, mode, h) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -2512,15 +2530,13 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { } // https://golang.org/issue/15948 -func TestTimeoutHandlerEmptyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) } +func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) { var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts c := ts.Client() @@ -2539,7 +2555,9 @@ func TestTimeoutHandlerPanicRecovery(t *testing.T) { wrapper := func(h Handler) Handler { return TimeoutHandler(h, time.Second, "") } - testHandlerPanic(t, false, false, wrapper, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, wrapper, "intentional death for testing") + }, testNotParallel) } func TestRedirectBadPath(t *testing.T) { @@ -2657,17 +2675,10 @@ func TestRedirectContentTypeAndBody(t *testing.T) { // connection immediately. But when it re-uses the connection, it typically closes // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. -func TestZeroLengthPostAndResponse_h1(t *testing.T) { - testZeroLengthPostAndResponse(t, h1Mode) -} -func TestZeroLengthPostAndResponse_h2(t *testing.T) { - testZeroLengthPostAndResponse(t, h2Mode) -} +func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) } -func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { +func testZeroLengthPostAndResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) @@ -2677,7 +2688,6 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } rw.Header().Set("Content-Length", "0") })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, strings.NewReader("")) if err != nil { @@ -2704,42 +2714,35 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } } -func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) } -func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) } - -func TestHandlerPanic_h1(t *testing.T) { - testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing") +func TestHandlerPanicNil(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, nil) + }, testNotParallel) } -func TestHandlerPanic_h2(t *testing.T) { - testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing") + +func TestHandlerPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, "intentional death for testing") + }, testNotParallel) } func TestHandlerPanicWithHijack(t *testing.T) { // Only testing HTTP/1, and our http2 server doesn't support hijacking. - testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, true, mode, nil, "intentional death for testing") + }, []testMode{http1Mode}) } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { - defer afterTest(t) - // Unlike the other tests that set the log output to io.Discard - // to quiet the output, this test uses a pipe. The pipe serves three - // purposes: +func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) { + // Direct log output to a pipe. // - // 1) The log.Print from the http server (generated by the caught - // panic) will go to the pipe instead of stderr, making the - // output quiet. + // We read from the pipe to verify that the handler actually caught the panic + // and logged something. // - // 2) We read from the pipe to verify that the handler - // actually caught the panic and logged something. - // - // 3) The blocking Read call prevents this TestHandlerPanic - // function from exiting before the HTTP server handler - // finishes crashing. If this text function exited too - // early (and its defer log.SetOutput(os.Stderr) ran), - // then the crash output could spill into the next test. + // We use a pipe rather than a buffer, because when testing connection hijacking + // server shutdown doesn't wait for the hijacking handler to return, so the + // log may occur after the server has shut down. pr, pw := io.Pipe() - log.SetOutput(pw) - defer log.SetOutput(os.Stderr) defer pw.Close() var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2755,12 +2758,11 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H if wrapper != nil { handler = wrapper(handler) } - cst := newClientServerTest(t, h2, handler) - defer cst.close() + cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(pw, "", 0) + }) - // Do a blocking read on the log output pipe so its logging - // doesn't bleed into the next test. But wait only 5 seconds - // for it. + // Do a blocking read on the log output pipe. done := make(chan bool, 1) go func() { buf := make([]byte, 4<<10) @@ -2781,10 +2783,16 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H return } + var delay time.Duration + if deadline, ok := t.Deadline(); ok { + delay = time.Until(deadline) + } else { + delay = 5 * time.Second + } select { case <-done: return - case <-time.After(5 * time.Second): + case <-time.After(delay): t.Fatal("expected server handler to log an error") } } @@ -2799,9 +2807,11 @@ func (w terrorWriter) Write(p []byte) (int, error) { // Issue 16456: allow writing 0 bytes on hijacked conn to test hijack // without any log spam. func TestServerWriteHijackZeroBytes(t *testing.T) { - defer afterTest(t) + run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode}) +} +func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) { done := make(chan struct{}) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -2814,10 +2824,9 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { if err != ErrHijacked { t.Errorf("Write error = %v; want ErrHijacked", err) } - })) - ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + }).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -2832,19 +2841,23 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { } } -func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } -func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } -func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } -func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } +func TestServerNoDate(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Date") + }) +} -func testServerNoHeader(t *testing.T, h2 bool, header string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentType(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Content-Type") + }) +} + +func testServerNoHeader(t *testing.T, mode testMode, header string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil io.WriteString(w, "foo") // non-empty })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -2855,15 +2868,13 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } } -func TestStripPrefix(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) } +func testStripPrefix(t *testing.T, mode testMode) { h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo/bar", h)) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts c := ts.Client() @@ -2913,15 +2924,11 @@ func TestStripPrefixNotModifyRequest(t *testing.T) { } } -func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } -func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } -func testRequestLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) } +func testRequestLimit(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") }), optQuietLog) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { @@ -2931,7 +2938,7 @@ func testRequestLimit(t *testing.T, h2 bool) { if res != nil { defer res.Body.Close() } - if h2 { + if mode == http2Mode { // In HTTP/2, the result depends on a race. If the client has received the // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip // will fail with an error. Otherwise, the client should receive a 431 from the @@ -2973,13 +2980,10 @@ func (cr countReader) Read(p []byte) (n int, err error) { return } -func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } -func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } -func testRequestBodyLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } +func testRequestBodyLimit(t *testing.T, mode testMode) { const limit = 1 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) n, err := io.Copy(io.Discard, r.Body) if err == nil { @@ -2996,7 +3000,6 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit) } })) - defer cst.close() nWritten := new(int64) req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) @@ -3020,13 +3023,12 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. -func TestClientWriteShutdown(t *testing.T) { +func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) } +func testClientWriteShutdown(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/17906") } - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -3071,12 +3073,12 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerGracefulClose, []testMode{http1Mode}) +} +func testServerGracefulClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3114,11 +3116,9 @@ func TestServerGracefulClose(t *testing.T) { <-writeErr } -func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) } -func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) } -func testCaseSensitiveMethod(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) } +func testCaseSensitiveMethod(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) } @@ -3139,8 +3139,10 @@ func testCaseSensitiveMethod(t *testing.T, h2 bool) { // response, the net/http package adds a "Content-Length: 0" response // header. func TestContentLengthZero(t *testing.T) { - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) - defer ts.Close() + run(t, testContentLengthZero, []testMode{http1Mode}) +} +func testContentLengthZero(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -3167,15 +3169,17 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testCloseNotifier, []testMode{http1Mode}) +} +func testCloseNotifier(t *testing.T, mode testMode) { gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() <-cc sawClose <- true - })) + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3209,11 +3213,12 @@ For: // // Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921. func TestCloseNotifierPipelined(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testCloseNotifierPipelined, []testMode{http1Mode}) +} +func testCloseNotifierPipelined(t *testing.T, mode testMode) { gotReq := make(chan bool, 2) sawClose := make(chan bool, 2) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() select { @@ -3222,8 +3227,7 @@ func TestCloseNotifierPipelined(t *testing.T) { case <-time.After(100 * time.Millisecond): } sawClose <- true - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3293,12 +3297,14 @@ func TestCloseNotifierChanLeak(t *testing.T) { // Issue 9763. // HTTP/1-only test. (http2 doesn't have Hijack) func TestHijackAfterCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testHijackAfterCloseNotifier, []testMode{http1Mode}) +} +func testHijackAfterCloseNotifier(t *testing.T, mode testMode) { script := make(chan string, 2) script <- "closenotify" script <- "hijack" close(script) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { plan := <-script switch plan { default: @@ -3321,13 +3327,12 @@ func TestHijackAfterCloseNotifier(t *testing.T) { c.Close() return } - })) - defer ts.Close() - res1, err := Get(ts.URL) + })).ts + res1, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } - res2, err := Get(ts.URL) + res2, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } @@ -3339,12 +3344,13 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode}) +} +func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) gotCloseNotify := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(bodyOkay) // caller will read false if nothing else reqBody := r.Body @@ -3371,8 +3377,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { case <-time.After(5 * time.Second): gotCloseNotify <- false } - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3392,14 +3397,14 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { } } -func TestOptions(t *testing.T) { +func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) } +func testOptions(t *testing.T, mode testMode) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() mux.HandleFunc("/", func(w ResponseWriter, r *Request) { uric <- r.RequestURI }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3444,6 +3449,37 @@ func TestOptions(t *testing.T) { } } +func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) } +func testOptionsHandler(t *testing.T, mode testMode) { + rc := make(chan *Request, 1) + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + rc <- r + }), func(ts *httptest.Server) { + ts.Config.DisableGeneralOptionsHandler = true + }).ts + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + + select { + case got := <-rc: + if got.Method != "OPTIONS" || got.RequestURI != "*" { + t.Errorf("Expected OPTIONS * request, got %v", got) + } + case <-time.After(5 * time.Second): + t.Error("timeout") + } +} + // Tests regarding the ordering of Write, WriteHeader, Header, and // Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the // (*response).header to the wire. In Go 1.1, the actual wire flush is @@ -3663,7 +3699,7 @@ func TestAcceptMaxFds(t *testing.T) { func TestWriteAfterHijack(t *testing.T) { req := reqBytes("GET / HTTP/1.1\nHost: golang.org") - var buf bytes.Buffer + var buf strings.Builder wrotec := make(chan bool, 1) conn := &rwTestConn{ Reader: bytes.NewReader(req), @@ -3725,12 +3761,12 @@ func TestDoubleHijack(t *testing.T) { // optimization and is pointless if dealing with a // badly behaved client. func TestHTTP10ConnectionHeader(t *testing.T) { - defer afterTest(t) - + run(t, testHTTP10ConnectionHeader, []testMode{http1Mode}) +} +func testHTTP10ConnectionHeader(t *testing.T, mode testMode) { mux := NewServeMux() mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts // net/http uses HTTP/1.1 for requests, so write requests manually tests := []struct { @@ -3777,14 +3813,11 @@ func TestHTTP10ConnectionHeader(t *testing.T) { } // See golang.org/issue/5660 -func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } -func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } -func testServerReaderFromOrder(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) } +func testServerReaderFromOrder(t *testing.T, mode testMode) { pr, pw := io.Pipe() const size = 3 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path done := make(chan bool) go func() { @@ -3804,7 +3837,6 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { pw.Close() <-done })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size)) if err != nil { @@ -3878,16 +3910,10 @@ func TestContentTypeOkayOn204(t *testing.T) { // proxy). So then two people own that Request.Body (both the server // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. -func TestTransportAndServerSharedBodyRace_h1(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h1Mode) +func TestTransportAndServerSharedBodyRace(t *testing.T) { + run(t, testTransportAndServerSharedBodyRace) } -func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h2Mode) -} -func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { const bodySize = 1 << 20 // errorf is like t.Errorf, but also writes to println. When @@ -3901,7 +3927,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { } unblockBackend := make(chan bool) - backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() didCopy := make(chan any) go func() { @@ -3928,7 +3954,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { backendRespc := make(chan *Response, 1) var proxy *clientServerTest - proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { req2, _ := NewRequest("POST", backend.ts.URL, req.Body) req2.ContentLength = bodySize cancel := make(chan struct{}) @@ -3948,7 +3974,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // Try to cause a race: Both the Transport and the proxy handler's Server // will try to read/close req.Body (aka req2.Body) - if h2 { + if mode == http2Mode { close(cancel) } else { proxy.c.Transport.(*Transport).CancelRequest(req2) @@ -3992,22 +4018,23 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // cause the Handler goroutine's Request.Body.Close to block. // See issue 7121. func TestRequestBodyCloseDoesntBlock(t *testing.T) { + run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode}) +} +func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) readErrCh := make(chan error, 1) errCh := make(chan error, 2) - server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go func(body io.Reader) { _, err := body.Read(make([]byte, 100)) readErrCh <- err }(req.Body) time.Sleep(500 * time.Millisecond) - })) - defer server.Close() + })).ts closeConn := make(chan bool) defer close(closeConn) @@ -4070,9 +4097,8 @@ func TestAppendTime(t *testing.T) { } } -func TestServerConnState(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) } +func testServerConnState(t *testing.T, mode testMode) { handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello.") @@ -4138,37 +4164,36 @@ func TestServerConnState(t *testing.T) { // next call to wantLog. } - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { handler[r.URL.Path](w, r) - })) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + sl := <-activeLog + if sl.active == nil && state == StateNew { + sl.active = c + } else if sl.active != c { + t.Errorf("unexpected conn in state %s", state) + activeLog <- sl + return + } + sl.got = append(sl.got, state) + if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { + close(sl.complete) + sl.complete = nil + } + activeLog <- sl + } + }).ts defer func() { activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete. ts.Close() }() - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - ts.Config.ConnState = func(c net.Conn, state ConnState) { - if c == nil { - t.Errorf("nil conn seen in state %s", state) - return - } - sl := <-activeLog - if sl.active == nil && state == StateNew { - sl.active = c - } else if sl.active != c { - t.Errorf("unexpected conn in state %s", state) - activeLog <- sl - return - } - sl.got = append(sl.got, state) - if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { - close(sl.complete) - sl.complete = nil - } - activeLog <- sl - } - - ts.Start() c := ts.Client() mustGet := func(url string, headers ...string) { @@ -4250,13 +4275,15 @@ func TestServerConnState(t *testing.T) { }, StateNew, StateActive, StateIdle, StateClosed) } -func TestServerKeepAlivesEnabled(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - ts.Config.SetKeepAlivesEnabled(false) - ts.Start() - defer ts.Close() - res, err := Get(ts.URL) +func TestServerKeepAlivesEnabledResultClose(t *testing.T) { + run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode}) +} +func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -4267,16 +4294,12 @@ func TestServerKeepAlivesEnabled(t *testing.T) { } // golang.org/issue/7856 -func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } -func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } -func testServerEmptyBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) } +func testServerEmptyBodyRace(t *testing.T, mode testMode) { var n int32 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) }), optQuietLog) - defer cst.close() var wg sync.WaitGroup const reqs = 20 for i := 0; i < reqs; i++ { @@ -4357,9 +4380,9 @@ func TestCloseWrite(t *testing.T) { // fixed. // // So add an explicit test for this. -func TestServerFlushAndHijack(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) } +func testServerFlushAndHijack(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello, ") w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() @@ -4370,8 +4393,7 @@ func TestServerFlushAndHijack(t *testing.T) { if err := conn.Close(); err != nil { t.Error(err) } - })) - defer ts.Close() + })).ts res, err := Get(ts.URL) if err != nil { t.Fatal(err) @@ -4393,20 +4415,21 @@ func TestServerFlushAndHijack(t *testing.T) { // To test, verify we don't timeout or see fewer unique client // addresses (== unique connections) than requests. func TestServerKeepAliveAfterWriteError(t *testing.T) { + run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode}) +} +func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) const numReq = 3 addrc := make(chan string, numReq) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrc <- r.RemoteAddr time.Sleep(500 * time.Millisecond) w.(Flusher).Flush() - })) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }).ts errc := make(chan error, numReq) go func() { @@ -4450,12 +4473,13 @@ func TestServerKeepAliveAfterWriteError(t *testing.T) { // Issue 9987: shouldn't add automatic Content-Length (or // Content-Type) if a Transfer-Encoding was set by the handler. func TestNoContentLengthIfTransferEncoding(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode}) +} +func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "foo") io.WriteString(w, "") - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -4465,7 +4489,7 @@ func TestNoContentLengthIfTransferEncoding(t *testing.T) { t.Fatal(err) } bs := bufio.NewScanner(c) - var got bytes.Buffer + var got strings.Builder for bs.Scan() { if strings.TrimSpace(bs.Text()) == "" { break @@ -4554,7 +4578,7 @@ GET /should-be-ignored HTTP/1.1 Host: foo `) - var buf bytes.Buffer + var buf strings.Builder conn := &rwTestConn{ Reader: bytes.NewReader(req), Writer: &buf, @@ -4603,15 +4627,12 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { } } -func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) } -func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) } -func testHandlerSetsBodyNil(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) } +func testHandlerSetsBodyNil(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = nil fmt.Fprintf(w, "%v", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -4701,9 +4722,11 @@ func TestServerValidatesHostHeader(t *testing.T) { } func TestServerHandlersCanHandleH2PRI(t *testing.T) { + run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode}) +} +func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) { const upgradeResponse = "upgrade here" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, br, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -4725,8 +4748,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { return } io.WriteString(conn, upgradeResponse) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -4793,17 +4815,12 @@ func TestServerValidatesHeaders(t *testing.T) { } } -func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) -} -func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) +func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { + run(t, testServerRequestContextCancel_ServeHTTPDone) } -func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) { ctxc := make(chan context.Context, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() select { case <-ctx.Done(): @@ -4812,7 +4829,6 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } ctxc <- ctx })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4831,16 +4847,16 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { // is always blocked in a Read call so it notices the EOF from the client. // See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode}) +} +func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) { inHandler := make(chan struct{}) handlerDone := make(chan struct{}) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(inHandler) <-r.Context().Done() close(handlerDone) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -4852,23 +4868,17 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { <-handlerDone } -func TestServerContext_ServerContextKey_h1(t *testing.T) { - testServerContext_ServerContextKey(t, h1Mode) -} -func TestServerContext_ServerContextKey_h2(t *testing.T) { - testServerContext_ServerContextKey(t, h2Mode) +func TestServerContext_ServerContextKey(t *testing.T) { + run(t, testServerContext_ServerContextKey) } -func testServerContext_ServerContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testServerContext_ServerContextKey(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() got := ctx.Value(ServerContextKey) if _, ok := got.(*Server); !ok { t.Errorf("context value = %T; want *http.Server", got) } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4876,20 +4886,14 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { res.Body.Close() } -func TestServerContext_LocalAddrContextKey_h1(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h1Mode) +func TestServerContext_LocalAddrContextKey(t *testing.T) { + run(t, testServerContext_LocalAddrContextKey) } -func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h2Mode) -} -func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) { ch := make(chan any, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) - defer cst.close() if _, err := cst.c.Head(cst.ts.URL); err != nil { t.Fatal(err) } @@ -4942,16 +4946,19 @@ func TestHandlerSetTransferEncodingGzip(t *testing.T) { } func BenchmarkClientServer(b *testing.B) { + run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode}) +} +func benchmarkClientServer(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - defer ts.Close() + })).ts b.StartTimer() + c := ts.Client() for i := 0; i < b.N; i++ { - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Fatal("Get:", err) } @@ -4969,33 +4976,21 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } -func BenchmarkClientServerParallel4(b *testing.B) { - benchmarkClientServerParallel(b, 4, false) -} - -func BenchmarkClientServerParallel64(b *testing.B) { - benchmarkClientServerParallel(b, 64, false) -} - -func BenchmarkClientServerParallelTLS4(b *testing.B) { - benchmarkClientServerParallel(b, 4, true) -} - -func BenchmarkClientServerParallelTLS64(b *testing.B) { - benchmarkClientServerParallel(b, 64, true) +func BenchmarkClientServerParallel(b *testing.B) { + for _, parallelism := range []int{4, 64} { + b.Run(fmt.Sprint(parallelism), func(b *testing.B) { + run(b, func(b *testing.B, mode testMode) { + benchmarkClientServerParallel(b, parallelism, mode) + }, []testMode{http1Mode, https1Mode, http2Mode}) + }) + } } -func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { +func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) { b.ReportAllocs() - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - if useTLS { - ts.StartTLS() - } else { - ts.Start() - } - defer ts.Close() + })).ts b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { @@ -5385,15 +5380,15 @@ Host: golang.org } } -func BenchmarkCloseNotifier(b *testing.B) { +func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) } +func benchmarkCloseNotifier(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() sawClose := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { <-rw.(CloseNotifier).CloseNotify() sawClose <- true - })) - defer ts.Close() + })).ts tot := time.NewTimer(5 * time.Second) defer tot.Stop() b.StartTimer() @@ -5429,20 +5424,18 @@ func TestConcurrentServerServe(t *testing.T) { } } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) } +func testServerIdleTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) - })) - ts.Config.ReadHeaderTimeout = 1 * time.Second - ts.Config.IdleTimeout = 2 * time.Second - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + }).ts c := ts.Client() get := func() string { @@ -5497,12 +5490,12 @@ func get(t *testing.T, c *Client, url string) string { // Tests that calls to Server.SetKeepAlivesEnabled(false) closes any // currently-open connections. func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode}) +} +func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -5541,16 +5534,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { - testServerShutdown(t, h1Mode) -} -func TestServerShutdown_h2(t *testing.T) { - testServerShutdown(t, h2Mode) -} - -func testServerShutdown(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) } +func testServerShutdown(t *testing.T, mode testMode) { var doShutdown func() // set later var doStateCount func() var shutdownRes = make(chan error, 1) @@ -5566,10 +5551,9 @@ func testServerShutdown(t *testing.T, h2 bool) { time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) }) - cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) { srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) }) - defer cst.close() doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) @@ -5599,24 +5583,22 @@ func testServerShutdown(t *testing.T, h2 bool) { } } -func TestServerShutdownStateNew(t *testing.T) { +func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) } +func testServerShutdownStateNew(t *testing.T, mode testMode) { if testing.Short() { t.Skip("test takes 5-6 seconds; skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - // nothing. - })) var connAccepted sync.WaitGroup - ts.Config.ConnState = func(conn net.Conn, state ConnState) { - if state == StateNew { - connAccepted.Done() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing. + }), func(ts *httptest.Server) { + ts.Config.ConnState = func(conn net.Conn, state ConnState) { + if state == StateNew { + connAccepted.Done() + } } - } - ts.Start() - defer ts.Close() + }).ts // Start a connection but never write to it. connAccepted.Add(1) @@ -5678,16 +5660,14 @@ func TestServerCloseDeadlock(t *testing.T) { // Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by // both HTTP/1 and HTTP/2. -func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } -func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } -func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { - if h2 { +func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) } +func testServerKeepAlivesEnabled(t *testing.T, mode testMode) { + if mode == http2Mode { restore := ExportSetH2GoawayTimeout(10 * time.Millisecond) defer restore() } // Not parallel: messes with global variable. (http2goAwayTimeout) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) @@ -5724,9 +5704,8 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { // Issue 18447: test that the Server's ReadTimeout is stopped while // the server's doing its 1-byte background read between requests, // waiting for the connection to maybe close. -func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) } +func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5734,17 +5713,16 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { select { case <-time.After(2 * timeout): fmt.Fprint(w, "ok") case <-r.Context().Done(): fmt.Fprint(w, r.Context().Err()) } - })) - ts.Config.ReadTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + }).ts c := ts.Client() @@ -5764,6 +5742,58 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { }) } +// Issue 54784: test that the Server's ReadHeaderTimeout only starts once the +// beginning of a request has been received, rather than including time the +// connection spent idle. +func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { + run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode}) +} +func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) { + runTimeSensitiveTest(t, []time.Duration{ + 10 * time.Millisecond, + 50 * time.Millisecond, + 250 * time.Millisecond, + time.Second, + 2 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = timeout + ts.Config.IdleTimeout = 0 // disable idle timeout + }).ts + + // rather than using an http.Client, create a single connection, so that + // we can ensure this connection is not closed. + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("dial failed: %v", err) + } + br := bufio.NewReader(conn) + defer conn.Close() + + if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil { + return fmt.Errorf("writing first request failed: %v", err) + } + + if _, err := ReadResponse(br, nil); err != nil { + return fmt.Errorf("first response (before timeout) failed: %v", err) + } + + // wait for longer than the server's ReadHeaderTimeout, and then send + // another request + time.Sleep(timeout * 3 / 2) + + if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil { + return fmt.Errorf("writing second request failed: %v", err) + } + + if _, err := ReadResponse(br, nil); err != nil { + return fmt.Errorf("second response (after timeout) failed: %v", err) + } + + return nil + }) +} + // runTimeSensitiveTest runs test with the provided durations until one passes. // If they all fail, t.Fatal is called with the last one's duration and error value. func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) { @@ -5778,20 +5808,71 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t * } } +// Issue 18535: test that the Server doesn't try to do a background +// read if it's already done one. +func TestServerDuplicateBackgroundRead(t *testing.T) { + run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode}) +} +func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) { + goroutines := 5 + requests := 2000 + if testing.Short() { + goroutines = 3 + requests = 100 + } + + hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts + + reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cn, err := net.Dial("tcp", hts.Listener.Addr().String()) + if err != nil { + t.Error(err) + return + } + defer cn.Close() + + wg.Add(1) + go func() { + defer wg.Done() + io.Copy(io.Discard, cn) + }() + + for j := 0; j < requests; j++ { + if t.Failed() { + return + } + _, err := cn.Write(reqBytes) + if err != nil { + t.Error(err) + return + } + } + }() + } + wg.Wait() +} + // Test that the bufio.Reader returned by Hijack includes any buffered // byte (from the Server's backgroundRead) in its buffer. We want the // Handler code to be able to tell that a byte is available via // bufio.Reader.Buffered(), without resorting to Reading it // (potentially blocking) to get at it. func TestServerHijackGetsBackgroundByte(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) inHandler := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) // Tell the client to send more data after the GET request. @@ -5814,8 +5895,7 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { t.Error("context unexpectedly canceled") default: } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -5844,14 +5924,15 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { // immediate 1MB of data to the server to fill up the server's 4KB // buffer. func TestServerHijackGetsBackgroundByte_big(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) const size = 8 << 10 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) conn, buf, err := w.(Hijacker).Hijack() @@ -5875,8 +5956,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { } else if !allX { t.Errorf("read %q; want %d 'x'", slurp, size) } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6012,73 +6092,27 @@ func TestStripPortFromHost(t *testing.T) { } } -func TestServerContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) - type baseKey struct{} - type connKey struct{} - ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) - } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) - } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - ctx := <-ch - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("base context key = %#v; want %q", got, want) - } - if got, want := ctx.Value(connKey{}), "conn"; got != want { - t.Errorf("conn context key = %#v; want %q", got, want) - } -} - -func TestServerContextsHTTP2(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerContexts(t *testing.T) { run(t, testServerContexts) } +func testServerContexts(t *testing.T, mode testMode) { type baseKey struct{} type connKey struct{} ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - if r.ProtoMajor != 2 { - t.Errorf("unexpected HTTP/1.x request") - } + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) + }), func(ts *httptest.Server) { + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.TLS = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - ts.StartTLS() - defer ts.Close() - ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true + }).ts res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) @@ -6095,20 +6129,20 @@ func TestServerContextsHTTP2(t *testing.T) { // Issue 35750: check ConnContext not modifying context for other connections func TestConnContextNotModifyingAllContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testConnContextNotModifyingAllContexts) +} +func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) { type connKey struct{} - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { rw.Header().Set("Connection", "close") - })) - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got := ctx.Value(connKey{}); got != nil { - t.Errorf("in ConnContext, unexpected context key = %#v", got) + }), func(ts *httptest.Server) { + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got := ctx.Value(connKey{}); got != nil { + t.Errorf("in ConnContext, unexpected context key = %#v", got) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() + }).ts var res *Response var err error @@ -6129,10 +6163,12 @@ func TestConnContextNotModifyingAllContexts(t *testing.T) { // Issue 30710: ensure that as per the spec, a server responds // with 501 Not Implemented for unsupported transfer-encodings. func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode}) +} +func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) - })) - defer cst.Close() + })).ts serverURL, err := url.Parse(cst.URL) if err != nil { @@ -6143,7 +6179,7 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { "fugazi", "foo-bar", "unknown", - "\rchunked", + `" chunked"`, } for _, badTE := range unsupportedTEs { @@ -6167,19 +6203,9 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { } } -func TestContentEncodingNoSniffing_h1(t *testing.T) { - testContentEncodingNoSniffing(t, h1Mode) -} - -func TestContentEncodingNoSniffing_h2(t *testing.T) { - testContentEncodingNoSniffing(t, h2Mode) -} - // Issue 31753: don't sniff when Content-Encoding is set -func testContentEncodingNoSniffing(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) } +func testContentEncodingNoSniffing(t *testing.T, mode testMode) { type setting struct { name string body []byte @@ -6242,13 +6268,12 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { for _, tt := range settings { t.Run(tt.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { if tt.contentEncoding != nil { rw.Header().Set("Content-Encoding", tt.contentEncoding.(string)) } rw.Write(tt.body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -6274,13 +6299,13 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // Issue 30803: ensure that TimeoutHandler logs spurious // WriteHeader calls, for consistency with other Handlers. func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { + run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode}) +} +func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - timeoutMsg := "timed out here!" tests := []struct { @@ -6321,7 +6346,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { exitHandler <- true } - logBuf := new(bytes.Buffer) + logBuf := new(strings.Builder) srvLog := log.New(logBuf, "", 0) // When expecting to timeout, we'll keep the duration short. dur := 20 * time.Millisecond @@ -6330,7 +6355,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { dur = 10 * time.Second } th := TimeoutHandler(sh, dur, timeoutMsg) - cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog)) + cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog)) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -6397,15 +6422,16 @@ func BenchmarkResponseStatusLine(b *testing.B) { } }) } + func TestDisableKeepAliveUpgrade(t *testing.T) { + run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode}) +} +func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - - s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "someProto") w.WriteHeader(StatusSwitchingProtocols) @@ -6418,10 +6444,9 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { // Copy from the *bufio.ReadWriter, which may contain buffered data. // Copy to the net.Conn, to avoid buffering the output. io.Copy(c, buf) - })) - s.Config.SetKeepAlivesEnabled(false) - s.Start() - defer s.Close() + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts cl := s.Client() cl.Transport.(*Transport).DisableKeepAlives = true @@ -6490,21 +6515,21 @@ func TestQuerySemicolon(t *testing.T) { {"?a=1;x=good;x=bad", "", "good", true}, } - for _, tt := range tests { - t.Run(tt.query+"/allow=false", func(t *testing.T) { - allowSemicolons := false - testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) - }) - t.Run(tt.query+"/allow=true", func(t *testing.T) { - allowSemicolons, expectWarning := true, false - testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) - }) - } + run(t, func(t *testing.T, mode testMode) { + for _, tt := range tests { + t.Run(tt.query+"/allow=false", func(t *testing.T) { + allowSemicolons := false + testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) + }) + t.Run(tt.query+"/allow=true", func(t *testing.T) { + allowSemicolons, expectWarning := true, false + testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) + }) + } + }) } -func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) { - setParallel(t) - +func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) { writeBackX := func(w ResponseWriter, r *Request) { x := r.URL.Query().Get("x") if expectWarning { @@ -6527,11 +6552,10 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon h = AllowQuerySemicolons(h) } - ts := httptest.NewUnstartedServer(h) - logBuf := &bytes.Buffer{} - ts.Config.ErrorLog = log.New(logBuf, "", 0) - ts.Start() - defer ts.Close() + logBuf := &strings.Builder{} + ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(logBuf, "", 0) + }).ts req, _ := NewRequest("GET", ts.URL+query, nil) res, err := ts.Client().Do(req) @@ -6566,13 +6590,15 @@ func TestMaxBytesHandler(t *testing.T) { for _, requestSize := range []int64{100, 1_000, 1_000_000} { t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), func(t *testing.T) { - testMaxBytesHandler(t, maxSize, requestSize) + run(t, func(t *testing.T, mode testMode) { + testMaxBytesHandler(t, mode, maxSize, requestSize) + }) }) } } } -func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { +func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) { var ( handlerN int64 handlerErr error @@ -6583,7 +6609,7 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { io.Copy(w, &buf) }) - ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts defer ts.Close() c := ts.Client() @@ -6650,36 +6676,87 @@ func TestProcessing(t *testing.T) { } } +func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) } +func testParseFormCleanup(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/20253") + } + + const maxMemory = 1024 + const key = "file" + + if runtime.GOOS == "windows" { + // Windows sometimes refuses to remove a file that was just closed. + t.Skip("https://go.dev/issue/25965") + } + + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + r.ParseMultipartForm(maxMemory) + f, _, err := r.FormFile(key) + if err != nil { + t.Errorf("r.FormFile(%q) = %v", key, err) + return + } + of, ok := f.(*os.File) + if !ok { + t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f) + return + } + w.Write([]byte(of.Name())) + })) + + fBuf := new(bytes.Buffer) + mw := multipart.NewWriter(fBuf) + mf, err := mw.CreateFormFile(key, "myfile.txt") + if err != nil { + t.Fatal(err) + } + if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil { + t.Fatal(err) + } + if err := mw.Close(); err != nil { + t.Fatal(err) + } + req, err := NewRequest("POST", cst.ts.URL, fBuf) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", mw.FormDataContentType()) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + fname, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + cst.close() + if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) { + t.Errorf("file %q exists after HTTP handler returned", string(fname)) + } +} + func TestHeadBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "HEAD") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "HEAD") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") }) }) } func TestGetBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "GET") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "GET") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") }) }) } -func testHeadBody(t *testing.T, h2, chunked bool, method string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("server reading body: %v", err) diff --git a/server.go b/server.go index fa5ca8a5..a0d6b57e 100644 --- a/server.go +++ b/server.go @@ -292,9 +292,9 @@ type conn struct { // on this connection, if any. lastMethod string - curReq atomic.Value // of *response (which has a Request in it) + curReq atomic.Pointer[response] // (which has a Request in it) - curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) + curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState)) // mu guards hijackedv mu sync.Mutex @@ -394,11 +394,11 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { return } -func (cw *chunkWriter) flush() { +func (cw *chunkWriter) flush() error { if !cw.wroteHeader { cw.writeHeader(nil) } - cw.res.conn.bufw.Flush() + return cw.res.conn.bufw.Flush() } func (cw *chunkWriter) close() { @@ -429,14 +429,14 @@ type response struct { wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" - // canWriteContinue is a boolean value accessed as an atomic int32 - // that says whether or not a 100 Continue header can be written - // to the connection. + // canWriteContinue is an atomic boolean that says whether or + // not a 100 Continue header can be written to the + // connection. // writeContinueMu must be held while writing the header. - // These two fields together synchronize the body reader - // (the expectContinueReader, which wants to write 100 Continue) + // These two fields together synchronize the body reader (the + // expectContinueReader, which wants to write 100 Continue) // against the main writer. - canWriteContinue atomicBool + canWriteContinue atomic.Bool writeContinueMu sync.Mutex w *bufio.Writer // buffers output in chunks to chunkWriter @@ -474,7 +474,7 @@ type response struct { // written. trailers []string - handlerDone atomicBool // set true when the handler exits + handlerDone atomic.Bool // set true when the handler exits // Buffers for Date, Content-Length, and status code dateBuf [len(TimeFormat)]byte @@ -485,7 +485,15 @@ type response struct { // TODO(bradfitz): this is currently (for Go 1.8) always // non-nil. Make this lazily-created again as it used to be? closeNotifyCh chan bool - didCloseNotify int32 // atomic (only 0->1 winner should send) + didCloseNotify atomic.Bool // atomic (only false->true winner should send) +} + +func (c *response) SetReadDeadline(deadline time.Time) error { + return c.conn.rwc.SetReadDeadline(deadline) +} + +func (c *response) SetWriteDeadline(deadline time.Time) error { + return c.conn.rwc.SetWriteDeadline(deadline) } // TrailerPrefix is a magic prefix for ResponseWriter.Header map keys @@ -508,11 +516,11 @@ const TrailerPrefix = "Trailer:" func (w *response) finalTrailers() Header { var t Header for k, vv := range w.handlerHeader { - if strings.HasPrefix(k, TrailerPrefix) { + if kk, found := strings.CutPrefix(k, TrailerPrefix); found { if t == nil { t = make(Header) } - t[strings.TrimPrefix(k, TrailerPrefix)] = vv + t[kk] = vv } } for _, k := range w.trailers { @@ -526,12 +534,6 @@ func (w *response) finalTrailers() Header { return t } -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } - // declareTrailer is called for each Trailer header when the // response header is written. It notes that a header will need to be // written in the trailers at the end of the response. @@ -554,12 +556,6 @@ func (w *response) requestTooLarge() { } } -// needsSniff reports whether a Content-Type still needs to be sniffed. -func (w *response) needsSniff() bool { - _, haveType := w.handlerHeader["Content-Type"] - return !w.cw.wroteHeader && !haveType && w.written < sniffLen -} - // writerOnly hides an io.Writer value's optional ReadFrom method // from io.Copy. type writerOnly struct { @@ -748,8 +744,8 @@ func (cr *connReader) handleReadError(_ error) { // may be called from multiple goroutines. func (cr *connReader) closeNotify() { - res, _ := cr.conn.curReq.Load().(*response) - if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res := cr.conn.curReq.Load() + if res != nil && !res.didCloseNotify.Swap(true) { res.closeNotifyCh <- true } } @@ -897,34 +893,34 @@ func (srv *Server) tlsHandshakeTimeout() time.Duration { type expectContinueReader struct { resp *response readCloser io.ReadCloser - closed atomicBool - sawEOF atomicBool + closed atomic.Bool + sawEOF atomic.Bool } func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { - if ecr.closed.isSet() { + if ecr.closed.Load() { return 0, ErrBodyReadAfterClose } w := ecr.resp - if !w.wroteContinue && w.canWriteContinue.isSet() && !w.conn.hijacked() { + if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() { w.wroteContinue = true w.writeContinueMu.Lock() - if w.canWriteContinue.isSet() { + if w.canWriteContinue.Load() { w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") w.conn.bufw.Flush() - w.canWriteContinue.setFalse() + w.canWriteContinue.Store(false) } w.writeContinueMu.Unlock() } n, err = ecr.readCloser.Read(p) if err == io.EOF { - ecr.sawEOF.setTrue() + ecr.sawEOF.Store(true) } return } func (ecr *expectContinueReader) Close() error { - ecr.closed.setTrue() + ecr.closed.Store(true) return ecr.readCloser.Close() } @@ -1151,9 +1147,9 @@ func (w *response) WriteHeader(code int) { // Handle informational headers if code >= 100 && code <= 199 { // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() - if code == 100 && w.canWriteContinue.isSet() { + if code == 100 && w.canWriteContinue.Load() { w.writeContinueMu.Lock() - w.canWriteContinue.setFalse() + w.canWriteContinue.Store(false) w.writeContinueMu.Unlock() } @@ -1311,7 +1307,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // send a Content-Length header. // Further, we don't send an automatic Content-Length if they // set a Transfer-Encoding, because they're generally incompatible. - if w.handlerDone.isSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + if w.handlerDone.Load() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } @@ -1353,7 +1349,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // because we don't know if the next bytes on the wire will be // the body-following-the-timer or the subsequent request. // See Issue 11549. - if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.isSet() { + if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.Load() { w.closeAfterReply = true } @@ -1611,13 +1607,13 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er return 0, ErrHijacked } - if w.canWriteContinue.isSet() { + if w.canWriteContinue.Load() { // Body reader wants to write 100 Continue but hasn't yet. // Tell it not to. The store must be done while holding the lock // because the lock makes sure that there is not an active write // this very moment. w.writeContinueMu.Lock() - w.canWriteContinue.setFalse() + w.canWriteContinue.Store(false) w.writeContinueMu.Unlock() } @@ -1643,7 +1639,7 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er } func (w *response) finishRequest() { - w.handlerDone.setTrue() + w.handlerDone.Store(true) if !w.wroteHeader { w.WriteHeader(StatusOK) @@ -1699,11 +1695,19 @@ func (w *response) closedRequestBodyEarly() bool { } func (w *response) Flush() { + w.FlushError() +} + +func (w *response) FlushError() error { if !w.wroteHeader { w.WriteHeader(StatusOK) } - w.w.Flush() - w.cw.flush() + err := w.w.Flush() + e2 := w.cw.flush() + if err == nil { + err = e2 + } + return err } func (c *conn) finalFlush() { @@ -1786,7 +1790,7 @@ func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { panic("internal error") } packedState := uint64(time.Now().Unix()<<8) | uint64(state) - atomic.StoreUint64(&c.curState.atomic, packedState) + c.curState.Store(packedState) if !runHook { return } @@ -1796,7 +1800,7 @@ func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { } func (c *conn) getState() (state ConnState, unixSec int64) { - packedState := atomic.LoadUint64(&c.curState.atomic) + packedState := c.curState.Load() return ConnState(packedState & 0xff), int64(packedState >> 8) } @@ -1964,7 +1968,7 @@ func (c *conn) serve(ctx context.Context) { if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} - w.canWriteContinue.setTrue() + w.canWriteContinue.Store(true) } } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() @@ -1994,6 +1998,7 @@ func (c *conn) serve(ctx context.Context) { return } w.finishRequest() + c.rwc.SetWriteDeadline(time.Time{}) if !w.shouldReuseConnection() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { c.closeWriteAndWait() @@ -2001,7 +2006,7 @@ func (c *conn) serve(ctx context.Context) { return } c.setState(c.rwc, StateIdle, runHooks) - c.curReq.Store((*response)(nil)) + c.curReq.Store(nil) if !w.conn.server.doKeepAlives() { // We're in shutdown mode. We might've replied @@ -2013,10 +2018,18 @@ func (c *conn) serve(ctx context.Context) { if d := c.server.idleTimeout(); d != 0 { c.rwc.SetReadDeadline(time.Now().Add(d)) - if _, err := c.bufr.Peek(4); err != nil { - return - } + } else { + c.rwc.SetReadDeadline(time.Time{}) } + + // Wait for the connection to become readable again before trying to + // read the next request. This prevents a ReadHeaderTimeout or + // ReadTimeout from starting until the first bytes of the next request + // have been received. + if _, err := c.bufr.Peek(4); err != nil { + return + } + c.rwc.SetReadDeadline(time.Time{}) } } @@ -2042,7 +2055,7 @@ func (w *response) sendExpectationFailed() { // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter // and a Hijacker. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - if w.handlerDone.isSet() { + if w.handlerDone.Load() { panic("net/http: Hijack called after ServeHTTP finished") } if w.wroteHeader { @@ -2064,7 +2077,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { } func (w *response) CloseNotify() <-chan bool { - if w.handlerDone.isSet() { + if w.handlerDone.Load() { panic("net/http: CloseNotify called after ServeHTTP finished") } return w.closeNotifyCh @@ -2595,6 +2608,10 @@ type Server struct { Handler Handler // handler to invoke, http.DefaultServeMux if nil + // DisableGeneralOptionsHandler, if true, passes "OPTIONS *" requests to the Handler, + // otherwise responds with 200 OK and Content-Length: 0. + DisableGeneralOptionsHandler bool + // TLSConfig optionally provides a TLS configuration for use // by ServeTLS and ListenAndServeTLS. Note that this value is // cloned by ServeTLS and ListenAndServeTLS, so it's not @@ -2678,46 +2695,20 @@ type Server struct { // value. ConnContext func(ctx context.Context, c net.Conn) context.Context - inShutdown atomicBool // true when server is in shutdown + inShutdown atomic.Bool // true when server is in shutdown - disableKeepAlives int32 // accessed atomically. + disableKeepAlives atomic.Bool nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoErr error // result of http2.ConfigureServer if used mu sync.Mutex listeners map[*net.Listener]struct{} activeConn map[*conn]struct{} - doneChan chan struct{} onShutdown []func() listenerGroup sync.WaitGroup } -func (s *Server) getDoneChan() <-chan struct{} { - s.mu.Lock() - defer s.mu.Unlock() - return s.getDoneChanLocked() -} - -func (s *Server) getDoneChanLocked() chan struct{} { - if s.doneChan == nil { - s.doneChan = make(chan struct{}) - } - return s.doneChan -} - -func (s *Server) closeDoneChanLocked() { - ch := s.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by s.mu. - close(ch) - } -} - // Close immediately closes all active net.Listeners and any // connections in state StateNew, StateActive, or StateIdle. For a // graceful shutdown, use Shutdown. @@ -2728,10 +2719,9 @@ func (s *Server) closeDoneChanLocked() { // Close returns any error returned from closing the Server's // underlying Listener(s). func (srv *Server) Close() error { - srv.inShutdown.setTrue() + srv.inShutdown.Store(true) srv.mu.Lock() defer srv.mu.Unlock() - srv.closeDoneChanLocked() err := srv.closeListenersLocked() // Unlock srv.mu while waiting for listenerGroup. @@ -2779,11 +2769,10 @@ const shutdownPollIntervalMax = 500 * time.Millisecond // Once Shutdown has been called on a server, it may not be reused; // future calls to methods such as Serve will return ErrServerClosed. func (srv *Server) Shutdown(ctx context.Context) error { - srv.inShutdown.setTrue() + srv.inShutdown.Store(true) srv.mu.Lock() lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() for _, f := range srv.onShutdown { go f() } @@ -2927,17 +2916,17 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { if handler == nil { handler = DefaultServeMux } - if req.RequestURI == "*" && req.Method == "OPTIONS" { + if !sh.srv.DisableGeneralOptionsHandler && req.RequestURI == "*" && req.Method == "OPTIONS" { handler = globalOptionsHandler{} } if req.URL != nil && strings.Contains(req.URL.RawQuery, ";") { - var allowQuerySemicolonsInUse int32 + var allowQuerySemicolonsInUse atomic.Bool req = req.WithContext(context.WithValue(req.Context(), silenceSemWarnContextKey, func() { - atomic.StoreInt32(&allowQuerySemicolonsInUse, 1) + allowQuerySemicolonsInUse.Store(true) })) defer func() { - if atomic.LoadInt32(&allowQuerySemicolonsInUse) == 0 { + if !allowQuerySemicolonsInUse.Load() { sh.srv.logf("http: URL query contains semicolon, which is no longer a supported separator; parts of the query may be stripped when parsed; see golang.org/issue/25192") } }() @@ -3068,10 +3057,8 @@ func (srv *Server) Serve(l net.Listener) error { for { rw, err := l.Accept() if err != nil { - select { - case <-srv.getDoneChan(): + if srv.shuttingDown() { return ErrServerClosed - default: } if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -3198,11 +3185,11 @@ func (s *Server) readHeaderTimeout() time.Duration { } func (s *Server) doKeepAlives() bool { - return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() + return !s.disableKeepAlives.Load() && !s.shuttingDown() } func (s *Server) shuttingDown() bool { - return s.inShutdown.isSet() + return s.inShutdown.Load() } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. @@ -3211,10 +3198,10 @@ func (s *Server) shuttingDown() bool { // shutting down should disable them. func (srv *Server) SetKeepAlivesEnabled(v bool) { if v { - atomic.StoreInt32(&srv.disableKeepAlives, 0) + srv.disableKeepAlives.Store(false) return } - atomic.StoreInt32(&srv.disableKeepAlives, 1) + srv.disableKeepAlives.Store(true) // Close idle HTTP/1 conns: srv.closeIdleConns() diff --git a/sniff_test.go b/sniff_test.go index 728c823c..fd6d6946 100644 --- a/sniff_test.go +++ b/sniff_test.go @@ -89,13 +89,9 @@ func TestDetectContentType(t *testing.T) { } } -func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) } -func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) } - -func testServerContentType(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentTypeSniff(t *testing.T) { run(t, testServerContentTypeSniff) } +func testServerContentTypeSniff(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] n, err := w.Write(tt.data) @@ -135,15 +131,12 @@ func testServerContentType(t *testing.T, h2 bool) { // Issue 5953: shouldn't sniff if the handler set a Content-Type header, // even if it's the empty string. -func TestServerIssue5953_h1(t *testing.T) { testServerIssue5953(t, h1Mode) } -func TestServerIssue5953_h2(t *testing.T) { testServerIssue5953(t, h2Mode) } -func testServerIssue5953(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerIssue5953(t *testing.T) { run(t, testServerIssue5953) } +func testServerIssue5953(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "hi") })) - defer cst.close() resp, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -174,11 +167,8 @@ func (b *byteAtATimeReader) Read(p []byte) (n int, err error) { return 1, nil } -func TestContentTypeWithVariousSources_h1(t *testing.T) { testContentTypeWithVariousSources(t, h1Mode) } -func TestContentTypeWithVariousSources_h2(t *testing.T) { testContentTypeWithVariousSources(t, h2Mode) } -func testContentTypeWithVariousSources(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestContentTypeWithVariousSources(t *testing.T) { run(t, testContentTypeWithVariousSources) } +func testContentTypeWithVariousSources(t *testing.T, mode testMode) { const ( input = "\n\n\t\n" expected = "text/html; charset=utf-8" @@ -240,8 +230,7 @@ func testContentTypeWithVariousSources(t *testing.T, h2 bool) { }, }} { t.Run(test.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(test.handler)) - defer cst.close() + cst := newClientServerTest(t, mode, HandlerFunc(test.handler)) resp, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -266,12 +255,9 @@ func testContentTypeWithVariousSources(t *testing.T, h2 bool) { } } -func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) } -func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) } -func testSniffWriteSize(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSniffWriteSize(t *testing.T) { run(t, testSniffWriteSize) } +func testSniffWriteSize(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) if err != nil { @@ -282,7 +268,6 @@ func testSniffWriteSize(t *testing.T, h2 bool) { t.Errorf("write of %d bytes wrote %d bytes", size, written) } })) - defer cst.close() for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { res, err := cst.c.Get(fmt.Sprintf("%s/?size=%d", cst.ts.URL, size)) if err != nil { diff --git a/tools/compare.bash b/tools/compare.bash index 7d861247..17d6796c 100755 --- a/tools/compare.bash +++ b/tools/compare.bash @@ -9,6 +9,4 @@ test -d $upstreamrepo || git clone git@github.com:golang/go.git $upstreamrepo git pull git checkout $TAG ) -for file in $(cd $upstreamrepo/src/net/http && find . -type f -name \*.go); do - git diff --no-index $upstreamrepo/src/net/http/$file $file || true -done +diff -ur $upstreamrepo/src/net/http . diff --git a/transfer.go b/transfer.go index 25cceda4..cc7d98d0 100644 --- a/transfer.go +++ b/transfer.go @@ -600,7 +600,7 @@ func readTransfer(msg any, r *bufio.Reader) (err error) { return nil } -// Checks whether chunked is part of the encodings stack +// Checks whether chunked is part of the encodings stack. func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } // Checks whether the encoding is explicitly "identity". @@ -738,7 +738,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, // Determine whether to hang up after sending a request and body, or // receiving a response and body -// 'header' is the request headers +// 'header' is the request headers. func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { if major < 1 { return true @@ -757,7 +757,7 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { return hasClose } -// Parse the trailer header +// Parse the trailer header. func fixTrailer(header Header, chunked bool) (Header, error) { vv, ok := header["Trailer"] if !ok { @@ -1081,7 +1081,7 @@ var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { }{})) // unwrapNopCloser return the underlying reader and true if r is a NopCloser -// else it return false +// else it return false. func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) { switch reflect.TypeOf(r) { case nopCloserType, nopCloserWriterToType: diff --git a/transport.go b/transport.go index 56ee28de..b422a919 100644 --- a/transport.go +++ b/transport.go @@ -37,8 +37,8 @@ import ( // DefaultTransport is the default implementation of Transport and is // used by DefaultClient. It establishes network connections as needed // and caches them for reuse by subsequent calls. It uses HTTP proxies -// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and -// $no_proxy) environment variables. +// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY +// and NO_PROXY (or the lowercase versions thereof). var DefaultTransport RoundTripper = &Transport{ Proxy: ProxyFromEnvironment, DialContext: defaultTransportDialContext(&net.Dialer{ @@ -119,6 +119,11 @@ type Transport struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*Request) (*url.URL, error) + // OnProxyConnectResponse is called when the Transport gets an HTTP response from + // a proxy for a CONNECT request. It's called before the check for a 200 OK response. + // If it returns an error, the request fails with that error. + OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error + // DialContext specifies the dial function for creating unencrypted TCP connections. // If DialContext is nil (and the deprecated Dial below is also nil), // then the transport dials using package net. @@ -315,6 +320,7 @@ func (t *Transport) Clone() *Transport { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t2 := &Transport{ Proxy: t.Proxy, + OnProxyConnectResponse: t.OnProxyConnectResponse, DialContext: t.DialContext, Dial: t.Dial, DialTLS: t.DialTLS, @@ -426,8 +432,8 @@ func (t *Transport) onceSetNextProtoDefaults() { // ProxyFromEnvironment returns the URL of the proxy to use for a // given request, as indicated by the environment variables // HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions -// thereof). HTTPS_PROXY takes precedence over HTTP_PROXY for https -// requests. +// thereof). Requests use the proxy from the environment variable +// matching their scheme, unless excluded by NO_PROXY. // // The environment values may be either a complete URL or a // "host[:port]", in which case the "http" scheme is assumed. @@ -814,14 +820,12 @@ func (t *Transport) cancelRequest(key cancelKey, err error) bool { // var ( - // proxyConfigOnce guards proxyConfig envProxyOnce sync.Once envProxyFuncValue func(*url.URL) (*url.URL, error) ) -// defaultProxyConfig returns a ProxyConfig value looked up -// from the environment. This mitigates expensive lookups -// on some platforms (e.g. Windows). +// envProxyFunc returns a function that reads the +// environment variable to determine the proxy address. func envProxyFunc() func(*url.URL) (*url.URL, error) { envProxyOnce.Do(func() { envProxyFuncValue = httpproxy.FromEnvironment().ProxyFunc() @@ -1722,6 +1726,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers conn.Close() return nil, err } + + if t.OnProxyConnectResponse != nil { + err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp) + if err != nil { + return nil, err + } + } + if resp.StatusCode != 200 { _, text, ok := strings.Cut(resp.Status, " ") conn.Close() @@ -2049,7 +2061,7 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte if pc.nwrite == startBytesWritten { return nothingWrittenError{err} } - return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err) } return err } @@ -2256,7 +2268,7 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { // common case. pc.closeLocked(errServerClosedIdle) } else { - pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %v", peekErr)) + pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", peekErr)) } } @@ -2390,6 +2402,10 @@ type nothingWrittenError struct { error } +func (nwe nothingWrittenError) Unwrap() error { + return nwe.error +} + func (pc *persistConn) writeLoop() { defer close(pc.writeLoopDone) for { @@ -2627,7 +2643,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.logf("writeErrCh resv: %T/%#v", err, err) } if err != nil { - pc.close(fmt.Errorf("write error: %v", err)) + pc.close(fmt.Errorf("write error: %w", err)) return nil, pc.mapRoundTripError(req, startBytesWritten, err) } if d := pc.t.ResponseHeaderTimeout; d > 0 { @@ -2729,7 +2745,7 @@ var portMap = map[string]string{ "socks5": "1080", } -// canonicalAddr returns url.Host but always with a ":port" suffix +// canonicalAddr returns url.Host but always with a ":port" suffix. func canonicalAddr(url *url.URL) string { addr := url.Hostname() if v, err := idnaASCII(addr); err == nil { diff --git a/transport_test.go b/transport_test.go index c2520f35..0ea2d412 100644 --- a/transport_test.go +++ b/transport_test.go @@ -134,12 +134,11 @@ func (tcs *testConnSet) check(t *testing.T) { } } -func TestReuseRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) } +func testReuseRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("{}")) - })) - defer ts.Close() + })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) @@ -164,10 +163,9 @@ func TestReuseRequest(t *testing.T) { // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port -func TestTransportKeepAlives(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() +func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) } +func testTransportKeepAlives(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { @@ -196,9 +194,10 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnResponse) +} +func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) @@ -252,9 +251,10 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { // describes the source source connection it got (remote port number + // address of its net.Conn). func TestTransportConnectionCloseOnRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) @@ -316,9 +316,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { // send Connection: close. // HTTP/1-only (Connection: close doesn't exist in h2) func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = true @@ -336,6 +337,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { // Test that Transport only sends one "Connection: close", regardless of // how "close" was indicated. func TestTransportRespectRequestWantsClose(t *testing.T) { + run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode}) +} +func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) { tests := []struct { disableKeepAlives bool close bool @@ -349,9 +353,7 @@ func TestTransportRespectRequestWantsClose(t *testing.T) { for _, tc := range tests { t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), func(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives @@ -386,9 +388,10 @@ func TestTransportRespectRequestWantsClose(t *testing.T) { } func TestTransportIdleCacheKeys(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportIdleCacheKeys, []testMode{http1Mode}) +} +func testTransportIdleCacheKeys(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -419,12 +422,12 @@ func TestTransportIdleCacheKeys(t *testing.T) { // Tests that the HTTP transport re-uses connections when a client // reads to the end of a response Body without closing it. -func TestTransportReadToEndReusesConn(t *testing.T) { - defer afterTest(t) +func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) } +func testTransportReadToEndReusesConn(t *testing.T, mode testMode) { const msg = "foobar" var addrSeen map[string]int - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrSeen[r.RemoteAddr]++ if r.URL.Path == "/chunked/" { w.WriteHeader(200) @@ -434,16 +437,13 @@ func TestTransportReadToEndReusesConn(t *testing.T) { w.WriteHeader(200) } w.Write([]byte(msg)) - })) - defer ts.Close() - - buf := make([]byte, len(msg)) + })).ts for pi, path := range []string{"/content-length/", "/chunked/"} { wantLen := []int{len(msg), -1}[pi] addrSeen = make(map[string]int) for i := 0; i < 3; i++ { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Errorf("Get %s: %v", path, err) continue @@ -458,9 +458,9 @@ func TestTransportReadToEndReusesConn(t *testing.T) { if res.ContentLength != int64(wantLen) { t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) } - n, err := res.Body.Read(buf) - if n != len(msg) || err != io.EOF { - t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + got, err := io.ReadAll(res.Body) + if string(got) != msg || err != nil { + t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg) } } if len(addrSeen) != 1 { @@ -470,13 +470,15 @@ func TestTransportReadToEndReusesConn(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - defer afterTest(t) + run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode}) +} +func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) { stop := make(chan struct{}) // stop marks the exit of main Test goroutine defer close(stop) resch := make(chan string) gotReq := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReq <- true var msg string select { @@ -489,8 +491,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("Write: %v", err) return } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -558,14 +559,15 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportMaxConnsPerHostIncludeDialInProgress) +} +func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) dialStarted := make(chan struct{}) @@ -625,7 +627,9 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { } func TestTransportMaxConnsPerHost(t *testing.T) { - defer afterTest(t) + run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -635,115 +639,101 @@ func TestTransportMaxConnsPerHost(t *testing.T) { } }) - testMaxConns := func(scheme string, ts *httptest.Server) { - defer ts.Close() - - c := ts.Client() - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - if err := ExportHttp2ConfigureTransport(tr); err != nil { - t.Fatalf("ExportHttp2ConfigureTransport: %v", err) - } - - mu := sync.Mutex{} - var conns []net.Conn - var dialCnt, gotConnCnt, tlsHandshakeCnt int32 - tr.Dial = func(network, addr string) (net.Conn, error) { - atomic.AddInt32(&dialCnt, 1) - c, err := net.Dial(network, addr) - mu.Lock() - defer mu.Unlock() - conns = append(conns, c) - return c, err - } - - doReq := func() { - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConnCnt, 1) - } - }, - TLSHandshakeStart: func() { - atomic.AddInt32(&tlsHandshakeCnt, 1) - }, - } - req, _ := NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + ts := newClientServerTest(t, mode, h).ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 - resp, err := c.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body failed: %v", err) - } - } + mu := sync.Mutex{} + var conns []net.Conn + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) + mu.Lock() + defer mu.Unlock() + conns = append(conns, c) + return c, err + } - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, } - wg.Wait() + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - expected := int32(tr.MaxConnsPerHost) - if dialCnt != expected { - t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) - } - if gotConnCnt != expected { - t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) } + } - if t.Failed() { - t.FailNow() - } + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() - mu.Lock() - for _, c := range conns { - c.Close() - } - conns = nil - mu.Unlock() - tr.CloseIdleConnections() + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected) + } + if gotConnCnt != expected { + t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } - doReq() - expected++ - if dialCnt != expected { - t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) - } - if gotConnCnt != expected { - t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } + if t.Failed() { + t.FailNow() } - testMaxConns("http", httptest.NewServer(h)) - testMaxConns("https", httptest.NewTLSServer(h)) + mu.Lock() + for _, c := range conns { + c.Close() + } + conns = nil + mu.Unlock() + tr.CloseIdleConnections() - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - testMaxConns("http2", ts) + doReq() + expected++ + if dialCnt != expected { + t.Errorf("round 2: too many dials: %d", dialCnt) + } + if gotConnCnt != expected { + t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } } func TestTransportRemovesDeadIdleConnections(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) +} +func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -788,10 +778,10 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { // Test that the Transport notices when a server hangs up on its // unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode}) +} +func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() fetch := func(n, retries int) string { @@ -845,11 +835,13 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for https://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { - defer afterTest(t) + run(t, testStressSurpriseServerCloses, []testMode{http1Mode}) +} +func testStressSurpriseServerCloses(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in short mode") } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") w.Header().Set("Content-Type", "text/plain") w.Write([]byte("Hello")) @@ -857,8 +849,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { conn, buf, _ := w.(Hijacker).Hijack() buf.Flush() conn.Close() - })) - defer ts.Close() + })).ts c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc @@ -905,16 +896,15 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly -func TestTransportHeadResponses(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) } +func testTransportHeadResponses(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Content-Length", "123") w.WriteHeader(200) - })) - defer ts.Close() + })).ts c := ts.Client() for i := 0; i < 2; i++ { @@ -940,16 +930,17 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel) +} +func testTransportHeadChunkedResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Transfer-Encoding", "chunked") // client should ignore w.Header().Set("x-client-ipport", r.RemoteAddr) w.WriteHeader(200) - })) - defer ts.Close() + })).ts c := ts.Client() // Ensure that we wait for the readLoop to complete before @@ -990,11 +981,10 @@ var roundTripTests = []struct { } // Test that the modification made to the Request by the RoundTripper is cleaned up -func TestRoundTripGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) } +func testRoundTripGzip(t *testing.T, mode testMode) { const responseBody = "test response body" - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", @@ -1009,8 +999,7 @@ func TestRoundTripGzip(t *testing.T) { rw.Header().Set("Content-Encoding", accept) rw.Write([]byte(responseBody)) } - })) - defer ts.Close() + })).ts tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { @@ -1054,12 +1043,14 @@ func TestRoundTripGzip(t *testing.T) { } -func TestTransportGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) } +func testTransportGzip(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56020") + } const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { if req.Method == "HEAD" { if g := req.Header.Get("Accept-Encoding"); g != "" { t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) @@ -1086,8 +1077,7 @@ func TestTransportGzip(t *testing.T) { io.CopyN(gz, rand.Reader, nRandBytes) } gz.Close() - })) - defer ts.Close() + })).ts c := ts.Client() for _, chunked := range []string{"1", "0"} { @@ -1152,10 +1142,10 @@ func TestTransportGzip(t *testing.T) { // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + run(t, testTransportExpect100Continue, []testMode{http1Mode}) +} +func testTransportExpect100Continue(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. @@ -1193,8 +1183,7 @@ func TestTransportExpect100Continue(t *testing.T) { conn.Close() } - })) - defer ts.Close() + })).ts tests := []struct { path string @@ -1241,7 +1230,9 @@ func TestTransportExpect100Continue(t *testing.T) { } func TestSOCKS5Proxy(t *testing.T) { - defer afterTest(t) + run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testSOCKS5Proxy(t *testing.T, mode testMode) { ch := make(chan string, 1) l := newLocalListener(t) defer l.Close() @@ -1321,12 +1312,7 @@ func TestSOCKS5Proxy(t *testing.T) { }) for _, useTLS := range []bool{false, true} { t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { - var ts *httptest.Server - if useTLS { - ts = httptest.NewTLSServer(h) - } else { - ts = httptest.NewServer(h) - } + ts := newClientServerTest(t, mode, h).ts go proxy(t) c := ts.Client() c.Transport.(*Transport).Proxy = ProxyURL(pu) @@ -1358,16 +1344,16 @@ func TestSOCKS5Proxy(t *testing.T) { func TestTransportProxy(t *testing.T) { defer afterTest(t) - testCases := []struct{ httpsSite, httpsProxy bool }{ - {false, false}, - {false, true}, - {true, false}, - {true, true}, + testCases := []struct{ siteMode, proxyMode testMode }{ + {http1Mode, http1Mode}, + {http1Mode, https1Mode}, + {https1Mode, http1Mode}, + {https1Mode, https1Mode}, } for _, testCase := range testCases { - httpsSite := testCase.httpsSite - httpsProxy := testCase.httpsProxy - t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { + siteMode := testCase.siteMode + proxyMode := testCase.proxyMode + t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) { siteCh := make(chan *Request, 1) h1 := HandlerFunc(func(w ResponseWriter, r *Request) { siteCh <- r @@ -1413,18 +1399,8 @@ func TestTransportProxy(t *testing.T) { }() } }) - var ts *httptest.Server - if httpsSite { - ts = httptest.NewTLSServer(h1) - } else { - ts = httptest.NewServer(h1) - } - var proxy *httptest.Server - if httpsProxy { - proxy = httptest.NewTLSServer(h2) - } else { - proxy = httptest.NewServer(h2) - } + ts := newClientServerTest(t, siteMode, h1).ts + proxy := newClientServerTest(t, proxyMode, h2).ts pu, err := url.Parse(proxy.URL) if err != nil { @@ -1435,7 +1411,7 @@ func TestTransportProxy(t *testing.T) { // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. c := proxy.Client() - if httpsSite { + if siteMode == https1Mode { c = ts.Client() } @@ -1452,7 +1428,7 @@ func TestTransportProxy(t *testing.T) { c.Transport.(*Transport).CloseIdleConnections() ts.Close() proxy.Close() - if httpsSite { + if siteMode == https1Mode { // First message should be a CONNECT, asking for a socket to the real server, if got.Method != "CONNECT" { t.Errorf("Wrong method for secure proxying: %q", got.Method) @@ -1488,6 +1464,98 @@ func TestTransportProxy(t *testing.T) { } } +func TestOnProxyConnectResponse(t *testing.T) { + + var tcases = []struct { + proxyStatusCode int + err error + }{ + { + StatusOK, + nil, + }, + { + StatusForbidden, + errors.New("403"), + }, + } + for _, tcase := range tcases { + h1 := HandlerFunc(func(w ResponseWriter, r *Request) { + + }) + + h2 := HandlerFunc(func(w ResponseWriter, r *Request) { + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + if tcase.proxyStatusCode != StatusOK { + w.WriteHeader(tcase.proxyStatusCode) + return + } + hijacker, ok := w.(Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &Response{ + StatusCode: StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + } + + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed: %v", err) + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + ts := newClientServerTest(t, https1Mode, h1).ts + proxy := newClientServerTest(t, https1Mode, h2).ts + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + c := proxy.Client() + + c.Transport.(*Transport).Proxy = ProxyURL(pu) + c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { + if proxyURL.String() != pu.String() { + t.Errorf("proxy url got %s, want %s", proxyURL, pu) + } + + if "https://"+connectReq.URL.String() != ts.URL { + t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL) + } + return tcase.err + } + if _, err := c.Head(ts.URL); err != nil { + if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) { + t.Errorf("got %v, want %v", err, tcase.err) + } + } + } +} + // Issue 28012: verify that the Transport closes its TCP connection to http proxies // when they're slow to reply to HTTPS CONNECT responses. func TestTransportProxyHTTPSConnectLeak(t *testing.T) { @@ -1601,10 +1669,10 @@ func TestTransportDialPreservesNetOpProxyError(t *testing.T) { // (A bug caused dialConn to instead write the per-request Proxy-Authorization // header through to the shared Header instance, introducing a data race.) func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - - proxy := httptest.NewTLSServer(NotFoundHandler()) + run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader) +} +func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) { + proxy := newClientServerTest(t, mode, NotFoundHandler()).ts defer proxy.Close() c := proxy.Client() @@ -1638,13 +1706,12 @@ func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. -func TestTransportGzipRecursive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) } +func testTransportGzipRecursive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -1666,13 +1733,12 @@ func TestTransportGzipRecursive(t *testing.T) { // golang.org/issue/7750: request fails when server replies with // a short gzip body -func TestTransportGzipShort(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) } +func testTransportGzipShort(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write([]byte{0x1f, 0x8b}) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -1702,19 +1768,23 @@ func waitNumGoroutine(nmax int) int { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + run(t, testTransportPersistConnLeak, testNotParallel) +} +func testTransportPersistConnLeak(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("flaky in HTTP/2") + } // Not parallel: counts goroutines - defer afterTest(t) const numReq = 25 gotReqCh := make(chan bool, numReq) unblockCh := make(chan bool, numReq) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh w.Header().Set("Content-Length", "0") w.WriteHeader(204) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1772,11 +1842,16 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + run(t, testTransportPersistConnLeakShortBody, testNotParallel) +} +func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("flaky in HTTP/2") + } + // Not parallel: measures goroutines. - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1850,9 +1925,10 @@ func (d *countingDialer) Read() (total, live int64) { } func TestTransportPersistConnLeakNeverIdle(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode}) +} +func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Close every connection so that it cannot be kept alive. conn, _, err := w.(Hijacker).Hijack() if err != nil { @@ -1860,8 +1936,7 @@ func TestTransportPersistConnLeakNeverIdle(t *testing.T) { return } conn.Close() - })) - defer ts.Close() + })).ts var d countingDialer c := ts.Client() @@ -1922,13 +1997,17 @@ func (cc *contextCounter) Read() (live int64) { } func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { - defer afterTest(t) + run(t, testTransportPersistConnContextLeakMaxConnsPerHost) +} +func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56021") + } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { runtime.Gosched() w.WriteHeader(StatusOK) - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).MaxConnsPerHost = 1 @@ -1978,16 +2057,15 @@ func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { } // This used to crash; https://golang.org/issue/3266 -func TestTransportIdleConnCrash(t *testing.T) { - defer afterTest(t) +func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) } +func testTransportIdleConnCrash(t *testing.T, mode testMode) { var tr *Transport unblockCh := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockCh tr.CloseIdleConnections() - })) - defer ts.Close() + })).ts c := ts.Client() tr = c.Transport.(*Transport) @@ -2009,16 +2087,15 @@ func TestTransportIdleConnCrash(t *testing.T) { // before the response body has been read. This was a regression // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. -func TestIssue3644(t *testing.T) { - defer afterTest(t) +func TestIssue3644(t *testing.T) { run(t, testIssue3644) } +func testIssue3644(t *testing.T, mode testMode) { const numFoos = 5000 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") for i := 0; i < numFoos; i++ { w.Write([]byte("foo ")) } - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { @@ -2036,14 +2113,12 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. -func TestIssue3595(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIssue3595(t *testing.T) { run(t, testIssue3595) } +func testIssue3595(t *testing.T, mode testMode) { const deniedMsg = "sorry, denied." - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { @@ -2061,12 +2136,11 @@ func TestIssue3595(t *testing.T) { // From https://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" -func TestChunkedNoContent(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) } +func testChunkedNoContent(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) - })) - defer ts.Close() + })).ts c := ts.Client() for _, closeBody := range []bool{true, false} { @@ -2085,17 +2159,18 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { + run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode}) +} +func testTransportConcurrency(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. - defer afterTest(t) maxProcs, numReqs := 16, 500 if testing.Short() { maxProcs, numReqs = 4, 50 } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) - })) - defer ts.Close() + })).ts var wg sync.WaitGroup wg.Add(numReqs) @@ -2146,67 +2221,46 @@ func TestTransportConcurrency(t *testing.T) { wg.Wait() } -func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - const debug = false +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) } +func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) { mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) - ts := httptest.NewServer(mux) - defer ts.Close() - timeout := 100 * time.Millisecond + ts := newClientServerTest(t, mode, mux).ts + connc := make(chan net.Conn, 1) c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr) if err != nil { return nil, err } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = NewLoggingConn("client", conn) + select { + case connc <- conn: + default: } return conn, nil } - getFailed := false - nRuns := 5 - if testing.Short() { - nRuns = 1 - } - for i := 0; i < nRuns; i++ { - if debug { - println("run", i+1, "of", nRuns) - } - sres, err := c.Get(ts.URL + "/get") - if err != nil { - if !getFailed { - // Make the timeout longer, once. - getFailed = true - t.Logf("increasing timeout") - i-- - timeout *= 10 - continue - } - t.Errorf("Error issuing GET: %v", err) - break - } - _, err = io.Copy(io.Discard, sres.Body) - if err == nil { - t.Errorf("Unexpected successful copy") - break - } + res, err := c.Get(ts.URL + "/get") + if err != nil { + t.Fatalf("Error issuing GET: %v", err) } - if debug { - println("tests complete; waiting for handlers to finish") + defer res.Body.Close() + + conn := <-connc + conn.SetDeadline(time.Now().Add(1 * time.Millisecond)) + _, err = io.Copy(io.Discard, res.Body) + if err == nil { + t.Errorf("Unexpected successful copy") } } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode}) +} +func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) { const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -2216,7 +2270,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { defer r.Body.Close() io.Copy(io.Discard, r.Body) }) - ts := httptest.NewServer(mux) + ts := newClientServerTest(t, mode, mux).ts timeout := 100 * time.Millisecond c := ts.Client() @@ -2269,9 +2323,8 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts.Close() } -func TestTransportResponseHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) } +func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } @@ -2284,8 +2337,7 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { inHandler <- true time.Sleep(2 * time.Second) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts c := ts.Client() c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond @@ -2341,18 +2393,18 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } func TestTransportCancelRequest(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportCancelRequest, []testMode{http1Mode}) +} +func testTransportCancelRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2394,17 +2446,14 @@ func TestTransportCancelRequest(t *testing.T) { } } -func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { - setParallel(t) - defer afterTest(t) +func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2431,11 +2480,15 @@ func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { } func TestTransportCancelRequestInDo(t *testing.T) { - testTransportCancelRequestInDo(t, nil) + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, nil) + }, []testMode{http1Mode}) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { - testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) + }, []testMode{http1Mode}) } func TestTransportCancelRequestInDial(t *testing.T) { @@ -2443,7 +2496,7 @@ func TestTransportCancelRequestInDial(t *testing.T) { if testing.Short() { t.Skip("skipping test in -short mode") } - var logbuf bytes.Buffer + var logbuf strings.Builder eventLog := log.New(&logbuf, "", 0) unblockDial := make(chan bool) @@ -2496,19 +2549,17 @@ Get = Get "http://something.no-network.tld/": net/http: request canceled while w } } -func TestCancelRequestWithChannel(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } +func testCancelRequestWithChannel(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2554,19 +2605,20 @@ func TestCancelRequestWithChannel(t *testing.T) { } func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, false) + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, false) + }) } func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, true) + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, true) + }) } -func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { - setParallel(t) - defer afterTest(t) +func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2641,11 +2693,11 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. -func TestTransportCloseResponseBody(t *testing.T) { - defer afterTest(t) +func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) } +func testTransportCloseResponseBody(t *testing.T, mode testMode) { writeErr := make(chan error, 1) msg := []byte("young\n") - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { for { _, err := w.Write(msg) if err != nil { @@ -2654,8 +2706,7 @@ func TestTransportCloseResponseBody(t *testing.T) { } w.(Flusher).Flush() } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -2760,10 +2811,8 @@ func TestTransportEmptyMethod(t *testing.T) { } } -func TestTransportSocketLateBinding(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) } +func testTransportSocketLateBinding(t *testing.T, mode testMode) { mux := NewServeMux() fooGate := make(chan bool, 1) mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { @@ -2774,8 +2823,7 @@ func TestTransportSocketLateBinding(t *testing.T) { mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { w.Header().Set("bar-ipport", r.RemoteAddr) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts dialGate := make(chan bool, 1) c := ts.Client() @@ -2919,18 +2967,18 @@ Content-Length: %d // Issue 17739: the HTTP client must ignore any unknown 1xx // informational responses before the actual response. func TestTransportIgnore1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportIgnore1xxResponses, []testMode{http1Mode}) +} +func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) buf.Flush() conn.Close() })) - defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway - var got bytes.Buffer + var got strings.Builder req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ @@ -2948,14 +2996,15 @@ func TestTransportIgnore1xxResponses(t *testing.T) { res.Write(&got) want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" if got.String() != want { - t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) + t.Errorf(" got: %q\nwant: %q\n", got.String(), want) } } func TestTransportLimits1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) +} +func testTransportLimits1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() for i := 0; i < 10; i++ { buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) @@ -2964,7 +3013,6 @@ func TestTransportLimits1xxResponses(t *testing.T) { buf.Flush() conn.Close() })) - defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway res, err := cst.c.Get(cst.ts.URL) @@ -2981,16 +3029,16 @@ func TestTransportLimits1xxResponses(t *testing.T) { // Issue 26161: the HTTP client must treat 101 responses // as the final response. func TestTransportTreat101Terminal(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportTreat101Terminal, []testMode{http1Mode}) +} +func testTransportTreat101Terminal(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -3014,7 +3062,7 @@ type proxyFromEnvTest struct { } func (t proxyFromEnvTest) String() string { - var buf bytes.Buffer + var buf strings.Builder space := func() { if buf.Len() > 0 { buf.WriteByte(' ') @@ -3122,16 +3170,18 @@ func TestProxyFromEnvironmentLowerCase(t *testing.T) { } func TestIdleConnChannelLeak(t *testing.T) { + run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel) +} +func testIdleConnChannelLeak(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. var mu sync.Mutex var n int - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() n++ mu.Unlock() - })) - defer ts.Close() + })).ts const nReqs = 5 didRead := make(chan bool, nReqs) @@ -3179,11 +3229,12 @@ func TestIdleConnChannelLeak(t *testing.T) { // body into a ReadCloser if it's a Closer, and that the Transport // then closes it. func TestTransportClosesRequestBody(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportClosesRequestBody, []testMode{http1Mode}) +} +func testTransportClosesRequestBody(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) - })) - defer ts.Close() + })).ts c := ts.Client() @@ -3260,10 +3311,11 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { // Trying to repro golang.org/issue/3514 func TestTLSServerClosesConnection(t *testing.T) { - defer afterTest(t) - + run(t, testTLSServerClosesConnection, []testMode{https1Mode}) +} +func testTLSServerClosesConnection(t *testing.T, mode testMode) { closedc := make(chan bool, 1) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/keep-alive-then-die") { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) @@ -3272,8 +3324,7 @@ func TestTLSServerClosesConnection(t *testing.T) { return } fmt.Fprintf(w, "hello") - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -3344,8 +3395,9 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) { // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}) +} +func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { var sconn struct { sync.Mutex c net.Conn @@ -3364,7 +3416,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { } defer closeConn() - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { io.WriteString(w, "bar") return @@ -3375,8 +3427,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive go io.Copy(io.Discard, conn) - })) - defer ts.Close() + })).ts c := ts.Client() const bodySize = 256 << 10 @@ -3409,9 +3460,9 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { // Tests that we don't leak Transport persistConn.readLoop goroutines // when a server hangs up immediately after saying it would keep-alive. -func TestTransportIssue10457(t *testing.T) { - defer afterTest(t) // used to fail in goroutine leak check - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) } +func testTransportIssue10457(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Send a response with no body, keep-alive // (implicit), and then lie and immediately close the // connection. This forces the Transport's readLoop to @@ -3420,8 +3471,7 @@ func TestTransportIssue10457(t *testing.T) { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive conn.Close() - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -3462,6 +3512,9 @@ func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } // This automatically prevents an infinite resend loop because we'll run out of // the cached keep-alive connections eventually. func TestRetryRequestsOnError(t *testing.T) { + run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode}) +} +func testRetryRequestsOnError(t *testing.T, mode testMode) { newRequest := func(method, urlStr string, body io.Reader) *Request { req, err := NewRequest(method, urlStr, body) if err != nil { @@ -3532,11 +3585,9 @@ func TestRetryRequestsOnError(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - defer afterTest(t) - var ( mu sync.Mutex - logbuf bytes.Buffer + logbuf strings.Builder ) logf := func(format string, args ...any) { mu.Lock() @@ -3545,11 +3596,10 @@ func TestRetryRequestsOnError(t *testing.T) { logbuf.WriteByte('\n') } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { logf("Handler") w.Header().Set("X-Status", "ok") - })) - defer ts.Close() + })).ts var writeNumAtomic int32 c := ts.Client() @@ -3619,15 +3669,13 @@ Handler } // Issue 6981 -func TestTransportClosesBodyOnError(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) } +func testTransportClosesBodyOnError(t *testing.T, mode testMode) { readBody := make(chan error, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.ReadAll(r.Body) readBody <- err - })) - defer ts.Close() + })).ts c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) @@ -3667,17 +3715,17 @@ func TestTransportClosesBodyOnError(t *testing.T) { } func TestTransportDialTLS(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLS(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq, didDial bool - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { mu.Lock() @@ -3704,19 +3752,17 @@ func TestTransportDialTLS(t *testing.T) { } } -func TestTransportDialContext(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } +func testTransportDialContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() @@ -3745,18 +3791,18 @@ func TestTransportDialContext(t *testing.T) { } func TestTransportDialTLSContext(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLSContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() @@ -3878,6 +3924,9 @@ func TestTransportTraceGotConnH2IdleConns(t *testing.T) { } func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { + run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) +} +func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3887,8 +3936,7 @@ func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { tr.MaxIdleConnsPerHost = 1 tr.IdleConnTimeout = 10 * time.Millisecond } - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) - defer cst.close() + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) if _, err := cst.c.Get(cst.ts.URL); err != nil { t.Fatalf("got error: %s", err) @@ -3919,13 +3967,12 @@ func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { // implicitly ask for gzip support. If they want that, they need to do it // on their own. // golang.org/issue/8923 -func TestTransportRangeAndGzip(t *testing.T) { - defer afterTest(t) +func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } +func testTransportRangeAndGzip(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { reqc <- r - })) - defer ts.Close() + })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) @@ -3950,15 +3997,13 @@ func TestTransportRangeAndGzip(t *testing.T) { } // Test for issue 10474 -func TestTransportResponseCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) } +func testTransportResponseCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // important that this response has a body. var b [1024]byte w.Write(b[:]) - })) - defer ts.Close() + })).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) @@ -3990,19 +4035,19 @@ func TestTransportResponseCancelRace(t *testing.T) { // Test for issue 19248: Content-Encoding's value is case insensitive. func TestTransportContentEncodingCaseInsensitive(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportContentEncodingCaseInsensitive) +} +func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { for _, ce := range []string{"gzip", "GZIP"} { ce := ce t.Run(ce, func(t *testing.T) { const encodedString = "Hello Gopher" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", ce) gz := gzip.NewWriter(w) gz.Write([]byte(encodedString)) gz.Close() - })) - defer ts.Close() + })).ts res, err := ts.Client().Get(ts.URL) if err != nil { @@ -4023,10 +4068,10 @@ func TestTransportContentEncodingCaseInsensitive(t *testing.T) { } func TestTransportDialCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) +} +func testTransportDialCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) @@ -4139,13 +4184,12 @@ func TestTransportFlushesBodyChunks(t *testing.T) { } // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. -func TestTransportFlushesRequestHeader(t *testing.T) { - defer afterTest(t) +func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) } +func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { gotReq := make(chan struct{}) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(gotReq) })) - defer cst.close() pr, pw := io.Pipe() req, err := NewRequest("POST", cst.ts.URL, pr) @@ -4174,20 +4218,21 @@ func TestTransportFlushesRequestHeader(t *testing.T) { // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { + run(t, testTransportPrefersResponseOverWriteError) +} +func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - defer afterTest(t) const contentLengthLimit = 1024 * 1024 // 1MB - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.ContentLength >= contentLengthLimit { w.WriteHeader(StatusBadRequest) r.Body.Close() return } w.WriteHeader(StatusOK) - })) - defer ts.Close() + })).ts c := ts.Client() fail := 0 @@ -4295,12 +4340,13 @@ func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { // Plus it's nice to be consistent and not have timing-dependent // behavior. func TestTransportReuseConnEmptyResponseBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportReuseConnEmptyResponseBody) +} +func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) // Empty response body. })) - defer cst.close() n := 100 if testing.Short() { n = 10 @@ -4406,31 +4452,43 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { } func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { - testTransportReuseConnection_Gzip(t, true) + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, true) + }) } func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { - testTransportReuseConnection_Gzip(t, false) + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, false) + }) } // Make sure we re-use underlying TCP connection for gzipped responses too. -func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { - setParallel(t) - defer afterTest(t) +func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) { addr := make(chan string, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addr <- r.RemoteAddr w.Header().Set("Content-Encoding", "gzip") if chunked { w.(Flusher).Flush() } w.Write(rgz) // arbitrary gzip response - })) - defer ts.Close() + })).ts c := ts.Client() + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) }, + GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) }, + PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) }, + ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) }, + ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + for i := 0; i < 2; i++ { - res, err := c.Get(ts.URL) + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(ctx) + res, err := c.Do(req) if err != nil { t.Fatal(err) } @@ -4448,15 +4506,16 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { } } -func TestTransportResponseHeaderLength(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) } +func testTransportResponseHeaderLength(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes") + } + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { w.Header().Set("Long", strings.Repeat("a", 1<<20)) } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 @@ -4483,8 +4542,11 @@ func TestTransportResponseHeaderLength(t *testing.T) { } func TestTransportEventTraceTLSVerify(t *testing.T) { + run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode}) +} +func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) { var mu sync.Mutex - var buf bytes.Buffer + var buf strings.Builder logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() @@ -4492,14 +4554,14 @@ func TestTransportEventTraceTLSVerify(t *testing.T) { buf.WriteByte('\n') } - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("Unexpected request") - })) - defer ts.Close() - ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { - logf("%s", p) - return len(p), nil - }), "", 0) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { + logf("%s", p) + return len(p), nil + }), "", 0) + }).ts certpool := x509.NewCertPool() certpool.AddCert(ts.Certificate()) @@ -4537,7 +4599,7 @@ func TestTransportEventTraceTLSVerify(t *testing.T) { wantOnce("TLSHandshakeStart") wantOnce("TLSHandshakeDone") - wantOnce("err = x509: certificate is valid for example.com") + wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com") if t.Failed() { t.Errorf("Output:\n%s", got) @@ -4583,9 +4645,10 @@ func TestTransportRejectsAlphaPort(t *testing.T) { // Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTrace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts var mu sync.Mutex var start, done bool @@ -4627,27 +4690,24 @@ func TestTLSHandshakeTrace(t *testing.T) { } } -func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } -func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } -func testTransportIdleConnTimeout(t *testing.T, h2 bool) { +func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) } +func testTransportIdleConnTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - defer afterTest(t) const timeout = 1 * time.Second - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) - defer cst.close() tr := cst.tr tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} idleConns := func() []string { - if h2 { + if mode == http2Mode { return tr.IdleConnStrsForTesting_h2() } else { return tr.IdleConnStrsForTesting() @@ -4701,12 +4761,11 @@ func testTransportIdleConnTimeout(t *testing.T, h2 bool) { // real connection until after the RoundTrip saw the error. Then we // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. -func TestIdleConnH2Crash(t *testing.T) { - setParallel(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) } +func testIdleConnH2Crash(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) - defer cst.close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -4798,9 +4857,11 @@ func TestTransportReturnsPeekError(t *testing.T) { // Issue 13290: send User-Agent in proxy CONNECT func TestTransportProxyConnectHeader(t *testing.T) { - defer afterTest(t) + run(t, testTransportProxyConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } @@ -4811,8 +4872,7 @@ func TestTransportProxyConnectHeader(t *testing.T) { return } c.Close() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { @@ -4842,9 +4902,11 @@ func TestTransportProxyConnectHeader(t *testing.T) { } func TestTransportProxyGetConnectHeader(t *testing.T) { - defer afterTest(t) + run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } @@ -4855,8 +4917,7 @@ func TestTransportProxyGetConnectHeader(t *testing.T) { return } c.Close() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { @@ -5043,14 +5104,15 @@ func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, p // Issue 22330: do not allow the response body to be read when the status code // forbids a response body. func TestNoBodyOnChunked304Response(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoBodyOnChunked304Response, []testMode{http1Mode}) +} +func testNoBodyOnChunked304Response(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) buf.Flush() conn.Close() })) - defer cst.close() // Our test server above is sending back bogus data after the // response (the "0\r\n\r\n" part), which causes the Transport @@ -5103,11 +5165,12 @@ func TestTransportCheckContextDoneEarly(t *testing.T) { // This is the test variant that times out before the server replies with // any response headers. func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { inHandler := make(chan net.Conn, 1) handlerReadReturned := make(chan bool, 1) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5120,7 +5183,6 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { } handlerReadReturned <- true })) - defer cst.close() const timeout = 50 * time.Millisecond cst.c.Timeout = timeout @@ -5155,11 +5217,12 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { // This is the test variant that has the server send response headers // first, and time out during the write of the response body. func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { inHandler := make(chan net.Conn, 1) handlerResult := make(chan error, 1) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "100") w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -5181,7 +5244,6 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } handlerResult <- nil })) - defer cst.close() // Set Timeout to something very long but non-zero to exercise // the codepaths that check for it. But rather than wait for it to fire @@ -5227,11 +5289,12 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode}) +} +func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) { done := make(chan struct{}) defer close(done) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5244,7 +5307,6 @@ func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) <-done })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) req.Header.Set("Upgrade", "foo") @@ -5277,10 +5339,10 @@ func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { } } -func TestTransportCONNECTBidi(t *testing.T) { - defer afterTest(t) +func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) } +func testTransportCONNECTBidi(t *testing.T, mode testMode) { const target = "backend:443" - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("unexpected method %q", r.Method) w.WriteHeader(500) @@ -5311,7 +5373,6 @@ func TestTransportCONNECTBidi(t *testing.T) { brw.Flush() } })) - defer cst.close() pr, pw := io.Pipe() defer pw.Close() req, err := NewRequest("CONNECT", cst.ts.URL, pr) @@ -5408,7 +5469,8 @@ func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { return c.TCPConn.ReadFrom(r) } -func TestTransportRequestWriteRoundTrip(t *testing.T) { +func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) } +func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { f, err := os.CreateTemp("", "net-http-newfilefunc") @@ -5502,7 +5564,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { cst := newClientServerTest( t, - h1Mode, + mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) r.Body.Close() @@ -5510,7 +5572,6 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { }), trFunc, ) - defer cst.close() req, err := NewRequest("PUT", cst.ts.URL, r) if err != nil { @@ -5527,11 +5588,15 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { t.Fatalf("status code = %d; want 200", resp.StatusCode) } - if !tConn.ReadFromCalled && tc.expectedReadFrom { + expectedReadFrom := tc.expectedReadFrom + if mode != http1Mode { + expectedReadFrom = false + } + if !tConn.ReadFromCalled && expectedReadFrom { t.Fatalf("did not call ReadFrom") } - if tConn.ReadFromCalled && !tc.expectedReadFrom { + if tConn.ReadFromCalled && !expectedReadFrom { t.Fatalf("ReadFrom was unexpectedly invoked") } }) @@ -5540,7 +5605,10 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { func TestTransportClone(t *testing.T) { tr := &Transport{ - Proxy: func(*Request) (*url.URL, error) { panic("") }, + Proxy: func(*Request) (*url.URL, error) { panic("") }, + OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { + return nil + }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, Dial: func(network, addr string) (net.Conn, error) { panic("") }, DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, @@ -5613,16 +5681,18 @@ func TestIs408(t *testing.T) { } func TestTransportIgnores408(t *testing.T) { + run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel) +} +func testTransportIgnores408(t *testing.T, mode testMode) { // Not parallel. Relies on mutating the log package's global Output. defer log.SetOutput(log.Writer()) - var logout bytes.Buffer + var logout strings.Builder log.SetOutput(&logout) - defer afterTest(t) const target = "backend:443" - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { nc, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5632,7 +5702,6 @@ func TestTransportIgnores408(t *testing.T) { nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail })) - defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) @@ -5666,9 +5735,10 @@ func TestTransportIgnores408(t *testing.T) { } func TestInvalidHeaderResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testInvalidHeaderResponse, []testMode{http1Mode}) +} +func testInvalidHeaderResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 200 OK\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + @@ -5678,7 +5748,6 @@ func TestInvalidHeaderResponse(t *testing.T) { buf.Flush() conn.Close() })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -5705,10 +5774,12 @@ func (bc *bodyCloser) Read(b []byte) (n int, err error) { // Issue 35015: ensure that Transport closes the body on any error // with an invalid request, as promised by Client.Do docs. func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportClosesBodyOnInvalidRequests) +} +func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Should not have been invoked") - })) - defer cst.Close() + })).ts u, _ := url.Parse(cst.URL) @@ -5773,7 +5844,7 @@ func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { var bc bodyCloser req := tt.req req.Body = &bc - _, err := DefaultClient.Do(tt.req) + _, err := cst.Client().Do(tt.req) if err == nil { t.Fatal("Expected an error") } @@ -5810,8 +5881,10 @@ func (w *breakableConn) Write(b []byte) (n int, err error) { // Issue 34978: don't cache a broken HTTP/2 connection func TestDontCacheBrokenHTTP2Conn(t *testing.T) { - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) - defer cst.close() + run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode}) +} +func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) var brokenState brokenState @@ -5873,7 +5946,9 @@ func TestDontCacheBrokenHTTP2Conn(t *testing.T) { // http.http2noCachedConnError is reported on multiple requests. There should // only be one decrement regardless of the number of failures. func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { - defer afterTest(t) + run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode}) +} +func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -5883,17 +5958,11 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { } }) - ts := httptest.NewUnstartedServer(h) - ts.EnableHTTP2 = true - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, h).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 - if err := ExportHttp2ConfigureTransport(tr); err != nil { - t.Fatalf("ExportHttp2ConfigureTransport: %v", err) - } errCh := make(chan error, 300) doReq := func() { @@ -5962,14 +6031,13 @@ type roundTripFunc func(r *Request) (*Response, error) func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } // Issue 32441: body is not reset after ErrSkipAltProtocol -func TestIssue32441(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) } +func testIssue32441(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. @@ -5986,11 +6054,13 @@ func TestIssue32441(t *testing.T) { // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. func TestTransportRejectsSignInContentLength(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode}) +} +func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "+3") w.Write([]byte("abc")) - })) - defer cst.Close() + })).ts c := cst.Client() res, err := c.Get(cst.URL) @@ -6104,14 +6174,16 @@ func TestErrorWriteLoopRace(t *testing.T) { // Test that a new request which uses the connection of an active request // cannot cause it to be canceled as well. func TestCancelRequestWhenSharingConnection(t *testing.T) { + run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode}) +} +func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) { reqc := make(chan chan struct{}, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { ch := make(chan struct{}, 1) reqc <- ch <-ch w.Header().Add("Content-Length", "0") - })) - defer ts.Close() + })).ts client := ts.Client() transport := client.Transport.(*Transport) @@ -6121,7 +6193,8 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - putidlec := make(chan chan struct{}) + putidlec := make(chan chan struct{}, 1) + reqerrc := make(chan error, 1) go func() { defer wg.Done() ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ @@ -6130,24 +6203,31 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { // and wait for the order to proceed. ch := make(chan struct{}) putidlec <- ch + close(putidlec) // panic if PutIdleConn runs twice for some reason <-ch }, }) req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) res, err := client.Do(req) + reqerrc <- err if err == nil { res.Body.Close() } - if err != nil { - t.Errorf("request 1: got err %v, want nil", err) - } }() // Wait for the first request to receive a response and return the // connection to the idle pool. r1c := <-reqc close(r1c) - idlec := <-putidlec + var idlec chan struct{} + select { + case err := <-reqerrc: + if err != nil { + t.Fatalf("request 1: got err %v, want nil", err) + } + idlec = <-putidlec + case idlec = <-putidlec: + } wg.Add(1) cancelctx, cancel := context.WithCancel(context.Background()) @@ -6161,6 +6241,9 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { if !errors.Is(err, context.Canceled) { t.Errorf("request 2: got err %v, want Canceled", err) } + + // Unblock the first request. + close(idlec) }() // Wait for the second request to arrive at the server, and then cancel @@ -6168,23 +6251,18 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { r2c := <-reqc cancel() - // Give the cancellation a moment to take effect, and then unblock the first request. - time.Sleep(1 * time.Millisecond) - close(idlec) + <-idlec close(r2c) wg.Wait() } -func TestHandlerAbortRacesBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { +func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) } +func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go io.Copy(io.Discard, req.Body) panic(ErrAbortHandler) - })) - defer ts.Close() + })).ts var wg sync.WaitGroup for i := 0; i < 2; i++ { diff --git a/triv.go b/triv.go index 83444f71..d5b562dd 100644 --- a/triv.go +++ b/triv.go @@ -7,7 +7,6 @@ package main import ( - "bytes" "expvar" "flag" "fmt" @@ -16,6 +15,7 @@ import ( "os" "os/exec" "strconv" + "strings" "sync" "github.com/ooni/oohttp" @@ -50,8 +50,8 @@ func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) { case "GET": ctr.n++ case "POST": - buf := new(bytes.Buffer) - io.Copy(buf, req.Body) + var buf strings.Builder + io.Copy(&buf, req.Body) body := buf.String() if n, err := strconv.Atoi(body); err != nil { fmt.Fprintf(w, "bad POST: %v\nbody: [%v]\n", err, body) @@ -102,7 +102,7 @@ func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) { io.WriteString(w, fmt.Sprintf("channel send #%d\n", <-ch)) } -// exec a program, redirecting output +// exec a program, redirecting output. func DateServer(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -119,7 +119,7 @@ func Logger(w http.ResponseWriter, req *http.Request) { http.Error(w, "oops", http.StatusNotFound) } -var webroot = flag.String("root", os.Getenv("HOME"), "web root directory") +var webroot = flag.String("root", "", "web root directory") func main() { flag.Parse() @@ -129,11 +129,13 @@ func main() { expvar.Publish("counter", ctr) http.Handle("/counter", ctr) http.Handle("/", http.HandlerFunc(Logger)) - http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot)))) + if *webroot != "" { + http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot)))) + } http.Handle("/chan", ChanCreate()) http.HandleFunc("/flags", FlagServer) http.HandleFunc("/args", ArgServer) http.HandleFunc("/go/hello", HelloServer) http.HandleFunc("/date", DateServer) - log.Fatal(http.ListenAndServe(":12345", nil)) + log.Fatal(http.ListenAndServe("localhost:12345", nil)) }