Skip to content

Commit

Permalink
defer buffer pool returns
Browse files Browse the repository at this point in the history
Under some error paths we may leak the buffer and never get it back.
This defers the return to ensure we always bring the buffer back home.
  • Loading branch information
thomasjungblut committed Jul 5, 2024
1 parent 414d6ad commit 422c30a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
14 changes: 8 additions & 6 deletions recordio/file_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ func (r *FileReader) ReadNext() ([]byte, error) {
}

expectedBytesRead, pooledRecordBuffer := allocateRecordBufferPooled(r.bufferPool, r.header, payloadSizeUncompressed, payloadSizeCompressed)
defer r.bufferPool.Put(pooledRecordBuffer)

numRead, err := io.ReadFull(r.reader, pooledRecordBuffer)
if err != nil {
return nil, fmt.Errorf("error while reading into record buffer of '%s': %w", r.file.Name(), err)
Expand All @@ -97,20 +99,19 @@ func (r *FileReader) ReadNext() ([]byte, error) {
var returnSlice []byte
if r.header.compressor != nil {
pooledDecompressionBuffer := r.bufferPool.Get(int(payloadSizeUncompressed))
defer r.bufferPool.Put(pooledDecompressionBuffer)

decompressedRecord, err := r.header.compressor.DecompressWithBuf(pooledRecordBuffer, pooledDecompressionBuffer)
if err != nil {
return nil, err
}
// we do a defensive copy here not to leak the pooled slice
returnSlice = make([]byte, len(decompressedRecord))
copy(returnSlice, decompressedRecord)
r.bufferPool.Put(pooledRecordBuffer)
r.bufferPool.Put(pooledDecompressionBuffer)
} else {
// we do a defensive copy here not to leak the pooled slice
returnSlice = make([]byte, len(pooledRecordBuffer))
copy(returnSlice, pooledRecordBuffer)
r.bufferPool.Put(pooledRecordBuffer)
}

// why not just r.currentOffset = r.reader.count? we could've skipped something in between which makes the counts inconsistent
Expand Down Expand Up @@ -159,6 +160,8 @@ func (r *FileReader) SkipNext() error {
// SkipNextV1 is legacy support path for non-vint compressed V1
func SkipNextV1(r *FileReader) error {
headerBuf := r.bufferPool.Get(RecordHeaderSizeBytes)
defer r.bufferPool.Put(headerBuf)

numRead, err := io.ReadFull(r.reader, headerBuf)
if err != nil {
return fmt.Errorf("error while reading record header of '%s': %w", r.file.Name(), err)
Expand All @@ -174,8 +177,6 @@ func SkipNextV1(r *FileReader) error {
return fmt.Errorf("error while parsing record header of '%s': %w", r.file.Name(), err)
}

r.bufferPool.Put(headerBuf)

expectedBytesSkipped := payloadSizeUncompressed
if r.header.compressor != nil {
expectedBytesSkipped = payloadSizeCompressed
Expand Down Expand Up @@ -207,6 +208,8 @@ func (r *FileReader) Close() error {
// legacy support path for non-vint compressed V1
func readNextV1(r *FileReader) ([]byte, error) {
headerBuf := r.bufferPool.Get(RecordHeaderSizeBytes)
defer r.bufferPool.Put(headerBuf)

numRead, err := io.ReadFull(r.reader, headerBuf)
if err != nil {
return nil, fmt.Errorf("error while reading record header of '%s': %w", r.file.Name(), err)
Expand All @@ -221,7 +224,6 @@ func readNextV1(r *FileReader) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("error while parsing record header of '%s': %w", r.file.Name(), err)
}
r.bufferPool.Put(headerBuf)

expectedBytesRead, recordBuffer := allocateRecordBuffer(r.header, payloadSizeUncompressed, payloadSizeCompressed)
numRead, err = io.ReadFull(r.reader, recordBuffer)
Expand Down
12 changes: 7 additions & 5 deletions recordio/mmap_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func (r *MMapReader) ReadNextAt(offset uint64) ([]byte, error) {
return readNextAtV1(r, offset)
} else {
headerBufPooled := r.bufferPool.Get(RecordHeaderV2MaxSizeBytes)
defer r.bufferPool.Put(headerBufPooled)

numRead, err := r.mmapReader.ReadAt(headerBufPooled, int64(offset))
if err != nil {
Expand All @@ -77,8 +78,9 @@ func (r *MMapReader) ReadNextAt(offset uint64) ([]byte, error) {
return nil, fmt.Errorf("failed reading record header at offset %d in mmap reader for '%s': %w", offset, r.path, err)
}

r.bufferPool.Put(headerBufPooled)
expectedBytesRead, pooledRecordBuf := allocateRecordBufferPooled(r.bufferPool, r.header, payloadSizeUncompressed, payloadSizeCompressed)
defer r.bufferPool.Put(pooledRecordBuf)

numRead, err = r.mmapReader.ReadAt(pooledRecordBuf, int64(offset)+int64(headerByteReader.Count()))
if err != nil {
return nil, fmt.Errorf("failed reading record at offset %d in mmap reader for '%s': %w", offset, r.path, err)
Expand All @@ -91,27 +93,28 @@ func (r *MMapReader) ReadNextAt(offset uint64) ([]byte, error) {
var returnSlice []byte
if r.header.compressor != nil {
pooledDecompressionBuffer := r.bufferPool.Get(int(payloadSizeUncompressed))
defer r.bufferPool.Put(pooledDecompressionBuffer)

decompressedRecord, err := r.header.compressor.DecompressWithBuf(pooledRecordBuf, pooledDecompressionBuffer)
if err != nil {
return nil, fmt.Errorf("failed decompressing record at offset %d in mmap reader for '%s': %w", offset, r.path, err)
}
// we do a defensive copy here not to leak the pooled slice
returnSlice = make([]byte, len(decompressedRecord))
copy(returnSlice, decompressedRecord)
r.bufferPool.Put(pooledRecordBuf)
r.bufferPool.Put(pooledDecompressionBuffer)
} else {
// we do a defensive copy here not to leak the pooled slice
returnSlice = make([]byte, len(pooledRecordBuf))
copy(returnSlice, pooledRecordBuf)
r.bufferPool.Put(pooledRecordBuf)
}
return returnSlice, nil
}
}

func readNextAtV1(r *MMapReader, offset uint64) ([]byte, error) {
headerBufPooled := r.bufferPool.Get(RecordHeaderSizeBytes)
defer r.bufferPool.Put(headerBufPooled)

numRead, err := r.mmapReader.ReadAt(headerBufPooled, int64(offset))
if err != nil {
return nil, fmt.Errorf("failed reading at offset %d in mmap reader for '%s': %w", offset, r.path, err)
Expand All @@ -126,7 +129,6 @@ func readNextAtV1(r *MMapReader, offset uint64) ([]byte, error) {
return nil, fmt.Errorf("failed reading record header at offset %d in mmap reader for '%s': %w", offset, r.path, err)
}

r.bufferPool.Put(headerBufPooled)
expectedBytesRead, recordBuffer := allocateRecordBuffer(r.header, payloadSizeUncompressed, payloadSizeCompressed)
numRead, err = r.mmapReader.ReadAt(recordBuffer, int64(offset+RecordHeaderSizeBytes))
if err != nil {
Expand Down

0 comments on commit 422c30a

Please sign in to comment.