diff --git a/marshal.go b/marshal.go index 0d55c6c..376d4bb 100644 --- a/marshal.go +++ b/marshal.go @@ -118,93 +118,95 @@ func marshalImpl(vo reflect.Value, w io.Writer, pos uint64, parent *Element, opt if tag.name == "" { tag.name = tn.Name } - if t, err := ElementTypeFromString(tag.name); err == nil { - e, ok := table[t] - if !ok { - return pos, ErrUnsupportedElement - } + t, err := ElementTypeFromString(tag.name) + if err != nil { + return pos, err + } + e, ok := table[t] + if !ok { + return pos, ErrUnsupportedElement + } - unknown := tag.size == sizeUnknown + unknown := tag.size == sizeUnknown - lst, ok := pealElem(vn, e.t == TypeBinary, tag.omitEmpty) - if !ok { - continue + lst, ok := pealElem(vn, e.t == TypeBinary, tag.omitEmpty) + if !ok { + continue + } + + for _, vn := range lst { + // Write element ID + var headerSize uint64 + n, err := w.Write(e.b) + if err != nil { + return pos, err } + headerSize += uint64(n) - for _, vn := range lst { - // Write element ID - var headerSize uint64 - n, err := w.Write(e.b) + var bw io.Writer + if unknown { + // Directly write length unspecified element + bsz := encodeDataSize(uint64(sizeUnknown), 0) + n, err := w.Write(bsz) if err != nil { return pos, err } headerSize += uint64(n) + bw = w + } else { + bw = &bytes.Buffer{} + } - var bw io.Writer - if unknown { - // Directly write length unspecified element - bsz := encodeDataSize(uint64(sizeUnknown), 0) - n, err := w.Write(bsz) - if err != nil { - return pos, err - } - headerSize += uint64(n) - bw = w - } else { - bw = &bytes.Buffer{} + var elem *Element + if len(options.hooks) > 0 { + elem = &Element{ + Value: vn.Interface(), + Name: tag.name, + Position: pos, + Size: sizeUnknown, + Parent: parent, } + } - var elem *Element - if len(options.hooks) > 0 { - elem = &Element{ - Value: vn.Interface(), - Name: tag.name, - Position: pos, - Size: sizeUnknown, - Parent: parent, - } + var size uint64 + if e.t == TypeMaster { + p, err := marshalImpl(vn, bw, pos+headerSize, elem, options) + if err != nil { + return pos, err } - - var size uint64 - if e.t == TypeMaster { - p, err := marshalImpl(vn, bw, pos+headerSize, elem, options) - if err != nil { - return pos, err - } - size = p - pos - headerSize - } else { - bc, err := perTypeEncoder[e.t](vn.Interface(), tag.size) - if err != nil { - return pos, err - } - n, err := bw.Write(bc) - if err != nil { - return pos, err - } - size = uint64(n) + size = p - pos - headerSize + } else { + bc, err := perTypeEncoder[e.t](vn.Interface(), tag.size) + if err != nil { + return pos, err + } + n, err := bw.Write(bc) + if err != nil { + return pos, err } + size = uint64(n) + } - // Write element with length - if !unknown { - if len(options.hooks) > 0 { - elem.Size = size - } - bsz := encodeDataSize(size, options.dataSizeLen) - n, err := w.Write(bsz) - if err != nil { - return pos, err - } - headerSize += uint64(n) - - if _, err := w.Write(bw.(*bytes.Buffer).Bytes()); err != nil { - return pos, err - } + // Write element with length + if !unknown { + if len(options.hooks) > 0 { + elem.Size = size } - for _, cb := range options.hooks { - cb(elem) + bsz := encodeDataSize(size, options.dataSizeLen) + n, err := w.Write(bsz) + if err != nil { + return pos, err + } + headerSize += uint64(n) + + if _, err := w.Write(bw.(*bytes.Buffer).Bytes()); err != nil { + return pos, err } - pos += headerSize + size } + for _, cb := range options.hooks { + cb(elem) + } + pos += headerSize + size } } return pos, nil diff --git a/marshal_test.go b/marshal_test.go index 577be73..4b3770d 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -207,6 +207,28 @@ func TestMarshal(t *testing.T) { } } +func TestMarshal_Error(t *testing.T) { + testCases := map[string]struct { + input interface{} + err error + }{ + "InvalidElementName": { + &struct { + Invalid uint64 `ebml:"Invalid"` + }{}, + ErrUnknownElementName, + }, + } + for n, c := range testCases { + t.Run(n, func(t *testing.T) { + var b bytes.Buffer + if err := Marshal(c.input, &b); err != c.err { + t.Fatalf("Unexpected error, expected: %v, got: %v", c.err, err) + } + }) + } +} + func TestMarshal_OptionError(t *testing.T) { errExpected := errors.New("an error") err := Marshal(&struct{}{}, &bytes.Buffer{}, diff --git a/mkvcore/blockwriter_test.go b/mkvcore/blockwriter_test.go index f2948e4..2c5d4ff 100644 --- a/mkvcore/blockwriter_test.go +++ b/mkvcore/blockwriter_test.go @@ -125,12 +125,9 @@ func TestBlockWriter(t *testing.T) { func TestBlockWriter_Options(t *testing.T) { buf := &bufferCloser{closed: make(chan struct{})} - tracks := []TrackDescription{ - {TrackNumber: 1}, - } - ws, err := NewSimpleBlockWriter( - buf, tracks, + buf, + []TrackDescription{{TrackNumber: 1}}, WithEBMLHeader(&struct { DocTypeVersion uint64 `ebml:"EBMLDocTypeVersion"` }{}), @@ -251,9 +248,6 @@ func (w *errorWriter) Close() error { } func TestBlockWriter_ErrorHandling(t *testing.T) { - tracks := []TrackDescription{ - {TrackNumber: 1}, - } const ( atBeginning int = iota @@ -289,7 +283,8 @@ func TestBlockWriter_ErrorHandling(t *testing.T) { } clearErr() ws, err := NewSimpleBlockWriter( - w, tracks, + w, + []TrackDescription{{TrackNumber: 1}}, WithOnErrorHandler(func(err error) { chError <- err }), WithOnFatalHandler(func(err error) { chFatal <- err }), ) @@ -398,12 +393,9 @@ func TestBlockWriter_ErrorHandling(t *testing.T) { func TestBlockWriter_WithMaxKeyframeInterval(t *testing.T) { buf := &bufferCloser{closed: make(chan struct{})} - tracks := []TrackDescription{ - {TrackNumber: 1}, - } - ws, err := NewSimpleBlockWriter( - buf, tracks, + buf, + []TrackDescription{{TrackNumber: 1}}, WithEBMLHeader(nil), WithSegmentInfo(nil), WithMaxKeyframeInterval(1, 900*0x6FFF), @@ -460,51 +452,65 @@ func TestBlockWriter_WithMaxKeyframeInterval(t *testing.T) { } func TestBlockWriter_WithSeekHead(t *testing.T) { - buf := &bufferCloser{closed: make(chan struct{})} - - tracks := []TrackDescription{ - {TrackNumber: 1}, - } - - ws, err := NewSimpleBlockWriter( - buf, tracks, - WithEBMLHeader(nil), - WithSegmentInfo(&struct { - TimecodeScale uint64 `ebml:"TimecodeScale"` - }{TimecodeScale: 1000000}), - WithSeekHead(true), - ) - if err != nil { - t.Fatalf("Failed to create BlockWriter: %v", err) - } - if len(ws) != 1 { - t.Fatalf("Number of the returned writer must be 1, but got %d", len(ws)) - } - - ws[0].Close() + t.Run("GenerateSeekHead", func(t *testing.T) { + buf := &bufferCloser{closed: make(chan struct{})} + + ws, err := NewSimpleBlockWriter( + buf, + []TrackDescription{{TrackNumber: 1}}, + WithEBMLHeader(nil), + WithSegmentInfo(&struct { + TimecodeScale uint64 `ebml:"TimecodeScale"` + }{TimecodeScale: 1000000}), + WithSeekHead(true), + ) + if err != nil { + t.Fatalf("Failed to create BlockWriter: %v", err) + } + if len(ws) != 1 { + t.Fatalf("Number of the returned writer must be 1, but got %d", len(ws)) + } - expectedBytes := []byte{ - // 1 2 3 4 5 6 7 8 9 10 11 12 - // Segment - 0x18, 0x53, 0x80, 0x67, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - // SeekHead - 0x11, 0x4D, 0x9B, 0x74, 0xAA, - 0x4D, 0xBB, 0x92, - 0x53, 0xAB, 0x84, 0x15, 0x49, 0xA9, 0x66, // Info - 0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2F, - 0x4D, 0xBB, 0x92, - 0x53, 0xAB, 0x84, 0x16, 0x54, 0xAE, 0x6B, // Tracks - 0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3B, - // Info, pos: 47 - 0x15, 0x49, 0xA9, 0x66, 0x87, - 0x2A, 0xD7, 0xB1, 0x83, 0x0F, 0x42, 0x40, - // Tracks, pos: 59 - 0x16, 0x54, 0xAE, 0x6B, 0x80, - // Cluster - 0x1F, 0x43, 0xB6, 0x75, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xE7, 0x81, 0x00, - } - if !bytes.Equal(buf.Bytes(), expectedBytes) { - t.Errorf("Unexpected binary,\nexpected: %+v\n got: %+v", expectedBytes, buf.Bytes()) - } + ws[0].Close() + + expectedBytes := []byte{ + // 1 2 3 4 5 6 7 8 9 10 11 12 + // Segment + 0x18, 0x53, 0x80, 0x67, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + // SeekHead + 0x11, 0x4D, 0x9B, 0x74, 0xAA, + 0x4D, 0xBB, 0x92, + 0x53, 0xAB, 0x84, 0x15, 0x49, 0xA9, 0x66, // Info + 0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2F, + 0x4D, 0xBB, 0x92, + 0x53, 0xAB, 0x84, 0x16, 0x54, 0xAE, 0x6B, // Tracks + 0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3B, + // Info, pos: 47 + 0x15, 0x49, 0xA9, 0x66, 0x87, + 0x2A, 0xD7, 0xB1, 0x83, 0x0F, 0x42, 0x40, + // Tracks, pos: 59 + 0x16, 0x54, 0xAE, 0x6B, 0x80, + // Cluster + 0x1F, 0x43, 0xB6, 0x75, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xE7, 0x81, 0x00, + } + if !bytes.Equal(buf.Bytes(), expectedBytes) { + t.Errorf("Unexpected binary,\nexpected: %+v\n got: %+v", expectedBytes, buf.Bytes()) + } + }) + t.Run("InvalidHeader", func(t *testing.T) { + buf := &bufferCloser{closed: make(chan struct{})} + + _, err := NewSimpleBlockWriter( + buf, + []TrackDescription{{TrackNumber: 1}}, + WithSegmentInfo(&struct { + Invalid uint64 `ebml:"InvalidA"` + }{}), + WithSeekHead(true), + ) + if err != ebml.ErrUnknownElementName { + t.Errorf("Unexpected error, expected: %v, got: %v", ebml.ErrUnknownElementName, err) + } + }) } diff --git a/mkvcore/seekhead.go b/mkvcore/seekhead.go index 40b6836..5ff3921 100644 --- a/mkvcore/seekhead.go +++ b/mkvcore/seekhead.go @@ -52,8 +52,8 @@ func setSeekHead(header *flexHeader, opts ...ebml.MarshalOption) error { optsWithHook := append([]ebml.MarshalOption{}, opts...) optsWithHook = append(optsWithHook, ebml.WithElementWriteHooks(hook)) - buf := &bytes.Buffer{} - if err := ebml.Marshal(header, buf, optsWithHook...); err != nil { + var buf bytes.Buffer + if err := ebml.Marshal(header, &buf, optsWithHook...); err != nil { return err }