diff --git a/Taskfile.yml b/Taskfile.yml index fe971b2..2589db6 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -89,9 +89,9 @@ tasks: bench: desc: "Run benchmarks" cmds: - - go test -list='{{.BENCH}}' ./wirebson + - go test -list='Benchmark.*' ./... # -timeout is needed due to https://github.com/golang/go/issues/69181 - - go test -bench='{{.BENCH}}' -count={{.BENCH_COUNT}} -benchtime={{.BENCH_TIME}} -timeout=60m ./wirebson | tee new.txt + - go test -count=10 -bench=BenchmarkDocument -benchtime={{.BENCH_TIME}} -timeout=60m ./wirebson | tee -a new.txt - bin/benchstat old.txt new.txt fuzz: diff --git a/wirebson/array.go b/wirebson/array.go index 583178f..5ef81ef 100644 --- a/wirebson/array.go +++ b/wirebson/array.go @@ -15,7 +15,6 @@ package wirebson import ( - "bytes" "encoding/binary" "iter" "log/slog" @@ -150,31 +149,41 @@ func (arr *Array) SortInterface(less func(a, b any) bool) sort.Interface { } // Encode encodes non-nil Array. -// -// TODO https://github.com/FerretDB/wire/issues/21 -// This method should accept a slice of bytes, not return it. -// That would allow to avoid unnecessary allocations. func (arr *Array) Encode() (RawArray, error) { must.NotBeZero(arr) - size := sizeArray(arr) - buf := bytes.NewBuffer(make([]byte, 0, size)) - - if err := binary.Write(buf, binary.LittleEndian, uint32(size)); err != nil { + raw := make([]byte, Size(arr)) + if err := arr.EncodeTo(raw); err != nil { return nil, lazyerrors.Error(err) } - for i, v := range arr.values { - if err := encodeField(buf, strconv.Itoa(i), v); err != nil { - return nil, lazyerrors.Error(err) + return raw, nil +} + +// EncodeTo encodes non-nil Array. +// +// raw must be at least Size(arr) bytes long; otherwise, EncodeTo will panic. +// Only raw[0:Size(arr)] bytes are modified. +func (arr *Array) EncodeTo(raw RawArray) error { + must.NotBeZero(arr) + + // ensure raw length early + s := sizeArray(arr) + raw[s-1] = 0 + + binary.LittleEndian.PutUint32(raw, uint32(s)) + + i := 4 + for n, v := range arr.values { + w, err := encodeField(raw[i:], strconv.Itoa(n), v) + if err != nil { + return lazyerrors.Error(err) } - } - if err := binary.Write(buf, binary.LittleEndian, byte(0)); err != nil { - return nil, lazyerrors.Error(err) + i += w } - return buf.Bytes(), nil + return nil } // Decode returns itself to implement [AnyArray]. diff --git a/wirebson/document.go b/wirebson/document.go index 53b15a5..59514d8 100644 --- a/wirebson/document.go +++ b/wirebson/document.go @@ -15,7 +15,6 @@ package wirebson import ( - "bytes" "encoding/binary" "iter" "log/slog" @@ -226,31 +225,41 @@ func (doc *Document) Command() string { } // Encode encodes non-nil Document. -// -// TODO https://github.com/FerretDB/wire/issues/21 -// This method should accept a slice of bytes, not return it. -// That would allow to avoid unnecessary allocations. func (doc *Document) Encode() (RawDocument, error) { must.NotBeZero(doc) - size := sizeDocument(doc) - buf := bytes.NewBuffer(make([]byte, 0, size)) - - if err := binary.Write(buf, binary.LittleEndian, uint32(size)); err != nil { + raw := make([]byte, Size(doc)) + if err := doc.EncodeTo(raw); err != nil { return nil, lazyerrors.Error(err) } + return raw, nil +} + +// EncodeTo encodes non-nil Document. +// +// raw must be at least Size(doc) bytes long; otherwise, EncodeTo will panic. +// Only raw[0:Size(doc)] bytes are modified. +func (doc *Document) EncodeTo(raw RawDocument) error { + must.NotBeZero(doc) + + // ensure raw length early + s := sizeDocument(doc) + raw[s-1] = 0 + + binary.LittleEndian.PutUint32(raw, uint32(s)) + + i := 4 for _, f := range doc.fields { - if err := encodeField(buf, f.name, f.value); err != nil { - return nil, lazyerrors.Error(err) + w, err := encodeField(raw[i:], f.name, f.value) + if err != nil { + return lazyerrors.Error(err) } - } - if err := binary.Write(buf, binary.LittleEndian, byte(0)); err != nil { - return nil, lazyerrors.Error(err) + i += w } - return buf.Bytes(), nil + return nil } // Decode returns itself to implement [AnyDocument]. diff --git a/wirebson/encode.go b/wirebson/encode.go index 52a6e35..cb3891b 100644 --- a/wirebson/encode.go +++ b/wirebson/encode.go @@ -15,7 +15,6 @@ package wirebson import ( - "bytes" "fmt" "time" @@ -24,138 +23,115 @@ import ( // encodeField encodes document/array field. // +// It returns the number of bytes written. // It panics if v is not a valid type. -func encodeField(buf *bytes.Buffer, name string, v any) error { +func encodeField(b []byte, name string, v any) (int, error) { + var i int switch v := v.(type) { case *Document: - if err := buf.WriteByte(byte(tagDocument)); err != nil { - return lazyerrors.Error(err) - } - - b := make([]byte, SizeCString(name)) - EncodeCString(b, name) + b[i] = byte(tagDocument) + i++ - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } + EncodeCString(b[i:], name) + i += SizeCString(name) - b, err := v.Encode() + err := v.EncodeTo(b[i:]) if err != nil { - return lazyerrors.Error(err) + return 0, lazyerrors.Error(err) } - if _, err = buf.Write(b); err != nil { - return lazyerrors.Error(err) - } + i += sizeDocument(v) case RawDocument: - if err := buf.WriteByte(byte(tagDocument)); err != nil { - return lazyerrors.Error(err) - } + b[i] = byte(tagDocument) + i++ - b := make([]byte, SizeCString(name)) - EncodeCString(b, name) + EncodeCString(b[i:], name) + i += SizeCString(name) - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) + if len(v) > len(b[i:]) { + panic(fmt.Sprintf("length of b should be at least %d bytes, got %d", len(v), len(b[i:]))) } - if _, err := buf.Write(v); err != nil { - return lazyerrors.Error(err) - } + i += copy(b[i:], v) case *Array: - if err := buf.WriteByte(byte(tagArray)); err != nil { - return lazyerrors.Error(err) - } - - b := make([]byte, SizeCString(name)) - EncodeCString(b, name) + b[i] = byte(tagArray) + i++ - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } + EncodeCString(b[i:], name) + i += SizeCString(name) - b, err := v.Encode() + err := v.EncodeTo(b[i:]) if err != nil { - return lazyerrors.Error(err) + return 0, lazyerrors.Error(err) } - if _, err = buf.Write(b); err != nil { - return lazyerrors.Error(err) - } + i += sizeArray(v) case RawArray: - if err := buf.WriteByte(byte(tagArray)); err != nil { - return lazyerrors.Error(err) - } + b[i] = byte(tagArray) + i++ - b := make([]byte, SizeCString(name)) - EncodeCString(b, name) + EncodeCString(b[i:], name) + i += SizeCString(name) - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) + if len(v) > len(b[i:]) { + panic(fmt.Sprintf("length of b should be at least %d bytes, got %d", len(v), len(b[i:]))) } - if _, err := buf.Write(v); err != nil { - return lazyerrors.Error(err) - } + i += copy(b[i:], v) default: - return encodeScalarField(buf, name, v) + return i + encodeScalarField(b[i:], name, v), nil } - return nil + return i, nil } // encodeScalarField encodes scalar document field. // +// It returns the number of bytes written. // It panics if v is not a scalar value. -func encodeScalarField(buf *bytes.Buffer, name string, v any) error { +func encodeScalarField(b []byte, name string, v any) int { + var i int switch v := v.(type) { case float64: - buf.WriteByte(byte(tagFloat64)) + b[i] = byte(tagFloat64) case string: - buf.WriteByte(byte(tagString)) + b[i] = byte(tagString) case Binary: - buf.WriteByte(byte(tagBinary)) + b[i] = byte(tagBinary) case ObjectID: - buf.WriteByte(byte(tagObjectID)) + b[i] = byte(tagObjectID) case bool: - buf.WriteByte(byte(tagBool)) + b[i] = byte(tagBool) case time.Time: - buf.WriteByte(byte(tagTime)) + b[i] = byte(tagTime) case NullType: - buf.WriteByte(byte(tagNull)) + b[i] = byte(tagNull) case Regex: - buf.WriteByte(byte(tagRegex)) + b[i] = byte(tagRegex) case int32: - buf.WriteByte(byte(tagInt32)) + b[i] = byte(tagInt32) case Timestamp: - buf.WriteByte(byte(tagTimestamp)) + b[i] = byte(tagTimestamp) case int64: - buf.WriteByte(byte(tagInt64)) + b[i] = byte(tagInt64) case Decimal128: - buf.WriteByte(byte(tagDecimal128)) + b[i] = byte(tagDecimal128) default: panic(fmt.Sprintf("invalid BSON type %T", v)) } + i++ - b := make([]byte, SizeCString(name)) - EncodeCString(b, name) + EncodeCString(b[i:], name) + i += SizeCString(name) - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - b = make([]byte, sizeScalar(v)) - encodeScalarValue(b, v) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } + encodeScalarValue(b[i:], v) + i += sizeScalar(v) - return nil + return i } // encodeScalarValue encodes value v into b. diff --git a/wirebson/encode_test.go b/wirebson/encode_test.go new file mode 100644 index 0000000..cb9b179 --- /dev/null +++ b/wirebson/encode_test.go @@ -0,0 +1,42 @@ +package wirebson + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeScalarField(t *testing.T) { + t.Parallel() + + actual := make([]byte, 13) + assert.Equal(t, 13, encodeScalarField(actual[0:], "foo", "bar")) + + expected := []byte{0x02, 0x66, 0x6f, 0x6f, 0x0, 0x4, 0x0, 0x0, 0x0, 0x62, 0x61, 0x72, 0x0} + assert.Equal(t, expected, actual) +} + +func TestEncodeField(t *testing.T) { + t.Parallel() + + var i int + actual := make([]byte, 22) + w, err := encodeField(actual[i:], "foo", "bar") + require.NoError(t, err) + + assert.Equal(t, 13, w) + i += w + + expected := []byte{0x2, 0x66, 0x6f, 0x6f, 0x0, 0x4, 0x0, 0x0, 0x0, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0} + assert.Equal(t, expected, actual) + + w, err = encodeField(actual[i:], "foo", int32(1)) + require.NoError(t, err) + + assert.Equal(t, 9, w) + i += w + + expected = []byte{0x2, 0x66, 0x6f, 0x6f, 0x0, 0x4, 0x0, 0x0, 0x0, 0x62, 0x61, 0x72, 0x0, 0x10, 0x66, 0x6f, 0x6f, 0x0, 0x1, 0x0, 0x0, 0x0} + assert.Equal(t, expected, actual) +} diff --git a/wirebson/objectid.go b/wirebson/objectid.go index 300d5b5..83e3d27 100644 --- a/wirebson/objectid.go +++ b/wirebson/objectid.go @@ -29,7 +29,9 @@ const sizeObjectID = 12 // b must be at least 12 ([sizeObjectID]) bytes long; otherwise, encodeObjectID will panic. // Only b[0:12] bytes are modified. func encodeObjectID(b []byte, v ObjectID) { + // ensure b length early _ = b[11] + copy(b, v[:]) }