Skip to content

Commit

Permalink
pool writers (indexed by compression level, since that can't be reset)
Browse files Browse the repository at this point in the history
  • Loading branch information
bburghgr committed Nov 16, 2024
1 parent fb07627 commit 8eb1422
Showing 1 changed file with 83 additions and 20 deletions.
103 changes: 83 additions & 20 deletions groot/internal/rcompress/rcompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,17 @@ func compressBlock(alg Kind, lvl int, tgt, src []byte) (int, error) {
hdr[0] = 'Z'
hdr[1] = 'L'
hdr[2] = 8 // zlib deflated
w, err := zlib.NewWriterLevel(buf, lvl)
w, err := zlibGetWriterLevel(buf, lvl)
if err != nil {
return 0, fmt.Errorf("rcompress: could not create ZLIB compressor: %w", err)
}

_, err = w.Write(src)
if err != nil {
_ = w.Close()
_ = zlibPutWriterLevel(w, lvl)
return 0, fmt.Errorf("rcompress: could not write ZLIB compressed bytes: %w", err)
}
err = w.Close()
err = zlibPutWriterLevel(w, lvl)
switch {
case err == nil:
// ok.
Expand Down Expand Up @@ -269,7 +269,7 @@ func compressBlock(alg Kind, lvl int, tgt, src []byte) (int, error) {

const chksum = 8
var room = int(float64(srcsz) * 2e-4) // lz4 needs some extra scratch space
dst := make([]byte, HeaderSize+chksum+len(src)+room)
dst := lz4GetBuffer(HeaderSize + chksum + len(src) + room)
wrk := dst[HeaderSize:]
var n int
switch {
Expand All @@ -284,18 +284,21 @@ func compressBlock(alg Kind, lvl int, tgt, src []byte) (int, error) {
n, err = lz4.CompressBlock(src, wrk[chksum:], ht)
}
if err != nil {
lz4PutBuffer(dst)
return 0, fmt.Errorf("rcompress: could not compress with LZ4: %w", err)
}

if n == 0 {
// not compressible.
lz4PutBuffer(dst)
return len(src), errNoCompression
}

wrk = wrk[:n+chksum]
binary.BigEndian.PutUint64(wrk[:chksum], xxHash64.Checksum(wrk[chksum:], 0))
dstsz = int32(n + chksum)
n = copy(buf.p, wrk)
lz4PutBuffer(dst)
buf.c += n

case ZSTD:
Expand Down Expand Up @@ -377,17 +380,16 @@ func Decompress(dst []byte, src io.Reader) error {
return fmt.Errorf("rcompress: could not create ZLIB reader: %w", err)
}
_, err = io.ReadFull(rc, dst[beg:end])
rc.Close()
zlibReaderPool.Put(rc)
zlibPutReader(rc)
if err != nil {
return fmt.Errorf("rcompress: could not decompress ZLIB buffer: %w", err)
}

case LZ4:
src := lz4NewBuffer(srcsz)
src := lz4GetBuffer(int(srcsz))
_, err = io.ReadFull(lr, src)
if err != nil {
lz4BufferPool.Put(src)
lz4PutBuffer(src)
return fmt.Errorf("rcompress: could not read LZ4 block: %w", err)
}
const chksum = 8
Expand All @@ -398,9 +400,9 @@ func Decompress(dst []byte, src io.Reader) error {
case srcsz > tgtsz:
// no compression
copy(dst[beg:end], src[chksum:])
lz4BufferPool.Put(src)
lz4PutBuffer(src)
default:
lz4BufferPool.Put(src)
lz4PutBuffer(src)
return fmt.Errorf("rcompress: could not decompress LZ4 block: %w", err)
}
}
Expand All @@ -423,13 +425,12 @@ func Decompress(dst []byte, src io.Reader) error {
}

case ZSTD:
rc, err := zstdNewReader(lr)
rc, err := zstdGetReader(lr)
if err != nil {
return fmt.Errorf("rcompress: could not create ZSTD reader: %w", err)
}
_, err = io.ReadFull(rc, dst[beg:end])
rc.Reset(nil)
zstdReaderPool.Put(rc)
zstdPutReader(rc)
if err != nil {
return fmt.Errorf("rcompress: could not decompress ZSTD block: %w", err)
}
Expand Down Expand Up @@ -464,24 +465,29 @@ var (
_ io.Writer = (*wbuff)(nil)
)

// TODO writers, need to index by options (e.g. compression level)
var (
lz4BufferPool = sync.Pool{}
zlibReaderPool = sync.Pool{}
zstdReaderPool = sync.Pool{}
lz4BufferPool sync.Pool
zlibReaderPool sync.Pool
zstdReaderPool sync.Pool
zlibWriterPools sync.Map // map[lvl]*pool
zstdWriterPools sync.Map // map[lvl]*pool
)

func lz4NewBuffer(size int64) []byte {
func lz4GetBuffer(size int) []byte {
var b []byte
if bi := lz4BufferPool.Get(); bi != nil {
b = bi.([]byte)
}
if int64(cap(b)) >= size {
if cap(b) >= size {
return b[:size]
}
return make([]byte, size)
}

func lz4PutBuffer(b []byte) {
lz4BufferPool.Put(b)
}

func zlibNewReader(r io.Reader) (io.ReadCloser, error) {
if ri := zlibReaderPool.Get(); ri != nil {
ri.(zlib.Resetter).Reset(r, nil)
Expand All @@ -490,11 +496,68 @@ func zlibNewReader(r io.Reader) (io.ReadCloser, error) {
return zlib.NewReader(r)
}

func zstdNewReader(r io.Reader) (*zstd.Decoder, error) {
func zlibPutReader(r io.ReadCloser) error {
// Note that zlib readers should be closed (but not reset)
err := r.Close()
zlibReaderPool.Put(r)
return err
}

func zstdGetReader(r io.Reader) (*zstd.Decoder, error) {
if ri := zstdReaderPool.Get(); ri != nil {
rd := ri.(*zstd.Decoder)
rd.Reset(r)
return rd, nil
}
return zstd.NewReader(r)
}

func zstdPutReader(r *zstd.Decoder) {
// Note that zstd decoders should be reset (but not closed)
r.Reset(nil)
zstdReaderPool.Put(r)
}

func zlibGetWriterLevel(w io.Writer, lvl int) (*zlib.Writer, error) {
if pi, ok := zlibWriterPools.Load(lvl); ok {
if wi := pi.(*sync.Pool).Get(); wi != nil {
z := wi.(*zlib.Writer)
z.Reset(w)
return z, nil
}
}
return zlib.NewWriterLevel(w, lvl)
}

func zlibPutWriterLevel(w *zlib.Writer, lvl int) error {
err := w.Close()
if pi, ok := zlibWriterPools.Load(lvl); ok {
pi.(*sync.Pool).Put(w)
} else {
pi, _ = zlibWriterPools.LoadOrStore(lvl, new(sync.Pool))
pi.(*sync.Pool).Put(w)
}
return err
}

func zstdGetWriterLevel(w io.Writer, lvl int) (*zstd.Encoder, error) {
if pi, ok := zstdWriterPools.Load(lvl); ok {
if wi := pi.(*sync.Pool).Get(); wi != nil {
z := wi.(*zstd.Encoder)
z.Reset(w)
return z, nil
}
}
return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(lvl)))
}

func zstdPutWriterLevel(w *zstd.Encoder, lvl int) error {
err := w.Close()
if pi, ok := zstdWriterPools.Load(lvl); ok {
pi.(*sync.Pool).Put(w)
} else {
pi, _ = zstdWriterPools.LoadOrStore(lvl, new(sync.Pool))
pi.(*sync.Pool).Put(w)
}
return err
}

0 comments on commit 8eb1422

Please sign in to comment.