From 67be7d9cafdaeb4e04e887ff78d09e030ee43b00 Mon Sep 17 00:00:00 2001 From: Patryk Kalinowski Date: Fri, 28 Jun 2024 16:29:27 +0200 Subject: [PATCH] middleware: add Discard method to WrapResponseWriter (#926) * middleware: add Discard method to WrapResponseWriter * resolve review comments * use ioutil.Discard and deprecate the public interface * move the Discard method back to the public interface * discard calls to WriteHeader too --- middleware/wrap_writer.go | 35 ++++++++++++++----- middleware/wrap_writer_test.go | 62 ++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/middleware/wrap_writer.go b/middleware/wrap_writer.go index cf5c44de..bf270881 100644 --- a/middleware/wrap_writer.go +++ b/middleware/wrap_writer.go @@ -6,6 +6,7 @@ package middleware import ( "bufio" "io" + "io/ioutil" "net" "net/http" ) @@ -61,6 +62,11 @@ type WrapResponseWriter interface { Tee(io.Writer) // Unwrap returns the original proxied target. Unwrap() http.ResponseWriter + // Discard causes all writes to the original ResponseWriter be discarded, + // instead writing only to the tee'd writer if it's set. + // The caller is responsible for calling WriteHeader and Write on the + // original ResponseWriter once the processing is done. + Discard() } // basicWriter wraps a http.ResponseWriter that implements the minimal @@ -71,25 +77,34 @@ type basicWriter struct { code int bytes int tee io.Writer + discard bool } func (b *basicWriter) WriteHeader(code int) { if !b.wroteHeader { b.code = code b.wroteHeader = true - b.ResponseWriter.WriteHeader(code) + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } } } -func (b *basicWriter) Write(buf []byte) (int, error) { +func (b *basicWriter) Write(buf []byte) (n int, err error) { b.maybeWriteHeader() - n, err := b.ResponseWriter.Write(buf) - if b.tee != nil { - _, err2 := b.tee.Write(buf[:n]) - // Prefer errors generated by the proxied writer. - if err == nil { - err = err2 + if !b.discard { + n, err = b.ResponseWriter.Write(buf) + if b.tee != nil { + _, err2 := b.tee.Write(buf[:n]) + // Prefer errors generated by the proxied writer. + if err == nil { + err = err2 + } } + } else if b.tee != nil { + n, err = b.tee.Write(buf) + } else { + n, err = ioutil.Discard.Write(buf) } b.bytes += n return n, err @@ -117,6 +132,10 @@ func (b *basicWriter) Unwrap() http.ResponseWriter { return b.ResponseWriter } +func (b *basicWriter) Discard() { + b.discard = true +} + // flushWriter ... type flushWriter struct { basicWriter diff --git a/middleware/wrap_writer_test.go b/middleware/wrap_writer_test.go index 2c442ada..7e8f6ab2 100644 --- a/middleware/wrap_writer_test.go +++ b/middleware/wrap_writer_test.go @@ -1,6 +1,8 @@ package middleware import ( + "bytes" + "net/http" "net/http/httptest" "testing" ) @@ -22,3 +24,63 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { t.Fatal("want Flush to have set wroteHeader=true") } } + +func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) { + // explicitly create the struct instead of NewRecorder to control the value of Code + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } + wrap := &basicWriter{ResponseWriter: original} + + var buf bytes.Buffer + wrap.Tee(&buf) + + _, err := wrap.Write([]byte("hello world")) + assertNoError(t, err) + + assertEqual(t, 200, original.Code) + assertEqual(t, []byte("hello world"), original.Body.Bytes()) + assertEqual(t, []byte("hello world"), buf.Bytes()) + assertEqual(t, 11, wrap.BytesWritten()) +} + +func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) { + t.Run("With Tee", func(t *testing.T) { + // explicitly create the struct instead of NewRecorder to control the value of Code + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } + wrap := &basicWriter{ResponseWriter: original} + + var buf bytes.Buffer + wrap.Tee(&buf) + wrap.Discard() + + _, err := wrap.Write([]byte("hello world")) + assertNoError(t, err) + + assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly + assertEqual(t, 0, original.Body.Len()) + assertEqual(t, []byte("hello world"), buf.Bytes()) + assertEqual(t, 11, wrap.BytesWritten()) + }) + + t.Run("Without Tee", func(t *testing.T) { + // explicitly create the struct instead of NewRecorder to control the value of Code + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } + wrap := &basicWriter{ResponseWriter: original} + wrap.Discard() + + _, err := wrap.Write([]byte("hello world")) + assertNoError(t, err) + + assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly + assertEqual(t, 0, original.Body.Len()) + assertEqual(t, 11, wrap.BytesWritten()) + }) +}