diff --git a/copier/file_builder_test.go b/copier/file_builder_test.go index dd76467..7d963c2 100644 --- a/copier/file_builder_test.go +++ b/copier/file_builder_test.go @@ -4,6 +4,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/stretchr/testify/assert" "io/fs" + "mime" "testing" "testing/fstest" "time" @@ -140,6 +141,8 @@ func TestDatedFileBuilder_getFileName(t *testing.T) { } func TestFileDetection(t *testing.T) { + t.Skip("Library doesn't support this") + t.Parallel() // Radio-T stream header @@ -157,12 +160,24 @@ func TestFileDetection(t *testing.T) { 249, 205, 0, 67, 51, 44, 234, 119, 55, 128, 0, 190, 78, 119, 255, 185, 198, } - mime := mimetype.Detect(header) - fileExtension := mime.Extension() - - t.Log(mime) + headerMime := mimetype.Detect(header) + fileExtension := headerMime.Extension() if fileExtension == "" { t.Errorf("File extension not detected") } } + +func TestMimeTypeDetection(t *testing.T) { + t.Parallel() + + const ( + contentType = "audio/mpeg" + expectExtension = ".mp3" + ) + + extension, err := mime.ExtensionsByType(contentType) + + assert.NoError(t, err) + assert.Contains(t, extension, expectExtension) +} diff --git a/copier/stream_copier.go b/copier/stream_copier.go index b3ae75b..bf9c5e8 100644 --- a/copier/stream_copier.go +++ b/copier/stream_copier.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog" "io" + "mime" "net/http" "os" ) @@ -59,13 +60,10 @@ func (d *StreamCopier) CopyStream(url string, getOutput GetOutputFunc) error { log.Info().Msg("recording started") - buf := bufio.NewReader(resp.Body) - fileHeader, err := buf.Peek(fileHeaderSize) - if err != nil && err != io.EOF { + fileExtension, err := DetectExtension(resp) + if err != nil { return err } - mime := mimetype.Detect(fileHeader) - fileExtension := mime.Extension() log.Debug().Str("extension", fileExtension).Msg("detected extension") output, err := getOutput(fileExtension) @@ -81,9 +79,34 @@ func (d *StreamCopier) CopyStream(url string, getOutput GetOutputFunc) error { log.Debug().Str("filename", file.Name()).Msg("output in file") } - bytesCopied, err := io.Copy(output, buf) + bytesCopied, err := io.Copy(output, resp.Body) log.Debug().Int64("bytes_copied", bytesCopied).Msg("copied bytes") log.Info().Msg("recording finished") return err } + +func DetectExtension(r *http.Response) (string, error) { + contentType := r.Header.Get("Content-Type") + if contentType != "" { + return extensionFromContentType(contentType) + } + return extensionFromBody(r.Body) +} + +func extensionFromContentType(contentType string) (string, error) { + extension, err := mime.ExtensionsByType(contentType) + if err != nil { + return "", err + } + return extension[0], nil +} + +func extensionFromBody(body io.Reader) (string, error) { + buf := bufio.NewReader(body) + fileHeader, err := buf.Peek(fileHeaderSize) + if err != nil && err != io.EOF { + return "", err + } + return mimetype.Detect(fileHeader).Extension(), nil +} diff --git a/copier/stream_copier_test.go b/copier/stream_copier_test.go index 505350f..51758c1 100644 --- a/copier/stream_copier_test.go +++ b/copier/stream_copier_test.go @@ -113,3 +113,50 @@ func getHandlerWithInterrupt() http.HandlerFunc { } } } + +func TestDetectExtension(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + header string + body string + expected string + }{ + { + name: "by header", + header: "audio/mpeg", + body: "", + expected: ".mp3", + }, + { + name: "by body", + header: "", + body: "Hello, World!", + expected: ".txt", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := &http.Response{ + Header: http.Header{ + "Content-Type": []string{tc.header}, + }, + Body: &closableBuffer{ + Buffer: *bytes.NewBuffer( + []byte(tc.body), + ), + }, + } + + actual, err := DetectExtension(r) + + assert.NoError(t, err) + assert.Equal(t, tc.expected, actual) + }) + } +}