Skip to content

Commit

Permalink
track buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Jan 21, 2025
1 parent b615b35 commit 70f0fe6
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 89 deletions.
4 changes: 2 additions & 2 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ type codecV0Bridge struct {
func (c codecV0Bridge) Marshal(v any) (mem.BufferSlice, error) {
data, err := c.codec.Marshal(v)
if err != nil {
return nil, err
return mem.BufferSlice{}, err
}
return mem.BufferSlice{mem.SliceBuffer(data)}, nil
return mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(data)}}, nil
}

func (c codecV0Bridge) Unmarshal(data mem.BufferSlice, v any) (err error) {
Expand Down
2 changes: 1 addition & 1 deletion encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ type errProtoCodec struct {

func (c *errProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
if c.encodingErr != nil {
return nil, c.encodingErr
return mem.BufferSlice{}, c.encodingErr
}
return encoding.GetCodecV2(proto.Name).Marshal(v)
}
Expand Down
10 changes: 5 additions & 5 deletions encoding/proto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,24 @@ type codecV2 struct{}
func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
vv := messageV2Of(v)
if vv == nil {
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
return mem.BufferSlice{}, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
}

size := proto.Size(vv)
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
if err != nil {
return nil, err
return mem.BufferSlice{}, err
}
data = append(data, mem.SliceBuffer(buf))
data.Bufs = append(data.Bufs, mem.SliceBuffer(buf))
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err
return mem.BufferSlice{}, err
}
data = append(data, mem.NewBuffer(buf, pool))
data.Bufs = append(data.Bufs, mem.NewBuffer(buf, pool))
}

return data, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ func (ht *serverHandlerTransport) write(s *ServerStream, hdr []byte, data mem.Bu
ht.writePendingHeaders(s)
}
ht.rw.Write(hdr)
for _, b := range data {
for _, b := range data.Bufs {
_, _ = ht.rw.Write(b.ReadOnlyData())
}
ht.rw.(http.Flusher).Flush()
Expand Down
6 changes: 3 additions & 3 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) {
func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.er; er != nil {
return nil, er
return mem.BufferSlice{}, er
}
s.requestRead(n)
for n != 0 {
Expand All @@ -394,9 +394,9 @@ func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
err = io.ErrUnexpectedEOF
}
data.Free()
return nil, err
return mem.BufferSlice{}, err
}
data = append(data, buf)
data.Bufs = append(data.Bufs, buf)
}
return data, nil
}
Expand Down
8 changes: 4 additions & 4 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func init() {
}

func newBufferSlice(b []byte) mem.BufferSlice {
return mem.BufferSlice{mem.SliceBuffer(b)}
return mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(b)}}
}

