Skip to content

Commit

Permalink
fix: add fuzzer for mp4 box decoding and fix discovered issues
Browse files Browse the repository at this point in the history
This greatly reduced the memory usage for corrupted mp4 files by checking for
sizes of the box compared to the sizes of elements..
  • Loading branch information
eric committed Jan 18, 2025
1 parent 8910b0a commit 9c3c1b9
Show file tree
Hide file tree
Showing 69 changed files with 489 additions and 78 deletions.
6 changes: 6 additions & 0 deletions av1/av1codecconfigurationrecord.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package av1

import (
"errors"
"fmt"
"io"

"github.com/Eyevinn/mp4ff/bits"
Expand Down Expand Up @@ -34,6 +35,11 @@ type CodecConfRec struct {

// DecodeAVCDecConfRec - decode an AV1CodecConfRec
func DecodeAV1CodecConfRec(data []byte) (CodecConfRec, error) {
// Minimum size is 4 bytes for the fixed header fields
if len(data) < 4 {
return CodecConfRec{}, fmt.Errorf("av1C: data size %d is too small (minimum 4 bytes)", len(data))
}

av1drc := CodecConfRec{}

Marker := data[0] >> 7
Expand Down
38 changes: 37 additions & 1 deletion avc/avcdecoderconfigurationrecord.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ func CreateAVCDecConfRec(spsNalus [][]byte, ppsNalus [][]byte, includePS bool) (

// DecodeAVCDecConfRec - decode an AVCDecConfRec
func DecodeAVCDecConfRec(data []byte) (DecConfRec, error) {
// Check minimum length for fixed header (6 bytes)
if len(data) < 6 {
return DecConfRec{}, fmt.Errorf("data too short for AVC decoder configuration record: %d bytes", len(data))
}

configurationVersion := data[0] // Should be 1
if configurationVersion != 1 {
return DecConfRec{}, fmt.Errorf("AVC decoder configuration record version %d unknown",
Expand All @@ -75,29 +80,56 @@ func DecodeAVCDecConfRec(data []byte) (DecConfRec, error) {
}
numSPS := data[5] & 0x1f // 5 bits following 3 reserved bits
pos := 6

spsNALUs := make([][]byte, 0, 1)
for i := 0; i < int(numSPS); i++ {
// Check if we have enough bytes to read NALU length
if pos+2 > len(data) {
return DecConfRec{}, fmt.Errorf("not enough data to read SPS NALU length at position %d", pos)
}
naluLength := int(binary.BigEndian.Uint16(data[pos : pos+2]))
pos += 2

// Check if we have enough bytes to read NALU
if pos+naluLength > len(data) {
return DecConfRec{}, fmt.Errorf("not enough data to read SPS NALU of length %d at position %d", naluLength, pos)
}
spsNALUs = append(spsNALUs, data[pos:pos+naluLength])
pos += naluLength
}
ppsNALUs := make([][]byte, 0, 1)

// Check if we have enough bytes to read numPPS
if pos >= len(data) {
return DecConfRec{}, fmt.Errorf("not enough data to read number of PPS at position %d", pos)
}
numPPS := data[pos]
pos++

ppsNALUs := make([][]byte, 0, 1)
for i := 0; i < int(numPPS); i++ {
// Check if we have enough bytes to read NALU length
if pos+2 > len(data) {
return DecConfRec{}, fmt.Errorf("not enough data to read PPS NALU length at position %d", pos)
}
naluLength := int(binary.BigEndian.Uint16(data[pos : pos+2]))
pos += 2

// Check if we have enough bytes to read NALU
if pos+naluLength > len(data) {
return DecConfRec{}, fmt.Errorf("not enough data to read PPS NALU of length %d at position %d", naluLength, pos)
}
ppsNALUs = append(ppsNALUs, data[pos:pos+naluLength])
pos += naluLength
}

adcr := DecConfRec{
AVCProfileIndication: AVCProfileIndication,
ProfileCompatibility: ProfileCompatibility,
AVCLevelIndication: AVCLevelIndication,
SPSnalus: spsNALUs,
PPSnalus: ppsNALUs,
}

// The rest of this structure may vary
// ISO/IEC 14496-15 2017 says that
// Compatible extensions to this record will extend it and
Expand All @@ -114,6 +146,10 @@ func DecodeAVCDecConfRec(data []byte) (DecConfRec, error) {
adcr.NoTrailingInfo = true
return adcr, nil
}
// Check if we have enough bytes for the trailing info
if pos+4 > len(data) {
return DecConfRec{}, fmt.Errorf("not enough data for trailing info at position %d", pos)
}
adcr.ChromaFormat = data[pos] & 0x03
adcr.BitDepthLumaMinus1 = data[pos+1] & 0x07
adcr.BitDepthChromaMinus1 = data[pos+2] & 0x07
Expand Down
7 changes: 7 additions & 0 deletions bits/fixedslicereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ func (s *FixedSliceReader) ReadZeroTerminatedString(maxLen int) string {
}
startPos := s.pos
maxPos := startPos + maxLen
if maxPos > s.len {
maxPos = s.len
}
for {
if s.pos >= maxPos {
s.err = errors.New("did not find terminating zero")
Expand Down Expand Up @@ -208,6 +211,10 @@ func (s *FixedSliceReader) ReadPossiblyZeroTerminatedString(maxLen int) (str str
// ReadBytes - read a slice of n bytes
// Return empty slice if n bytes not available
func (s *FixedSliceReader) ReadBytes(n int) []byte {
if n < 0 {
s.err = fmt.Errorf("attempt to read negative number of bytes: %d", n)
return []byte{}
}
if s.err != nil {
return []byte{}
}
Expand Down
14 changes: 10 additions & 4 deletions mp4/box.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ func DecodeHeader(r io.Reader) (BoxHeader, error) {
} else if size == 0 {
return BoxHeader{}, fmt.Errorf("Size 0, meaning to end of file, not supported")
}
if uint64(headerLen) > size {
return BoxHeader{}, fmt.Errorf("box header size %d exceeds box size %d", headerLen, size)
}
return BoxHeader{string(buf[4:8]), size, headerLen}, nil
}

Expand Down Expand Up @@ -380,14 +383,17 @@ func makebuf(b Box) []byte {

// readBoxBody reads complete box body. Returns error if not possible
func readBoxBody(r io.Reader, h BoxHeader) ([]byte, error) {
bodyLen := h.Size - uint64(h.Hdrlen)
if bodyLen == 0 {
hdrLen := uint64(h.Hdrlen)
if hdrLen == h.Size {
return nil, nil
}
body := make([]byte, bodyLen)
_, err := io.ReadFull(r, body)
bodyLen := h.Size - hdrLen
body, err := io.ReadAll(io.LimitReader(r, int64(bodyLen)))
if err != nil {
return nil, err
}
if len(body) != int(bodyLen) {
return nil, fmt.Errorf("read box body length %d does not match expected length %d", len(body), bodyLen)
}
return body, nil
}
5 changes: 4 additions & 1 deletion mp4/boxsr.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ func DecodeHeaderSR(sr bits.SliceReader) (BoxHeader, error) {
} else if size == 0 {
return BoxHeader{}, fmt.Errorf("Size 0, meaning to end of file, not supported")
}
return BoxHeader{boxType, size, headerLen}, nil
if uint64(headerLen) > size {
return BoxHeader{}, fmt.Errorf("box header size %d exceeds box size %d", headerLen, size)
}
return BoxHeader{boxType, size, headerLen}, sr.AccError()
}

// DecodeFile - parse and decode a file from reader r with optional file options.
Expand Down
21 changes: 17 additions & 4 deletions mp4/co64.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ func DecodeCo64(hdr BoxHeader, startPos uint64, r io.Reader) (Box, error) {
func DecodeCo64SR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, error) {
versionAndFlags := sr.ReadUint32()
nrEntries := sr.ReadUint32()

b := &Co64Box{
Version: byte(versionAndFlags >> 24),
Flags: versionAndFlags & flagsMask,
ChunkOffset: make([]uint64, nrEntries),
Version: byte(versionAndFlags >> 24),
Flags: versionAndFlags & flagsMask,
}

if hdr.Size != b.expectedSize(nrEntries) {
return nil, fmt.Errorf("co64: expected size %d, got %d", b.expectedSize(nrEntries), hdr.Size)
}

b.ChunkOffset = make([]uint64, nrEntries)

for i := uint32(0); i < nrEntries; i++ {
b.ChunkOffset[i] = sr.ReadUint64()
if sr.AccError() != nil {
return nil, sr.AccError()
}
}
return b, sr.AccError()
}
Expand All @@ -49,9 +58,13 @@ func (b *Co64Box) Type() string {
return "co64"
}

func (b *Co64Box) expectedSize(nrEntries uint32) uint64 {
return uint64(boxHeaderSize + 8 + nrEntries*8)
}

// Size - box-specific size
func (b *Co64Box) Size() uint64 {
return uint64(boxHeaderSize + 8 + len(b.ChunkOffset)*8)
return b.expectedSize(uint32(len(b.ChunkOffset)))
}

// Encode - write box to w
Expand Down
20 changes: 15 additions & 5 deletions mp4/ctts.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ func DecodeCttsSR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err
entryCount := sr.ReadUint32()

b := &CttsBox{
Version: byte(versionAndFlags >> 24),
Flags: versionAndFlags & flagsMask,
EndSampleNr: make([]uint32, entryCount+1),
SampleOffset: make([]int32, entryCount),
Version: byte(versionAndFlags >> 24),
Flags: versionAndFlags & flagsMask,
}

if hdr.Size != b.expectedSize(entryCount) {
return nil, fmt.Errorf("ctts: expected size %d, got %d", b.expectedSize(entryCount), hdr.Size)
}

b.EndSampleNr = make([]uint32, entryCount+1)
b.SampleOffset = make([]int32, entryCount)

var endSampleNr uint32 = 0
b.EndSampleNr[0] = endSampleNr
for i := 0; i < int(entryCount); i++ {
Expand All @@ -58,7 +63,12 @@ func (b *CttsBox) Type() string {

// Size - calculated size of box
func (b *CttsBox) Size() uint64 {
return uint64(boxHeaderSize + 8 + len(b.SampleOffset)*8)
return b.expectedSize(uint32(len(b.SampleOffset)))
}

// expectedSize - calculate size for a given entry count
func (b *CttsBox) expectedSize(entryCount uint32) uint64 {
return uint64(boxHeaderSize + 8 + uint64(entryCount)*8) // 8 = version + flags + entryCount, 8 = sampleCount(4) + sampleOffset(4)
}

// Encode - write box to w
Expand Down
16 changes: 13 additions & 3 deletions mp4/elst.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ func DecodeElstSR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err
b := &ElstBox{
Version: version,
Flags: versionAndFlags & flagsMask,
Entries: make([]ElstEntry, entryCount),
}

if hdr.Size != b.expectedSize(entryCount) {
return nil, fmt.Errorf("elst: expected size %d, got %d", b.expectedSize(entryCount), hdr.Size)
}

b.Entries = make([]ElstEntry, entryCount)

if version == 1 {
for i := 0; i < int(entryCount); i++ {
b.Entries[i].SegmentDuration = sr.ReadUint64()
Expand Down Expand Up @@ -71,10 +76,15 @@ func (b *ElstBox) Type() string {

// Size - calculated size of box
func (b *ElstBox) Size() uint64 {
return b.expectedSize(uint32(len(b.Entries)))
}

// expectedSize - calculate size for a given entry count
func (b *ElstBox) expectedSize(entryCount uint32) uint64 {
if b.Version == 1 {
return uint64(boxHeaderSize + 8 + len(b.Entries)*20)
return uint64(boxHeaderSize + 8 + uint64(entryCount)*20) // 8 = version + flags + entryCount, 20 = uint64 + int64 + 2*int16
}
return uint64(boxHeaderSize + 8 + len(b.Entries)*12) // m.Version == 0
return uint64(boxHeaderSize + 8 + uint64(entryCount)*12) // 8 = version + flags + entryCount, 12 = uint32 + int32 + 2*int16
}

// Encode - write box to w
Expand Down
6 changes: 4 additions & 2 deletions mp4/eventmessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,17 @@ func DecodeSilbSR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err
b.Version = uint8(versionAndFlags >> 24)
b.Flags = versionAndFlags & flagsMask
nrSchemes := sr.ReadUint32()
b.Schemes = make([]SilbEntry, nrSchemes)
for i := uint32(0); i < nrSchemes; i++ {
schemeIdURI := sr.ReadZeroTerminatedString(int(hdr.payloadLen()) - 8)
value := sr.ReadZeroTerminatedString(int(hdr.payloadLen()) - 8 - len(schemeIdURI) - 1)
atLeastOneFlag := sr.ReadUint8() == 1
b.Schemes[i] = SilbEntry{
b.Schemes = append(b.Schemes, SilbEntry{
SchemeIdURI: schemeIdURI,
Value: value,
AtLeastOneFlag: atLeastOneFlag,
})
if sr.AccError() != nil {
return nil, sr.AccError()
}
}
b.OtherSchemesFlag = sr.ReadUint8() == 1
Expand Down
81 changes: 81 additions & 0 deletions mp4/fuzz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//go:build go1.18
// +build go1.18

package mp4

import (
"bytes"
"context"
"errors"
"io"
"os"
"runtime"
"strings"
"testing"
"time"
)

func monitorMemory(ctx context.Context, t *testing.T, memoryLimit int) {
go func() {
timer := time.NewTicker(500 * time.Millisecond)
defer timer.Stop()
var m runtime.MemStats

for {
select {
case <-ctx.Done():
return
case <-timer.C:
runtime.ReadMemStats(&m)
if m.Alloc > uint64(memoryLimit) {
t.Logf("memory limit exceeded: %d > %d", m.Alloc, memoryLimit)
t.Fail()
return
}
}
}
}()
}

func FuzzDecodeBox(f *testing.F) {
entries, err := os.ReadDir("testdata")
if err != nil {
f.Fatal(err)
}

for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".mp4") {
testData, err := os.ReadFile("testdata/" + entry.Name())
if err != nil {
f.Fatal(err)
}
f.Add(testData)
}
}

f.Fuzz(func(t *testing.T, b []byte) {
if t.Name() == "FuzzDecodeBox/75565444c6c2f1dd" {
t.Skip("There is a bug in SencBox.Size() that needs to be fixed for " + t.Name())
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
monitorMemory(ctx, t, 500*1024*1024) // 500MB

r := bytes.NewReader(b)

var pos uint64 = 0
for {
box, err := DecodeBox(pos, r)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
}
if box == nil {
break
}
pos += box.Size()
}
})
}
6 changes: 4 additions & 2 deletions mp4/moof.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ type MoofBox struct {

// DecodeMoof - box-specific decode
func DecodeMoof(hdr BoxHeader, startPos uint64, r io.Reader) (Box, error) {
data := make([]byte, hdr.payloadLen())
_, err := io.ReadFull(r, data)
data, err := io.ReadAll(io.LimitReader(r, int64(hdr.payloadLen())))
if err != nil {
return nil, err
}
if len(data) != int(hdr.payloadLen()) {
return nil, fmt.Errorf("moof: expected %d bytes, got %d", hdr.payloadLen(), len(data))
}
sr := bits.NewFixedSliceReader(data)
children, err := DecodeContainerChildrenSR(hdr, startPos+8, startPos+hdr.Size, sr)
if err != nil {
Expand Down
Loading

0 comments on commit 9c3c1b9

Please sign in to comment.