Skip to content

Commit

Permalink
feat(compress): Create compress writer only if the content is compres…
Browse files Browse the repository at this point in the history
…sible
  • Loading branch information
Neurostep committed Jul 2, 2024
1 parent 2599705 commit 625b26f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 15 deletions.
63 changes: 50 additions & 13 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,29 +193,26 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
// current Compressor.
func (c *Compressor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
encoder, encoding := c.selectEncoder(r.Header, w)

cw := &compressResponseWriter{
ResponseWriter: w,
w: w,
contentTypes: c.allowedTypes,
contentWildcards: c.allowedWildcards,
encoding: encoding,
compressible: false, // determined in post-handler
}
if encoder != nil {
cw.w = encoder
cw.encoder = encoder
}
// Re-add the encoder to the pool if applicable.
defer cleanup()
defer cw.Close()

next.ServeHTTP(cw, r)
})
}

// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (func() io.Writer, string) {
header := h.Get("Accept-Encoding")

// Parse the names of all accepted algorithms from the header.
Expand All @@ -225,23 +222,31 @@ func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, strin
for _, name := range c.encodingPrecedence {
if matchAcceptEncoding(accepted, name) {
if pool, ok := c.pooledEncoders[name]; ok {
encoder := pool.Get().(ioResetterWriter)
cleanup := func() {
pool.Put(encoder)
fn := func() io.Writer {
enc := pool.Get().(ioResetterWriter)
enc.Reset(w)
return &pooledEncoder{
Writer: enc,
pool: pool,
}
}
encoder.Reset(w)
return encoder, name, cleanup
return fn, name

}
if fn, ok := c.encoders[name]; ok {
return fn(w, c.level), name, func() {}
fn := func() io.Writer {
return &encoder{
Writer: fn(w, c.level),
}
}
return fn, name
}
}

}

// No encoder found to match the accepted encoding
return nil, "", func() {}
return nil, ""
}

func matchAcceptEncoding(accepted []string, encoding string) bool {
Expand Down Expand Up @@ -276,6 +281,8 @@ type compressResponseWriter struct {
encoding string
wroteHeader bool
compressible bool

encoder func() io.Writer
}

func (cw *compressResponseWriter) isCompressible() bool {
Expand Down Expand Up @@ -335,6 +342,9 @@ func (cw *compressResponseWriter) Write(p []byte) (int, error) {

func (cw *compressResponseWriter) writer() io.Writer {
if cw.compressible {
if cw.w == nil {
cw.w = cw.encoder()
}
return cw.w
}
return cw.ResponseWriter
Expand Down Expand Up @@ -385,6 +395,33 @@ func (cw *compressResponseWriter) Unwrap() http.ResponseWriter {
return cw.ResponseWriter
}

type (
encoder struct {
io.Writer
}

pooledEncoder struct {
io.Writer
pool *sync.Pool
}
)

func (e *encoder) Close() error {
if c, ok := e.Writer.(io.WriteCloser); ok {
return c.Close()
}
return nil
}

func (e *pooledEncoder) Close() error {
var err error
if w, ok := e.Writer.(io.WriteCloser); ok {
err = w.Close()
}
e.pool.Put(e.Writer)
return err
}

func encoderGzip(w io.Writer, level int) io.Writer {
gw, err := gzip.NewWriterLevel(w, level)
if err != nil {
Expand Down
54 changes: 52 additions & 2 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ func TestCompressor(t *testing.T) {
return w
})

if len(compressor.encoders) != 1 {
t.Errorf("nop encoder should be stored in the encoders map")
var sideEffect int
compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer {
return newSideEffectWriter(w, &sideEffect)
})

if len(compressor.encoders) != 2 {
t.Errorf("nop and test encoders should be stored in the encoders map")
}

r.Use(compressor.Handler)
Expand All @@ -48,6 +53,11 @@ func TestCompressor(t *testing.T) {
w.Write([]byte("textstring"))
})

r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.Write([]byte("textstring"))
})

ts := httptest.NewServer(r)
defer ts.Close()

Expand Down Expand Up @@ -98,6 +108,20 @@ func TestCompressor(t *testing.T) {
expectedEncoding: "nop",
checkRawResponse: true,
},
{
name: "test encoder is used",
path: "/getimage",
acceptedEncodings: []string{"test"},
expectedEncoding: "",
checkRawResponse: true,
},
{
name: "test encoder is used and Close is called",
path: "/gethtml",
acceptedEncodings: []string{"test"},
expectedEncoding: "test",
checkRawResponse: true,
},
}

for _, tc := range tests {
Expand All @@ -117,6 +141,9 @@ func TestCompressor(t *testing.T) {
})

}
if sideEffect != 0 {
t.Errorf("side effect should be cleared after close")
}
}

func TestCompressorWildcards(t *testing.T) {
Expand Down Expand Up @@ -254,3 +281,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string {

return string(respBody)
}

type (
sideEffectWriter struct {
w io.Writer
s *int
}
)

func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer {
*sideEffect = *sideEffect + 1

return &sideEffectWriter{w: w, s: sideEffect}
}

func (w *sideEffectWriter) Write(p []byte) (n int, err error) {
return w.w.Write(p)
}

func (w *sideEffectWriter) Close() error {
*w.s = *w.s - 1

return nil
}

0 comments on commit 625b26f

Please sign in to comment.