diff --git a/middleware/compress_test.go b/middleware/compress_test.go index eaafc13b..43e5773f 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -1,6 +1,7 @@ package middleware import ( + "bytes" "compress/flate" "compress/gzip" "fmt" @@ -43,7 +44,7 @@ func TestCompressor(t *testing.T) { }) r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Type", "text/plain") w.Write([]byte("textstring")) }) @@ -55,18 +56,21 @@ func TestCompressor(t *testing.T) { path string expectedEncoding string acceptedEncodings []string + checkRawResponse bool }{ { name: "no expected encodings due to no accepted encodings", path: "/gethtml", - acceptedEncodings: nil, + acceptedEncodings: []string{""}, expectedEncoding: "", + checkRawResponse: true, }, { name: "no expected encodings due to content type", path: "/getplain", acceptedEncodings: nil, expectedEncoding: "", + checkRawResponse: true, }, { name: "gzip is only encoding", @@ -92,12 +96,16 @@ func TestCompressor(t *testing.T) { path: "/getcss", acceptedEncodings: []string{"nop, gzip, deflate"}, expectedEncoding: "nop", + checkRawResponse: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { + if tc.checkRawResponse { + testRequestRawResponse(t, ts, "GET", tc.path, []byte("textstring"), tc.acceptedEncodings...) + } resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...) if respString != "textstring" { t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString) @@ -171,6 +179,35 @@ func TestCompressorWildcards(t *testing.T) { } } +func testRequestRawResponse(t *testing.T, ts *httptest.Server, method, path string, exp []byte, encodings ...string) { + req, err := http.NewRequest(method, ts.URL+path, nil) + if err != nil { + t.Fatal(err) + return + } + if len(encodings) > 0 { + encodingsString := strings.Join(encodings, ",") + req.Header.Set("Accept-Encoding", encodingsString) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + return + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + return + } + defer resp.Body.Close() + + if !bytes.Equal(respBody, exp) { + t.Errorf("expected %q but got %q", exp, respBody) + } +} + func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) { req, err := http.NewRequest(method, ts.URL+path, nil) if err != nil {