From f10dc4a9cab53a47f600f518fcccaa044666e95f Mon Sep 17 00:00:00 2001 From: Maksim Terekhin Date: Sat, 8 Jun 2024 01:13:37 +0200 Subject: [PATCH] fix(middleware): Close created writer in the compressor middleware (#919) --- middleware/compress.go | 2 +- middleware/compress_test.go | 46 +++++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index 28240c4b..7ba95fd5 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -371,7 +371,7 @@ func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) er } func (cw *compressResponseWriter) Close() error { - if c, ok := cw.writer().(io.WriteCloser); ok { + if c, ok := cw.w.(io.WriteCloser); ok { return c.Close() } return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer") diff --git a/middleware/compress_test.go b/middleware/compress_test.go index eaafc13b..992f1d40 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -26,8 +26,13 @@ func TestCompressor(t *testing.T) { return w }) - if len(compressor.encoders) != 1 { - t.Errorf("nop encoder should be stored in the encoders map") + var sideEffect int + compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer { + return newSideEffectWriter(w, &sideEffect) + }) + + if len(compressor.encoders) != 2 { + t.Errorf("nop and test encoders should be stored in the encoders map") } r.Use(compressor.Handler) @@ -47,6 +52,11 @@ func TestCompressor(t *testing.T) { w.Write([]byte("textstring")) }) + r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.Write([]byte("textstring")) + }) + ts := httptest.NewServer(r) defer ts.Close() @@ -93,6 +103,12 @@ func TestCompressor(t *testing.T) { acceptedEncodings: []string{"nop, gzip, deflate"}, expectedEncoding: "nop", }, + { + name: "test is used and side effect is cleared after close", + path: "/getimage", + acceptedEncodings: []string{"test"}, + expectedEncoding: "", + }, } for _, tc := range tests { @@ -107,7 +123,10 @@ func TestCompressor(t *testing.T) { } }) + } + if sideEffect > 1 { + t.Errorf("side effect should be cleared after close") } } @@ -217,3 +236,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string { return string(respBody) } + +type ( + sideEffectWriter struct { + w io.Writer + s *int + } +) + +func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer { + *sideEffect = *sideEffect + 1 + + return &sideEffectWriter{w: w, s: sideEffect} +} + +func (w *sideEffectWriter) Write(p []byte) (n int, err error) { + return w.w.Write(p) +} + +func (w *sideEffectWriter) Close() error { + *w.s = *w.s - 1 + + return nil +}