Skip to content

Commit

Permalink
Minor refactoring (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Dec 17, 2019
1 parent 154310b commit 4c648e0
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 132 deletions.
142 changes: 72 additions & 70 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,93 +118,95 @@ func marshalImpl(vo reflect.Value, w io.Writer, pos uint64, parent *Element, opt
if tag.name == "" {
tag.name = tn.Name
}
if t, err := ElementTypeFromString(tag.name); err == nil {
e, ok := table[t]
if !ok {
return pos, ErrUnsupportedElement
}
t, err := ElementTypeFromString(tag.name)
if err != nil {
return pos, err
}
e, ok := table[t]
if !ok {
return pos, ErrUnsupportedElement
}

unknown := tag.size == sizeUnknown
unknown := tag.size == sizeUnknown

lst, ok := pealElem(vn, e.t == TypeBinary, tag.omitEmpty)
if !ok {
continue
lst, ok := pealElem(vn, e.t == TypeBinary, tag.omitEmpty)
if !ok {
continue
}

for _, vn := range lst {
// Write element ID
var headerSize uint64
n, err := w.Write(e.b)
if err != nil {
return pos, err
}
headerSize += uint64(n)

for _, vn := range lst {
// Write element ID
var headerSize uint64
n, err := w.Write(e.b)
var bw io.Writer
if unknown {
// Directly write length unspecified element
bsz := encodeDataSize(uint64(sizeUnknown), 0)
n, err := w.Write(bsz)
if err != nil {
return pos, err
}
headerSize += uint64(n)
bw = w
} else {
bw = &bytes.Buffer{}
}

var bw io.Writer
if unknown {
// Directly write length unspecified element
bsz := encodeDataSize(uint64(sizeUnknown), 0)
n, err := w.Write(bsz)
if err != nil {
return pos, err
}
headerSize += uint64(n)
bw = w
} else {
bw = &bytes.Buffer{}
var elem *Element
if len(options.hooks) > 0 {
elem = &Element{
Value: vn.Interface(),
Name: tag.name,
Position: pos,
Size: sizeUnknown,
Parent: parent,
}
}

var elem *Element
if len(options.hooks) > 0 {
elem = &Element{
Value: vn.Interface(),
Name: tag.name,
Position: pos,
Size: sizeUnknown,
Parent: parent,
}
var size uint64
if e.t == TypeMaster {
p, err := marshalImpl(vn, bw, pos+headerSize, elem, options)
if err != nil {
return pos, err
}

var size uint64
if e.t == TypeMaster {
p, err := marshalImpl(vn, bw, pos+headerSize, elem, options)
if err != nil {
return pos, err
}
size = p - pos - headerSize
} else {
bc, err := perTypeEncoder[e.t](vn.Interface(), tag.size)
if err != nil {
return pos, err
}
n, err := bw.Write(bc)
if err != nil {
return pos, err
}
size = uint64(n)
size = p - pos - headerSize
} else {
bc, err := perTypeEncoder[e.t](vn.Interface(), tag.size)
if err != nil {
return pos, err
}
n, err := bw.Write(bc)
if err != nil {
return pos, err
}
size = uint64(n)
}

// Write element with length
if !unknown {
if len(options.hooks) > 0 {
elem.Size = size
}
bsz := encodeDataSize(size, options.dataSizeLen)
n, err := w.Write(bsz)
if err != nil {
return pos, err
}
headerSize += uint64(n)

if _, err := w.Write(bw.(*bytes.Buffer).Bytes()); err != nil {
return pos, err
}
// Write element with length
if !unknown {
if len(options.hooks) > 0 {
elem.Size = size
}
for _, cb := range options.hooks {
cb(elem)
bsz := encodeDataSize(size, options.dataSizeLen)
n, err := w.Write(bsz)
if err != nil {
return pos, err
}
headerSize += uint64(n)

if _, err := w.Write(bw.(*bytes.Buffer).Bytes()); err != nil {
return pos, err
}
pos += headerSize + size
}
for _, cb := range options.hooks {
cb(elem)
}
pos += headerSize + size
}
}
return pos, nil
Expand Down
22 changes: 22 additions & 0 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,28 @@ func TestMarshal(t *testing.T) {
}
}

func TestMarshal_Error(t *testing.T) {
testCases := map[string]struct {
input interface{}
err error
}{
"InvalidElementName": {
&struct {
Invalid uint64 `ebml:"Invalid"`
}{},
ErrUnknownElementName,
},
}
for n, c := range testCases {
t.Run(n, func(t *testing.T) {
var b bytes.Buffer
if err := Marshal(c.input, &b); err != c.err {
t.Fatalf("Unexpected error, expected: %v, got: %v", c.err, err)
}
})
}
}

func TestMarshal_OptionError(t *testing.T) {
errExpected := errors.New("an error")
err := Marshal(&struct{}{}, &bytes.Buffer{},
Expand Down
126 changes: 66 additions & 60 deletions mkvcore/blockwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,9 @@ func TestBlockWriter(t *testing.T) {
func TestBlockWriter_Options(t *testing.T) {
buf := &bufferCloser{closed: make(chan struct{})}

tracks := []TrackDescription{
{TrackNumber: 1},
}

ws, err := NewSimpleBlockWriter(
buf, tracks,
buf,
[]TrackDescription{{TrackNumber: 1}},
WithEBMLHeader(&struct {
DocTypeVersion uint64 `ebml:"EBMLDocTypeVersion"`
}{}),
Expand Down Expand Up @@ -251,9 +248,6 @@ func (w *errorWriter) Close() error {
}

func TestBlockWriter_ErrorHandling(t *testing.T) {
tracks := []TrackDescription{
{TrackNumber: 1},
}

const (
atBeginning int = iota
Expand Down Expand Up @@ -289,7 +283,8 @@ func TestBlockWriter_ErrorHandling(t *testing.T) {
}
clearErr()
ws, err := NewSimpleBlockWriter(
w, tracks,
w,
[]TrackDescription{{TrackNumber: 1}},
WithOnErrorHandler(func(err error) { chError <- err }),
WithOnFatalHandler(func(err error) { chFatal <- err }),
)
Expand Down Expand Up @@ -398,12 +393,9 @@ func TestBlockWriter_ErrorHandling(t *testing.T) {
func TestBlockWriter_WithMaxKeyframeInterval(t *testing.T) {
buf := &bufferCloser{closed: make(chan struct{})}

tracks := []TrackDescription{
{TrackNumber: 1},
}

ws, err := NewSimpleBlockWriter(
buf, tracks,
buf,
[]TrackDescription{{TrackNumber: 1}},
WithEBMLHeader(nil),
WithSegmentInfo(nil),
WithMaxKeyframeInterval(1, 900*0x6FFF),
Expand Down Expand Up @@ -460,51 +452,65 @@ func TestBlockWriter_WithMaxKeyframeInterval(t *testing.T) {
}

func TestBlockWriter_WithSeekHead(t *testing.T) {
buf := &bufferCloser{closed: make(chan struct{})}

tracks := []TrackDescription{
{TrackNumber: 1},
}

ws, err := NewSimpleBlockWriter(
buf, tracks,
WithEBMLHeader(nil),
WithSegmentInfo(&struct {
TimecodeScale uint64 `ebml:"TimecodeScale"`
}{TimecodeScale: 1000000}),
WithSeekHead(true),
)
if err != nil {
t.Fatalf("Failed to create BlockWriter: %v", err)
}
if len(ws) != 1 {
t.Fatalf("Number of the returned writer must be 1, but got %d", len(ws))
}

ws[0].Close()
t.Run("GenerateSeekHead", func(t *testing.T) {
buf := &bufferCloser{closed: make(chan struct{})}

ws, err := NewSimpleBlockWriter(
buf,
[]TrackDescription{{TrackNumber: 1}},
WithEBMLHeader(nil),
WithSegmentInfo(&struct {
TimecodeScale uint64 `ebml:"TimecodeScale"`
}{TimecodeScale: 1000000}),
WithSeekHead(true),
)
if err != nil {
t.Fatalf("Failed to create BlockWriter: %v", err)
}
if len(ws) != 1 {
t.Fatalf("Number of the returned writer must be 1, but got %d", len(ws))
}

expectedBytes := []byte{
// 1 2 3 4 5 6 7 8 9 10 11 12
// Segment
0x18, 0x53, 0x80, 0x67, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
// SeekHead
0x11, 0x4D, 0x9B, 0x74, 0xAA,
0x4D, 0xBB, 0x92,
0x53, 0xAB, 0x84, 0x15, 0x49, 0xA9, 0x66, // Info
0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2F,
0x4D, 0xBB, 0x92,
0x53, 0xAB, 0x84, 0x16, 0x54, 0xAE, 0x6B, // Tracks
0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3B,
// Info, pos: 47
0x15, 0x49, 0xA9, 0x66, 0x87,
0x2A, 0xD7, 0xB1, 0x83, 0x0F, 0x42, 0x40,
// Tracks, pos: 59
0x16, 0x54, 0xAE, 0x6B, 0x80,
// Cluster
0x1F, 0x43, 0xB6, 0x75, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xE7, 0x81, 0x00,
}
if !bytes.Equal(buf.Bytes(), expectedBytes) {
t.Errorf("Unexpected binary,\nexpected: %+v\n got: %+v", expectedBytes, buf.Bytes())
}
ws[0].Close()

expectedBytes := []byte{
// 1 2 3 4 5 6 7 8 9 10 11 12
// Segment
0x18, 0x53, 0x80, 0x67, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
// SeekHead
0x11, 0x4D, 0x9B, 0x74, 0xAA,
0x4D, 0xBB, 0x92,
0x53, 0xAB, 0x84, 0x15, 0x49, 0xA9, 0x66, // Info
0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2F,
0x4D, 0xBB, 0x92,
0x53, 0xAB, 0x84, 0x16, 0x54, 0xAE, 0x6B, // Tracks
0x53, 0xAC, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3B,
// Info, pos: 47
0x15, 0x49, 0xA9, 0x66, 0x87,
0x2A, 0xD7, 0xB1, 0x83, 0x0F, 0x42, 0x40,
// Tracks, pos: 59
0x16, 0x54, 0xAE, 0x6B, 0x80,
// Cluster
0x1F, 0x43, 0xB6, 0x75, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xE7, 0x81, 0x00,
}
if !bytes.Equal(buf.Bytes(), expectedBytes) {
t.Errorf("Unexpected binary,\nexpected: %+v\n got: %+v", expectedBytes, buf.Bytes())
}
})
t.Run("InvalidHeader", func(t *testing.T) {
buf := &bufferCloser{closed: make(chan struct{})}

_, err := NewSimpleBlockWriter(
buf,
[]TrackDescription{{TrackNumber: 1}},
WithSegmentInfo(&struct {
Invalid uint64 `ebml:"InvalidA"`
}{}),
WithSeekHead(true),
)
if err != ebml.ErrUnknownElementName {
t.Errorf("Unexpected error, expected: %v, got: %v", ebml.ErrUnknownElementName, err)
}
})
}
4 changes: 2 additions & 2 deletions mkvcore/seekhead.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func setSeekHead(header *flexHeader, opts ...ebml.MarshalOption) error {
optsWithHook := append([]ebml.MarshalOption{}, opts...)
optsWithHook = append(optsWithHook, ebml.WithElementWriteHooks(hook))

buf := &bytes.Buffer{}
if err := ebml.Marshal(header, buf, optsWithHook...); err != nil {
var buf bytes.Buffer
if err := ebml.Marshal(header, &buf, optsWithHook...); err != nil {
return err
}

Expand Down

0 comments on commit 4c648e0

Please sign in to comment.