func (s *Stream) readTo(p []byte) (int, error) {
Expand Down Expand Up @@ -879,7 +879,7 @@ func (s) TestGracefulClose(t *testing.T) {
}

// Confirm the existing stream still functions as expected.
s.Write(nil, nil, &WriteOptions{Last: true})
s.Write(nil, mem.BufferSlice{}, &WriteOptions{Last: true})
if _, err := s.readTo(incomingHeader); err != io.EOF {
t.Fatalf("Client expected EOF from the server. Got: %v", err)
}
Expand Down Expand Up @@ -1714,7 +1714,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
st.mu.Unlock()
// Close all streams
for _, stream := range clientStreams {
stream.Write(nil, nil, &WriteOptions{Last: true})
stream.Write(nil, mem.BufferSlice{}, &WriteOptions{Last: true})
if _, err := stream.readTo(make([]byte, 5)); err != io.EOF {
t.Fatalf("Client expected an EOF from the server. Got: %v", err)
}
Expand Down Expand Up @@ -2296,7 +2296,7 @@ func runPingPongTest(t *testing.T, msgSize int) {
}
}

stream.Write(nil, nil, &WriteOptions{Last: true})
stream.Write(nil, mem.BufferSlice{}, &WriteOptions{Last: true})
if _, err := stream.readTo(incomingHeader); err != io.EOF {
t.Fatalf("Client expected EOF from the server. Got: %v", err)
}
Expand Down
45 changes: 28 additions & 17 deletions mem/buffer_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ const (
// By convention, any APIs that return (mem.BufferSlice, error) should reduce
// the burden on the caller by never returning a mem.BufferSlice that needs to
// be freed if the error is non-nil, unless explicitly stated.
type BufferSlice []Buffer
type BufferSlice struct {
Bufs []Buffer
track bool
}

// Len returns the sum of the length of all the Buffers in this slice.
//
Expand All @@ -51,22 +54,22 @@ type BufferSlice []Buffer
// in the slice, and *not* the value returned by this function.
func (s BufferSlice) Len() int {
var length int
for _, b := range s {
for _, b := range s.Bufs {
length += b.Len()
}
return length
}

// Ref invokes Ref on each buffer in the slice.
func (s BufferSlice) Ref() {
for _, b := range s {
for _, b := range s.Bufs {
b.Ref()
}
}

// Free invokes Buffer.Free() on each Buffer in the slice.
func (s BufferSlice) Free() {
for _, b := range s {
for _, b := range s.Bufs {
b.Free()
}
}
Expand All @@ -77,7 +80,7 @@ func (s BufferSlice) Free() {
// is full or s runs out of data, returning the minimum of s.Len() and len(dst).
func (s BufferSlice) CopyTo(dst []byte) int {
off := 0
for _, b := range s {
for _, b := range s.Bufs {
off += copy(dst[off:], b.ReadOnlyData())
}
return off
Expand All @@ -102,9 +105,9 @@ func (s BufferSlice) Materialize() []byte {
// function simply increases the refcount before returning said Buffer. Freeing this
// buffer won't release it until the BufferSlice is itself released.
func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
if len(s) == 1 {
s[0].Ref()
return s[0]
if len(s.Bufs) == 1 {
s.Bufs[0].Ref()
return s.Bufs[0]
}
sLen := s.Len()
if sLen == 0 {
Expand All @@ -115,6 +118,14 @@ func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
return NewBuffer(buf, pool)
}

// Track enables tracking.
func (s BufferSlice) Track() {
s.track = true

Check failure on line 123 in mem/buffer_slice.go

View workflow job for this annotation

GitHub Actions / tests (vet, 1.22)

ineffective assignment to field BufferSlice.track (SA4005)
for _, b := range s.Bufs {
b.Track()
}
}

// Reader returns a new Reader for the input slice after taking references to
// each underlying buffer.
func (s BufferSlice) Reader() Reader {
Expand Down Expand Up @@ -152,18 +163,18 @@ func (r *sliceReader) Remaining() int {

func (r *sliceReader) Close() error {
r.data.Free()
r.data = nil
r.data.Bufs = nil
r.len = 0
return nil
}

func (r *sliceReader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
if len(r.data.Bufs) == 0 || r.bufferIdx != len(r.data.Bufs[0].ReadOnlyData()) {
return false
}

r.data[0].Free()
r.data = r.data[1:]
r.data.Bufs[0].Free()
r.data.Bufs = r.data.Bufs[1:]
r.bufferIdx = 0
return true
}
Expand All @@ -176,7 +187,7 @@ func (r *sliceReader) Read(buf []byte) (n int, _ error) {
for len(buf) != 0 && r.len != 0 {
// Copy as much as possible from the first Buffer in the slice into the
// given byte slice.
data := r.data[0].ReadOnlyData()
data := r.data.Bufs[0].ReadOnlyData()
copied := copy(buf, data[r.bufferIdx:])
r.len -= copied // Reduce len by the number of bytes copied.
r.bufferIdx += copied // Increment the buffer index.
Expand All @@ -201,7 +212,7 @@ func (r *sliceReader) ReadByte() (byte, error) {
for r.freeFirstBufferIfEmpty() {
}

b := r.data[0].ReadOnlyData()[r.bufferIdx]
b := r.data.Bufs[0].ReadOnlyData()[r.bufferIdx]
r.len--
r.bufferIdx++
// Free the first buffer in the slice if the last byte was read
Expand All @@ -218,7 +229,7 @@ type writer struct {

func (w *writer) Write(p []byte) (n int, err error) {
b := Copy(p, w.pool)
*w.buffers = append(*w.buffers, b)
*&w.buffers.Bufs = append(*&w.buffers.Bufs, b)

Check failure on line 232 in mem/buffer_slice.go

View workflow job for this annotation

GitHub Actions / tests (vet, 1.22)

*&x will be simplified to x. It will not copy x. (SA4001)

Check failure on line 232 in mem/buffer_slice.go

View workflow job for this annotation

GitHub Actions / tests (vet, 1.22)

*&x will be simplified to x. It will not copy x. (SA4001)
return b.Len(), nil
}

Expand Down Expand Up @@ -265,15 +276,15 @@ nextBuffer:
pool.Put(buf)
} else {
*buf = (*buf)[:usedCap]
result = append(result, NewBuffer(buf, pool))
result.Bufs = append(result.Bufs, NewBuffer(buf, pool))
}
if err == io.EOF {
err = nil
}
return result, err
}
if len(*buf) == usedCap {
result = append(result, NewBuffer(buf, pool))
result.Bufs = append(result.Bufs, NewBuffer(buf, pool))
continue nextBuffer
}
}
Expand Down
32 changes: 16 additions & 16 deletions mem/buffer_slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ func (s) TestBufferSlice_Len(t *testing.T) {
}{
{
name: "empty",
in: nil,
in: mem.BufferSlice{},
want: 0,
},
{
name: "single",
in: mem.BufferSlice{newBuffer([]byte("abcd"), nil)},
in: mem.BufferSlice{Bufs: []mem.Buffer{newBuffer([]byte("abcd"), nil)}},
want: 4,
},
{
name: "multiple",
in: mem.BufferSlice{
in: mem.BufferSlice{Bufs: []mem.Buffer{
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
},
}},
want: 12,
},
}
Expand All @@ -76,10 +76,10 @@ func (s) TestBufferSlice_Len(t *testing.T) {

func (s) TestBufferSlice_Ref(t *testing.T) {
// Create a new buffer slice and a reference to it.
bs := mem.BufferSlice{
bs := mem.BufferSlice{Bufs: []mem.Buffer{
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
}
}}
bs.Ref()

// Free the original buffer slice and verify that the reference can still
Expand All @@ -101,17 +101,17 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) {
}{
{
name: "single",
in: mem.BufferSlice{newBuffer([]byte("abcd"), nil)},
in: mem.BufferSlice{Bufs: []mem.Buffer{newBuffer([]byte("abcd"), nil)}},
pool: nil, // MaterializeToBuffer should not use the pool in this case.
wantData: []byte("abcd"),
},
{
name: "multiple",
in: mem.BufferSlice{
in: mem.BufferSlice{Bufs: []mem.Buffer{
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
},
}},
pool: mem.DefaultBufferPool(),
wantData: []byte("abcdabcdabcd"),
},
Expand All @@ -129,11 +129,11 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) {
}

func (s) TestBufferSlice_Reader(t *testing.T) {
bs := mem.BufferSlice{
bs := mem.BufferSlice{Bufs: []mem.Buffer{
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
newBuffer([]byte("abcd"), nil),
}
}}
wantData := []byte("abcdabcdabcd")

reader := bs.Reader()
Expand Down Expand Up @@ -347,13 +347,13 @@ func (s) TestBufferSlice_ReadAll_Reads(t *testing.T) {
if !bytes.Equal(r.read, gotData) {
t.Fatalf("ReadAll() returned data %q, wanted %q", gotData, r.read)
}
if len(data) != tc.wantBufs {
t.Fatalf("ReadAll() returned %d bufs, wanted %d bufs", len(data), tc.wantBufs)
if len(data.Bufs) != tc.wantBufs {
t.Fatalf("ReadAll() returned %d bufs, wanted %d bufs", len(data.Bufs), tc.wantBufs)
}
// all but last should be full buffers
for i := 0; i < len(data)-1; i++ {
if data[i].Len() != readAllBufSize {
t.Fatalf("ReadAll() returned data length %d, wanted %d", data[i].Len(), readAllBufSize)
for i := 0; i < len(data.Bufs)-1; i++ {
if data.Bufs[i].Len() != readAllBufSize {
t.Fatalf("ReadAll() returned data length %d, wanted %d", data.Bufs[i].Len(), readAllBufSize)
}
}
data.Free()
Expand Down
Loading

0 comments on commit 70f0fe6

Please sign in to comment.