Skip to content

Commit

Permalink
Fix StreamCopier not including file headers into output (#6)
Browse files Browse the repository at this point in the history
Fix StreamCopier not including file headers into output; for #4
  • Loading branch information
gabriel-vasile authored Jul 29, 2022
1 parent 8e7cbb2 commit fca9346
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
42 changes: 27 additions & 15 deletions copier/stream_copier.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package copier

import (
"bufio"
"bytes"
"errors"
"github.com/gabriel-vasile/mimetype"
"github.com/google/uuid"
"github.com/rs/zerolog"
"io"
"mime"
"net/http"
"os"

"github.com/gabriel-vasile/mimetype"
"github.com/google/uuid"
"github.com/rs/zerolog"
)

const (
Expand Down Expand Up @@ -60,7 +61,7 @@ func (d *StreamCopier) CopyStream(url string, getOutput GetOutputFunc) error {

log.Info().Msg("recording started")

fileExtension, err := DetectExtension(resp)
fileExtension, body, err := DetectExtension(resp)
if err != nil {
return err
}
Expand All @@ -79,17 +80,21 @@ func (d *StreamCopier) CopyStream(url string, getOutput GetOutputFunc) error {
log.Debug().Str("filename", file.Name()).Msg("output in file")
}

bytesCopied, err := io.Copy(output, resp.Body)
bytesCopied, err := io.Copy(output, body)
log.Debug().Int64("bytes_copied", bytesCopied).Msg("copied bytes")

log.Info().Msg("recording finished")
return err
}

func DetectExtension(r *http.Response) (string, error) {
// DetectExtension returns response extension by first looking into response headers.
// As a fallback, it looks into response body and returns the extension and a new
// body containing the original content.
func DetectExtension(r *http.Response) (string, io.Reader, error) {
contentType := r.Header.Get("Content-Type")
if contentType != "" {
return extensionFromContentType(contentType)
ext, err := extensionFromContentType(contentType)
return ext, r.Body, err
}
return extensionFromBody(r.Body)
}
Expand All @@ -112,11 +117,18 @@ func isSupportedContentType(contentType string) bool {
return err == nil
}

func extensionFromBody(body io.Reader) (string, error) {
buf := bufio.NewReader(body)
fileHeader, err := buf.Peek(fileHeaderSize)
if err != nil && err != io.EOF {
return "", err
}
return mimetype.Detect(fileHeader).Extension(), nil
// extensionFromBody returns the extension of the file contained by body and a
// new body containing the original input file.
func extensionFromBody(body io.Reader) (ext string, newBody io.Reader, err error) {
// header will store the bytes mimetype uses for detection.
header := bytes.NewBuffer(nil)

// After DetectReader, the data read from input is copied into header.
mtype, err := mimetype.DetectReader(io.TeeReader(body, header))

// Concatenate back the header to the rest of the file.
// newBody now contains the complete, original data.
newBody = io.MultiReader(header, body)

return mtype.Extension(), newBody, err
}
34 changes: 31 additions & 3 deletions copier/stream_copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package copier

import (
"bytes"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"io"
"log"
"net/http"
"net/http/httptest"
"testing"

"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)

var testLogger = zerolog.New(nil)
Expand Down Expand Up @@ -165,10 +166,37 @@ func TestDetectExtension(t *testing.T) {
},
}

actual, err := DetectExtension(r)
actual, _, err := DetectExtension(r)

assert.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
}
}

// Check if the files copied by StreamCopier have the same length as the original
// files.
func TestStreamCopierFileLength(t *testing.T) {
response := []byte("Hello World!")
handler := func(w http.ResponseWriter, _ *http.Request) {
// Force extension detection to look into the response body.
w.Header().Set("Content-type", "")
w.WriteHeader(http.StatusOK)
_, err := w.Write(response)
if err != nil {
log.Println(err)
}
}
server := httptest.NewServer(http.HandlerFunc(handler))
defer server.Close()

copier := NewStreamCopier(http.DefaultClient, testLogger)
output := new(closableBuffer)

err := copier.CopyStream(server.URL, func(_ string) (io.WriteCloser, error) {
return output, nil
})
assert.Equal(t, err, nil)
// Check if output length is what we expect.
assert.Equal(t, len(response), output.Len())
}

0 comments on commit fca9346

Please sign in to comment.