diff --git a/codec.go b/codec.go index 959c2f99d4a7..e56f3620b545 100644 --- a/codec.go +++ b/codec.go @@ -69,9 +69,9 @@ type codecV0Bridge struct { func (c codecV0Bridge) Marshal(v any) (mem.BufferSlice, error) { data, err := c.codec.Marshal(v) if err != nil { - return nil, err + return mem.BufferSlice{}, err } - return mem.BufferSlice{mem.SliceBuffer(data)}, nil + return mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(data)}}, nil } func (c codecV0Bridge) Unmarshal(data mem.BufferSlice, v any) (err error) { diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go index 19769146f91b..65235d9d860c 100644 --- a/encoding/encoding_test.go +++ b/encoding/encoding_test.go @@ -93,7 +93,7 @@ type errProtoCodec struct { func (c *errProtoCodec) Marshal(v any) (mem.BufferSlice, error) { if c.encodingErr != nil { - return nil, c.encodingErr + return mem.BufferSlice{}, c.encodingErr } return encoding.GetCodecV2(proto.Name).Marshal(v) } diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go index ceec319dd2fb..664dda05fd24 100644 --- a/encoding/proto/proto.go +++ b/encoding/proto/proto.go @@ -43,24 +43,24 @@ type codecV2 struct{} func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) { vv := messageV2Of(v) if vv == nil { - return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v) + return mem.BufferSlice{}, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v) } size := proto.Size(vv) if mem.IsBelowBufferPoolingThreshold(size) { buf, err := proto.Marshal(vv) if err != nil { - return nil, err + return mem.BufferSlice{}, err } - data = append(data, mem.SliceBuffer(buf)) + data.Bufs = append(data.Bufs, mem.SliceBuffer(buf)) } else { pool := mem.DefaultBufferPool() buf := pool.Get(size) if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil { pool.Put(buf) - return nil, err + return mem.BufferSlice{}, err } - data = append(data, mem.NewBuffer(buf, pool)) + data.Bufs = append(data.Bufs, mem.NewBuffer(buf, pool)) } return data, nil diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index d9305a65d88f..4e18cecfc731 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -345,7 +345,7 @@ func (ht *serverHandlerTransport) write(s *ServerStream, hdr []byte, data mem.Bu ht.writePendingHeaders(s) } ht.rw.Write(hdr) - for _, b := range data { + for _, b := range data.Bufs { _, _ = ht.rw.Write(b.ReadOnlyData()) } ht.rw.(http.Flusher).Flush() diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 2859b87755f0..22e0dc9c0541 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -376,7 +376,7 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) { func (s *Stream) read(n int) (data mem.BufferSlice, err error) { // Don't request a read if there was an error earlier if er := s.trReader.er; er != nil { - return nil, er + return mem.BufferSlice{}, er } s.requestRead(n) for n != 0 { @@ -394,9 +394,9 @@ func (s *Stream) read(n int) (data mem.BufferSlice, err error) { err = io.ErrUnexpectedEOF } data.Free() - return nil, err + return mem.BufferSlice{}, err } - data = append(data, buf) + data.Bufs = append(data.Bufs, buf) } return data, nil } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c91757b3f96e..d3371963b30c 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -76,7 +76,7 @@ func init() { } func newBufferSlice(b []byte) mem.BufferSlice { - return mem.BufferSlice{mem.SliceBuffer(b)} + return mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(b)}} } func (s *Stream) readTo(p []byte) (int, error) { @@ -879,7 +879,7 @@ func (s) TestGracefulClose(t *testing.T) { } // Confirm the existing stream still functions as expected. - s.Write(nil, nil, &WriteOptions{Last: true}) + s.Write(nil, mem.BufferSlice{}, &WriteOptions{Last: true}) if _, err := s.readTo(incomingHeader); err != io.EOF { t.Fatalf("Client expected EOF from the server. Got: %v", err) } @@ -1714,7 +1714,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) st.mu.Unlock() // Close all streams for _, stream := range clientStreams { - stream.Write(nil, nil, &WriteOptions{Last: true}) + stream.Write(nil, mem.BufferSlice{}, &WriteOptions{Last: true}) if _, err := stream.readTo(make([]byte, 5)); err != io.EOF { t.Fatalf("Client expected an EOF from the server. Got: %v", err) } @@ -2296,7 +2296,7 @@ func runPingPongTest(t *testing.T, msgSize int) { } } - stream.Write(nil, nil, &WriteOptions{Last: true}) + stream.Write(nil, mem.BufferSlice{}, &WriteOptions{Last: true}) if _, err := stream.readTo(incomingHeader); err != io.EOF { t.Fatalf("Client expected EOF from the server. Got: %v", err) } diff --git a/mem/buffer_slice.go b/mem/buffer_slice.go index 65002e2cc851..64d113d0bc7b 100644 --- a/mem/buffer_slice.go +++ b/mem/buffer_slice.go @@ -41,7 +41,10 @@ const ( // By convention, any APIs that return (mem.BufferSlice, error) should reduce // the burden on the caller by never returning a mem.BufferSlice that needs to // be freed if the error is non-nil, unless explicitly stated. -type BufferSlice []Buffer +type BufferSlice struct { + Bufs []Buffer + track bool +} // Len returns the sum of the length of all the Buffers in this slice. // @@ -51,7 +54,7 @@ type BufferSlice []Buffer // in the slice, and *not* the value returned by this function. func (s BufferSlice) Len() int { var length int - for _, b := range s { + for _, b := range s.Bufs { length += b.Len() } return length @@ -59,14 +62,14 @@ func (s BufferSlice) Len() int { // Ref invokes Ref on each buffer in the slice. func (s BufferSlice) Ref() { - for _, b := range s { + for _, b := range s.Bufs { b.Ref() } } // Free invokes Buffer.Free() on each Buffer in the slice. func (s BufferSlice) Free() { - for _, b := range s { + for _, b := range s.Bufs { b.Free() } } @@ -77,7 +80,7 @@ func (s BufferSlice) Free() { // is full or s runs out of data, returning the minimum of s.Len() and len(dst). func (s BufferSlice) CopyTo(dst []byte) int { off := 0 - for _, b := range s { + for _, b := range s.Bufs { off += copy(dst[off:], b.ReadOnlyData()) } return off @@ -102,9 +105,9 @@ func (s BufferSlice) Materialize() []byte { // function simply increases the refcount before returning said Buffer. Freeing this // buffer won't release it until the BufferSlice is itself released. func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer { - if len(s) == 1 { - s[0].Ref() - return s[0] + if len(s.Bufs) == 1 { + s.Bufs[0].Ref() + return s.Bufs[0] } sLen := s.Len() if sLen == 0 { @@ -115,6 +118,14 @@ func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer { return NewBuffer(buf, pool) } +// Track enables tracking. +func (s BufferSlice) Track() { + s.track = true + for _, b := range s.Bufs { + b.Track() + } +} + // Reader returns a new Reader for the input slice after taking references to // each underlying buffer. func (s BufferSlice) Reader() Reader { @@ -152,18 +163,18 @@ func (r *sliceReader) Remaining() int { func (r *sliceReader) Close() error { r.data.Free() - r.data = nil + r.data.Bufs = nil r.len = 0 return nil } func (r *sliceReader) freeFirstBufferIfEmpty() bool { - if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) { + if len(r.data.Bufs) == 0 || r.bufferIdx != len(r.data.Bufs[0].ReadOnlyData()) { return false } - r.data[0].Free() - r.data = r.data[1:] + r.data.Bufs[0].Free() + r.data.Bufs = r.data.Bufs[1:] r.bufferIdx = 0 return true } @@ -176,7 +187,7 @@ func (r *sliceReader) Read(buf []byte) (n int, _ error) { for len(buf) != 0 && r.len != 0 { // Copy as much as possible from the first Buffer in the slice into the // given byte slice. - data := r.data[0].ReadOnlyData() + data := r.data.Bufs[0].ReadOnlyData() copied := copy(buf, data[r.bufferIdx:]) r.len -= copied // Reduce len by the number of bytes copied. r.bufferIdx += copied // Increment the buffer index. @@ -201,7 +212,7 @@ func (r *sliceReader) ReadByte() (byte, error) { for r.freeFirstBufferIfEmpty() { } - b := r.data[0].ReadOnlyData()[r.bufferIdx] + b := r.data.Bufs[0].ReadOnlyData()[r.bufferIdx] r.len-- r.bufferIdx++ // Free the first buffer in the slice if the last byte was read @@ -218,7 +229,7 @@ type writer struct { func (w *writer) Write(p []byte) (n int, err error) { b := Copy(p, w.pool) - *w.buffers = append(*w.buffers, b) + *&w.buffers.Bufs = append(*&w.buffers.Bufs, b) return b.Len(), nil } @@ -265,7 +276,7 @@ nextBuffer: pool.Put(buf) } else { *buf = (*buf)[:usedCap] - result = append(result, NewBuffer(buf, pool)) + result.Bufs = append(result.Bufs, NewBuffer(buf, pool)) } if err == io.EOF { err = nil @@ -273,7 +284,7 @@ nextBuffer: return result, err } if len(*buf) == usedCap { - result = append(result, NewBuffer(buf, pool)) + result.Bufs = append(result.Bufs, NewBuffer(buf, pool)) continue nextBuffer } } diff --git a/mem/buffer_slice_test.go b/mem/buffer_slice_test.go index bb9303f0e9e1..e5b5566ae481 100644 --- a/mem/buffer_slice_test.go +++ b/mem/buffer_slice_test.go @@ -47,21 +47,21 @@ func (s) TestBufferSlice_Len(t *testing.T) { }{ { name: "empty", - in: nil, + in: mem.BufferSlice{}, want: 0, }, { name: "single", - in: mem.BufferSlice{newBuffer([]byte("abcd"), nil)}, + in: mem.BufferSlice{Bufs: []mem.Buffer{newBuffer([]byte("abcd"), nil)}}, want: 4, }, { name: "multiple", - in: mem.BufferSlice{ + in: mem.BufferSlice{Bufs: []mem.Buffer{ newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), - }, + }}, want: 12, }, } @@ -76,10 +76,10 @@ func (s) TestBufferSlice_Len(t *testing.T) { func (s) TestBufferSlice_Ref(t *testing.T) { // Create a new buffer slice and a reference to it. - bs := mem.BufferSlice{ + bs := mem.BufferSlice{Bufs: []mem.Buffer{ newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), - } + }} bs.Ref() // Free the original buffer slice and verify that the reference can still @@ -101,17 +101,17 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) { }{ { name: "single", - in: mem.BufferSlice{newBuffer([]byte("abcd"), nil)}, + in: mem.BufferSlice{Bufs: []mem.Buffer{newBuffer([]byte("abcd"), nil)}}, pool: nil, // MaterializeToBuffer should not use the pool in this case. wantData: []byte("abcd"), }, { name: "multiple", - in: mem.BufferSlice{ + in: mem.BufferSlice{Bufs: []mem.Buffer{ newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), - }, + }}, pool: mem.DefaultBufferPool(), wantData: []byte("abcdabcdabcd"), }, @@ -129,11 +129,11 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) { } func (s) TestBufferSlice_Reader(t *testing.T) { - bs := mem.BufferSlice{ + bs := mem.BufferSlice{Bufs: []mem.Buffer{ newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), newBuffer([]byte("abcd"), nil), - } + }} wantData := []byte("abcdabcdabcd") reader := bs.Reader() @@ -347,13 +347,13 @@ func (s) TestBufferSlice_ReadAll_Reads(t *testing.T) { if !bytes.Equal(r.read, gotData) { t.Fatalf("ReadAll() returned data %q, wanted %q", gotData, r.read) } - if len(data) != tc.wantBufs { - t.Fatalf("ReadAll() returned %d bufs, wanted %d bufs", len(data), tc.wantBufs) + if len(data.Bufs) != tc.wantBufs { + t.Fatalf("ReadAll() returned %d bufs, wanted %d bufs", len(data.Bufs), tc.wantBufs) } // all but last should be full buffers - for i := 0; i < len(data)-1; i++ { - if data[i].Len() != readAllBufSize { - t.Fatalf("ReadAll() returned data length %d, wanted %d", data[i].Len(), readAllBufSize) + for i := 0; i < len(data.Bufs)-1; i++ { + if data.Bufs[i].Len() != readAllBufSize { + t.Fatalf("ReadAll() returned data length %d, wanted %d", data.Bufs[i].Len(), readAllBufSize) } } data.Free() diff --git a/mem/buffers.go b/mem/buffers.go index ecbf0b9a73ea..9f822ba40cfe 100644 --- a/mem/buffers.go +++ b/mem/buffers.go @@ -27,6 +27,7 @@ package mem import ( "fmt" + "runtime/debug" "sync" "sync/atomic" ) @@ -54,6 +55,8 @@ type Buffer interface { // Len returns the Buffer's size. Len() int + Track() + split(n int) (left, right Buffer) read(buf []byte) (int, Buffer) } @@ -73,10 +76,13 @@ func IsBelowBufferPoolingThreshold(size int) bool { } type buffer struct { - origData *[]byte - data []byte - refs *atomic.Int32 - pool BufferPool + origData *[]byte + data []byte + refs *atomic.Int32 + pool BufferPool + tracking bool + freeCallers []string + refCallers []string } func newBuffer() *buffer { @@ -134,13 +140,23 @@ func (b *buffer) ReadOnlyData() []byte { } func (b *buffer) Ref() { + if b.tracking { + b.refCallers = append(b.refCallers, string(debug.Stack())) + } if b.refs == nil { panic("Cannot ref freed buffer") } b.refs.Add(1) } +func (b *buffer) Track() { + b.tracking = true +} + func (b *buffer) Free() { + if b.tracking { + b.freeCallers = append(b.freeCallers, string(debug.Stack())) + } if b.refs == nil { panic("Cannot free freed buffer") } @@ -161,7 +177,15 @@ func (b *buffer) Free() { b.pool = nil bufferObjectPool.Put(b) default: - panic("Cannot free freed buffer") + op := "Free callers" + for _, c := range b.freeCallers { + op = op + "\n\n" + c + } + op = op + "\nRef callers" + for _, c := range b.refCallers { + op = op + "\n\n" + c + } + panic("Cannot free freed buffer" + op) } } @@ -224,8 +248,9 @@ func (e emptyBuffer) ReadOnlyData() []byte { return nil } -func (e emptyBuffer) Ref() {} -func (e emptyBuffer) Free() {} +func (e emptyBuffer) Ref() {} +func (e emptyBuffer) Free() {} +func (e emptyBuffer) Track() {} func (e emptyBuffer) Len() int { return 0 @@ -252,6 +277,8 @@ func (s SliceBuffer) Ref() {} // Free is a noop implementation of Free. func (s SliceBuffer) Free() {} +func (s SliceBuffer) Track() {} + // Len is a noop implementation of Len. func (s SliceBuffer) Len() int { return len(s) } diff --git a/preloader.go b/preloader.go index ee0ff969af4d..bd803743f621 100644 --- a/preloader.go +++ b/preloader.go @@ -62,7 +62,7 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error { materializedData := data.Materialize() data.Free() - p.encodedData = mem.BufferSlice{mem.SliceBuffer(materializedData)} + p.encodedData = mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(materializedData)}} // TODO: it should be possible to grab the bufferPool from the underlying // stream implementation with a type cast to its actual type (such as @@ -76,7 +76,7 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error { if p.pf.isCompressed() { materializedCompData := compData.Materialize() compData.Free() - compData = mem.BufferSlice{mem.SliceBuffer(materializedCompData)} + compData = mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(materializedCompData)}} } p.hdr, p.payload = msgHeader(p.encodedData, compData, p.pf) diff --git a/rpc_util.go b/rpc_util.go index 9fac2b08b48b..5a5798d9be5f 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -658,17 +658,17 @@ type parser struct { func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSlice, error) { err := p.r.ReadMessageHeader(p.header[:]) if err != nil { - return 0, nil, err + return 0, mem.BufferSlice{}, err } pf := payloadFormat(p.header[0]) length := binary.BigEndian.Uint32(p.header[1:]) if int64(length) > int64(maxInt) { - return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) + return 0, mem.BufferSlice{}, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) } if int(length) > maxReceiveMessageSize { - return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) + return 0, mem.BufferSlice{}, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) } data, err := p.r.Read(int(length)) @@ -676,7 +676,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSl if err == io.EOF { err = io.ErrUnexpectedEOF } - return 0, nil, err + return 0, mem.BufferSlice{}, err } return pf, data, nil } @@ -686,15 +686,15 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSl // generates an empty message. func encode(c baseCodec, msg any) (mem.BufferSlice, error) { if msg == nil { // NOTE: typed nils will not be caught by this check - return nil, nil + return mem.BufferSlice{}, nil } b, err := c.Marshal(msg) if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + return mem.BufferSlice{}, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if uint(b.Len()) > math.MaxUint32 { b.Free() - return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + return mem.BufferSlice{}, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b.Bufs)) } return b, nil } @@ -706,7 +706,7 @@ func encode(c baseCodec, msg any) (mem.BufferSlice, error) { // TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor. func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) { if (compressor == nil && cp == nil) || in.Len() == 0 { - return nil, compressionNone, nil + return mem.BufferSlice{}, compressionNone, nil } var out mem.BufferSlice w := mem.NewWriter(&out, pool) @@ -717,15 +717,15 @@ func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, if compressor != nil { z, err := compressor.Compress(w) if err != nil { - return nil, 0, wrapErr(err) + return mem.BufferSlice{}, 0, wrapErr(err) } - for _, b := range in { + for _, b := range in.Bufs { if _, err := z.Write(b.ReadOnlyData()); err != nil { - return nil, 0, wrapErr(err) + return mem.BufferSlice{}, 0, wrapErr(err) } } if err := z.Close(); err != nil { - return nil, 0, wrapErr(err) + return mem.BufferSlice{}, 0, wrapErr(err) } } else { // This is obviously really inefficient since it fully materializes the data, but @@ -735,7 +735,7 @@ func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, buf := in.MaterializeToBuffer(pool) defer buf.Free() if err := cp.Do(w, buf.ReadOnlyData()); err != nil { - return nil, 0, wrapErr(err) + return mem.BufferSlice{}, 0, wrapErr(err) } } return out, compressionMade, nil @@ -803,7 +803,7 @@ type payloadInfo struct { } func (p *payloadInfo) free() { - if p != nil && p.uncompressedBytes != nil { + if p != nil && p.uncompressedBytes.Bufs != nil { p.uncompressedBytes.Free() } } @@ -818,14 +818,14 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM ) (out mem.BufferSlice, err error) { pf, compressed, err := p.recvMsg(maxReceiveMessageSize) if err != nil { - return nil, err + return mem.BufferSlice{}, err } compressedLength := compressed.Len() if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil { compressed.Free() - return nil, st.Err() + return mem.BufferSlice{}, st.Err() } var size int @@ -838,25 +838,26 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM var uncompressedBuf []byte uncompressedBuf, err = dc.Do(compressed.Reader()) if err == nil { - out = mem.BufferSlice{mem.SliceBuffer(uncompressedBuf)} + out = mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(uncompressedBuf)}} } size = len(uncompressedBuf) } else { out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool) } if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) + return mem.BufferSlice{}, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) } if size > maxReceiveMessageSize { out.Free() // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) + return mem.BufferSlice{}, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) } } else { out = compressed } + out.Track() if payInfo != nil { payInfo.compressedLength = compressedLength out.Ref() @@ -871,13 +872,13 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) { dcReader, err := compressor.Decompress(d.Reader()) if err != nil { - return nil, 0, err + return mem.BufferSlice{}, 0, err } out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool) if err != nil { out.Free() - return nil, 0, err + return mem.BufferSlice{}, 0, err } return out, out.Len(), nil } diff --git a/rpc_util_test.go b/rpc_util_test.go index 94f50bc24ade..c5dab0cc1c1f 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -53,23 +53,23 @@ func (f *fullReader) ReadMessageHeader(header []byte) error { func (f *fullReader) Read(n int) (mem.BufferSlice, error) { if n == 0 { - return nil, nil + return mem.BufferSlice{}, nil } if len(f.data) == 0 { - return nil, io.EOF + return mem.BufferSlice{}, io.EOF } if len(f.data) < n { data := f.data f.data = nil - return mem.BufferSlice{mem.SliceBuffer(data)}, io.ErrUnexpectedEOF + return mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(data)}}, io.ErrUnexpectedEOF } buf := f.data[:n] f.data = f.data[n:] - return mem.BufferSlice{mem.SliceBuffer(buf)}, nil + return mem.BufferSlice{Bufs: []mem.Buffer{mem.SliceBuffer(buf)}}, nil } var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface @@ -146,7 +146,7 @@ func (s) TestEncode(t *testing.T) { t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err) continue } - if hdr, _ := msgHeader(data, nil, compressionNone); !bytes.Equal(hdr, test.hdr) { + if hdr, _ := msgHeader(data, mem.BufferSlice{}, compressionNone); !bytes.Equal(hdr, test.hdr) { t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr) } } @@ -225,7 +225,7 @@ func bmEncode(b *testing.B, mSize int) { cdc := getCodec(protoenc.Name) msg := &perfpb.Buffer{Body: make([]byte, mSize)} encodeData, _ := encode(cdc, msg) - encodedSz := int64(len(encodeData)) + encodedSz := int64(len(encodeData.Bufs)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/stream.go b/stream.go index 17e2267b3320..2d6af059b483 100644 --- a/stream.go +++ b/stream.go @@ -992,7 +992,7 @@ func (cs *clientStream) CloseSend() error { } cs.sentLast = true op := func(a *csAttempt) error { - a.s.Write(nil, nil, &transport.WriteOptions{Last: true}) + a.s.Write(nil, mem.BufferSlice{}, &transport.WriteOptions{Last: true}) // Always return nil; io.EOF is the only error that might make sense // instead, but there is no need to signal the client to call RecvMsg // as the only use left for the stream after CloseSend is to call @@ -1375,7 +1375,7 @@ func (as *addrConnStream) CloseSend() error { } as.sentLast = true - as.s.Write(nil, nil, &transport.WriteOptions{Last: true}) + as.s.Write(nil, mem.BufferSlice{}, &transport.WriteOptions{Last: true}) // Always return nil; io.EOF is the only error that might make sense // instead, but there is no need to signal the client to call RecvMsg // as the only use left for the stream after CloseSend is to call @@ -1811,12 +1811,12 @@ func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, // Marshal and Compress the data at this point data, err = encode(codec, m) if err != nil { - return nil, nil, nil, 0, err + return nil, mem.BufferSlice{}, mem.BufferSlice{}, 0, err } compData, pf, err := compress(data, cp, comp, pool) if err != nil { data.Free() - return nil, nil, nil, 0, err + return nil, mem.BufferSlice{}, mem.BufferSlice{}, 0, err } hdr, payload = msgHeader(data, compData, pf) return hdr, data, payload, pf, nil