diff --git a/internal/integration/crud_prose_test.go b/internal/integration/crud_prose_test.go index 826dfb4f71..9cbe6e2da6 100644 --- a/internal/integration/crud_prose_test.go +++ b/internal/integration/crud_prose_test.go @@ -1017,3 +1017,27 @@ func TestClientBulkWriteProse(t *testing.T) { assert.Equal(mt, num, int(n), "expected %d documents, got: %d", num, n) }) } + +func TestBatchSize(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Pinned)) + mt.Setup() + + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + var docs []any + limit := hello.MaxBsonObjectSize - 30 + for need := hello.MaxMessageSizeBytes - 350; need > 0; need -= limit { + if need >= limit { + docs = append(docs, bson.D{{"x", string(make([]byte, limit))}}) + } else { + docs = append(docs, bson.D{{"x", string(make([]byte, need))}}) + } + } + _, err = mt.Coll.InsertMany(context.Background(), docs) + assert.NoError(mt, err, "InsertMany error: %v", err) +} diff --git a/x/mongo/driver/batches.go b/x/mongo/driver/batches.go index 51a32bc962..ed1b2882bb 100644 --- a/x/mongo/driver/batches.go +++ b/x/mongo/driver/batches.go @@ -39,7 +39,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, totalSize int) (int, idx, dst = bsoncore.ReserveLength(dst) dst = append(dst, b.Identifier...) dst = append(dst, 0x00) - var size int + size := len(dst) var n int for i := b.offset; i < len(b.Documents); i++ { if n == maxCount { @@ -69,7 +69,7 @@ func (b *Batches) AppendBatchArray(dst []byte, maxCount, totalSize int) (int, [] } l := len(dst) aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier) - var size int + size := len(dst) var n int for i := b.offset; i < len(b.Documents); i++ { if n == maxCount { diff --git a/x/mongo/driver/batches_test.go b/x/mongo/driver/batches_test.go index 5a028c96dd..5efc57c7ed 100644 --- a/x/mongo/driver/batches_test.go +++ b/x/mongo/driver/batches_test.go @@ -36,9 +36,10 @@ func TestAppendBatchSequence(t *testing.T) { batches := newTestBatches(t) got := []byte{42} + sizeLimit := len(batches.Documents[0]) + len(batches.Documents[1]) var n int var err error - n, got, err = batches.AppendBatchSequence(got, 2, len(batches.Documents[0])) + n, got, err = batches.AppendBatchSequence(got, 2, sizeLimit) assert.NoError(t, err) assert.Equal(t, 1, n) @@ -57,9 +58,10 @@ func TestAppendBatchArray(t *testing.T) { batches := newTestBatches(t) got := []byte{42} + sizeLimit := len(batches.Documents[0]) + len(batches.Documents[1]) var n int var err error - n, got, err = batches.AppendBatchArray(got, 2, len(batches.Documents[0])) + n, got, err = batches.AppendBatchArray(got, 2, sizeLimit) assert.NoError(t, err) assert.Equal(t, 1, n)