From e57b2d6520ae27a0afb75c08ec755153c4428a68 Mon Sep 17 00:00:00 2001 From: Maksim Terekhin Date: Tue, 2 Jul 2024 15:26:38 +0200 Subject: [PATCH] feat(compress): Create compress writer only if the content is compressible --- middleware/compress.go | 63 +++++++++++++++++++++++++++++-------- middleware/compress_test.go | 54 +++++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index 00946613..4d209c18 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -193,21 +193,18 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) { // current Compressor. func (c *Compressor) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - encoder, encoding, cleanup := c.selectEncoder(r.Header, w) + encoder, encoding := c.selectEncoder(r.Header, w) cw := &compressResponseWriter{ ResponseWriter: w, - w: w, contentTypes: c.allowedTypes, contentWildcards: c.allowedWildcards, encoding: encoding, compressible: false, // determined in post-handler } if encoder != nil { - cw.w = encoder + cw.encoder = encoder } - // Re-add the encoder to the pool if applicable. - defer cleanup() defer cw.Close() next.ServeHTTP(cw, r) @@ -215,7 +212,7 @@ func (c *Compressor) Handler(next http.Handler) http.Handler { } // selectEncoder returns the encoder, the name of the encoder, and a closer function. -func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) { +func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (func() io.Writer, string) { header := h.Get("Accept-Encoding") // Parse the names of all accepted algorithms from the header. @@ -225,23 +222,31 @@ func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, strin for _, name := range c.encodingPrecedence { if matchAcceptEncoding(accepted, name) { if pool, ok := c.pooledEncoders[name]; ok { - encoder := pool.Get().(ioResetterWriter) - cleanup := func() { - pool.Put(encoder) + fn := func() io.Writer { + enc := pool.Get().(ioResetterWriter) + enc.Reset(w) + return &pooledEncoder{ + Writer: enc, + pool: pool, + } } - encoder.Reset(w) - return encoder, name, cleanup + return fn, name } if fn, ok := c.encoders[name]; ok { - return fn(w, c.level), name, func() {} + fn := func() io.Writer { + return &encoder{ + Writer: fn(w, c.level), + } + } + return fn, name } } } // No encoder found to match the accepted encoding - return nil, "", func() {} + return nil, "" } func matchAcceptEncoding(accepted []string, encoding string) bool { @@ -276,6 +281,8 @@ type compressResponseWriter struct { encoding string wroteHeader bool compressible bool + + encoder func() io.Writer } func (cw *compressResponseWriter) isCompressible() bool { @@ -335,6 +342,9 @@ func (cw *compressResponseWriter) Write(p []byte) (int, error) { func (cw *compressResponseWriter) writer() io.Writer { if cw.compressible { + if cw.w == nil { + cw.w = cw.encoder() + } return cw.w } return cw.ResponseWriter @@ -385,6 +395,33 @@ func (cw *compressResponseWriter) Unwrap() http.ResponseWriter { return cw.ResponseWriter } +type ( + encoder struct { + io.Writer + } + + pooledEncoder struct { + io.Writer + pool *sync.Pool + } +) + +func (e *encoder) Close() error { + if c, ok := e.Writer.(io.WriteCloser); ok { + return c.Close() + } + return nil +} + +func (e *pooledEncoder) Close() error { + var err error + if w, ok := e.Writer.(io.WriteCloser); ok { + err = w.Close() + } + e.pool.Put(e.Writer) + return err +} + func encoderGzip(w io.Writer, level int) io.Writer { gw, err := gzip.NewWriterLevel(w, level) if err != nil { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 43e5773f..211c9094 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -27,8 +27,13 @@ func TestCompressor(t *testing.T) { return w }) - if len(compressor.encoders) != 1 { - t.Errorf("nop encoder should be stored in the encoders map") + var sideEffect int + compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer { + return newSideEffectWriter(w, &sideEffect) + }) + + if len(compressor.encoders) != 2 { + t.Errorf("nop and test encoders should be stored in the encoders map") } r.Use(compressor.Handler) @@ -48,6 +53,11 @@ func TestCompressor(t *testing.T) { w.Write([]byte("textstring")) }) + r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.Write([]byte("textstring")) + }) + ts := httptest.NewServer(r) defer ts.Close() @@ -98,6 +108,20 @@ func TestCompressor(t *testing.T) { expectedEncoding: "nop", checkRawResponse: true, }, + { + name: "test encoder is used", + path: "/getimage", + acceptedEncodings: []string{"test"}, + expectedEncoding: "", + checkRawResponse: true, + }, + { + name: "test encoder is used and Close is called", + path: "/gethtml", + acceptedEncodings: []string{"test"}, + expectedEncoding: "test", + checkRawResponse: true, + }, } for _, tc := range tests { @@ -117,6 +141,9 @@ func TestCompressor(t *testing.T) { }) } + if sideEffect != 0 { + t.Errorf("side effect should be cleared after close") + } } func TestCompressorWildcards(t *testing.T) { @@ -254,3 +281,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string { return string(respBody) } + +type ( + sideEffectWriter struct { + w io.Writer + s *int + } +) + +func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer { + *sideEffect = *sideEffect + 1 + + return &sideEffectWriter{w: w, s: sideEffect} +} + +func (w *sideEffectWriter) Write(p []byte) (n int, err error) { + return w.w.Write(p) +} + +func (w *sideEffectWriter) Close() error { + *w.s = *w.s - 1 + + return nil +}