Skip to content

Commit 8eb1422

Browse files
committed
pool writers (indexed by compression level, since that can't be reset)
1 parent fb07627 commit 8eb1422

File tree

1 file changed

+83
-20
lines changed

1 file changed

+83
-20
lines changed

groot/internal/rcompress/rcompress.go

+83-20
Original file line numberDiff line numberDiff line change
@@ -207,17 +207,17 @@ func compressBlock(alg Kind, lvl int, tgt, src []byte) (int, error) {
207207
hdr[0] = 'Z'
208208
hdr[1] = 'L'
209209
hdr[2] = 8 // zlib deflated
210-
w, err := zlib.NewWriterLevel(buf, lvl)
210+
w, err := zlibGetWriterLevel(buf, lvl)
211211
if err != nil {
212212
return 0, fmt.Errorf("rcompress: could not create ZLIB compressor: %w", err)
213213
}
214214

215215
_, err = w.Write(src)
216216
if err != nil {
217-
_ = w.Close()
217+
_ = zlibPutWriterLevel(w, lvl)
218218
return 0, fmt.Errorf("rcompress: could not write ZLIB compressed bytes: %w", err)
219219
}
220-
err = w.Close()
220+
err = zlibPutWriterLevel(w, lvl)
221221
switch {
222222
case err == nil:
223223
// ok.
@@ -269,7 +269,7 @@ func compressBlock(alg Kind, lvl int, tgt, src []byte) (int, error) {
269269

270270
const chksum = 8
271271
var room = int(float64(srcsz) * 2e-4) // lz4 needs some extra scratch space
272-
dst := make([]byte, HeaderSize+chksum+len(src)+room)
272+
dst := lz4GetBuffer(HeaderSize + chksum + len(src) + room)
273273
wrk := dst[HeaderSize:]
274274
var n int
275275
switch {
@@ -284,18 +284,21 @@ func compressBlock(alg Kind, lvl int, tgt, src []byte) (int, error) {
284284
n, err = lz4.CompressBlock(src, wrk[chksum:], ht)
285285
}
286286
if err != nil {
287+
lz4PutBuffer(dst)
287288
return 0, fmt.Errorf("rcompress: could not compress with LZ4: %w", err)
288289
}
289290

290291
if n == 0 {
291292
// not compressible.
293+
lz4PutBuffer(dst)
292294
return len(src), errNoCompression
293295
}
294296

295297
wrk = wrk[:n+chksum]
296298
binary.BigEndian.PutUint64(wrk[:chksum], xxHash64.Checksum(wrk[chksum:], 0))
297299
dstsz = int32(n + chksum)
298300
n = copy(buf.p, wrk)
301+
lz4PutBuffer(dst)
299302
buf.c += n
300303

301304
case ZSTD:
@@ -377,17 +380,16 @@ func Decompress(dst []byte, src io.Reader) error {
377380
return fmt.Errorf("rcompress: could not create ZLIB reader: %w", err)
378381
}
379382
_, err = io.ReadFull(rc, dst[beg:end])
380-
rc.Close()
381-
zlibReaderPool.Put(rc)
383+
zlibPutReader(rc)
382384
if err != nil {
383385
return fmt.Errorf("rcompress: could not decompress ZLIB buffer: %w", err)
384386
}
385387

386388
case LZ4:
387-
src := lz4NewBuffer(srcsz)
389+
src := lz4GetBuffer(int(srcsz))
388390
_, err = io.ReadFull(lr, src)
389391
if err != nil {
390-
lz4BufferPool.Put(src)
392+
lz4PutBuffer(src)
391393
return fmt.Errorf("rcompress: could not read LZ4 block: %w", err)
392394
}
393395
const chksum = 8
@@ -398,9 +400,9 @@ func Decompress(dst []byte, src io.Reader) error {
398400
case srcsz > tgtsz:
399401
// no compression
400402
copy(dst[beg:end], src[chksum:])
401-
lz4BufferPool.Put(src)
403+
lz4PutBuffer(src)
402404
default:
403-
lz4BufferPool.Put(src)
405+
lz4PutBuffer(src)
404406
return fmt.Errorf("rcompress: could not decompress LZ4 block: %w", err)
405407
}
406408
}
@@ -423,13 +425,12 @@ func Decompress(dst []byte, src io.Reader) error {
423425
}
424426

425427
case ZSTD:
426-
rc, err := zstdNewReader(lr)
428+
rc, err := zstdGetReader(lr)
427429
if err != nil {
428430
return fmt.Errorf("rcompress: could not create ZSTD reader: %w", err)
429431
}
430432
_, err = io.ReadFull(rc, dst[beg:end])
431-
rc.Reset(nil)
432-
zstdReaderPool.Put(rc)
433+
zstdPutReader(rc)
433434
if err != nil {
434435
return fmt.Errorf("rcompress: could not decompress ZSTD block: %w", err)
435436
}
@@ -464,24 +465,29 @@ var (
464465
_ io.Writer = (*wbuff)(nil)
465466
)
466467

467-
// TODO writers, need to index by options (e.g. compression level)
468468
var (
469-
lz4BufferPool = sync.Pool{}
470-
zlibReaderPool = sync.Pool{}
471-
zstdReaderPool = sync.Pool{}
469+
lz4BufferPool sync.Pool
470+
zlibReaderPool sync.Pool
471+
zstdReaderPool sync.Pool
472+
zlibWriterPools sync.Map // map[lvl]*pool
473+
zstdWriterPools sync.Map // map[lvl]*pool
472474
)
473475

474-
func lz4NewBuffer(size int64) []byte {
476+
func lz4GetBuffer(size int) []byte {
475477
var b []byte
476478
if bi := lz4BufferPool.Get(); bi != nil {
477479
b = bi.([]byte)
478480
}
479-
if int64(cap(b)) >= size {
481+
if cap(b) >= size {
480482
return b[:size]
481483
}
482484
return make([]byte, size)
483485
}
484486

487+
func lz4PutBuffer(b []byte) {
488+
lz4BufferPool.Put(b)
489+
}
490+
485491
func zlibNewReader(r io.Reader) (io.ReadCloser, error) {
486492
if ri := zlibReaderPool.Get(); ri != nil {
487493
ri.(zlib.Resetter).Reset(r, nil)
@@ -490,11 +496,68 @@ func zlibNewReader(r io.Reader) (io.ReadCloser, error) {
490496
return zlib.NewReader(r)
491497
}
492498

493-
func zstdNewReader(r io.Reader) (*zstd.Decoder, error) {
499+
func zlibPutReader(r io.ReadCloser) error {
500+
// Note that zlib readers should be closed (but not reset)
501+
err := r.Close()
502+
zlibReaderPool.Put(r)
503+
return err
504+
}
505+
506+
func zstdGetReader(r io.Reader) (*zstd.Decoder, error) {
494507
if ri := zstdReaderPool.Get(); ri != nil {
495508
rd := ri.(*zstd.Decoder)
496509
rd.Reset(r)
497510
return rd, nil
498511
}
499512
return zstd.NewReader(r)
500513
}
514+
515+
func zstdPutReader(r *zstd.Decoder) {
516+
// Note that zstd decoders should be reset (but not closed)
517+
r.Reset(nil)
518+
zstdReaderPool.Put(r)
519+
}
520+
521+
func zlibGetWriterLevel(w io.Writer, lvl int) (*zlib.Writer, error) {
522+
if pi, ok := zlibWriterPools.Load(lvl); ok {
523+
if wi := pi.(*sync.Pool).Get(); wi != nil {
524+
z := wi.(*zlib.Writer)
525+
z.Reset(w)
526+
return z, nil
527+
}
528+
}
529+
return zlib.NewWriterLevel(w, lvl)
530+
}
531+
532+
func zlibPutWriterLevel(w *zlib.Writer, lvl int) error {
533+
err := w.Close()
534+
if pi, ok := zlibWriterPools.Load(lvl); ok {
535+
pi.(*sync.Pool).Put(w)
536+
} else {
537+
pi, _ = zlibWriterPools.LoadOrStore(lvl, new(sync.Pool))
538+
pi.(*sync.Pool).Put(w)
539+
}
540+
return err
541+
}
542+
543+
func zstdGetWriterLevel(w io.Writer, lvl int) (*zstd.Encoder, error) {
544+
if pi, ok := zstdWriterPools.Load(lvl); ok {
545+
if wi := pi.(*sync.Pool).Get(); wi != nil {
546+
z := wi.(*zstd.Encoder)
547+
z.Reset(w)
548+
return z, nil
549+
}
550+
}
551+
return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(lvl)))
552+
}
553+
554+
func zstdPutWriterLevel(w *zstd.Encoder, lvl int) error {
555+
err := w.Close()
556+
if pi, ok := zstdWriterPools.Load(lvl); ok {
557+
pi.(*sync.Pool).Put(w)
558+
} else {
559+
pi, _ = zstdWriterPools.LoadOrStore(lvl, new(sync.Pool))
560+
pi.(*sync.Pool).Put(w)
561+
}
562+
return err
563+
}

0 commit comments

Comments
 (0)