diff --git a/internal/integration/gridfs_test.go b/internal/integration/gridfs_test.go index 69c6191d14..80a9126886 100644 --- a/internal/integration/gridfs_test.go +++ b/internal/integration/gridfs_test.go @@ -11,7 +11,9 @@ import ( "context" "io" "math/rand" + "os" "runtime" + "sync" "testing" "time" @@ -19,6 +21,7 @@ import ( "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/israce" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" @@ -529,6 +532,41 @@ func TestGridFS(x *testing.T) { }) } +func TestOpenUploadStreamConcurrently(t *testing.T) { + t.Parallel() + + uri, err := integtest.MongoDBURI() + require.NoError(t, err, "error getting URI: %v", err) + opts := options.Client().ApplyURI(uri) + if os.Getenv("REQUIRE_API_VERSION") == "true" { + opts.SetServerAPIOptions(options.ServerAPI(options.ServerAPIVersion1)) + } + client, err := mongo.Connect(opts) + require.NoError(t, err, "Connect error: %v", err) + defer func() { + _ = client.Disconnect(context.Background()) + }() + + db := client.Database(mtest.TestDB) + bucket := db.GridFSBucket() + defer func() { + _ = bucket.Drop(context.Background()) + }() + + const size = 10_000 + + wg := sync.WaitGroup{} + wg.Add(size) + for i := 0; i < size; i++ { + go func() { + defer wg.Done() + _, err := bucket.OpenUploadStream(context.Background(), "foo") + assert.NoError(t, err, "OpenUploadStream error: %v", err) + }() + } + wg.Wait() +} + func assertGridFSCollectionState(mt *mtest.T, coll *mongo.Collection, expectedName string, expectedNumDocuments int64) { mt.Helper() diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index 7423077029..1da0f56041 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io" + "sync/atomic" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/csot" @@ -37,6 +38,8 @@ var ErrMissingGridFSChunkSize = errors.New("files collection document does not c // GridFSBucket represents a GridFS bucket. type GridFSBucket struct { + firstWriteDone uint32 + db *Database chunksColl *Collection // collection to store file chunks filesColl *Collection // collection to store file metadata @@ -47,9 +50,8 @@ type GridFSBucket struct { rc *readconcern.ReadConcern rp *readpref.ReadPref - firstWriteDone bool - readBuf []byte - writeBuf []byte + readBuf []byte + writeBuf []byte } // upload contains options to upload a file to a bucket. @@ -531,14 +533,10 @@ func (b *GridFSBucket) createIndexes(ctx context.Context) error { } func (b *GridFSBucket) checkFirstWrite(ctx context.Context) error { - if !b.firstWriteDone { + if atomic.CompareAndSwapUint32(&b.firstWriteDone, 0, 1) { // before the first write operation, must determine if files collection is empty // if so, create indexes if they do not already exist - - if err := b.createIndexes(ctx); err != nil { - return err - } - b.firstWriteDone = true + return b.createIndexes(ctx) } return nil